Skip to main content

lean_ctx/proxy/
forward.rs

1use axum::{
2    body::Body,
3    extract::State,
4    http::{request::Parts, Request, StatusCode},
5    response::Response,
6};
7
8use super::ProxyState;
9
10/// Default request-body ceiling (MiB). A large-codebase refactor with several
11/// big files in context easily exceeds the old 10 MiB cap, which surfaced to the
12/// agent as a hard `400` mid-task. Raised and made configurable via
13/// `LEAN_CTX_PROXY_MAX_BODY_MB`.
14const DEFAULT_MAX_BODY_MB: usize = 64;
15
16fn max_body_bytes() -> usize {
17    std::env::var("LEAN_CTX_PROXY_MAX_BODY_MB")
18        .ok()
19        .and_then(|v| v.trim().parse::<usize>().ok())
20        .filter(|mb| *mb > 0)
21        .unwrap_or(DEFAULT_MAX_BODY_MB)
22        .saturating_mul(1024 * 1024)
23}
24
25pub type CompressFn = fn(&[u8]) -> (Vec<u8>, usize, usize);
26
27pub async fn forward_request(
28    State(state): State<ProxyState>,
29    req: Request<Body>,
30    upstream_base: &str,
31    default_path: &str,
32    compress_body: CompressFn,
33    provider_label: &str,
34    extra_stream_types: &[&str],
35) -> Result<Response, StatusCode> {
36    let (parts, body) = req.into_parts();
37    let body_bytes = axum::body::to_bytes(body, max_body_bytes())
38        .await
39        .map_err(|_| StatusCode::PAYLOAD_TOO_LARGE)?;
40
41    state.stats.record_request();
42
43    // Parse once; reuse for introspection and per-model cost attribution.
44    let parsed = serde_json::from_slice::<serde_json::Value>(&body_bytes).ok();
45    if let Some(parsed) = &parsed {
46        let provider = match provider_label {
47            "Anthropic" => super::introspect::Provider::Anthropic,
48            "OpenAI" => super::introspect::Provider::OpenAi,
49            _ => super::introspect::Provider::Gemini,
50        };
51        let breakdown = super::introspect::analyze_request(parsed, provider);
52        state.introspect.record(breakdown);
53    }
54
55    let (compressed_body, original_size, compressed_size) = compress_body(&body_bytes);
56
57    if compressed_size < original_size {
58        state
59            .stats
60            .record_compression(original_size, compressed_size);
61    }
62
63    let tokens_saved = original_size.saturating_sub(compressed_size) as u64 / 4;
64    super::metrics::record_request(tokens_saved, compressed_size as u64);
65
66    let model = parsed
67        .as_ref()
68        .and_then(|v| v.get("model"))
69        .and_then(|m| m.as_str());
70    super::cost::record(
71        model,
72        tokens_saved,
73        original_size as u64,
74        compressed_size as u64,
75    );
76
77    let upstream_url = build_upstream_url(&parts, upstream_base, default_path);
78    let response = send_upstream(
79        &state,
80        &parts,
81        &upstream_url,
82        compressed_body,
83        provider_label,
84    )
85    .await?;
86
87    build_response(response, extra_stream_types).await
88}
89
90fn build_upstream_url(parts: &Parts, base: &str, default_path: &str) -> String {
91    format!(
92        "{base}{}",
93        parts
94            .uri
95            .path_and_query()
96            .map_or(default_path, axum::http::uri::PathAndQuery::as_str)
97    )
98}
99
100async fn send_upstream(
101    state: &ProxyState,
102    parts: &Parts,
103    url: &str,
104    body: Vec<u8>,
105    provider_label: &str,
106) -> Result<reqwest::Response, StatusCode> {
107    let mut req = state.client.request(parts.method.clone(), url);
108
109    const ALLOWED_HEADERS: &[&str] = &[
110        "authorization",
111        "x-api-key",
112        "content-type",
113        "accept",
114        "user-agent",
115        "anthropic-version",
116        "anthropic-beta",
117        "anthropic-dangerous-direct-browser-access",
118        "openai-organization",
119        "openai-beta",
120        "x-goog-api-key",
121        "x-goog-api-client",
122    ];
123    for (key, value) in &parts.headers {
124        let k = key.as_str().to_lowercase();
125        if ALLOWED_HEADERS.contains(&k.as_str()) {
126            req = req.header(key.clone(), value.clone());
127        }
128    }
129
130    req.body(body).send().await.map_err(|e| {
131        tracing::error!("lean-ctx proxy: {provider_label} upstream error: {e}");
132        StatusCode::BAD_GATEWAY
133    })
134}
135
136const FORWARDED_HEADERS: &[&str] = &[
137    "content-type",
138    "content-encoding",
139    "x-request-id",
140    "openai-organization",
141    "openai-processing-ms",
142    "openai-version",
143    "anthropic-ratelimit-requests-limit",
144    "anthropic-ratelimit-requests-remaining",
145    "anthropic-ratelimit-tokens-limit",
146    "anthropic-ratelimit-tokens-remaining",
147    "retry-after",
148    "x-ratelimit-limit-requests",
149    "x-ratelimit-remaining-requests",
150    "x-ratelimit-limit-tokens",
151    "x-ratelimit-remaining-tokens",
152    "cache-control",
153];
154
155async fn build_response(
156    response: reqwest::Response,
157    extra_stream_types: &[&str],
158) -> Result<Response, StatusCode> {
159    let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::OK);
160    let resp_headers = response.headers().clone();
161
162    let is_stream = resp_headers
163        .get("content-type")
164        .and_then(|v| v.to_str().ok())
165        .is_some_and(|ct| {
166            ct.contains("text/event-stream") || extra_stream_types.iter().any(|t| ct.contains(t))
167        });
168
169    if is_stream {
170        let stream = response.bytes_stream();
171        let body = Body::from_stream(stream);
172        let mut resp = Response::builder().status(status);
173        for (k, v) in &resp_headers {
174            let ks = k.as_str().to_lowercase();
175            if FORWARDED_HEADERS.contains(&ks.as_str()) {
176                resp = resp.header(k, v);
177            }
178        }
179        return resp
180            .body(body)
181            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
182    }
183
184    let resp_bytes = response
185        .bytes()
186        .await
187        .map_err(|_| StatusCode::BAD_GATEWAY)?;
188
189    let mut resp = Response::builder().status(status);
190    for (k, v) in &resp_headers {
191        let ks = k.as_str().to_lowercase();
192        if FORWARDED_HEADERS.contains(&ks.as_str()) {
193            resp = resp.header(k, v);
194        }
195    }
196    resp.body(Body::from(resp_bytes))
197        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    fn parts_for(uri: &str) -> Parts {
205        Request::builder().uri(uri).body(()).unwrap().into_parts().0
206    }
207
208    #[test]
209    fn upstream_url_preserves_subpath() {
210        let base = "https://api.anthropic.com";
211        let parts = parts_for("/v1/messages/count_tokens");
212        assert_eq!(
213            build_upstream_url(&parts, base, "/v1/messages"),
214            "https://api.anthropic.com/v1/messages/count_tokens"
215        );
216    }
217
218    #[test]
219    fn upstream_url_preserves_batches_subpath() {
220        let base = "https://api.anthropic.com";
221        let parts = parts_for("/v1/messages/batches/batch_123/results");
222        assert_eq!(
223            build_upstream_url(&parts, base, "/v1/messages"),
224            "https://api.anthropic.com/v1/messages/batches/batch_123/results"
225        );
226    }
227
228    #[test]
229    fn upstream_url_exact_path() {
230        let base = "https://api.anthropic.com";
231        let parts = parts_for("/v1/messages");
232        assert_eq!(
233            build_upstream_url(&parts, base, "/v1/messages"),
234            "https://api.anthropic.com/v1/messages"
235        );
236    }
237
238    #[test]
239    fn upstream_url_preserves_query_params() {
240        let base = "https://api.anthropic.com";
241        let parts = parts_for("/v1/messages/count_tokens?model=claude-4");
242        assert_eq!(
243            build_upstream_url(&parts, base, "/v1/messages"),
244            "https://api.anthropic.com/v1/messages/count_tokens?model=claude-4"
245        );
246    }
247}