lean_ctx/proxy/
forward.rs1use axum::{
2 body::Body,
3 extract::State,
4 http::{request::Parts, Request, StatusCode},
5 response::Response,
6};
7
8use super::ProxyState;
9
10const MAX_BODY_BYTES: usize = 10 * 1024 * 1024;
11
12pub type CompressFn = fn(&[u8]) -> (Vec<u8>, usize, usize);
13
14pub async fn forward_request(
15 State(state): State<ProxyState>,
16 req: Request<Body>,
17 upstream_base: &str,
18 default_path: &str,
19 compress_body: CompressFn,
20 provider_label: &str,
21 extra_stream_types: &[&str],
22) -> Result<Response, StatusCode> {
23 let (parts, body) = req.into_parts();
24 let body_bytes = axum::body::to_bytes(body, MAX_BODY_BYTES)
25 .await
26 .map_err(|_| StatusCode::BAD_REQUEST)?;
27
28 state.stats.record_request();
29
30 if let Ok(parsed) = serde_json::from_slice::<serde_json::Value>(&body_bytes) {
32 let provider = match provider_label {
33 "Anthropic" => super::introspect::Provider::Anthropic,
34 "OpenAI" => super::introspect::Provider::OpenAi,
35 _ => super::introspect::Provider::Gemini,
36 };
37 let breakdown = super::introspect::analyze_request(&parsed, provider);
38 state.introspect.record(breakdown);
39 }
40
41 let (compressed_body, original_size, compressed_size) = compress_body(&body_bytes);
42
43 if compressed_size < original_size {
44 state
45 .stats
46 .record_compression(original_size, compressed_size);
47 }
48
49 let upstream_url = build_upstream_url(&parts, upstream_base, default_path);
50 let response = send_upstream(
51 &state,
52 &parts,
53 &upstream_url,
54 compressed_body,
55 provider_label,
56 )
57 .await?;
58
59 build_response(response, extra_stream_types).await
60}
61
62fn build_upstream_url(parts: &Parts, base: &str, default_path: &str) -> String {
63 format!(
64 "{base}{}",
65 parts
66 .uri
67 .path_and_query()
68 .map_or(default_path, axum::http::uri::PathAndQuery::as_str)
69 )
70}
71
72async fn send_upstream(
73 state: &ProxyState,
74 parts: &Parts,
75 url: &str,
76 body: Vec<u8>,
77 provider_label: &str,
78) -> Result<reqwest::Response, StatusCode> {
79 let mut req = state.client.request(parts.method.clone(), url);
80
81 const ALLOWED_HEADERS: &[&str] = &[
82 "authorization",
83 "x-api-key",
84 "content-type",
85 "accept",
86 "user-agent",
87 "anthropic-version",
88 "anthropic-beta",
89 "anthropic-dangerous-direct-browser-access",
90 "openai-organization",
91 "openai-beta",
92 "x-goog-api-key",
93 "x-goog-api-client",
94 ];
95 for (key, value) in &parts.headers {
96 let k = key.as_str().to_lowercase();
97 if ALLOWED_HEADERS.contains(&k.as_str()) {
98 req = req.header(key.clone(), value.clone());
99 }
100 }
101
102 req.body(body).send().await.map_err(|e| {
103 tracing::error!("lean-ctx proxy: {provider_label} upstream error: {e}");
104 StatusCode::BAD_GATEWAY
105 })
106}
107
108async fn build_response(
109 response: reqwest::Response,
110 extra_stream_types: &[&str],
111) -> Result<Response, StatusCode> {
112 let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::OK);
113 let resp_headers = response.headers().clone();
114
115 let is_stream = resp_headers
116 .get("content-type")
117 .and_then(|v| v.to_str().ok())
118 .is_some_and(|ct| {
119 ct.contains("text/event-stream") || extra_stream_types.iter().any(|t| ct.contains(t))
120 });
121
122 if is_stream {
123 let stream = response.bytes_stream();
124 let body = Body::from_stream(stream);
125 let mut resp = Response::builder().status(status);
126 for (k, v) in &resp_headers {
127 resp = resp.header(k, v);
128 }
129 return resp
130 .body(body)
131 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
132 }
133
134 let resp_bytes = response
135 .bytes()
136 .await
137 .map_err(|_| StatusCode::BAD_GATEWAY)?;
138
139 let mut resp = Response::builder().status(status);
140 for (k, v) in &resp_headers {
141 let ks = k.as_str().to_lowercase();
142 if ks == "transfer-encoding" || ks == "content-length" {
143 continue;
144 }
145 resp = resp.header(k, v);
146 }
147 resp.body(Body::from(resp_bytes))
148 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
149}