llmvm_protocol/
service.rs1use serde::{Deserialize, Serialize};
2
3use std::{
4 sync::Arc,
5 task::{Context, Poll},
6};
7
8use futures::stream::StreamExt;
9use multilink::{tower::Service, ServiceResponse};
10
11pub use multilink::{BoxedService, ServiceError, ServiceFuture};
12
13use crate::{
14 Backend, BackendGenerationRequest, BackendGenerationResponse, Core, GenerationRequest,
15 GenerationResponse, Message, ThreadInfo,
16};
17
18#[derive(Clone, Serialize, Deserialize)]
20pub enum BackendRequest {
21 Generation(BackendGenerationRequest),
22 GenerationStream(BackendGenerationRequest),
23}
24
25#[derive(Clone, Serialize, Deserialize)]
27pub enum BackendResponse {
28 Generation(BackendGenerationResponse),
29 GenerationStream(BackendGenerationResponse),
30}
31
32#[derive(Clone, Serialize, Deserialize)]
34pub enum CoreRequest {
35 Generation(GenerationRequest),
36 GenerationStream(GenerationRequest),
37 GetLastThreadInfo,
38 GetAllThreadInfos,
39 GetThreadMessages { id: String },
40 InitProject,
41}
42
43#[derive(Clone, Serialize, Deserialize)]
45pub enum CoreResponse {
46 Generation(GenerationResponse),
47 GenerationStream(GenerationResponse),
48 GetLastThreadInfo(Option<ThreadInfo>),
49 GetAllThreadInfos(Vec<ThreadInfo>),
50 GetThreadMessages(Vec<Message>),
51 InitProject,
52}
53
54pub struct BackendService<B>
57where
58 B: Backend,
59{
60 backend: Arc<B>,
61}
62
63impl<B> Clone for BackendService<B>
64where
65 B: Backend,
66{
67 fn clone(&self) -> Self {
68 Self {
69 backend: self.backend.clone(),
70 }
71 }
72}
73
74impl<B> BackendService<B>
75where
76 B: Backend,
77{
78 pub fn new(backend: Arc<B>) -> Self {
79 Self { backend }
80 }
81}
82
83pub struct CoreService<C>
86where
87 C: Core,
88{
89 core: Arc<C>,
90}
91
92impl<C> Clone for CoreService<C>
93where
94 C: Core,
95{
96 fn clone(&self) -> Self {
97 Self {
98 core: self.core.clone(),
99 }
100 }
101}
102
103impl<C> CoreService<C>
104where
105 C: Core,
106{
107 pub fn new(core: Arc<C>) -> Self {
108 Self { core }
109 }
110}
111
112impl<B> Service<BackendRequest> for BackendService<B>
113where
114 B: Backend + 'static,
115{
116 type Response = ServiceResponse<BackendResponse>;
117 type Error = ServiceError;
118 type Future = ServiceFuture<ServiceResponse<BackendResponse>>;
119
120 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121 Poll::Ready(Ok(()))
122 }
123
124 fn call(&mut self, req: BackendRequest) -> Self::Future {
125 let backend = self.backend.clone();
126 Box::pin(async move {
127 Ok(match req {
128 BackendRequest::Generation(req) => backend
129 .generate(req)
130 .await
131 .map(|v| ServiceResponse::Single(BackendResponse::Generation(v))),
132 BackendRequest::GenerationStream(req) => {
133 backend.generate_stream(req).await.map(|s| {
134 ServiceResponse::Multiple(
135 s.map(|resp| resp.map(|resp| BackendResponse::GenerationStream(resp)))
136 .boxed(),
137 )
138 })
139 }
140 }?)
141 })
142 }
143}
144
145impl<C> Service<CoreRequest> for CoreService<C>
146where
147 C: Core + 'static,
148{
149 type Response = ServiceResponse<CoreResponse>;
150 type Error = ServiceError;
151 type Future = ServiceFuture<ServiceResponse<CoreResponse>>;
152
153 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154 Poll::Ready(Ok(()))
155 }
156
157 fn call(&mut self, req: CoreRequest) -> Self::Future {
158 let core = self.core.clone();
159 Box::pin(async move {
160 Ok(match req {
161 CoreRequest::Generation(req) => core
162 .generate(req)
163 .await
164 .map(|v| ServiceResponse::Single(CoreResponse::Generation(v))),
165 CoreRequest::GenerationStream(req) => core.generate_stream(req).await.map(|s| {
166 ServiceResponse::Multiple(
167 s.map(|resp| resp.map(|resp| CoreResponse::GenerationStream(resp)))
168 .boxed(),
169 )
170 }),
171 CoreRequest::GetLastThreadInfo => core
172 .get_last_thread_info()
173 .await
174 .map(|i| ServiceResponse::Single(CoreResponse::GetLastThreadInfo(i))),
175 CoreRequest::GetAllThreadInfos => core
176 .get_all_thread_infos()
177 .await
178 .map(|i| ServiceResponse::Single(CoreResponse::GetAllThreadInfos(i))),
179 CoreRequest::GetThreadMessages { id } => core
180 .get_thread_messages(id)
181 .await
182 .map(|m| ServiceResponse::Single(CoreResponse::GetThreadMessages(m))),
183 CoreRequest::InitProject => core
184 .init_project()
185 .map(|_| ServiceResponse::Single(CoreResponse::InitProject)),
186 }?)
187 })
188 }
189}
190
191#[cfg(all(feature = "http-client", feature = "stdio-client"))]
192pub mod util {
193 use multilink::{
194 http::client::HttpClientConfig, stdio::client::StdioClientConfig,
195 util::service::build_service_from_config, BoxedService, ServiceError,
196 };
197
198 use super::{CoreRequest, CoreResponse};
199
200 pub const LLMVM_CORE_CLI_COMMAND: &str = "llmvm-core";
202 pub const LLMVM_CORE_CLI_ARGS: [&'static str; 2] = ["--log-to-file", "stdio-server"];
204
205 pub async fn build_core_service_from_config(
209 stdio_client_config: Option<StdioClientConfig>,
210 http_client_config: Option<HttpClientConfig>,
211 ) -> Result<BoxedService<CoreRequest, CoreResponse>, ServiceError> {
212 build_service_from_config::<CoreRequest, CoreResponse>(
213 LLMVM_CORE_CLI_COMMAND,
214 &LLMVM_CORE_CLI_ARGS,
215 stdio_client_config,
216 http_client_config,
217 )
218 .await
219 }
220}