#![forbid(unsafe_code)]
#![warn(missing_docs)]
use axum::{
Json, Router,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
};
use serde_json::Value;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub upstream: String,
pub upstream_token: Option<String>,
pub prefix: String,
}
pub type SignalSink = Arc<Mutex<Vec<Value>>>;
#[derive(Clone)]
pub struct AppState {
pub config: ProxyConfig,
pub signals: SignalSink,
pub http: reqwest::Client,
}
pub fn router(state: AppState) -> Router {
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/healthz", axum::routing::get(|| async { "ok" }))
.with_state(state)
}
pub async fn serve(addr: SocketAddr, state: AppState) -> Result<(), std::io::Error> {
let app = router(state);
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!(?addr, "evolve-proxy listening");
axum::serve(listener, app).await
}
async fn chat_completions(
State(state): State<AppState>,
Json(mut body): Json<Value>,
) -> Result<Response, ProxyHandlerError> {
if let Some(messages) = body.get_mut("messages").and_then(|m| m.as_array_mut()) {
messages.insert(
0,
serde_json::json!({
"role": "system",
"content": state.config.prefix,
}),
);
} else {
body["messages"] = serde_json::json!([
{"role": "system", "content": state.config.prefix}
]);
}
let url = format!("{}/v1/chat/completions", state.config.upstream);
let mut req = state.http.post(&url).json(&body);
if let Some(token) = &state.config.upstream_token {
req = req.bearer_auth(token);
}
let upstream_resp = req
.send()
.await
.map_err(|e| ProxyHandlerError::Upstream(format!("forward failed: {e}")))?;
let status = upstream_resp.status();
let upstream_body_bytes = upstream_resp
.bytes()
.await
.map_err(|e| ProxyHandlerError::Upstream(format!("upstream body read failed: {e}")))?;
{
let mut sink = state.signals.lock().await;
sink.push(serde_json::json!({
"event": "proxy_request_forwarded",
"status": status.as_u16(),
"prefix_injected": true,
}));
}
let mut response = Response::new(axum::body::Body::from(upstream_body_bytes));
*response.status_mut() = axum::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(axum::http::StatusCode::BAD_GATEWAY);
response
.headers_mut()
.insert("content-type", "application/json".parse().unwrap());
Ok(response)
}
#[derive(Debug)]
pub enum ProxyHandlerError {
Upstream(String),
}
impl IntoResponse for ProxyHandlerError {
fn into_response(self) -> Response {
match self {
ProxyHandlerError::Upstream(msg) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({"error": msg})),
)
.into_response(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::to_bytes;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn fresh_sink() -> SignalSink {
Arc::new(Mutex::new(Vec::new()))
}
async fn setup(upstream_uri: String, prefix: &str) -> (Router, SignalSink) {
let sink = fresh_sink();
let state = AppState {
config: ProxyConfig {
upstream: upstream_uri,
upstream_token: Some("test-token".into()),
prefix: prefix.to_string(),
},
signals: sink.clone(),
http: reqwest::Client::new(),
};
(router(state), sink)
}
#[tokio::test]
async fn healthz_returns_ok() {
let (app, _) = setup("http://localhost:0".into(), "").await;
let resp = app
.oneshot(
Request::builder()
.uri("/healthz")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn forwards_and_injects_system_prefix() {
let upstream = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string(r#"{"choices":[{"message":{"content":"pong"}}]}"#),
)
.mount(&upstream)
.await;
let (app, sink) = setup(upstream.uri(), "INJECTED PREFIX").await;
let body = serde_json::json!({
"model": "gpt-4",
"messages": [{"role":"user","content":"ping"}],
});
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(axum::body::Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let recorded = sink.lock().await;
assert_eq!(recorded.len(), 1);
assert_eq!(recorded[0]["event"], "proxy_request_forwarded");
assert_eq!(recorded[0]["status"], 200);
let body_bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let text = String::from_utf8_lossy(&body_bytes);
assert!(text.contains("pong"));
}
#[tokio::test]
async fn returns_502_on_upstream_failure() {
let (app, _) = setup("http://127.0.0.1:1".into(), "x").await;
let body = serde_json::json!({"messages": [{"role":"user","content":"x"}]});
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(axum::body::Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
}
}