1use axum::{
2 body::Body,
3 extract::State,
4 http::{request::Parts, Request, StatusCode},
5 response::Response,
6};
7
8use super::ProxyState;
9
10const DEFAULT_MAX_BODY_MB: usize = 64;
15
16fn max_body_bytes() -> usize {
17 std::env::var("LEAN_CTX_PROXY_MAX_BODY_MB")
18 .ok()
19 .and_then(|v| v.trim().parse::<usize>().ok())
20 .filter(|mb| *mb > 0)
21 .unwrap_or(DEFAULT_MAX_BODY_MB)
22 .saturating_mul(1024 * 1024)
23}
24
25pub type CompressFn = fn(serde_json::Value, usize) -> (Vec<u8>, usize, usize);
29
30pub async fn forward_request(
31 State(state): State<ProxyState>,
32 req: Request<Body>,
33 upstream_base: &str,
34 default_path: &str,
35 compress_body: CompressFn,
36 provider_label: &str,
37 extra_stream_types: &[&str],
38) -> Result<Response, StatusCode> {
39 let (parts, body) = req.into_parts();
40 let body_bytes = axum::body::to_bytes(body, max_body_bytes())
41 .await
42 .map_err(|_| StatusCode::PAYLOAD_TOO_LARGE)?;
43
44 state.stats.record_request();
45
46 let original_size = body_bytes.len();
47
48 let parsed = serde_json::from_slice::<serde_json::Value>(&body_bytes).ok();
52 if let Some(ref parsed) = parsed {
53 let provider = match provider_label {
54 "Anthropic" => super::introspect::Provider::Anthropic,
55 "OpenAI" => super::introspect::Provider::OpenAi,
56 _ => super::introspect::Provider::Gemini,
57 };
58 let breakdown = super::introspect::analyze_request(parsed, provider);
59 state.introspect.record(breakdown);
60 }
61
62 let (compressed_body, _, compressed_size) = if let Some(value) = parsed.clone() {
63 compress_body(value, original_size)
64 } else {
65 (body_bytes.to_vec(), original_size, original_size)
66 };
67
68 if compressed_size < original_size {
69 state
70 .stats
71 .record_compression(original_size, compressed_size);
72 }
73
74 let tokens_saved = original_size.saturating_sub(compressed_size) as u64 / 4;
75 super::metrics::record_request(tokens_saved, compressed_size as u64);
76
77 let model = parsed
78 .as_ref()
79 .and_then(|v| v.get("model"))
80 .and_then(|m| m.as_str());
81 super::cost::record(
82 model,
83 tokens_saved,
84 original_size as u64,
85 compressed_size as u64,
86 );
87
88 let upstream_url = build_upstream_url(&parts, upstream_base, default_path);
89 let response = send_upstream(
90 &state,
91 &parts,
92 &upstream_url,
93 compressed_body,
94 provider_label,
95 )
96 .await?;
97
98 build_response(response, extra_stream_types).await
99}
100
101fn build_upstream_url(parts: &Parts, base: &str, default_path: &str) -> String {
102 format!(
103 "{base}{}",
104 parts
105 .uri
106 .path_and_query()
107 .map_or(default_path, axum::http::uri::PathAndQuery::as_str)
108 )
109}
110
111async fn send_upstream(
112 state: &ProxyState,
113 parts: &Parts,
114 url: &str,
115 body: Vec<u8>,
116 provider_label: &str,
117) -> Result<reqwest::Response, StatusCode> {
118 let mut req = state.client.request(parts.method.clone(), url);
119
120 const ALLOWED_HEADERS: &[&str] = &[
121 "authorization",
122 "x-api-key",
123 "content-type",
124 "accept",
125 "user-agent",
126 "anthropic-version",
127 "anthropic-beta",
128 "anthropic-dangerous-direct-browser-access",
129 "openai-organization",
130 "openai-beta",
131 "x-goog-api-key",
132 "x-goog-api-client",
133 ];
134 for (key, value) in &parts.headers {
135 let k = key.as_str().to_lowercase();
136 if ALLOWED_HEADERS.contains(&k.as_str()) {
137 req = req.header(key.clone(), value.clone());
138 }
139 }
140
141 req.body(body).send().await.map_err(|e| {
142 tracing::error!("lean-ctx proxy: {provider_label} upstream error: {e}");
143 StatusCode::BAD_GATEWAY
144 })
145}
146
147const FORWARDED_HEADERS: &[&str] = &[
148 "content-type",
149 "content-encoding",
150 "x-request-id",
151 "openai-organization",
152 "openai-processing-ms",
153 "openai-version",
154 "anthropic-ratelimit-requests-limit",
155 "anthropic-ratelimit-requests-remaining",
156 "anthropic-ratelimit-tokens-limit",
157 "anthropic-ratelimit-tokens-remaining",
158 "retry-after",
159 "x-ratelimit-limit-requests",
160 "x-ratelimit-remaining-requests",
161 "x-ratelimit-limit-tokens",
162 "x-ratelimit-remaining-tokens",
163 "cache-control",
164];
165
166async fn build_response(
167 response: reqwest::Response,
168 extra_stream_types: &[&str],
169) -> Result<Response, StatusCode> {
170 let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::OK);
171 let resp_headers = response.headers().clone();
172
173 let is_stream = resp_headers
174 .get("content-type")
175 .and_then(|v| v.to_str().ok())
176 .is_some_and(|ct| {
177 ct.contains("text/event-stream") || extra_stream_types.iter().any(|t| ct.contains(t))
178 });
179
180 if is_stream {
181 let stream = response.bytes_stream();
182 let body = Body::from_stream(stream);
183 let mut resp = Response::builder().status(status);
184 for (k, v) in &resp_headers {
185 let ks = k.as_str().to_lowercase();
186 if FORWARDED_HEADERS.contains(&ks.as_str()) {
187 resp = resp.header(k, v);
188 }
189 }
190 return resp
191 .body(body)
192 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
193 }
194
195 let resp_bytes = response
196 .bytes()
197 .await
198 .map_err(|_| StatusCode::BAD_GATEWAY)?;
199
200 let mut resp = Response::builder().status(status);
201 for (k, v) in &resp_headers {
202 let ks = k.as_str().to_lowercase();
203 if FORWARDED_HEADERS.contains(&ks.as_str()) {
204 resp = resp.header(k, v);
205 }
206 }
207 resp.body(Body::from(resp_bytes))
208 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 fn parts_for(uri: &str) -> Parts {
216 Request::builder().uri(uri).body(()).unwrap().into_parts().0
217 }
218
219 #[test]
220 fn upstream_url_preserves_subpath() {
221 let base = "https://api.anthropic.com";
222 let parts = parts_for("/v1/messages/count_tokens");
223 assert_eq!(
224 build_upstream_url(&parts, base, "/v1/messages"),
225 "https://api.anthropic.com/v1/messages/count_tokens"
226 );
227 }
228
229 #[test]
230 fn upstream_url_preserves_batches_subpath() {
231 let base = "https://api.anthropic.com";
232 let parts = parts_for("/v1/messages/batches/batch_123/results");
233 assert_eq!(
234 build_upstream_url(&parts, base, "/v1/messages"),
235 "https://api.anthropic.com/v1/messages/batches/batch_123/results"
236 );
237 }
238
239 #[test]
240 fn upstream_url_exact_path() {
241 let base = "https://api.anthropic.com";
242 let parts = parts_for("/v1/messages");
243 assert_eq!(
244 build_upstream_url(&parts, base, "/v1/messages"),
245 "https://api.anthropic.com/v1/messages"
246 );
247 }
248
249 #[test]
250 fn upstream_url_preserves_query_params() {
251 let base = "https://api.anthropic.com";
252 let parts = parts_for("/v1/messages/count_tokens?model=claude-4");
253 assert_eq!(
254 build_upstream_url(&parts, base, "/v1/messages"),
255 "https://api.anthropic.com/v1/messages/count_tokens?model=claude-4"
256 );
257 }
258}