1use std::fmt::Debug;
2use std::sync::Arc;
3
4use axum::body::{Body, BodyDataStream};
5use axum::extract::State;
6use axum::http::HeaderValue;
7use axum::response::IntoResponse;
8use axum::routing::{get, post};
9use axum::Router;
10use futures::FutureExt;
11use hyper::body::Incoming;
12use hyper::{Response, StatusCode};
13use hyper_util::client::legacy::connect::HttpConnector;
14use hyper_util::client::legacy::Client;
15use hyper_util::rt::TokioExecutor;
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tokio::net::TcpListener;
19use tokio::sync::Semaphore;
20use tokio::time::Instant;
21use tracing::{debug, info, warn};
22
23use crate::daemon_trait::LlmConfig;
24use crate::LlmDaemon;
25
26#[derive(Debug)]
27pub struct ProxyConfig {
28 port: u16,
29}
30
31impl LlmConfig for ProxyConfig {
32 fn endpoint(&self) -> url::Url {
33 url::Url::parse(&format!(
34 "http://127.0.0.1:{}/v1/completions",
35 self.port
36 ))
37 .expect("failed to parse url")
38 }
39
40 fn health_url(&self) -> url::Url {
41 url::Url::parse(&format!("http://127.0.0.1:{}/health", self.port))
42 .expect("failed to parse url")
43 }
44}
45
46impl Default for ProxyConfig {
47 fn default() -> Self {
48 Self { port: 8282 }
49 }
50}
51
52#[derive(Debug)]
57pub struct Proxy<D: LlmDaemon + Debug> {
58 config: ProxyConfig,
59 inner: D,
60}
61
62impl<D: LlmDaemon + Debug> Proxy<D> {
63 pub fn new(config: ProxyConfig, inner: D) -> Self {
64 Self { config, inner }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69struct Completion {
70 content: String,
71}
72
73impl<D: LlmDaemon + Debug> LlmDaemon for Proxy<D> {
74 type Config = ProxyConfig;
75
76 fn fork_daemon(&self) -> anyhow::Result<()> {
77 info!("Fork inner daemon {:?}", self.inner);
78 self.inner.fork_daemon()
79 }
80
81 fn heartbeat<'a, 'b>(
82 &'b self,
83 ) -> impl futures::prelude::Future<Output = anyhow::Result<()>> + Send + 'a
84 where
85 'a: 'b,
86 {
87 let port = self.config.port;
88 let hb = self.inner.heartbeat().boxed();
90 let proxy = run_proxy(port).boxed();
91
92 async move {
93 let (r0, r1) = futures::join!(hb, proxy);
94 r0?;
95 r1?;
96 Ok(())
97 }
98 }
99
100 fn config(&self) -> &Self::Config {
101 &self.config
102 }
103
104 fn ping(&self) -> anyhow::Result<()> {
105 self.inner.ping()
106 }
107}
108
109async fn handle_health(
110 State((_sem, client)): State<(
111 Arc<Semaphore>,
112 Client<HttpConnector, BodyDataStream>,
113 )>,
114) -> Result<impl IntoResponse, StatusCode> {
115 let req_builder = hyper::Request::builder()
116 .uri("http://127.0.0.1:28282/health")
117 .method("GET");
118
119 let request = req_builder.body(Body::empty().into_data_stream()).unwrap();
120
121 client.request(request).await.map_err(|e| {
122 warn!("error: {} {}", e, e.is_connect());
123 if e.is_connect() {
124 StatusCode::SERVICE_UNAVAILABLE
125 } else {
126 StatusCode::INTERNAL_SERVER_ERROR
127 }
128 })
129}
130
131async fn inner(
132 client: &Client<HttpConnector, BodyDataStream>,
133 req: axum::extract::Request,
134) -> Result<Response<Incoming>, hyper_util::client::legacy::Error> {
135 let mut req_builder = hyper::Request::builder()
137 .uri("http://127.0.0.1:28282/v1/completions")
138 .method(req.method());
139 let headers = req_builder.headers_mut().unwrap();
140 req.headers().into_iter().for_each(|(name, value)| {
141 headers.append(name, value.clone());
142 });
143
144 let request = req_builder
145 .body(req.into_body().into_data_stream())
146 .unwrap();
147
148 client.request(request).await
149}
150
151async fn handle_proxy(
152 State((sem, client)): State<(
153 Arc<Semaphore>,
154 Client<HttpConnector, BodyDataStream>,
155 )>,
156 req: axum::extract::Request,
157) -> Result<impl IntoResponse, StatusCode> {
158 info!("Processing request {:?}", req);
159 let clock = Instant::now();
160 let acquired = sem
161 .clone()
162 .acquire_owned()
163 .await
164 .expect("failed to acquire semaphore");
165 acquired.forget();
166 let lock_wait_ms = clock.elapsed().as_millis();
167 inner(&client, req)
168 .await
169 .map(|mut res| {
170 let gen_latency = clock.elapsed().as_millis();
171 debug!("generated: lock wait: {lock_wait_ms}ms, total gen latency: {gen_latency}ms");
172 res.headers_mut().append(
173 "x-tracing-info",
174 HeaderValue::from_str(
175 &json!({"lock_wait_ms": lock_wait_ms, "gen_latency": gen_latency}).to_string(),
176 )
177 .expect("fail to create header"),
178 );
179 res
180 })
181 .map_err(|e| {
182 warn!("error: {} {}", e, e.is_connect());
183 if e.is_connect() {
184 StatusCode::SERVICE_UNAVAILABLE
185 } else {
186 StatusCode::INTERNAL_SERVER_ERROR
187 }
188 })
189 .inspect(|_| {
190 sem.add_permits(1);
191 })
192}
193
194pub async fn run_proxy(port: u16) -> anyhow::Result<()> {
195 let client =
196 hyper_util::client::legacy::Client::builder(TokioExecutor::new())
197 .build_http();
198 let app = Router::new()
199 .route("/completions", post(handle_proxy))
200 .route("/v1/completions", post(handle_proxy))
201 .route("/health", get(handle_health))
202 .with_state((Arc::new(Semaphore::new(1)), client));
203 debug!("Creating listener on {}", port);
204 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).await?;
205
206 axum::serve(listener, app).await?;
207
208 Ok(())
209}
210
211#[cfg(test)]
212#[cfg(feature = "llama-daemon")]
213mod tests {
214 use std::sync::Arc;
215 use std::time::Duration;
216
217 use futures::future;
218 use tokio::runtime::Runtime;
219 use tokio::sync::Mutex;
220 use tracing::error;
221 use tracing_test::traced_test;
222
223 use crate::daemon_trait::LlmConfig as _;
224 use crate::proxy::Proxy;
225 use crate::{Daemon2, Generator, LlmDaemon};
226
227 #[traced_test]
228 #[test]
229 fn proxy_trait_test() -> anyhow::Result<()> {
230 type Target = Proxy<Daemon2>;
231 let conf = <Target as LlmDaemon>::Config::default();
232 let endpoint = conf.endpoint();
233 let inst = Target::new(
234 conf,
235 crate::Daemon2::from((
236 "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF".to_string(),
237 28282,
238 )),
239 );
240
241 inst.fork_daemon()?;
242 let runtime = Runtime::new()?;
243 runtime.spawn(inst.heartbeat());
244 runtime.block_on(async {
245 inst.ready().await;
246 let gen = Generator::new(endpoint, None);
247 let resp = gen
248 .generate("<bos>Sum of 7 and 8 is ".to_string())
249 .await
250 .inspect_err(|err| {
251 error!("error: {:?}", err);
252 });
253 assert!(resp.is_ok());
254 assert!(resp.unwrap().contains("15"));
255 });
256 Ok(())
257 }
258
259 #[tokio::test]
260 async fn without_spawn() -> anyhow::Result<()> {
261 let mutex = Arc::new(Mutex::new(1));
262
263 let m1 = mutex.clone();
264 let h1 = async move {
265 let _guard = m1.lock().await;
266 tokio::time::sleep(Duration::from_secs(1)).await;
267 };
268 let (h1, flag1) = future::abortable(h1);
269
270 let m2 = mutex.clone();
271 let h2 = async move {
272 flag1.abort();
273 let _guard = m2.lock().await;
274 tokio::time::sleep(Duration::from_secs(1)).await;
275 };
276
277 let r1 = tokio::join!(h1, h2).0;
278 r1.expect_err("Should be aborted");
279
280 Ok(())
281 }
282
283 #[tokio::test]
284 async fn pass_the_mutexguard() -> anyhow::Result<()> {
285 let mutex = Arc::new(Mutex::new(1));
286
287 let m1 = mutex.clone();
288 let h1 = async move {
289 tokio::spawn(async move {
290 let _guard = m1.lock().await;
291 tokio::time::sleep(Duration::from_secs(1)).await;
292 })
293 .await
294 };
295 let (h1, flag1) = future::abortable(h1);
296
297 let m2 = mutex.clone();
298 let h2 = async move {
299 tokio::time::sleep(Duration::from_millis(5)).await;
300 flag1.abort();
301 let _guard = m2.lock().await;
302 tokio::time::sleep(Duration::from_secs(1)).await;
303 };
304
305 let r1 = tokio::join!(h1, h2).0;
306 r1.expect_err("Should be aborted");
307 Ok(())
308 }
309
310 #[tokio::test]
311 async fn scoped_locks() -> anyhow::Result<()> {
312 let mutex = Arc::new(Mutex::new(1));
313
314 let m1 = mutex.clone();
315 let handle = tokio::spawn(async move {
316 let _guard = m1.lock().await;
317 tokio::time::sleep(Duration::from_secs(1)).await;
318 });
319
320 let t1 = handle;
321
322 tokio::time::sleep(Duration::from_millis(5)).await;
323
324 t1.abort();
325
326 let m2 = mutex.clone();
327 let handle2 = tokio::spawn(async move {
328 let _guard = m2.lock().await;
329 tokio::time::sleep(Duration::from_secs(1)).await;
330 });
331
332 let t2 = handle2.abort_handle();
333
334 tokio::time::sleep(Duration::from_millis(5)).await;
335
336 t2.abort();
337
338 let m3 = mutex.clone();
339 let handle3 = tokio::spawn(async move {
340 let _guard = m3.lock().await;
341 tokio::time::sleep(Duration::from_secs(1)).await;
342 });
343
344 tokio::time::sleep(Duration::from_millis(5)).await;
345
346 future::abortable(handle3).1.abort();
347
348 let m4 = mutex.clone();
349 let handle4 = async move {
350 let _guard = m4.lock().await;
351 tokio::time::sleep(Duration::from_secs(1)).await;
352 };
353 let h4 = future::abortable(handle4);
354
355 tokio::spawn(h4.0);
356
357 tokio::time::sleep(Duration::from_millis(1005)).await;
358
359 h4.1.abort();
360
361 tokio::time::sleep(Duration::from_millis(1000)).await;
362
363 Ok(())
364 }
365}