Skip to main content

lean_ctx/proxy/
forward.rs

1use axum::{
2    body::Body,
3    extract::State,
4    http::{request::Parts, Request, StatusCode},
5    response::Response,
6};
7
8use super::ProxyState;
9
10const MAX_BODY_BYTES: usize = 10 * 1024 * 1024;
11
12pub type CompressFn = fn(&[u8]) -> (Vec<u8>, usize, usize);
13
14pub async fn forward_request(
15    State(state): State<ProxyState>,
16    req: Request<Body>,
17    upstream_base: &str,
18    default_path: &str,
19    compress_body: CompressFn,
20    provider_label: &str,
21    extra_stream_types: &[&str],
22) -> Result<Response, StatusCode> {
23    let (parts, body) = req.into_parts();
24    let body_bytes = axum::body::to_bytes(body, MAX_BODY_BYTES)
25        .await
26        .map_err(|_| StatusCode::BAD_REQUEST)?;
27
28    state.stats.record_request();
29
30    let (compressed_body, original_size, compressed_size) = compress_body(&body_bytes);
31
32    if compressed_size < original_size {
33        state
34            .stats
35            .record_compression(original_size, compressed_size);
36    }
37
38    let upstream_url = build_upstream_url(&parts, upstream_base, default_path);
39    let response = send_upstream(
40        &state,
41        &parts,
42        &upstream_url,
43        compressed_body,
44        provider_label,
45    )
46    .await?;
47
48    build_response(response, extra_stream_types).await
49}
50
51fn build_upstream_url(parts: &Parts, base: &str, default_path: &str) -> String {
52    format!(
53        "{base}{}",
54        parts
55            .uri
56            .path_and_query()
57            .map(|pq| pq.as_str())
58            .unwrap_or(default_path)
59    )
60}
61
62async fn send_upstream(
63    state: &ProxyState,
64    parts: &Parts,
65    url: &str,
66    body: Vec<u8>,
67    provider_label: &str,
68) -> Result<reqwest::Response, StatusCode> {
69    let mut req = state.client.request(parts.method.clone(), url);
70
71    for (key, value) in &parts.headers {
72        let k = key.as_str().to_lowercase();
73        if k == "host" || k == "content-length" || k == "transfer-encoding" {
74            continue;
75        }
76        req = req.header(key.clone(), value.clone());
77    }
78
79    req.body(body).send().await.map_err(|e| {
80        eprintln!("lean-ctx proxy: {provider_label} upstream error: {e}");
81        StatusCode::BAD_GATEWAY
82    })
83}
84
85async fn build_response(
86    response: reqwest::Response,
87    extra_stream_types: &[&str],
88) -> Result<Response, StatusCode> {
89    let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::OK);
90    let resp_headers = response.headers().clone();
91
92    let is_stream = resp_headers
93        .get("content-type")
94        .and_then(|v| v.to_str().ok())
95        .is_some_and(|ct| {
96            ct.contains("text/event-stream") || extra_stream_types.iter().any(|t| ct.contains(t))
97        });
98
99    if is_stream {
100        let stream = response.bytes_stream();
101        let body = Body::from_stream(stream);
102        let mut resp = Response::builder().status(status);
103        for (k, v) in &resp_headers {
104            resp = resp.header(k, v);
105        }
106        return resp
107            .body(body)
108            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
109    }
110
111    let resp_bytes = response
112        .bytes()
113        .await
114        .map_err(|_| StatusCode::BAD_GATEWAY)?;
115
116    let mut resp = Response::builder().status(status);
117    for (k, v) in &resp_headers {
118        let ks = k.as_str().to_lowercase();
119        if ks == "transfer-encoding" || ks == "content-length" {
120            continue;
121        }
122        resp = resp.header(k, v);
123    }
124    resp.body(Body::from(resp_bytes))
125        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
126}