1use crate::AppState;
10use crate::capture::{CaptureEvent, CaptureRecord, RequestPart, ResponsePart, Usage};
11use crate::proxy::sse_tap::{self, TapReceiver};
12use axum::body::Body;
13use axum::extract::{OriginalUri, Request, State};
14use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
15use axum::response::{IntoResponse, Response};
16use bytes::Bytes;
17use chrono::{DateTime, Utc};
18use serde_json::Value;
19use std::collections::BTreeMap;
20use std::sync::Arc;
21use url::Url;
22
23const HOP_BY_HOP: &[&str] = &[
24 "connection",
25 "keep-alive",
26 "proxy-authenticate",
27 "proxy-authorization",
28 "te",
29 "trailers",
30 "transfer-encoding",
31 "upgrade",
32 "host",
33];
34
35const MAX_REQUEST_BODY: usize = 32 * 1024 * 1024;
36
37fn upstream_client() -> &'static reqwest::Client {
42 static CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
43 CLIENT.get_or_init(|| {
44 reqwest::Client::builder()
45 .no_proxy()
46 .build()
47 .unwrap_or_else(|_| reqwest::Client::new())
48 })
49}
50
51type ResponseHeaderPair = (HeaderMap, BTreeMap<String, String>);
54
55struct PreparedRequest {
57 method: Method,
58 upstream_url: Url,
59 path_for_capture: String,
60 req_headers: HeaderMap,
61 body_bytes: Bytes,
62 req_body_json: Value,
63}
64
65struct CaptureCtx {
68 state: AppState,
69 started_at: DateTime<Utc>,
70 method: Method,
71 path: String,
72 req_headers_map: BTreeMap<String, String>,
73 req_body_json: Value,
74 resp_status: u16,
75 resp_headers_map: BTreeMap<String, String>,
76 model: Option<String>,
77}
78
79pub async fn forward(
80 State(state): State<AppState>,
81 OriginalUri(uri): OriginalUri,
82 req: Request,
83) -> Response {
84 let method = req.method().clone();
85 let req_headers = req.headers().clone();
86 let path_for_capture = uri
87 .path_and_query()
88 .map(|pq| pq.as_str().to_string())
89 .unwrap_or_else(|| "/".into());
90 let upstream_url = build_upstream_url(&state.upstream, &uri);
91
92 let body_bytes = match read_request_body(req).await {
93 Ok(bytes) => bytes,
94 Err(resp) => return resp,
95 };
96
97 let req_body_json = serde_json::from_slice::<Value>(&body_bytes).unwrap_or(Value::Null);
98 let prepared = PreparedRequest {
99 method,
100 upstream_url,
101 path_for_capture,
102 req_headers,
103 body_bytes,
104 req_body_json,
105 };
106 dispatch(state, prepared).await
107}
108
109async fn dispatch(state: AppState, prepared: PreparedRequest) -> Response {
110 let upstream_resp = match send_upstream(&prepared).await {
111 Ok(resp) => resp,
112 Err(err_resp) => return err_resp,
113 };
114
115 let status = upstream_resp.status();
116 let (resp_headers, resp_headers_map) = collect_response_headers(upstream_resp.headers());
117 let started_at = chrono::Utc::now();
118 let model = crate::capture::extract::extract_model_from_request_body(&prepared.req_body_json);
119 let req_headers_map = headers_to_map(&prepared.req_headers);
120
121 let byte_stream = upstream_resp.bytes_stream();
122 let (client_stream, tap_rx) = sse_tap::tee(byte_stream);
123
124 let ctx = CaptureCtx {
125 state,
126 started_at,
127 method: prepared.method,
128 path: prepared.path_for_capture,
129 req_headers_map,
130 req_body_json: prepared.req_body_json,
131 resp_status: status.as_u16(),
132 resp_headers_map,
133 model,
134 };
135 tokio::spawn(run_capture(ctx, tap_rx));
136
137 build_streaming_response(status, resp_headers, client_stream)
138}
139
140fn build_upstream_url(upstream: &Url, uri: &axum::http::Uri) -> Url {
141 let mut url = upstream.clone();
142 let base_path = upstream.path().trim_end_matches('/');
143 let req_path = uri.path();
144 let combined = format!("{base_path}{req_path}");
145 url.set_path(&combined);
146 url.set_query(uri.query());
147 url
148}
149
150async fn read_request_body(req: Request) -> Result<Bytes, Response> {
151 match axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY).await {
152 Ok(bytes) => Ok(bytes),
153 Err(err) => {
154 tracing::warn!(?err, "failed to read request body");
155 Err((
156 StatusCode::BAD_REQUEST,
157 "request body too large or unreadable",
158 )
159 .into_response())
160 }
161 }
162}
163
164async fn send_upstream(prepared: &PreparedRequest) -> Result<reqwest::Response, Response> {
165 let mut rb = upstream_client()
166 .request(
167 reqwest_method(&prepared.method),
168 prepared.upstream_url.clone(),
169 )
170 .body(prepared.body_bytes.to_vec());
171 for (name, value) in prepared.req_headers.iter() {
172 let kn = name.as_str();
173 if HOP_BY_HOP.iter().any(|h| h.eq_ignore_ascii_case(kn)) {
174 continue;
175 }
176 if kn.eq_ignore_ascii_case("content-length") {
177 continue;
178 }
179 if let (Ok(rname), Ok(rval)) = (
180 reqwest::header::HeaderName::from_bytes(name.as_str().as_bytes()),
181 reqwest::header::HeaderValue::from_bytes(value.as_bytes()),
182 ) {
183 rb = rb.header(rname, rval);
184 }
185 }
186
187 match rb.send().await {
188 Ok(resp) => Ok(resp),
189 Err(err) => {
190 let kind = classify_reqwest_err(&err);
191 tracing::warn!(?err, kind, "upstream request failed");
192 let body = serde_json::json!({
193 "error": {
194 "type": kind,
195 "message": err.to_string(),
196 }
197 });
198 Err((StatusCode::BAD_GATEWAY, axum::Json(body)).into_response())
199 }
200 }
201}
202
203fn collect_response_headers(upstream: &reqwest::header::HeaderMap) -> ResponseHeaderPair {
204 let mut axum_headers = HeaderMap::new();
205 let mut as_map: BTreeMap<String, String> = BTreeMap::new();
206 for (name, value) in upstream.iter() {
207 if HOP_BY_HOP
208 .iter()
209 .any(|h| name.as_str().eq_ignore_ascii_case(h))
210 {
211 continue;
212 }
213 if let (Ok(an), Ok(av)) = (
214 HeaderName::from_bytes(name.as_str().as_bytes()),
215 HeaderValue::from_bytes(value.as_bytes()),
216 ) {
217 axum_headers.insert(an, av);
218 }
219 if let Ok(text) = value.to_str() {
220 as_map.insert(name.as_str().to_string(), text.to_string());
221 }
222 }
223 (axum_headers, as_map)
224}
225
226fn build_streaming_response<S>(status: StatusCode, headers: HeaderMap, client_stream: S) -> Response
227where
228 S: futures::Stream<Item = Result<Bytes, std::io::Error>> + Send + 'static,
229{
230 let body = Body::from_stream(client_stream);
231 let mut builder = Response::builder().status(status);
232 for (name, value) in headers.iter() {
233 builder = builder.header(name, value);
234 }
235 builder
236 .body(body)
237 .unwrap_or_else(|_| StatusCode::BAD_GATEWAY.into_response())
238}
239
240async fn run_capture(ctx: CaptureCtx, tap_rx: TapReceiver) {
241 let CaptureCtx {
242 state,
243 started_at,
244 method,
245 path,
246 mut req_headers_map,
247 mut req_body_json,
248 resp_status,
249 mut resp_headers_map,
250 model,
251 } = ctx;
252
253 let request_id = crate::capture::extract::extract_request_id(&resp_headers_map);
254 let seq = next_seq(&state.store, state.session_id.as_str()).await;
255
256 if let Err(err) = state.events.send(CaptureEvent::RequestStarted {
257 session_id: state.session_id.as_str().to_string(),
258 seq,
259 started_at,
260 model: model.clone(),
261 }) {
262 tracing::trace!(?err, "no subscribers for RequestStarted");
263 }
264
265 let (body_reassembled, frames_count, partial_err) =
266 sse_tap::reassemble(state.provider, tap_rx).await;
267
268 let ended_at = chrono::Utc::now();
269 let duration_ms = duration_ms_since(started_at, ended_at);
270 let usage = usage_from_reassembled(body_reassembled.as_ref());
271
272 if state.redact {
273 crate::capture::redact::redact_headers(&mut req_headers_map);
274 crate::capture::redact::redact_body(&mut req_body_json);
275 crate::capture::redact::redact_headers(&mut resp_headers_map);
276 }
277
278 let rec = CaptureRecord {
279 seq,
280 session_id: state.session_id.as_str().to_string(),
281 request_id: request_id.clone(),
282 started_at,
283 ended_at: Some(ended_at),
284 duration_ms: Some(duration_ms),
285 ttft_ms: None,
286 request: RequestPart {
287 method: method.as_str().to_string(),
288 path,
289 headers: req_headers_map,
290 body: req_body_json,
291 },
292 response: Some(ResponsePart {
293 status: resp_status,
294 headers: resp_headers_map,
295 body_reassembled,
296 raw_sse_text: None,
297 raw_sse_frames_count: frames_count,
298 }),
299 usage: usage.clone(),
300 model,
301 error: partial_err.clone(),
302 partial: partial_err.is_some(),
303 schema_version: 1,
304 };
305 if let Err(err) = state.store.append(rec).await {
306 tracing::warn!(?err, "store append failed");
307 }
308
309 let has_error = partial_err.is_some() || resp_status >= 400;
310 if let Err(err) = state.events.send(CaptureEvent::RequestCompleted {
311 session_id: state.session_id.as_str().to_string(),
312 seq,
313 duration_ms,
314 status: resp_status,
315 request_id,
316 usage,
317 has_error,
318 }) {
319 tracing::trace!(?err, "no subscribers for RequestCompleted");
320 }
321}
322
323fn headers_to_map(headers: &HeaderMap) -> BTreeMap<String, String> {
324 let mut out: BTreeMap<String, String> = BTreeMap::new();
325 for (name, value) in headers.iter() {
326 if let Ok(text) = value.to_str() {
327 out.insert(name.as_str().to_string(), text.to_string());
328 }
329 }
330 out
331}
332
333fn reqwest_method(method: &Method) -> reqwest::Method {
334 reqwest::Method::from_bytes(method.as_str().as_bytes()).unwrap_or(reqwest::Method::GET)
335}
336
337fn classify_reqwest_err(err: &reqwest::Error) -> &'static str {
338 if err.is_timeout() {
339 return "upstream_timeout";
340 }
341 if err.is_connect() {
342 return "upstream_unreachable";
343 }
344 if err.to_string().to_lowercase().contains("tls") {
345 return "tls_handshake_failed";
346 }
347 "upstream_error"
348}
349
350fn usage_from_reassembled(value: Option<&Value>) -> Option<Usage> {
351 let value = value?;
352 let usage = value.get("usage")?;
353 Some(Usage {
354 input_tokens: usage
355 .get("input_tokens")
356 .and_then(Value::as_u64)
357 .unwrap_or(0),
358 output_tokens: usage
359 .get("output_tokens")
360 .and_then(Value::as_u64)
361 .unwrap_or(0),
362 cache_creation_input_tokens: usage
363 .get("cache_creation_input_tokens")
364 .and_then(Value::as_u64)
365 .unwrap_or(0),
366 cache_read_input_tokens: usage
367 .get("cache_read_input_tokens")
368 .and_then(Value::as_u64)
369 .unwrap_or(0),
370 })
371}
372
373fn duration_ms_since(started_at: DateTime<Utc>, ended_at: DateTime<Utc>) -> u64 {
374 let millis = (ended_at - started_at).num_milliseconds().max(0);
375 u64::try_from(millis).unwrap_or(0)
376}
377
378async fn next_seq(store: &Arc<dyn crate::store::Store>, sid: &str) -> u64 {
379 let highest = store
380 .list_requests(sid)
381 .await
382 .map(|list| list.iter().map(|item| item.seq).max().unwrap_or(0))
383 .unwrap_or(0);
384 highest.saturating_add(1)
385}