llmvm_protocol/
service.rs

1use serde::{Deserialize, Serialize};
2
3use std::{
4    sync::Arc,
5    task::{Context, Poll},
6};
7
8use futures::stream::StreamExt;
9use multilink::{tower::Service, ServiceResponse};
10
11pub use multilink::{BoxedService, ServiceError, ServiceFuture};
12
13use crate::{
14    Backend, BackendGenerationRequest, BackendGenerationResponse, Core, GenerationRequest,
15    GenerationResponse, Message, ThreadInfo,
16};
17
18/// Enum containing all types of backend requests.
19#[derive(Clone, Serialize, Deserialize)]
20pub enum BackendRequest {
21    Generation(BackendGenerationRequest),
22    GenerationStream(BackendGenerationRequest),
23}
24
25/// Enum containing all types of backend responses.
26#[derive(Clone, Serialize, Deserialize)]
27pub enum BackendResponse {
28    Generation(BackendGenerationResponse),
29    GenerationStream(BackendGenerationResponse),
30}
31
32/// Enum containing all types of core requests.
33#[derive(Clone, Serialize, Deserialize)]
34pub enum CoreRequest {
35    Generation(GenerationRequest),
36    GenerationStream(GenerationRequest),
37    GetLastThreadInfo,
38    GetAllThreadInfos,
39    GetThreadMessages { id: String },
40    InitProject,
41}
42
43/// Enum containing all types of core responses.
44#[derive(Clone, Serialize, Deserialize)]
45pub enum CoreResponse {
46    Generation(GenerationResponse),
47    GenerationStream(GenerationResponse),
48    GetLastThreadInfo(Option<ThreadInfo>),
49    GetAllThreadInfos(Vec<ThreadInfo>),
50    GetThreadMessages(Vec<Message>),
51    InitProject,
52}
53
54/// Service that receives [`BackendRequest`] values,
55/// calls a [`Backend`] and responds with [`BackendResponse`].
56pub struct BackendService<B>
57where
58    B: Backend,
59{
60    backend: Arc<B>,
61}
62
63impl<B> Clone for BackendService<B>
64where
65    B: Backend,
66{
67    fn clone(&self) -> Self {
68        Self {
69            backend: self.backend.clone(),
70        }
71    }
72}
73
74impl<B> BackendService<B>
75where
76    B: Backend,
77{
78    pub fn new(backend: Arc<B>) -> Self {
79        Self { backend }
80    }
81}
82
83/// Service that receives [`CoreRequest`] values,
84/// calls a [`Core`] and responds with [`CoreResponse`].
85pub struct CoreService<C>
86where
87    C: Core,
88{
89    core: Arc<C>,
90}
91
92impl<C> Clone for CoreService<C>
93where
94    C: Core,
95{
96    fn clone(&self) -> Self {
97        Self {
98            core: self.core.clone(),
99        }
100    }
101}
102
103impl<C> CoreService<C>
104where
105    C: Core,
106{
107    pub fn new(core: Arc<C>) -> Self {
108        Self { core }
109    }
110}
111
112impl<B> Service<BackendRequest> for BackendService<B>
113where
114    B: Backend + 'static,
115{
116    type Response = ServiceResponse<BackendResponse>;
117    type Error = ServiceError;
118    type Future = ServiceFuture<ServiceResponse<BackendResponse>>;
119
120    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121        Poll::Ready(Ok(()))
122    }
123
124    fn call(&mut self, req: BackendRequest) -> Self::Future {
125        let backend = self.backend.clone();
126        Box::pin(async move {
127            Ok(match req {
128                BackendRequest::Generation(req) => backend
129                    .generate(req)
130                    .await
131                    .map(|v| ServiceResponse::Single(BackendResponse::Generation(v))),
132                BackendRequest::GenerationStream(req) => {
133                    backend.generate_stream(req).await.map(|s| {
134                        ServiceResponse::Multiple(
135                            s.map(|resp| resp.map(|resp| BackendResponse::GenerationStream(resp)))
136                                .boxed(),
137                        )
138                    })
139                }
140            }?)
141        })
142    }
143}
144
145impl<C> Service<CoreRequest> for CoreService<C>
146where
147    C: Core + 'static,
148{
149    type Response = ServiceResponse<CoreResponse>;
150    type Error = ServiceError;
151    type Future = ServiceFuture<ServiceResponse<CoreResponse>>;
152
153    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154        Poll::Ready(Ok(()))
155    }
156
157    fn call(&mut self, req: CoreRequest) -> Self::Future {
158        let core = self.core.clone();
159        Box::pin(async move {
160            Ok(match req {
161                CoreRequest::Generation(req) => core
162                    .generate(req)
163                    .await
164                    .map(|v| ServiceResponse::Single(CoreResponse::Generation(v))),
165                CoreRequest::GenerationStream(req) => core.generate_stream(req).await.map(|s| {
166                    ServiceResponse::Multiple(
167                        s.map(|resp| resp.map(|resp| CoreResponse::GenerationStream(resp)))
168                            .boxed(),
169                    )
170                }),
171                CoreRequest::GetLastThreadInfo => core
172                    .get_last_thread_info()
173                    .await
174                    .map(|i| ServiceResponse::Single(CoreResponse::GetLastThreadInfo(i))),
175                CoreRequest::GetAllThreadInfos => core
176                    .get_all_thread_infos()
177                    .await
178                    .map(|i| ServiceResponse::Single(CoreResponse::GetAllThreadInfos(i))),
179                CoreRequest::GetThreadMessages { id } => core
180                    .get_thread_messages(id)
181                    .await
182                    .map(|m| ServiceResponse::Single(CoreResponse::GetThreadMessages(m))),
183                CoreRequest::InitProject => core
184                    .init_project()
185                    .map(|_| ServiceResponse::Single(CoreResponse::InitProject)),
186            }?)
187        })
188    }
189}
190
191#[cfg(all(feature = "http-client", feature = "stdio-client"))]
192pub mod util {
193    use multilink::{
194        http::client::HttpClientConfig, stdio::client::StdioClientConfig,
195        util::service::build_service_from_config, BoxedService, ServiceError,
196    };
197
198    use super::{CoreRequest, CoreResponse};
199
200    /// The default name of the core cli binary.
201    pub const LLMVM_CORE_CLI_COMMAND: &str = "llmvm-core";
202    /// CLI arguments for the core, when invoking the process for stdio communication.
203    pub const LLMVM_CORE_CLI_ARGS: [&'static str; 2] = ["--log-to-file", "stdio-server"];
204
205    /// Create a core service client that communicates with stdio or HTTP.
206    /// If `http_client_config` is provided, an HTTP client will be created.
207    /// Otherwise, a stdio client is created. Useful for frontends.
208    pub async fn build_core_service_from_config(
209        stdio_client_config: Option<StdioClientConfig>,
210        http_client_config: Option<HttpClientConfig>,
211    ) -> Result<BoxedService<CoreRequest, CoreResponse>, ServiceError> {
212        build_service_from_config::<CoreRequest, CoreResponse>(
213            LLMVM_CORE_CLI_COMMAND,
214            &LLMVM_CORE_CLI_ARGS,
215            stdio_client_config,
216            http_client_config,
217        )
218        .await
219    }
220}