Skip to main content

evolve_proxy/
lib.rs

1//! OpenAI-compat HTTP proxy.
2//!
3//! Started via `evolve proxy --for cursor`. Sits between Cursor (or similar)
4//! and the real upstream provider. On each request:
5//!   1. Inject the active config's `system_prompt_prefix` into the `messages`.
6//!   2. Forward to upstream.
7//!   3. Emit a signal candidate describing the interaction.
8
9#![forbid(unsafe_code)]
10#![warn(missing_docs)]
11
12use axum::{
13    Json, Router,
14    extract::State,
15    http::StatusCode,
16    response::{IntoResponse, Response},
17    routing::post,
18};
19use serde_json::Value;
20use std::net::SocketAddr;
21use std::sync::Arc;
22use tokio::sync::Mutex;
23
24/// Proxy configuration.
25#[derive(Debug, Clone)]
26pub struct ProxyConfig {
27    /// Upstream base URL, e.g. `https://api.openai.com`.
28    pub upstream: String,
29    /// Auth token to forward to upstream (Bearer token).
30    pub upstream_token: Option<String>,
31    /// System prompt prefix to inject into every request.
32    pub prefix: String,
33}
34
35/// Handle for emitting signal events from the proxy.
36pub type SignalSink = Arc<Mutex<Vec<Value>>>;
37
38/// Application state shared by all handlers.
39#[derive(Clone)]
40pub struct AppState {
41    /// Current config (swap-able at runtime in future).
42    pub config: ProxyConfig,
43    /// Sink where proxy events land — the CLI later flushes these to storage.
44    pub signals: SignalSink,
45    /// Reqwest client for upstream forwarding.
46    pub http: reqwest::Client,
47}
48
49/// Build the axum router.
50pub fn router(state: AppState) -> Router {
51    Router::new()
52        .route("/v1/chat/completions", post(chat_completions))
53        .route("/healthz", axum::routing::get(|| async { "ok" }))
54        .with_state(state)
55}
56
57/// Bind a listener and serve until the OS signals shutdown.
58pub async fn serve(addr: SocketAddr, state: AppState) -> Result<(), std::io::Error> {
59    let app = router(state);
60    let listener = tokio::net::TcpListener::bind(addr).await?;
61    tracing::info!(?addr, "evolve-proxy listening");
62    axum::serve(listener, app).await
63}
64
65async fn chat_completions(
66    State(state): State<AppState>,
67    Json(mut body): Json<Value>,
68) -> Result<Response, ProxyHandlerError> {
69    // Inject the system prefix as a prepended system message.
70    if let Some(messages) = body.get_mut("messages").and_then(|m| m.as_array_mut()) {
71        messages.insert(
72            0,
73            serde_json::json!({
74                "role": "system",
75                "content": state.config.prefix,
76            }),
77        );
78    } else {
79        body["messages"] = serde_json::json!([
80            {"role": "system", "content": state.config.prefix}
81        ]);
82    }
83
84    let url = format!("{}/v1/chat/completions", state.config.upstream);
85    let mut req = state.http.post(&url).json(&body);
86    if let Some(token) = &state.config.upstream_token {
87        req = req.bearer_auth(token);
88    }
89    let upstream_resp = req
90        .send()
91        .await
92        .map_err(|e| ProxyHandlerError::Upstream(format!("forward failed: {e}")))?;
93    let status = upstream_resp.status();
94    let upstream_body_bytes = upstream_resp
95        .bytes()
96        .await
97        .map_err(|e| ProxyHandlerError::Upstream(format!("upstream body read failed: {e}")))?;
98
99    // Emit a signal candidate. (CLI flushes these.)
100    {
101        let mut sink = state.signals.lock().await;
102        sink.push(serde_json::json!({
103            "event": "proxy_request_forwarded",
104            "status": status.as_u16(),
105            "prefix_injected": true,
106        }));
107    }
108
109    let mut response = Response::new(axum::body::Body::from(upstream_body_bytes));
110    *response.status_mut() = axum::http::StatusCode::from_u16(status.as_u16())
111        .unwrap_or(axum::http::StatusCode::BAD_GATEWAY);
112    response
113        .headers_mut()
114        .insert("content-type", "application/json".parse().unwrap());
115    Ok(response)
116}
117
118/// Internal handler error.
119#[derive(Debug)]
120pub enum ProxyHandlerError {
121    /// Upstream-side failure.
122    Upstream(String),
123}
124
125impl IntoResponse for ProxyHandlerError {
126    fn into_response(self) -> Response {
127        match self {
128            ProxyHandlerError::Upstream(msg) => (
129                StatusCode::BAD_GATEWAY,
130                Json(serde_json::json!({"error": msg})),
131            )
132                .into_response(),
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use axum::body::to_bytes;
141    use axum::http::{Request, StatusCode};
142    use tower::ServiceExt;
143    use wiremock::matchers::{method, path};
144    use wiremock::{Mock, MockServer, ResponseTemplate};
145
146    fn fresh_sink() -> SignalSink {
147        Arc::new(Mutex::new(Vec::new()))
148    }
149
150    async fn setup(upstream_uri: String, prefix: &str) -> (Router, SignalSink) {
151        let sink = fresh_sink();
152        let state = AppState {
153            config: ProxyConfig {
154                upstream: upstream_uri,
155                upstream_token: Some("test-token".into()),
156                prefix: prefix.to_string(),
157            },
158            signals: sink.clone(),
159            http: reqwest::Client::new(),
160        };
161        (router(state), sink)
162    }
163
164    #[tokio::test]
165    async fn healthz_returns_ok() {
166        let (app, _) = setup("http://localhost:0".into(), "").await;
167        let resp = app
168            .oneshot(
169                Request::builder()
170                    .uri("/healthz")
171                    .body(axum::body::Body::empty())
172                    .unwrap(),
173            )
174            .await
175            .unwrap();
176        assert_eq!(resp.status(), StatusCode::OK);
177    }
178
179    #[tokio::test]
180    async fn forwards_and_injects_system_prefix() {
181        let upstream = MockServer::start().await;
182        Mock::given(method("POST"))
183            .and(path("/v1/chat/completions"))
184            .respond_with(
185                ResponseTemplate::new(200)
186                    .set_body_string(r#"{"choices":[{"message":{"content":"pong"}}]}"#),
187            )
188            .mount(&upstream)
189            .await;
190
191        let (app, sink) = setup(upstream.uri(), "INJECTED PREFIX").await;
192        let body = serde_json::json!({
193            "model": "gpt-4",
194            "messages": [{"role":"user","content":"ping"}],
195        });
196        let resp = app
197            .oneshot(
198                Request::builder()
199                    .method("POST")
200                    .uri("/v1/chat/completions")
201                    .header("content-type", "application/json")
202                    .body(axum::body::Body::from(serde_json::to_vec(&body).unwrap()))
203                    .unwrap(),
204            )
205            .await
206            .unwrap();
207        assert_eq!(resp.status(), StatusCode::OK);
208
209        // Check the proxy recorded a forwarded event.
210        let recorded = sink.lock().await;
211        assert_eq!(recorded.len(), 1);
212        assert_eq!(recorded[0]["event"], "proxy_request_forwarded");
213        assert_eq!(recorded[0]["status"], 200);
214
215        // Sanity: body contains upstream's output
216        let body_bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
217        let text = String::from_utf8_lossy(&body_bytes);
218        assert!(text.contains("pong"));
219    }
220
221    #[tokio::test]
222    async fn returns_502_on_upstream_failure() {
223        // Use a known-closed port for upstream.
224        let (app, _) = setup("http://127.0.0.1:1".into(), "x").await;
225        let body = serde_json::json!({"messages": [{"role":"user","content":"x"}]});
226        let resp = app
227            .oneshot(
228                Request::builder()
229                    .method("POST")
230                    .uri("/v1/chat/completions")
231                    .header("content-type", "application/json")
232                    .body(axum::body::Body::from(serde_json::to_vec(&body).unwrap()))
233                    .unwrap(),
234            )
235            .await
236            .unwrap();
237        assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
238    }
239}