llm_daemon/
proxy.rs

1use std::fmt::Debug;
2use std::sync::Arc;
3
4use axum::body::{Body, BodyDataStream};
5use axum::extract::State;
6use axum::http::HeaderValue;
7use axum::response::IntoResponse;
8use axum::routing::{get, post};
9use axum::Router;
10use futures::FutureExt;
11use hyper::body::Incoming;
12use hyper::{Response, StatusCode};
13use hyper_util::client::legacy::connect::HttpConnector;
14use hyper_util::client::legacy::Client;
15use hyper_util::rt::TokioExecutor;
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use tokio::net::TcpListener;
19use tokio::sync::Semaphore;
20use tokio::time::Instant;
21use tracing::{debug, info, warn};
22
23use crate::daemon_trait::LlmConfig;
24use crate::LlmDaemon;
25
26#[derive(Debug)]
27pub struct ProxyConfig {
28    port: u16,
29}
30
31impl LlmConfig for ProxyConfig {
32    fn endpoint(&self) -> url::Url {
33        url::Url::parse(&format!(
34            "http://127.0.0.1:{}/v1/completions",
35            self.port
36        ))
37        .expect("failed to parse url")
38    }
39
40    fn health_url(&self) -> url::Url {
41        url::Url::parse(&format!("http://127.0.0.1:{}/health", self.port))
42            .expect("failed to parse url")
43    }
44}
45
46impl Default for ProxyConfig {
47    fn default() -> Self {
48        Self { port: 8282 }
49    }
50}
51
52/// A proxy to actual LLM server, and only send the last pending request.
53/// If the LLM generation is slower than requests, then processing the oldest
54/// request does not make sense, as the user is more interested in the later
55/// requests. But current server (AFAIK) does not handle this properly...
56#[derive(Debug)]
57pub struct Proxy<D: LlmDaemon + Debug> {
58    config: ProxyConfig,
59    inner: D,
60}
61
62impl<D: LlmDaemon + Debug> Proxy<D> {
63    pub fn new(config: ProxyConfig, inner: D) -> Self {
64        Self { config, inner }
65    }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69struct Completion {
70    content: String,
71}
72
73impl<D: LlmDaemon + Debug> LlmDaemon for Proxy<D> {
74    type Config = ProxyConfig;
75
76    fn fork_daemon(&self) -> anyhow::Result<()> {
77        info!("Fork inner daemon {:?}", self.inner);
78        self.inner.fork_daemon()
79    }
80
81    fn heartbeat<'a, 'b>(
82        &'b self,
83    ) -> impl futures::prelude::Future<Output = anyhow::Result<()>> + Send + 'a
84    where
85        'a: 'b,
86    {
87        let port = self.config.port;
88        // boxed() is due to https://github.com/rust-lang/rust/issues/100013
89        let hb = self.inner.heartbeat().boxed();
90        let proxy = run_proxy(port).boxed();
91
92        async move {
93            let (r0, r1) = futures::join!(hb, proxy);
94            r0?;
95            r1?;
96            Ok(())
97        }
98    }
99
100    fn config(&self) -> &Self::Config {
101        &self.config
102    }
103
104    fn ping(&self) -> anyhow::Result<()> {
105        self.inner.ping()
106    }
107}
108
109async fn handle_health(
110    State((_sem, client)): State<(
111        Arc<Semaphore>,
112        Client<HttpConnector, BodyDataStream>,
113    )>,
114) -> Result<impl IntoResponse, StatusCode> {
115    let req_builder = hyper::Request::builder()
116        .uri("http://127.0.0.1:28282/health")
117        .method("GET");
118
119    let request = req_builder.body(Body::empty().into_data_stream()).unwrap();
120
121    client.request(request).await.map_err(|e| {
122        warn!("error: {} {}", e, e.is_connect());
123        if e.is_connect() {
124            StatusCode::SERVICE_UNAVAILABLE
125        } else {
126            StatusCode::INTERNAL_SERVER_ERROR
127        }
128    })
129}
130
131async fn inner(
132    client: &Client<HttpConnector, BodyDataStream>,
133    req: axum::extract::Request,
134) -> Result<Response<Incoming>, hyper_util::client::legacy::Error> {
135    // FIXME: 28282 -> configured port
136    let mut req_builder = hyper::Request::builder()
137        .uri("http://127.0.0.1:28282/v1/completions")
138        .method(req.method());
139    let headers = req_builder.headers_mut().unwrap();
140    req.headers().into_iter().for_each(|(name, value)| {
141        headers.append(name, value.clone());
142    });
143
144    let request = req_builder
145        .body(req.into_body().into_data_stream())
146        .unwrap();
147
148    client.request(request).await
149}
150
151async fn handle_proxy(
152    State((sem, client)): State<(
153        Arc<Semaphore>,
154        Client<HttpConnector, BodyDataStream>,
155    )>,
156    req: axum::extract::Request,
157) -> Result<impl IntoResponse, StatusCode> {
158    info!("Processing request {:?}", req);
159    let clock = Instant::now();
160    let acquired = sem
161        .clone()
162        .acquire_owned()
163        .await
164        .expect("failed to acquire semaphore");
165    acquired.forget();
166    let lock_wait_ms = clock.elapsed().as_millis();
167    inner(&client, req)
168        .await
169        .map(|mut res| {
170            let gen_latency = clock.elapsed().as_millis();
171            debug!("generated: lock wait: {lock_wait_ms}ms, total gen latency: {gen_latency}ms");
172            res.headers_mut().append(
173                "x-tracing-info",
174                HeaderValue::from_str(
175                    &json!({"lock_wait_ms": lock_wait_ms, "gen_latency": gen_latency}).to_string(),
176                )
177                .expect("fail to create header"),
178            );
179            res
180        })
181        .map_err(|e| {
182            warn!("error: {} {}", e, e.is_connect());
183            if e.is_connect() {
184                StatusCode::SERVICE_UNAVAILABLE
185            } else {
186                StatusCode::INTERNAL_SERVER_ERROR
187            }
188        })
189        .inspect(|_| {
190            sem.add_permits(1);
191        })
192}
193
194pub async fn run_proxy(port: u16) -> anyhow::Result<()> {
195    let client =
196        hyper_util::client::legacy::Client::builder(TokioExecutor::new())
197            .build_http();
198    let app = Router::new()
199        .route("/completions", post(handle_proxy))
200        .route("/v1/completions", post(handle_proxy))
201        .route("/health", get(handle_health))
202        .with_state((Arc::new(Semaphore::new(1)), client));
203    debug!("Creating listener on {}", port);
204    let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).await?;
205
206    axum::serve(listener, app).await?;
207
208    Ok(())
209}
210
211#[cfg(test)]
212#[cfg(feature = "llama-daemon")]
213mod tests {
214    use std::sync::Arc;
215    use std::time::Duration;
216
217    use futures::future;
218    use tokio::runtime::Runtime;
219    use tokio::sync::Mutex;
220    use tracing::error;
221    use tracing_test::traced_test;
222
223    use crate::daemon_trait::LlmConfig as _;
224    use crate::proxy::Proxy;
225    use crate::{Daemon2, Generator, LlmDaemon};
226
227    #[traced_test]
228    #[test]
229    fn proxy_trait_test() -> anyhow::Result<()> {
230        type Target = Proxy<Daemon2>;
231        let conf = <Target as LlmDaemon>::Config::default();
232        let endpoint = conf.endpoint();
233        let inst = Target::new(
234            conf,
235            crate::Daemon2::from((
236                "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF".to_string(),
237                28282,
238            )),
239        );
240
241        inst.fork_daemon()?;
242        let runtime = Runtime::new()?;
243        runtime.spawn(inst.heartbeat());
244        runtime.block_on(async {
245            inst.ready().await;
246            let gen = Generator::new(endpoint, None);
247            let resp = gen
248                .generate("<bos>Sum of 7 and 8 is ".to_string())
249                .await
250                .inspect_err(|err| {
251                    error!("error: {:?}", err);
252                });
253            assert!(resp.is_ok());
254            assert!(resp.unwrap().contains("15"));
255        });
256        Ok(())
257    }
258
259    #[tokio::test]
260    async fn without_spawn() -> anyhow::Result<()> {
261        let mutex = Arc::new(Mutex::new(1));
262
263        let m1 = mutex.clone();
264        let h1 = async move {
265            let _guard = m1.lock().await;
266            tokio::time::sleep(Duration::from_secs(1)).await;
267        };
268        let (h1, flag1) = future::abortable(h1);
269
270        let m2 = mutex.clone();
271        let h2 = async move {
272            flag1.abort();
273            let _guard = m2.lock().await;
274            tokio::time::sleep(Duration::from_secs(1)).await;
275        };
276
277        let r1 = tokio::join!(h1, h2).0;
278        r1.expect_err("Should be aborted");
279
280        Ok(())
281    }
282
283    #[tokio::test]
284    async fn pass_the_mutexguard() -> anyhow::Result<()> {
285        let mutex = Arc::new(Mutex::new(1));
286
287        let m1 = mutex.clone();
288        let h1 = async move {
289            tokio::spawn(async move {
290                let _guard = m1.lock().await;
291                tokio::time::sleep(Duration::from_secs(1)).await;
292            })
293            .await
294        };
295        let (h1, flag1) = future::abortable(h1);
296
297        let m2 = mutex.clone();
298        let h2 = async move {
299            tokio::time::sleep(Duration::from_millis(5)).await;
300            flag1.abort();
301            let _guard = m2.lock().await;
302            tokio::time::sleep(Duration::from_secs(1)).await;
303        };
304
305        let r1 = tokio::join!(h1, h2).0;
306        r1.expect_err("Should be aborted");
307        Ok(())
308    }
309
310    #[tokio::test]
311    async fn scoped_locks() -> anyhow::Result<()> {
312        let mutex = Arc::new(Mutex::new(1));
313
314        let m1 = mutex.clone();
315        let handle = tokio::spawn(async move {
316            let _guard = m1.lock().await;
317            tokio::time::sleep(Duration::from_secs(1)).await;
318        });
319
320        let t1 = handle;
321
322        tokio::time::sleep(Duration::from_millis(5)).await;
323
324        t1.abort();
325
326        let m2 = mutex.clone();
327        let handle2 = tokio::spawn(async move {
328            let _guard = m2.lock().await;
329            tokio::time::sleep(Duration::from_secs(1)).await;
330        });
331
332        let t2 = handle2.abort_handle();
333
334        tokio::time::sleep(Duration::from_millis(5)).await;
335
336        t2.abort();
337
338        let m3 = mutex.clone();
339        let handle3 = tokio::spawn(async move {
340            let _guard = m3.lock().await;
341            tokio::time::sleep(Duration::from_secs(1)).await;
342        });
343
344        tokio::time::sleep(Duration::from_millis(5)).await;
345
346        future::abortable(handle3).1.abort();
347
348        let m4 = mutex.clone();
349        let handle4 = async move {
350            let _guard = m4.lock().await;
351            tokio::time::sleep(Duration::from_secs(1)).await;
352        };
353        let h4 = future::abortable(handle4);
354
355        tokio::spawn(h4.0);
356
357        tokio::time::sleep(Duration::from_millis(1005)).await;
358
359        h4.1.abort();
360
361        tokio::time::sleep(Duration::from_millis(1000)).await;
362
363        Ok(())
364    }
365}