1use axum::{
2 body::Body,
3 extract::State,
4 http::{request::Parts, Request, StatusCode},
5 response::Response,
6};
7
8use super::ProxyState;
9
10const DEFAULT_MAX_BODY_MB: usize = 64;
15
16fn max_body_bytes() -> usize {
17 std::env::var("LEAN_CTX_PROXY_MAX_BODY_MB")
18 .ok()
19 .and_then(|v| v.trim().parse::<usize>().ok())
20 .filter(|mb| *mb > 0)
21 .unwrap_or(DEFAULT_MAX_BODY_MB)
22 .saturating_mul(1024 * 1024)
23}
24
25pub type CompressFn = fn(&[u8]) -> (Vec<u8>, usize, usize);
26
27pub async fn forward_request(
28 State(state): State<ProxyState>,
29 req: Request<Body>,
30 upstream_base: &str,
31 default_path: &str,
32 compress_body: CompressFn,
33 provider_label: &str,
34 extra_stream_types: &[&str],
35) -> Result<Response, StatusCode> {
36 let (parts, body) = req.into_parts();
37 let body_bytes = axum::body::to_bytes(body, max_body_bytes())
38 .await
39 .map_err(|_| StatusCode::PAYLOAD_TOO_LARGE)?;
40
41 state.stats.record_request();
42
43 let parsed = serde_json::from_slice::<serde_json::Value>(&body_bytes).ok();
45 if let Some(parsed) = &parsed {
46 let provider = match provider_label {
47 "Anthropic" => super::introspect::Provider::Anthropic,
48 "OpenAI" => super::introspect::Provider::OpenAi,
49 _ => super::introspect::Provider::Gemini,
50 };
51 let breakdown = super::introspect::analyze_request(parsed, provider);
52 state.introspect.record(breakdown);
53 }
54
55 let (compressed_body, original_size, compressed_size) = compress_body(&body_bytes);
56
57 if compressed_size < original_size {
58 state
59 .stats
60 .record_compression(original_size, compressed_size);
61 }
62
63 let tokens_saved = original_size.saturating_sub(compressed_size) as u64 / 4;
64 super::metrics::record_request(tokens_saved, compressed_size as u64);
65
66 let model = parsed
67 .as_ref()
68 .and_then(|v| v.get("model"))
69 .and_then(|m| m.as_str());
70 super::cost::record(
71 model,
72 tokens_saved,
73 original_size as u64,
74 compressed_size as u64,
75 );
76
77 let upstream_url = build_upstream_url(&parts, upstream_base, default_path);
78 let response = send_upstream(
79 &state,
80 &parts,
81 &upstream_url,
82 compressed_body,
83 provider_label,
84 )
85 .await?;
86
87 build_response(response, extra_stream_types).await
88}
89
90fn build_upstream_url(parts: &Parts, base: &str, default_path: &str) -> String {
91 format!(
92 "{base}{}",
93 parts
94 .uri
95 .path_and_query()
96 .map_or(default_path, axum::http::uri::PathAndQuery::as_str)
97 )
98}
99
100async fn send_upstream(
101 state: &ProxyState,
102 parts: &Parts,
103 url: &str,
104 body: Vec<u8>,
105 provider_label: &str,
106) -> Result<reqwest::Response, StatusCode> {
107 let mut req = state.client.request(parts.method.clone(), url);
108
109 const ALLOWED_HEADERS: &[&str] = &[
110 "authorization",
111 "x-api-key",
112 "content-type",
113 "accept",
114 "user-agent",
115 "anthropic-version",
116 "anthropic-beta",
117 "anthropic-dangerous-direct-browser-access",
118 "openai-organization",
119 "openai-beta",
120 "x-goog-api-key",
121 "x-goog-api-client",
122 ];
123 for (key, value) in &parts.headers {
124 let k = key.as_str().to_lowercase();
125 if ALLOWED_HEADERS.contains(&k.as_str()) {
126 req = req.header(key.clone(), value.clone());
127 }
128 }
129
130 req.body(body).send().await.map_err(|e| {
131 tracing::error!("lean-ctx proxy: {provider_label} upstream error: {e}");
132 StatusCode::BAD_GATEWAY
133 })
134}
135
136const FORWARDED_HEADERS: &[&str] = &[
137 "content-type",
138 "content-encoding",
139 "x-request-id",
140 "openai-organization",
141 "openai-processing-ms",
142 "openai-version",
143 "anthropic-ratelimit-requests-limit",
144 "anthropic-ratelimit-requests-remaining",
145 "anthropic-ratelimit-tokens-limit",
146 "anthropic-ratelimit-tokens-remaining",
147 "retry-after",
148 "x-ratelimit-limit-requests",
149 "x-ratelimit-remaining-requests",
150 "x-ratelimit-limit-tokens",
151 "x-ratelimit-remaining-tokens",
152 "cache-control",
153];
154
155async fn build_response(
156 response: reqwest::Response,
157 extra_stream_types: &[&str],
158) -> Result<Response, StatusCode> {
159 let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::OK);
160 let resp_headers = response.headers().clone();
161
162 let is_stream = resp_headers
163 .get("content-type")
164 .and_then(|v| v.to_str().ok())
165 .is_some_and(|ct| {
166 ct.contains("text/event-stream") || extra_stream_types.iter().any(|t| ct.contains(t))
167 });
168
169 if is_stream {
170 let stream = response.bytes_stream();
171 let body = Body::from_stream(stream);
172 let mut resp = Response::builder().status(status);
173 for (k, v) in &resp_headers {
174 let ks = k.as_str().to_lowercase();
175 if FORWARDED_HEADERS.contains(&ks.as_str()) {
176 resp = resp.header(k, v);
177 }
178 }
179 return resp
180 .body(body)
181 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
182 }
183
184 let resp_bytes = response
185 .bytes()
186 .await
187 .map_err(|_| StatusCode::BAD_GATEWAY)?;
188
189 let mut resp = Response::builder().status(status);
190 for (k, v) in &resp_headers {
191 let ks = k.as_str().to_lowercase();
192 if FORWARDED_HEADERS.contains(&ks.as_str()) {
193 resp = resp.header(k, v);
194 }
195 }
196 resp.body(Body::from(resp_bytes))
197 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 fn parts_for(uri: &str) -> Parts {
205 Request::builder().uri(uri).body(()).unwrap().into_parts().0
206 }
207
208 #[test]
209 fn upstream_url_preserves_subpath() {
210 let base = "https://api.anthropic.com";
211 let parts = parts_for("/v1/messages/count_tokens");
212 assert_eq!(
213 build_upstream_url(&parts, base, "/v1/messages"),
214 "https://api.anthropic.com/v1/messages/count_tokens"
215 );
216 }
217
218 #[test]
219 fn upstream_url_preserves_batches_subpath() {
220 let base = "https://api.anthropic.com";
221 let parts = parts_for("/v1/messages/batches/batch_123/results");
222 assert_eq!(
223 build_upstream_url(&parts, base, "/v1/messages"),
224 "https://api.anthropic.com/v1/messages/batches/batch_123/results"
225 );
226 }
227
228 #[test]
229 fn upstream_url_exact_path() {
230 let base = "https://api.anthropic.com";
231 let parts = parts_for("/v1/messages");
232 assert_eq!(
233 build_upstream_url(&parts, base, "/v1/messages"),
234 "https://api.anthropic.com/v1/messages"
235 );
236 }
237
238 #[test]
239 fn upstream_url_preserves_query_params() {
240 let base = "https://api.anthropic.com";
241 let parts = parts_for("/v1/messages/count_tokens?model=claude-4");
242 assert_eq!(
243 build_upstream_url(&parts, base, "/v1/messages"),
244 "https://api.anthropic.com/v1/messages/count_tokens?model=claude-4"
245 );
246 }
247}