use super::*;
#[tokio::test]
async fn proxy_streaming_parses_usage_even_when_usage_is_late_in_stream() {
let prefix = Bytes::from(format!("event: {}\n\n", "x".repeat(4096)));
let n = 260usize; let usage = Bytes::from(
"event: response.completed\n\
data: {\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"total_tokens\":3}}}\n\n",
);
let upstream = axum::Router::new().route(
"/v1/responses",
post(move || {
let prefix = prefix.clone();
let usage = usage.clone();
async move {
let mut body = Vec::with_capacity(prefix.len().saturating_mul(n) + usage.len());
for _ in 0..n {
body.extend_from_slice(prefix.as_ref());
}
body.extend_from_slice(usage.as_ref());
let mut resp = Response::new(Body::from(body));
*resp.status_mut() = StatusCode::OK;
resp.headers_mut().insert(
axum::http::header::CONTENT_TYPE,
HeaderValue::from_static("text/event-stream"),
);
resp
}
}),
);
let (u_addr, u_handle) = spawn_axum_server(upstream);
let proxy_client = Client::new();
let retry = retry_config(1, "502", Vec::new(), RetryStrategy::Failover);
let cfg = make_proxy_config(
vec![UpstreamConfig {
base_url: format!("http://{}/v1", u_addr),
auth: UpstreamAuth {
auth_token: None,
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: HashMap::new(),
model_mapping: HashMap::new(),
}],
retry,
);
let proxy = ProxyService::new(
proxy_client,
Arc::new(cfg),
"codex",
Arc::new(std::sync::Mutex::new(HashMap::new())),
);
let state = proxy.state.clone();
let app = crate::proxy::router(proxy);
let (proxy_addr, proxy_handle) = spawn_axum_server(app);
let client = reqwest::Client::new();
let mut drained_ok = false;
let mut last_status: Option<StatusCode> = None;
for _ in 0..3 {
let resp = client
.post(format!("http://{}/v1/responses", proxy_addr))
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.body(r#"{"model":"gpt","input":"hi"}"#)
.send()
.await
.expect("send");
last_status = Some(resp.status());
if resp.status() == StatusCode::OK && resp.bytes().await.is_ok() {
drained_ok = true;
break;
}
sleep(Duration::from_millis(20)).await;
}
assert_eq!(last_status, Some(StatusCode::OK));
assert!(
drained_ok,
"expected to drain SSE body without decode error"
);
let mut finished = Vec::new();
for _ in 0..100 {
finished = state.list_recent_finished(10).await;
if finished.iter().any(|f| f.usage.is_some()) {
break;
}
sleep(Duration::from_millis(20)).await;
}
assert!(
!finished.is_empty(),
"expected finished request to be recorded"
);
let u = finished
.iter()
.find_map(|f| f.usage.as_ref())
.expect("usage should be parsed");
assert_eq!(u.total_tokens, 3);
proxy_handle.abort();
u_handle.abort();
}
#[tokio::test]
async fn proxy_does_not_retry_or_failover_on_400() {
let upstream1_hits = Arc::new(AtomicUsize::new(0));
let upstream2_hits = Arc::new(AtomicUsize::new(0));
let u1_hits = upstream1_hits.clone();
let upstream1 = axum::Router::new().route(
"/v1/responses",
post(move || async move {
u1_hits.fetch_add(1, Ordering::SeqCst);
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "err": "bad request" })),
)
}),
);
let (u1_addr, u1_handle) = spawn_axum_server(upstream1);
let u2_hits = upstream2_hits.clone();
let upstream2 = axum::Router::new().route(
"/v1/responses",
post(move || async move {
u2_hits.fetch_add(1, Ordering::SeqCst);
(StatusCode::OK, Json(serde_json::json!({ "ok": true })))
}),
);
let (u2_addr, u2_handle) = spawn_axum_server(upstream2);
let proxy_client = Client::new();
let retry = RetryConfig {
upstream: Some(crate::config::RetryLayerConfig {
max_attempts: Some(2),
backoff_ms: Some(0),
backoff_max_ms: Some(0),
jitter_ms: Some(0),
on_status: Some("502".to_string()),
on_class: Some(Vec::new()),
strategy: Some(RetryStrategy::SameUpstream),
}),
provider: Some(crate::config::RetryLayerConfig {
max_attempts: Some(2),
backoff_ms: Some(0),
backoff_max_ms: Some(0),
jitter_ms: Some(0),
on_status: Some("502".to_string()),
on_class: Some(Vec::new()),
strategy: Some(RetryStrategy::Failover),
}),
allow_cross_station_before_first_output: Some(true),
cloudflare_challenge_cooldown_secs: Some(0),
cloudflare_timeout_cooldown_secs: Some(0),
transport_cooldown_secs: Some(0),
cooldown_backoff_factor: Some(1),
cooldown_backoff_max_secs: Some(0),
..Default::default()
};
let cfg = make_proxy_config(
vec![
UpstreamConfig {
base_url: format!("http://{}/v1", u1_addr),
auth: UpstreamAuth {
auth_token: None,
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: HashMap::new(),
model_mapping: HashMap::new(),
},
UpstreamConfig {
base_url: format!("http://{}/v1", u2_addr),
auth: UpstreamAuth {
auth_token: None,
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: HashMap::new(),
model_mapping: HashMap::new(),
},
],
retry,
);
let proxy = ProxyService::new(
proxy_client,
Arc::new(cfg),
"codex",
Arc::new(std::sync::Mutex::new(HashMap::new())),
);
let app = crate::proxy::router(proxy);
let (proxy_addr, proxy_handle) = spawn_axum_server(app);
let client = reqwest::Client::new();
let resp = client
.post(format!("http://{}/v1/responses", proxy_addr))
.header("content-type", "application/json")
.body(r#"{"model":"gpt","input":"hi"}"#)
.send()
.await
.expect("send");
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
assert_eq!(upstream1_hits.load(Ordering::SeqCst), 1);
assert_eq!(upstream2_hits.load(Ordering::SeqCst), 0);
proxy_handle.abort();
u1_handle.abort();
u2_handle.abort();
}
#[tokio::test]
async fn proxy_failover_retries_404_when_enabled() {
let upstream1_hits = Arc::new(AtomicUsize::new(0));
let upstream2_hits = Arc::new(AtomicUsize::new(0));
let u1_hits = upstream1_hits.clone();
let upstream1 = axum::Router::new().route(
"/v1/responses",
post(move || async move {
u1_hits.fetch_add(1, Ordering::SeqCst);
StatusCode::NOT_FOUND
}),
);
let (u1_addr, u1_handle) = spawn_axum_server(upstream1);
let u2_hits = upstream2_hits.clone();
let upstream2 = axum::Router::new().route(
"/v1/responses",
post(move || async move {
u2_hits.fetch_add(1, Ordering::SeqCst);
(StatusCode::OK, Json(serde_json::json!({ "ok": true })))
}),
);
let (u2_addr, u2_handle) = spawn_axum_server(upstream2);
let proxy_client = Client::new();
let retry = retry_config(2, "400-599", Vec::new(), RetryStrategy::Failover);
let cfg = make_proxy_config(
vec![
UpstreamConfig {
base_url: format!("http://{}/v1", u1_addr),
auth: UpstreamAuth {
auth_token: None,
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: HashMap::new(),
model_mapping: HashMap::new(),
},
UpstreamConfig {
base_url: format!("http://{}/v1", u2_addr),
auth: UpstreamAuth {
auth_token: None,
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: HashMap::new(),
model_mapping: HashMap::new(),
},
],
retry,
);
let proxy = ProxyService::new(
proxy_client,
Arc::new(cfg),
"codex",
Arc::new(std::sync::Mutex::new(HashMap::new())),
);
let app = crate::proxy::router(proxy);
let (proxy_addr, proxy_handle) = spawn_axum_server(app);
let client = reqwest::Client::new();
let resp = client
.post(format!("http://{}/v1/responses", proxy_addr))
.header("content-type", "application/json")
.body(r#"{"model":"gpt","input":"hi"}"#)
.send()
.await
.expect("send");
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(upstream1_hits.load(Ordering::SeqCst), 2);
let u2 = upstream2_hits.load(Ordering::SeqCst);
assert!(
matches!(u2, 1 | 2),
"expected upstream2 hits to be 1..=2 (transport flake tolerance), got {u2}"
);
proxy_handle.abort();
u1_handle.abort();
u2_handle.abort();
}
#[tokio::test]
async fn proxy_does_not_failover_on_non_retryable_client_error_class() {
let upstream1_hits = Arc::new(AtomicUsize::new(0));
let upstream2_hits = Arc::new(AtomicUsize::new(0));
let u1_hits = upstream1_hits.clone();
let upstream1 = axum::Router::new().route(
"/v1/responses",
post(move || async move {
u1_hits.fetch_add(1, Ordering::SeqCst);
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"type": "invalid_request_error",
"message": "`tool_use` ids must be unique"
}
})),
)
}),
);
let (u1_addr, u1_handle) = spawn_axum_server(upstream1);
let u2_hits = upstream2_hits.clone();
let upstream2 = axum::Router::new().route(
"/v1/responses",
post(move || async move {
u2_hits.fetch_add(1, Ordering::SeqCst);
(StatusCode::OK, Json(serde_json::json!({ "ok": true })))
}),
);
let (u2_addr, u2_handle) = spawn_axum_server(upstream2);
let proxy_client = Client::new();
let retry = retry_config(2, "400-599", Vec::new(), RetryStrategy::Failover);
let cfg = make_proxy_config(
vec![
UpstreamConfig {
base_url: format!("http://{}/v1", u1_addr),
auth: UpstreamAuth {
auth_token: None,
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: HashMap::new(),
model_mapping: HashMap::new(),
},
UpstreamConfig {
base_url: format!("http://{}/v1", u2_addr),
auth: UpstreamAuth {
auth_token: None,
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: HashMap::new(),
model_mapping: HashMap::new(),
},
],
retry,
);
let proxy = ProxyService::new(
proxy_client,
Arc::new(cfg),
"codex",
Arc::new(std::sync::Mutex::new(HashMap::new())),
);
let app = crate::proxy::router(proxy);
let (proxy_addr, proxy_handle) = spawn_axum_server(app);
let client = reqwest::Client::new();
let resp = client
.post(format!("http://{}/v1/responses", proxy_addr))
.header("content-type", "application/json")
.body(r#"{"model":"gpt","input":"hi"}"#)
.send()
.await
.expect("send");
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
assert_eq!(upstream1_hits.load(Ordering::SeqCst), 1);
assert_eq!(upstream2_hits.load(Ordering::SeqCst), 0);
proxy_handle.abort();
u1_handle.abort();
u2_handle.abort();
}
#[tokio::test]
async fn proxy_skips_upstreams_that_do_not_support_model() {
let upstream1_hits = Arc::new(AtomicUsize::new(0));
let upstream2_hits = Arc::new(AtomicUsize::new(0));
let u1_hits = upstream1_hits.clone();
let upstream1 = axum::Router::new().route(
"/v1/responses",
post(move || async move {
u1_hits.fetch_add(1, Ordering::SeqCst);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "err": "should not hit" })),
)
}),
);
let (u1_addr, u1_handle) = spawn_axum_server(upstream1);
let u2_hits = upstream2_hits.clone();
let upstream2 = axum::Router::new().route(
"/v1/responses",
post(move || async move {
u2_hits.fetch_add(1, Ordering::SeqCst);
(
StatusCode::OK,
Json(serde_json::json!({ "ok": true, "upstream": 2 })),
)
}),
);
let (u2_addr, u2_handle) = spawn_axum_server(upstream2);
let proxy_client = Client::new();
let retry = retry_config(1, "502", Vec::new(), RetryStrategy::Failover);
let cfg = make_proxy_config(
vec![
UpstreamConfig {
base_url: format!("http://{}/v1", u1_addr),
auth: UpstreamAuth {
auth_token: None,
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: {
let mut m = HashMap::new();
m.insert("other-*".to_string(), true);
m
},
model_mapping: HashMap::new(),
},
UpstreamConfig {
base_url: format!("http://{}/v1", u2_addr),
auth: UpstreamAuth {
auth_token: None,
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: {
let mut m = HashMap::new();
m.insert("gpt-*".to_string(), true);
m
},
model_mapping: HashMap::new(),
},
],
retry,
);
let proxy = ProxyService::new(
proxy_client,
Arc::new(cfg),
"codex",
Arc::new(std::sync::Mutex::new(HashMap::new())),
);
let app = crate::proxy::router(proxy);
let (proxy_addr, proxy_handle) = spawn_axum_server(app);
let client = reqwest::Client::new();
let resp = client
.post(format!("http://{}/v1/responses", proxy_addr))
.header("content-type", "application/json")
.body(r#"{"model":"gpt-4","input":"hi"}"#)
.send()
.await
.expect("send");
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(upstream1_hits.load(Ordering::SeqCst), 0);
assert_eq!(upstream2_hits.load(Ordering::SeqCst), 1);
proxy_handle.abort();
u1_handle.abort();
u2_handle.abort();
}
#[tokio::test]
async fn proxy_applies_model_mapping_to_request_body() {
let upstream_hits = Arc::new(AtomicUsize::new(0));
let hits = upstream_hits.clone();
let upstream = axum::Router::new().route(
"/v1/responses",
post(move |body: axum::body::Bytes| async move {
hits.fetch_add(1, Ordering::SeqCst);
let v: serde_json::Value =
serde_json::from_slice(&body).expect("json body should parse");
let model = v.get("model").and_then(|m| m.as_str()).unwrap_or("");
if model == "anthropic/claude-sonnet-4" {
(StatusCode::OK, Json(serde_json::json!({ "ok": true })))
} else {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "model": model })),
)
}
}),
);
let (u_addr, u_handle) = spawn_axum_server(upstream);
let proxy_client = Client::new();
let retry = retry_config(1, "502", Vec::new(), RetryStrategy::Failover);
let cfg = make_proxy_config(
vec![UpstreamConfig {
base_url: format!("http://{}/v1", u_addr),
auth: UpstreamAuth {
auth_token: None,
auth_token_env: None,
api_key: None,
api_key_env: None,
},
tags: HashMap::new(),
supported_models: {
let mut m = HashMap::new();
m.insert("anthropic/claude-*".to_string(), true);
m
},
model_mapping: {
let mut m = HashMap::new();
m.insert("claude-*".to_string(), "anthropic/claude-*".to_string());
m
},
}],
retry,
);
let proxy = ProxyService::new(
proxy_client,
Arc::new(cfg),
"codex",
Arc::new(std::sync::Mutex::new(HashMap::new())),
);
let app = crate::proxy::router(proxy);
let (proxy_addr, proxy_handle) = spawn_axum_server(app);
let client = reqwest::Client::new();
let resp = client
.post(format!("http://{}/v1/responses", proxy_addr))
.header("content-type", "application/json")
.body(r#"{"model":"claude-sonnet-4","input":"hi"}"#)
.send()
.await
.expect("send");
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(upstream_hits.load(Ordering::SeqCst), 1);
proxy_handle.abort();
u_handle.abort();
}