1use std::sync::{
6 Arc,
7 atomic::{AtomicU64, Ordering},
8};
9
10use axum::{
11 Json, Router,
12 body::Body,
13 extract::State,
14 http::{Request, Response, StatusCode},
15 middleware,
16 response::IntoResponse,
17 routing::{get, post},
18};
19use tokio::task::JoinHandle;
20use tower_http::limit::RequestBodyLimitLayer;
21
22use secrecy::ExposeSecret;
23
24use crate::{
25 auth::{self, AgentRole, AuthState},
26 config::ProxyConfig,
27 error::ProxyError,
28 middleware::{CostRecorder, ProxyMiddleware, run_on_request_chain, run_on_response_chain},
29 types::{ConnectionContext, ProxyRequest, ProxyResponse, detect_agent_type, detect_api_format},
30};
31
32#[derive(Clone)]
34pub struct ProxyState {
35 pub config: Arc<ProxyConfig>,
37 pub middlewares: Arc<Vec<Box<dyn ProxyMiddleware>>>,
39 pub client: reqwest::Client,
41 pub cost_recorder: Option<Arc<dyn CostRecorder>>,
43 next_request_id: Arc<AtomicU64>,
44}
45
46impl ProxyState {
47 fn next_request_id(&self) -> u64 {
48 self.next_request_id.fetch_add(1, Ordering::SeqCst)
49 }
50}
51
52impl std::fmt::Debug for ProxyState {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 let mw_names: Vec<&str> = self.middlewares.iter().map(|m| m.name()).collect();
55 f.debug_struct("ProxyState")
56 .field("config", &self.config)
57 .field("middlewares", &mw_names)
58 .field("client", &self.client)
59 .field(
60 "cost_recorder",
61 &self.cost_recorder.as_ref().map(|_| "CostRecorder"),
62 )
63 .field("next_request_id", &self.next_request_id)
64 .finish()
65 }
66}
67
68pub struct AgentProxy {
72 config: ProxyConfig,
73 middlewares: Arc<Vec<Box<dyn ProxyMiddleware>>>,
74 cost_recorder: Option<Arc<dyn CostRecorder>>,
75}
76
77impl AgentProxy {
78 #[must_use]
80 pub fn builder() -> AgentProxyBuilder {
81 AgentProxyBuilder::default()
82 }
83
84 pub fn into_router(self) -> Result<Router, ProxyError> {
92 let client = build_reqwest_client(&self.config)?;
93 let state = Arc::new(ProxyState {
94 config: Arc::new(self.config),
95 middlewares: self.middlewares,
96 client,
97 cost_recorder: self.cost_recorder,
98 next_request_id: Arc::new(AtomicU64::new(1)),
99 });
100 Ok(build_router(state))
101 }
102
103 pub async fn serve(self) -> Result<JoinHandle<()>, ProxyError> {
111 let client = build_reqwest_client(&self.config)?;
112
113 let state = Arc::new(ProxyState {
114 config: Arc::new(self.config),
115 middlewares: self.middlewares,
116 client,
117 cost_recorder: self.cost_recorder,
118 next_request_id: Arc::new(AtomicU64::new(1)),
119 });
120
121 for mw in state.middlewares.iter() {
123 mw.on_init().await?;
124 }
125
126 let app = build_router(state.clone());
127 let listener = tokio::net::TcpListener::bind(state.config.listen)
128 .await
129 .map_err(|e| ProxyError::Internal(e.into()))?;
130
131 tracing::warn!("agent-proxy listening on {}", state.config.listen);
132
133 let handle = tokio::spawn(async move {
134 if let Err(e) = axum::serve(listener, app).await {
135 tracing::error!("server error: {e}");
136 }
137 });
138
139 Ok(handle)
140 }
141}
142
143#[derive(Default)]
157pub struct AgentProxyBuilder {
158 config: Option<ProxyConfig>,
159 middlewares: Vec<Box<dyn ProxyMiddleware>>,
160 cost_recorder: Option<Arc<dyn CostRecorder>>,
161}
162
163impl AgentProxyBuilder {
164 #[must_use]
166 pub fn cost_recorder(mut self, cr: Arc<dyn CostRecorder>) -> Self {
167 self.cost_recorder = Some(cr);
168 self
169 }
170
171 #[must_use]
173 pub fn config(mut self, config: ProxyConfig) -> Self {
174 self.config = Some(config);
175 self
176 }
177
178 #[must_use]
180 pub fn middleware<M: ProxyMiddleware + 'static>(mut self, m: M) -> Self {
181 self.middlewares.push(Box::new(m));
182 self
183 }
184
185 pub fn build(self) -> Result<AgentProxy, ProxyError> {
191 let config = self
192 .config
193 .ok_or_else(|| ProxyError::Internal(anyhow::anyhow!("config is required")))?;
194 Ok(AgentProxy {
195 config,
196 middlewares: Arc::new(self.middlewares),
197 cost_recorder: self.cost_recorder,
198 })
199 }
200}
201
202impl std::fmt::Debug for AgentProxyBuilder {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 let mw_names: Vec<&str> = self.middlewares.iter().map(|m| m.name()).collect();
205 f.debug_struct("AgentProxyBuilder")
206 .field("config", &self.config)
207 .field("middlewares", &mw_names)
208 .field("cost_recorder", &self.cost_recorder.is_some())
209 .finish()
210 }
211}
212
213impl std::fmt::Debug for AgentProxy {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 let mw_names: Vec<&str> = self.middlewares.iter().map(|m| m.name()).collect();
216 f.debug_struct("AgentProxy")
217 .field("config", &self.config)
218 .field("middlewares", &mw_names)
219 .field("cost_recorder", &self.cost_recorder.is_some())
220 .finish()
221 }
222}
223
224fn build_reqwest_client(config: &ProxyConfig) -> Result<reqwest::Client, ProxyError> {
226 reqwest::Client::builder()
227 .connect_timeout(config.upstream_connect_timeout)
228 .read_timeout(config.upstream_read_timeout)
229 .http1_only()
230 .build()
231 .map_err(|e| ProxyError::Internal(e.into()))
232}
233
234fn build_router(state: Arc<ProxyState>) -> Router {
236 let auth_state = AuthState::from_config(&state.config);
237
238 Router::new()
239 .route("/v1/messages", post(handle_proxy_request))
240 .route("/v1/chat/completions", post(handle_proxy_request))
241 .route("/v1/responses", post(handle_proxy_request))
242 .route("/health", get(handle_health))
243 .layer(middleware::from_fn_with_state(
244 auth_state,
245 auth::auth_middleware,
246 ))
247 .layer(RequestBodyLimitLayer::new(state.config.max_body_size))
248 .with_state(state)
249}
250
251async fn handle_health() -> Json<serde_json::Value> {
253 Json(serde_json::json!({"status": "ok"}))
254}
255
256#[allow(clippy::too_many_lines)]
266async fn handle_proxy_request(
267 State(state): State<Arc<ProxyState>>,
268 req: Request<Body>,
269) -> Response<Body> {
270 let request_id = state.next_request_id();
271 let path = req.uri().path().to_string();
272 let detected_format = detect_api_format(&path);
273
274 if detected_format.is_none() {
276 return (
277 StatusCode::NOT_FOUND,
278 Json(serde_json::json!({
279 "error": {"code": "not_found", "message": format!("no route for {path}")}
280 })),
281 )
282 .into_response();
283 }
284
285 let (parts, body) = req.into_parts();
287
288 let body_too_large = parts
290 .headers
291 .get("content-length")
292 .and_then(|cl| cl.to_str().ok())
293 .and_then(|s| s.parse::<usize>().ok())
294 .is_some_and(|len| len > state.config.max_body_size);
295
296 if body_too_large {
297 return (
298 StatusCode::PAYLOAD_TOO_LARGE,
299 Json(serde_json::json!({
300 "error": {
301 "code": "body_too_large",
302 "message": format!("request body exceeds size limit of {}", state.config.max_body_size)
303 }
304 })),
305 )
306 .into_response();
307 }
308
309 let body_bytes = match axum::body::to_bytes(body, state.config.max_body_size).await {
310 Ok(b) => b,
311 Err(_e) => {
312 return (
313 StatusCode::PAYLOAD_TOO_LARGE,
314 Json(serde_json::json!({
315 "error": {
316 "code": "body_too_large",
317 "message": "request body exceeds size limit"
318 }
319 })),
320 )
321 .into_response();
322 }
323 };
324
325 let agent_type = detect_agent_type(&parts.headers, &path);
326 let agent_role = parts.extensions.get::<AgentRole>().map(|r| r.0.clone());
327
328 let mut proxy_req = ProxyRequest::new(parts.method, path, parts.headers, body_bytes);
329
330 if let Err(e) = validate_proxy_request(&proxy_req) {
332 log_error(
333 &e,
334 &ConnectionContext::new(request_id, agent_type, agent_role.clone(), detected_format),
335 );
336 return e.to_response();
337 }
338
339 let mut ctx = ConnectionContext::new(request_id, agent_type, agent_role, detected_format);
340
341 let session_id = proxy_req
343 .headers
344 .iter()
345 .find(|(k, _)| k.as_str().eq_ignore_ascii_case("x-claude-code-session-id"))
346 .and_then(|(_, v)| v.to_str().ok())
347 .map(ToString::to_string);
348
349 let mut project_path = proxy_req
350 .headers
351 .iter()
352 .find(|(k, _)| {
353 let key = k.as_str().to_lowercase();
354 key == "x-claude-code-project-path" || key == "x-project-path"
355 })
356 .and_then(|(_, v)| v.to_str().ok())
357 .map(ToString::to_string);
358
359 let billing_headers: Vec<String> = proxy_req
361 .headers
362 .iter()
363 .filter(|(k, _)| {
364 let key = k.as_str().to_lowercase();
365 key.starts_with("x-")
366 })
367 .map(|(k, v)| format!("{}={}", k.as_str(), v.to_str().unwrap_or("<binary>")))
368 .collect();
369 tracing::info!(
370 request_id = ctx.request_id,
371 session_id = ?session_id,
372 project_path = ?project_path,
373 agent_type = %agent_type,
374 headers = %billing_headers.join(", "),
375 "billing correlation headers"
376 );
377
378 if let Some(ref sid) = session_id {
379 ctx.session_id = Some(sid.clone());
381
382 if let Some(acc) = crate::report::consume_report(sid) {
383 ctx.tokenless_saved_tokens = acc.total_saved;
384 ctx.tokenless_rtk_saved = acc.rtk_saved;
385 ctx.tokenless_response_saved = acc.response_saved;
386 ctx.tokenless_schema_saved = acc.schema_saved;
387 ctx.tokenless_breakdown_json = Some(acc.breakdown_json);
388 if project_path.is_none() {
390 project_path = acc.project_path;
391 }
392 if ctx.user_name.is_none() {
394 ctx.user_name = acc.user_name;
395 }
396 }
397 }
398
399 if let Some(ref proj) = project_path {
400 ctx.project_path = Some(proj.clone());
401 }
402
403 let compression_stats = crate::compression::read_tokenless_stats();
405 if compression_stats.total_saved() > 0 {
406 ctx.insert(crate::extensions::EXT_COMPRESSION_STATS, compression_stats);
407 }
408
409 if let Err(e) = run_on_request_chain(&state.middlewares, &mut proxy_req, &mut ctx).await {
411 log_error(&e, &ctx);
412 return e.to_response();
413 }
414
415 let channel = ctx.get::<crate::types::ChannelConfig>(crate::extensions::EXT_SELECTED_CHANNEL);
417
418 if let Some(ch) = channel {
419 let is_streaming = proxy_req.is_streaming();
420
421 match forward_to_upstream(&state.client, &proxy_req, ch).await {
422 Ok(upstream_resp) => {
423 if is_streaming {
424 handle_streaming_response(upstream_resp, &state, &ctx).await
425 } else {
426 handle_non_streaming_response(upstream_resp, &state, &ctx).await
427 }
428 }
429 Err(e) => {
430 log_error(&e, &ctx);
431 e.to_response()
432 }
433 }
434 } else {
435 let err = ProxyError::ChannelSelection {
436 model: "unknown".into(),
437 };
438 log_error(&err, &ctx);
439 err.to_response()
440 }
441}
442
443async fn handle_non_streaming_response(
445 upstream_resp: reqwest::Response,
446 state: &Arc<ProxyState>,
447 ctx: &ConnectionContext,
448) -> Response<Body> {
449 let status = upstream_resp.status();
450 let headers = upstream_resp.headers().clone();
451
452 let body_bytes = match upstream_resp.bytes().await {
453 Ok(b) => b,
454 Err(e) => {
455 let err = ProxyError::Upstream {
456 source: format!("failed to read upstream response: {e}"),
457 inner: Some(e.into()),
458 };
459 log_error(&err, ctx);
460 return err.to_response();
461 }
462 };
463
464 let body_text = String::from_utf8_lossy(&body_bytes);
465 tracing::warn!(
466 request_id = ctx.request_id,
467 upstream_status = %status,
468 upstream_body = %body_text,
469 target_protocol = ?ctx.target_protocol,
470 channel = ?ctx.get::<crate::types::ChannelConfig>(crate::extensions::EXT_SELECTED_CHANNEL).map(|ch| ch.name.clone()),
471 "upstream response received"
472 );
473
474 let mut proxy_resp = ProxyResponse::new(status, headers, body_bytes, false);
475
476 if let Err(e) = run_on_response_chain(&state.middlewares, &mut proxy_resp, ctx).await {
477 log_error(&e, ctx);
478 return e.to_response();
479 }
480
481 if let Some(ref cr) = state.cost_recorder
483 && let Ok(body_json) = serde_json::from_slice::<serde_json::Value>(&proxy_resp.body)
484 && let Err(e) = cr.record(ctx, &body_json).await
485 {
486 tracing::warn!(
487 request_id = ctx.request_id,
488 error = %e,
489 "cost recording failed"
490 );
491 }
492
493 build_axum_response(proxy_resp)
494}
495
496async fn handle_streaming_response(
501 upstream_resp: reqwest::Response,
502 state: &Arc<ProxyState>,
503 ctx: &ConnectionContext,
504) -> Response<Body> {
505 let status = upstream_resp.status();
506 let headers = upstream_resp.headers().clone();
507
508 let body_bytes = match upstream_resp.bytes().await {
510 Ok(b) => b,
511 Err(e) => {
512 let err = ProxyError::Upstream {
513 source: format!("failed to read streaming response: {e}"),
514 inner: Some(e.into()),
515 };
516 log_error(&err, ctx);
517 return err.to_response();
518 }
519 };
520
521 let body_text = String::from_utf8_lossy(&body_bytes);
522 tracing::warn!(
523 request_id = ctx.request_id,
524 upstream_status = %status,
525 upstream_body = %body_text,
526 target_protocol = ?ctx.target_protocol,
527 channel = ?ctx.get::<crate::types::ChannelConfig>(crate::extensions::EXT_SELECTED_CHANNEL).map(|ch| ch.name.clone()),
528 "upstream streaming response received"
529 );
530
531 let mut proxy_resp = ProxyResponse::new(status, headers, body_bytes, true);
532
533 if let Err(e) = run_on_response_chain(&state.middlewares, &mut proxy_resp, ctx).await {
534 log_error(&e, ctx);
535 return e.to_response();
536 }
537
538 if let Some(ref cr) = state.cost_recorder {
540 let body_json = extract_usage_from_sse(&proxy_resp.body);
541 if let Err(e) = cr.record(ctx, &body_json).await {
542 tracing::warn!(
543 request_id = ctx.request_id,
544 error = %e,
545 "cost recording failed for streaming response"
546 );
547 }
548 }
549
550 build_axum_response(proxy_resp)
551}
552
553fn validate_proxy_request(req: &ProxyRequest) -> Result<(), ProxyError> {
559 if let Some(ct) = req
561 .headers
562 .get("content-type")
563 .and_then(|v| v.to_str().ok())
564 && !ct.starts_with("application/json")
565 {
566 return Err(ProxyError::BadRequest(format!(
567 "unsupported content-type: {ct}. expected application/json"
568 )));
569 }
570
571 if req.body.is_empty() {
573 return Err(ProxyError::BadRequest("empty request body".into()));
574 }
575
576 Ok(())
577}
578
579async fn forward_to_upstream(
584 client: &reqwest::Client,
585 proxy_req: &ProxyRequest,
586 channel: &crate::types::ChannelConfig,
587) -> Result<reqwest::Response, ProxyError> {
588 let api_key_str = channel.api_key.expose_secret().to_owned();
589
590 let path = channel
592 .rewrite_path
593 .as_deref()
594 .filter(|p| !p.is_empty())
595 .unwrap_or(&proxy_req.path);
596 let url = format!("{}{}", channel.url.trim_end_matches('/'), path);
597
598 let mut req_builder = client
599 .request(proxy_req.method.clone(), &url)
600 .body(proxy_req.body.to_vec());
601
602 for (key, value) in &proxy_req.headers {
604 let key_str = key.as_str().to_lowercase();
605 let should_drop = matches!(
606 key_str.as_str(),
607 "transfer-encoding"
608 | "connection"
609 | "keep-alive"
610 | "accept-encoding"
611 | "host"
612 | "content-length"
613 | "authorization"
614 | "x-api-key"
615 );
616 if !should_drop {
617 req_builder = req_builder.header(key.clone(), value.clone());
618 }
619 }
620
621 if !api_key_str.is_empty() {
623 req_builder = req_builder.header("Authorization", format!("Bearer {api_key_str}"));
624 }
625
626 req_builder.send().await.map_err(|e| {
627 if e.is_timeout() {
628 ProxyError::Upstream {
629 source: format!("upstream timeout: {e}"),
630 inner: Some(e.into()),
631 }
632 } else if e.is_connect() {
633 ProxyError::Upstream {
634 source: format!("upstream connection failed: {e}"),
635 inner: Some(e.into()),
636 }
637 } else {
638 ProxyError::Upstream {
639 source: format!("upstream request failed: {e}"),
640 inner: Some(e.into()),
641 }
642 }
643 })
644}
645
646fn build_axum_response(proxy_resp: ProxyResponse) -> Response<Body> {
648 let mut response = Response::new(Body::from(proxy_resp.body));
649 *response.status_mut() = proxy_resp.status;
650 for (key, value) in &proxy_resp.headers {
651 if is_forward_header(key.as_str()) {
652 response.headers_mut().insert(key.clone(), value.clone());
653 }
654 }
655 response
656}
657
658fn is_forward_header(name: &str) -> bool {
660 let lower = name.to_lowercase();
661 !matches!(
662 lower.as_str(),
663 "transfer-encoding"
664 | "connection"
665 | "keep-alive"
666 | "content-length"
667 | "host"
668 | "authorization"
669 | "x-api-key"
670 )
671}
672
673fn log_error(err: &ProxyError, ctx: &ConnectionContext) {
675 match err {
676 ProxyError::Internal(e) => {
677 tracing::error!(
678 request_id = ctx.request_id,
679 error = %e,
680 "internal error"
681 );
682 }
683 ProxyError::Upstream { source, .. } => {
684 tracing::warn!(
685 request_id = ctx.request_id,
686 error = %source,
687 "upstream error"
688 );
689 }
690 _ => {
691 tracing::debug!(
692 request_id = ctx.request_id,
693 error = %err,
694 "request error"
695 );
696 }
697 }
698}
699
700fn extract_usage_from_sse(body: &[u8]) -> serde_json::Value {
712 let Ok(text) = std::str::from_utf8(body) else {
713 return serde_json::Value::Null;
714 };
715
716 let normalized = normalize_sse_format(text);
721
722 let mut merged: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
723
724 for line in normalized.lines() {
725 let Some(data) = line.strip_prefix("data: ") else {
726 continue;
727 };
728 if data.is_empty() || data == "[DONE]" {
729 continue;
730 }
731 let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else {
732 continue;
733 };
734
735 if event.get("type").and_then(|v| v.as_str()) == Some("message_start")
737 && let Some(u) = event.get("message").and_then(|m| m.get("usage"))
738 {
739 merge_usage_fields(&mut merged, u);
740 }
741 if event.get("type").and_then(|v| v.as_str()) == Some("message_delta")
743 && let Some(u) = event.get("usage")
744 {
745 merge_usage_fields(&mut merged, u);
746 }
747 if event.get("type").and_then(|v| v.as_str()) == Some("response.completed")
749 && let Some(u) = event.get("response").and_then(|r| r.get("usage"))
750 {
751 merge_usage_fields(&mut merged, u);
752 }
753 if event.get("choices").is_some()
755 && let Some(u) = event.get("usage")
756 {
757 merge_usage_fields(&mut merged, u);
758 }
759 if let Some(u) = event.get("usage")
761 && event.get("choices").is_none()
762 && event.get("type").is_none()
763 {
764 merge_usage_fields(&mut merged, u);
765 }
766 }
767
768 if merged.is_empty() {
769 serde_json::Value::Null
770 } else {
771 serde_json::json!({"usage": serde_json::Value::Object(merged)})
772 }
773}
774
775#[must_use]
782fn normalize_sse_format(text: &str) -> String {
783 text.lines()
784 .map(|line| {
785 let line = line.trim_end();
786 if let Some(rest) = line.strip_prefix("data:")
788 && !rest.starts_with(' ')
789 {
790 return format!("data: {rest}");
791 }
792 if let Some(rest) = line.strip_prefix("event:")
794 && !rest.starts_with(' ')
795 {
796 return format!("event: {rest}");
797 }
798 line.to_owned()
799 })
800 .collect::<Vec<_>>()
801 .join("\n")
802}
803
804fn merge_usage_fields(
810 acc: &mut serde_json::Map<String, serde_json::Value>,
811 usage: &serde_json::Value,
812) {
813 if let Some(obj) = usage.as_object() {
814 for (k, v) in obj {
815 let is_nonzero_number =
816 v.as_u64().is_some_and(|n| n > 0) || v.as_f64().is_some_and(|f| f > 0.0);
817 if is_nonzero_number || !acc.contains_key(k) {
818 acc.insert(k.clone(), v.clone());
819 }
820 }
821 }
822}
823
824#[cfg(test)]
825#[allow(clippy::unwrap_used, clippy::expect_used)]
826mod tests {
827 use async_trait::async_trait;
828 use axum::{body::Body, http::StatusCode};
829 use tower::ServiceExt;
830
831 use super::*;
832 use crate::{
833 middleware::ProxyMiddleware,
834 types::{ApiFormat, ChannelConfig},
835 };
836
837 struct UpstreamMiddleware {
839 url: String,
840 }
841
842 #[async_trait]
843 impl ProxyMiddleware for UpstreamMiddleware {
844 async fn on_request(
845 &self,
846 _req: &mut ProxyRequest,
847 ctx: &mut ConnectionContext,
848 ) -> Result<(), ProxyError> {
849 ctx.insert(
850 crate::extensions::EXT_SELECTED_CHANNEL,
851 ChannelConfig {
852 url: self.url.clone(),
853 api_key: secrecy::SecretString::from("sk-test"),
854 protocol: ApiFormat::AnthropicMessages,
855 name: "test".into(),
856 rewrite_path: None,
857 },
858 );
859 Ok(())
860 }
861
862 async fn on_response(
863 &self,
864 _res: &mut ProxyResponse,
865 _ctx: &ConnectionContext,
866 ) -> Result<(), ProxyError> {
867 Ok(())
868 }
869
870 fn name(&self) -> &'static str {
871 "upstream"
872 }
873 }
874
875 fn build_test_router(
877 config: ProxyConfig,
878 middlewares: Vec<Box<dyn ProxyMiddleware>>,
879 ) -> Router {
880 let client = reqwest::Client::builder()
881 .http1_only()
882 .build()
883 .expect("build test client");
884
885 let state = Arc::new(ProxyState {
886 config: Arc::new(config),
887 middlewares: Arc::new(middlewares),
888 client,
889 cost_recorder: None,
890 next_request_id: Arc::new(AtomicU64::new(1)),
891 });
892
893 build_router(state)
894 }
895
896 #[tokio::test]
897 async fn test_health_endpoint_returns_200() {
898 let config = ProxyConfig::default();
899 let router = build_test_router(config, vec![]);
900
901 let response = router
902 .oneshot(
903 Request::builder()
904 .uri("/health")
905 .body(Body::empty())
906 .unwrap(),
907 )
908 .await
909 .unwrap();
910
911 assert_eq!(response.status(), StatusCode::OK);
912 }
913
914 #[tokio::test]
915 async fn test_unknown_path_returns_404() {
916 let config = ProxyConfig::default();
917 let router = build_test_router(config, vec![]);
918
919 let response = router
920 .oneshot(
921 Request::builder()
922 .uri("/unknown/path")
923 .method("POST")
924 .header("content-type", "application/json")
925 .body(Body::from(r#"{"model":"test"}"#))
926 .unwrap(),
927 )
928 .await
929 .unwrap();
930
931 assert_eq!(response.status(), StatusCode::NOT_FOUND);
932 }
933
934 #[tokio::test]
935 async fn test_auth_failure_returns_401() {
936 let config = ProxyConfig {
937 proxy_api_key: Some(secrecy::SecretString::new("sk-secret".into())),
938 ..Default::default()
939 };
940 let router = build_test_router(config, vec![]);
941
942 let response = router
943 .oneshot(
944 Request::builder()
945 .uri("/health")
946 .body(Body::empty())
947 .unwrap(),
948 )
949 .await
950 .unwrap();
951
952 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
953 }
954
955 #[tokio::test]
956 async fn test_auth_success_passes_through() {
957 let config = ProxyConfig {
958 proxy_api_key: Some(secrecy::SecretString::new("sk-secret".into())),
959 ..Default::default()
960 };
961 let router = build_test_router(config, vec![]);
962
963 let response = router
964 .oneshot(
965 Request::builder()
966 .uri("/health")
967 .header("authorization", "Bearer sk-secret")
968 .body(Body::empty())
969 .unwrap(),
970 )
971 .await
972 .unwrap();
973
974 assert_eq!(response.status(), StatusCode::OK);
975 }
976
977 #[tokio::test]
978 async fn test_body_too_large_returns_413() {
979 let config = ProxyConfig {
980 max_body_size: 1024, ..Default::default()
982 };
983 let router = build_test_router(config, vec![]);
984
985 let big_body = "x".repeat(2048);
986 let response = router
987 .oneshot(
988 Request::builder()
989 .uri("/v1/messages")
990 .method("POST")
991 .header("content-type", "application/json")
992 .body(Body::from(big_body))
993 .unwrap(),
994 )
995 .await
996 .unwrap();
997
998 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
999 }
1000
1001 #[tokio::test]
1002 async fn test_no_channel_returns_503() {
1003 let config = ProxyConfig::default();
1004 let router = build_test_router(config, vec![]);
1005
1006 let response = router
1007 .oneshot(
1008 Request::builder()
1009 .uri("/v1/messages")
1010 .method("POST")
1011 .header("content-type", "application/json")
1012 .body(Body::from(
1013 r#"{"model":"claude-sonnet","messages":[{"role":"user","content":"hi"}]}"#,
1014 ))
1015 .unwrap(),
1016 )
1017 .await
1018 .unwrap();
1019
1020 assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
1021 }
1022
1023 async fn start_mock_upstream() -> (String, JoinHandle<()>) {
1025 use axum::routing::post;
1026
1027 async fn mock_messages_handler() -> Json<serde_json::Value> {
1028 Json(serde_json::json!({
1029 "id": "msg_123",
1030 "type": "message",
1031 "role": "assistant",
1032 "content": [{"type": "text", "text": "Hello from upstream!"}],
1033 "model": "claude-sonnet",
1034 "usage": {"input_tokens": 10, "output_tokens": 20}
1035 }))
1036 }
1037
1038 let app = Router::new().route("/v1/messages", post(mock_messages_handler));
1039
1040 let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
1041 .await
1042 .expect("bind");
1043 let addr = listener.local_addr().expect("local addr");
1044
1045 let handle = tokio::spawn(async move {
1046 axum::serve(listener, app).await.unwrap();
1047 });
1048
1049 (format!("http://{addr}"), handle)
1050 }
1051
1052 #[tokio::test]
1053 async fn test_successful_proxy_returns_200() {
1054 let (upstream_url, _upstream_handle) = start_mock_upstream().await;
1055
1056 let config = ProxyConfig::default();
1057 let middlewares: Vec<Box<dyn ProxyMiddleware>> =
1058 vec![Box::new(UpstreamMiddleware { url: upstream_url })];
1059
1060 let router = build_test_router(config, middlewares);
1061
1062 let response = router
1063 .oneshot(
1064 Request::builder()
1065 .uri("/v1/messages")
1066 .method("POST")
1067 .header("content-type", "application/json")
1068 .body(Body::from(
1069 r#"{"model":"claude-sonnet","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}"#,
1070 ))
1071 .unwrap(),
1072 )
1073 .await
1074 .unwrap();
1075
1076 assert_eq!(response.status(), StatusCode::OK);
1077 }
1078
1079 #[test]
1082 fn test_extract_usage_from_sse_with_space() {
1083 let body = b"event: message_start\n\
1085 data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":100,\"output_tokens\":0}}}\n\n\
1086 event: message_delta\n\
1087 data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":50,\"cache_read_input_tokens\":30}}\n\n";
1088 let result = extract_usage_from_sse(body);
1089 let usage = result.get("usage").unwrap();
1090 assert_eq!(usage.get("input_tokens").unwrap().as_u64().unwrap(), 100);
1091 assert_eq!(usage.get("output_tokens").unwrap().as_u64().unwrap(), 50);
1092 assert_eq!(
1093 usage
1094 .get("cache_read_input_tokens")
1095 .unwrap()
1096 .as_u64()
1097 .unwrap(),
1098 30
1099 );
1100 }
1101
1102 #[test]
1103 fn test_extract_usage_from_sse_without_space() {
1104 let body = b"event:message_start\n\
1106 data:{\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":200,\"output_tokens\":0}}}\n\n\
1107 event:message_delta\n\
1108 data:{\"type\":\"message_delta\",\"usage\":{\"output_tokens\":80,\"cache_read_input_tokens\":60}}\n\n";
1109 let result = extract_usage_from_sse(body);
1110 let usage = result.get("usage").unwrap();
1111 assert_eq!(usage.get("input_tokens").unwrap().as_u64().unwrap(), 200);
1112 assert_eq!(usage.get("output_tokens").unwrap().as_u64().unwrap(), 80);
1113 assert_eq!(
1114 usage
1115 .get("cache_read_input_tokens")
1116 .unwrap()
1117 .as_u64()
1118 .unwrap(),
1119 60
1120 );
1121 }
1122
1123 #[test]
1124 fn test_extract_usage_from_sse_mixed_format() {
1125 let body = b"event:message_start\n\
1127 data:{\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":150,\"output_tokens\":0}}}\n\n\
1128 event: message_delta\n\
1129 data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":90}}\n\n";
1130 let result = extract_usage_from_sse(body);
1131 let usage = result.get("usage").unwrap();
1132 assert_eq!(usage.get("input_tokens").unwrap().as_u64().unwrap(), 150);
1133 assert_eq!(usage.get("output_tokens").unwrap().as_u64().unwrap(), 90);
1134 }
1135
1136 #[test]
1137 fn test_normalize_sse_format() {
1138 let input = "event:message_start\ndata:{\"type\":\"message_start\"}\n\n";
1140 let output = normalize_sse_format(input);
1141 assert!(output.contains("event: message_start"));
1142 assert!(output.contains("data: {\"type\":\"message_start\"}"));
1143
1144 let input2 = "event: message_start\ndata: {\"type\":\"message_start\"}\n\n";
1146 let output2 = normalize_sse_format(input2);
1147 assert_eq!(output2.trim(), input2.trim());
1148
1149 let input3 = "event:message_start\ndata: {\"type\":\"message_start\"}\n\nevent: message_delta\ndata:{\"type\":\"message_delta\"}";
1151 let output3 = normalize_sse_format(input3);
1152 assert!(output3.contains("event: message_start"));
1153 assert!(output3.contains("data: {\"type\":\"message_start\"}"));
1154 assert!(output3.contains("event: message_delta"));
1155 assert!(output3.contains("data: {\"type\":\"message_delta\"}"));
1156 }
1157}