Skip to main content

shunt/
forwarder.rs

1use anyhow::{Context, Result};
2use axum::body::Body;
3use axum::http::{HeaderMap, HeaderName, HeaderValue, Response};
4use bytes::Bytes;
5use reqwest::Client;
6use std::str::FromStr;
7use uuid::Uuid;
8
9use crate::config::AccountConfig;
10
11/// Headers that must never be forwarded in either direction.
12const HOP_BY_HOP: &[&str] = &[
13    "connection",
14    "keep-alive",
15    "proxy-authenticate",
16    "proxy-authorization",
17    "te",
18    "trailers",
19    "transfer-encoding",
20    "upgrade",
21    "host",
22    "content-length",
23];
24
25/// Headers the proxy explicitly passes through to upstream.
26/// All other client-supplied headers are dropped (allowlist approach, #15).
27const ALLOWED_REQUEST_HEADERS: &[&str] = &[
28    "content-type",
29    "accept",
30    "anthropic-version",
31    "anthropic-beta",
32    "anthropic-dangerous-direct-browser-access",
33    "x-request-id",
34    "user-agent",
35    // chatgpt.com sentinel token — injected by proxy, pass through
36    "openai-sentinel-chat-requirements-token",
37];
38
39/// Sensitive response headers that upstream must never inject into client responses (#21).
40const BLOCKED_RESPONSE_HEADERS: &[&str] = &[
41    "set-cookie",
42    "set-cookie2",
43    "access-control-allow-origin",
44    "access-control-allow-credentials",
45    "access-control-allow-methods",
46    "access-control-allow-headers",
47];
48
49fn is_hop_by_hop(name: &str) -> bool {
50    HOP_BY_HOP.contains(&name.to_ascii_lowercase().as_str())
51}
52
53pub struct Forwarder {
54    client: Client,
55}
56
57impl Forwarder {
58    pub fn new(timeout_secs: u64) -> Result<Self> {
59        let client = Client::builder()
60            .timeout(std::time::Duration::from_secs(timeout_secs))
61            .redirect(reqwest::redirect::Policy::none())
62            .build()
63            .context("Failed to build HTTP client")?;
64
65        Ok(Self { client })
66    }
67
68    /// Forward a request to the upstream using the given account's OAuth credential.
69    ///
70    /// - `upstream` overrides the base URL for this account (per-provider routing).
71    /// - Strips `Authorization` and `x-api-key` from the client request.
72    /// - Injects `Authorization: Bearer <token>` (live token, may differ from account.credential).
73    /// - Keeps the upstream TCP connection alive for streaming responses.
74    pub async fn forward(
75        &self,
76        upstream: &str,
77        method: &str,
78        path: &str,
79        body: Bytes,
80        client_headers: &HeaderMap,
81        account: &AccountConfig,
82        token: &str,
83    ) -> Result<Response<Body>> {
84        let _request_id = &Uuid::new_v4().to_string()[..8];
85        let url = format!("{}{}", upstream, path);
86
87        let mut upstream_headers = reqwest::header::HeaderMap::new();
88
89        // #15: allowlist — only forward explicitly permitted client headers.
90        for &name in ALLOWED_REQUEST_HEADERS {
91            if let Some(value) = client_headers.get(name) {
92                if let Ok(n) = reqwest::header::HeaderName::from_str(name) {
93                    if let Ok(v) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) {
94                        upstream_headers.insert(n, v);
95                    }
96                }
97            }
98        }
99
100        // Inject provider-specific auth headers (Bearer token + any required protocol headers).
101        account.provider.inject_auth_headers(&mut upstream_headers, token)
102            .context("failed to inject auth headers")?;
103
104        let upstream_resp = self
105            .client
106            .request(
107                reqwest::Method::from_str(method).context("invalid method")?,
108                &url,
109            )
110            .headers(upstream_headers)
111            .body(body.clone())
112            .send()
113            .await
114            .context("upstream request failed")?;
115
116        let status = upstream_resp.status();
117
118        let mut builder = Response::builder().status(status.as_u16());
119
120        for (name, value) in upstream_resp.headers().iter() {
121            let lower = name.as_str().to_ascii_lowercase();
122            // #21: drop hop-by-hop and sensitive response headers.
123            if is_hop_by_hop(&lower) || BLOCKED_RESPONSE_HEADERS.contains(&lower.as_str()) {
124                continue;
125            }
126            if let (Ok(n), Ok(v)) = (
127                HeaderName::from_str(name.as_str()),
128                HeaderValue::from_bytes(value.as_bytes()),
129            ) {
130                builder = builder.header(n, v);
131            }
132        }
133
134        let body = Body::from_stream(upstream_resp.bytes_stream());
135        Ok(builder.body(body).expect("response builder invariant"))
136    }
137}