1#![forbid(unsafe_code)]
10#![warn(missing_docs)]
11
12use axum::{
13 Json, Router,
14 extract::State,
15 http::StatusCode,
16 response::{IntoResponse, Response},
17 routing::post,
18};
19use serde_json::Value;
20use std::net::SocketAddr;
21use std::sync::Arc;
22use tokio::sync::Mutex;
23
24#[derive(Debug, Clone)]
26pub struct ProxyConfig {
27 pub upstream: String,
29 pub upstream_token: Option<String>,
31 pub prefix: String,
33}
34
35pub type SignalSink = Arc<Mutex<Vec<Value>>>;
37
38#[derive(Clone)]
40pub struct AppState {
41 pub config: ProxyConfig,
43 pub signals: SignalSink,
45 pub http: reqwest::Client,
47}
48
49pub fn router(state: AppState) -> Router {
51 Router::new()
52 .route("/v1/chat/completions", post(chat_completions))
53 .route("/healthz", axum::routing::get(|| async { "ok" }))
54 .with_state(state)
55}
56
57pub async fn serve(addr: SocketAddr, state: AppState) -> Result<(), std::io::Error> {
59 let app = router(state);
60 let listener = tokio::net::TcpListener::bind(addr).await?;
61 tracing::info!(?addr, "evolve-proxy listening");
62 axum::serve(listener, app).await
63}
64
65async fn chat_completions(
66 State(state): State<AppState>,
67 Json(mut body): Json<Value>,
68) -> Result<Response, ProxyHandlerError> {
69 if let Some(messages) = body.get_mut("messages").and_then(|m| m.as_array_mut()) {
71 messages.insert(
72 0,
73 serde_json::json!({
74 "role": "system",
75 "content": state.config.prefix,
76 }),
77 );
78 } else {
79 body["messages"] = serde_json::json!([
80 {"role": "system", "content": state.config.prefix}
81 ]);
82 }
83
84 let url = format!("{}/v1/chat/completions", state.config.upstream);
85 let mut req = state.http.post(&url).json(&body);
86 if let Some(token) = &state.config.upstream_token {
87 req = req.bearer_auth(token);
88 }
89 let upstream_resp = req
90 .send()
91 .await
92 .map_err(|e| ProxyHandlerError::Upstream(format!("forward failed: {e}")))?;
93 let status = upstream_resp.status();
94 let upstream_body_bytes = upstream_resp
95 .bytes()
96 .await
97 .map_err(|e| ProxyHandlerError::Upstream(format!("upstream body read failed: {e}")))?;
98
99 {
101 let mut sink = state.signals.lock().await;
102 sink.push(serde_json::json!({
103 "event": "proxy_request_forwarded",
104 "status": status.as_u16(),
105 "prefix_injected": true,
106 }));
107 }
108
109 let mut response = Response::new(axum::body::Body::from(upstream_body_bytes));
110 *response.status_mut() = axum::http::StatusCode::from_u16(status.as_u16())
111 .unwrap_or(axum::http::StatusCode::BAD_GATEWAY);
112 response
113 .headers_mut()
114 .insert("content-type", "application/json".parse().unwrap());
115 Ok(response)
116}
117
118#[derive(Debug)]
120pub enum ProxyHandlerError {
121 Upstream(String),
123}
124
125impl IntoResponse for ProxyHandlerError {
126 fn into_response(self) -> Response {
127 match self {
128 ProxyHandlerError::Upstream(msg) => (
129 StatusCode::BAD_GATEWAY,
130 Json(serde_json::json!({"error": msg})),
131 )
132 .into_response(),
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use axum::body::to_bytes;
141 use axum::http::{Request, StatusCode};
142 use tower::ServiceExt;
143 use wiremock::matchers::{method, path};
144 use wiremock::{Mock, MockServer, ResponseTemplate};
145
146 fn fresh_sink() -> SignalSink {
147 Arc::new(Mutex::new(Vec::new()))
148 }
149
150 async fn setup(upstream_uri: String, prefix: &str) -> (Router, SignalSink) {
151 let sink = fresh_sink();
152 let state = AppState {
153 config: ProxyConfig {
154 upstream: upstream_uri,
155 upstream_token: Some("test-token".into()),
156 prefix: prefix.to_string(),
157 },
158 signals: sink.clone(),
159 http: reqwest::Client::new(),
160 };
161 (router(state), sink)
162 }
163
164 #[tokio::test]
165 async fn healthz_returns_ok() {
166 let (app, _) = setup("http://localhost:0".into(), "").await;
167 let resp = app
168 .oneshot(
169 Request::builder()
170 .uri("/healthz")
171 .body(axum::body::Body::empty())
172 .unwrap(),
173 )
174 .await
175 .unwrap();
176 assert_eq!(resp.status(), StatusCode::OK);
177 }
178
179 #[tokio::test]
180 async fn forwards_and_injects_system_prefix() {
181 let upstream = MockServer::start().await;
182 Mock::given(method("POST"))
183 .and(path("/v1/chat/completions"))
184 .respond_with(
185 ResponseTemplate::new(200)
186 .set_body_string(r#"{"choices":[{"message":{"content":"pong"}}]}"#),
187 )
188 .mount(&upstream)
189 .await;
190
191 let (app, sink) = setup(upstream.uri(), "INJECTED PREFIX").await;
192 let body = serde_json::json!({
193 "model": "gpt-4",
194 "messages": [{"role":"user","content":"ping"}],
195 });
196 let resp = app
197 .oneshot(
198 Request::builder()
199 .method("POST")
200 .uri("/v1/chat/completions")
201 .header("content-type", "application/json")
202 .body(axum::body::Body::from(serde_json::to_vec(&body).unwrap()))
203 .unwrap(),
204 )
205 .await
206 .unwrap();
207 assert_eq!(resp.status(), StatusCode::OK);
208
209 let recorded = sink.lock().await;
211 assert_eq!(recorded.len(), 1);
212 assert_eq!(recorded[0]["event"], "proxy_request_forwarded");
213 assert_eq!(recorded[0]["status"], 200);
214
215 let body_bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
217 let text = String::from_utf8_lossy(&body_bytes);
218 assert!(text.contains("pong"));
219 }
220
221 #[tokio::test]
222 async fn returns_502_on_upstream_failure() {
223 let (app, _) = setup("http://127.0.0.1:1".into(), "x").await;
225 let body = serde_json::json!({"messages": [{"role":"user","content":"x"}]});
226 let resp = app
227 .oneshot(
228 Request::builder()
229 .method("POST")
230 .uri("/v1/chat/completions")
231 .header("content-type", "application/json")
232 .body(axum::body::Body::from(serde_json::to_vec(&body).unwrap()))
233 .unwrap(),
234 )
235 .await
236 .unwrap();
237 assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
238 }
239}