use super::adapter::{OpenResponsesAdapter, PendingToolCall};
use super::schemas::chat_completions::{
ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, FunctionCall,
FunctionDefinition, Tool as ChatTool, ToolCall, Usage, generated_chat_completion_id,
normalize_chat_completion_chunk_value, normalize_chat_completion_response_value,
};
use super::schemas::completions::{
CompletionChunk, CompletionRequest, CompletionResponse, generated_completion_id,
normalize_completion_chunk_value, normalize_completion_response_value,
};
use super::schemas::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
use super::schemas::responses::{
InputTokensDetails, OutputTokensDetails, ResponseUsage, ResponsesRequest, ResponsesResponse,
ResponsesStreamingEvent, generated_response_id, normalize_responses_response_value,
normalize_responses_streaming_event_value,
};
use super::streaming::{StreamingState, parse_chat_chunk};
use crate::AppState;
use crate::client::HttpClient;
use crate::extract_model_from_request;
use crate::handlers::{ResolvedTrust, target_message_handler};
use crate::traits::RequestContext;
use axum::Json;
use axum::body::Body;
use axum::extract::{FromRequest, State};
use axum::http::{HeaderMap, Request, StatusCode, header};
use axum::response::{IntoResponse, Response};
use futures_util::StreamExt;
use http_body_util::BodyExt;
use serde_json::json;
use std::collections::{HashMap, HashSet};
use tracing::{debug, error, info, trace, warn};
struct ForwardResult {
response: Response,
trusted: bool,
internal_error: bool,
}
fn is_sse_content_type(content_type: &str) -> bool {
content_type
.split(';')
.next()
.map(str::trim)
.map(|value| value.eq_ignore_ascii_case("text/event-stream"))
.unwrap_or(false)
}
fn response_is_sse(response: &Response) -> bool {
response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(is_sse_content_type)
.unwrap_or(false)
}
pub async fn models_handler<T: HttpClient + Clone + Send + Sync + 'static>(
State(state): State<AppState<T>>,
req: Request<Body>,
) -> impl IntoResponse {
crate::handlers::models(State(state), req).await
}
pub async fn chat_completions_handler<T: HttpClient + Clone + Send + Sync + 'static>(
State(state): State<AppState<T>>,
headers: HeaderMap,
Json(request): Json<ChatCompletionRequest>,
) -> Response {
let original_model = request.model.clone();
let is_streaming = request.stream.unwrap_or(false);
debug!(
model = %original_model,
messages_count = request.messages.len(),
stream = is_streaming,
"Chat completions request validated"
);
let body_bytes = match serde_json::to_vec(&request) {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to serialize chat completions request");
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to process request",
);
}
};
let resolved_model =
extract_model_from_request(&headers, &body_bytes).unwrap_or(original_model.clone());
let ForwardResult {
response,
trusted,
internal_error,
} = forward_request(state, headers, "/chat/completions", body_bytes).await;
if response.status().is_success() {
let response_is_sse = response_is_sse(&response);
if response_is_sse && !is_streaming {
debug!(
model = %resolved_model,
"Using streaming chat sanitizer because upstream returned SSE"
);
}
if is_streaming || response_is_sse {
sanitize_streaming_chat_response(response, resolved_model, trusted).await
} else {
sanitize_chat_response(response, resolved_model).await
}
} else if trusted || internal_error {
debug!(model = %resolved_model, "Bypassing error sanitization for trusted provider");
response
} else {
sanitize_error_response(response).await
}
}
pub async fn responses_handler<T: HttpClient + Clone + Send + Sync + 'static>(
State(state): State<AppState<T>>,
headers: HeaderMap,
req: Request<Body>,
) -> Response {
let (mut parts, body) = req.into_parts();
let extensions = std::mem::take(&mut parts.extensions);
let req = Request::from_parts(parts, body);
let request: ResponsesRequest = match axum::extract::Json::from_request(req, &state).await {
Ok(Json(r)) => r,
Err(e) => {
error!(error = %e, "Failed to parse responses request");
return error_response(
StatusCode::BAD_REQUEST,
"invalid_request_error",
&format!("Invalid request: {}", e),
);
}
};
debug!(
model = %request.model,
has_previous_response_id = request.previous_response_id.is_some(),
stream = ?request.stream,
"Responses request validated"
);
let bearer_token = headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "));
let use_adapter = should_use_adapter(&state, &request.model, bearer_token);
if use_adapter {
debug!(model = %request.model, "Using Open Responses adapter");
return handle_adapter_request(state, headers, request, extensions).await;
}
debug!(model = %request.model, "Passthrough mode for responses request");
let original_model = request.model.clone();
let is_streaming = request.stream.unwrap_or(false);
let mut request = request;
if let Some(ref mut tools) = request.tools {
for tool in tools.iter_mut() {
if let super::schemas::responses::Tool::Function { parameters, .. } = tool
&& let Some(obj) = parameters.as_object_mut()
&& !obj.contains_key("additionalProperties")
{
obj.insert(
"additionalProperties".to_string(),
serde_json::Value::Bool(false),
);
}
}
}
let body_bytes = match serde_json::to_vec(&request) {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to serialize responses request");
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to process request",
);
}
};
let resolved_model =
extract_model_from_request(&headers, &body_bytes).unwrap_or(original_model.clone());
let ForwardResult {
response,
trusted,
internal_error,
} = forward_request(state, headers, "/responses", body_bytes).await;
if response.status().is_success() {
let response_is_sse = response_is_sse(&response);
if is_streaming || response_is_sse {
sanitize_streaming_responses_response(response, resolved_model, trusted).await
} else {
sanitize_responses_response(response, resolved_model).await
}
} else if trusted || internal_error {
debug!(model = %resolved_model, "Bypassing error sanitization for trusted provider");
response
} else {
sanitize_error_response(response).await
}
}
pub async fn embeddings_handler<T: HttpClient + Clone + Send + Sync + 'static>(
State(state): State<AppState<T>>,
headers: HeaderMap,
Json(request): Json<EmbeddingsRequest>,
) -> Response {
let original_model = request.model.clone();
debug!(
model = %original_model,
"Embeddings request validated"
);
let body_bytes = match serde_json::to_vec(&request) {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to serialize embeddings request");
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to process request",
);
}
};
let resolved_model =
extract_model_from_request(&headers, &body_bytes).unwrap_or(original_model.clone());
let ForwardResult {
response,
trusted,
internal_error,
} = forward_request(state, headers, "/embeddings", body_bytes).await;
if response.status().is_success() {
sanitize_embeddings_response(response, resolved_model).await
} else if trusted || internal_error {
debug!(model = %resolved_model, "Bypassing error sanitization for trusted provider");
response
} else {
sanitize_error_response(response).await
}
}
pub async fn completions_handler<T: HttpClient + Clone + Send + Sync + 'static>(
State(state): State<AppState<T>>,
headers: HeaderMap,
Json(request): Json<CompletionRequest>,
) -> Response {
let original_model = request.model.clone();
let is_streaming = request.stream.unwrap_or(false);
debug!(
model = %original_model,
stream = is_streaming,
"Completions request validated"
);
let body_bytes = match serde_json::to_vec(&request) {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to serialize completions request");
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to process request",
);
}
};
let resolved_model =
extract_model_from_request(&headers, &body_bytes).unwrap_or(original_model);
let ForwardResult {
response,
trusted,
internal_error,
} = forward_request(state, headers, "/completions", body_bytes).await;
if response.status().is_success() {
let response_is_sse = response_is_sse(&response);
if is_streaming || response_is_sse {
sanitize_streaming_completions_response(response, resolved_model, trusted).await
} else {
sanitize_completions_response(response, resolved_model).await
}
} else if trusted || internal_error {
debug!(model = %resolved_model, "Bypassing error sanitization for trusted provider");
response
} else {
sanitize_error_response(response).await
}
}
fn should_use_adapter<T: HttpClient + Clone + Send + Sync + 'static>(
state: &AppState<T>,
model: &str,
bearer_token: Option<&str>,
) -> bool {
let pool = match state.targets.targets.get(model) {
Some(pool) => pool.clone(),
None => {
debug!(model = %model, "No target found, cannot determine adapter setting");
return false;
}
};
if let Some(target) = pool.first_target() {
return target
.open_responses
.as_ref()
.map(|config| config.adapter)
.unwrap_or(false);
}
let labels = bearer_token
.and_then(|token| {
state
.targets
.key_labels
.get(token)
.map(|r| r.value().clone())
})
.unwrap_or_default();
match pool.evaluate_routing_rules(&labels) {
Some(crate::target::RoutingAction::Redirect { target }) => {
debug!(
model = %model,
redirect_target = %target,
"Following routing rule redirect for adapter check"
);
state
.targets
.targets
.get(target.as_str())
.and_then(|p| p.first_target().cloned())
.and_then(|t| t.open_responses)
.map(|config| config.adapter)
.unwrap_or(false)
}
Some(crate::target::RoutingAction::Deny) => {
debug!(model = %model, "Routing rule denied, defaulting to passthrough");
false
}
None => {
debug!(model = %model, "Pool is empty with no matching routing rules, cannot determine adapter setting");
false
}
}
}
async fn handle_adapter_request<T: HttpClient + Clone + Send + Sync + 'static>(
state: AppState<T>,
headers: HeaderMap,
mut request: ResponsesRequest,
extensions: axum::http::Extensions,
) -> Response {
let adapter =
OpenResponsesAdapter::new(state.response_store.clone(), state.tool_executor.clone());
let mut ctx = RequestContext::new().with_model(&request.model);
ctx.extensions = extensions;
let available_server_tools = state.tool_executor.tools(&ctx).await;
let available_server_tool_map: HashMap<String, &crate::traits::ToolSchema> =
available_server_tools
.iter()
.map(|t| (t.name.clone(), t))
.collect();
debug!(
available_server_tool_count = available_server_tools.len(),
"Resolved available server-side tools"
);
let mut requested_server_tools: Vec<&crate::traits::ToolSchema> = Vec::new();
if let Some(ref client_tools) = request.tools {
for tool in client_tools {
if let super::schemas::responses::Tool::HostedTool { name } = tool {
match available_server_tool_map.get(name.as_str()) {
Some(schema) => requested_server_tools.push(schema),
None => {
return error_response(
StatusCode::BAD_REQUEST,
"invalid_request_error",
&format!(
"Unknown hosted tool: '{}'. Available tools: [{}]",
name,
available_server_tools
.iter()
.map(|t| t.name.as_str())
.collect::<Vec<_>>()
.join(", ")
),
);
}
}
}
}
if !requested_server_tools.is_empty() {
let client_fn_names: HashSet<String> = client_tools
.iter()
.filter_map(|t| match t {
super::schemas::responses::Tool::Function { name, .. } => Some(name.clone()),
_ => None,
})
.collect();
let collisions: Vec<&str> = requested_server_tools
.iter()
.filter(|t| client_fn_names.contains(&t.name))
.map(|t| t.name.as_str())
.collect();
if !collisions.is_empty() {
return error_response(
StatusCode::BAD_REQUEST,
"invalid_request_error",
&format!(
"Tool name collision: the following tools are provided by both the client and server: {}",
collisions.join(", ")
),
);
}
}
request.tools = Some(
client_tools
.iter()
.filter(|t| !matches!(t, super::schemas::responses::Tool::HostedTool { .. }))
.cloned()
.collect(),
);
}
let server_tool_names: HashSet<String> = requested_server_tools
.iter()
.map(|t| t.name.clone())
.collect();
let mut chat_request = match adapter.to_chat_request(&request).await {
Ok(req) => req,
Err(e) => {
error!(error = %e, "Failed to convert responses request to chat completions");
return error_response(
StatusCode::BAD_REQUEST,
"invalid_request_error",
&format!("Failed to process request: {}", e),
);
}
};
if !requested_server_tools.is_empty() {
let schemas: Vec<crate::traits::ToolSchema> = requested_server_tools
.iter()
.map(|t| (*t).clone())
.collect();
merge_server_tools(&mut chat_request, &schemas);
}
let header_stream = state
.streaming_header
.as_ref()
.and_then(|name| headers.get(name.as_str()))
.and_then(|v| v.to_str().ok())
== Some("true");
if request.stream == Some(true) || header_stream {
debug!(
explicit_stream = request.stream == Some(true),
header_stream = header_stream,
"Using streaming adapter mode"
);
return handle_streaming_adapter_request(
state,
headers,
request,
chat_request,
server_tool_names,
ctx,
)
.await;
}
let max_iterations = adapter.max_iterations();
let mut iteration = 0;
let mut accumulated_usage: Option<Usage> = None;
loop {
iteration += 1;
debug!(
iteration = iteration,
max = max_iterations,
"Tool loop iteration"
);
let body_bytes = match serde_json::to_vec(&chat_request) {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to serialize chat completions request");
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to process request",
);
}
};
let model_call_span = tracing::info_span!(
"model.call",
llm.token_count.input = tracing::field::Empty,
llm.token_count.output = tracing::field::Empty,
loop.iteration = iteration as i64,
);
let response = forward_request_raw(
state.clone(),
headers.clone(),
"/chat/completions",
body_bytes,
)
.await;
if !response.status().is_success() {
return response;
}
let (parts, body) = response.into_parts();
let body_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
error!(error = %e, "Failed to read response body");
return error_response(
StatusCode::BAD_GATEWAY,
"upstream_error",
"Failed to read upstream response",
);
}
};
let chat_response: ChatCompletionResponse = match serde_json::from_slice(&body_bytes) {
Ok(resp) => resp,
Err(e) => {
error!(error = %e, "Failed to parse chat completions response");
if let Ok(text) = std::str::from_utf8(&body_bytes) {
debug!(
response_preview = &text[..text.len().min(500)],
"Response body preview"
);
}
return error_response(
StatusCode::BAD_GATEWAY,
"upstream_error",
"Failed to parse upstream response",
);
}
};
if let Some(ref usage) = chat_response.usage {
model_call_span.record("llm.token_count.input", usage.prompt_tokens as i64);
model_call_span.record("llm.token_count.output", usage.completion_tokens as i64);
accumulated_usage = Some(match accumulated_usage.take() {
None => usage.clone(),
Some(prev) => Usage {
prompt_tokens: prev.prompt_tokens + usage.prompt_tokens,
completion_tokens: prev.completion_tokens + usage.completion_tokens,
total_tokens: prev.total_tokens + usage.total_tokens,
prompt_tokens_details: None,
completion_tokens_details: None,
},
});
}
drop(model_call_span);
if OpenResponsesAdapter::requires_tool_action(&chat_response) && iteration < max_iterations
{
debug!("Response requires tool action");
let tool_calls = OpenResponsesAdapter::extract_tool_calls(&chat_response);
debug!(tool_count = tool_calls.len(), "Extracted tool calls");
let results = adapter
.execute_tool_calls(&tool_calls, &server_tool_names, &ctx)
.await;
if adapter.has_unhandled_tools(&results) {
debug!("Some tools are unhandled, returning to client");
let aggregate_response_usage = accumulated_usage.as_ref().map(|u| ResponseUsage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens_details: OutputTokensDetails {
reasoning_tokens: 0,
},
});
let responses_response = adapter.to_responses_response_with_usage(
&chat_response,
&request,
aggregate_response_usage,
);
let response_bytes = match serde_json::to_vec(&responses_response) {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to serialize responses response");
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to serialize response",
);
}
};
return Response::builder()
.status(parts.status)
.header("content-type", "application/json")
.body(Body::from(response_bytes))
.unwrap_or_else(|_| {
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to build response",
)
});
}
debug!("All tools handled, continuing loop");
if let Some(choice) = chat_response.choices.first() {
OpenResponsesAdapter::add_tool_results_to_messages(
&mut chat_request.messages,
&choice.message,
&results,
);
}
continue;
}
let aggregate_response_usage = accumulated_usage.as_ref().map(|u| ResponseUsage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens_details: OutputTokensDetails {
reasoning_tokens: 0,
},
});
let responses_response = adapter.to_responses_response_with_usage(
&chat_response,
&request,
aggregate_response_usage,
);
info!(
response_id = %responses_response.id,
status = ?responses_response.status,
output_items = responses_response.output.len(),
iterations = iteration,
"Adapter conversion complete"
);
let response_bytes = match serde_json::to_vec(&responses_response) {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to serialize responses response");
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to serialize response",
);
}
};
return Response::builder()
.status(parts.status)
.header("content-type", "application/json")
.body(Body::from(response_bytes))
.unwrap_or_else(|_| {
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to build response",
)
});
}
}
async fn handle_streaming_adapter_request<T: HttpClient + Clone + Send + Sync + 'static>(
state: AppState<T>,
headers: HeaderMap,
request: ResponsesRequest,
mut chat_request: ChatCompletionRequest,
server_tool_names: HashSet<String>,
ctx: RequestContext,
) -> Response {
chat_request.stream = Some(true);
let body_bytes = match serde_json::to_vec(&chat_request) {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to serialize streaming chat completions request");
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to process request",
);
}
};
let response = forward_request_raw(
state.clone(),
headers.clone(),
"/chat/completions",
body_bytes,
)
.await;
if !response.status().is_success() {
return response;
}
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !is_sse_content_type(content_type) {
warn!(
content_type = content_type,
"Expected SSE stream but got different content type"
);
return response;
}
let (parts, first_body) = response.into_parts();
let adapter =
OpenResponsesAdapter::new(state.response_store.clone(), state.tool_executor.clone());
let max_iterations = adapter.max_iterations();
let transformed_stream = async_stream::stream! {
let mut streaming_state = StreamingState::new(&request);
let mut messages = chat_request.messages.clone();
let mut current_body = Some(first_body);
let mut iteration = 0u32;
while let Some(body) = current_body.take() {
iteration += 1;
let byte_stream = body.into_data_stream();
let mut buffer = String::new();
let mut last_finish_reason: Option<String> = None;
let pinned_stream = std::pin::pin!(byte_stream);
let mut stream = pinned_stream;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(bytes) => {
if let Ok(text) = std::str::from_utf8(&bytes) {
buffer.push_str(text);
} else {
continue;
}
while let Some(event_end) = buffer.find("\n\n") {
let event_text = buffer[..event_end].to_string();
buffer = buffer[event_end + 2..].to_string();
for line in event_text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data.trim() == "[DONE]" {
trace!("Received [DONE] marker");
continue;
}
if let Some(chunk) = parse_chat_chunk(data) {
for choice in &chunk.choices {
if let Some(ref reason) = choice.finish_reason {
last_finish_reason = Some(reason.clone());
}
}
trace!(chunk_id = %chunk.id, "Processing chat chunk");
let events = streaming_state.process_chunk(&chunk);
for event in events {
yield Ok::<_, std::io::Error>(event.to_sse().into_bytes());
}
}
}
}
}
}
Err(e) => {
error!(error = %e, "Error reading stream");
break;
}
}
}
if last_finish_reason.as_deref() == Some("tool_calls") && iteration < max_iterations {
let pending = streaming_state.extract_tool_calls();
if pending.is_empty() {
debug!("finish_reason was tool_calls but no tool calls found");
break;
}
debug!(
iteration,
tool_count = pending.len(),
"Streaming tool loop: executing tools"
);
let tool_calls: Vec<PendingToolCall> = pending
.into_iter()
.map(|(id, name, args)| PendingToolCall {
id,
name,
arguments: args,
})
.collect();
let results = adapter.execute_tool_calls(&tool_calls, &server_tool_names, &ctx).await;
if adapter.has_unhandled_tools(&results) {
debug!("Unhandled tools in streaming mode, stopping loop");
break;
}
let assistant_msg = ChatMessage {
role: "assistant".to_string(),
content: None,
name: None,
tool_calls: Some(
tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
call_type: "function".to_string(),
function: FunctionCall {
name: tc.name.clone(),
arguments: tc.arguments.clone(),
},
})
.collect(),
),
tool_call_id: None,
reasoning: None,
reasoning_content: None,
reasoning_details: None,
extra: None,
};
OpenResponsesAdapter::add_tool_results_to_messages(&mut messages, &assistant_msg, &results);
streaming_state.prepare_next_iteration();
let mut next_request = chat_request.clone();
next_request.messages = messages.clone();
let body_bytes = match serde_json::to_vec(&next_request) {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to serialize next tool-loop iteration");
break;
}
};
let next_response = forward_request_raw(
state.clone(),
headers.clone(),
"/chat/completions",
body_bytes,
)
.await;
if !next_response.status().is_success() {
error!(
status = %next_response.status(),
"Upstream error during streaming tool loop"
);
break;
}
let (_, next_body) = next_response.into_parts();
current_body = Some(next_body);
} else {
break;
}
}
let done_events = streaming_state.finalize();
for event in done_events {
yield Ok::<_, std::io::Error>(event.to_sse().into_bytes());
}
};
Response::builder()
.status(parts.status)
.header("content-type", "text/event-stream")
.header("cache-control", "no-cache")
.header("connection", "keep-alive")
.body(Body::from_stream(transformed_stream))
.unwrap_or_else(|_| {
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to build streaming response",
)
})
}
async fn forward_request_raw<T: HttpClient + Clone + Send + Sync + 'static>(
state: AppState<T>,
mut headers: HeaderMap,
path: &str,
body_bytes: Vec<u8>,
) -> Response {
headers.insert(
axum::http::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers.remove(axum::http::header::ACCEPT_ENCODING);
let request = match Request::builder()
.method("POST")
.uri(path)
.body(Body::from(body_bytes))
{
Ok(mut req) => {
*req.headers_mut() = headers;
req
}
Err(e) => {
error!(error = %e, "Failed to build upstream request");
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to process request",
);
}
};
match target_message_handler(State(state), request).await {
Ok(response) => response,
Err(err) => err.into_response(),
}
}
async fn forward_request<T: HttpClient + Clone + Send + Sync + 'static>(
state: AppState<T>,
mut headers: HeaderMap,
path: &str,
body_bytes: Vec<u8>,
) -> ForwardResult {
headers.insert(
axum::http::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers.remove(axum::http::header::ACCEPT_ENCODING);
let mut request_builder = Request::builder().method("POST").uri(path);
for (name, value) in headers.iter() {
request_builder = request_builder.header(name, value);
}
let request = match request_builder.body(Body::from(body_bytes)) {
Ok(req) => req,
Err(e) => {
error!(error = %e, "Failed to build request");
let response = error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Failed to build request",
);
return ForwardResult {
response,
trusted: false,
internal_error: true,
};
}
};
let (response, internal_error) = match target_message_handler(State(state), request).await {
Ok(response) => (response, false),
Err(err) => (err.into_response(), true),
};
let trusted = response
.extensions()
.get::<ResolvedTrust>()
.map(|t| t.0)
.unwrap_or(false);
ForwardResult {
response,
trusted,
internal_error,
}
}
fn merge_server_tools(
chat_request: &mut ChatCompletionRequest,
server_tools: &[crate::traits::ToolSchema],
) {
let chat_tools: Vec<ChatTool> = server_tools
.iter()
.map(|ts| {
let mut params = ts.parameters.clone();
if let Some(obj) = params.as_object_mut()
&& !obj.contains_key("additionalProperties")
{
obj.insert(
"additionalProperties".to_string(),
serde_json::Value::Bool(false),
);
}
ChatTool {
tool_type: "function".to_string(),
function: FunctionDefinition {
name: ts.name.clone(),
description: Some(ts.description.clone()),
parameters: Some(params),
strict: Some(ts.strict),
},
}
})
.collect();
match chat_request.tools {
Some(ref mut existing) => existing.extend(chat_tools),
None => chat_request.tools = Some(chat_tools),
}
}
fn error_response(status: StatusCode, error_type: &str, message: &str) -> Response {
let body = json!({
"error": {
"message": message,
"type": error_type,
"param": null,
"code": null
}
});
(status, Json(body)).into_response()
}
async fn sanitize_chat_response(mut response: Response, original_model: String) -> Response {
let body_bytes = match axum::body::to_bytes(std::mem::take(response.body_mut()), usize::MAX)
.await
{
Ok(bytes) => {
debug!(
bytes_read = bytes.len(),
body_sample = ?String::from_utf8_lossy(&bytes).chars().take(100).collect::<String>(),
"Read upstream response body for sanitization"
);
bytes
}
Err(e) => {
error!(error = %e, "Failed to read response body for sanitization");
return error_response(
StatusCode::BAD_GATEWAY,
"api_error",
"Failed to read upstream response",
);
}
};
let mut raw_response: serde_json::Value = match serde_json::from_slice(&body_bytes) {
Ok(resp) => resp,
Err(e) => {
error!(
error = %e,
body_sample = ?String::from_utf8_lossy(&body_bytes).chars().take(200).collect::<String>(),
"Failed to deserialize chat response from provider, returning standard error"
);
return error_response(StatusCode::BAD_GATEWAY, "api_error", "Bad gateway");
}
};
normalize_chat_completion_response_value(&mut raw_response, &original_model);
let mut chat_response: ChatCompletionResponse = match serde_json::from_value(raw_response) {
Ok(resp) => resp,
Err(e) => {
error!(
error = %e,
"Failed to coerce chat response into strict schema"
);
return error_response(StatusCode::BAD_GATEWAY, "api_error", "Bad gateway");
}
};
chat_response.model = original_model;
match serde_json::to_vec(&chat_response) {
Ok(sanitized_bytes) => {
let content_length = sanitized_bytes.len();
debug!(
content_length = content_length,
body_sample = ?String::from_utf8_lossy(&sanitized_bytes).chars().take(100).collect::<String>(),
"Setting sanitized response body"
);
*response.body_mut() = Body::from(sanitized_bytes);
response
.headers_mut()
.remove(axum::http::header::TRANSFER_ENCODING);
response.headers_mut().insert(
header::CONTENT_LENGTH,
header::HeaderValue::from(content_length),
);
debug!("Sanitized non-streaming chat completion response");
response
}
Err(e) => {
error!(
error = %e,
"Failed to serialize sanitized chat response, returning standard error"
);
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"api_error",
"Internal server error",
)
}
}
}
async fn sanitize_streaming_chat_response(
mut response: Response,
original_model: String,
trusted: bool,
) -> Response {
let body_stream =
http_body_util::BodyExt::into_data_stream(std::mem::take(response.body_mut()));
let buffered_stream = crate::sse::SseBufferedStream::new(body_stream);
let stream_fallback_id = generated_chat_completion_id();
let sanitized_stream = buffered_stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk);
let mut sanitized_lines = Vec::new();
for line in chunk_str.lines() {
if let Some(data_part) = line.strip_prefix("data: ") {
if data_part.trim() == "[DONE]" {
sanitized_lines.push(line.to_string());
continue;
}
let mut raw_chunk: serde_json::Value = match serde_json::from_str(data_part)
{
Ok(chunk) => chunk,
Err(e) => {
if let Some(error_event) = try_format_sse_error(data_part, trusted) {
error!(
data_sample = ?data_part.chars().take(200).collect::<String>(),
"Provider returned error object inside SSE stream, forwarding as error event"
);
sanitized_lines.push(error_event);
continue;
}
error!(
error = %e,
data_sample = ?data_part.chars().take(200).collect::<String>(),
"Failed to parse SSE data line from provider, terminating stream"
);
return Err(std::io::Error::other(
"Malformed SSE data from provider",
));
}
};
normalize_chat_completion_chunk_value(
&mut raw_chunk,
&original_model,
&stream_fallback_id,
);
match serde_json::from_value::<ChatCompletionChunk>(raw_chunk) {
Ok(mut chunk_data) => {
chunk_data.model = original_model.clone();
match serde_json::to_string(&chunk_data) {
Ok(sanitized_json) => {
sanitized_lines.push(format!("data: {}", sanitized_json));
}
Err(e) => {
error!(error = %e, "Failed to serialize chunk, terminating stream");
return Err(std::io::Error::other(
"Failed to serialize chunk",
));
}
}
}
Err(e) => {
if let Some(error_event) = try_format_sse_error(data_part, trusted) {
error!(
data_sample = ?data_part.chars().take(200).collect::<String>(),
"Provider returned error object inside SSE stream, forwarding as error event"
);
sanitized_lines.push(error_event);
continue;
}
error!(
error = %e,
data_sample = ?data_part.chars().take(200).collect::<String>(),
"Failed to parse SSE data line from provider, terminating stream"
);
return Err(std::io::Error::other(
"Malformed SSE data from provider",
));
}
}
} else if line.is_empty() {
sanitized_lines.push(String::new());
}
}
let mut sanitized_chunk = sanitized_lines.join("\n");
let input_trailing = chunk_str.chars().rev().take_while(|&c| c == '\n').count();
let output_trailing = sanitized_chunk
.chars()
.rev()
.take_while(|&c| c == '\n')
.count();
for _ in output_trailing..input_trailing {
sanitized_chunk.push('\n');
}
Ok::<_, std::io::Error>(axum::body::Bytes::from(sanitized_chunk))
}
Err(e) => {
error!(error = %e, "Stream error");
Err(std::io::Error::other(e))
}
}
});
*response.body_mut() = Body::from_stream(sanitized_stream);
response.headers_mut().remove(header::CONTENT_LENGTH);
debug!("Set up streaming chat completion response sanitization");
response
}
async fn sanitize_completions_response(mut response: Response, original_model: String) -> Response {
let body_bytes =
match axum::body::to_bytes(std::mem::take(response.body_mut()), usize::MAX).await {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to read completions response body");
return error_response(
StatusCode::BAD_GATEWAY,
"api_error",
"Failed to read upstream response",
);
}
};
let mut raw_response: serde_json::Value = match serde_json::from_slice(&body_bytes) {
Ok(resp) => resp,
Err(e) => {
error!(
error = %e,
body_sample = ?String::from_utf8_lossy(&body_bytes).chars().take(200).collect::<String>(),
"Failed to deserialize completions response from provider, returning standard error"
);
return error_response(StatusCode::BAD_GATEWAY, "api_error", "Bad gateway");
}
};
normalize_completion_response_value(&mut raw_response, &original_model);
let mut completion_response: CompletionResponse = match serde_json::from_value(raw_response) {
Ok(resp) => resp,
Err(e) => {
error!(
error = %e,
"Failed to coerce completions response into strict schema"
);
return error_response(StatusCode::BAD_GATEWAY, "api_error", "Bad gateway");
}
};
completion_response.model = original_model;
match serde_json::to_vec(&completion_response) {
Ok(sanitized_bytes) => {
let content_length = sanitized_bytes.len();
*response.body_mut() = Body::from(sanitized_bytes);
response
.headers_mut()
.remove(axum::http::header::TRANSFER_ENCODING);
response.headers_mut().insert(
header::CONTENT_LENGTH,
header::HeaderValue::from(content_length),
);
debug!("Sanitized non-streaming completions response");
response
}
Err(e) => {
error!(error = %e, "Failed to serialize sanitized completions response");
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"api_error",
"Internal server error",
)
}
}
}
async fn sanitize_streaming_completions_response(
mut response: Response,
original_model: String,
trusted: bool,
) -> Response {
let body_stream =
http_body_util::BodyExt::into_data_stream(std::mem::take(response.body_mut()));
let buffered_stream = crate::sse::SseBufferedStream::new(body_stream);
let stream_fallback_id = generated_completion_id();
let sanitized_stream = buffered_stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk);
let mut sanitized_lines = Vec::new();
for line in chunk_str.lines() {
if let Some(data_part) = line.strip_prefix("data: ") {
if data_part.trim() == "[DONE]" {
sanitized_lines.push(line.to_string());
continue;
}
let mut raw_chunk: serde_json::Value = match serde_json::from_str(data_part)
{
Ok(chunk) => chunk,
Err(e) => {
if let Some(error_event) = try_format_sse_error(data_part, trusted) {
error!(
data_sample = ?data_part.chars().take(200).collect::<String>(),
"Provider returned error object inside SSE stream, forwarding as error event"
);
sanitized_lines.push(error_event);
continue;
}
error!(
error = %e,
raw = %data_part.chars().take(200).collect::<String>(),
"Failed to parse completions chunk, terminating stream"
);
return Err(std::io::Error::other(
"Malformed SSE data from provider",
));
}
};
normalize_completion_chunk_value(
&mut raw_chunk,
&original_model,
&stream_fallback_id,
);
match serde_json::from_value::<CompletionChunk>(raw_chunk) {
Ok(mut chunk_data) => {
chunk_data.model = original_model.clone();
match serde_json::to_string(&chunk_data) {
Ok(json) => sanitized_lines.push(format!("data: {json}")),
Err(e) => {
error!(error = %e, "Failed to serialize completions chunk, terminating stream");
return Err(std::io::Error::other(
"Failed to serialize completions chunk",
));
}
}
}
Err(e) => {
if let Some(error_event) = try_format_sse_error(data_part, trusted) {
error!(
data_sample = ?data_part.chars().take(200).collect::<String>(),
"Provider returned error object inside SSE stream, forwarding as error event"
);
sanitized_lines.push(error_event);
continue;
}
error!(
error = %e,
raw = %data_part.chars().take(200).collect::<String>(),
"Failed to parse completions chunk, terminating stream"
);
return Err(std::io::Error::other(
"Malformed SSE data from provider",
));
}
}
} else if line.is_empty() {
sanitized_lines.push(String::new());
}
}
let mut sanitized_chunk = sanitized_lines.join("\n");
let input_trailing =
chunk_str.chars().rev().take_while(|&c| c == '\n').count();
let output_trailing = sanitized_chunk
.chars()
.rev()
.take_while(|&c| c == '\n')
.count();
for _ in output_trailing..input_trailing {
sanitized_chunk.push('\n');
}
Ok::<_, std::io::Error>(axum::body::Bytes::from(sanitized_chunk))
}
Err(e) => {
error!(error = %e, "Stream error");
Err(std::io::Error::other(e))
}
}
});
*response.body_mut() = Body::from_stream(sanitized_stream);
response.headers_mut().remove(header::CONTENT_LENGTH);
debug!("Set up streaming completions response sanitization");
response
}
async fn sanitize_embeddings_response(mut response: Response, original_model: String) -> Response {
let body_bytes =
match axum::body::to_bytes(std::mem::take(response.body_mut()), usize::MAX).await {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to read embeddings response body");
return error_response(
StatusCode::BAD_GATEWAY,
"api_error",
"Failed to read upstream response",
);
}
};
let mut embeddings_response: EmbeddingsResponse = match serde_json::from_slice(&body_bytes) {
Ok(resp) => resp,
Err(e) => {
error!(
error = %e,
body_sample = ?String::from_utf8_lossy(&body_bytes).chars().take(200).collect::<String>(),
"Failed to deserialize embeddings response from provider, returning standard error"
);
return error_response(StatusCode::BAD_GATEWAY, "api_error", "Bad gateway");
}
};
embeddings_response.model = original_model;
match serde_json::to_vec(&embeddings_response) {
Ok(sanitized_bytes) => {
let content_length = sanitized_bytes.len();
*response.body_mut() = Body::from(sanitized_bytes);
response
.headers_mut()
.remove(axum::http::header::TRANSFER_ENCODING);
response.headers_mut().insert(
header::CONTENT_LENGTH,
header::HeaderValue::from(content_length),
);
debug!("Sanitized embeddings response");
response
}
Err(e) => {
error!(
error = %e,
"Failed to serialize sanitized embeddings response, returning standard error"
);
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"api_error",
"Internal server error",
)
}
}
}
async fn sanitize_responses_response(mut response: Response, original_model: String) -> Response {
let body_bytes =
match axum::body::to_bytes(std::mem::take(response.body_mut()), usize::MAX).await {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to read responses API response body");
return error_response(
StatusCode::BAD_GATEWAY,
"api_error",
"Failed to read upstream response",
);
}
};
let mut raw_response: serde_json::Value = match serde_json::from_slice(&body_bytes) {
Ok(resp) => resp,
Err(e) => {
error!(
error = %e,
body_sample = ?String::from_utf8_lossy(&body_bytes).chars().take(200).collect::<String>(),
"Failed to deserialize responses API response from provider, returning standard error"
);
return error_response(StatusCode::BAD_GATEWAY, "api_error", "Bad gateway");
}
};
normalize_responses_response_value(&mut raw_response, &original_model);
let mut responses_response: ResponsesResponse = match serde_json::from_value(raw_response) {
Ok(resp) => resp,
Err(e) => {
error!(error = %e, "Failed to coerce responses API response into strict schema");
return error_response(StatusCode::BAD_GATEWAY, "api_error", "Bad gateway");
}
};
responses_response.model = original_model;
match serde_json::to_vec(&responses_response) {
Ok(sanitized_bytes) => {
let content_length = sanitized_bytes.len();
*response.body_mut() = Body::from(sanitized_bytes);
response
.headers_mut()
.remove(axum::http::header::TRANSFER_ENCODING);
response.headers_mut().insert(
header::CONTENT_LENGTH,
header::HeaderValue::from(content_length),
);
debug!("Sanitized responses API response");
response
}
Err(e) => {
error!(
error = %e,
"Failed to serialize sanitized responses API response, returning standard error"
);
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"api_error",
"Internal server error",
)
}
}
}
async fn sanitize_streaming_responses_response(
mut response: Response,
original_model: String,
trusted: bool,
) -> Response {
let body_stream =
http_body_util::BodyExt::into_data_stream(std::mem::take(response.body_mut()));
let buffered_stream = crate::sse::SseBufferedStream::new(body_stream);
let stream_fallback_response_id = generated_response_id();
let sanitized_stream = buffered_stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk);
let mut sanitized_lines = Vec::new();
for line in chunk_str.lines() {
if let Some(data_part) = line.strip_prefix("data: ") {
if data_part.trim() == "[DONE]" {
sanitized_lines.push(line.to_string());
continue;
}
let mut raw_event: serde_json::Value = match serde_json::from_str(data_part)
{
Ok(event) => event,
Err(e) => {
if let Some(error_event) = try_format_sse_error(data_part, trusted) {
error!(
data_sample = ?data_part.chars().take(200).collect::<String>(),
"Provider returned error object inside SSE stream, forwarding as error event"
);
sanitized_lines.push(error_event);
continue;
}
error!(
error = %e,
data_sample = ?data_part.chars().take(200).collect::<String>(),
"Failed to parse responses SSE data line from provider, terminating stream"
);
return Err(std::io::Error::other(
"Malformed SSE data from provider",
));
}
};
normalize_responses_streaming_event_value(
&mut raw_event,
&original_model,
&stream_fallback_response_id,
);
match serde_json::from_value::<ResponsesStreamingEvent>(raw_event) {
Ok(mut event) => {
if let Some(ref mut response) = event.response {
response.model = original_model.clone();
}
match serde_json::to_string(&event) {
Ok(sanitized_json) => {
sanitized_lines.push(format!("data: {}", sanitized_json));
}
Err(e) => {
error!(error = %e, "Failed to serialize responses chunk, terminating stream");
return Err(std::io::Error::other(
"Failed to serialize chunk",
));
}
}
}
Err(e) => {
if let Some(error_event) = try_format_sse_error(data_part, trusted) {
error!(
data_sample = ?data_part.chars().take(200).collect::<String>(),
"Provider returned error object inside SSE stream, forwarding as error event"
);
sanitized_lines.push(error_event);
continue;
}
error!(
error = %e,
data_sample = ?data_part.chars().take(200).collect::<String>(),
"Failed to parse responses SSE data line from provider, terminating stream"
);
return Err(std::io::Error::other(
"Malformed SSE data from provider",
));
}
}
} else if line.is_empty() {
sanitized_lines.push(String::new());
}
}
let mut sanitized_chunk = sanitized_lines.join("\n");
let input_trailing = chunk_str.chars().rev().take_while(|&c| c == '\n').count();
let output_trailing = sanitized_chunk
.chars()
.rev()
.take_while(|&c| c == '\n')
.count();
for _ in output_trailing..input_trailing {
sanitized_chunk.push('\n');
}
Ok::<_, std::io::Error>(axum::body::Bytes::from(sanitized_chunk))
}
Err(e) => {
error!(error = %e, "Stream error while sanitizing responses");
Err(std::io::Error::other("Stream error"))
}
}
});
*response.body_mut() = Body::from_stream(sanitized_stream);
response.headers_mut().remove(header::CONTENT_LENGTH);
debug!("Set up streaming responses response sanitization");
response
}
fn try_format_sse_error(data_part: &str, trusted: bool) -> Option<String> {
let value = serde_json::from_str::<serde_json::Value>(data_part).ok()?;
let error_obj = value.get("error")?;
if trusted {
let sanitized = serde_json::to_string(&value).ok()?;
Some(format!("data: {sanitized}"))
} else {
let code = match error_obj.get("code").and_then(|c| c.as_u64()) {
Some(c) => c as u16,
None => {
warn!(
code = ?error_obj.get("code"),
"Provider error object has non-numeric or missing code, defaulting to 500"
);
500
}
};
let (error_type, message) = sanitized_error_for_status(code);
let sanitized = json!({
"error": {
"message": message,
"type": error_type,
"param": null,
"code": code,
}
});
Some(format!("data: {sanitized}"))
}
}
async fn sanitize_error_response(mut response: Response) -> Response {
let status = response.status();
let body_bytes =
match axum::body::to_bytes(std::mem::take(response.body_mut()), usize::MAX).await {
Ok(bytes) => bytes,
Err(e) => {
error!(error = %e, "Failed to read error response body");
return standard_error_response(status);
}
};
error!(
status = %status,
third_party_error = ?String::from_utf8_lossy(&body_bytes),
"Third-party error response (logged, not forwarded)"
);
standard_error_response(status)
}
fn sanitized_error_for_status(status: u16) -> (&'static str, &'static str) {
match status {
400 => ("invalid_request_error", "Invalid request"),
401 => ("authentication_error", "Authentication failed"),
403 => ("permission_error", "Permission denied"),
404 => ("not_found_error", "Not found"),
429 => ("rate_limit_error", "Rate limit exceeded"),
500 => ("api_error", "Internal server error"),
502 => ("api_error", "Bad gateway"),
503 => ("api_error", "Service unavailable"),
_ => ("api_error", "An error occurred"),
}
}
fn standard_error_response(status: StatusCode) -> Response {
let (error_type, message) = sanitized_error_for_status(status.as_u16());
error_response(status, error_type, message)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::target::{Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
#[test]
fn test_error_response_format() {
let response = error_response(
StatusCode::BAD_REQUEST,
"invalid_request_error",
"Test error",
);
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_error_response_matches_openai_format() {
let response = error_response(
StatusCode::BAD_REQUEST,
"invalid_request_error",
"Invalid request",
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let our_error: serde_json::Value = serde_json::from_slice(&body).unwrap();
let openai_format = json!({
"error": {
"message": "Invalid request",
"type": "invalid_request_error",
"param": null,
"code": null
}
});
assert_eq!(our_error, openai_format);
assert!(our_error["error"]["message"].is_string());
assert!(our_error["error"]["type"].is_string());
assert!(our_error["error"]["param"].is_null());
assert!(our_error["error"]["code"].is_null());
}
#[tokio::test]
async fn test_error_types_match_openai_conventions() {
let test_cases = vec![
(StatusCode::BAD_REQUEST, "invalid_request_error"),
(StatusCode::UNAUTHORIZED, "authentication_error"),
(StatusCode::FORBIDDEN, "permission_error"),
(StatusCode::NOT_FOUND, "not_found_error"),
(StatusCode::TOO_MANY_REQUESTS, "rate_limit_error"),
(StatusCode::INTERNAL_SERVER_ERROR, "api_error"),
(StatusCode::BAD_GATEWAY, "api_error"),
(StatusCode::SERVICE_UNAVAILABLE, "api_error"),
];
for (status, expected_type) in test_cases {
let response = standard_error_response(status);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let error: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(
error["error"]["type"].as_str().unwrap(),
expected_type,
"Status {} should map to error type {}",
status,
expected_type
);
}
}
#[tokio::test]
async fn test_strict_sanitize_non_streaming_removes_unknown_fields() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
},
"provider": "custom-llm-provider",
"cost": 0.00123,
"internal_id": "xyz-123",
"custom_field": "should_be_removed"
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"id\":\"chatcmpl-123\""));
assert!(body_str.contains("\"model\":\"gpt-4\""));
assert!(body_str.contains("\"choices\""));
assert!(!body_str.contains("provider"));
assert!(!body_str.contains("cost"));
assert!(!body_str.contains("internal_id"));
assert!(!body_str.contains("custom_field"));
}
#[tokio::test]
async fn test_strict_sanitize_rewrites_model_field() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.onwards_model("gpt-4-turbo-2024-04-09".to_string()) .build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4-turbo-2024-04-09",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello!"
},
"finish_reason": "stop"
}]
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"model\":\"gpt-4\""));
assert!(!body_str.contains("gpt-4-turbo-2024-04-09"));
}
#[tokio::test]
async fn test_strict_sanitize_chat_backfills_missing_noncritical_fields() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"choices": [{
"message": {
"content": "Hello from downstream"
}
}]
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["model"], "gpt-4");
assert_eq!(json["object"], "chat.completion");
let id = json["id"].as_str().expect("id should be a string");
assert!(id.starts_with("chatcmpl-"));
assert_eq!(json["choices"][0]["index"], 0);
assert_eq!(json["choices"][0]["message"]["role"], "assistant");
assert_eq!(
json["choices"][0]["message"]["content"],
"Hello from downstream"
);
}
#[tokio::test]
async fn test_strict_sanitize_chat_backfills_choice_indexes_from_position() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"choices": [
{"message": {"content": "first"}},
{"message": {"content": "second"}}
]
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["choices"][0]["index"], 0);
assert_eq!(json["choices"][1]["index"], 1);
}
#[tokio::test]
async fn test_strict_sanitize_streaming_removes_unknown_fields() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let streaming_chunks = vec![
r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}],"provider":"custom-provider","cost":0.001}
"#.to_string(),
"data: [DONE]\n\n".to_string(),
];
let mock_client = MockHttpClient::new_streaming(StatusCode::OK, streaming_chunks);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body =
r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}],"stream":true}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(!body_str.contains("provider"));
assert!(!body_str.contains("cost"));
assert!(body_str.contains("[DONE]"));
}
#[tokio::test]
async fn test_strict_sanitize_streaming_rewrites_model() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.onwards_model("gpt-4-turbo".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let streaming_chunks = vec![
r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4-turbo","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
"#.to_string(),
"data: [DONE]\n\n".to_string(),
];
let mock_client = MockHttpClient::new_streaming(StatusCode::OK, streaming_chunks);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body =
r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}],"stream":true}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"model\":\"gpt-4\""));
assert!(!body_str.contains("gpt-4-turbo"));
}
#[tokio::test]
async fn test_strict_sanitize_streaming_chat_backfills_missing_noncritical_fields() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let streaming_chunks = vec![
r#"data: {"choices":[{"delta":{"content":"Hello"}}]}
"#
.to_string(),
"data: [DONE]\n\n".to_string(),
];
let mock_client = MockHttpClient::new_streaming(StatusCode::OK, streaming_chunks);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}],"stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"object\":\"chat.completion.chunk\""));
assert!(body_str.contains("\"model\":\"gpt-4\""));
assert!(body_str.contains("\"index\":0"));
assert!(body_str.contains("\"content\":\"Hello\""));
assert!(body_str.contains("\"id\":\"chatcmpl-"));
assert!(body_str.contains("[DONE]"));
}
#[tokio::test]
async fn test_strict_sanitize_streaming_chat_backfills_chunk_indexes_from_position() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let streaming_chunks = vec![
r#"data: {"choices":[{"delta":{"content":"first"}},{"delta":{"content":"second"}}]}
"#
.to_string(),
"data: [DONE]\n\n".to_string(),
];
let mock_client = MockHttpClient::new_streaming(StatusCode::OK, streaming_chunks);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}],"stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"index\":0"));
assert!(body_str.contains("\"index\":1"));
}
#[tokio::test]
async fn test_strict_sanitize_streaming_chat_reuses_fallback_id_across_chunks() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let streaming_chunks = vec![
"data: {\"choices\":[{\"delta\":{\"content\":\"first\"}}]}\n\n".to_string(),
"data: {\"choices\":[{\"delta\":{\"content\":\"second\"}}]}\n\n".to_string(),
"data: [DONE]\n\n".to_string(),
];
let mock_client = MockHttpClient::new_streaming(StatusCode::OK, streaming_chunks);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}],"stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
let ids: Vec<String> = body_str
.lines()
.filter_map(|line| line.strip_prefix("data: "))
.filter(|data| *data != "[DONE]")
.map(|data| serde_json::from_str::<serde_json::Value>(data).unwrap())
.filter_map(|json| json.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
assert!(ids.len() >= 2);
assert!(ids.iter().all(|id| id == &ids[0]));
}
#[tokio::test]
async fn test_strict_chat_uses_streaming_sanitizer_when_body_transform_enables_streaming_after_validation()
{
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let transform_fn: crate::BodyTransformFn = Arc::new(|path, headers, body_bytes| {
if path != "/chat/completions" {
return None;
}
let should_stream = headers
.get("x-fusillade-stream")
.and_then(|value| value.to_str().ok())
.map(|value| value.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if !should_stream {
return None;
}
let mut json_body = serde_json::from_slice::<serde_json::Value>(body_bytes).ok()?;
let obj = json_body.as_object_mut()?;
obj.insert("stream".to_string(), serde_json::Value::Bool(true));
obj.insert(
"stream_options".to_string(),
serde_json::json!({
"include_usage": true
}),
);
serde_json::to_vec(&json_body)
.ok()
.map(axum::body::Bytes::from)
});
let streaming_chunks = vec![
r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}],"provider":"custom-provider"}
"#
.to_string(),
"data: [DONE]\n\n".to_string(),
];
let mock_client = MockHttpClient::new_streaming(StatusCode::OK, streaming_chunks);
let state = AppState::with_client_and_transform(targets, mock_client.clone(), transform_fn);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.header("x-fusillade-stream", "true")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok()),
Some("text/event-stream")
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("data: {"));
assert!(body_str.contains("[DONE]"));
assert!(!body_str.contains("custom-provider"));
let requests = mock_client.get_requests();
assert_eq!(requests.len(), 1);
let forwarded_body: serde_json::Value = serde_json::from_slice(&requests[0].body).unwrap();
assert_eq!(forwarded_body["stream"], true);
assert_eq!(forwarded_body["stream_options"]["include_usage"], true);
}
#[tokio::test]
async fn test_strict_sanitize_embeddings_removes_unknown_fields() {
let targets = Arc::new(DashMap::new());
targets.insert(
"text-embedding-3-small".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"object": "list",
"data": [{
"object": "embedding",
"embedding": [0.1, 0.2, 0.3],
"index": 0
}],
"model": "text-embedding-3-small",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
},
"provider": "custom-embeddings",
"cost": 0.0001,
"cache_hit": true
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"text-embedding-3-small","input":"Hello world"}"#;
let request = Request::builder()
.method("POST")
.uri("/embeddings")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"object\":\"list\""));
assert!(body_str.contains("\"model\":\"text-embedding-3-small\""));
assert!(body_str.contains("\"data\""));
assert!(!body_str.contains("provider"));
assert!(!body_str.contains("cost"));
assert!(!body_str.contains("cache_hit"));
}
#[tokio::test]
async fn test_strict_sanitize_embeddings_rewrites_model() {
let targets = Arc::new(DashMap::new());
targets.insert(
"text-embedding-3-small".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.onwards_model("text-embedding-3-small-internal".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"object": "list",
"data": [{
"object": "embedding",
"embedding": [0.1, 0.2],
"index": 0
}],
"model": "text-embedding-3-small-internal",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"text-embedding-3-small","input":"Hello"}"#;
let request = Request::builder()
.method("POST")
.uri("/embeddings")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"model\":\"text-embedding-3-small\""));
assert!(!body_str.contains("text-embedding-3-small-internal"));
}
#[tokio::test]
async fn test_strict_error_returns_standard_message() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_error_response = r#"{
"error": {
"message": "Database connection failed at postgresql://internal-db:5432",
"type": "internal_error",
"provider": "custom-llm-backend",
"internal_trace_id": "xyz-123-abc",
"debug_info": "Stack trace: error at line 42"
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::BAD_REQUEST, mock_error_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"error\""));
assert!(body_str.contains("\"type\":\"invalid_request_error\""));
assert!(body_str.contains("\"message\":\"Invalid request\""));
assert!(!body_str.contains("Database"));
assert!(!body_str.contains("postgresql"));
assert!(!body_str.contains("provider"));
assert!(!body_str.contains("internal_trace_id"));
assert!(!body_str.contains("debug_info"));
}
#[tokio::test]
async fn test_strict_error_500_returns_standard_message() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_error_response = r#"{
"error": {
"message": "Contact support@provider.com with trace ID abc-123",
"type": "server_error"
}
}"#;
let mock_client =
MockHttpClient::new(StatusCode::INTERNAL_SERVER_ERROR, mock_error_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"type\":\"api_error\""));
assert!(body_str.contains("\"message\":\"Internal server error\""));
assert!(!body_str.contains("support@provider.com"));
assert!(!body_str.contains("trace ID abc-123"));
}
#[tokio::test]
async fn test_strict_handle_malformed_error_response() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_error_response = "<html><body>Internal Server Error</body></html>";
let mock_client =
MockHttpClient::new(StatusCode::INTERNAL_SERVER_ERROR, mock_error_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(!body_str.contains("<html>"));
assert!(!body_str.contains("<body>"));
assert!(body_str.contains("\"error\""));
let _: serde_json::Value = serde_json::from_str(&body_str).expect("Should be valid JSON");
}
#[tokio::test]
async fn test_strict_sanitize_responses_removes_unknown_fields() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "resp_abc123",
"object": "response",
"created_at": 1677652288,
"completed_at": 1677652290,
"status": "completed",
"incomplete_details": null,
"model": "gpt-4o",
"previous_response_id": null,
"instructions": null,
"output": [],
"error": null,
"tools": [],
"tool_choice": "auto",
"truncation": "auto",
"parallel_tool_calls": true,
"text": {},
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"top_logprobs": 0,
"temperature": 1.0,
"reasoning": null,
"usage": null,
"max_output_tokens": null,
"max_tool_calls": null,
"store": false,
"background": false,
"service_tier": "default",
"metadata": null,
"safety_identifier": null,
"prompt_cache_key": null,
"provider": "custom-provider",
"cost": 0.0123,
"internal_trace_id": "xyz-789"
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4o","input":"Hello"}"#;
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"id\":\"resp_abc123\""));
assert!(body_str.contains("\"model\":\"gpt-4o\""));
assert!(body_str.contains("\"status\":\"completed\""));
assert!(!body_str.contains("provider"));
assert!(!body_str.contains("cost"));
assert!(!body_str.contains("internal_trace_id"));
}
#[tokio::test]
async fn test_strict_sanitize_responses_rewrites_model() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.onwards_model("gpt-4o-2024-05-13".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "resp_abc123",
"object": "response",
"created_at": 1677652288,
"completed_at": 1677652290,
"status": "completed",
"incomplete_details": null,
"model": "gpt-4o-2024-05-13",
"previous_response_id": null,
"instructions": null,
"output": [],
"error": null,
"tools": [],
"tool_choice": "auto",
"truncation": "auto",
"parallel_tool_calls": true,
"text": {},
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"top_logprobs": 0,
"temperature": 1.0,
"reasoning": null,
"usage": null,
"max_output_tokens": null,
"max_tool_calls": null,
"store": false,
"background": false,
"service_tier": "default",
"metadata": null,
"safety_identifier": null,
"prompt_cache_key": null
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4o","input":"Hello"}"#;
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"model\":\"gpt-4o\""));
assert!(!body_str.contains("gpt-4o-2024-05-13"));
}
#[tokio::test]
async fn test_strict_sanitize_responses_backfills_missing_noncritical_fields() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "resp_abc123",
"created_at": 1677652288,
"model": "provider-model",
"output": [{
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": "Hello from downstream"
}]
}]
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(r#"{"model":"gpt-4o","input":"Hello"}"#))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["id"], "resp_abc123");
assert_eq!(json["model"], "gpt-4o");
assert_eq!(json["object"], "response");
assert_eq!(json["status"], "completed");
assert_eq!(json["tool_choice"], "auto");
assert_eq!(json["parallel_tool_calls"], true);
assert_eq!(json["usage"], serde_json::Value::Null);
assert_eq!(
json["output"][0]["content"][0]["text"],
"Hello from downstream"
);
}
#[tokio::test]
async fn test_strict_sanitize_responses_error() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_error_response = r#"{
"error": {
"message": "Internal error in provider backend at server xyz-123.internal.com",
"type": "server_error",
"provider": "custom-provider",
"trace_id": "abc-def-ghi"
}
}"#;
let mock_client =
MockHttpClient::new(StatusCode::INTERNAL_SERVER_ERROR, mock_error_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4o","input":"Hello"}"#;
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"type\":\"api_error\""));
assert!(body_str.contains("\"message\":\"Internal server error\""));
assert!(!body_str.contains("Internal error in provider backend"));
assert!(!body_str.contains("xyz-123.internal.com"));
assert!(!body_str.contains("provider"));
assert!(!body_str.contains("trace_id"));
}
#[tokio::test]
async fn test_strict_sanitize_responses_streaming_removes_unknown_fields() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_stream = concat!(
"data: {\"type\":\"response.created\",\"sequence_number\":0,\"response\":{",
"\"id\":\"resp-123\",\"object\":\"response\",\"created_at\":1677652288,",
"\"completed_at\":null,\"status\":\"in_progress\",\"incomplete_details\":null,",
"\"model\":\"gpt-4o-2024-08-06\",\"previous_response_id\":null,\"instructions\":null,",
"\"output\":[],\"error\":null,\"tools\":[],\"tool_choice\":\"auto\",",
"\"truncation\":\"disabled\",\"parallel_tool_calls\":true,",
"\"text\":{\"format\":{\"type\":\"text\"}},\"top_p\":1.0,\"presence_penalty\":0.0,",
"\"frequency_penalty\":0.0,\"top_logprobs\":0,\"temperature\":1.0,",
"\"reasoning\":{\"effort\":null,\"summary\":null},\"usage\":null,",
"\"max_output_tokens\":null,\"max_tool_calls\":null,\"store\":false,",
"\"background\":false,\"service_tier\":\"default\",\"metadata\":null,",
"\"safety_identifier\":null,\"prompt_cache_key\":null}}\n\n",
"data: {\"type\":\"response.output_text.delta\",\"sequence_number\":1,",
"\"item_id\":\"msg_abc\",\"output_index\":0,\"content_index\":0,\"delta\":\"Hello\"}\n\n",
"data: {\"type\":\"response.completed\",\"sequence_number\":2,\"response\":{",
"\"id\":\"resp-123\",\"object\":\"response\",\"created_at\":1677652288,",
"\"completed_at\":1677652290,\"status\":\"completed\",\"incomplete_details\":null,",
"\"model\":\"gpt-4o-2024-08-06\",\"previous_response_id\":null,\"instructions\":null,",
"\"output\":[],\"error\":null,\"tools\":[],\"tool_choice\":\"auto\",",
"\"truncation\":\"disabled\",\"parallel_tool_calls\":true,",
"\"text\":{\"format\":{\"type\":\"text\"}},\"top_p\":1.0,\"presence_penalty\":0.0,",
"\"frequency_penalty\":0.0,\"top_logprobs\":0,\"temperature\":1.0,",
"\"reasoning\":{\"effort\":null,\"summary\":null},\"usage\":null,",
"\"max_output_tokens\":null,\"max_tool_calls\":null,\"store\":false,",
"\"background\":false,\"service_tier\":\"default\",\"metadata\":null,",
"\"safety_identifier\":null,\"prompt_cache_key\":null}}\n\n",
"data: [DONE]\n\n"
);
let mock_client = MockHttpClient::new(StatusCode::OK, mock_stream);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4o","input":"Hello","stream":true}"#;
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_str = String::from_utf8(body_bytes.to_vec()).unwrap();
assert!(
response_str.contains("\"type\":\"response.created\""),
"response.created event should be present"
);
assert!(
response_str.contains("\"type\":\"response.completed\""),
"response.completed event should be present"
);
assert!(
response_str.contains("\"type\":\"response.output_text.delta\""),
"delta event should pass through"
);
assert!(
!response_str.contains("gpt-4o-2024-08-06"),
"Provider model name should be rewritten"
);
assert!(
response_str.contains("\"model\":\"gpt-4o\""),
"Model should be rewritten to the requested model name"
);
assert!(
response_str.contains("\"sequence_number\":0"),
"sequence_number should be preserved"
);
assert!(
response_str.contains("\"reasoning\":{\"effort\":null,\"summary\":null}"),
"reasoning null fields must not be collapsed to {{}}"
);
assert!(
response_str.contains("\"delta\":\"Hello\""),
"delta field should pass through on delta events"
);
}
#[tokio::test]
async fn test_strict_sanitize_responses_streaming_backfills_missing_noncritical_fields() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_stream = concat!(
"data: {\"type\":\"response.created\",\"sequence_number\":0,\"response\":{",
"\"id\":\"resp-123\",\"created_at\":1677652288,\"model\":\"provider-model\"}}\n\n",
"data: {\"type\":\"response.output_text.delta\",\"sequence_number\":1,",
"\"item_id\":\"msg_abc\",\"output_index\":0,\"content_index\":0,\"delta\":\"Hello\"}\n\n",
"data: {\"type\":\"response.completed\",\"sequence_number\":2,\"response\":{",
"\"id\":\"resp-123\",\"created_at\":1677652288,\"output\":[{",
"\"type\":\"message\",\"role\":\"assistant\",\"content\":[{",
"\"type\":\"output_text\",\"text\":\"Hello\"}]}]}}\n\n",
"data: [DONE]\n\n"
);
let mock_client = MockHttpClient::new(StatusCode::OK, mock_stream);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-4o","input":"Hello","stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"type\":\"response.created\""));
assert!(body_str.contains("\"type\":\"response.completed\""));
assert!(body_str.contains("\"model\":\"gpt-4o\""));
assert!(body_str.contains("\"id\":\"resp-123\""));
assert!(body_str.contains("\"status\":\"in_progress\""));
assert!(body_str.contains("\"status\":\"completed\""));
assert!(body_str.contains("\"tool_choice\":\"auto\""));
assert!(body_str.contains("\"delta\":\"Hello\""));
assert!(body_str.contains("[DONE]"));
}
#[tokio::test]
async fn test_strict_sanitize_responses_streaming_backfills_incomplete_details() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_stream = concat!(
"data: {\"type\":\"response.created\",\"sequence_number\":0,\"response\":{",
"\"id\":\"resp-123\",\"created_at\":1677652288,\"model\":\"provider-model\"}}\n\n",
"data: [DONE]\n\n"
);
let mock_client = MockHttpClient::new(StatusCode::OK, mock_stream);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-4o","input":"Hello","stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"incomplete_details\":null"));
}
#[tokio::test]
async fn test_strict_sanitize_responses_streaming_reuses_fallback_id_across_snapshots() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_stream = concat!(
"data: {\"type\":\"response.created\",\"sequence_number\":0,\"response\":{",
"\"created_at\":1677652288,\"model\":\"provider-model\"}}\n\n",
"data: {\"type\":\"response.completed\",\"sequence_number\":1,\"response\":{",
"\"created_at\":1677652288,\"output\":[]}}\n\n",
"data: [DONE]\n\n"
);
let mock_client = MockHttpClient::new(StatusCode::OK, mock_stream);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-4o","input":"Hello","stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
let ids: Vec<String> = body_str
.lines()
.filter_map(|line| line.strip_prefix("data: "))
.filter(|data| *data != "[DONE]")
.map(|data| serde_json::from_str::<serde_json::Value>(data).unwrap())
.filter_map(|json| {
json.get("response")
.and_then(|response| response.get("id"))
.and_then(|id| id.as_str())
.map(str::to_string)
})
.collect();
assert!(ids.len() >= 2);
assert!(ids.iter().all(|id| id == &ids[0]));
}
#[tokio::test]
async fn test_strict_sanitize_responses_streaming_forwards_error_object() {
let response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.body(Body::from(
"data: {\"error\":{\"message\":\"Input too long\",\"type\":\"invalid_request_error\",\"code\":\"context_length_exceeded\"}}\n\n",
))
.unwrap();
let result =
sanitize_streaming_responses_response(response, "test-model".to_string(), true).await;
assert_eq!(result.status(), StatusCode::OK);
let body = axum::body::to_bytes(result.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8_lossy(&body);
assert!(body_str.contains("Input too long"));
assert!(body_str.contains("\"error\""));
}
#[tokio::test]
async fn test_strict_responses_uses_streaming_sanitizer_when_body_transform_enables_streaming_after_validation()
{
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let transform_fn: crate::BodyTransformFn = Arc::new(|path, headers, body_bytes| {
if path != "/responses" {
return None;
}
let should_stream = headers
.get("x-fusillade-stream")
.and_then(|value| value.to_str().ok())
.map(|value| value.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if !should_stream {
return None;
}
let mut json_body = serde_json::from_slice::<serde_json::Value>(body_bytes).ok()?;
let obj = json_body.as_object_mut()?;
obj.insert("stream".to_string(), serde_json::Value::Bool(true));
serde_json::to_vec(&json_body)
.ok()
.map(axum::body::Bytes::from)
});
let mock_client = MockHttpClient::new_streaming(
StatusCode::OK,
vec![
concat!(
"data: {\"type\":\"response.created\",\"sequence_number\":0,\"response\":{",
"\"id\":\"resp-123\",\"object\":\"response\",\"created_at\":1677652288,",
"\"completed_at\":null,\"status\":\"in_progress\",\"incomplete_details\":null,",
"\"model\":\"gpt-4o\",\"previous_response_id\":null,\"instructions\":null,",
"\"output\":[],\"error\":null,\"tools\":[],\"tool_choice\":\"auto\",",
"\"truncation\":\"disabled\",\"parallel_tool_calls\":true,",
"\"text\":{\"format\":{\"type\":\"text\"}},\"top_p\":1.0,\"presence_penalty\":0.0,",
"\"frequency_penalty\":0.0,\"top_logprobs\":0,\"temperature\":1.0,",
"\"reasoning\":{\"effort\":null,\"summary\":null},\"usage\":null,",
"\"max_output_tokens\":null,\"max_tool_calls\":null,\"store\":false,",
"\"background\":false,\"service_tier\":\"default\",\"metadata\":null,",
"\"safety_identifier\":null,\"prompt_cache_key\":null,",
"\"provider\":\"custom-provider\"}}\n\n"
)
.to_string(),
"data: [DONE]\n\n".to_string(),
],
);
let state = AppState::with_client_and_transform(targets, mock_client.clone(), transform_fn)
.with_streaming_header("x-fusillade-stream");
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.header("x-fusillade-stream", "true")
.body(Body::from(r#"{"model":"gpt-4o","input":"Hello"}"#))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok()),
Some("text/event-stream")
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"type\":\"response.created\""));
assert!(body_str.contains("[DONE]"));
assert!(!body_str.contains("custom-provider"));
let requests = mock_client.get_requests();
assert_eq!(requests.len(), 1);
let forwarded_body: serde_json::Value = serde_json::from_slice(&requests[0].body).unwrap();
assert_eq!(forwarded_body["stream"], true);
}
#[tokio::test]
async fn test_chat_sanitization_updates_content_length_header() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "provider-model-name",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
},
"provider": "custom-provider",
"cost": 0.00123,
"trace_id": "abc-xyz-123",
"custom_metadata": "will be dropped"
}"#;
let mut mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
mock_client.set_header("content-length", mock_response.len().to_string());
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"test"}]}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(body_json["model"], "gpt-4"); assert!(body_json.get("provider").is_none()); assert!(body_json.get("cost").is_none()); assert!(body_json.get("trace_id").is_none()); }
#[tokio::test]
async fn test_embeddings_sanitization_updates_content_length_header() {
let targets = Arc::new(DashMap::new());
targets.insert(
"text-embedding-ada-002".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"object": "list",
"data": [{
"object": "embedding",
"index": 0,
"embedding": [0.1, 0.2, 0.3]
}],
"model": "provider-embedding-model",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
},
"provider_metadata": "extra_data",
"cost": 0.0001,
"processing_time_ms": 123
}"#;
let mut mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
mock_client.set_header("content-length", mock_response.len().to_string());
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"text-embedding-ada-002","input":"test"}"#;
let request = Request::builder()
.method("POST")
.uri("/embeddings")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(body_json["model"], "text-embedding-ada-002"); assert!(body_json.get("provider_metadata").is_none()); assert!(body_json.get("cost").is_none()); assert!(body_json.get("processing_time_ms").is_none()); }
#[tokio::test]
async fn test_trusted_target_sanitizes_success_responses() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.openai.com".parse().unwrap())
.onwards_key("sk-test".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("gpt-4".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4-actual-provider-model",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
},
"provider_metadata": {
"cost": 0.001,
"trace_id": "trace-123"
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}]
}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
response_json["model"], "gpt-4",
"Trusted target should still sanitize success responses - model rewritten"
);
assert!(
response_json.get("provider_metadata").is_none(),
"Trusted target should still sanitize success responses - metadata removed"
);
assert_eq!(response_json["object"], "chat.completion");
assert_eq!(response_json["choices"][0]["message"]["content"], "Hello!");
assert_eq!(response_json["usage"]["total_tokens"], 15);
}
#[tokio::test]
async fn test_trusted_target_bypasses_error_sanitization() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.openai.com".parse().unwrap())
.onwards_key("sk-test".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("gpt-4".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_error = r#"{
"error": {
"message": "Internal provider error: GPU cluster unavailable in eu-west-3",
"type": "provider_error",
"code": "internal_error",
"metadata": {
"provider": "openai",
"region": "eu-west-3",
"trace_id": "trace-abc-123"
}
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::INTERNAL_SERVER_ERROR, mock_error);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}]
}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
response_json["error"]["message"],
"Internal provider error: GPU cluster unavailable in eu-west-3",
"Trusted target should preserve original error message"
);
assert_eq!(response_json["error"]["code"], "internal_error");
assert!(response_json["error"]["metadata"].is_object());
assert_eq!(response_json["error"]["metadata"]["provider"], "openai");
assert_eq!(response_json["error"]["metadata"]["region"], "eu-west-3");
assert_eq!(
response_json["error"]["metadata"]["trace_id"],
"trace-abc-123"
);
}
#[tokio::test]
async fn test_trusted_target_bypasses_streaming_error_sanitization() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.openai.com".parse().unwrap())
.onwards_key("sk-test".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("gpt-4".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_error = r#"{
"error": {
"message": "Rate limit exceeded in streaming mode: too many concurrent streams",
"type": "provider_error",
"code": "rate_limit_exceeded",
"metadata": {
"provider": "openai",
"concurrent_streams": 150,
"max_streams": 100,
"trace_id": "trace-stream-456"
}
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::TOO_MANY_REQUESTS, mock_error);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"stream": true
}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
response_json["error"]["message"],
"Rate limit exceeded in streaming mode: too many concurrent streams",
"Trusted target should bypass error sanitization for streaming requests"
);
assert_eq!(response_json["error"]["type"], "provider_error");
assert!(response_json["error"]["metadata"].is_object());
assert_eq!(response_json["error"]["metadata"]["provider"], "openai");
assert_eq!(
response_json["error"]["metadata"]["concurrent_streams"],
150
);
assert_eq!(response_json["error"]["metadata"]["max_streams"], 100);
assert_eq!(
response_json["error"]["metadata"]["trace_id"],
"trace-stream-456"
);
}
#[tokio::test]
async fn test_trusted_target_bypasses_error_sanitization_responses() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.openai.com".parse().unwrap())
.onwards_key("sk-test".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("gpt-4o-mini".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_error = r#"{
"error": {
"message": "Provider-specific error: vLLM out of memory",
"type": "provider_error",
"code": "oom_error",
"metadata": {
"provider": "vllm",
"gpu_id": "3",
"trace_id": "vllm-trace-456"
}
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::INTERNAL_SERVER_ERROR, mock_error);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4o-mini","input":"Test message"}"#;
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
response_json["error"]["message"], "Provider-specific error: vLLM out of memory",
"Trusted target should preserve original error message for /v1/responses"
);
assert_eq!(response_json["error"]["code"], "oom_error");
assert!(response_json["error"]["metadata"].is_object());
assert_eq!(response_json["error"]["metadata"]["provider"], "vllm");
assert_eq!(
response_json["error"]["metadata"]["trace_id"],
"vllm-trace-456"
);
}
#[tokio::test]
async fn test_trusted_target_bypasses_error_sanitization_embeddings() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.openai.com".parse().unwrap())
.onwards_key("sk-test".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("text-embedding-ada-002".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_error = r#"{
"error": {
"message": "Embedding service error: Token limit 8192 exceeded for input text",
"type": "invalid_request_error",
"code": "context_length_exceeded",
"metadata": {
"provider": "openai",
"max_tokens": 8192,
"actual_tokens": 9500,
"trace_id": "emb-trace-789"
}
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::BAD_REQUEST, mock_error);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "text-embedding-ada-002",
"input": "Test text for embedding"
}"#;
let request = Request::builder()
.method("POST")
.uri("/embeddings")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
response_json["error"]["message"],
"Embedding service error: Token limit 8192 exceeded for input text",
"Trusted target should preserve original error message for /v1/embeddings"
);
assert_eq!(response_json["error"]["code"], "context_length_exceeded");
assert!(response_json["error"]["metadata"].is_object());
assert_eq!(response_json["error"]["metadata"]["max_tokens"], 8192);
assert_eq!(response_json["error"]["metadata"]["actual_tokens"], 9500);
assert_eq!(
response_json["error"]["metadata"]["trace_id"],
"emb-trace-789"
);
}
#[tokio::test]
async fn test_trusted_target_sanitizes_success_responses_api() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.openai.com".parse().unwrap())
.onwards_key("sk-test".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("gpt-4o-mini".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "resp-123",
"object": "response",
"created_at": 1677652288,
"completed_at": 1677652290,
"status": "completed",
"incomplete_details": null,
"model": "gpt-4o-mini-actual-provider",
"previous_response_id": null,
"instructions": null,
"output": [],
"error": null,
"tools": [],
"tool_choice": "auto",
"truncation": "auto",
"parallel_tool_calls": true,
"text": {},
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"top_logprobs": 0,
"temperature": 1.0,
"reasoning": null,
"usage": null,
"max_output_tokens": null,
"max_tool_calls": null,
"store": false,
"background": false,
"service_tier": "default",
"metadata": null,
"safety_identifier": null,
"prompt_cache_key": null,
"provider_metadata": {
"cost": 0.002,
"trace_id": "trace-responses-456"
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4o-mini","input":"Test message"}"#;
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
response_json["model"], "gpt-4o-mini",
"Trusted target should still sanitize /v1/responses success - model rewritten"
);
assert!(
response_json.get("provider_metadata").is_none(),
"Trusted target should still sanitize /v1/responses success - metadata removed"
);
assert_eq!(response_json["object"], "response");
assert_eq!(response_json["status"], "completed");
assert!(response_json["output"].is_array());
}
#[tokio::test]
async fn test_trusted_target_sanitizes_success_embeddings() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.openai.com".parse().unwrap())
.onwards_key("sk-test".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("text-embedding-ada-002".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [0.1, 0.2, 0.3],
"index": 0
}
],
"model": "text-embedding-ada-002-actual-provider",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
},
"provider_metadata": {
"cost": 0.0001,
"trace_id": "trace-emb-789",
"region": "us-east-1"
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "text-embedding-ada-002",
"input": "Test text for embedding"
}"#;
let request = Request::builder()
.method("POST")
.uri("/embeddings")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
response_json["model"], "text-embedding-ada-002",
"Trusted target should still sanitize /v1/embeddings success - model rewritten"
);
assert!(
response_json.get("provider_metadata").is_none(),
"Trusted target should still sanitize /v1/embeddings success - metadata removed"
);
assert_eq!(response_json["object"], "list");
assert_eq!(response_json["data"][0]["embedding"][0], 0.1);
assert_eq!(response_json["usage"]["total_tokens"], 8);
}
#[tokio::test]
async fn test_responses_sanitization_updates_content_length_header() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "resp_123",
"object": "response",
"created_at": 1234567890,
"completed_at": 1234567900,
"status": "completed",
"incomplete_details": null,
"model": "provider-model",
"previous_response_id": null,
"instructions": null,
"output": [],
"error": null,
"tools": [],
"tool_choice": "auto",
"truncation": "disabled",
"parallel_tool_calls": true,
"text": {
"format": {
"type": "text"
}
},
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"top_logprobs": 0,
"temperature": 1.0,
"reasoning": null,
"usage": null,
"max_output_tokens": null,
"max_tool_calls": null,
"store": false,
"background": false,
"service_tier": "default",
"metadata": null,
"safety_identifier": null,
"prompt_cache_key": null,
"provider_trace_id": "xyz-123",
"internal_cost": 0.456,
"custom_field": "should_be_removed"
}"#;
let mut mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
mock_client.set_header("content-length", mock_response.len().to_string());
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model":"gpt-4o","input":"test"}"#;
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(body_json["model"], "gpt-4o"); assert!(body_json.get("provider_trace_id").is_none()); assert!(body_json.get("internal_cost").is_none()); assert!(body_json.get("custom_field").is_none()); }
#[tokio::test]
async fn test_streaming_multiline_sse_events_are_sanitized() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let chunk1 = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"provider-model","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}],"provider":"leaked-provider","cost":0.001}
"#;
let mock_client = MockHttpClient::new_streaming(StatusCode::OK, vec![chunk1.to_string()]);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body =
r#"{"model":"gpt-4","messages":[{"role":"user","content":"test"}],"stream":true}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"model\":\"gpt-4\""));
assert!(
!body_str.contains("provider"),
"Provider field should be removed"
);
assert!(!body_str.contains("cost"), "Cost field should be removed");
assert!(
!body_str.contains("leaked-provider"),
"Provider name should not leak"
);
}
#[tokio::test]
async fn test_streaming_sse_comments_dont_leak_metadata() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let chunks = vec![
": provider=custom-llm cost=0.001 trace_id=xyz-123\n".to_string(),
r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"provider-model","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
"#
.to_string(),
];
let mock_client = MockHttpClient::new_streaming(StatusCode::OK, chunks);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body =
r#"{"model":"gpt-4","messages":[{"role":"user","content":"test"}],"stream":true}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(
!body_str.contains("provider=custom-llm"),
"Provider metadata in comments should be stripped"
);
assert!(
!body_str.contains("trace_id=xyz-123"),
"Trace ID should be stripped"
);
assert!(!body_str.contains("cost=0.001"), "Cost should be stripped");
assert!(body_str.contains("\"model\":\"gpt-4\""));
}
#[tokio::test]
async fn test_responses_api_omitted_fields_not_injected_as_null() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "resp_123",
"object": "response",
"created_at": 1234567890,
"completed_at": 1234567900,
"status": "completed",
"incomplete_details": null,
"model": "gpt-4o",
"previous_response_id": null,
"instructions": null,
"output": [],
"error": null,
"tools": [],
"tool_choice": "auto",
"truncation": "disabled",
"parallel_tool_calls": true,
"text": { "format": { "type": "text" } },
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"top_logprobs": 0,
"temperature": 1.0,
"reasoning": null,
"usage": null,
"max_output_tokens": null,
"max_tool_calls": null,
"store": false,
"background": false,
"service_tier": "default",
"metadata": null,
"safety_identifier": null,
"prompt_cache_key": null
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client.clone());
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "gpt-4o",
"input": [{
"type": "message",
"role": "user",
"content": "Hello"
}]
}"#;
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let _response = router.oneshot(request).await.unwrap();
let requests = mock_client.get_requests();
assert_eq!(requests.len(), 1);
let forwarded_body = String::from_utf8(requests[0].body.clone()).unwrap();
let forwarded_json: serde_json::Value = serde_json::from_str(&forwarded_body).unwrap();
let input_items = forwarded_json["input"].as_array().unwrap();
assert_eq!(input_items.len(), 1);
let message_item = &input_items[0];
assert!(
!message_item.as_object().unwrap().contains_key("id"),
"Optional 'id' field should not be present when omitted in request, found: {:?}",
message_item
);
assert!(
!message_item.as_object().unwrap().contains_key("status"),
"Optional 'status' field should not be present when omitted in request, found: {:?}",
message_item
);
}
#[tokio::test]
async fn test_responses_api_unknown_item_types_preserved() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-4o".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "resp_123",
"object": "response",
"created_at": 1234567890,
"completed_at": 1234567900,
"status": "completed",
"incomplete_details": null,
"model": "gpt-4o",
"previous_response_id": null,
"instructions": null,
"output": [],
"error": null,
"tools": [],
"tool_choice": "auto",
"truncation": "disabled",
"parallel_tool_calls": true,
"text": { "format": { "type": "text" } },
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"top_logprobs": 0,
"temperature": 1.0,
"reasoning": null,
"usage": null,
"max_output_tokens": null,
"max_tool_calls": null,
"store": false,
"background": false,
"service_tier": "default",
"metadata": null,
"safety_identifier": null,
"prompt_cache_key": null
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client.clone());
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "gpt-4o",
"input": [{
"type": "web_search",
"query": "latest news",
"max_results": 10,
"custom_field": "should be preserved"
}]
}"#;
let request = Request::builder()
.method("POST")
.uri("/responses")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let _response = router.oneshot(request).await.unwrap();
let requests = mock_client.get_requests();
assert_eq!(requests.len(), 1);
let forwarded_body = String::from_utf8(requests[0].body.clone()).unwrap();
let forwarded_json: serde_json::Value = serde_json::from_str(&forwarded_body).unwrap();
let input_items = forwarded_json["input"].as_array().unwrap();
assert_eq!(input_items.len(), 1);
let unknown_item = &input_items[0];
assert_eq!(unknown_item["type"], "web_search");
assert_eq!(unknown_item["query"], "latest news");
assert_eq!(unknown_item["max_results"], 10);
assert_eq!(
unknown_item["custom_field"], "should be preserved",
"Unknown item fields should be preserved, but got: {:?}",
unknown_item
);
}
#[tokio::test]
async fn test_strict_mode_ignores_target_sanitize_response_flag() {
use crate::target::{Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
targets_map.insert(
"gpt-4".to_string(),
Target::builder()
.url("https://api.openai.com".parse().unwrap())
.onwards_key("sk-test".to_string())
.sanitize_response(true)
.build()
.into_pool(),
);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true, http_pool_config: None,
};
let mock_response = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4-actual-provider-model",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
},
"provider_metadata": {
"cost": 0.001,
"trace_id": "trace-123"
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = crate::AppState::with_client(targets, mock_client)
.with_response_transform(crate::create_openai_sanitizer());
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}]
}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(response_json["model"], "gpt-4");
assert!(
response_json.get("provider_metadata").is_none(),
"Provider metadata should be removed by strict mode sanitization"
);
assert_eq!(response_json["object"], "chat.completion");
assert_eq!(response_json["choices"][0]["message"]["content"], "Hello!");
}
#[tokio::test]
async fn test_untrusted_target_still_sanitized() {
use crate::target::{Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
targets_map.insert(
"third-party".to_string(),
Target::builder()
.url("https://third-party.com".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "provider-internal-model",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
},
"provider_metadata": {
"should": "be removed"
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "third-party",
"messages": [{"role": "user", "content": "Hello"}]
}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
response_json["model"], "third-party",
"Untrusted target should have model field rewritten"
);
assert!(
response_json.get("provider_metadata").is_none(),
"Untrusted target should have provider metadata removed"
);
assert_eq!(response_json["choices"][0]["message"]["content"], "Hello!");
}
#[tokio::test]
async fn test_model_override_header_prevents_trust_bypass() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let trusted_pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://trusted.com".parse().unwrap())
.onwards_key("sk-trusted".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("trusted-pool".to_string(), trusted_pool);
let untrusted_pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://untrusted.com".parse().unwrap())
.onwards_key("sk-untrusted".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
false, Vec::new(),
);
targets_map.insert("untrusted-pool".to_string(), untrusted_pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "chatcmpl-bypass-attempt",
"object": "chat.completion",
"created": 1234567890,
"model": "untrusted-internal-model",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Response from untrusted provider"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
},
"untrusted_metadata": {
"should_be_removed": "yes",
"cost": "$0.001",
"trace_id": "leak-attempt-123"
}
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "trusted-pool",
"messages": [{"role": "user", "content": "Try to bypass"}]
}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.header("model-override", "untrusted-pool") .body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert!(
response_json.get("untrusted_metadata").is_none(),
"SECURITY: Untrusted metadata MUST be removed - bypass attempt prevented"
);
assert_ne!(
response_json["model"], "untrusted-internal-model",
"Model should not be provider's internal model name"
);
assert_eq!(
response_json["choices"][0]["message"]["content"],
"Response from untrusted provider"
);
}
#[tokio::test]
async fn test_trusted_streaming_success_responses_still_sanitized() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.openai.com".parse().unwrap())
.onwards_key("sk-test".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("gpt-4".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_stream = "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4-provider\",\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"index\":0,\"finish_reason\":null}],\"provider_cost\":0.001}\n\ndata: [DONE]\n\n";
let mock_client = MockHttpClient::new(StatusCode::OK, mock_stream);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"stream": true
}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_str = String::from_utf8(body_bytes.to_vec()).unwrap();
assert!(
response_str.contains("\"model\":\"gpt-4\""),
"Trusted pool streaming success should still sanitize - model rewritten"
);
assert!(
!response_str.contains("\"model\":\"gpt-4-provider\""),
"Original provider model should be replaced"
);
assert!(
!response_str.contains("\"provider_cost\""),
"Trusted pool streaming success should still sanitize - metadata removed"
);
}
#[tokio::test]
async fn test_provider_trusted_overrides_untrusted_pool() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.example.com".parse().unwrap())
.trusted(true) .build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
false, Vec::new(),
);
targets_map.insert("gpt-4".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let error_body = r#"{"error": {"message": "provider specific error", "type": "rate_limit_error", "provider_trace": "trace-xyz"}}"#;
let mock_client = MockHttpClient::new(StatusCode::TOO_MANY_REQUESTS, error_body);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = std::str::from_utf8(&body_bytes).unwrap();
assert!(
body_str.contains("provider_trace"),
"Provider-specific fields should pass through for trusted provider (even with untrusted pool)"
);
assert!(
body_str.contains("provider specific error"),
"Original provider message should pass through for trusted provider"
);
}
#[tokio::test]
async fn test_provider_untrusted_overrides_trusted_pool() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::{LoadBalanceStrategy, Target, Targets};
use crate::test_utils::MockHttpClient;
use axum::body::Body;
use axum::http::Request;
use dashmap::DashMap;
use std::sync::Arc;
use tower::ServiceExt;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.example.com".parse().unwrap())
.trusted(false) .build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("gpt-4".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let error_body = r#"{"error": {"message": "provider specific error", "type": "rate_limit_error", "provider_trace": "trace-xyz"}}"#;
let mock_client = MockHttpClient::new(StatusCode::TOO_MANY_REQUESTS, error_body);
let state = crate::AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request_body = r#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]}"#;
let request = Request::builder()
.method("POST")
.uri("/chat/completions")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = std::str::from_utf8(&body_bytes).unwrap();
assert!(
!body_str.contains("provider_trace"),
"Provider-specific fields should be stripped for explicitly untrusted provider (even with trusted pool)"
);
assert!(
!body_str.contains("provider specific error"),
"Original provider message should be stripped for explicitly untrusted provider"
);
assert!(
body_str.contains("Rate limit exceeded"),
"Sanitized standard message should be present"
);
}
fn completions_mock_response(model: &str) -> String {
format!(
r#"{{
"id": "cmpl-abc123",
"object": "text_completion",
"created": 1677652288,
"model": "{model}",
"choices": [{{
"text": "Hello, world!",
"index": 0,
"logprobs": null,
"finish_reason": "stop"
}}],
"usage": {{
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}}
}}"#
)
}
fn completions_test_targets(model: &str) -> Targets {
let targets = Arc::new(DashMap::new());
targets.insert(
model.to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build()
.into_pool(),
);
Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
}
}
#[tokio::test]
async fn test_completions_accepts_missing_prompt() {
let mock_client = MockHttpClient::new(
StatusCode::OK,
&completions_mock_response("gpt-3.5-turbo-instruct"),
);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(r#"{"model":"gpt-3.5-turbo-instruct"}"#))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_completions_rejects_missing_model() {
let mock_client = MockHttpClient::new(StatusCode::OK, "{}");
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(r#"{"prompt":"Say hello"}"#))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[tokio::test]
async fn test_completions_forwards_to_completions_endpoint() {
let mock_client = MockHttpClient::new(
StatusCode::OK,
&completions_mock_response("gpt-3.5-turbo-instruct"),
);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client.clone(),
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Say hello"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let requests = mock_client.get_requests();
assert_eq!(requests.len(), 1);
assert!(
requests[0].uri.contains("completions"),
"Request should be forwarded to /completions, got: {}",
requests[0].uri
);
assert!(
!requests[0].uri.contains("chat"),
"Request must NOT be forwarded to /chat/completions"
);
}
#[tokio::test]
async fn test_strict_sanitize_completions_removes_unknown_fields() {
let mock_response = r#"{
"id": "cmpl-abc123",
"object": "text_completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-instruct",
"choices": [{"text": "Hello!", "index": 0, "logprobs": null, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12},
"provider": "custom-provider",
"cost": 0.0001,
"internal_id": "xyz-456"
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Say hello"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"object\":\"text_completion\""));
assert!(body_str.contains("\"id\":\"cmpl-abc123\""));
assert!(body_str.contains("\"choices\""));
assert!(body_str.contains("\"text\":\"Hello!\""));
assert!(!body_str.contains("\"provider\""));
assert!(!body_str.contains("\"cost\""));
assert!(!body_str.contains("\"internal_id\""));
}
#[tokio::test]
async fn test_strict_sanitize_completions_rewrites_model() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-3.5-turbo-instruct".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.onwards_model("gpt-3.5-turbo-instruct-internal-v2".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = completions_mock_response("gpt-3.5-turbo-instruct-internal-v2");
let mock_client = MockHttpClient::new(StatusCode::OK, &mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hello"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"model\":\"gpt-3.5-turbo-instruct\""));
assert!(!body_str.contains("internal-v2"));
}
#[tokio::test]
async fn test_strict_sanitize_completions_backfills_missing_noncritical_fields() {
let mock_response = r#"{
"choices": [{"text": "Hello!", "finish_reason": "stop"}]
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Say hello"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["model"], "gpt-3.5-turbo-instruct");
assert_eq!(json["object"], "text_completion");
let id = json["id"].as_str().expect("id should be a string");
assert!(id.starts_with("cmpl-"));
assert_eq!(json["choices"][0]["index"], 0);
assert_eq!(json["choices"][0]["text"], "Hello!");
}
#[tokio::test]
async fn test_strict_sanitize_completions_backfills_choice_indexes_from_position() {
let mock_response = r#"{
"choices": [{"text": "first"}, {"text": "second"}]
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Say hello"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["choices"][0]["index"], 0);
assert_eq!(json["choices"][1]["index"], 1);
}
#[tokio::test]
async fn test_strict_sanitize_completions_does_not_backfill_generated_text() {
let mock_response = r#"{
"choices": [{"finish_reason": "stop"}]
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Say hello"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"]["type"], "api_error");
}
#[tokio::test]
async fn test_completions_sanitization_updates_content_length_header() {
let mock_response = r#"{
"id": "cmpl-abc123",
"object": "text_completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-instruct",
"choices": [{"text": "Hi", "index": 0, "logprobs": null, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 1, "total_tokens": 6},
"provider_metadata": "extra_data_that_will_be_stripped",
"cost": 0.0001,
"processing_time_ms": 123
}"#;
let mut mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
mock_client.set_header("content-length", mock_response.len().to_string());
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hi"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let content_length: usize = response
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse().ok())
.expect("content-length header should be present");
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(content_length, body.len());
assert!(
content_length < mock_response.len(),
"sanitized body should be smaller"
);
assert_eq!(body_json["model"], "gpt-3.5-turbo-instruct");
assert!(body_json.get("provider_metadata").is_none());
assert!(body_json.get("cost").is_none());
}
#[tokio::test]
async fn test_completions_array_prompt_accepted() {
let mock_client = MockHttpClient::new(
StatusCode::OK,
&completions_mock_response("gpt-3.5-turbo-instruct"),
);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client.clone(),
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":["Hello","World"]}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let requests = mock_client.get_requests();
let request_json: serde_json::Value = serde_json::from_slice(&requests[0].body).unwrap();
assert!(request_json["prompt"].is_array());
}
#[tokio::test]
async fn test_strict_sanitize_completions_streaming_removes_unknown_fields() {
let mock_client = MockHttpClient::new_streaming(
StatusCode::OK,
vec![
"data: {\"id\":\"cmpl-abc\",\"object\":\"text_completion\",\"created\":1677652288,\"model\":\"gpt-3.5-turbo-instruct\",\"choices\":[{\"text\":\"Hello\",\"index\":0,\"logprobs\":null,\"finish_reason\":null}],\"provider\":\"custom\",\"cost\":0.001}\n\n".to_string(),
"data: [DONE]\n\n".to_string(),
],
);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hello","stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"object\":\"text_completion\""));
assert!(body_str.contains("\"text\":\"Hello\""));
assert!(body_str.contains("[DONE]"));
assert!(!body_str.contains("\"provider\""));
assert!(!body_str.contains("\"cost\""));
}
#[tokio::test]
async fn test_strict_sanitize_completions_streaming_rewrites_model() {
let targets = Arc::new(DashMap::new());
targets.insert(
"gpt-3.5-turbo-instruct".to_string(),
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.onwards_model("gpt-3.5-turbo-instruct-0914".to_string())
.build()
.into_pool(),
);
let targets = Targets {
targets,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_client = MockHttpClient::new_streaming(
StatusCode::OK,
vec![
"data: {\"id\":\"cmpl-abc\",\"object\":\"text_completion\",\"created\":1677652288,\"model\":\"gpt-3.5-turbo-instruct-0914\",\"choices\":[{\"text\":\"Hi\",\"index\":0,\"logprobs\":null,\"finish_reason\":null}]}\n\n".to_string(),
"data: {\"id\":\"cmpl-abc\",\"object\":\"text_completion\",\"created\":1677652288,\"model\":\"gpt-3.5-turbo-instruct-0914\",\"choices\":[{\"text\":\"\",\"index\":0,\"logprobs\":null,\"finish_reason\":\"stop\"}]}\n\n".to_string(),
"data: [DONE]\n\n".to_string(),
],
);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hello","stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"model\":\"gpt-3.5-turbo-instruct\""));
assert!(!body_str.contains("0914"));
}
#[tokio::test]
async fn test_strict_sanitize_completions_streaming_backfills_missing_noncritical_fields() {
let mock_client = MockHttpClient::new_streaming(
StatusCode::OK,
vec![
"data: {\"choices\":[{\"text\":\"Hello\"}]}\n\n".to_string(),
"data: [DONE]\n\n".to_string(),
],
);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hello","stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"object\":\"text_completion\""));
assert!(body_str.contains("\"model\":\"gpt-3.5-turbo-instruct\""));
assert!(body_str.contains("\"index\":0"));
assert!(body_str.contains("\"text\":\"Hello\""));
assert!(body_str.contains("\"id\":\"cmpl-"));
assert!(body_str.contains("[DONE]"));
}
#[tokio::test]
async fn test_strict_sanitize_completions_streaming_backfills_missing_chunk_text() {
let mock_client = MockHttpClient::new_streaming(
StatusCode::OK,
vec![
"data: {\"choices\":[{\"finish_reason\":null}]}\n\n".to_string(),
"data: [DONE]\n\n".to_string(),
],
);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hello","stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"text\":\"\""));
assert!(body_str.contains("\"index\":0"));
assert!(body_str.contains("[DONE]"));
}
#[tokio::test]
async fn test_strict_sanitize_completions_streaming_backfills_chunk_indexes_from_position() {
let mock_client = MockHttpClient::new_streaming(
StatusCode::OK,
vec![
"data: {\"choices\":[{\"text\":\"first\"},{\"text\":\"second\"}]}\n\n".to_string(),
"data: [DONE]\n\n".to_string(),
],
);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hello","stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"index\":0"));
assert!(body_str.contains("\"index\":1"));
}
#[tokio::test]
async fn test_strict_sanitize_completions_streaming_reuses_fallback_id_across_chunks() {
let mock_client = MockHttpClient::new_streaming(
StatusCode::OK,
vec![
"data: {\"choices\":[{\"text\":\"first\"}]}\n\n".to_string(),
"data: {\"choices\":[{\"text\":\"second\"}]}\n\n".to_string(),
"data: [DONE]\n\n".to_string(),
],
);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hello","stream":true}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
let ids: Vec<String> = body_str
.lines()
.filter_map(|line| line.strip_prefix("data: "))
.filter(|data| *data != "[DONE]")
.map(|data| serde_json::from_str::<serde_json::Value>(data).unwrap())
.filter_map(|json| json.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
assert!(ids.len() >= 2);
assert!(ids.iter().all(|id| id == &ids[0]));
}
#[tokio::test]
async fn test_strict_completions_uses_streaming_sanitizer_when_body_transform_enables_streaming_after_validation()
{
let transform_fn: crate::BodyTransformFn = Arc::new(|path, headers, body_bytes| {
if path != "/completions" {
return None;
}
let should_stream = headers
.get("x-fusillade-stream")
.and_then(|value| value.to_str().ok())
.map(|value| value.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if !should_stream {
return None;
}
let mut json_body = serde_json::from_slice::<serde_json::Value>(body_bytes).ok()?;
let obj = json_body.as_object_mut()?;
obj.insert("stream".to_string(), serde_json::Value::Bool(true));
serde_json::to_vec(&json_body)
.ok()
.map(axum::body::Bytes::from)
});
let mock_client = MockHttpClient::new_streaming(
StatusCode::OK,
vec![
"data: {\"id\":\"cmpl-abc\",\"object\":\"text_completion\",\"created\":1677652288,\"model\":\"gpt-3.5-turbo-instruct\",\"choices\":[{\"text\":\"Hello\",\"index\":0,\"logprobs\":null,\"finish_reason\":null}],\"provider\":\"custom-provider\"}\n\n".to_string(),
"data: [DONE]\n\n".to_string(),
],
);
let state = AppState::with_client_and_transform(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client.clone(),
transform_fn,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.header("x-fusillade-stream", "true")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hello"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok()),
Some("text/event-stream")
);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("\"text\":\"Hello\""));
assert!(body_str.contains("[DONE]"));
assert!(!body_str.contains("custom-provider"));
let requests = mock_client.get_requests();
assert_eq!(requests.len(), 1);
let forwarded_body: serde_json::Value = serde_json::from_slice(&requests[0].body).unwrap();
assert_eq!(forwarded_body["stream"], true);
}
#[tokio::test]
async fn test_completions_untrusted_error_sanitized() {
let mock_error = r#"{"error":{"message":"Provider internal error: OOM on GPU 3","code":"oom","provider":"custom-llm"}}"#;
let mock_client = MockHttpClient::new(StatusCode::INTERNAL_SERVER_ERROR, mock_error);
let state = AppState::with_client(
completions_test_targets("gpt-3.5-turbo-instruct"),
mock_client,
);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hello"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(!body_str.contains("OOM on GPU 3"));
assert!(!body_str.contains("custom-llm"));
assert!(body_str.contains("error"));
}
#[tokio::test]
async fn test_completions_trusted_error_passed_through() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::LoadBalanceStrategy;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("gpt-3.5-turbo-instruct".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_error = r#"{"error":{"message":"Context length exceeded","type":"invalid_request_error","code":"context_length_exceeded"}}"#;
let mock_client = MockHttpClient::new(StatusCode::BAD_REQUEST, mock_error);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hello"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(body_json["error"]["message"], "Context length exceeded");
assert_eq!(body_json["error"]["code"], "context_length_exceeded");
}
#[tokio::test]
async fn test_completions_trusted_success_still_sanitized() {
use crate::load_balancer::{Provider, ProviderPool};
use crate::target::LoadBalanceStrategy;
let targets_map = Arc::new(DashMap::new());
let pool = ProviderPool::with_config(
vec![Provider::new(
Target::builder()
.url("https://api.openai.com/v1/".parse().unwrap())
.onwards_key("sk-test".to_string())
.build(),
1,
)],
None,
None,
None,
None,
LoadBalanceStrategy::default(),
true, Vec::new(),
);
targets_map.insert("gpt-3.5-turbo-instruct".to_string(), pool);
let targets = Targets {
targets: targets_map,
key_rate_limiters: Arc::new(DashMap::new()),
key_concurrency_limiters: Arc::new(DashMap::new()),
key_labels: Arc::new(DashMap::new()),
strict_mode: true,
http_pool_config: None,
};
let mock_response = r#"{
"id": "cmpl-abc",
"object": "text_completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-instruct-actual",
"choices": [{"text": "Hi!", "index": 0, "logprobs": null, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5},
"provider_metadata": {"cost": 0.0001}
}"#;
let mock_client = MockHttpClient::new(StatusCode::OK, mock_response);
let state = AppState::with_client(targets, mock_client);
let router = crate::strict::build_strict_router(state);
let request = Request::builder()
.method("POST")
.uri("/completions")
.header("content-type", "application/json")
.body(Body::from(
r#"{"model":"gpt-3.5-turbo-instruct","prompt":"Hi"}"#,
))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(body_json["model"], "gpt-3.5-turbo-instruct");
assert_ne!(body_json["model"], "gpt-3.5-turbo-instruct-actual");
assert!(body_json.get("provider_metadata").is_none());
}
#[test]
fn test_try_format_sse_error_trusted_forwards_verbatim() {
let data_part =
r#"{"error": {"message": "Input too long", "type": "BadRequestError", "code": 400}}"#;
let result = try_format_sse_error(data_part, true);
assert!(result.is_some());
let line = result.unwrap();
assert!(line.starts_with("data: "));
assert!(line.contains("Input too long"));
}
#[test]
fn test_try_format_sse_error_untrusted_sanitizes() {
let data_part =
r#"{"error": {"message": "OOM on GPU 3", "type": "BadRequestError", "code": 400}}"#;
let result = try_format_sse_error(data_part, false);
assert!(result.is_some());
let line = result.unwrap();
assert!(!line.contains("OOM on GPU 3"));
assert!(line.contains("Invalid request"));
}
#[test]
fn test_try_format_sse_error_ignores_valid_chunk() {
let data_part = r#"{"id": "chatcmpl-123", "object": "chat.completion.chunk"}"#;
assert!(try_format_sse_error(data_part, true).is_none());
}
#[test]
fn test_try_format_sse_error_ignores_non_json() {
assert!(try_format_sse_error("[DONE]", true).is_none());
}
#[tokio::test]
async fn test_streaming_chat_forwards_trusted_error_in_sse() {
let error_body = "data: {\"error\": {\"message\": \"Input too long\", \"type\": \"BadRequestError\", \"code\": 400}}\n\n";
let response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.body(Body::from(error_body))
.unwrap();
let result =
sanitize_streaming_chat_response(response, "test-model".to_string(), true).await;
assert_eq!(result.status(), StatusCode::OK);
let body = axum::body::to_bytes(result.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8_lossy(&body);
assert!(body_str.contains("Input too long"));
}
#[tokio::test]
async fn test_streaming_chat_sanitizes_untrusted_error_in_sse() {
let error_body = "data: {\"error\": {\"message\": \"OOM on GPU 3\", \"type\": \"server_error\", \"code\": 400}}\n\n";
let response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.body(Body::from(error_body))
.unwrap();
let result =
sanitize_streaming_chat_response(response, "test-model".to_string(), false).await;
assert_eq!(result.status(), StatusCode::OK);
let body = axum::body::to_bytes(result.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8_lossy(&body);
assert!(!body_str.contains("OOM on GPU 3"));
assert!(body_str.contains("\"error\""));
}
}