Skip to main content

shift_proxy/
forward.rs

1//! Forward requests to upstream provider APIs and stream responses back.
2//!
3//! Handles header forwarding (auth passthrough), hop-by-hop header stripping
4//! (RFC 9110 §7.6.1), and transparent SSE/chunked response streaming.
5
6use axum::body::Body;
7use axum::http::{HeaderMap, HeaderValue, StatusCode};
8use axum::response::{IntoResponse, Response};
9use reqwest::Client;
10
11/// Headers stripped from upstream responses before forwarding to the client.
12///
13/// - `content-encoding` / `content-length`: reqwest auto-decompresses response
14///   bodies, so these are stale. Forwarding them causes double-decompression.
15///   NOTE: The `gzip`, `brotli`, and `deflate` features MUST be enabled on reqwest
16///   for this stripping to be correct. Without them, reqwest does NOT decompress,
17///   and stripping content-encoding causes clients to receive raw compressed bytes.
18/// - Hop-by-hop headers per RFC 9110 §7.6.1.
19const STRIP_RESPONSE_HEADERS: &[&str] = &[
20    "content-encoding",
21    "content-length",
22    "transfer-encoding",
23    "connection",
24    "keep-alive",
25    "proxy-authenticate",
26    "proxy-authorization",
27    "te",
28    "trailer",
29    "upgrade",
30];
31
32/// Headers stripped from the forwarded request.
33///
34/// - `host` / `content-length`: stale for the upstream connection.
35/// - `accept-encoding`: let reqwest negotiate compression based on its enabled
36///   decompression features (`gzip`, `brotli`, `deflate`). Forwarding the client's
37///   header could request encodings reqwest can't decompress (e.g., `zstd`), which
38///   would result in raw compressed bytes reaching the client after we strip
39///   `content-encoding`.
40/// - `content-encoding`: the proxy already decompressed the request body (see
41///   `body::extract_body`), so telling the upstream it's still compressed would
42///   cause a decode error on their end.
43const STRIP_REQUEST_HEADERS: &[&str] = &[
44    "host",
45    "content-length",
46    "accept-encoding",
47    "content-encoding",
48];
49
50/// Forward a request to an upstream URL, streaming the response back.
51///
52/// Auth headers (`authorization`, `x-api-key`, `anthropic-version`, `x-goog-api-key`)
53/// pass through unchanged. The response body is streamed directly — SSE and
54/// chunked responses are not buffered.
55pub async fn forward_request(
56    client: &Client,
57    method: &str,
58    target_url: &str,
59    request_headers: &HeaderMap,
60    body: Option<String>,
61) -> Response {
62    let forwarded_headers = forward_headers(request_headers);
63
64    let mut req = match method.to_uppercase().as_str() {
65        "POST" => client.post(target_url),
66        "GET" => client.get(target_url),
67        "PUT" => client.put(target_url),
68        "DELETE" => client.delete(target_url),
69        "PATCH" => client.patch(target_url),
70        _ => client.post(target_url),
71    };
72
73    req = req.headers(forwarded_headers);
74
75    if let Some(body) = body {
76        req = req.body(body);
77    }
78
79    match req.send().await {
80        Ok(upstream) => stream_response(upstream),
81        Err(err) => {
82            tracing::error!("upstream error: {}", err);
83            (
84                StatusCode::BAD_GATEWAY,
85                axum::Json(serde_json::json!({
86                    "error": "Bad Gateway",
87                    "detail": "Upstream provider unreachable"
88                })),
89            )
90                .into_response()
91        }
92    }
93}
94
95/// Convert a reqwest Response into an axum Response, streaming the body
96/// and stripping hop-by-hop headers.
97fn stream_response(upstream: reqwest::Response) -> Response {
98    let status = StatusCode::from_u16(upstream.status().as_u16()).unwrap_or(StatusCode::OK);
99
100    let mut response_headers = HeaderMap::new();
101    for (name, value) in upstream.headers() {
102        let name_str = name.as_str().to_lowercase();
103        if STRIP_RESPONSE_HEADERS
104            .iter()
105            .any(|h| h == &name_str.as_str())
106        {
107            continue;
108        }
109        if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
110            response_headers.insert(name.clone(), v);
111        }
112    }
113
114    // Stream the response body directly without buffering.
115    // This is critical for SSE (Anthropic/OpenAI streaming) to work correctly.
116    let body = Body::from_stream(upstream.bytes_stream());
117
118    let mut response = Response::new(body);
119    *response.status_mut() = status;
120    *response.headers_mut() = response_headers;
121    response
122}
123
124/// Forward request headers, stripping host/content-length but passing auth through.
125fn forward_headers(original: &HeaderMap) -> HeaderMap {
126    let strip: std::collections::HashSet<&str> = STRIP_REQUEST_HEADERS.iter().copied().collect();
127
128    let mut result = HeaderMap::new();
129    for (name, value) in original {
130        let name_lower = name.as_str().to_lowercase();
131        if !strip.contains(name_lower.as_str()) {
132            result.insert(name.clone(), value.clone());
133        }
134    }
135    result
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use axum::http::header;
142
143    #[test]
144    fn forward_headers_strips_host_content_length_and_accept_encoding() {
145        let mut headers = HeaderMap::new();
146        headers.insert(header::HOST, "example.com".parse().unwrap());
147        headers.insert(header::CONTENT_LENGTH, "42".parse().unwrap());
148        headers.insert(header::ACCEPT_ENCODING, "gzip, br".parse().unwrap());
149        headers.insert(header::CONTENT_ENCODING, "gzip".parse().unwrap());
150        headers.insert(header::AUTHORIZATION, "Bearer sk-test".parse().unwrap());
151        headers.insert("x-api-key", "sk-ant-test".parse().unwrap());
152        headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
153
154        let result = forward_headers(&headers);
155
156        assert!(result.get(header::HOST).is_none());
157        assert!(result.get(header::CONTENT_LENGTH).is_none());
158        assert!(
159            result.get(header::ACCEPT_ENCODING).is_none(),
160            "accept-encoding should be stripped so reqwest negotiates its own"
161        );
162        assert!(
163            result.get(header::CONTENT_ENCODING).is_none(),
164            "content-encoding should be stripped — body was already decompressed"
165        );
166        assert_eq!(result.get(header::AUTHORIZATION).unwrap(), "Bearer sk-test");
167        assert_eq!(result.get("x-api-key").unwrap(), "sk-ant-test");
168        assert_eq!(result.get("anthropic-version").unwrap(), "2023-06-01");
169    }
170}