use crate::AppState;
use crate::auth;
use crate::client::HttpClient;
use crate::errors::{ErrorResponseBody, OnwardsErrorResponse};
use crate::models::ListModelResponse;
use crate::sse::SseBufferedStream;
use crate::target::Target;
use axum::{
Json,
extract::Request,
extract::State,
http::{
HeaderMap, HeaderName, HeaderValue, StatusCode, Uri,
header::{CONTENT_LENGTH, TRANSFER_ENCODING},
},
response::{IntoResponse, Response},
};
use serde_json::map::Entry;
use tracing::{debug, error, instrument, trace};
fn record_response_status(status_code: u16) {
tracing::Span::current().record("http.response.status_code", status_code);
}
#[derive(Clone, Debug)]
struct OriginalModel(String);
#[derive(Clone, Debug)]
pub(crate) struct ResolvedTrust(pub(crate) bool);
fn filter_headers_for_upstream(headers: &mut HeaderMap, target: &Target) {
const HEADERS_TO_STRIP: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"upgrade",
"authorization",
"x-api-key",
"api-key",
"sec-fetch-site",
"sec-fetch-mode",
"sec-fetch-dest",
"sec-fetch-user",
"origin",
"referer",
"cookie",
"if-modified-since",
"if-none-match",
"if-match",
"if-unmodified-since",
"if-range",
"model-override",
];
for header in HEADERS_TO_STRIP {
headers.remove(*header);
}
let sec_ch_ua_headers: Vec<_> = headers
.keys()
.filter(|name| name.as_str().starts_with("sec-ch-ua"))
.cloned()
.collect();
for header in sec_ch_ua_headers {
headers.remove(header);
}
if let Some(key) = &target.onwards_key {
let header_name_str = target
.upstream_auth_header_name
.as_deref()
.unwrap_or("Authorization");
let header_name = HeaderName::from_bytes(header_name_str.as_bytes()).unwrap();
let prefix = target
.upstream_auth_header_prefix
.as_deref()
.unwrap_or("Bearer ");
let header_value = format!("{}{}", prefix, key);
debug!(
"Adding {} header for upstream {}: {}",
header_name_str, target.url, header_value
);
headers.insert(header_name, header_value.parse().unwrap());
} else {
debug!(
"No upstream authentication configured for target {}",
target.url
);
}
headers.insert("x-forwarded-proto", "https".parse().unwrap());
}
#[instrument(skip(state, req), fields(
gen_ai.request.model = tracing::field::Empty,
http.response.status_code = tracing::field::Empty,
))]
pub async fn target_message_handler<T: HttpClient>(
State(state): State<AppState<T>>,
mut req: axum::extract::Request,
) -> Result<Response, OnwardsErrorResponse> {
let mut body_bytes =
match axum::body::to_bytes(std::mem::take(req.body_mut()), usize::MAX).await {
Ok(bytes) => bytes,
Err(_) => return Err(OnwardsErrorResponse::internal()), };
if let Some(ref transform_fn) = state.body_transform_fn {
let path = req.uri().path();
if let Some(transformed_body) = transform_fn(path, req.headers(), &body_bytes) {
debug!("Applied body transformation for path: {}", path);
body_bytes = transformed_body;
}
}
trace!(
"Incoming request details:\n Method: {}\n URI: {}\n Headers: {:?}\n Body: {}",
req.method(),
req.uri(),
req.headers(),
String::from_utf8_lossy(&body_bytes)
);
let model_name = match crate::extract_model_from_request(req.headers(), &body_bytes) {
Some(model) => model,
None => {
record_response_status(400);
return Err(OnwardsErrorResponse::bad_request(
"Could not parse onwards model from request. 'model' parameter must be supplied in either the body or in the Model-Override header.",
Some("model"),
));
}
};
tracing::Span::current().record("gen_ai.request.model", &model_name);
req.extensions_mut()
.insert(OriginalModel(model_name.clone()));
trace!("Received request for model: {}", model_name);
trace!(
"Available targets: {:?}",
state
.targets
.targets
.iter()
.map(|entry| entry.key().clone())
.collect::<Vec<_>>()
);
let pool = match state.targets.targets.get(&model_name) {
Some(pool) => pool.clone(),
None => {
debug!("No target found for model: {}", model_name);
record_response_status(404);
return Err(OnwardsErrorResponse::model_not_found(model_name.as_str()));
}
};
if pool.is_empty() {
debug!("Pool for model '{}' has no providers", model_name);
return Err(OnwardsErrorResponse::bad_gateway());
}
let bearer_token = req
.headers()
.get("authorization")
.and_then(|auth_header| auth_header.to_str().ok())
.and_then(|auth_value| auth_value.strip_prefix("Bearer "));
if let Some(keys) = pool.keys() {
match bearer_token {
Some(token) => {
trace!("Validating bearer token against pool keys");
if auth::validate_bearer_token(keys, token) {
debug!("Bearer token validation successful");
} else {
debug!("Bearer token validation failed - token not in key set");
record_response_status(403);
return Err(OnwardsErrorResponse::forbidden());
}
}
None => {
debug!("No bearer token found in authorization header");
record_response_status(401);
return Err(OnwardsErrorResponse::unauthorized());
}
}
} else {
debug!(
"Pool '{}' has no keys configured - allowing request",
model_name
);
}
if let Some(limiter) = pool.pool_limiter()
&& limiter.check().is_err()
{
debug!("Pool-level rate limit exceeded for model: {}", model_name);
record_response_status(429);
return Err(OnwardsErrorResponse::rate_limited());
}
if let Some(token) = bearer_token
&& let Some(limiter) = state.targets.key_rate_limiters.get(token)
&& limiter.check().is_err()
{
debug!("Per-key rate limit exceeded for token: {}", token);
record_response_status(429);
return Err(OnwardsErrorResponse::rate_limited());
}
let _pool_concurrency_guard = if let Some(limiter) = pool.pool_concurrency_limiter() {
match limiter.acquire().await {
Ok(guard) => Some(guard),
Err(_) => {
debug!(
"Pool-level concurrency limit exceeded for model: {}",
model_name
);
return Err(OnwardsErrorResponse::concurrency_limited());
}
}
} else {
None
};
let _key_concurrency_guard = if let Some(token) = bearer_token {
if let Some(limiter) = state.targets.key_concurrency_limiters.get(token) {
match limiter.acquire().await {
Ok(guard) => Some(guard),
Err(_) => {
debug!("Per-key concurrency limit exceeded for token: {}", token);
return Err(OnwardsErrorResponse::concurrency_limited());
}
}
} else {
None
}
} else {
None
};
let path_and_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(req.uri().path())
.to_string();
let original_headers = req.headers().clone();
let method = req.method().clone();
let mut last_error: Option<OnwardsErrorResponse> = None;
for (_idx, target) in pool.select_ordered() {
if let Some(ref limiter) = target.limiter
&& limiter.check().is_err()
{
debug!("Provider rate limited: {:?}", target.url);
last_error = Some(OnwardsErrorResponse::rate_limited());
if pool.should_fallback_on_rate_limit() {
debug!("Fallback on rate limit enabled, trying next provider");
continue;
} else {
return Err(OnwardsErrorResponse::rate_limited());
}
}
let _target_concurrency_guard = if let Some(ref limiter) = target.concurrency_limiter {
match limiter.acquire().await {
Ok(guard) => Some(guard),
Err(_) => {
debug!("Provider concurrency limit exceeded: {:?}", target.url);
last_error = Some(OnwardsErrorResponse::concurrency_limited());
if pool.should_fallback_on_rate_limit() {
debug!("Fallback on rate limit enabled, trying next provider");
continue;
} else {
return Err(OnwardsErrorResponse::concurrency_limited());
}
}
}
} else {
None
};
let response_headers = target.response_headers.clone();
let mut attempt_body = body_bytes.clone();
if let Some(ref rewrite) = target.onwards_model
&& !attempt_body.is_empty()
{
debug!("Rewriting model key to: {}", rewrite);
let error = OnwardsErrorResponse::bad_request(
"Could not parse onwards model from request. 'model' parameter must be supplied in either the body or in the Model-Override header.",
Some("model"),
);
let mut body_serialized: serde_json::Value = match serde_json::from_slice(&attempt_body)
{
Ok(value) => value,
Err(_) => return Err(error.clone()),
};
let entry = body_serialized
.as_object_mut()
.ok_or(error.clone())?
.entry("model");
match entry {
Entry::Occupied(mut entry) => {
entry.insert(serde_json::Value::String(rewrite.clone()));
}
Entry::Vacant(_entry) => {
return Err(error.clone());
}
}
attempt_body = match serde_json::to_vec(&body_serialized) {
Ok(bytes) => axum::body::Bytes::from(bytes),
Err(_) => return Err(OnwardsErrorResponse::internal()),
};
}
let request_path = path_and_query.strip_prefix('/').unwrap_or(&path_and_query);
let target_path = target.url.path().trim_end_matches('/');
let path_to_join = if !target_path.is_empty() && target_path != "/" {
let target_path_no_slash = &target_path[1..];
if let Some(rest) = request_path.strip_prefix(target_path_no_slash) {
if rest.is_empty() || rest.starts_with('/') {
rest.strip_prefix('/').unwrap_or(rest)
} else {
request_path
}
} else {
request_path
}
} else {
request_path
};
let upstream_uri = match target.url.join(path_to_join) {
Ok(url) => url.to_string(),
Err(_) => return Err(OnwardsErrorResponse::internal()),
};
let upstream_uri_parsed = match Uri::try_from(&upstream_uri) {
Ok(uri) => uri,
Err(_) => {
error!("Invalid URI: {}", upstream_uri);
return Err(OnwardsErrorResponse::internal());
}
};
let mut attempt_headers = original_headers.clone();
if let Some(host) = upstream_uri_parsed.host() {
let host_value = if let Some(port) = upstream_uri_parsed.port_u16() {
format!("{host}:{port}")
} else {
host.to_string()
};
attempt_headers.insert("host", host_value.parse().unwrap());
}
attempt_headers.insert(
CONTENT_LENGTH,
attempt_body
.len()
.to_string()
.parse()
.expect("Content-Length should be valid"),
);
attempt_headers.remove(TRANSFER_ENCODING);
filter_headers_for_upstream(&mut attempt_headers, target);
let attempt_req = axum::extract::Request::builder()
.method(method.clone())
.uri(upstream_uri_parsed)
.body(axum::body::Body::from(attempt_body))
.unwrap();
let (mut parts, body) = attempt_req.into_parts();
parts.headers = attempt_headers;
let attempt_req = axum::extract::Request::from_parts(parts, body);
trace!(
"Outgoing request to provider:\n URI: {}\n Headers: {:?}",
upstream_uri,
attempt_req.headers()
);
let request_result = if let Some(timeout_secs) = target.request_timeout_secs {
let timeout_duration = std::time::Duration::from_secs(timeout_secs);
match tokio::time::timeout(timeout_duration, state.http_client.request(attempt_req))
.await
{
Err(_) => {
debug!(
"Request to {} timed out after {:?}",
upstream_uri, timeout_duration
);
last_error = Some(OnwardsErrorResponse::gateway_timeout());
if pool.fallback_enabled() {
continue;
} else {
record_response_status(504);
return Err(last_error.unwrap());
}
}
Ok(result) => result,
}
} else {
state.http_client.request(attempt_req).await
};
let mut response = match request_result {
Err(e) => {
error!(
"Error forwarding request to target url {}: {}",
upstream_uri, e
);
last_error = Some(OnwardsErrorResponse::bad_gateway());
if pool.fallback_enabled() {
continue;
} else {
return Err(last_error.unwrap());
}
}
Ok(response) => response,
};
let status = response.status().as_u16();
if pool.should_fallback_on_status(status) {
debug!(
"Provider returned fallback status {}, trying next: {:?}",
status, target.url
);
last_error = Some(OnwardsErrorResponse::bad_gateway());
continue;
}
if target.sanitize_response && !(200..300).contains(&status) {
let error_body = axum::body::to_bytes(response.into_body(), 64 * 1024)
.await
.ok();
error!(
status = status,
upstream = %target.url,
body = error_body
.as_ref()
.map(|b| String::from_utf8_lossy(b))
.unwrap_or_default()
.as_ref(),
"Upstream provider returned error, sanitizing before forwarding to client"
);
let sanitized_error = if (400..500).contains(&status) {
OnwardsErrorResponse::builder()
.body(ErrorResponseBody {
message: "The upstream provider rejected the request.".to_string(),
r#type: "invalid_request_error".to_string(),
param: None,
code: "upstream_error".to_string(),
})
.status(StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST))
.build()
} else {
OnwardsErrorResponse::builder()
.body(ErrorResponseBody {
message: "An internal error occurred. Please try again later.".to_string(),
r#type: "internal_error".to_string(),
param: None,
code: "internal_error".to_string(),
})
.status(StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_GATEWAY))
.build()
};
record_response_status(status);
return Err(sanitized_error);
}
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let is_sse = content_type.contains("text/event-stream");
let needs_sse_buffering = !state.targets.strict_mode
&& state.response_transform_fn.is_some()
&& target.sanitize_response
&& (200..300).contains(&status);
if is_sse && needs_sse_buffering {
debug!("Wrapping SSE response with buffered stream for non-strict sanitization");
let (parts, body) = response.into_parts();
let byte_stream = body.into_data_stream();
let buffered = SseBufferedStream::new(byte_stream);
let new_body = axum::body::Body::from_stream(buffered);
response = Response::from_parts(parts, new_body);
}
if let Some(ref transform_fn) = state.response_transform_fn
&& target.sanitize_response
&& (200..300).contains(&status)
&& !state.targets.strict_mode
{
debug!(
"Attempting response sanitization for status {}, path {}",
status, path_and_query
);
let original_model = req
.extensions()
.get::<OriginalModel>()
.map(|m| m.0.to_string());
if is_sse {
debug!("Applying streaming sanitization");
let sanitizer = crate::response_sanitizer::ResponseSanitizer {
original_model: original_model.clone(),
};
use futures_util::StreamExt;
let body_stream =
http_body_util::BodyExt::into_data_stream(std::mem::take(response.body_mut()));
let transformed_stream = body_stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
match sanitizer.sanitize_streaming(&chunk) {
Ok(Some(sanitized)) => Ok::<_, std::io::Error>(sanitized),
Ok(None) => Ok(chunk),
Err(e) => {
tracing::error!("Failed to sanitize streaming chunk: {}", e);
Ok(chunk) }
}
}
Err(e) => {
tracing::error!("Stream error: {}", e);
Err(std::io::Error::other(e))
}
}
});
*response.body_mut() = axum::body::Body::from_stream(transformed_stream);
} else {
debug!("Applying non-streaming sanitization");
let response_body =
axum::body::to_bytes(std::mem::take(response.body_mut()), usize::MAX)
.await
.map_err(|e| {
error!("Failed to buffer response body: {}", e);
OnwardsErrorResponse::internal()
})?;
debug!(
"Response body buffered: {} bytes, content-type: {}",
response_body.len(),
content_type
);
trace!(
"Response body content: {}",
String::from_utf8_lossy(&response_body)
);
match transform_fn(
&path_and_query,
response.headers(),
&response_body,
original_model.as_deref(),
) {
Ok(Some(transformed_body)) => {
let content_length = transformed_body.len();
debug!(
"Sanitization successful: {} bytes -> {} bytes",
response_body.len(),
content_length
);
trace!(
"Sanitized body: {}",
String::from_utf8_lossy(&transformed_body)
);
*response.body_mut() = axum::body::Body::from(transformed_body);
response.headers_mut().remove(TRANSFER_ENCODING);
response
.headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::from(content_length));
}
Ok(None) => {
debug!(
"Sanitization returned None, restoring original {} bytes",
response_body.len()
);
let content_length = response_body.len();
*response.body_mut() = axum::body::Body::from(response_body);
response.headers_mut().remove(TRANSFER_ENCODING);
response
.headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::from(content_length));
}
Err(e) => {
error!("Response sanitization failed: {}", e);
return Err(OnwardsErrorResponse::internal());
}
}
}
}
if let Some(headers) = response_headers {
for (key, value) in headers.iter() {
if let (Ok(header_name), Ok(header_value)) =
(key.parse::<HeaderName>(), value.parse::<HeaderValue>())
{
response.headers_mut().insert(header_name, header_value);
}
}
trace!(
model = %model_name,
headers = ?headers,
"Added custom response headers"
);
}
record_response_status(response.status().as_u16());
debug!(
"Returning response with status {}, content-length: {:?}, strict_mode: {}",
response.status(),
response.headers().get(CONTENT_LENGTH),
state.targets.strict_mode
);
let resolved_trust = target.trusted.unwrap_or_else(|| pool.is_trusted());
response
.extensions_mut()
.insert(ResolvedTrust(resolved_trust));
return Ok(response);
}
let final_error =
last_error.unwrap_or_else(|| OnwardsErrorResponse::model_not_found(model_name.as_str()));
record_response_status(final_error.status.as_u16());
Err(final_error)
}
#[instrument(skip(state, req))]
pub async fn models<T: HttpClient>(
State(state): State<AppState<T>>,
req: Request,
) -> impl IntoResponse {
let bearer_token = req
.headers()
.get("authorization")
.and_then(|auth_header| auth_header.to_str().ok())
.and_then(|auth_value| auth_value.strip_prefix("Bearer "));
let accessible_models: Vec<String> = state
.targets
.targets
.iter()
.filter(|entry| {
let pool = entry.value();
let Some(keys) = pool.keys() else {
return true;
};
let Some(token) = bearer_token else {
return false;
};
auth::validate_bearer_token(keys, token)
})
.map(|entry| entry.key().clone())
.collect();
Json(ListModelResponse::from_model_names(&accessible_models))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_headers_strips_hop_by_hop_headers() {
let mut headers = HeaderMap::new();
headers.insert("connection", "keep-alive".parse().unwrap());
headers.insert("keep-alive", "timeout=5".parse().unwrap());
headers.insert("proxy-authenticate", "Basic".parse().unwrap());
headers.insert("proxy-authorization", "Basic abc123".parse().unwrap());
headers.insert("te", "trailers".parse().unwrap());
headers.insert("trailer", "Expires".parse().unwrap());
headers.insert("upgrade", "HTTP/2.0".parse().unwrap());
headers.insert("content-type", "application/json".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("connection"));
assert!(!headers.contains_key("keep-alive"));
assert!(!headers.contains_key("proxy-authenticate"));
assert!(!headers.contains_key("proxy-authorization"));
assert!(!headers.contains_key("te"));
assert!(!headers.contains_key("trailer"));
assert!(!headers.contains_key("upgrade"));
assert!(headers.contains_key("content-type"));
}
#[test]
fn test_filter_headers_strips_auth_headers() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer client-token".parse().unwrap());
headers.insert("x-api-key", "client-api-key".parse().unwrap());
headers.insert("api-key", "another-key".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("authorization"));
assert!(!headers.contains_key("x-api-key"));
assert!(!headers.contains_key("api-key"));
}
#[test]
fn test_filter_headers_strips_browser_security_headers() {
let mut headers = HeaderMap::new();
headers.insert("sec-fetch-site", "cross-site".parse().unwrap());
headers.insert("sec-fetch-mode", "cors".parse().unwrap());
headers.insert("sec-fetch-dest", "empty".parse().unwrap());
headers.insert("sec-fetch-user", "?1".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("sec-fetch-site"));
assert!(!headers.contains_key("sec-fetch-mode"));
assert!(!headers.contains_key("sec-fetch-dest"));
assert!(!headers.contains_key("sec-fetch-user"));
}
#[test]
fn test_filter_headers_strips_all_sec_ch_ua_variants() {
let mut headers = HeaderMap::new();
headers.insert("sec-ch-ua", "\"Chrome\";v=\"120\"".parse().unwrap());
headers.insert("sec-ch-ua-mobile", "?0".parse().unwrap());
headers.insert("sec-ch-ua-platform", "\"macOS\"".parse().unwrap());
headers.insert("sec-ch-ua-arch", "\"arm64\"".parse().unwrap());
headers.insert("sec-ch-ua-model", "\"\"".parse().unwrap());
headers.insert("user-agent", "Mozilla/5.0...".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("sec-ch-ua"));
assert!(!headers.contains_key("sec-ch-ua-mobile"));
assert!(!headers.contains_key("sec-ch-ua-platform"));
assert!(!headers.contains_key("sec-ch-ua-arch"));
assert!(!headers.contains_key("sec-ch-ua-model"));
assert!(headers.contains_key("user-agent"));
}
#[test]
fn test_filter_headers_strips_browser_context_headers() {
let mut headers = HeaderMap::new();
headers.insert("origin", "http://localhost:3000".parse().unwrap());
headers.insert("referer", "http://localhost:3000/chat".parse().unwrap());
headers.insert("cookie", "session=abc123".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("origin"));
assert!(!headers.contains_key("referer"));
assert!(!headers.contains_key("cookie"));
}
#[test]
fn test_filter_headers_strips_caching_headers() {
let mut headers = HeaderMap::new();
headers.insert(
"if-modified-since",
"Wed, 21 Oct 2015 07:28:00 GMT".parse().unwrap(),
);
headers.insert("if-none-match", "\"abc123\"".parse().unwrap());
headers.insert("if-match", "\"xyz789\"".parse().unwrap());
headers.insert(
"if-unmodified-since",
"Wed, 21 Oct 2015 07:28:00 GMT".parse().unwrap(),
);
headers.insert("if-range", "\"abc123\"".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("if-modified-since"));
assert!(!headers.contains_key("if-none-match"));
assert!(!headers.contains_key("if-match"));
assert!(!headers.contains_key("if-unmodified-since"));
assert!(!headers.contains_key("if-range"));
}
#[test]
fn test_filter_headers_strips_model_override_header() {
let mut headers = HeaderMap::new();
headers.insert("model-override", "gpt-4".parse().unwrap());
headers.insert("content-type", "application/json".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("model-override"));
assert!(headers.contains_key("content-type"));
}
#[test]
fn test_filter_headers_keeps_safe_headers() {
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
headers.insert("accept", "application/json".parse().unwrap());
headers.insert("user-agent", "MyClient/1.0".parse().unwrap());
headers.insert("accept-language", "en-US,en;q=0.9".parse().unwrap());
headers.insert("accept-encoding", "gzip, deflate, br".parse().unwrap());
headers.insert("x-stainless-lang", "js".parse().unwrap());
headers.insert("x-stainless-os", "macOS".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(headers.contains_key("content-type"));
assert!(headers.contains_key("accept"));
assert!(headers.contains_key("user-agent"));
assert!(headers.contains_key("accept-language"));
assert!(headers.contains_key("accept-encoding"));
assert!(headers.contains_key("x-stainless-lang"));
assert!(headers.contains_key("x-stainless-os"));
}
#[test]
fn test_filter_headers_adds_authorization_when_onwards_key_present() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer client-token".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.onwards_key("sk-upstream-key".to_string())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(headers.contains_key("authorization"));
assert_eq!(
headers.get("authorization").unwrap().to_str().unwrap(),
"Bearer sk-upstream-key"
);
}
#[test]
fn test_filter_headers_no_authorization_when_onwards_key_absent() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer client-token".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("authorization"));
}
#[test]
fn test_filter_headers_custom_auth_header_name() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer client-token".parse().unwrap());
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.onwards_key("my-api-key-123".to_string())
.upstream_auth_header_name("X-API-Key".to_string())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("authorization"));
assert!(headers.contains_key("x-api-key"));
assert_eq!(
headers.get("x-api-key").unwrap().to_str().unwrap(),
"Bearer my-api-key-123"
);
}
#[test]
fn test_filter_headers_custom_auth_header_prefix() {
let mut headers = HeaderMap::new();
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.onwards_key("token-xyz".to_string())
.upstream_auth_header_prefix("ApiKey ".to_string())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(headers.contains_key("authorization"));
assert_eq!(
headers.get("authorization").unwrap().to_str().unwrap(),
"ApiKey token-xyz"
);
}
#[test]
fn test_filter_headers_empty_auth_header_prefix() {
let mut headers = HeaderMap::new();
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.onwards_key("plain-api-key-456".to_string())
.upstream_auth_header_prefix("".to_string())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(headers.contains_key("authorization"));
assert_eq!(
headers.get("authorization").unwrap().to_str().unwrap(),
"plain-api-key-456"
);
}
#[test]
fn test_filter_headers_custom_header_name_and_prefix() {
let mut headers = HeaderMap::new();
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.onwards_key("secret-key".to_string())
.upstream_auth_header_name("X-Custom-Auth".to_string())
.upstream_auth_header_prefix("Token ".to_string())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("authorization"));
assert!(headers.contains_key("x-custom-auth"));
assert_eq!(
headers.get("x-custom-auth").unwrap().to_str().unwrap(),
"Token secret-key"
);
}
#[test]
fn test_filter_headers_adds_x_forwarded_proto() {
let mut headers = HeaderMap::new();
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(headers.contains_key("x-forwarded-proto"));
assert_eq!(
headers.get("x-forwarded-proto").unwrap().to_str().unwrap(),
"https"
);
}
#[test]
fn test_filter_headers_comprehensive_browser_request() {
let mut headers = HeaderMap::new();
headers.insert("connection", "keep-alive".parse().unwrap());
headers.insert("authorization", "Bearer client-secret".parse().unwrap());
headers.insert("sec-fetch-site", "same-origin".parse().unwrap());
headers.insert("sec-fetch-mode", "cors".parse().unwrap());
headers.insert("sec-fetch-dest", "empty".parse().unwrap());
headers.insert("sec-ch-ua", "\"Chrome\";v=\"120\"".parse().unwrap());
headers.insert("sec-ch-ua-mobile", "?0".parse().unwrap());
headers.insert("sec-ch-ua-platform", "\"macOS\"".parse().unwrap());
headers.insert("origin", "http://localhost:5173".parse().unwrap());
headers.insert(
"referer",
"http://localhost:5173/playground".parse().unwrap(),
);
headers.insert("cookie", "session=xyz; token=abc".parse().unwrap());
headers.insert("if-none-match", "\"abc123\"".parse().unwrap());
headers.insert("content-type", "application/json".parse().unwrap());
headers.insert("accept", "application/json".parse().unwrap());
headers.insert("user-agent", "Mozilla/5.0...".parse().unwrap());
let target = Target::builder()
.url("https://api.anthropic.com".parse().unwrap())
.onwards_key("sk-ant-upstream-key".to_string())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(!headers.contains_key("connection"));
assert!(!headers.contains_key("sec-fetch-site"));
assert!(!headers.contains_key("sec-fetch-mode"));
assert!(!headers.contains_key("sec-fetch-dest"));
assert!(!headers.contains_key("sec-ch-ua"));
assert!(!headers.contains_key("sec-ch-ua-mobile"));
assert!(!headers.contains_key("sec-ch-ua-platform"));
assert!(!headers.contains_key("origin"));
assert!(!headers.contains_key("referer"));
assert!(!headers.contains_key("cookie"));
assert!(!headers.contains_key("if-none-match"));
assert!(headers.contains_key("content-type"));
assert!(headers.contains_key("accept"));
assert!(headers.contains_key("user-agent"));
assert_eq!(
headers.get("authorization").unwrap().to_str().unwrap(),
"Bearer sk-ant-upstream-key"
);
assert_eq!(
headers.get("x-forwarded-proto").unwrap().to_str().unwrap(),
"https"
);
}
#[test]
fn test_filter_headers_preserves_provider_specific_headers() {
let mut headers = HeaderMap::new();
headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
headers.insert("openai-organization", "org-123".parse().unwrap());
headers.insert("x-stainless-lang", "js".parse().unwrap());
headers.insert("x-stainless-runtime", "browser:chrome".parse().unwrap());
let target = Target::builder()
.url("https://api.anthropic.com".parse().unwrap())
.build();
filter_headers_for_upstream(&mut headers, &target);
assert!(headers.contains_key("anthropic-version"));
assert!(headers.contains_key("openai-organization"));
assert!(headers.contains_key("x-stainless-lang"));
assert!(headers.contains_key("x-stainless-runtime"));
}
#[test]
fn test_path_stripping_with_duplicate_prefix() {
let target_url = url::Url::parse("https://api.openai.com/v1/").unwrap();
let target_path = target_url.path().trim_end_matches('/'); let request_path = "v1/chat/completions";
let path_to_join = if !target_path.is_empty() && target_path != "/" {
let target_path_no_slash = &target_path[1..];
if let Some(rest) = request_path.strip_prefix(target_path_no_slash) {
if rest.is_empty() || rest.starts_with('/') {
rest.strip_prefix('/').unwrap_or(rest)
} else {
request_path
}
} else {
request_path
}
} else {
request_path
};
let result = target_url.join(path_to_join).unwrap();
assert_eq!(
result.as_str(),
"https://api.openai.com/v1/chat/completions"
);
}
#[test]
fn test_path_stripping_without_duplicate() {
let target_url = url::Url::parse("https://api.openai.com/").unwrap();
let target_path = target_url.path().trim_end_matches('/'); let request_path = "v1/chat/completions";
let path_to_join = if !target_path.is_empty() && target_path != "/" {
let target_path_no_slash = &target_path[1..];
if let Some(rest) = request_path.strip_prefix(target_path_no_slash) {
if rest.is_empty() || rest.starts_with('/') {
rest.strip_prefix('/').unwrap_or(rest)
} else {
request_path
}
} else {
request_path
}
} else {
request_path
};
let result = target_url.join(path_to_join).unwrap();
assert_eq!(
result.as_str(),
"https://api.openai.com/v1/chat/completions"
);
}
#[test]
fn test_path_stripping_with_actual_duplicate_paths() {
let target_url = url::Url::parse("https://api.example.com/").unwrap();
let target_path = target_url.path().trim_end_matches('/');
let request_path = "v1/v1/something";
let path_to_join = if !target_path.is_empty() && target_path != "/" {
let target_path_no_slash = &target_path[1..];
if let Some(rest) = request_path.strip_prefix(target_path_no_slash) {
if rest.is_empty() || rest.starts_with('/') {
rest.strip_prefix('/').unwrap_or(rest)
} else {
request_path
}
} else {
request_path
}
} else {
request_path
};
let result = target_url.join(path_to_join).unwrap();
assert_eq!(result.as_str(), "https://api.example.com/v1/v1/something");
}
#[test]
fn test_path_stripping_with_different_prefix() {
let target_url = url::Url::parse("https://api.example.com/v2/").unwrap();
let target_path = target_url.path().trim_end_matches('/'); let request_path = "v1/chat/completions";
let path_to_join = if !target_path.is_empty() && target_path != "/" {
let target_path_no_slash = &target_path[1..];
if let Some(rest) = request_path.strip_prefix(target_path_no_slash) {
if rest.is_empty() || rest.starts_with('/') {
rest.strip_prefix('/').unwrap_or(rest)
} else {
request_path
}
} else {
request_path
}
} else {
request_path
};
let result = target_url.join(path_to_join).unwrap();
assert_eq!(
result.as_str(),
"https://api.example.com/v2/v1/chat/completions"
);
}
#[test]
fn test_path_stripping_with_query_params() {
let target_url = url::Url::parse("https://api.openai.com/v1/").unwrap();
let target_path = target_url.path().trim_end_matches('/'); let request_path = "v1/chat/completions?stream=true";
let path_to_join = if !target_path.is_empty() && target_path != "/" {
let target_path_no_slash = &target_path[1..];
if let Some(rest) = request_path.strip_prefix(target_path_no_slash) {
if rest.is_empty() || rest.starts_with('/') {
rest.strip_prefix('/').unwrap_or(rest)
} else {
request_path
}
} else {
request_path
}
} else {
request_path
};
let result = target_url.join(path_to_join).unwrap();
assert_eq!(
result.as_str(),
"https://api.openai.com/v1/chat/completions?stream=true"
);
}
#[test]
fn test_path_stripping_false_positive() {
let target_url = url::Url::parse("https://api.example.com/v1/").unwrap();
let target_path = target_url.path().trim_end_matches('/'); let request_path = "v1x/something";
let path_to_join = if !target_path.is_empty() && target_path != "/" {
let target_path_no_slash = &target_path[1..];
if let Some(rest) = request_path.strip_prefix(target_path_no_slash) {
if rest.is_empty() || rest.starts_with('/') {
rest.strip_prefix('/').unwrap_or(rest)
} else {
request_path
}
} else {
request_path
}
} else {
request_path
};
let result = target_url.join(path_to_join).unwrap();
assert_eq!(result.as_str(), "https://api.example.com/v1/v1x/something");
}
#[test]
fn test_timeout_config_can_be_set() {
use crate::target::Target;
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.request_timeout_secs(30)
.build();
assert_eq!(target.request_timeout_secs, Some(30));
}
#[test]
fn test_timeout_defaults_to_none() {
use crate::target::Target;
let target = Target::builder()
.url("https://api.example.com".parse().unwrap())
.build();
assert_eq!(target.request_timeout_secs, None);
}
#[test]
fn test_gateway_timeout_error_response() {
let error = OnwardsErrorResponse::gateway_timeout();
let response = error.into_response();
assert_eq!(response.status().as_u16(), 504);
}
#[tokio::test]
async fn test_gateway_timeout_error_body() {
use http_body_util::BodyExt;
let error = OnwardsErrorResponse::gateway_timeout();
let response = error.into_response();
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
assert!(body_str.contains("message"));
assert!(body_str.contains("code"));
assert!(body_str.contains("gateway_timeout"));
assert!(body_str.contains("took too long"));
}
#[derive(Debug, Clone)]
struct DelayedMockClient {
delay: std::time::Duration,
response_status: u16,
}
impl DelayedMockClient {
fn new(delay: std::time::Duration, response_status: u16) -> Self {
Self {
delay,
response_status,
}
}
}
#[async_trait::async_trait]
impl HttpClient for DelayedMockClient {
async fn request(
&self,
_req: axum::extract::Request,
) -> Result<axum::response::Response, Box<dyn std::error::Error + Send + Sync>> {
tokio::time::sleep(self.delay).await;
Ok(axum::response::Response::builder()
.status(self.response_status)
.body(axum::body::Body::empty())
.unwrap())
}
}
#[tokio::test]
async fn test_mock_client_timeout_fires() {
use crate::target::Target;
let target = Target::builder()
.url("https://api.example.com/".parse().unwrap())
.request_timeout_secs(1)
.build();
let pool = target.into_pool();
let mock_client = DelayedMockClient::new(std::time::Duration::from_secs(2), 200);
let state = AppState {
targets: crate::target::Targets {
targets: std::sync::Arc::new(dashmap::DashMap::new()),
key_rate_limiters: std::sync::Arc::new(dashmap::DashMap::new()),
key_concurrency_limiters: std::sync::Arc::new(dashmap::DashMap::new()),
strict_mode: false,
http_pool_config: None,
},
http_client: mock_client,
body_transform_fn: None,
response_transform_fn: None,
tool_executor: std::sync::Arc::new(crate::NoOpToolExecutor),
response_store: std::sync::Arc::new(crate::NoOpResponseStore),
};
let req = axum::extract::Request::builder()
.uri("/v1/chat/completions")
.method("POST")
.header("content-type", "application/json")
.body(axum::body::Body::from(r#"{"model":"gpt-4","messages":[]}"#))
.unwrap();
let target = pool.first_target().unwrap();
let timeout_secs = target.request_timeout_secs.unwrap();
let timeout_duration = std::time::Duration::from_secs(timeout_secs);
let result = tokio::time::timeout(timeout_duration, state.http_client.request(req)).await;
assert!(result.is_err(), "Expected timeout but request completed");
}
#[test]
fn test_pool_with_fallback_enabled() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{FallbackConfig, LoadBalanceStrategy, Target};
let target1 = Target::builder()
.url("https://provider1.example.com/".parse().unwrap())
.request_timeout_secs(1)
.build();
let target2 = Target::builder()
.url("https://provider2.example.com/".parse().unwrap())
.request_timeout_secs(1)
.build();
let providers = vec![
Provider {
target: target1,
weight: 1,
},
Provider {
target: target2,
weight: 1,
},
];
let fallback_config = Some(FallbackConfig {
enabled: true,
on_status: vec![],
on_rate_limit: false,
..Default::default()
});
let pool = ProviderPool::with_config(
providers,
None,
None,
None,
fallback_config,
LoadBalanceStrategy::Priority,
false,
);
assert!(pool.fallback_enabled(), "Fallback should be enabled");
assert_eq!(pool.len(), 2, "Pool should have 2 providers");
}
}