Skip to main content

agent_proxy_rust_core/
server.rs

1//! The axum-based proxy server engine.
2//!
3//! Provides the [`AgentProxy`] builder, router, and request handler.
4
5use std::sync::{
6    Arc,
7    atomic::{AtomicU64, Ordering},
8};
9
10use axum::{
11    Json, Router,
12    body::Body,
13    extract::State,
14    http::{Request, Response, StatusCode},
15    middleware,
16    response::IntoResponse,
17    routing::{get, post},
18};
19use tokio::task::JoinHandle;
20use tower_http::limit::RequestBodyLimitLayer;
21
22use secrecy::ExposeSecret;
23
24use crate::{
25    auth::{self, AgentRole, AuthState},
26    config::ProxyConfig,
27    error::ProxyError,
28    middleware::{CostRecorder, ProxyMiddleware, run_on_request_chain, run_on_response_chain},
29    types::{ConnectionContext, ProxyRequest, ProxyResponse, detect_agent_type, detect_api_format},
30};
31
32/// Shared state for the proxy server.
33#[derive(Clone)]
34pub struct ProxyState {
35    /// Proxy configuration.
36    pub config: Arc<ProxyConfig>,
37    /// Registered middleware chain.
38    pub middlewares: Arc<Vec<Box<dyn ProxyMiddleware>>>,
39    /// Reusable HTTP client for upstream forwarding.
40    pub client: reqwest::Client,
41    /// Optional cost recorder for post-response billing.
42    pub cost_recorder: Option<Arc<dyn CostRecorder>>,
43    next_request_id: Arc<AtomicU64>,
44}
45
46impl ProxyState {
47    fn next_request_id(&self) -> u64 {
48        self.next_request_id.fetch_add(1, Ordering::SeqCst)
49    }
50}
51
52impl std::fmt::Debug for ProxyState {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        let mw_names: Vec<&str> = self.middlewares.iter().map(|m| m.name()).collect();
55        f.debug_struct("ProxyState")
56            .field("config", &self.config)
57            .field("middlewares", &mw_names)
58            .field("client", &self.client)
59            .field(
60                "cost_recorder",
61                &self.cost_recorder.as_ref().map(|_| "CostRecorder"),
62            )
63            .field("next_request_id", &self.next_request_id)
64            .finish()
65    }
66}
67
68/// The proxy application.
69///
70/// Created via [`AgentProxy::builder()`] and started with [`AgentProxy::serve()`].
71pub struct AgentProxy {
72    config: ProxyConfig,
73    middlewares: Arc<Vec<Box<dyn ProxyMiddleware>>>,
74    cost_recorder: Option<Arc<dyn CostRecorder>>,
75}
76
77impl AgentProxy {
78    /// Creates a new [`AgentProxyBuilder`].
79    #[must_use]
80    pub fn builder() -> AgentProxyBuilder {
81        AgentProxyBuilder::default()
82    }
83
84    /// Returns the axum [`Router`] for this proxy without starting a server.
85    /// Useful for combining with other routers (e.g., admin API).
86    ///
87    /// # Errors
88    ///
89    /// Returns [`ProxyError::Config`] if the reqwest client cannot be built
90    /// from the proxy configuration.
91    pub fn into_router(self) -> Result<Router, ProxyError> {
92        let client = build_reqwest_client(&self.config)?;
93        let state = Arc::new(ProxyState {
94            config: Arc::new(self.config),
95            middlewares: self.middlewares,
96            client,
97            cost_recorder: self.cost_recorder,
98            next_request_id: Arc::new(AtomicU64::new(1)),
99        });
100        Ok(build_router(state))
101    }
102
103    /// Starts the proxy server and returns a [`JoinHandle`].
104    ///
105    /// Runs `on_init` on all middlewares before binding.
106    ///
107    /// # Errors
108    ///
109    /// Returns a [`ProxyError`] if binding to the listen address fails.
110    pub async fn serve(self) -> Result<JoinHandle<()>, ProxyError> {
111        let client = build_reqwest_client(&self.config)?;
112
113        let state = Arc::new(ProxyState {
114            config: Arc::new(self.config),
115            middlewares: self.middlewares,
116            client,
117            cost_recorder: self.cost_recorder,
118            next_request_id: Arc::new(AtomicU64::new(1)),
119        });
120
121        // Run on_init hooks
122        for mw in state.middlewares.iter() {
123            mw.on_init().await?;
124        }
125
126        let app = build_router(state.clone());
127        let listener = tokio::net::TcpListener::bind(state.config.listen)
128            .await
129            .map_err(|e| ProxyError::Internal(e.into()))?;
130
131        tracing::warn!("agent-proxy listening on {}", state.config.listen);
132
133        let handle = tokio::spawn(async move {
134            if let Err(e) = axum::serve(listener, app).await {
135                tracing::error!("server error: {e}");
136            }
137        });
138
139        Ok(handle)
140    }
141}
142
143/// Builder for [`AgentProxy`].
144///
145/// # Example
146///
147/// ```rust,ignore
148/// use agent_proxy_rust_core::{AgentProxy, ProxyConfig};
149///
150/// let proxy = AgentProxy::builder()
151///     .config(ProxyConfig::default())
152///     .middleware(my_middleware)
153///     .build()
154///     .unwrap();
155/// ```
156#[derive(Default)]
157pub struct AgentProxyBuilder {
158    config: Option<ProxyConfig>,
159    middlewares: Vec<Box<dyn ProxyMiddleware>>,
160    cost_recorder: Option<Arc<dyn CostRecorder>>,
161}
162
163impl AgentProxyBuilder {
164    /// Sets the cost recorder for post-response billing.
165    #[must_use]
166    pub fn cost_recorder(mut self, cr: Arc<dyn CostRecorder>) -> Self {
167        self.cost_recorder = Some(cr);
168        self
169    }
170
171    /// Sets the proxy configuration.
172    #[must_use]
173    pub fn config(mut self, config: ProxyConfig) -> Self {
174        self.config = Some(config);
175        self
176    }
177
178    /// Adds a middleware to the chain (in registration order).
179    #[must_use]
180    pub fn middleware<M: ProxyMiddleware + 'static>(mut self, m: M) -> Self {
181        self.middlewares.push(Box::new(m));
182        self
183    }
184
185    /// Builds the [`AgentProxy`].
186    ///
187    /// # Errors
188    ///
189    /// Returns a [`ProxyError`] if no config was provided.
190    pub fn build(self) -> Result<AgentProxy, ProxyError> {
191        let config = self
192            .config
193            .ok_or_else(|| ProxyError::Internal(anyhow::anyhow!("config is required")))?;
194        Ok(AgentProxy {
195            config,
196            middlewares: Arc::new(self.middlewares),
197            cost_recorder: self.cost_recorder,
198        })
199    }
200}
201
202impl std::fmt::Debug for AgentProxyBuilder {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        let mw_names: Vec<&str> = self.middlewares.iter().map(|m| m.name()).collect();
205        f.debug_struct("AgentProxyBuilder")
206            .field("config", &self.config)
207            .field("middlewares", &mw_names)
208            .field("cost_recorder", &self.cost_recorder.is_some())
209            .finish()
210    }
211}
212
213impl std::fmt::Debug for AgentProxy {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        let mw_names: Vec<&str> = self.middlewares.iter().map(|m| m.name()).collect();
216        f.debug_struct("AgentProxy")
217            .field("config", &self.config)
218            .field("middlewares", &mw_names)
219            .field("cost_recorder", &self.cost_recorder.is_some())
220            .finish()
221    }
222}
223
224/// Builds the reqwest client for upstream forwarding.
225fn build_reqwest_client(config: &ProxyConfig) -> Result<reqwest::Client, ProxyError> {
226    reqwest::Client::builder()
227        .connect_timeout(config.upstream_connect_timeout)
228        .read_timeout(config.upstream_read_timeout)
229        .http1_only()
230        .build()
231        .map_err(|e| ProxyError::Internal(e.into()))
232}
233
234/// Builds the axum router.
235fn build_router(state: Arc<ProxyState>) -> Router {
236    let auth_state = AuthState::from_config(&state.config);
237
238    Router::new()
239        .route("/v1/messages", post(handle_proxy_request))
240        .route("/v1/chat/completions", post(handle_proxy_request))
241        .route("/v1/responses", post(handle_proxy_request))
242        .route("/health", get(handle_health))
243        .layer(middleware::from_fn_with_state(
244            auth_state,
245            auth::auth_middleware,
246        ))
247        .layer(RequestBodyLimitLayer::new(state.config.max_body_size))
248        .with_state(state)
249}
250
251/// Health check handler.
252async fn handle_health() -> Json<serde_json::Value> {
253    Json(serde_json::json!({"status": "ok"}))
254}
255
256/// Single dispatch handler for all AI API endpoints.
257///
258/// 1. Detects the API format from the path.
259/// 2. Detects the agent type from headers.
260/// 3. Reads the auth role from request extensions.
261/// 4. Runs the `on_request` middleware chain.
262/// 5. Forwards to upstream (streaming or non-streaming).
263/// 6. Runs the `on_response` middleware chain.
264/// 7. Returns the response to the client.
265#[allow(clippy::too_many_lines)]
266async fn handle_proxy_request(
267    State(state): State<Arc<ProxyState>>,
268    req: Request<Body>,
269) -> Response<Body> {
270    let request_id = state.next_request_id();
271    let path = req.uri().path().to_string();
272    let detected_format = detect_api_format(&path);
273
274    // Unknown path → 404
275    if detected_format.is_none() {
276        return (
277            StatusCode::NOT_FOUND,
278            Json(serde_json::json!({
279                "error": {"code": "not_found", "message": format!("no route for {path}")}
280            })),
281        )
282            .into_response();
283    }
284
285    // Read body with size check
286    let (parts, body) = req.into_parts();
287
288    // Check Content-Length header for early rejection
289    let body_too_large = parts
290        .headers
291        .get("content-length")
292        .and_then(|cl| cl.to_str().ok())
293        .and_then(|s| s.parse::<usize>().ok())
294        .is_some_and(|len| len > state.config.max_body_size);
295
296    if body_too_large {
297        return (
298            StatusCode::PAYLOAD_TOO_LARGE,
299            Json(serde_json::json!({
300                "error": {
301                    "code": "body_too_large",
302                    "message": format!("request body exceeds size limit of {}", state.config.max_body_size)
303                }
304            })),
305        )
306            .into_response();
307    }
308
309    let body_bytes = match axum::body::to_bytes(body, state.config.max_body_size).await {
310        Ok(b) => b,
311        Err(_e) => {
312            return (
313                StatusCode::PAYLOAD_TOO_LARGE,
314                Json(serde_json::json!({
315                    "error": {
316                        "code": "body_too_large",
317                        "message": "request body exceeds size limit"
318                    }
319                })),
320            )
321                .into_response();
322        }
323    };
324
325    let agent_type = detect_agent_type(&parts.headers, &path);
326    let agent_role = parts.extensions.get::<AgentRole>().map(|r| r.0.clone());
327
328    let mut proxy_req = ProxyRequest::new(parts.method, path, parts.headers, body_bytes);
329
330    // ── Pre-middleware input validation ─────────────────────────────
331    if let Err(e) = validate_proxy_request(&proxy_req) {
332        log_error(
333            &e,
334            &ConnectionContext::new(request_id, agent_type, agent_role.clone(), detected_format),
335        );
336        return e.to_response();
337    }
338
339    let mut ctx = ConnectionContext::new(request_id, agent_type, agent_role, detected_format);
340
341    // ── Session correlation: extract header + consume tokenless report ──
342    let session_id = proxy_req
343        .headers
344        .iter()
345        .find(|(k, _)| k.as_str().eq_ignore_ascii_case("x-claude-code-session-id"))
346        .and_then(|(_, v)| v.to_str().ok())
347        .map(ToString::to_string);
348
349    let mut project_path = proxy_req
350        .headers
351        .iter()
352        .find(|(k, _)| {
353            let key = k.as_str().to_lowercase();
354            key == "x-claude-code-project-path" || key == "x-project-path"
355        })
356        .and_then(|(_, v)| v.to_str().ok())
357        .map(ToString::to_string);
358
359    // Log all x-* headers and billing-relevant fields for debugging
360    let billing_headers: Vec<String> = proxy_req
361        .headers
362        .iter()
363        .filter(|(k, _)| {
364            let key = k.as_str().to_lowercase();
365            key.starts_with("x-")
366        })
367        .map(|(k, v)| format!("{}={}", k.as_str(), v.to_str().unwrap_or("<binary>")))
368        .collect();
369    tracing::info!(
370        request_id = ctx.request_id,
371        session_id = ?session_id,
372        project_path = ?project_path,
373        agent_type = %agent_type,
374        headers = %billing_headers.join(", "),
375        "billing correlation headers"
376    );
377
378    if let Some(ref sid) = session_id {
379        // Always set session_id from header (regardless of report availability)
380        ctx.session_id = Some(sid.clone());
381
382        if let Some(acc) = crate::report::consume_report(sid) {
383            ctx.tokenless_saved_tokens = acc.total_saved;
384            ctx.tokenless_rtk_saved = acc.rtk_saved;
385            ctx.tokenless_response_saved = acc.response_saved;
386            ctx.tokenless_schema_saved = acc.schema_saved;
387            ctx.tokenless_breakdown_json = Some(acc.breakdown_json);
388            // Fallback: extract project_path from report if no header
389            if project_path.is_none() {
390                project_path = acc.project_path;
391            }
392            // Fallback: extract user_name from report
393            if ctx.user_name.is_none() {
394                ctx.user_name = acc.user_name;
395            }
396        }
397    }
398
399    if let Some(ref proj) = project_path {
400        ctx.project_path = Some(proj.clone());
401    }
402
403    // ── Inject compression stats from tokenless env var ────────────
404    let compression_stats = crate::compression::read_tokenless_stats();
405    if compression_stats.total_saved() > 0 {
406        ctx.insert(crate::extensions::EXT_COMPRESSION_STATS, compression_stats);
407    }
408
409    // on_request chain (registration order)
410    if let Err(e) = run_on_request_chain(&state.middlewares, &mut proxy_req, &mut ctx).await {
411        log_error(&e, &ctx);
412        return e.to_response();
413    }
414
415    // Get upstream target from extensions (set by model-router middleware)
416    let channel = ctx.get::<crate::types::ChannelConfig>(crate::extensions::EXT_SELECTED_CHANNEL);
417
418    if let Some(ch) = channel {
419        let is_streaming = proxy_req.is_streaming();
420
421        match forward_to_upstream(&state.client, &proxy_req, ch).await {
422            Ok(upstream_resp) => {
423                if is_streaming {
424                    handle_streaming_response(upstream_resp, &state, &ctx).await
425                } else {
426                    handle_non_streaming_response(upstream_resp, &state, &ctx).await
427                }
428            }
429            Err(e) => {
430                log_error(&e, &ctx);
431                e.to_response()
432            }
433        }
434    } else {
435        let err = ProxyError::ChannelSelection {
436            model: "unknown".into(),
437        };
438        log_error(&err, &ctx);
439        err.to_response()
440    }
441}
442
443/// Handles a non-streaming upstream response.
444async fn handle_non_streaming_response(
445    upstream_resp: reqwest::Response,
446    state: &Arc<ProxyState>,
447    ctx: &ConnectionContext,
448) -> Response<Body> {
449    let status = upstream_resp.status();
450    let headers = upstream_resp.headers().clone();
451
452    let body_bytes = match upstream_resp.bytes().await {
453        Ok(b) => b,
454        Err(e) => {
455            let err = ProxyError::Upstream {
456                source: format!("failed to read upstream response: {e}"),
457                inner: Some(e.into()),
458            };
459            log_error(&err, ctx);
460            return err.to_response();
461        }
462    };
463
464    let body_text = String::from_utf8_lossy(&body_bytes);
465    tracing::warn!(
466        request_id = ctx.request_id,
467        upstream_status = %status,
468        upstream_body = %body_text,
469        target_protocol = ?ctx.target_protocol,
470        channel = ?ctx.get::<crate::types::ChannelConfig>(crate::extensions::EXT_SELECTED_CHANNEL).map(|ch| ch.name.clone()),
471        "upstream response received"
472    );
473
474    let mut proxy_resp = ProxyResponse::new(status, headers, body_bytes, false);
475
476    if let Err(e) = run_on_response_chain(&state.middlewares, &mut proxy_resp, ctx).await {
477        log_error(&e, ctx);
478        return e.to_response();
479    }
480
481    // Cost recording (fire-and-forget — failures are logged but don't block)
482    if let Some(ref cr) = state.cost_recorder
483        && let Ok(body_json) = serde_json::from_slice::<serde_json::Value>(&proxy_resp.body)
484        && let Err(e) = cr.record(ctx, &body_json).await
485    {
486        tracing::warn!(
487            request_id = ctx.request_id,
488            error = %e,
489            "cost recording failed"
490        );
491    }
492
493    build_axum_response(proxy_resp)
494}
495
496/// Handles a streaming upstream response.
497///
498/// For the MVP, buffers the full response body. SSE frame-by-frame transformation
499/// will be implemented when the bridge middleware adds `transform_stream` support.
500async fn handle_streaming_response(
501    upstream_resp: reqwest::Response,
502    state: &Arc<ProxyState>,
503    ctx: &ConnectionContext,
504) -> Response<Body> {
505    let status = upstream_resp.status();
506    let headers = upstream_resp.headers().clone();
507
508    // Buffer the full streaming response (frame-by-frame transform is Phase 2)
509    let body_bytes = match upstream_resp.bytes().await {
510        Ok(b) => b,
511        Err(e) => {
512            let err = ProxyError::Upstream {
513                source: format!("failed to read streaming response: {e}"),
514                inner: Some(e.into()),
515            };
516            log_error(&err, ctx);
517            return err.to_response();
518        }
519    };
520
521    let body_text = String::from_utf8_lossy(&body_bytes);
522    tracing::warn!(
523        request_id = ctx.request_id,
524        upstream_status = %status,
525        upstream_body = %body_text,
526        target_protocol = ?ctx.target_protocol,
527        channel = ?ctx.get::<crate::types::ChannelConfig>(crate::extensions::EXT_SELECTED_CHANNEL).map(|ch| ch.name.clone()),
528        "upstream streaming response received"
529    );
530
531    let mut proxy_resp = ProxyResponse::new(status, headers, body_bytes, true);
532
533    if let Err(e) = run_on_response_chain(&state.middlewares, &mut proxy_resp, ctx).await {
534        log_error(&e, ctx);
535        return e.to_response();
536    }
537
538    // Cost recording for streaming responses (SSE usage is extracted from buffered body)
539    if let Some(ref cr) = state.cost_recorder {
540        let body_json = extract_usage_from_sse(&proxy_resp.body);
541        if let Err(e) = cr.record(ctx, &body_json).await {
542            tracing::warn!(
543                request_id = ctx.request_id,
544                error = %e,
545                "cost recording failed for streaming response"
546            );
547        }
548    }
549
550    build_axum_response(proxy_resp)
551}
552
553/// Validates a [`ProxyRequest`] before the middleware chain runs.
554///
555/// Catches obviously malformed requests early:
556/// - Non-JSON content-type
557/// - Empty body
558fn validate_proxy_request(req: &ProxyRequest) -> Result<(), ProxyError> {
559    // Reject non-JSON content-type
560    if let Some(ct) = req
561        .headers
562        .get("content-type")
563        .and_then(|v| v.to_str().ok())
564        && !ct.starts_with("application/json")
565    {
566        return Err(ProxyError::BadRequest(format!(
567            "unsupported content-type: {ct}. expected application/json"
568        )));
569    }
570
571    // Reject empty body
572    if req.body.is_empty() {
573        return Err(ProxyError::BadRequest("empty request body".into()));
574    }
575
576    Ok(())
577}
578
579/// Forwards the proxy request to the upstream server.
580///
581/// Uses `channel.rewrite_path` if set, otherwise passes through the
582/// (possibly bridge-rewritten) `proxy_req.path`.
583async fn forward_to_upstream(
584    client: &reqwest::Client,
585    proxy_req: &ProxyRequest,
586    channel: &crate::types::ChannelConfig,
587) -> Result<reqwest::Response, ProxyError> {
588    let api_key_str = channel.api_key.expose_secret().to_owned();
589
590    // Use rewrite_path if set and non-empty, otherwise use the original request path
591    let path = channel
592        .rewrite_path
593        .as_deref()
594        .filter(|p| !p.is_empty())
595        .unwrap_or(&proxy_req.path);
596    let url = format!("{}{}", channel.url.trim_end_matches('/'), path);
597
598    let mut req_builder = client
599        .request(proxy_req.method.clone(), &url)
600        .body(proxy_req.body.to_vec());
601
602    // Apply header forwarding policy: drop hop-by-hop and auth headers
603    for (key, value) in &proxy_req.headers {
604        let key_str = key.as_str().to_lowercase();
605        let should_drop = matches!(
606            key_str.as_str(),
607            "transfer-encoding"
608                | "connection"
609                | "keep-alive"
610                | "accept-encoding"
611                | "host"
612                | "content-length"
613                | "authorization"
614                | "x-api-key"
615        );
616        if !should_drop {
617            req_builder = req_builder.header(key.clone(), value.clone());
618        }
619    }
620
621    // Inject upstream auth
622    if !api_key_str.is_empty() {
623        req_builder = req_builder.header("Authorization", format!("Bearer {api_key_str}"));
624    }
625
626    req_builder.send().await.map_err(|e| {
627        if e.is_timeout() {
628            ProxyError::Upstream {
629                source: format!("upstream timeout: {e}"),
630                inner: Some(e.into()),
631            }
632        } else if e.is_connect() {
633            ProxyError::Upstream {
634                source: format!("upstream connection failed: {e}"),
635                inner: Some(e.into()),
636            }
637        } else {
638            ProxyError::Upstream {
639                source: format!("upstream request failed: {e}"),
640                inner: Some(e.into()),
641            }
642        }
643    })
644}
645
646/// Builds an axum [`Response`] from a [`ProxyResponse`].
647fn build_axum_response(proxy_resp: ProxyResponse) -> Response<Body> {
648    let mut response = Response::new(Body::from(proxy_resp.body));
649    *response.status_mut() = proxy_resp.status;
650    for (key, value) in &proxy_resp.headers {
651        if is_forward_header(key.as_str()) {
652            response.headers_mut().insert(key.clone(), value.clone());
653        }
654    }
655    response
656}
657
658/// Returns `true` if the header should be forwarded from upstream to client.
659fn is_forward_header(name: &str) -> bool {
660    let lower = name.to_lowercase();
661    !matches!(
662        lower.as_str(),
663        "transfer-encoding"
664            | "connection"
665            | "keep-alive"
666            | "content-length"
667            | "host"
668            | "authorization"
669            | "x-api-key"
670    )
671}
672
673/// Logs an error with appropriate severity.
674fn log_error(err: &ProxyError, ctx: &ConnectionContext) {
675    match err {
676        ProxyError::Internal(e) => {
677            tracing::error!(
678                request_id = ctx.request_id,
679                error = %e,
680                "internal error"
681            );
682        }
683        ProxyError::Upstream { source, .. } => {
684            tracing::warn!(
685                request_id = ctx.request_id,
686                error = %source,
687                "upstream error"
688            );
689        }
690        _ => {
691            tracing::debug!(
692                request_id = ctx.request_id,
693                error = %err,
694                "request error"
695            );
696        }
697    }
698}
699
700/// Extracts token usage from an SSE streaming response body and wraps it
701/// in a JSON value suitable for cost recording.
702///
703/// Parses `data:` lines looking for usage-bearing events:
704/// - **Anthropic**: `message_start` + `message_delta` events (merged — see below)
705/// - **`OpenAI` Chat**: last chunk with `usage` field (before `[DONE]`)
706/// - **`OpenAI` Responses**: `response.completed` event with `usage` field
707///
708/// For Anthropic streams, `message_start` carries `input_tokens` while
709/// `message_delta` carries `output_tokens` + cache fields. We merge them
710/// instead of overwriting, so both input and output tokens are captured.
711fn extract_usage_from_sse(body: &[u8]) -> serde_json::Value {
712    let Ok(text) = std::str::from_utf8(body) else {
713        return serde_json::Value::Null;
714    };
715
716    // Normalize SSE format: ensure `data:` has a space after colon
717    // Different providers use different formats:
718    // - Anthropic/DeepSeek: `data: {...}` (with space)
719    // - DashScope: `data:{...}` (no space)
720    let normalized = normalize_sse_format(text);
721
722    let mut merged: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
723
724    for line in normalized.lines() {
725        let Some(data) = line.strip_prefix("data: ") else {
726            continue;
727        };
728        if data.is_empty() || data == "[DONE]" {
729            continue;
730        }
731        let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else {
732            continue;
733        };
734
735        // Anthropic message_start — carries input_tokens
736        if event.get("type").and_then(|v| v.as_str()) == Some("message_start")
737            && let Some(u) = event.get("message").and_then(|m| m.get("usage"))
738        {
739            merge_usage_fields(&mut merged, u);
740        }
741        // Anthropic message_delta — carries output_tokens + cache fields
742        if event.get("type").and_then(|v| v.as_str()) == Some("message_delta")
743            && let Some(u) = event.get("usage")
744        {
745            merge_usage_fields(&mut merged, u);
746        }
747        // OpenAI Responses completed
748        if event.get("type").and_then(|v| v.as_str()) == Some("response.completed")
749            && let Some(u) = event.get("response").and_then(|r| r.get("usage"))
750        {
751            merge_usage_fields(&mut merged, u);
752        }
753        // OpenAI Chat: has "choices" and "usage"
754        if event.get("choices").is_some()
755            && let Some(u) = event.get("usage")
756        {
757            merge_usage_fields(&mut merged, u);
758        }
759        // Other providers: usage at top level (no choices wrapper)
760        if let Some(u) = event.get("usage")
761            && event.get("choices").is_none()
762            && event.get("type").is_none()
763        {
764            merge_usage_fields(&mut merged, u);
765        }
766    }
767
768    if merged.is_empty() {
769        serde_json::Value::Null
770    } else {
771        serde_json::json!({"usage": serde_json::Value::Object(merged)})
772    }
773}
774
775/// Normalizes SSE format to ensure consistent parsing.
776///
777/// Handles format variations from different providers:
778/// - `data:{...}` → `data: {...}` (add space after colon)
779/// - `event:message_start` → `event: message_start`
780/// - Strips trailing whitespace from lines
781#[must_use]
782fn normalize_sse_format(text: &str) -> String {
783    text.lines()
784        .map(|line| {
785            let line = line.trim_end();
786            // Add space after `data:` if missing
787            if let Some(rest) = line.strip_prefix("data:")
788                && !rest.starts_with(' ')
789            {
790                return format!("data: {rest}");
791            }
792            // Add space after `event:` if missing
793            if let Some(rest) = line.strip_prefix("event:")
794                && !rest.starts_with(' ')
795            {
796                return format!("event: {rest}");
797            }
798            line.to_owned()
799        })
800        .collect::<Vec<_>>()
801        .join("\n")
802}
803
804/// Merges usage fields from an SSE event into the accumulator map.
805///
806/// Existing keys are overwritten only when the incoming value is a non-zero
807/// number, so later events (e.g. `message_delta` with `output_tokens`) can
808/// update fields while `message_start`'s `input_tokens` is preserved.
809fn merge_usage_fields(
810    acc: &mut serde_json::Map<String, serde_json::Value>,
811    usage: &serde_json::Value,
812) {
813    if let Some(obj) = usage.as_object() {
814        for (k, v) in obj {
815            let is_nonzero_number =
816                v.as_u64().is_some_and(|n| n > 0) || v.as_f64().is_some_and(|f| f > 0.0);
817            if is_nonzero_number || !acc.contains_key(k) {
818                acc.insert(k.clone(), v.clone());
819            }
820        }
821    }
822}
823
824#[cfg(test)]
825#[allow(clippy::unwrap_used, clippy::expect_used)]
826mod tests {
827    use async_trait::async_trait;
828    use axum::{body::Body, http::StatusCode};
829    use tower::ServiceExt;
830
831    use super::*;
832    use crate::{
833        middleware::ProxyMiddleware,
834        types::{ApiFormat, ChannelConfig},
835    };
836
837    /// Mock middleware that sets an upstream URL via extensions.
838    struct UpstreamMiddleware {
839        url: String,
840    }
841
842    #[async_trait]
843    impl ProxyMiddleware for UpstreamMiddleware {
844        async fn on_request(
845            &self,
846            _req: &mut ProxyRequest,
847            ctx: &mut ConnectionContext,
848        ) -> Result<(), ProxyError> {
849            ctx.insert(
850                crate::extensions::EXT_SELECTED_CHANNEL,
851                ChannelConfig {
852                    url: self.url.clone(),
853                    api_key: secrecy::SecretString::from("sk-test"),
854                    protocol: ApiFormat::AnthropicMessages,
855                    name: "test".into(),
856                    rewrite_path: None,
857                },
858            );
859            Ok(())
860        }
861
862        async fn on_response(
863            &self,
864            _res: &mut ProxyResponse,
865            _ctx: &ConnectionContext,
866        ) -> Result<(), ProxyError> {
867            Ok(())
868        }
869
870        fn name(&self) -> &'static str {
871            "upstream"
872        }
873    }
874
875    /// Builds a test-only router (without server binding).
876    fn build_test_router(
877        config: ProxyConfig,
878        middlewares: Vec<Box<dyn ProxyMiddleware>>,
879    ) -> Router {
880        let client = reqwest::Client::builder()
881            .http1_only()
882            .build()
883            .expect("build test client");
884
885        let state = Arc::new(ProxyState {
886            config: Arc::new(config),
887            middlewares: Arc::new(middlewares),
888            client,
889            cost_recorder: None,
890            next_request_id: Arc::new(AtomicU64::new(1)),
891        });
892
893        build_router(state)
894    }
895
896    #[tokio::test]
897    async fn test_health_endpoint_returns_200() {
898        let config = ProxyConfig::default();
899        let router = build_test_router(config, vec![]);
900
901        let response = router
902            .oneshot(
903                Request::builder()
904                    .uri("/health")
905                    .body(Body::empty())
906                    .unwrap(),
907            )
908            .await
909            .unwrap();
910
911        assert_eq!(response.status(), StatusCode::OK);
912    }
913
914    #[tokio::test]
915    async fn test_unknown_path_returns_404() {
916        let config = ProxyConfig::default();
917        let router = build_test_router(config, vec![]);
918
919        let response = router
920            .oneshot(
921                Request::builder()
922                    .uri("/unknown/path")
923                    .method("POST")
924                    .header("content-type", "application/json")
925                    .body(Body::from(r#"{"model":"test"}"#))
926                    .unwrap(),
927            )
928            .await
929            .unwrap();
930
931        assert_eq!(response.status(), StatusCode::NOT_FOUND);
932    }
933
934    #[tokio::test]
935    async fn test_auth_failure_returns_401() {
936        let config = ProxyConfig {
937            proxy_api_key: Some(secrecy::SecretString::new("sk-secret".into())),
938            ..Default::default()
939        };
940        let router = build_test_router(config, vec![]);
941
942        let response = router
943            .oneshot(
944                Request::builder()
945                    .uri("/health")
946                    .body(Body::empty())
947                    .unwrap(),
948            )
949            .await
950            .unwrap();
951
952        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
953    }
954
955    #[tokio::test]
956    async fn test_auth_success_passes_through() {
957        let config = ProxyConfig {
958            proxy_api_key: Some(secrecy::SecretString::new("sk-secret".into())),
959            ..Default::default()
960        };
961        let router = build_test_router(config, vec![]);
962
963        let response = router
964            .oneshot(
965                Request::builder()
966                    .uri("/health")
967                    .header("authorization", "Bearer sk-secret")
968                    .body(Body::empty())
969                    .unwrap(),
970            )
971            .await
972            .unwrap();
973
974        assert_eq!(response.status(), StatusCode::OK);
975    }
976
977    #[tokio::test]
978    async fn test_body_too_large_returns_413() {
979        let config = ProxyConfig {
980            max_body_size: 1024, // 1KB limit
981            ..Default::default()
982        };
983        let router = build_test_router(config, vec![]);
984
985        let big_body = "x".repeat(2048);
986        let response = router
987            .oneshot(
988                Request::builder()
989                    .uri("/v1/messages")
990                    .method("POST")
991                    .header("content-type", "application/json")
992                    .body(Body::from(big_body))
993                    .unwrap(),
994            )
995            .await
996            .unwrap();
997
998        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
999    }
1000
1001    #[tokio::test]
1002    async fn test_no_channel_returns_503() {
1003        let config = ProxyConfig::default();
1004        let router = build_test_router(config, vec![]);
1005
1006        let response = router
1007            .oneshot(
1008                Request::builder()
1009                    .uri("/v1/messages")
1010                    .method("POST")
1011                    .header("content-type", "application/json")
1012                    .body(Body::from(
1013                        r#"{"model":"claude-sonnet","messages":[{"role":"user","content":"hi"}]}"#,
1014                    ))
1015                    .unwrap(),
1016            )
1017            .await
1018            .unwrap();
1019
1020        assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
1021    }
1022
1023    /// Starts a simple HTTP server for testing upstream forwarding.
1024    async fn start_mock_upstream() -> (String, JoinHandle<()>) {
1025        use axum::routing::post;
1026
1027        async fn mock_messages_handler() -> Json<serde_json::Value> {
1028            Json(serde_json::json!({
1029                "id": "msg_123",
1030                "type": "message",
1031                "role": "assistant",
1032                "content": [{"type": "text", "text": "Hello from upstream!"}],
1033                "model": "claude-sonnet",
1034                "usage": {"input_tokens": 10, "output_tokens": 20}
1035            }))
1036        }
1037
1038        let app = Router::new().route("/v1/messages", post(mock_messages_handler));
1039
1040        let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
1041            .await
1042            .expect("bind");
1043        let addr = listener.local_addr().expect("local addr");
1044
1045        let handle = tokio::spawn(async move {
1046            axum::serve(listener, app).await.unwrap();
1047        });
1048
1049        (format!("http://{addr}"), handle)
1050    }
1051
1052    #[tokio::test]
1053    async fn test_successful_proxy_returns_200() {
1054        let (upstream_url, _upstream_handle) = start_mock_upstream().await;
1055
1056        let config = ProxyConfig::default();
1057        let middlewares: Vec<Box<dyn ProxyMiddleware>> =
1058            vec![Box::new(UpstreamMiddleware { url: upstream_url })];
1059
1060        let router = build_test_router(config, middlewares);
1061
1062        let response = router
1063            .oneshot(
1064                Request::builder()
1065                    .uri("/v1/messages")
1066                    .method("POST")
1067                    .header("content-type", "application/json")
1068                    .body(Body::from(
1069                        r#"{"model":"claude-sonnet","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}"#,
1070                    ))
1071                    .unwrap(),
1072            )
1073            .await
1074            .unwrap();
1075
1076        assert_eq!(response.status(), StatusCode::OK);
1077    }
1078
1079    // ── extract_usage_from_sse tests ──────────────────────────────
1080
1081    #[test]
1082    fn test_extract_usage_from_sse_with_space() {
1083        // DeepSeek format: `data: {...}` (with space after colon)
1084        let body = b"event: message_start\n\
1085            data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":100,\"output_tokens\":0}}}\n\n\
1086            event: message_delta\n\
1087            data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":50,\"cache_read_input_tokens\":30}}\n\n";
1088        let result = extract_usage_from_sse(body);
1089        let usage = result.get("usage").unwrap();
1090        assert_eq!(usage.get("input_tokens").unwrap().as_u64().unwrap(), 100);
1091        assert_eq!(usage.get("output_tokens").unwrap().as_u64().unwrap(), 50);
1092        assert_eq!(
1093            usage
1094                .get("cache_read_input_tokens")
1095                .unwrap()
1096                .as_u64()
1097                .unwrap(),
1098            30
1099        );
1100    }
1101
1102    #[test]
1103    fn test_extract_usage_from_sse_without_space() {
1104        // DashScope format: `data:{...}` (no space after colon)
1105        let body = b"event:message_start\n\
1106            data:{\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":200,\"output_tokens\":0}}}\n\n\
1107            event:message_delta\n\
1108            data:{\"type\":\"message_delta\",\"usage\":{\"output_tokens\":80,\"cache_read_input_tokens\":60}}\n\n";
1109        let result = extract_usage_from_sse(body);
1110        let usage = result.get("usage").unwrap();
1111        assert_eq!(usage.get("input_tokens").unwrap().as_u64().unwrap(), 200);
1112        assert_eq!(usage.get("output_tokens").unwrap().as_u64().unwrap(), 80);
1113        assert_eq!(
1114            usage
1115                .get("cache_read_input_tokens")
1116                .unwrap()
1117                .as_u64()
1118                .unwrap(),
1119            60
1120        );
1121    }
1122
1123    #[test]
1124    fn test_extract_usage_from_sse_mixed_format() {
1125        // Mixed format (shouldn't happen in practice, but test robustness)
1126        let body = b"event:message_start\n\
1127            data:{\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":150,\"output_tokens\":0}}}\n\n\
1128            event: message_delta\n\
1129            data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":90}}\n\n";
1130        let result = extract_usage_from_sse(body);
1131        let usage = result.get("usage").unwrap();
1132        assert_eq!(usage.get("input_tokens").unwrap().as_u64().unwrap(), 150);
1133        assert_eq!(usage.get("output_tokens").unwrap().as_u64().unwrap(), 90);
1134    }
1135
1136    #[test]
1137    fn test_normalize_sse_format() {
1138        // Test DashScope format (no space)
1139        let input = "event:message_start\ndata:{\"type\":\"message_start\"}\n\n";
1140        let output = normalize_sse_format(input);
1141        assert!(output.contains("event: message_start"));
1142        assert!(output.contains("data: {\"type\":\"message_start\"}"));
1143
1144        // Test standard format (with space) - should remain unchanged
1145        let input2 = "event: message_start\ndata: {\"type\":\"message_start\"}\n\n";
1146        let output2 = normalize_sse_format(input2);
1147        assert_eq!(output2.trim(), input2.trim());
1148
1149        // Test mixed format
1150        let input3 = "event:message_start\ndata: {\"type\":\"message_start\"}\n\nevent: message_delta\ndata:{\"type\":\"message_delta\"}";
1151        let output3 = normalize_sse_format(input3);
1152        assert!(output3.contains("event: message_start"));
1153        assert!(output3.contains("data: {\"type\":\"message_start\"}"));
1154        assert!(output3.contains("event: message_delta"));
1155        assert!(output3.contains("data: {\"type\":\"message_delta\"}"));
1156    }
1157}