lean_ctx/proxy/
forward.rs1use axum::{
2 body::Body,
3 extract::State,
4 http::{request::Parts, Request, StatusCode},
5 response::Response,
6};
7
8use super::ProxyState;
9
10const MAX_BODY_BYTES: usize = 10 * 1024 * 1024;
11
12pub type CompressFn = fn(&[u8]) -> (Vec<u8>, usize, usize);
13
14pub async fn forward_request(
15 State(state): State<ProxyState>,
16 req: Request<Body>,
17 upstream_base: &str,
18 default_path: &str,
19 compress_body: CompressFn,
20 provider_label: &str,
21 extra_stream_types: &[&str],
22) -> Result<Response, StatusCode> {
23 let (parts, body) = req.into_parts();
24 let body_bytes = axum::body::to_bytes(body, MAX_BODY_BYTES)
25 .await
26 .map_err(|_| StatusCode::BAD_REQUEST)?;
27
28 state.stats.record_request();
29
30 let (compressed_body, original_size, compressed_size) = compress_body(&body_bytes);
31
32 if compressed_size < original_size {
33 state
34 .stats
35 .record_compression(original_size, compressed_size);
36 }
37
38 let upstream_url = build_upstream_url(&parts, upstream_base, default_path);
39 let response = send_upstream(
40 &state,
41 &parts,
42 &upstream_url,
43 compressed_body,
44 provider_label,
45 )
46 .await?;
47
48 build_response(response, extra_stream_types).await
49}
50
51fn build_upstream_url(parts: &Parts, base: &str, default_path: &str) -> String {
52 format!(
53 "{base}{}",
54 parts
55 .uri
56 .path_and_query()
57 .map_or(default_path, axum::http::uri::PathAndQuery::as_str)
58 )
59}
60
61async fn send_upstream(
62 state: &ProxyState,
63 parts: &Parts,
64 url: &str,
65 body: Vec<u8>,
66 provider_label: &str,
67) -> Result<reqwest::Response, StatusCode> {
68 let mut req = state.client.request(parts.method.clone(), url);
69
70 for (key, value) in &parts.headers {
71 let k = key.as_str().to_lowercase();
72 if k == "host" || k == "content-length" || k == "transfer-encoding" {
73 continue;
74 }
75 req = req.header(key.clone(), value.clone());
76 }
77
78 req.body(body).send().await.map_err(|e| {
79 tracing::error!("lean-ctx proxy: {provider_label} upstream error: {e}");
80 StatusCode::BAD_GATEWAY
81 })
82}
83
84async fn build_response(
85 response: reqwest::Response,
86 extra_stream_types: &[&str],
87) -> Result<Response, StatusCode> {
88 let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::OK);
89 let resp_headers = response.headers().clone();
90
91 let is_stream = resp_headers
92 .get("content-type")
93 .and_then(|v| v.to_str().ok())
94 .is_some_and(|ct| {
95 ct.contains("text/event-stream") || extra_stream_types.iter().any(|t| ct.contains(t))
96 });
97
98 if is_stream {
99 let stream = response.bytes_stream();
100 let body = Body::from_stream(stream);
101 let mut resp = Response::builder().status(status);
102 for (k, v) in &resp_headers {
103 resp = resp.header(k, v);
104 }
105 return resp
106 .body(body)
107 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
108 }
109
110 let resp_bytes = response
111 .bytes()
112 .await
113 .map_err(|_| StatusCode::BAD_GATEWAY)?;
114
115 let mut resp = Response::builder().status(status);
116 for (k, v) in &resp_headers {
117 let ks = k.as_str().to_lowercase();
118 if ks == "transfer-encoding" || ks == "content-length" {
119 continue;
120 }
121 resp = resp.header(k, v);
122 }
123 resp.body(Body::from(resp_bytes))
124 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
125}