1use axum::body::Body;
7use axum::http::{HeaderMap, HeaderValue, StatusCode};
8use axum::response::{IntoResponse, Response};
9use reqwest::Client;
10
11const STRIP_RESPONSE_HEADERS: &[&str] = &[
20 "content-encoding",
21 "content-length",
22 "transfer-encoding",
23 "connection",
24 "keep-alive",
25 "proxy-authenticate",
26 "proxy-authorization",
27 "te",
28 "trailer",
29 "upgrade",
30];
31
32const STRIP_REQUEST_HEADERS: &[&str] = &[
44 "host",
45 "content-length",
46 "accept-encoding",
47 "content-encoding",
48];
49
50pub async fn forward_request(
56 client: &Client,
57 method: &str,
58 target_url: &str,
59 request_headers: &HeaderMap,
60 body: Option<String>,
61) -> Response {
62 let forwarded_headers = forward_headers(request_headers);
63
64 let mut req = match method.to_uppercase().as_str() {
65 "POST" => client.post(target_url),
66 "GET" => client.get(target_url),
67 "PUT" => client.put(target_url),
68 "DELETE" => client.delete(target_url),
69 "PATCH" => client.patch(target_url),
70 _ => client.post(target_url),
71 };
72
73 req = req.headers(forwarded_headers);
74
75 if let Some(body) = body {
76 req = req.body(body);
77 }
78
79 match req.send().await {
80 Ok(upstream) => stream_response(upstream),
81 Err(err) => {
82 tracing::error!("upstream error: {}", err);
83 (
84 StatusCode::BAD_GATEWAY,
85 axum::Json(serde_json::json!({
86 "error": "Bad Gateway",
87 "detail": "Upstream provider unreachable"
88 })),
89 )
90 .into_response()
91 }
92 }
93}
94
95fn stream_response(upstream: reqwest::Response) -> Response {
98 let status = StatusCode::from_u16(upstream.status().as_u16()).unwrap_or(StatusCode::OK);
99
100 let mut response_headers = HeaderMap::new();
101 for (name, value) in upstream.headers() {
102 let name_str = name.as_str().to_lowercase();
103 if STRIP_RESPONSE_HEADERS
104 .iter()
105 .any(|h| h == &name_str.as_str())
106 {
107 continue;
108 }
109 if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
110 response_headers.insert(name.clone(), v);
111 }
112 }
113
114 let body = Body::from_stream(upstream.bytes_stream());
117
118 let mut response = Response::new(body);
119 *response.status_mut() = status;
120 *response.headers_mut() = response_headers;
121 response
122}
123
124fn forward_headers(original: &HeaderMap) -> HeaderMap {
126 let strip: std::collections::HashSet<&str> = STRIP_REQUEST_HEADERS.iter().copied().collect();
127
128 let mut result = HeaderMap::new();
129 for (name, value) in original {
130 let name_lower = name.as_str().to_lowercase();
131 if !strip.contains(name_lower.as_str()) {
132 result.insert(name.clone(), value.clone());
133 }
134 }
135 result
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use axum::http::header;
142
143 #[test]
144 fn forward_headers_strips_host_content_length_and_accept_encoding() {
145 let mut headers = HeaderMap::new();
146 headers.insert(header::HOST, "example.com".parse().unwrap());
147 headers.insert(header::CONTENT_LENGTH, "42".parse().unwrap());
148 headers.insert(header::ACCEPT_ENCODING, "gzip, br".parse().unwrap());
149 headers.insert(header::CONTENT_ENCODING, "gzip".parse().unwrap());
150 headers.insert(header::AUTHORIZATION, "Bearer sk-test".parse().unwrap());
151 headers.insert("x-api-key", "sk-ant-test".parse().unwrap());
152 headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
153
154 let result = forward_headers(&headers);
155
156 assert!(result.get(header::HOST).is_none());
157 assert!(result.get(header::CONTENT_LENGTH).is_none());
158 assert!(
159 result.get(header::ACCEPT_ENCODING).is_none(),
160 "accept-encoding should be stripped so reqwest negotiates its own"
161 );
162 assert!(
163 result.get(header::CONTENT_ENCODING).is_none(),
164 "content-encoding should be stripped — body was already decompressed"
165 );
166 assert_eq!(result.get(header::AUTHORIZATION).unwrap(), "Bearer sk-test");
167 assert_eq!(result.get("x-api-key").unwrap(), "sk-ant-test");
168 assert_eq!(result.get("anthropic-version").unwrap(), "2023-06-01");
169 }
170}