Skip to main content

lean_ctx/proxy/
forward.rs

1use axum::{
2    body::Body,
3    extract::State,
4    http::{request::Parts, Request, StatusCode},
5    response::Response,
6};
7
8use super::ProxyState;
9
10const MAX_BODY_BYTES: usize = 10 * 1024 * 1024;
11
12pub type CompressFn = fn(&[u8]) -> (Vec<u8>, usize, usize);
13
14pub async fn forward_request(
15    State(state): State<ProxyState>,
16    req: Request<Body>,
17    upstream_base: &str,
18    default_path: &str,
19    compress_body: CompressFn,
20    provider_label: &str,
21    extra_stream_types: &[&str],
22) -> Result<Response, StatusCode> {
23    let (parts, body) = req.into_parts();
24    let body_bytes = axum::body::to_bytes(body, MAX_BODY_BYTES)
25        .await
26        .map_err(|_| StatusCode::BAD_REQUEST)?;
27
28    state.stats.record_request();
29
30    // Introspect the request for context analysis
31    if let Ok(parsed) = serde_json::from_slice::<serde_json::Value>(&body_bytes) {
32        let provider = match provider_label {
33            "Anthropic" => super::introspect::Provider::Anthropic,
34            "OpenAI" => super::introspect::Provider::OpenAi,
35            _ => super::introspect::Provider::Gemini,
36        };
37        let breakdown = super::introspect::analyze_request(&parsed, provider);
38        state.introspect.record(breakdown);
39    }
40
41    let (compressed_body, original_size, compressed_size) = compress_body(&body_bytes);
42
43    if compressed_size < original_size {
44        state
45            .stats
46            .record_compression(original_size, compressed_size);
47    }
48
49    let tokens_saved = original_size.saturating_sub(compressed_size) as u64 / 4;
50    super::metrics::record_request(tokens_saved, compressed_size as u64);
51
52    let upstream_url = build_upstream_url(&parts, upstream_base, default_path);
53    let response = send_upstream(
54        &state,
55        &parts,
56        &upstream_url,
57        compressed_body,
58        provider_label,
59    )
60    .await?;
61
62    build_response(response, extra_stream_types).await
63}
64
65fn build_upstream_url(parts: &Parts, base: &str, default_path: &str) -> String {
66    format!(
67        "{base}{}",
68        parts
69            .uri
70            .path_and_query()
71            .map_or(default_path, axum::http::uri::PathAndQuery::as_str)
72    )
73}
74
75async fn send_upstream(
76    state: &ProxyState,
77    parts: &Parts,
78    url: &str,
79    body: Vec<u8>,
80    provider_label: &str,
81) -> Result<reqwest::Response, StatusCode> {
82    let mut req = state.client.request(parts.method.clone(), url);
83
84    const ALLOWED_HEADERS: &[&str] = &[
85        "authorization",
86        "x-api-key",
87        "content-type",
88        "accept",
89        "user-agent",
90        "anthropic-version",
91        "anthropic-beta",
92        "anthropic-dangerous-direct-browser-access",
93        "openai-organization",
94        "openai-beta",
95        "x-goog-api-key",
96        "x-goog-api-client",
97    ];
98    for (key, value) in &parts.headers {
99        let k = key.as_str().to_lowercase();
100        if ALLOWED_HEADERS.contains(&k.as_str()) {
101            req = req.header(key.clone(), value.clone());
102        }
103    }
104
105    req.body(body).send().await.map_err(|e| {
106        tracing::error!("lean-ctx proxy: {provider_label} upstream error: {e}");
107        StatusCode::BAD_GATEWAY
108    })
109}
110
111const FORWARDED_HEADERS: &[&str] = &[
112    "content-type",
113    "content-encoding",
114    "x-request-id",
115    "openai-organization",
116    "openai-processing-ms",
117    "openai-version",
118    "anthropic-ratelimit-requests-limit",
119    "anthropic-ratelimit-requests-remaining",
120    "anthropic-ratelimit-tokens-limit",
121    "anthropic-ratelimit-tokens-remaining",
122    "retry-after",
123    "x-ratelimit-limit-requests",
124    "x-ratelimit-remaining-requests",
125    "x-ratelimit-limit-tokens",
126    "x-ratelimit-remaining-tokens",
127    "cache-control",
128];
129
130async fn build_response(
131    response: reqwest::Response,
132    extra_stream_types: &[&str],
133) -> Result<Response, StatusCode> {
134    let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::OK);
135    let resp_headers = response.headers().clone();
136
137    let is_stream = resp_headers
138        .get("content-type")
139        .and_then(|v| v.to_str().ok())
140        .is_some_and(|ct| {
141            ct.contains("text/event-stream") || extra_stream_types.iter().any(|t| ct.contains(t))
142        });
143
144    if is_stream {
145        let stream = response.bytes_stream();
146        let body = Body::from_stream(stream);
147        let mut resp = Response::builder().status(status);
148        for (k, v) in &resp_headers {
149            let ks = k.as_str().to_lowercase();
150            if FORWARDED_HEADERS.contains(&ks.as_str()) {
151                resp = resp.header(k, v);
152            }
153        }
154        return resp
155            .body(body)
156            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
157    }
158
159    let resp_bytes = response
160        .bytes()
161        .await
162        .map_err(|_| StatusCode::BAD_GATEWAY)?;
163
164    let mut resp = Response::builder().status(status);
165    for (k, v) in &resp_headers {
166        let ks = k.as_str().to_lowercase();
167        if FORWARDED_HEADERS.contains(&ks.as_str()) {
168            resp = resp.header(k, v);
169        }
170    }
171    resp.body(Body::from(resp_bytes))
172        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
173}