Skip to main content

lean_ctx/proxy/
forward.rs

1use axum::{
2    body::Body,
3    extract::State,
4    http::{request::Parts, Request, StatusCode},
5    response::Response,
6};
7
8use super::ProxyState;
9
10/// Default request-body ceiling (MiB). A large-codebase refactor with several
11/// big files in context easily exceeds the old 10 MiB cap, which surfaced to the
12/// agent as a hard `400` mid-task. Raised and made configurable via
13/// `LEAN_CTX_PROXY_MAX_BODY_MB`.
14const DEFAULT_MAX_BODY_MB: usize = 64;
15
16fn max_body_bytes() -> usize {
17    std::env::var("LEAN_CTX_PROXY_MAX_BODY_MB")
18        .ok()
19        .and_then(|v| v.trim().parse::<usize>().ok())
20        .filter(|mb| *mb > 0)
21        .unwrap_or(DEFAULT_MAX_BODY_MB)
22        .saturating_mul(1024 * 1024)
23}
24
25/// Receives the already-parsed JSON value, avoiding a redundant
26/// `serde_json::from_slice` on every request. Returns the serialized (possibly
27/// compressed) body, original size, and compressed size.
28pub type CompressFn = fn(serde_json::Value, usize) -> (Vec<u8>, usize, usize);
29
30pub async fn forward_request(
31    State(state): State<ProxyState>,
32    req: Request<Body>,
33    upstream_base: &str,
34    default_path: &str,
35    compress_body: CompressFn,
36    provider_label: &str,
37    extra_stream_types: &[&str],
38) -> Result<Response, StatusCode> {
39    let (parts, body) = req.into_parts();
40    let body_bytes = axum::body::to_bytes(body, max_body_bytes())
41        .await
42        .map_err(|_| StatusCode::PAYLOAD_TOO_LARGE)?;
43
44    state.stats.record_request();
45
46    let original_size = body_bytes.len();
47
48    // Parse once; the parsed value is shared between introspection, cost
49    // attribution, and compression — eliminating the redundant re-parse that
50    // each compress_body function previously performed internally.
51    let parsed = serde_json::from_slice::<serde_json::Value>(&body_bytes).ok();
52    if let Some(ref parsed) = parsed {
53        let provider = match provider_label {
54            "Anthropic" => super::introspect::Provider::Anthropic,
55            "OpenAI" => super::introspect::Provider::OpenAi,
56            _ => super::introspect::Provider::Gemini,
57        };
58        let breakdown = super::introspect::analyze_request(parsed, provider);
59        state.introspect.record(breakdown);
60    }
61
62    let (compressed_body, _, compressed_size) = if let Some(value) = parsed.clone() {
63        compress_body(value, original_size)
64    } else {
65        (body_bytes.to_vec(), original_size, original_size)
66    };
67
68    if compressed_size < original_size {
69        state
70            .stats
71            .record_compression(original_size, compressed_size);
72    }
73
74    let tokens_saved = original_size.saturating_sub(compressed_size) as u64 / 4;
75    super::metrics::record_request(tokens_saved, compressed_size as u64);
76
77    let model = parsed
78        .as_ref()
79        .and_then(|v| v.get("model"))
80        .and_then(|m| m.as_str());
81    super::cost::record(
82        model,
83        tokens_saved,
84        original_size as u64,
85        compressed_size as u64,
86    );
87
88    let upstream_url = build_upstream_url(&parts, upstream_base, default_path);
89    let response = send_upstream(
90        &state,
91        &parts,
92        &upstream_url,
93        compressed_body,
94        provider_label,
95    )
96    .await?;
97
98    build_response(response, extra_stream_types).await
99}
100
101fn build_upstream_url(parts: &Parts, base: &str, default_path: &str) -> String {
102    format!(
103        "{base}{}",
104        parts
105            .uri
106            .path_and_query()
107            .map_or(default_path, axum::http::uri::PathAndQuery::as_str)
108    )
109}
110
111async fn send_upstream(
112    state: &ProxyState,
113    parts: &Parts,
114    url: &str,
115    body: Vec<u8>,
116    provider_label: &str,
117) -> Result<reqwest::Response, StatusCode> {
118    let mut req = state.client.request(parts.method.clone(), url);
119
120    const ALLOWED_HEADERS: &[&str] = &[
121        "authorization",
122        "x-api-key",
123        "content-type",
124        "accept",
125        "user-agent",
126        "anthropic-version",
127        "anthropic-beta",
128        "anthropic-dangerous-direct-browser-access",
129        "openai-organization",
130        "openai-beta",
131        "x-goog-api-key",
132        "x-goog-api-client",
133    ];
134    for (key, value) in &parts.headers {
135        let k = key.as_str().to_lowercase();
136        if ALLOWED_HEADERS.contains(&k.as_str()) {
137            req = req.header(key.clone(), value.clone());
138        }
139    }
140
141    req.body(body).send().await.map_err(|e| {
142        tracing::error!("lean-ctx proxy: {provider_label} upstream error: {e}");
143        StatusCode::BAD_GATEWAY
144    })
145}
146
147const FORWARDED_HEADERS: &[&str] = &[
148    "content-type",
149    "content-encoding",
150    "x-request-id",
151    "openai-organization",
152    "openai-processing-ms",
153    "openai-version",
154    "anthropic-ratelimit-requests-limit",
155    "anthropic-ratelimit-requests-remaining",
156    "anthropic-ratelimit-tokens-limit",
157    "anthropic-ratelimit-tokens-remaining",
158    "retry-after",
159    "x-ratelimit-limit-requests",
160    "x-ratelimit-remaining-requests",
161    "x-ratelimit-limit-tokens",
162    "x-ratelimit-remaining-tokens",
163    "cache-control",
164];
165
166async fn build_response(
167    response: reqwest::Response,
168    extra_stream_types: &[&str],
169) -> Result<Response, StatusCode> {
170    let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::OK);
171    let resp_headers = response.headers().clone();
172
173    let is_stream = resp_headers
174        .get("content-type")
175        .and_then(|v| v.to_str().ok())
176        .is_some_and(|ct| {
177            ct.contains("text/event-stream") || extra_stream_types.iter().any(|t| ct.contains(t))
178        });
179
180    if is_stream {
181        let stream = response.bytes_stream();
182        let body = Body::from_stream(stream);
183        let mut resp = Response::builder().status(status);
184        for (k, v) in &resp_headers {
185            let ks = k.as_str().to_lowercase();
186            if FORWARDED_HEADERS.contains(&ks.as_str()) {
187                resp = resp.header(k, v);
188            }
189        }
190        return resp
191            .body(body)
192            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
193    }
194
195    let resp_bytes = response
196        .bytes()
197        .await
198        .map_err(|_| StatusCode::BAD_GATEWAY)?;
199
200    let mut resp = Response::builder().status(status);
201    for (k, v) in &resp_headers {
202        let ks = k.as_str().to_lowercase();
203        if FORWARDED_HEADERS.contains(&ks.as_str()) {
204            resp = resp.header(k, v);
205        }
206    }
207    resp.body(Body::from(resp_bytes))
208        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    fn parts_for(uri: &str) -> Parts {
216        Request::builder().uri(uri).body(()).unwrap().into_parts().0
217    }
218
219    #[test]
220    fn upstream_url_preserves_subpath() {
221        let base = "https://api.anthropic.com";
222        let parts = parts_for("/v1/messages/count_tokens");
223        assert_eq!(
224            build_upstream_url(&parts, base, "/v1/messages"),
225            "https://api.anthropic.com/v1/messages/count_tokens"
226        );
227    }
228
229    #[test]
230    fn upstream_url_preserves_batches_subpath() {
231        let base = "https://api.anthropic.com";
232        let parts = parts_for("/v1/messages/batches/batch_123/results");
233        assert_eq!(
234            build_upstream_url(&parts, base, "/v1/messages"),
235            "https://api.anthropic.com/v1/messages/batches/batch_123/results"
236        );
237    }
238
239    #[test]
240    fn upstream_url_exact_path() {
241        let base = "https://api.anthropic.com";
242        let parts = parts_for("/v1/messages");
243        assert_eq!(
244            build_upstream_url(&parts, base, "/v1/messages"),
245            "https://api.anthropic.com/v1/messages"
246        );
247    }
248
249    #[test]
250    fn upstream_url_preserves_query_params() {
251        let base = "https://api.anthropic.com";
252        let parts = parts_for("/v1/messages/count_tokens?model=claude-4");
253        assert_eq!(
254            build_upstream_url(&parts, base, "/v1/messages"),
255            "https://api.anthropic.com/v1/messages/count_tokens?model=claude-4"
256        );
257    }
258}