1use std::net::SocketAddr;
2use std::time::Instant;
3
4use axum::Json;
5use axum::Router;
6use axum::body::Body;
7use axum::extract::State;
8use axum::http::StatusCode;
9use axum::response::IntoResponse;
10use axum::routing::{get, post};
11use futures_util::{StreamExt, stream};
12use tower_http::cors::CorsLayer;
13use tower_http::trace::TraceLayer;
14use tracing::info;
15
16use reqwest::Client;
17
18#[derive(Clone)]
19pub struct ApiState {
20 pub server_url: String,
21 pub api_key: Option<String>,
22 pub model_name: String,
23 pub pid: u32,
24 pub start_time: Instant,
25 pub port: u16,
26 pub client: reqwest::Client,
27}
28
29fn extract_api_key(headers: &axum::http::HeaderMap) -> Option<String> {
30 headers
31 .get("Authorization")
32 .and_then(|v| v.to_str().ok())
33 .and_then(|v| v.strip_prefix("Bearer "))
34 .map(|s| s.to_string())
35}
36
37async fn auth_middleware(
38 State(state): State<ApiState>,
39 req: axum::extract::Request,
40 next: axum::middleware::Next,
41) -> axum::response::Response {
42 if let Some(expected) = &state.api_key {
43 let provided = extract_api_key(req.headers());
44 if provided.as_deref() != Some(expected) {
45 return (
46 StatusCode::UNAUTHORIZED,
47 Json(serde_json::json!({"error": "Unauthorized"})),
48 )
49 .into_response();
50 }
51 }
52 next.run(req).await
53}
54
55async fn proxy_streaming(
58 State(state): State<ApiState>,
59 req: axum::extract::Request,
60) -> impl IntoResponse {
61 let path = req.uri().path().to_string();
62 let method = req.method().clone();
63 let headers = req.headers().clone();
64
65 let url = format!("{}{}", state.server_url, path);
66
67 let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
69 Ok(b) => b,
70 Err(e) => {
71 info!("Failed to read request body for {}: {}", path, e);
72 return (
73 StatusCode::BAD_REQUEST,
74 Json(serde_json::json!({"error": format!("Failed to read request body: {}", e)})),
75 )
76 .into_response();
77 }
78 };
79 let body_stream = stream::iter(vec![Ok::<_, reqwest::Error>(body_bytes)]);
80
81 let mut request_builder = match method {
82 axum::http::Method::GET => state.client.get(&url),
83 axum::http::Method::POST => state.client.post(&url),
84 axum::http::Method::PUT => state.client.put(&url),
85 axum::http::Method::DELETE => state.client.delete(&url),
86 _ => return (
87 StatusCode::METHOD_NOT_ALLOWED,
88 Json(serde_json::json!({"error": "Method not supported"})),
89 )
90 .into_response(),
91 };
92
93 const HOP_BY_HOP: &[&str] = &[
94 "connection",
95 "keep-alive",
96 "proxy-authenticate",
97 "proxy-authorization",
98 "te",
99 "trailer",
100 "transfer-encoding",
101 "upgrade",
102 "host",
103 ];
104 for (name, value) in headers.iter() {
105 let name_str = name.as_str();
106 if !HOP_BY_HOP.contains(&name_str) && name_str != "authorization" {
107 request_builder = request_builder.header(name, value);
108 }
109 }
110
111 let response = request_builder
112 .body(reqwest::Body::wrap_stream(body_stream))
113 .send()
114 .await;
115
116 match response {
117 Ok(resp) => {
118 let status = resp.status();
119 let headers = resp.headers().clone();
120 let is_sse = resp
121 .headers()
122 .get(axum::http::header::CONTENT_TYPE)
123 .and_then(|v| v.to_str().ok())
124 .map(|v| v.contains("text/event-stream"))
125 .unwrap_or(false);
126
127 if is_sse {
128 let mut response = axum::response::Response::new(Body::from_stream(
129 resp.bytes_stream().map(|result| {
130 result.map_err(std::io::Error::other)
131 }),
132 ));
133 *response.status_mut() = status;
134 for (name, value) in headers.iter() {
135 response.headers_mut().insert(name, value.clone());
136 }
137 response
138 } else {
139 let bytes = match resp.bytes().await {
140 Ok(b) => b,
141 Err(e) => {
142 info!("Failed to read response body for {}: {}", path, e);
143 return (
144 StatusCode::BAD_GATEWAY,
145 Json(serde_json::json!({"error": format!("Failed to read backend response: {}", e)})),
146 )
147 .into_response();
148 }
149 };
150 (status, headers, bytes).into_response()
151 }
152 }
153 Err(e) => {
154 info!("Proxy error for {}: {}", path, e);
155 (
156 StatusCode::BAD_GATEWAY,
157 Json(serde_json::json!({"error": format!("Backend unavailable: {}", e)})),
158 )
159 .into_response()
160 }
161 }
162}
163
164async fn health(State(state): State<ApiState>) -> impl IntoResponse {
166 let resp = state
167 .client
168 .get(format!("{}/health", state.server_url))
169 .send()
170 .await;
171
172 match resp {
173 Ok(response) if response.status().is_success() => Json(serde_json::json!({
174 "status": "ok",
175 "backend": "healthy"
176 })),
177 Ok(_) => Json(serde_json::json!({
178 "status": "degraded",
179 "backend": "unreachable"
180 })),
181 Err(_) => Json(serde_json::json!({
182 "status": "degraded",
183 "backend": "unreachable"
184 })),
185 }
186}
187
188async fn status(State(state): State<ApiState>) -> impl IntoResponse {
190 let uptime = state.start_time.elapsed();
191 let uptime_secs = uptime.as_secs();
192
193 let loaded_models = match state
195 .client
196 .get(format!("{}/models", state.server_url))
197 .send()
198 .await
199 {
200 Ok(resp) if resp.status().is_success() => {
201 let json: serde_json::Value = match resp.json().await {
202 Ok(v) => v,
203 Err(_) => serde_json::json!([]),
204 };
205 json.get("data")
206 .and_then(|d| d.as_array())
207 .map(|a| a.len())
208 .unwrap_or(0)
209 }
210 _ => 0,
211 };
212
213 Json(serde_json::json!({
214 "status": "running",
215 "pid": state.pid,
216 "port": state.port,
217 "model": state.model_name,
218 "uptime_seconds": uptime_secs,
219 "loaded_models": loaded_models,
220 }))
221}
222
223pub async fn start_api_server(
224 addr: SocketAddr,
225 api_key: Option<String>,
226 server_port: u16,
227 model_name: String,
228 pid: u32,
229 mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
230 host: String,
231 tls_config: Option<axum_server::tls_rustls::RustlsConfig>,
232) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
233 let bind = addr;
234 let start_time = Instant::now();
235 let client = Client::builder()
236 .pool_max_idle_per_host(20)
237 .timeout(std::time::Duration::from_secs(300))
238 .build()?;
239 let state = ApiState {
240 server_url: format!("http://127.0.0.1:{}", server_port),
241 api_key,
242 model_name,
243 pid,
244 start_time,
245 port: bind.port(),
246 client,
247 };
248
249 let cors = CorsLayer::new()
250 .allow_origin(tower_http::cors::Any)
251 .allow_methods([
252 axum::http::Method::GET,
253 axum::http::Method::POST,
254 axum::http::Method::PUT,
255 axum::http::Method::DELETE,
256 axum::http::Method::OPTIONS,
257 ])
258 .allow_headers([
259 axum::http::header::CONTENT_TYPE,
260 axum::http::header::AUTHORIZATION,
261 ]);
262
263 let api_key_clone = state.api_key.clone();
264 let protocol = if tls_config.is_some() {
265 "https"
266 } else {
267 "http"
268 };
269 info!(
270 "API server starting on {protocol}://{} (proxying to http://127.0.0.1:{})",
271 host, server_port
272 );
273 if api_key_clone.is_some() {
274 info!("API key authentication is ENABLED");
275 }
276
277 let app = Router::new()
278 .route("/health", get(health))
279 .route("/metrics", get(proxy_streaming))
280 .merge(
281 Router::new()
282 .route("/v1/chat/completions", post(proxy_streaming))
283 .route("/v1/completions", post(proxy_streaming))
284 .route("/v1/embeddings", post(proxy_streaming))
285 .route("/v1/models", get(proxy_streaming))
286 .route("/api/status", get(status))
287 .fallback(proxy_streaming)
288 .layer(cors)
289 .layer(TraceLayer::new_for_http())
290 .layer(axum::middleware::from_fn_with_state(
291 state.clone(),
292 auth_middleware,
293 )),
294 )
295 .with_state(state);
296
297 match tls_config {
298 Some(tls_cfg) => {
299 let tls_listener = axum_server::bind_rustls(bind, tls_cfg);
300 let shutdown_fut = async {
301 let _ = shutdown_rx.wait_for(|v| *v).await;
302 };
303 let _ = tokio::select! {
304 result = tls_listener.serve(app.into_make_service()) => result,
305 _ = shutdown_fut => Ok(()),
306 };
307 }
308 None => {
309 axum::serve(tokio::net::TcpListener::bind(bind).await?, app)
310 .with_graceful_shutdown(async move {
311 let _ = shutdown_rx.wait_for(|v| *v).await;
312 })
313 .await?;
314 }
315 }
316 Ok(())
317}