use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::{
body::{Body, Bytes},
extract::{DefaultBodyLimit, State},
http::{header, HeaderMap, StatusCode},
response::Response,
routing::{get, post},
Json, Router,
};
use gradatum_core::event_sink::{EngineEvent, EventSink};
use crate::{health::HealthState, metrics::EngineMetrics, runtime::ForwardProxy};
#[derive(Clone)]
pub struct AppState {
pub proxy: ForwardProxy,
pub health: Arc<HealthState>,
pub metrics: Arc<EngineMetrics>,
pub sink: Arc<dyn EventSink>,
pub model_name: String,
pub provider: String,
pub timeout_secs: u64,
pub body_limit_bytes: usize,
}
pub struct EngineServer;
impl EngineServer {
pub fn router(state: AppState) -> Router {
let body_limit = state.body_limit_bytes;
Router::new()
.route("/health", get(health_handler))
.route("/v1/chat/completions", post(chat_handler))
.route("/v1/embeddings", post(embed_handler))
.layer(DefaultBodyLimit::max(body_limit))
.with_state(state)
}
pub fn metrics_router(metrics: Arc<EngineMetrics>) -> Router {
Router::new()
.route("/metrics", get(metrics_handler))
.with_state(metrics)
}
#[cfg(any(test, feature = "test-utils"))]
pub fn test_app_with_child(
child_base_url: String,
timeout_secs: u64,
body_limit_bytes: usize,
) -> Router {
use gradatum_core::event_sink::InMemorySink;
let client = reqwest::Client::builder()
.build()
.expect("client reqwest de test");
let health = Arc::new(HealthState::new("test-model"));
health.set_ready();
let state = AppState {
proxy: ForwardProxy::new(client, child_base_url),
health,
metrics: Arc::new(EngineMetrics::new()),
sink: Arc::new(InMemorySink::default()),
model_name: "test-model".into(),
provider: "engine-test".into(),
timeout_secs,
body_limit_bytes,
};
Self::router(state)
}
}
async fn health_handler(State(s): State<AppState>) -> Json<crate::health::HealthSnapshot> {
Json(s.health.snapshot())
}
async fn metrics_handler(State(m): State<Arc<EngineMetrics>>) -> (StatusCode, Body) {
let text = m.render();
(StatusCode::OK, Body::from(text))
}
async fn chat_handler(
State(s): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> Result<Response, (StatusCode, String)> {
proxy_request(&s, "/v1/chat/completions", headers, body).await
}
async fn embed_handler(
State(s): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> Result<Response, (StatusCode, String)> {
proxy_request(&s, "/v1/embeddings", headers, body).await
}
async fn proxy_request(
s: &AppState,
subpath: &str,
headers: HeaderMap,
body: Bytes,
) -> Result<Response, (StatusCode, String)> {
let t0 = Instant::now();
let content_type = headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("application/json")
.to_string();
let sent = tokio::time::timeout(
Duration::from_secs(s.timeout_secs),
s.proxy.forward(subpath, &content_type, body),
)
.await;
let ms = t0.elapsed().as_millis() as u64;
match sent {
Err(_elapsed) => {
s.metrics.record_request(subpath, 504, ms);
Err((StatusCode::GATEWAY_TIMEOUT, "timeout".into()))
}
Ok(Err(e)) => {
let status = e.status();
s.metrics.record_request(subpath, status, ms);
Err((
StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
e.to_string(),
))
}
Ok(Ok(resp)) => {
let status = resp.status().as_u16();
s.metrics.record_request(subpath, status, ms);
s.sink
.emit(EngineEvent::RequestServed {
route: subpath.to_string(),
model: s.model_name.clone(),
provider: s.provider.clone(),
latency_ms: ms,
})
.await;
let resp_ct = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("application/json")
.to_string();
let axum_status = StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_GATEWAY);
Response::builder()
.status(axum_status)
.header(header::CONTENT_TYPE, resp_ct)
.body(Body::from_stream(resp.bytes_stream()))
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))
}
}
}
#[cfg(test)]
mod transparent_handler {
use super::*;
use axum::body::{to_bytes, Body, Bytes as AxBytes};
use axum::http::{Request, StatusCode};
use axum::response::Response as AxResponse;
use axum::routing::post;
use axum::Router;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tower::ServiceExt;
async fn start_child_stub(
delay_secs: u64,
resp_body: &'static str,
) -> (u16, Arc<Mutex<Vec<u8>>>) {
let captured = Arc::new(Mutex::new(Vec::<u8>::new()));
let c1 = captured.clone();
let c2 = captured.clone();
let handler = move |body: AxBytes| {
let cap = c1.clone();
async move {
if delay_secs > 0 {
tokio::time::sleep(std::time::Duration::from_secs(delay_secs)).await;
}
*cap.lock().await = body.to_vec();
AxResponse::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(Body::from(resp_body))
.unwrap()
}
};
let embed_handler = move |body: AxBytes| {
let cap = c2.clone();
async move {
*cap.lock().await = body.to_vec();
AxResponse::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(Body::from("{\"data\":[{\"embedding\":[0.1],\"index\":0}]}"))
.unwrap()
}
};
let app = Router::new()
.route("/v1/chat/completions", post(handler))
.route("/v1/embeddings", post(embed_handler));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(port, captured)
}
#[tokio::test]
async fn chat_forwards_body_with_sampling_and_slot_preserved() {
let (port, captured) =
start_child_stub(0, "{\"choices\":[{\"message\":{\"content\":\"ok\"}}]}").await;
let app = EngineServer::test_app_with_child(
format!("http://127.0.0.1:{port}"),
30,
32 * 1024 * 1024,
);
let raw = br#"{"messages":[{"role":"user","content":"hi"}],"temperature":0.7,"slot_id":2,"tools":[],"seed":7}"#;
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(raw.to_vec()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let got = captured.lock().await.clone();
assert_eq!(
got.as_slice(),
raw.as_slice(),
"body transparent : sampling/slot_id/tools/seed préservés"
);
}
#[tokio::test]
async fn chat_determinism_temperature_zero_preserved() {
let (port, captured) =
start_child_stub(0, "{\"choices\":[{\"message\":{\"content\":\"{}\"}}]}").await;
let app = EngineServer::test_app_with_child(
format!("http://127.0.0.1:{port}"),
30,
32 * 1024 * 1024,
);
let raw = br#"{"messages":[{"role":"user","content":"classify"}],"temperature":0.0}"#;
let _ = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(raw.to_vec()))
.unwrap(),
)
.await
.unwrap();
let got = captured.lock().await.clone();
let v: serde_json::Value = serde_json::from_slice(&got).unwrap();
assert_eq!(
v["temperature"], 0.0,
"temperature:0.0 doit être forwardé (déterminisme curator non régressé)"
);
}
#[tokio::test]
async fn chat_response_body_passed_through() {
let (port, _) = start_child_stub(
0,
"{\"id\":\"child-1\",\"choices\":[{\"message\":{\"content\":\"hi\"}}]}",
)
.await;
let app = EngineServer::test_app_with_child(
format!("http://127.0.0.1:{port}"),
30,
32 * 1024 * 1024,
);
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(
"{\"messages\":[{\"role\":\"user\",\"content\":\"x\"}]}",
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(
v["id"], "child-1",
"réponse child renvoyée telle quelle (pas de réécriture)"
);
}
#[tokio::test]
async fn chat_returns_504_on_timeout() {
let (port, _) = start_child_stub(5, "{}").await; let app = EngineServer::test_app_with_child(
format!("http://127.0.0.1:{port}"),
1,
32 * 1024 * 1024,
); let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(
"{\"messages\":[{\"role\":\"user\",\"content\":\"x\"}]}",
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
resp.status(),
StatusCode::GATEWAY_TIMEOUT,
"timeout < latence child → 504"
);
}
#[tokio::test]
async fn body_limit_returns_413_over_limit() {
let (port, _) = start_child_stub(0, "{}").await;
let app = EngineServer::test_app_with_child(format!("http://127.0.0.1:{port}"), 30, 64);
let big = vec![b'a'; 1024]; let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(big))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
resp.status(),
StatusCode::PAYLOAD_TOO_LARGE,
"body > body_limit_bytes → 413"
);
}
#[tokio::test]
async fn embed_forwards_and_returns() {
let (port, captured) = start_child_stub(0, "{}").await;
let app = EngineServer::test_app_with_child(
format!("http://127.0.0.1:{port}"),
30,
32 * 1024 * 1024,
);
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/embeddings")
.header("content-type", "application/json")
.body(Body::from("{\"input\":\"hello\"}"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let got = captured.lock().await.clone();
let v: serde_json::Value = serde_json::from_slice(&got).unwrap();
assert_eq!(v["input"], "hello", "embed body forwardé transparent");
}
}