1use anyhow::{Context, Result};
2use axum::body::Body;
3use axum::http::{HeaderMap, HeaderName, HeaderValue, Response};
4use bytes::Bytes;
5use reqwest::Client;
6use std::str::FromStr;
7use uuid::Uuid;
8
9use crate::config::AccountConfig;
10
11const HOP_BY_HOP: &[&str] = &[
13 "connection",
14 "keep-alive",
15 "proxy-authenticate",
16 "proxy-authorization",
17 "te",
18 "trailers",
19 "transfer-encoding",
20 "upgrade",
21 "host",
22 "content-length",
23];
24
25const ALLOWED_REQUEST_HEADERS: &[&str] = &[
28 "content-type",
29 "accept",
30 "anthropic-version",
31 "anthropic-beta",
32 "anthropic-dangerous-direct-browser-access",
33 "x-request-id",
34 "user-agent",
35 "openai-sentinel-chat-requirements-token",
37];
38
39const BLOCKED_RESPONSE_HEADERS: &[&str] = &[
41 "set-cookie",
42 "set-cookie2",
43 "access-control-allow-origin",
44 "access-control-allow-credentials",
45 "access-control-allow-methods",
46 "access-control-allow-headers",
47];
48
49fn is_hop_by_hop(name: &str) -> bool {
50 HOP_BY_HOP.contains(&name.to_ascii_lowercase().as_str())
51}
52
53pub struct Forwarder {
54 client: Client,
55}
56
57impl Forwarder {
58 pub fn new(timeout_secs: u64) -> Result<Self> {
59 let client = Client::builder()
60 .timeout(std::time::Duration::from_secs(timeout_secs))
61 .redirect(reqwest::redirect::Policy::none())
62 .build()
63 .context("Failed to build HTTP client")?;
64
65 Ok(Self { client })
66 }
67
68 pub async fn forward(
75 &self,
76 upstream: &str,
77 method: &str,
78 path: &str,
79 body: Bytes,
80 client_headers: &HeaderMap,
81 account: &AccountConfig,
82 token: &str,
83 ) -> Result<Response<Body>> {
84 let _request_id = &Uuid::new_v4().to_string()[..8];
85 let url = format!("{}{}", upstream, path);
86
87 let mut upstream_headers = reqwest::header::HeaderMap::new();
88
89 for &name in ALLOWED_REQUEST_HEADERS {
91 if let Some(value) = client_headers.get(name) {
92 if let Ok(n) = reqwest::header::HeaderName::from_str(name) {
93 if let Ok(v) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) {
94 upstream_headers.insert(n, v);
95 }
96 }
97 }
98 }
99
100 account.provider.inject_auth_headers(&mut upstream_headers, token)
102 .context("failed to inject auth headers")?;
103
104 let upstream_resp = self
105 .client
106 .request(
107 reqwest::Method::from_str(method).context("invalid method")?,
108 &url,
109 )
110 .headers(upstream_headers)
111 .body(body.clone())
112 .send()
113 .await
114 .context("upstream request failed")?;
115
116 let status = upstream_resp.status();
117
118 let mut builder = Response::builder().status(status.as_u16());
119
120 for (name, value) in upstream_resp.headers().iter() {
121 let lower = name.as_str().to_ascii_lowercase();
122 if is_hop_by_hop(&lower) || BLOCKED_RESPONSE_HEADERS.contains(&lower.as_str()) {
124 continue;
125 }
126 if let (Ok(n), Ok(v)) = (
127 HeaderName::from_str(name.as_str()),
128 HeaderValue::from_bytes(value.as_bytes()),
129 ) {
130 builder = builder.header(n, v);
131 }
132 }
133
134 let body = Body::from_stream(upstream_resp.bytes_stream());
135 Ok(builder.body(body).expect("response builder invariant"))
136 }
137}