use axum::body::Bytes;
use crate::error::EngineError;
#[derive(Clone)]
pub struct ForwardProxy {
client: reqwest::Client,
child_base_url: String,
}
impl ForwardProxy {
pub fn new(client: reqwest::Client, child_base_url: String) -> Self {
Self {
client,
child_base_url,
}
}
pub(crate) async fn forward(
&self,
subpath: &str,
content_type: &str,
body: Bytes,
) -> Result<reqwest::Response, EngineError> {
let url = format!("{}{subpath}", self.child_base_url);
self.client
.post(&url)
.header(reqwest::header::CONTENT_TYPE, content_type)
.body(body)
.send()
.await
.map_err(|e| EngineError::Inference(format!("proxy {subpath} : {e}")))
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{routing::post, Router};
use std::sync::Arc;
use tokio::{net::TcpListener, sync::Mutex};
async fn start_capture_stub(
path: &'static str,
status: u16,
resp_content_type: &'static str,
resp_body: &'static str,
) -> (u16, Arc<Mutex<Vec<u8>>>) {
use axum::body::Bytes as AxBytes;
use axum::http::StatusCode as AxStatus;
use axum::response::Response as AxResponse;
let captured = Arc::new(Mutex::new(Vec::<u8>::new()));
let cap2 = captured.clone();
let app = Router::new().route(
path,
post(move |body: AxBytes| {
let cap = cap2.clone();
async move {
*cap.lock().await = body.to_vec();
AxResponse::builder()
.status(AxStatus::from_u16(status).unwrap())
.header("content-type", resp_content_type)
.body(axum::body::Body::from(resp_body))
.unwrap()
}
}),
);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(port, captured)
}
fn make_forward(port: u16) -> ForwardProxy {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.unwrap();
ForwardProxy::new(client, format!("http://127.0.0.1:{port}"))
}
#[tokio::test]
async fn forward_preserves_body_byte_for_byte() {
let (port, captured) = start_capture_stub(
"/v1/chat/completions",
200,
"application/json",
"{\"ok\":true}",
)
.await;
let fwd = make_forward(port);
let raw = br#"{"messages":[{"role":"user","content":"hi"}],"slot_id":3,"temperature":0.7,"tools":[{"type":"function"}],"seed":42,"stream":false}"#;
let resp = fwd
.forward(
"/v1/chat/completions",
"application/json",
axum::body::Bytes::from(raw.to_vec()),
)
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 200);
let got = captured.lock().await.clone();
assert_eq!(
got.as_slice(),
raw.as_slice(),
"le body forwardé doit être identique byte-for-byte (slot_id/tools/sampling/seed préservés)"
);
}
#[tokio::test]
async fn forward_passes_status_and_content_type() {
let (port, _) = start_capture_stub(
"/v1/chat/completions",
503,
"text/event-stream",
"data: x\n\n",
)
.await;
let fwd = make_forward(port);
let resp = fwd
.forward(
"/v1/chat/completions",
"application/json",
axum::body::Bytes::from_static(b"{}"),
)
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 503, "statut upstream propagé");
let ct = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
ct.starts_with("text/event-stream"),
"content-type upstream propagé : {ct}"
);
}
#[tokio::test]
async fn forward_connection_refused_returns_inference_error() {
let fwd = make_forward(1); let result = fwd
.forward(
"/v1/chat/completions",
"application/json",
axum::body::Bytes::from_static(b"{}"),
)
.await;
assert!(
matches!(result, Err(EngineError::Inference(_))),
"connexion refusée → Inference (pas de panic)"
);
}
}