Skip to main content

ccs_proxy/proxy/
forward.rs

1//! Reqwest-backed streaming reverse-proxy handler.
2//!
3//! Every inbound HTTP request is forwarded to `state.upstream` unchanged
4//! (modulo hop-by-hop headers), the response is streamed back to the client
5//! as it arrives (no buffering), and a tee'd copy of the byte stream is fed
6//! to a background task that reassembles the SSE into a final JSON message
7//! and writes a `CaptureRecord` via `state.store`.
8
9use crate::AppState;
10use crate::capture::{CaptureEvent, CaptureRecord, RequestPart, ResponsePart, Usage};
11use crate::proxy::sse_tap::{self, TapReceiver};
12use axum::body::Body;
13use axum::extract::{OriginalUri, Request, State};
14use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
15use axum::response::{IntoResponse, Response};
16use bytes::Bytes;
17use chrono::{DateTime, Utc};
18use serde_json::Value;
19use std::collections::BTreeMap;
20use std::sync::Arc;
21use url::Url;
22
23const HOP_BY_HOP: &[&str] = &[
24    "connection",
25    "keep-alive",
26    "proxy-authenticate",
27    "proxy-authorization",
28    "te",
29    "trailers",
30    "transfer-encoding",
31    "upgrade",
32    "host",
33];
34
35const MAX_REQUEST_BODY: usize = 32 * 1024 * 1024;
36
37/// Returns a process-global `reqwest::Client` so that successive forwarded
38/// requests reuse the connection pool, DNS cache, and TLS session cache.
39/// Rebuilding a `Client` per request defeats keep-alive and adds measurable
40/// TTFT overhead for a proxy.
41fn upstream_client() -> &'static reqwest::Client {
42    static CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
43    CLIENT.get_or_init(|| {
44        reqwest::Client::builder()
45            .no_proxy()
46            .build()
47            .unwrap_or_else(|_| reqwest::Client::new())
48    })
49}
50
51/// Headers prepared for the inbound side of the response, plus the same
52/// headers projected to a `BTreeMap` for the capture record.
53type ResponseHeaderPair = (HeaderMap, BTreeMap<String, String>);
54
55/// Inputs collected from the inbound request before we hit the network.
56struct PreparedRequest {
57    method: Method,
58    upstream_url: Url,
59    path_for_capture: String,
60    req_headers: HeaderMap,
61    body_bytes: Bytes,
62    req_body_json: Value,
63}
64
65/// All inputs the background capture task needs after the response status +
66/// headers are known.
67struct CaptureCtx {
68    state: AppState,
69    started_at: DateTime<Utc>,
70    method: Method,
71    path: String,
72    req_headers_map: BTreeMap<String, String>,
73    req_body_json: Value,
74    resp_status: u16,
75    resp_headers_map: BTreeMap<String, String>,
76    model: Option<String>,
77}
78
79pub async fn forward(
80    State(state): State<AppState>,
81    OriginalUri(uri): OriginalUri,
82    req: Request,
83) -> Response {
84    let method = req.method().clone();
85    let req_headers = req.headers().clone();
86    let path_for_capture = uri
87        .path_and_query()
88        .map(|pq| pq.as_str().to_string())
89        .unwrap_or_else(|| "/".into());
90    let upstream_url = build_upstream_url(&state.upstream, &uri);
91
92    let body_bytes = match read_request_body(req).await {
93        Ok(bytes) => bytes,
94        Err(resp) => return resp,
95    };
96
97    let req_body_json = serde_json::from_slice::<Value>(&body_bytes).unwrap_or(Value::Null);
98    let prepared = PreparedRequest {
99        method,
100        upstream_url,
101        path_for_capture,
102        req_headers,
103        body_bytes,
104        req_body_json,
105    };
106    dispatch(state, prepared).await
107}
108
109async fn dispatch(state: AppState, prepared: PreparedRequest) -> Response {
110    let upstream_resp = match send_upstream(&prepared).await {
111        Ok(resp) => resp,
112        Err(err_resp) => return err_resp,
113    };
114
115    let status = upstream_resp.status();
116    let (resp_headers, resp_headers_map) = collect_response_headers(upstream_resp.headers());
117    let started_at = chrono::Utc::now();
118    let model = crate::capture::extract::extract_model_from_request_body(&prepared.req_body_json);
119    let req_headers_map = headers_to_map(&prepared.req_headers);
120
121    let byte_stream = upstream_resp.bytes_stream();
122    let (client_stream, tap_rx) = sse_tap::tee(byte_stream);
123
124    let ctx = CaptureCtx {
125        state,
126        started_at,
127        method: prepared.method,
128        path: prepared.path_for_capture,
129        req_headers_map,
130        req_body_json: prepared.req_body_json,
131        resp_status: status.as_u16(),
132        resp_headers_map,
133        model,
134    };
135    tokio::spawn(run_capture(ctx, tap_rx));
136
137    build_streaming_response(status, resp_headers, client_stream)
138}
139
140fn build_upstream_url(upstream: &Url, uri: &axum::http::Uri) -> Url {
141    let mut url = upstream.clone();
142    let base_path = upstream.path().trim_end_matches('/');
143    let req_path = uri.path();
144    let combined = format!("{base_path}{req_path}");
145    url.set_path(&combined);
146    url.set_query(uri.query());
147    url
148}
149
150async fn read_request_body(req: Request) -> Result<Bytes, Response> {
151    match axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY).await {
152        Ok(bytes) => Ok(bytes),
153        Err(err) => {
154            tracing::warn!(?err, "failed to read request body");
155            Err((
156                StatusCode::BAD_REQUEST,
157                "request body too large or unreadable",
158            )
159                .into_response())
160        }
161    }
162}
163
164async fn send_upstream(prepared: &PreparedRequest) -> Result<reqwest::Response, Response> {
165    let mut rb = upstream_client()
166        .request(
167            reqwest_method(&prepared.method),
168            prepared.upstream_url.clone(),
169        )
170        .body(prepared.body_bytes.to_vec());
171    for (name, value) in prepared.req_headers.iter() {
172        let kn = name.as_str();
173        if HOP_BY_HOP.iter().any(|h| h.eq_ignore_ascii_case(kn)) {
174            continue;
175        }
176        if kn.eq_ignore_ascii_case("content-length") {
177            continue;
178        }
179        if let (Ok(rname), Ok(rval)) = (
180            reqwest::header::HeaderName::from_bytes(name.as_str().as_bytes()),
181            reqwest::header::HeaderValue::from_bytes(value.as_bytes()),
182        ) {
183            rb = rb.header(rname, rval);
184        }
185    }
186
187    match rb.send().await {
188        Ok(resp) => Ok(resp),
189        Err(err) => {
190            let kind = classify_reqwest_err(&err);
191            tracing::warn!(?err, kind, "upstream request failed");
192            let body = serde_json::json!({
193                "error": {
194                    "type": kind,
195                    "message": err.to_string(),
196                }
197            });
198            Err((StatusCode::BAD_GATEWAY, axum::Json(body)).into_response())
199        }
200    }
201}
202
203fn collect_response_headers(upstream: &reqwest::header::HeaderMap) -> ResponseHeaderPair {
204    let mut axum_headers = HeaderMap::new();
205    let mut as_map: BTreeMap<String, String> = BTreeMap::new();
206    for (name, value) in upstream.iter() {
207        if HOP_BY_HOP
208            .iter()
209            .any(|h| name.as_str().eq_ignore_ascii_case(h))
210        {
211            continue;
212        }
213        if let (Ok(an), Ok(av)) = (
214            HeaderName::from_bytes(name.as_str().as_bytes()),
215            HeaderValue::from_bytes(value.as_bytes()),
216        ) {
217            axum_headers.insert(an, av);
218        }
219        if let Ok(text) = value.to_str() {
220            as_map.insert(name.as_str().to_string(), text.to_string());
221        }
222    }
223    (axum_headers, as_map)
224}
225
226fn build_streaming_response<S>(status: StatusCode, headers: HeaderMap, client_stream: S) -> Response
227where
228    S: futures::Stream<Item = Result<Bytes, std::io::Error>> + Send + 'static,
229{
230    let body = Body::from_stream(client_stream);
231    let mut builder = Response::builder().status(status);
232    for (name, value) in headers.iter() {
233        builder = builder.header(name, value);
234    }
235    builder
236        .body(body)
237        .unwrap_or_else(|_| StatusCode::BAD_GATEWAY.into_response())
238}
239
240async fn run_capture(ctx: CaptureCtx, tap_rx: TapReceiver) {
241    let CaptureCtx {
242        state,
243        started_at,
244        method,
245        path,
246        mut req_headers_map,
247        mut req_body_json,
248        resp_status,
249        mut resp_headers_map,
250        model,
251    } = ctx;
252
253    let request_id = crate::capture::extract::extract_request_id(&resp_headers_map);
254    let seq = next_seq(&state.store, state.session_id.as_str()).await;
255
256    if let Err(err) = state.events.send(CaptureEvent::RequestStarted {
257        session_id: state.session_id.as_str().to_string(),
258        seq,
259        started_at,
260        model: model.clone(),
261    }) {
262        tracing::trace!(?err, "no subscribers for RequestStarted");
263    }
264
265    let (body_reassembled, frames_count, partial_err) =
266        sse_tap::reassemble(state.provider, tap_rx).await;
267
268    let ended_at = chrono::Utc::now();
269    let duration_ms = duration_ms_since(started_at, ended_at);
270    let usage = usage_from_reassembled(body_reassembled.as_ref());
271
272    if state.redact {
273        crate::capture::redact::redact_headers(&mut req_headers_map);
274        crate::capture::redact::redact_body(&mut req_body_json);
275        crate::capture::redact::redact_headers(&mut resp_headers_map);
276    }
277
278    let rec = CaptureRecord {
279        seq,
280        session_id: state.session_id.as_str().to_string(),
281        request_id: request_id.clone(),
282        started_at,
283        ended_at: Some(ended_at),
284        duration_ms: Some(duration_ms),
285        ttft_ms: None,
286        request: RequestPart {
287            method: method.as_str().to_string(),
288            path,
289            headers: req_headers_map,
290            body: req_body_json,
291        },
292        response: Some(ResponsePart {
293            status: resp_status,
294            headers: resp_headers_map,
295            body_reassembled,
296            raw_sse_text: None,
297            raw_sse_frames_count: frames_count,
298        }),
299        usage: usage.clone(),
300        model,
301        error: partial_err.clone(),
302        partial: partial_err.is_some(),
303        schema_version: 1,
304    };
305    if let Err(err) = state.store.append(rec).await {
306        tracing::warn!(?err, "store append failed");
307    }
308
309    let has_error = partial_err.is_some() || resp_status >= 400;
310    if let Err(err) = state.events.send(CaptureEvent::RequestCompleted {
311        session_id: state.session_id.as_str().to_string(),
312        seq,
313        duration_ms,
314        status: resp_status,
315        request_id,
316        usage,
317        has_error,
318    }) {
319        tracing::trace!(?err, "no subscribers for RequestCompleted");
320    }
321}
322
323fn headers_to_map(headers: &HeaderMap) -> BTreeMap<String, String> {
324    let mut out: BTreeMap<String, String> = BTreeMap::new();
325    for (name, value) in headers.iter() {
326        if let Ok(text) = value.to_str() {
327            out.insert(name.as_str().to_string(), text.to_string());
328        }
329    }
330    out
331}
332
333fn reqwest_method(method: &Method) -> reqwest::Method {
334    reqwest::Method::from_bytes(method.as_str().as_bytes()).unwrap_or(reqwest::Method::GET)
335}
336
337fn classify_reqwest_err(err: &reqwest::Error) -> &'static str {
338    if err.is_timeout() {
339        return "upstream_timeout";
340    }
341    if err.is_connect() {
342        return "upstream_unreachable";
343    }
344    if err.to_string().to_lowercase().contains("tls") {
345        return "tls_handshake_failed";
346    }
347    "upstream_error"
348}
349
350fn usage_from_reassembled(value: Option<&Value>) -> Option<Usage> {
351    let value = value?;
352    let usage = value.get("usage")?;
353    Some(Usage {
354        input_tokens: usage
355            .get("input_tokens")
356            .and_then(Value::as_u64)
357            .unwrap_or(0),
358        output_tokens: usage
359            .get("output_tokens")
360            .and_then(Value::as_u64)
361            .unwrap_or(0),
362        cache_creation_input_tokens: usage
363            .get("cache_creation_input_tokens")
364            .and_then(Value::as_u64)
365            .unwrap_or(0),
366        cache_read_input_tokens: usage
367            .get("cache_read_input_tokens")
368            .and_then(Value::as_u64)
369            .unwrap_or(0),
370    })
371}
372
373fn duration_ms_since(started_at: DateTime<Utc>, ended_at: DateTime<Utc>) -> u64 {
374    let millis = (ended_at - started_at).num_milliseconds().max(0);
375    u64::try_from(millis).unwrap_or(0)
376}
377
378async fn next_seq(store: &Arc<dyn crate::store::Store>, sid: &str) -> u64 {
379    let highest = store
380        .list_requests(sid)
381        .await
382        .map(|list| list.iter().map(|item| item.seq).max().unwrap_or(0))
383        .unwrap_or(0);
384    highest.saturating_add(1)
385}