Skip to main content

llm_manager/
serve_api.rs

1use std::net::SocketAddr;
2use std::time::Instant;
3
4use axum::Json;
5use axum::Router;
6use axum::body::Body;
7use axum::extract::State;
8use axum::http::StatusCode;
9use axum::response::IntoResponse;
10use axum::routing::{get, post};
11use futures_util::{StreamExt, stream};
12use tower_http::cors::CorsLayer;
13use tower_http::trace::TraceLayer;
14use tracing::info;
15
16use reqwest::Client;
17
18#[derive(Clone)]
19pub struct ApiState {
20    pub server_url: String,
21    pub api_key: Option<String>,
22    pub model_name: String,
23    pub pid: u32,
24    pub start_time: Instant,
25    pub port: u16,
26    pub client: reqwest::Client,
27}
28
29fn extract_api_key(headers: &axum::http::HeaderMap) -> Option<String> {
30    headers
31        .get("Authorization")
32        .and_then(|v| v.to_str().ok())
33        .and_then(|v| v.strip_prefix("Bearer "))
34        .map(|s| s.to_string())
35}
36
37async fn auth_middleware(
38    State(state): State<ApiState>,
39    req: axum::extract::Request,
40    next: axum::middleware::Next,
41) -> axum::response::Response {
42    if let Some(expected) = &state.api_key {
43        let provided = extract_api_key(req.headers());
44        if provided.as_deref() != Some(expected) {
45            return (
46                StatusCode::UNAUTHORIZED,
47                Json(serde_json::json!({"error": "Unauthorized"})),
48            )
49                .into_response();
50        }
51    }
52    next.run(req).await
53}
54
55/// Proxy a request to the llama-server backend with SSE streaming support.
56/// Checks Content-Type: if text/event-stream, streams the body; otherwise buffers.
57async fn proxy_streaming(
58    State(state): State<ApiState>,
59    req: axum::extract::Request,
60) -> impl IntoResponse {
61    let path = req.uri().path().to_string();
62    let method = req.method().clone();
63    let headers = req.headers().clone();
64
65    let url = format!("{}{}", state.server_url, path);
66
67    // Convert request body to a stream for reqwest
68    let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
69        Ok(b) => b,
70        Err(e) => {
71            info!("Failed to read request body for {}: {}", path, e);
72            return (
73                StatusCode::BAD_REQUEST,
74                Json(serde_json::json!({"error": format!("Failed to read request body: {}", e)})),
75            )
76                .into_response();
77        }
78    };
79    let body_stream = stream::iter(vec![Ok::<_, reqwest::Error>(body_bytes)]);
80
81    let mut request_builder = match method {
82        axum::http::Method::GET => state.client.get(&url),
83        axum::http::Method::POST => state.client.post(&url),
84        axum::http::Method::PUT => state.client.put(&url),
85        axum::http::Method::DELETE => state.client.delete(&url),
86        _ => return (
87            StatusCode::METHOD_NOT_ALLOWED,
88            Json(serde_json::json!({"error": "Method not supported"})),
89        )
90            .into_response(),
91    };
92
93    const HOP_BY_HOP: &[&str] = &[
94        "connection",
95        "keep-alive",
96        "proxy-authenticate",
97        "proxy-authorization",
98        "te",
99        "trailer",
100        "transfer-encoding",
101        "upgrade",
102        "host",
103    ];
104    for (name, value) in headers.iter() {
105        let name_str = name.as_str();
106        if !HOP_BY_HOP.contains(&name_str) && name_str != "authorization" {
107            request_builder = request_builder.header(name, value);
108        }
109    }
110
111    let response = request_builder
112        .body(reqwest::Body::wrap_stream(body_stream))
113        .send()
114        .await;
115
116    match response {
117        Ok(resp) => {
118            let status = resp.status();
119            let headers = resp.headers().clone();
120            let is_sse = resp
121                .headers()
122                .get(axum::http::header::CONTENT_TYPE)
123                .and_then(|v| v.to_str().ok())
124                .map(|v| v.contains("text/event-stream"))
125                .unwrap_or(false);
126
127            if is_sse {
128                let mut response = axum::response::Response::new(Body::from_stream(
129                    resp.bytes_stream().map(|result| {
130                        result.map_err(std::io::Error::other)
131                    }),
132                ));
133                *response.status_mut() = status;
134                for (name, value) in headers.iter() {
135                    response.headers_mut().insert(name, value.clone());
136                }
137                response
138            } else {
139                let bytes = match resp.bytes().await {
140                    Ok(b) => b,
141                    Err(e) => {
142                        info!("Failed to read response body for {}: {}", path, e);
143                        return (
144                            StatusCode::BAD_GATEWAY,
145                            Json(serde_json::json!({"error": format!("Failed to read backend response: {}", e)})),
146                        )
147                            .into_response();
148                    }
149                };
150                (status, headers, bytes).into_response()
151            }
152        }
153        Err(e) => {
154            info!("Proxy error for {}: {}", path, e);
155            (
156                StatusCode::BAD_GATEWAY,
157                Json(serde_json::json!({"error": format!("Backend unavailable: {}", e)})),
158            )
159                .into_response()
160        }
161    }
162}
163
164/// Simple health check endpoint - no auth, verifies backend
165async fn health(State(state): State<ApiState>) -> impl IntoResponse {
166    let resp = state
167        .client
168        .get(format!("{}/health", state.server_url))
169        .send()
170        .await;
171
172    match resp {
173        Ok(response) if response.status().is_success() => Json(serde_json::json!({
174            "status": "ok",
175            "backend": "healthy"
176        })),
177        Ok(_) => Json(serde_json::json!({
178            "status": "degraded",
179            "backend": "unreachable"
180        })),
181        Err(_) => Json(serde_json::json!({
182            "status": "degraded",
183            "backend": "unreachable"
184        })),
185    }
186}
187
188/// Custom status endpoint.
189async fn status(State(state): State<ApiState>) -> impl IntoResponse {
190    let uptime = state.start_time.elapsed();
191    let uptime_secs = uptime.as_secs();
192
193    // Try to get loaded models from llama-server
194    let loaded_models = match state
195        .client
196        .get(format!("{}/models", state.server_url))
197        .send()
198        .await
199    {
200        Ok(resp) if resp.status().is_success() => {
201            let json: serde_json::Value = match resp.json().await {
202                Ok(v) => v,
203                Err(_) => serde_json::json!([]),
204            };
205            json.get("data")
206                .and_then(|d| d.as_array())
207                .map(|a| a.len())
208                .unwrap_or(0)
209        }
210        _ => 0,
211    };
212
213    Json(serde_json::json!({
214        "status": "running",
215        "pid": state.pid,
216        "port": state.port,
217        "model": state.model_name,
218        "uptime_seconds": uptime_secs,
219        "loaded_models": loaded_models,
220    }))
221}
222
223pub async fn start_api_server(
224    addr: SocketAddr,
225    api_key: Option<String>,
226    server_port: u16,
227    model_name: String,
228    pid: u32,
229    mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
230    host: String,
231    tls_config: Option<axum_server::tls_rustls::RustlsConfig>,
232) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
233    let bind = addr;
234    let start_time = Instant::now();
235    let client = Client::builder()
236        .pool_max_idle_per_host(20)
237        .timeout(std::time::Duration::from_secs(300))
238        .build()?;
239    let state = ApiState {
240        server_url: format!("http://127.0.0.1:{}", server_port),
241        api_key,
242        model_name,
243        pid,
244        start_time,
245        port: bind.port(),
246        client,
247    };
248
249    let cors = CorsLayer::new()
250        .allow_origin(tower_http::cors::Any)
251        .allow_methods([
252            axum::http::Method::GET,
253            axum::http::Method::POST,
254            axum::http::Method::PUT,
255            axum::http::Method::DELETE,
256            axum::http::Method::OPTIONS,
257        ])
258        .allow_headers([
259            axum::http::header::CONTENT_TYPE,
260            axum::http::header::AUTHORIZATION,
261        ]);
262
263    let api_key_clone = state.api_key.clone();
264    let protocol = if tls_config.is_some() {
265        "https"
266    } else {
267        "http"
268    };
269    info!(
270        "API server starting on {protocol}://{} (proxying to http://127.0.0.1:{})",
271        host, server_port
272    );
273    if api_key_clone.is_some() {
274        info!("API key authentication is ENABLED");
275    }
276
277    let app = Router::new()
278        .route("/health", get(health))
279        .route("/metrics", get(proxy_streaming))
280        .merge(
281            Router::new()
282                .route("/v1/chat/completions", post(proxy_streaming))
283                .route("/v1/completions", post(proxy_streaming))
284                .route("/v1/embeddings", post(proxy_streaming))
285                .route("/v1/models", get(proxy_streaming))
286                .route("/api/status", get(status))
287                .fallback(proxy_streaming)
288                .layer(cors)
289                .layer(TraceLayer::new_for_http())
290                .layer(axum::middleware::from_fn_with_state(
291                    state.clone(),
292                    auth_middleware,
293                )),
294        )
295        .with_state(state);
296
297    match tls_config {
298        Some(tls_cfg) => {
299            let tls_listener = axum_server::bind_rustls(bind, tls_cfg);
300            let shutdown_fut = async {
301                let _ = shutdown_rx.wait_for(|v| *v).await;
302            };
303            let _ = tokio::select! {
304                result = tls_listener.serve(app.into_make_service()) => result,
305                _ = shutdown_fut => Ok(()),
306            };
307        }
308        None => {
309            axum::serve(tokio::net::TcpListener::bind(bind).await?, app)
310                .with_graceful_shutdown(async move {
311                    let _ = shutdown_rx.wait_for(|v| *v).await;
312                })
313                .await?;
314        }
315    }
316    Ok(())
317}