use std::{
io,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use axum::{
Json, Router,
extract::State,
http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri},
response::{
IntoResponse, Response,
sse::{Event, Sse},
},
routing::{get, post},
};
use serde_json::{Value, json};
use thiserror::Error;
use tokio::net::TcpListener;
use tracing::{debug, error, info, warn};
use crate::{
attestation::{AttestationError, AttestationVerifier},
config::{NvidiaRequirement, ProxyConfig},
e2ee::{E2eeCodec, E2eeCodecError},
keys::ProxyInstanceKey,
openai::{
ErrorResponse,
chat::{
ChatCompletionRequest, ChatConstructionError, ChatRequestError, NormalizedChatMessage,
},
},
sessions::{AttestedModelState, SessionContext, SessionError, SessionManager, SessionRequest},
tools::{ToolEmulationContext, ToolOutputClassification, ValidatedToolCall},
venice::{VeniceClient, VeniceClientError},
};
pub const HEADER_PROXY_E2EE: &str = "X-Venice-Proxy-E2EE";
pub const HEADER_PROXY_ATTESTATION_MODE: &str = "X-Venice-Proxy-Attestation-Mode";
pub const HEADER_PROXY_ATTESTED_MODEL: &str = "X-Venice-Proxy-Attested-Model";
pub const HEADER_PROXY_TEE_PROVIDER: &str = "X-Venice-Proxy-TEE-Provider";
pub const HEADER_PROXY_TDX_VERIFIED: &str = "X-Venice-Proxy-TDX-Verified";
pub const HEADER_PROXY_TDX_DEBUG: &str = "X-Venice-Proxy-TDX-Debug";
pub const HEADER_PROXY_NVIDIA_VERIFIED: &str = "X-Venice-Proxy-NVIDIA-Verified";
pub const HEADER_PROXY_KEY_BINDING: &str = "X-Venice-Proxy-Key-Binding";
pub const HEADER_PROXY_SESSION_ID: &str = "X-Venice-Proxy-Session-Id";
pub const HEADER_PROXY_SESSION_SCOPE: &str = "X-Venice-Proxy-Session-Scope";
pub const HEADER_PROXY_TOOL_MODE: &str = "X-Venice-Proxy-Tool-Mode";
pub const HEADER_PROXY_TOOL_RETRIES: &str = "X-Venice-Proxy-Tool-Retries";
pub const HEADER_PROXY_ERROR_CODE: &str = "X-Venice-Proxy-Error-Code";
#[derive(Debug, Clone)]
pub struct AppState {
config: Arc<ProxyConfig>,
venice_client: VeniceClient,
proxy_instance_key: Option<ProxyInstanceKey>,
session_manager: SessionManager,
attestation_verifier: AttestationVerifier,
}
impl AppState {
pub fn new(config: ProxyConfig) -> Result<Self, VeniceClientError> {
let venice_client = VeniceClient::from_config(&config)?;
Ok(Self::from_parts(config, venice_client))
}
pub fn from_parts(config: ProxyConfig, venice_client: VeniceClient) -> Self {
let proxy_instance_key = ProxyInstanceKey::generate_from_config(&config.keys);
let session_manager = SessionManager::new(config.session.clone());
let attestation_verifier = AttestationVerifier::from_config(&config, venice_client.clone());
Self {
config: Arc::new(config),
venice_client,
proxy_instance_key,
session_manager,
attestation_verifier,
}
}
pub fn config(&self) -> &ProxyConfig {
&self.config
}
pub fn venice_client(&self) -> &VeniceClient {
&self.venice_client
}
pub fn proxy_instance_key(&self) -> Option<&ProxyInstanceKey> {
self.proxy_instance_key.as_ref()
}
pub fn session_manager(&self) -> &SessionManager {
&self.session_manager
}
pub fn attestation_verifier(&self) -> &AttestationVerifier {
&self.attestation_verifier
}
}
pub fn router(config: ProxyConfig) -> Result<Router, VeniceClientError> {
Ok(router_from_state(AppState::new(config)?))
}
pub fn router_with_venice_client(config: ProxyConfig, venice_client: VeniceClient) -> Router {
router_from_state(AppState::from_parts(config, venice_client))
}
fn router_from_state(state: AppState) -> Router {
Router::new()
.route("/v1/models", get(list_models).fallback(method_not_allowed))
.route(
"/v1/chat/completions",
post(create_chat_completion).fallback(method_not_allowed),
)
.fallback(not_found)
.with_state(state)
}
pub async fn serve(listener: TcpListener, router: Router) -> io::Result<()> {
axum::serve(listener, router).await
}
async fn list_models(State(state): State<AppState>) -> Result<Response, ProxyError> {
info!(route = "/v1/models", "listing Venice models");
let models = state.venice_client().list_models().await?;
let mut response = Json(models).into_response();
ProxyMetadataHeaders::from_config(state.config()).apply(response.headers_mut());
info!(route = "/v1/models", "Venice models response proxied");
Ok(response)
}
async fn create_chat_completion(
State(state): State<AppState>,
headers: HeaderMap,
Json(body): Json<Value>,
) -> Result<Response, ProxyError> {
let request = ChatCompletionRequest::parse(&body)?;
let proxy_instance_key = state
.proxy_instance_key()
.ok_or(ProxyError::ProxyInstanceKeyUnavailable)?;
let session_resolution = state
.session_manager()
.get_or_create(SessionRequest::new(&request.model, &headers).with_body(&body))?;
let session_created = session_resolution.created;
let session_replaced_expired = session_resolution.replaced_expired;
let session_scope = session_resolution.session.scope;
let session = ensure_attested_session(&state, session_resolution.session).await?;
let model_public_key = session
.attested_model_public_key
.as_deref()
.ok_or(ProxyError::MissingAttestedModelKey)?;
let codec =
E2eeCodec::from_config(&state.config().e2ee).map_err(ChatConstructionError::E2ee)?;
let tool_context = ToolEmulationContext::from_request(&state.config().tools, &request)?;
let metadata = ProxyMetadataHeaders::for_verified_chat(state.config(), &session);
info!(
route = "/v1/chat/completions",
model = %request.model,
stream = request.stream,
message_count = request.messages.len(),
tool_count = request.tools.len(),
tool_mode = tool_context.is_some(),
session_created,
session_replaced_expired = ?session_replaced_expired,
session_scope = %session_scope,
"chat completion request accepted"
);
if let Some(tool_context) = tool_context {
info!(model = %request.model, "using tool-emulated chat completion");
return openai_tool_emulated_chat_response(
&state,
&request,
&tool_context,
codec,
proxy_instance_key.clone(),
model_public_key,
metadata,
)
.await;
}
let prepared = request.to_venice_e2ee_request(&codec, model_public_key)?;
info!(
model = %request.model,
client_stream = prepared.client_stream,
"forwarding encrypted chat completion to Venice"
);
let upstream = state
.venice_client()
.create_chat_completion_stream(
&prepared.upstream,
proxy_instance_key.public_key_hex(),
model_public_key,
)
.await?;
if prepared.client_stream {
info!(model = %request.model, "streaming chat completion response to client");
let include_usage_requested = request.stream_options.include_usage.unwrap_or(false);
let transformer = OpenAiChatStreamTransformer::new(
codec,
proxy_instance_key.clone(),
request.model.clone(),
include_usage_requested,
);
Ok(chat_sse_response(
upstream,
transformer,
request.model,
include_usage_requested,
&CHAT_SSE_LOG,
metadata,
))
} else {
info!(model = %request.model, "buffering chat completion response for client");
openai_chat_buffered_response(
upstream,
codec,
proxy_instance_key.clone(),
request.model,
metadata,
)
.await
}
}
async fn ensure_attested_session(
state: &AppState,
session: SessionContext,
) -> Result<SessionContext, ProxyError> {
if session.attested_model_public_key.is_some() {
info!(model = %session.model_id, session_scope = %session.scope, "using cached model attestation");
return Ok(session);
}
info!(model = %session.model_id, session_scope = %session.scope, "fetching model attestation");
let attestation = state
.attestation_verifier()
.verify_model_attestation(&session.model_id)
.await?;
info!(
model = %attestation.model_id,
tee_provider = attestation.tee_provider.as_deref().unwrap_or("unknown"),
tdx_verified = attestation.tdx.verified,
nvidia_verified = attestation.nvidia.verified.as_header_value(),
"model attestation verified"
);
let state_update = AttestedModelState {
model_public_key: attestation.model_public_key,
attestation_report: attestation.attestation_report,
verified_at: attestation.verified_at,
};
Ok(state
.session_manager()
.set_attested_model_state(&session.session_key, state_update)?)
}
async fn openai_chat_buffered_response(
upstream: reqwest::Response,
codec: E2eeCodec,
proxy_instance_key: ProxyInstanceKey,
fallback_model: String,
metadata: ProxyMetadataHeaders,
) -> Result<Response, ProxyError> {
let completion =
buffer_openai_chat_completion(upstream, codec, proxy_instance_key, fallback_model).await?;
let mut response = Json(completion).into_response();
metadata.apply(response.headers_mut());
Ok(response)
}
async fn openai_tool_emulated_chat_response(
state: &AppState,
request: &ChatCompletionRequest,
tool_context: &ToolEmulationContext,
codec: E2eeCodec,
proxy_instance_key: ProxyInstanceKey,
model_public_key: &str,
metadata: ProxyMetadataHeaders,
) -> Result<Response, ProxyError> {
info!(
model = %request.model,
max_retries = tool_context.max_retries(),
"starting tool-emulated chat completion"
);
if request.stream {
let upstream = tool_emulated_upstream_stream(
state,
request,
tool_context,
&codec,
&proxy_instance_key,
model_public_key,
None,
)
.await?;
let include_usage_requested = request.stream_options.include_usage.unwrap_or(false);
let transformer = OpenAiToolEmulatedChatStreamTransformer::new(
tool_context,
codec,
proxy_instance_key,
request.model.clone(),
include_usage_requested,
)
.map_err(ProxyError::ChatStream)?;
return Ok(chat_sse_response(
upstream,
transformer,
request.model.clone(),
include_usage_requested,
&TOOL_EMULATED_CHAT_SSE_LOG,
metadata,
));
}
let mut retries = 0;
let mut correction: Option<(String, String)> = None;
loop {
let upstream = tool_emulated_upstream_stream(
state,
request,
tool_context,
&codec,
&proxy_instance_key,
model_public_key,
correction.as_ref(),
)
.await?;
let completion = match tokio::time::timeout(
tool_context.marker_timeout(),
buffer_openai_chat_completion(
upstream,
codec.clone(),
proxy_instance_key.clone(),
request.model.clone(),
),
)
.await
{
Ok(completion) => completion?,
Err(_) => {
let validation_error = format!(
"tool-emulated completion did not finish within {}",
humantime::format_duration(tool_context.config().tool_call_marker_timeout)
);
if retries >= tool_context.max_retries() {
return Err(ProxyError::ToolCallRetryExhausted {
max_retries: tool_context.max_retries(),
last_validation_error: validation_error,
});
}
warn!(
model = %request.model,
retry = retries + 1,
max_retries = tool_context.max_retries(),
"tool call marker timed out; retrying with correction"
);
retries += 1;
correction = Some((validation_error, String::new()));
continue;
}
};
let assistant_content = completion
.get("choices")
.and_then(Value::as_array)
.and_then(|choices| choices.first())
.and_then(|choice| choice.get("message"))
.and_then(|message| message.get("content"))
.and_then(Value::as_str)
.unwrap_or_default();
let mut metadata = metadata.clone();
if retries > 0 {
metadata.tool_retries = Some(retries);
}
match tool_context.classify_assistant_output(assistant_content) {
ToolOutputClassification::NormalText => {
info!(model = %request.model, retries, "tool emulation produced normal text");
let mut response = Json(completion).into_response();
metadata.apply(response.headers_mut());
return Ok(response);
}
ToolOutputClassification::ToolCalls(tool_calls) => {
info!(
model = %request.model,
tool_calls = tool_calls.len(),
retries,
"tool emulation produced tool calls"
);
let body = openai_tool_call_completion(completion, tool_calls);
let mut response = Json(body).into_response();
metadata.apply(response.headers_mut());
return Ok(response);
}
ToolOutputClassification::InvalidToolCall {
error,
invalid_output,
} => {
if retries >= tool_context.max_retries() {
warn!(
model = %request.model,
max_retries = tool_context.max_retries(),
validation_error = %error,
"tool call validation failed and retries were exhausted"
);
return Err(ProxyError::ToolCallRetryExhausted {
max_retries: tool_context.max_retries(),
last_validation_error: error.to_string(),
});
}
warn!(
model = %request.model,
retry = retries + 1,
max_retries = tool_context.max_retries(),
validation_error = %error,
"tool call validation failed; retrying with correction"
);
retries += 1;
correction = Some((error.to_string(), invalid_output));
}
}
}
}
async fn tool_emulated_upstream_stream(
state: &AppState,
request: &ChatCompletionRequest,
tool_context: &ToolEmulationContext,
codec: &E2eeCodec,
proxy_instance_key: &ProxyInstanceKey,
model_public_key: &str,
correction: Option<&(String, String)>,
) -> Result<reqwest::Response, ProxyError> {
let messages = tool_emulated_messages(request, tool_context, correction);
let mut tool_request = request.clone();
tool_request.messages = messages;
let prepared = tool_request.to_venice_e2ee_request(codec, model_public_key)?;
Ok(state
.venice_client()
.create_chat_completion_stream(
&prepared.upstream,
proxy_instance_key.public_key_hex(),
model_public_key,
)
.await?)
}
fn tool_emulated_messages(
request: &ChatCompletionRequest,
tool_context: &ToolEmulationContext,
correction: Option<&(String, String)>,
) -> Vec<NormalizedChatMessage> {
let mut messages = request.messages.clone();
let mut tool_system_content = tool_context.controller_message().content;
if let Some((validation_error, invalid_output)) = correction {
tool_system_content.push_str("\n\n");
tool_system_content.push_str(
&tool_context
.correction_message(validation_error, invalid_output)
.content,
);
}
append_to_system_message(&mut messages, tool_system_content);
messages
}
fn append_to_system_message(messages: &mut Vec<NormalizedChatMessage>, content: String) {
if let Some(system_message) = messages.iter_mut().find(|message| message.role == "system") {
system_message.content.push_str("\n\n");
system_message.content.push_str(&content);
} else {
messages.insert(0, NormalizedChatMessage::new("system", content));
}
}
fn openai_tool_call_completion(completion: Value, tool_calls: Vec<ValidatedToolCall>) -> Value {
let choice = completion
.get("choices")
.and_then(Value::as_array)
.and_then(|choices| choices.first())
.cloned()
.unwrap_or(Value::Null);
let index = choice.get("index").and_then(Value::as_u64).unwrap_or(0);
let tool_call_values: Vec<Value> = tool_calls
.iter()
.map(ValidatedToolCall::to_openai_value)
.collect();
let reasoning_content = choice
.get("message")
.and_then(|message| message.get("reasoning_content"))
.and_then(Value::as_str);
let mut message = serde_json::Map::new();
message.insert("role".to_owned(), Value::String("assistant".to_owned()));
message.insert("content".to_owned(), Value::Null);
if let Some(reasoning_content) = reasoning_content {
message.insert(
"reasoning_content".to_owned(),
Value::String(reasoning_content.to_owned()),
);
}
message.insert("tool_calls".to_owned(), Value::Array(tool_call_values));
json!({
"id": string_field(&completion, "id").unwrap_or("chatcmpl-local"),
"object": "chat.completion",
"created": integer_field(&completion, "created").unwrap_or_else(unix_timestamp_now),
"model": string_field(&completion, "model").unwrap_or("unknown"),
"choices": [{
"index": index,
"message": Value::Object(message),
"finish_reason": "tool_calls",
}],
"usage": completion.get("usage").cloned().unwrap_or(Value::Null),
})
}
async fn buffer_openai_chat_completion(
mut upstream: reqwest::Response,
codec: E2eeCodec,
proxy_instance_key: ProxyInstanceKey,
fallback_model: String,
) -> Result<Value, ChatStreamError> {
info!(model = %fallback_model, "buffering upstream chat stream");
let mut parser = SseEventParser::default();
let mut transformer =
OpenAiChatCompletionBuffer::new(codec, proxy_instance_key, fallback_model.clone());
let mut upstream_done = false;
let mut chunk_count = 0_u64;
let mut event_count = 0_u64;
while let Some(chunk) = upstream
.chunk()
.await
.map_err(ChatStreamError::upstream_stream)?
{
chunk_count += 1;
let chunk = std::str::from_utf8(&chunk).map_err(ChatStreamError::invalid_utf8)?;
let events = parser.push(chunk)?;
event_count += events.len() as u64;
debug!(
model = %fallback_model,
chunk_count,
parsed_events = events.len(),
total_events = event_count,
"parsed buffered upstream SSE chunk"
);
for event in events {
if transformer.handle_event(event)? {
upstream_done = true;
break;
}
}
if upstream_done {
break;
}
}
if !upstream_done {
warn!(
model = %fallback_model,
chunk_count,
event_count,
"buffered upstream stream ended before DONE"
);
parser.finish()?;
return Err(ChatStreamError::malformed_event(
"upstream stream ended before data: [DONE]",
));
}
let completion = transformer.into_response();
info!(
model = %fallback_model,
chunk_count,
event_count,
"buffered upstream chat stream transformed"
);
Ok(completion)
}
struct ChatSseLogMessages {
start: &'static str,
parsed_chunk: &'static str,
transformed_event: &'static str,
completed: &'static str,
ended_early: &'static str,
}
const CHAT_SSE_LOG: ChatSseLogMessages = ChatSseLogMessages {
start: "starting upstream chat SSE transformation",
parsed_chunk: "parsed streaming upstream SSE chunk",
transformed_event: "transformed streaming upstream SSE event",
completed: "completed upstream chat SSE transformation",
ended_early: "streaming upstream stream ended before DONE",
};
const TOOL_EMULATED_CHAT_SSE_LOG: ChatSseLogMessages = ChatSseLogMessages {
start: "starting tool-emulated upstream chat SSE transformation",
parsed_chunk: "parsed tool-emulated upstream SSE chunk",
transformed_event: "transformed tool-emulated upstream SSE event",
completed: "completed tool-emulated upstream chat SSE transformation",
ended_early: "tool-emulated upstream stream ended before DONE",
};
trait ChatSseTransformer {
fn handle_event(&mut self, event: RawSseEvent) -> Result<Vec<StreamOutput>, ChatStreamError>;
}
fn chat_sse_response<T>(
upstream: reqwest::Response,
transformer: T,
fallback_model: String,
include_usage_requested: bool,
log: &'static ChatSseLogMessages,
metadata: ProxyMetadataHeaders,
) -> Response
where
T: ChatSseTransformer + Send + 'static,
{
let stream = chat_sse_event_stream(
upstream,
transformer,
fallback_model,
include_usage_requested,
log,
);
let mut response = Sse::new(stream).into_response();
metadata.apply(response.headers_mut());
response
}
fn chat_sse_event_stream<T>(
mut upstream: reqwest::Response,
mut transformer: T,
fallback_model: String,
include_usage_requested: bool,
log: &'static ChatSseLogMessages,
) -> impl futures_core::Stream<Item = Result<Event, axum::BoxError>>
where
T: ChatSseTransformer + Send + 'static,
{
async_stream::try_stream! {
info!(
model = %fallback_model,
include_usage_requested,
"{}", log.start
);
let mut parser = SseEventParser::default();
let mut upstream_done = false;
let mut chunk_count = 0_u64;
let mut event_count = 0_u64;
let mut output_count = 0_u64;
while let Some(chunk) = upstream
.chunk()
.await
.map_err(ChatStreamError::upstream_stream)
.map_err(box_chat_stream_error)?
{
chunk_count += 1;
let chunk = std::str::from_utf8(&chunk)
.map_err(ChatStreamError::invalid_utf8)
.map_err(box_chat_stream_error)?;
let events = parser.push(chunk).map_err(box_chat_stream_error)?;
event_count += events.len() as u64;
debug!(
model = %fallback_model,
chunk_count,
parsed_events = events.len(),
total_events = event_count,
"{}", log.parsed_chunk
);
for event in events {
let outputs = transformer.handle_event(event).map_err(box_chat_stream_error)?;
output_count += outputs.len() as u64;
debug!(
model = %fallback_model,
emitted_outputs = outputs.len(),
total_outputs = output_count,
"{}", log.transformed_event
);
for output in outputs {
match output {
StreamOutput::Json(value) => yield Event::default().data(value.to_string()),
StreamOutput::Done => {
upstream_done = true;
info!(
model = %fallback_model,
chunk_count,
event_count,
output_count,
"{}", log.completed
);
yield Event::default().data("[DONE]");
break;
}
}
}
if upstream_done {
break;
}
}
if upstream_done {
break;
}
}
if !upstream_done {
warn!(
model = %fallback_model,
chunk_count,
event_count,
output_count,
"{}", log.ended_early
);
parser.finish().map_err(box_chat_stream_error)?;
Err::<(), axum::BoxError>(box_chat_stream_error(ChatStreamError::malformed_event(
"upstream stream ended before data: [DONE]",
)))?;
}
}
}
fn box_chat_stream_error(error: ChatStreamError) -> axum::BoxError {
error!(error = %error, "chat stream transformation failed");
Box::new(error)
}
#[derive(Debug, Default)]
struct SseEventParser {
buffer: String,
}
impl SseEventParser {
fn push(&mut self, chunk: &str) -> Result<Vec<RawSseEvent>, ChatStreamError> {
self.buffer.push_str(chunk);
let mut events = Vec::new();
while let Some((boundary_start, boundary_len)) = sse_event_boundary(&self.buffer) {
let raw = self.buffer[..boundary_start].to_owned();
self.buffer.drain(..boundary_start + boundary_len);
if let Some(event) = parse_sse_event(&raw)? {
events.push(event);
}
}
debug!(
chunk_bytes = chunk.len(),
buffered_bytes = self.buffer.len(),
parsed_events = events.len(),
"SSE parser processed upstream chunk"
);
Ok(events)
}
fn finish(&self) -> Result<(), ChatStreamError> {
if self.buffer.trim().is_empty() {
Ok(())
} else {
warn!(
buffered_bytes = self.buffer.len(),
"upstream SSE stream ended with incomplete event"
);
Err(ChatStreamError::malformed_event(
"upstream stream ended with an incomplete SSE event",
))
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct RawSseEvent {
event: Option<String>,
data: String,
}
struct UpstreamEventLogMessages {
event: &'static str,
sse_error: &'static str,
done: &'static str,
parsing: Option<&'static str>,
json_error: &'static str,
missing_choices: &'static str,
parsed: Option<&'static str>,
unexpected_choice_count: &'static str,
}
const BUFFERED_UPSTREAM_EVENT_LOG: UpstreamEventLogMessages = UpstreamEventLogMessages {
event: "buffering upstream SSE event",
sse_error: "upstream SSE error event while buffering response",
done: "received upstream DONE while buffering response",
parsing: Some("parsing buffered upstream chat JSON chunk"),
json_error: "upstream JSON error chunk while buffering response",
missing_choices: "buffered upstream chat chunk is missing choices array",
parsed: Some("parsed buffered upstream chat chunk"),
unexpected_choice_count: "unexpected buffered upstream choice count",
};
const STREAMING_UPSTREAM_EVENT_LOG: UpstreamEventLogMessages = UpstreamEventLogMessages {
event: "transforming streaming upstream SSE event",
sse_error: "upstream SSE error event while streaming response",
done: "received upstream DONE while streaming response",
parsing: Some("parsing streaming upstream chat JSON chunk"),
json_error: "upstream JSON error chunk while streaming response",
missing_choices: "streaming upstream chat chunk is missing choices array",
parsed: Some("parsed streaming upstream chat chunk"),
unexpected_choice_count: "unexpected streaming upstream choice count",
};
const TOOL_EMULATED_UPSTREAM_EVENT_LOG: UpstreamEventLogMessages = UpstreamEventLogMessages {
event: "transforming tool-emulated streaming upstream SSE event",
sse_error: "upstream SSE error event while streaming tool-emulated response",
done: "received upstream DONE while streaming tool-emulated response",
parsing: None,
json_error: "upstream JSON error chunk while streaming tool-emulated response",
missing_choices: "tool-emulated upstream chat chunk is missing choices array",
parsed: None,
unexpected_choice_count: "unexpected tool-emulated upstream choice count",
};
enum UpstreamEventKind {
Done,
Usage(Value),
Choice { value: Value, choice: Value },
}
fn classify_upstream_event(
event: RawSseEvent,
log: &UpstreamEventLogMessages,
) -> Result<UpstreamEventKind, ChatStreamError> {
let event_type = event.event.as_deref().unwrap_or("message");
let is_done = event.data.trim() == "[DONE]";
debug!(event_type, is_done, "{}", log.event);
if event.event.as_deref() == Some("error") {
warn!("{}", log.sse_error);
return Err(ChatStreamError::upstream_event(event.data));
}
if is_done {
info!("{}", log.done);
return Ok(UpstreamEventKind::Done);
}
if let Some(parsing) = log.parsing {
debug!("{}", parsing);
}
let value: Value = serde_json::from_str(&event.data).map_err(ChatStreamError::json_event)?;
if let Some(error) = value.get("error") {
warn!("{}", log.json_error);
return Err(ChatStreamError::upstream_event(error.to_string()));
}
let Some(choices) = value.get("choices").and_then(Value::as_array) else {
warn!("{}", log.missing_choices);
return Err(ChatStreamError::malformed_event(
"upstream chat chunk is missing choices array",
));
};
if let Some(parsed) = log.parsed {
debug!(choice_count = choices.len(), "{}", parsed);
}
if choices.is_empty() {
return Ok(UpstreamEventKind::Usage(value));
}
if choices.len() != 1 {
warn!(
choice_count = choices.len(),
"{}", log.unexpected_choice_count
);
return Err(ChatStreamError::malformed_event(format!(
"expected exactly one upstream choice, got {}",
choices.len(),
)));
}
let choice = choices[0].clone();
Ok(UpstreamEventKind::Choice { value, choice })
}
struct ChunkContext {
codec: E2eeCodec,
proxy_instance_key: ProxyInstanceKey,
fallback_id: String,
fallback_created: i64,
fallback_model: String,
}
impl ChunkContext {
fn new(codec: E2eeCodec, proxy_instance_key: ProxyInstanceKey, fallback_model: String) -> Self {
Self {
codec,
proxy_instance_key,
fallback_id: format!("chatcmpl-local-{}", uuid::Uuid::new_v4()),
fallback_created: unix_timestamp_now(),
fallback_model,
}
}
fn decrypt(&self, content: Option<&str>) -> Result<Option<String>, ChatStreamError> {
self.codec
.decrypt_response_content(content, self.proxy_instance_key.private_key())
.map_err(ChatStreamError::decryption)
}
fn chunk_with_choice(
&self,
upstream: &Value,
index: u64,
delta: Value,
finish_reason: Value,
) -> Value {
json!({
"id": string_field(upstream, "id").unwrap_or(&self.fallback_id),
"object": string_field(upstream, "object").unwrap_or("chat.completion.chunk"),
"created": integer_field(upstream, "created").unwrap_or(self.fallback_created),
"model": string_field(upstream, "model").unwrap_or(&self.fallback_model),
"choices": [{
"index": index,
"delta": delta,
"finish_reason": finish_reason,
}],
})
}
fn usage_chunk(&self, upstream: &Value, usage: &Value) -> Value {
json!({
"id": string_field(upstream, "id").unwrap_or(&self.fallback_id),
"object": string_field(upstream, "object").unwrap_or("chat.completion.chunk"),
"created": integer_field(upstream, "created").unwrap_or(self.fallback_created),
"model": string_field(upstream, "model").unwrap_or(&self.fallback_model),
"choices": [],
"usage": usage,
})
}
}
struct OpenAiChatCompletionBuffer {
ctx: ChunkContext,
id: Option<String>,
created: Option<i64>,
model: Option<String>,
choice_index: Option<u64>,
saw_encrypted_response_field: bool,
content: String,
reasoning_content: String,
finish_reason: Option<Value>,
usage: Option<Value>,
}
impl OpenAiChatCompletionBuffer {
fn new(codec: E2eeCodec, proxy_instance_key: ProxyInstanceKey, fallback_model: String) -> Self {
Self {
ctx: ChunkContext::new(codec, proxy_instance_key, fallback_model),
id: None,
created: None,
model: None,
choice_index: None,
saw_encrypted_response_field: false,
content: String::new(),
reasoning_content: String::new(),
finish_reason: None,
usage: None,
}
}
fn handle_event(&mut self, event: RawSseEvent) -> Result<bool, ChatStreamError> {
match classify_upstream_event(event, &BUFFERED_UPSTREAM_EVENT_LOG)? {
UpstreamEventKind::Done => {
if !self.saw_encrypted_response_field {
self.ctx.decrypt(None)?;
}
if self.finish_reason.is_none() {
self.finish_reason = Some(Value::String("stop".to_owned()));
}
Ok(true)
}
UpstreamEventKind::Usage(value) => {
self.record_metadata(&value);
self.handle_usage_chunk(&value).map(|()| false)
}
UpstreamEventKind::Choice { value, choice } => {
self.record_metadata(&value);
self.handle_choice_chunk(&choice)?;
Ok(false)
}
}
}
fn handle_usage_chunk(&mut self, value: &Value) -> Result<(), ChatStreamError> {
let Some(usage) = value.get("usage") else {
warn!("buffered upstream chunk has no choices and no usage");
return Err(ChatStreamError::malformed_event(
"upstream chunk has no choices and no usage",
));
};
info!("buffered upstream usage chunk");
self.usage = Some(usage.clone());
Ok(())
}
fn handle_choice_chunk(&mut self, choice: &Value) -> Result<(), ChatStreamError> {
let choice = choice.as_object().ok_or_else(|| {
ChatStreamError::malformed_event("upstream choice must be a JSON object")
})?;
let index = normalized_choice_index(choice.get("index"))?;
match self.choice_index {
Some(existing) if existing != index => {
return Err(ChatStreamError::malformed_event(
"upstream choice index changed while buffering a completion",
));
}
None => self.choice_index = Some(index),
Some(_) => {}
}
let finish_reason = normalized_finish_reason(choice.get("finish_reason"))?;
let delta = choice.get("delta").unwrap_or(&Value::Null);
let content = encrypted_delta_content(delta)?;
let reasoning_content = encrypted_delta_reasoning_content(delta)?;
debug!(
choice_index = index,
has_encrypted_content = content.is_some(),
has_encrypted_reasoning_content = reasoning_content.is_some(),
has_finish_reason = !finish_reason.is_null(),
"transforming buffered upstream choice chunk"
);
if let Some(content) = content {
let decrypted = self.ctx.decrypt(Some(content))?;
self.saw_encrypted_response_field = true;
debug!(
choice_index = index,
has_decrypted_content = decrypted.is_some(),
"decrypted buffered upstream content chunk"
);
if let Some(content) = decrypted {
self.content.push_str(&content);
}
}
if let Some(reasoning_content) = reasoning_content {
let decrypted = self.ctx.decrypt(Some(reasoning_content))?;
self.saw_encrypted_response_field = true;
debug!(
choice_index = index,
has_decrypted_reasoning_content = decrypted.is_some(),
"decrypted buffered upstream reasoning content chunk"
);
if let Some(reasoning_content) = decrypted {
self.reasoning_content.push_str(&reasoning_content);
}
}
if !finish_reason.is_null() {
self.finish_reason = Some(finish_reason);
}
Ok(())
}
fn record_metadata(&mut self, value: &Value) {
if self.id.is_none()
&& let Some(id) = string_field(value, "id")
{
self.id = Some(id.to_owned());
}
if self.created.is_none()
&& let Some(created) = integer_field(value, "created")
{
self.created = Some(created);
}
if self.model.is_none()
&& let Some(model) = string_field(value, "model")
{
self.model = Some(model.to_owned());
}
}
fn into_response(self) -> Value {
let mut message = serde_json::Map::new();
message.insert("role".to_owned(), Value::String("assistant".to_owned()));
if !self.reasoning_content.is_empty() {
message.insert(
"reasoning_content".to_owned(),
Value::String(self.reasoning_content),
);
}
message.insert("content".to_owned(), Value::String(self.content));
json!({
"id": self.id.unwrap_or(self.ctx.fallback_id),
"object": "chat.completion",
"created": self.created.unwrap_or(self.ctx.fallback_created),
"model": self.model.unwrap_or(self.ctx.fallback_model),
"choices": [{
"index": self.choice_index.unwrap_or(0),
"message": Value::Object(message),
"finish_reason": self.finish_reason.unwrap_or_else(|| Value::String("stop".to_owned())),
}],
"usage": self.usage.unwrap_or(Value::Null),
})
}
}
fn sse_event_boundary(buffer: &str) -> Option<(usize, usize)> {
["\r\n\r\n", "\n\n", "\r\r"]
.into_iter()
.filter_map(|delimiter| buffer.find(delimiter).map(|index| (index, delimiter.len())))
.min_by_key(|(index, _)| *index)
}
fn parse_sse_event(raw: &str) -> Result<Option<RawSseEvent>, ChatStreamError> {
let mut event = None;
let mut data_lines = Vec::new();
let mut saw_non_comment_field = false;
for line in raw.lines() {
let line = line.strip_suffix('\r').unwrap_or(line);
if line.is_empty() || line.starts_with(':') {
continue;
}
saw_non_comment_field = true;
let (field, value) = line.split_once(':').unwrap_or((line, ""));
let value = value.strip_prefix(' ').unwrap_or(value);
match field {
"event" => event = Some(value.to_owned()),
"data" => data_lines.push(value.to_owned()),
"id" | "retry" => {}
other => {
warn!(field = other, "unsupported upstream SSE field");
return Err(ChatStreamError::malformed_event(format!(
"unsupported upstream SSE field {other:?}",
)));
}
}
}
if data_lines.is_empty() {
return if saw_non_comment_field {
warn!("upstream SSE event did not contain a data field");
Err(ChatStreamError::malformed_event(
"upstream SSE event did not contain a data field",
))
} else {
debug!("ignored upstream SSE comment or heartbeat event");
Ok(None)
};
}
debug!(
event_type = event.as_deref().unwrap_or("message"),
data_line_count = data_lines.len(),
"parsed upstream SSE event"
);
Ok(Some(RawSseEvent {
event,
data: data_lines.join("\n"),
}))
}
struct OpenAiChatStreamTransformer {
ctx: ChunkContext,
include_usage_requested: bool,
sent_role: bool,
sent_final_finish: bool,
}
impl OpenAiChatStreamTransformer {
fn new(
codec: E2eeCodec,
proxy_instance_key: ProxyInstanceKey,
fallback_model: String,
include_usage_requested: bool,
) -> Self {
Self {
ctx: ChunkContext::new(codec, proxy_instance_key, fallback_model),
include_usage_requested,
sent_role: false,
sent_final_finish: false,
}
}
fn handle_choice_chunk(
&mut self,
value: &Value,
choice: &Value,
) -> Result<Vec<StreamOutput>, ChatStreamError> {
let choice = choice.as_object().ok_or_else(|| {
ChatStreamError::malformed_event("upstream choice must be a JSON object")
})?;
let finish_reason = normalized_finish_reason(choice.get("finish_reason"))?;
let delta = choice.get("delta").unwrap_or(&Value::Null);
let content = encrypted_delta_content(delta)?;
let reasoning_content = encrypted_delta_reasoning_content(delta)?;
debug!(
has_encrypted_content = content.is_some(),
has_encrypted_reasoning_content = reasoning_content.is_some(),
has_finish_reason = !finish_reason.is_null(),
"transforming streaming upstream choice chunk"
);
let mut output = Vec::new();
if content.is_none() && reasoning_content.is_none() {
if !finish_reason.is_null() {
output.push(StreamOutput::Json(self.chunk_with_choice(
value,
choice.get("index"),
json!({}),
finish_reason,
)?));
self.sent_final_finish = true;
}
return Ok(output);
}
let decrypted_content = match content {
Some(content) => self.ctx.decrypt(Some(content))?,
None => None,
};
let decrypted_reasoning_content = match reasoning_content {
Some(reasoning_content) => self.ctx.decrypt(Some(reasoning_content))?,
None => None,
};
debug!(
has_decrypted_content = decrypted_content.is_some(),
has_decrypted_reasoning_content = decrypted_reasoning_content.is_some(),
"decrypted streaming upstream content chunk"
);
if decrypted_content.is_some() || decrypted_reasoning_content.is_some() {
let mut delta = serde_json::Map::new();
if !self.sent_role {
delta.insert("role".to_owned(), Value::String("assistant".to_owned()));
self.sent_role = true;
}
if let Some(reasoning_content) = decrypted_reasoning_content {
delta.insert(
"reasoning_content".to_owned(),
Value::String(reasoning_content),
);
}
if let Some(content) = decrypted_content {
delta.insert("content".to_owned(), Value::String(content));
}
let final_finish = !finish_reason.is_null();
let content_finish_reason = if final_finish {
Value::Null
} else {
finish_reason.clone()
};
output.push(StreamOutput::Json(self.chunk_with_choice(
value,
choice.get("index"),
Value::Object(delta),
content_finish_reason,
)?));
if final_finish {
output.push(StreamOutput::Json(self.chunk_with_choice(
value,
choice.get("index"),
json!({}),
finish_reason,
)?));
self.sent_final_finish = true;
}
return Ok(output);
}
Ok(output)
}
fn handle_usage_chunk(&self, value: &Value) -> Result<Vec<StreamOutput>, ChatStreamError> {
let Some(usage) = value.get("usage") else {
warn!("streaming upstream chunk has no choices and no usage");
return Err(ChatStreamError::malformed_event(
"upstream chunk has no choices and no usage",
));
};
if !self.include_usage_requested {
debug!("streaming upstream usage chunk ignored because client did not request usage");
return Ok(Vec::new());
}
info!("streaming upstream usage chunk forwarded");
Ok(vec![StreamOutput::Json(self.ctx.usage_chunk(value, usage))])
}
fn finish_chunk(&self) -> Value {
self.ctx
.chunk_with_choice(&Value::Null, 0, json!({}), Value::String("stop".to_owned()))
}
fn chunk_with_choice(
&self,
upstream: &Value,
index: Option<&Value>,
delta: Value,
finish_reason: Value,
) -> Result<Value, ChatStreamError> {
let index = normalized_choice_index(index)?;
Ok(self
.ctx
.chunk_with_choice(upstream, index, delta, finish_reason))
}
}
impl ChatSseTransformer for OpenAiChatStreamTransformer {
fn handle_event(&mut self, event: RawSseEvent) -> Result<Vec<StreamOutput>, ChatStreamError> {
match classify_upstream_event(event, &STREAMING_UPSTREAM_EVENT_LOG)? {
UpstreamEventKind::Done => {
let mut output = Vec::new();
if !self.sent_final_finish {
debug!("synthesizing final streaming finish chunk before DONE");
output.push(StreamOutput::Json(self.finish_chunk()));
self.sent_final_finish = true;
}
output.push(StreamOutput::Done);
Ok(output)
}
UpstreamEventKind::Usage(value) => self.handle_usage_chunk(&value),
UpstreamEventKind::Choice { value, choice } => {
self.handle_choice_chunk(&value, &choice)
}
}
}
}
const TOOL_CALL_START_MARKER: &str = "<tool_call>";
struct OpenAiToolEmulatedChatStreamTransformer {
ctx: ChunkContext,
tool_context: ToolEmulationContext,
include_usage_requested: bool,
sent_role: bool,
sent_final_finish: bool,
pending_text: String,
tool_buffer: String,
buffering_tool_call: bool,
emitted_tool_calls: bool,
}
impl OpenAiToolEmulatedChatStreamTransformer {
fn new(
tool_context: &ToolEmulationContext,
codec: E2eeCodec,
proxy_instance_key: ProxyInstanceKey,
fallback_model: String,
include_usage_requested: bool,
) -> Result<Self, ChatStreamError> {
Ok(Self {
ctx: ChunkContext::new(codec, proxy_instance_key, fallback_model),
tool_context: tool_context.clone(),
include_usage_requested,
sent_role: false,
sent_final_finish: false,
pending_text: String::new(),
tool_buffer: String::new(),
buffering_tool_call: false,
emitted_tool_calls: false,
})
}
fn handle_choice_chunk(
&mut self,
value: &Value,
choice: &Value,
) -> Result<Vec<StreamOutput>, ChatStreamError> {
let choice = choice.as_object().ok_or_else(|| {
ChatStreamError::malformed_event("upstream choice must be a JSON object")
})?;
let index = normalized_choice_index(choice.get("index"))?;
let finish_reason = normalized_finish_reason(choice.get("finish_reason"))?;
let delta = choice.get("delta").unwrap_or(&Value::Null);
let content = encrypted_delta_content(delta)?;
let reasoning_content = encrypted_delta_reasoning_content(delta)?;
let mut output = Vec::new();
if let Some(reasoning_content) = reasoning_content
&& let Some(reasoning_content) = self.ctx.decrypt(Some(reasoning_content))?
&& !self.sent_final_finish
{
output.push(self.reasoning_chunk(value, index, reasoning_content));
}
if let Some(content) = content
&& let Some(content) = self.ctx.decrypt(Some(content))?
&& !self.sent_final_finish
{
output.extend(self.push_decrypted_content(value, index, &content)?);
}
if !finish_reason.is_null() && !self.sent_final_finish {
output.extend(self.finish_buffered_content(value, index, finish_reason)?);
}
Ok(output)
}
fn push_decrypted_content(
&mut self,
upstream: &Value,
index: u64,
content: &str,
) -> Result<Vec<StreamOutput>, ChatStreamError> {
if self.buffering_tool_call {
self.tool_buffer.push_str(content);
self.ensure_tool_buffer_within_limit()?;
return Ok(Vec::new());
}
self.pending_text.push_str(content);
if let Some(marker_index) = self.pending_text.find(TOOL_CALL_START_MARKER) {
let text = self.pending_text[..marker_index].to_owned();
self.tool_buffer = self.pending_text[marker_index..].to_owned();
self.pending_text.clear();
self.buffering_tool_call = true;
self.ensure_tool_buffer_within_limit()?;
return Ok(self.text_chunk_if_not_empty(upstream, index, text));
}
let streamable_len = streamable_pending_text_len(&self.pending_text);
if streamable_len == 0 {
return Ok(Vec::new());
}
let text = self.pending_text[..streamable_len].to_owned();
self.pending_text.drain(..streamable_len);
Ok(vec![
self.text_field_chunk(upstream, index, "content", text),
])
}
fn finish_buffered_content(
&mut self,
upstream: &Value,
index: u64,
finish_reason: Value,
) -> Result<Vec<StreamOutput>, ChatStreamError> {
let mut output = Vec::new();
if self.buffering_tool_call {
output.extend(self.buffered_tool_call_chunks(upstream, index)?);
} else if !self.pending_text.is_empty() {
let text = std::mem::take(&mut self.pending_text);
output.push(self.text_field_chunk(upstream, index, "content", text));
}
let finish_reason = if self.emitted_tool_calls {
Value::String("tool_calls".to_owned())
} else {
finish_reason
};
output.push(StreamOutput::Json(self.ctx.chunk_with_choice(
upstream,
index,
json!({}),
finish_reason,
)));
self.sent_final_finish = true;
Ok(output)
}
fn buffered_tool_call_chunks(
&mut self,
upstream: &Value,
index: u64,
) -> Result<Vec<StreamOutput>, ChatStreamError> {
self.ensure_tool_buffer_within_limit()?;
match self
.tool_context
.classify_assistant_output(&self.tool_buffer)
{
ToolOutputClassification::ToolCalls(tool_calls) => {
self.emitted_tool_calls = true;
Ok(tool_calls
.iter()
.enumerate()
.map(|(tool_index, tool_call)| {
self.full_tool_call_chunk(upstream, index, tool_index, tool_call)
})
.collect())
}
ToolOutputClassification::NormalText => {
let text = std::mem::take(&mut self.tool_buffer);
self.buffering_tool_call = false;
Ok(self.text_chunk_if_not_empty(upstream, index, text))
}
ToolOutputClassification::InvalidToolCall { error, .. } => {
error!(
validation_error = %error,
payload_bytes = self.tool_buffer.len(),
payload = %self.tool_buffer,
"buffered streamed tool-call payload failed validation"
);
Err(ChatStreamError::malformed_event(format!(
"tool call parsing failed: {error}"
)))
}
}
}
fn ensure_tool_buffer_within_limit(&self) -> Result<(), ChatStreamError> {
if self.tool_buffer.len() > self.tool_context.config().tool_call_max_bytes {
return Err(ChatStreamError::malformed_event(format!(
"tool call output exceeded max size of {} bytes",
self.tool_context.config().tool_call_max_bytes
)));
}
Ok(())
}
fn text_chunk_if_not_empty(
&mut self,
upstream: &Value,
index: u64,
text: String,
) -> Vec<StreamOutput> {
if text.is_empty() {
Vec::new()
} else {
vec![self.text_field_chunk(upstream, index, "content", text)]
}
}
fn reasoning_chunk(
&mut self,
upstream: &Value,
index: u64,
reasoning_content: String,
) -> StreamOutput {
self.text_field_chunk(upstream, index, "reasoning_content", reasoning_content)
}
fn text_field_chunk(
&mut self,
upstream: &Value,
index: u64,
field: &'static str,
text: String,
) -> StreamOutput {
let mut delta = serde_json::Map::new();
self.insert_role_if_needed(&mut delta);
delta.insert(field.to_owned(), Value::String(text));
StreamOutput::Json(self.ctx.chunk_with_choice(
upstream,
index,
Value::Object(delta),
Value::Null,
))
}
fn insert_role_if_needed(&mut self, delta: &mut serde_json::Map<String, Value>) {
if !self.sent_role {
delta.insert("role".to_owned(), Value::String("assistant".to_owned()));
self.sent_role = true;
}
}
fn full_tool_call_chunk(
&mut self,
upstream: &Value,
index: u64,
tool_index: usize,
tool_call: &ValidatedToolCall,
) -> StreamOutput {
let mut delta = serde_json::Map::new();
self.insert_role_if_needed(&mut delta);
let mut tool_call_value = tool_call.to_openai_value();
if let Some(tool_call_object) = tool_call_value.as_object_mut() {
tool_call_object.insert("index".to_owned(), json!(tool_index));
}
delta.insert("tool_calls".to_owned(), Value::Array(vec![tool_call_value]));
StreamOutput::Json(self.ctx.chunk_with_choice(
upstream,
index,
Value::Object(delta),
Value::Null,
))
}
fn handle_usage_chunk(&self, value: &Value) -> Result<Vec<StreamOutput>, ChatStreamError> {
let Some(usage) = value.get("usage") else {
warn!("tool-emulated upstream chunk has no choices and no usage");
return Err(ChatStreamError::malformed_event(
"upstream chunk has no choices and no usage",
));
};
if !self.include_usage_requested {
return Ok(Vec::new());
}
Ok(vec![StreamOutput::Json(self.ctx.usage_chunk(value, usage))])
}
fn finish_stream(&mut self) -> Result<Vec<StreamOutput>, ChatStreamError> {
let upstream = &Value::Null;
let mut output = Vec::new();
if !self.sent_final_finish {
output.extend(self.finish_buffered_content(
upstream,
0,
Value::String("stop".to_owned()),
)?);
}
output.push(StreamOutput::Done);
Ok(output)
}
}
fn streamable_pending_text_len(pending_text: &str) -> usize {
let protected_suffix_len = TOOL_CALL_START_MARKER.len().saturating_sub(1);
if pending_text.len() <= protected_suffix_len {
return 0;
}
let mut split_at = pending_text.len() - protected_suffix_len;
while !pending_text.is_char_boundary(split_at) {
split_at -= 1;
}
split_at
}
impl ChatSseTransformer for OpenAiToolEmulatedChatStreamTransformer {
fn handle_event(&mut self, event: RawSseEvent) -> Result<Vec<StreamOutput>, ChatStreamError> {
match classify_upstream_event(event, &TOOL_EMULATED_UPSTREAM_EVENT_LOG)? {
UpstreamEventKind::Done => self.finish_stream(),
UpstreamEventKind::Usage(value) => self.handle_usage_chunk(&value),
UpstreamEventKind::Choice { value, choice } => {
self.handle_choice_chunk(&value, &choice)
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum StreamOutput {
Json(Value),
Done,
}
fn normalized_choice_index(index: Option<&Value>) -> Result<u64, ChatStreamError> {
match index {
Some(Value::Number(number)) => number.as_u64().ok_or_else(|| {
ChatStreamError::malformed_event("upstream choice index must be a non-negative integer")
}),
Some(_) => Err(ChatStreamError::malformed_event(
"upstream choice index must be a non-negative integer",
)),
None => Ok(0),
}
}
fn normalized_finish_reason(value: Option<&Value>) -> Result<Value, ChatStreamError> {
match value {
Some(Value::Null) | None => Ok(Value::Null),
Some(Value::String(reason)) => Ok(Value::String(reason.clone())),
Some(_) => Err(ChatStreamError::malformed_event(
"upstream finish_reason must be a string or null",
)),
}
}
fn encrypted_delta_content(delta: &Value) -> Result<Option<&str>, ChatStreamError> {
encrypted_delta_text_field(delta, "content")
}
fn encrypted_delta_reasoning_content(delta: &Value) -> Result<Option<&str>, ChatStreamError> {
encrypted_delta_text_field(delta, "reasoning_content")
}
fn encrypted_delta_text_field<'a>(
delta: &'a Value,
field: &'static str,
) -> Result<Option<&'a str>, ChatStreamError> {
match delta.get(field) {
Some(Value::Null) => {
debug!(field, "ignoring null upstream delta text field");
Ok(None)
}
Some(Value::String(content)) if content.is_empty() => {
debug!(field, "ignoring empty upstream delta text field");
Ok(None)
}
Some(Value::String(content)) => Ok(Some(content.as_str())),
Some(_) => Err(ChatStreamError::malformed_event(format!(
"upstream delta.{field} must be a string or null"
))),
None => Ok(None),
}
}
fn string_field<'a>(value: &'a Value, field: &str) -> Option<&'a str> {
value.get(field).and_then(Value::as_str)
}
fn integer_field(value: &Value, field: &str) -> Option<i64> {
value.get(field).and_then(Value::as_i64)
}
fn unix_timestamp_now() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_secs() as i64)
.unwrap_or(0)
}
async fn method_not_allowed(method: Method, uri: Uri) -> ProxyError {
ProxyError::MethodNotAllowed { method, uri }
}
async fn not_found(uri: Uri) -> ProxyError {
ProxyError::NotFound { uri }
}
#[derive(Debug, Error)]
pub enum ChatStreamError {
#[error("Venice upstream stream failed: {message}")]
UpstreamStream { message: String },
#[error("Venice upstream stream emitted an error event: {message}")]
UpstreamEvent { message: String },
#[error("Venice upstream stream event is malformed: {message}")]
MalformedEvent { message: String },
#[error("failed to decrypt Venice E2EE response chunk: {source}")]
Decryption { source: E2eeCodecError },
}
impl ChatStreamError {
fn upstream_stream(source: reqwest::Error) -> Self {
Self::UpstreamStream {
message: source.to_string(),
}
}
fn upstream_event(message: impl Into<String>) -> Self {
Self::UpstreamEvent {
message: message.into(),
}
}
fn malformed_event(message: impl Into<String>) -> Self {
Self::MalformedEvent {
message: message.into(),
}
}
fn invalid_utf8(source: std::str::Utf8Error) -> Self {
Self::MalformedEvent {
message: format!("upstream SSE bytes are not valid UTF-8: {source}"),
}
}
fn json_event(source: serde_json::Error) -> Self {
Self::MalformedEvent {
message: format!("upstream SSE data is not valid JSON: {source}"),
}
}
fn decryption(source: E2eeCodecError) -> Self {
Self::Decryption { source }
}
fn api_error_type(&self) -> &'static str {
match self {
Self::UpstreamStream { .. }
| Self::UpstreamEvent { .. }
| Self::MalformedEvent { .. } => "proxy_upstream_error",
Self::Decryption { .. } => "proxy_e2ee_error",
}
}
fn api_error_code(&self) -> &'static str {
match self {
Self::UpstreamStream { .. } => "upstream_stream_error",
Self::UpstreamEvent { .. } => "upstream_stream_error",
Self::MalformedEvent { .. } => "upstream_malformed_response",
Self::Decryption { .. } => "e2ee_response_decryption_failed",
}
}
}
#[derive(Debug, Error)]
pub enum ProxyError {
#[error(transparent)]
Venice(#[from] VeniceClientError),
#[error(transparent)]
Attestation(#[from] AttestationError),
#[error(transparent)]
Session(#[from] SessionError),
#[error(transparent)]
ChatRequest(#[from] ChatRequestError),
#[error(transparent)]
ChatConstruction(#[from] ChatConstructionError),
#[error(transparent)]
ChatStream(#[from] ChatStreamError),
#[error("The model failed to produce a valid tool call after correction attempts.")]
ToolCallRetryExhausted {
max_retries: u32,
last_validation_error: String,
},
#[error(
"proxy instance key is unavailable; keys.generate_proxy_instance_key_on_startup must be enabled for E2EE chat requests"
)]
ProxyInstanceKeyUnavailable,
#[error("session does not contain an attested model public key after attestation verification")]
MissingAttestedModelKey,
#[error("method {method} is not supported for {uri}")]
MethodNotAllowed { method: Method, uri: Uri },
#[error("route {uri} was not found")]
NotFound { uri: Uri },
}
impl ProxyError {
fn status(&self) -> StatusCode {
match self {
Self::Venice(_) => StatusCode::BAD_GATEWAY,
Self::Attestation(error) if error.verifier_unavailable() => {
StatusCode::SERVICE_UNAVAILABLE
}
Self::Attestation(_) => StatusCode::BAD_GATEWAY,
Self::Session(
SessionError::MissingSessionIdentifier | SessionError::InvalidHeaderValue { .. },
) => StatusCode::BAD_REQUEST,
Self::Session(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::ChatRequest(_) => StatusCode::BAD_REQUEST,
Self::ChatConstruction(_)
| Self::ChatStream(_)
| Self::ToolCallRetryExhausted { .. } => StatusCode::BAD_GATEWAY,
Self::ProxyInstanceKeyUnavailable | Self::MissingAttestedModelKey => {
StatusCode::INTERNAL_SERVER_ERROR
}
Self::MethodNotAllowed { .. } => StatusCode::METHOD_NOT_ALLOWED,
Self::NotFound { .. } => StatusCode::NOT_FOUND,
}
}
fn error_type(&self) -> &'static str {
match self {
Self::Venice(error) => error.api_error_type(),
Self::Attestation(error) => error.api_error_type(),
Self::Session(
SessionError::MissingSessionIdentifier | SessionError::InvalidHeaderValue { .. },
) => "invalid_request_error",
Self::Session(_) => "proxy_session_error",
Self::ChatRequest(_) => "invalid_request_error",
Self::ChatConstruction(_) => "proxy_e2ee_error",
Self::ChatStream(error) => error.api_error_type(),
Self::ToolCallRetryExhausted { .. } => "proxy_tool_call_error",
Self::ProxyInstanceKeyUnavailable => "proxy_configuration_error",
Self::MissingAttestedModelKey => "proxy_attestation_error",
Self::MethodNotAllowed { .. } | Self::NotFound { .. } => "invalid_request_error",
}
}
fn code(&self) -> &'static str {
match self {
Self::Venice(error) => error.api_error_code(),
Self::Attestation(error) => error.api_error_code(),
Self::Session(SessionError::MissingSessionIdentifier) => "session_identifier_missing",
Self::Session(SessionError::InvalidHeaderValue { .. }) => "invalid_session_header",
Self::Session(_) => "session_error",
Self::ChatRequest(error) => error.api_error_code(),
Self::ChatConstruction(error) => error.api_error_code(),
Self::ChatStream(error) => error.api_error_code(),
Self::ToolCallRetryExhausted { .. } => "invalid_tool_call",
Self::ProxyInstanceKeyUnavailable => "proxy_instance_key_unavailable",
Self::MissingAttestedModelKey => "attestation_failed",
Self::MethodNotAllowed { .. } => "method_not_allowed",
Self::NotFound { .. } => "not_found",
}
}
}
impl IntoResponse for ProxyError {
fn into_response(self) -> Response {
let status = self.status();
let error_code = self.code();
let error_type = self.error_type();
if status.is_server_error() {
error!(
status = status.as_u16(),
error_code,
error_type,
error = %self,
"proxy request failed"
);
} else {
warn!(
status = status.as_u16(),
error_code,
error_type,
error = %self,
"proxy request rejected"
);
}
let mut response = if let Self::ToolCallRetryExhausted {
max_retries,
last_validation_error,
} = &self
{
let body = json!({
"error": {
"message": self.to_string(),
"type": error_type,
"code": error_code,
"details": {
"max_retries": max_retries,
"last_validation_error": last_validation_error,
},
}
});
(status, Json(body)).into_response()
} else {
let body = ErrorResponse::new(self.to_string(), error_type, error_code);
(status, Json(body)).into_response()
};
apply_error_headers(response.headers_mut(), error_code);
response
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ProxyMetadataHeaders {
pub e2ee: Option<String>,
pub attestation_mode: Option<String>,
pub attested_model: Option<String>,
pub tee_provider: Option<String>,
pub tdx_verified: Option<bool>,
pub tdx_debug: Option<bool>,
pub nvidia_verified: Option<String>,
pub key_binding: Option<bool>,
pub session_id: Option<String>,
pub session_scope: Option<String>,
pub tool_mode: Option<String>,
pub tool_retries: Option<u32>,
}
impl ProxyMetadataHeaders {
pub fn from_config(config: &ProxyConfig) -> Self {
Self {
attestation_mode: Some(config.attestation.mode.as_str().to_owned()),
tool_mode: Some(config.tools.mode.as_str().to_owned()),
..Self::default()
}
}
pub fn for_verified_chat(config: &ProxyConfig, session: &SessionContext) -> Self {
let evidence = session
.attestation_report
.as_ref()
.and_then(|report| report.get("attestation"))
.and_then(Value::as_object);
let tee_provider = evidence
.and_then(|evidence| evidence.get("tee_provider"))
.and_then(Value::as_str)
.unwrap_or("unknown")
.to_owned();
let tdx_debug = evidence.and_then(|evidence| {
evidence
.get("debug")
.or_else(|| evidence.get("tdx_debug"))
.and_then(Value::as_bool)
});
let nvidia_payload_present = evidence
.and_then(|evidence| evidence.get("nvidia_payload"))
.is_some_and(|value| !value.is_null());
let nvidia_verified = match (config.attestation.require_nvidia, nvidia_payload_present) {
(_, false) => "not-present",
(NvidiaRequirement::Never, true) => "ignored",
(_, true) => "verified",
}
.to_owned();
Self {
e2ee: Some("verified".to_owned()),
attestation_mode: Some(config.attestation.mode.as_str().to_owned()),
attested_model: Some(session.model_id.clone()),
tee_provider: Some(tee_provider),
tdx_verified: config.attestation.require_tdx.then_some(true),
tdx_debug,
nvidia_verified: Some(nvidia_verified),
key_binding: Some(true),
session_id: Some(session.agent_session_id.clone()),
session_scope: Some(session.scope.as_str().to_owned()),
tool_mode: Some(config.tools.mode.as_str().to_owned()),
tool_retries: None,
}
}
pub fn apply(&self, headers: &mut HeaderMap) {
insert_optional_header(headers, HEADER_PROXY_E2EE, self.e2ee.as_deref());
insert_optional_header(
headers,
HEADER_PROXY_ATTESTATION_MODE,
self.attestation_mode.as_deref(),
);
insert_optional_header(
headers,
HEADER_PROXY_ATTESTED_MODEL,
self.attested_model.as_deref(),
);
insert_optional_header(
headers,
HEADER_PROXY_TEE_PROVIDER,
self.tee_provider.as_deref(),
);
insert_optional_bool_header(headers, HEADER_PROXY_TDX_VERIFIED, self.tdx_verified);
insert_optional_bool_header(headers, HEADER_PROXY_TDX_DEBUG, self.tdx_debug);
insert_optional_header(
headers,
HEADER_PROXY_NVIDIA_VERIFIED,
self.nvidia_verified.as_deref(),
);
insert_optional_bool_header(headers, HEADER_PROXY_KEY_BINDING, self.key_binding);
insert_optional_header(headers, HEADER_PROXY_SESSION_ID, self.session_id.as_deref());
insert_optional_header(
headers,
HEADER_PROXY_SESSION_SCOPE,
self.session_scope.as_deref(),
);
insert_optional_header(headers, HEADER_PROXY_TOOL_MODE, self.tool_mode.as_deref());
if let Some(tool_retries) = self.tool_retries {
insert_header(
headers,
HEADER_PROXY_TOOL_RETRIES,
&tool_retries.to_string(),
);
}
}
}
pub fn apply_error_headers(headers: &mut HeaderMap, error_code: &str) {
insert_header(headers, HEADER_PROXY_ERROR_CODE, error_code);
}
fn insert_optional_header(headers: &mut HeaderMap, name: &'static str, value: Option<&str>) {
if let Some(value) = value {
insert_header(headers, name, value);
}
}
fn insert_optional_bool_header(headers: &mut HeaderMap, name: &'static str, value: Option<bool>) {
if let Some(value) = value {
insert_header(headers, name, if value { "true" } else { "false" });
}
}
fn insert_header(headers: &mut HeaderMap, name: &'static str, value: &str) {
let Ok(name) = HeaderName::from_bytes(name.as_bytes()) else {
return;
};
let Ok(value) = HeaderValue::from_str(value) else {
return;
};
headers.insert(name, value);
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
collections::{HashMap, VecDeque},
sync::{Arc, Mutex},
time::Duration,
};
use axum::{
body::Body,
extract::Query,
http::Request,
routing::{get, post},
};
use serde_json::json;
use crate::config::NvidiaRequirement;
use tower::ServiceExt;
fn test_app() -> Router {
router_with_venice_client(ProxyConfig::default(), test_venice_client())
}
fn test_venice_client() -> VeniceClient {
test_venice_client_for_base_url("http://127.0.0.1:1/api/v1")
}
fn test_venice_client_for_base_url(base_url: impl AsRef<str>) -> VeniceClient {
VeniceClient::new(base_url.as_ref(), "test-api-key", Duration::from_secs(1))
.expect("test Venice client should build")
}
fn chat_config_with_basic_test_attestation() -> ProxyConfig {
let mut config = ProxyConfig::default();
config.attestation.require_tdx = false;
config.attestation.require_nvidia = NvidiaRequirement::Never;
config
}
#[test]
fn app_state_initializes_key_and_session_managers_from_config() {
let state = AppState::from_parts(ProxyConfig::default(), test_venice_client());
let key = state
.proxy_instance_key()
.expect("default config should generate startup key");
assert_eq!(key.public_key_hex().len(), 130);
assert!(state.session_manager().is_empty());
assert_eq!(
state.attestation_verifier().policy(),
&ProxyConfig::default().attestation
);
let mut config = ProxyConfig::default();
config.keys.generate_proxy_instance_key_on_startup = false;
let state = AppState::from_parts(config, test_venice_client());
assert!(state.proxy_instance_key().is_none());
}
async fn error_body(response: Response) -> ErrorResponse {
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body should buffer");
serde_json::from_slice(&bytes).expect("response should be OpenAI-style error JSON")
}
#[tokio::test]
async fn chat_route_ignores_upstream_role_only_chunk_before_encrypted_content() {
let response = streaming_chat_response(
"chat-route-role-only",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
vec![
MockStreamFrame::Role,
MockStreamFrame::Text("Hello"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let data = sse_data(&body);
assert_eq!(data.len(), 3);
let first: Value = serde_json::from_str(data[0]).expect("first chunk should be JSON");
assert_eq!(first["choices"][0]["delta"]["role"], "assistant");
assert_eq!(first["choices"][0]["delta"]["content"], "Hello");
assert_eq!(data[2], "[DONE]");
}
#[tokio::test]
async fn chat_route_streams_decrypted_normal_assistant_text() {
let response = streaming_chat_response(
"chat-route-test",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
vec![
MockStreamFrame::NullContent,
MockStreamFrame::EmptyContent,
MockStreamFrame::Text("Hello"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(HEADER_PROXY_E2EE).unwrap(),
"verified"
);
assert_eq!(
response.headers().get(HEADER_PROXY_ATTESTED_MODEL).unwrap(),
"e2ee-test"
);
let body = response_body(response).await;
let data = sse_data(&body);
assert_eq!(data.len(), 3);
let first: Value = serde_json::from_str(data[0]).expect("first chunk should be JSON");
assert_eq!(first["object"], "chat.completion.chunk");
assert_eq!(first["model"], "e2ee-test");
assert_eq!(first["choices"][0]["delta"]["role"], "assistant");
assert_eq!(first["choices"][0]["delta"]["content"], "Hello");
assert!(first["choices"][0]["finish_reason"].is_null());
let final_chunk: Value = serde_json::from_str(data[1]).expect("final chunk should be JSON");
assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
assert_eq!(final_chunk["choices"][0]["finish_reason"], "stop");
assert_eq!(data[2], "[DONE]");
}
#[tokio::test]
async fn chat_route_streams_decrypted_reasoning_content() {
let response = streaming_chat_response(
"chat-route-reasoning-stream",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"reasoning":{"effort":"high"}}"#,
vec![
MockStreamFrame::Reasoning("Thinking"),
MockStreamFrame::Text("Answer"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let data = sse_data(&body);
assert_eq!(data.len(), 4);
let reasoning: Value =
serde_json::from_str(data[0]).expect("reasoning chunk should be JSON");
let answer: Value = serde_json::from_str(data[1]).expect("answer chunk should be JSON");
assert_eq!(reasoning["choices"][0]["delta"]["role"], "assistant");
assert_eq!(
reasoning["choices"][0]["delta"]["reasoning_content"],
"Thinking"
);
assert!(answer["choices"][0]["delta"].get("role").is_none());
assert_eq!(answer["choices"][0]["delta"]["content"], "Answer");
assert_eq!(data.last().copied(), Some("[DONE]"));
}
#[tokio::test]
async fn chat_route_streams_multiple_decrypted_content_chunks() {
let response = streaming_chat_response(
"chat-route-multiple-chunks",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
vec![
MockStreamFrame::Text("Hello"),
MockStreamFrame::Text(" world"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let data = sse_data(&body);
let first: Value = serde_json::from_str(data[0]).expect("first chunk should be JSON");
let second: Value = serde_json::from_str(data[1]).expect("second chunk should be JSON");
assert_eq!(first["choices"][0]["delta"]["role"], "assistant");
assert_eq!(first["choices"][0]["delta"]["content"], "Hello");
assert!(second["choices"][0]["delta"].get("role").is_none());
assert_eq!(second["choices"][0]["delta"]["content"], " world");
assert_eq!(data.last().copied(), Some("[DONE]"));
}
#[tokio::test]
async fn chat_route_passes_through_usage_chunk_when_requested_and_upstream_provides_it() {
let response = streaming_chat_response(
"chat-route-usage",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"stream_options":{"include_usage":true}}"#,
vec![
MockStreamFrame::Text("Hello"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Usage,
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let data = sse_data(&body);
assert_eq!(data.len(), 4);
let usage_chunk: Value = serde_json::from_str(data[2]).expect("usage chunk should be JSON");
assert_eq!(usage_chunk["choices"], json!([]));
assert_eq!(usage_chunk["usage"]["total_tokens"], 3);
assert_eq!(data[3], "[DONE]");
}
#[tokio::test]
async fn chat_route_returns_buffered_non_streaming_completion() {
let response = chat_response(
"chat-route-non-streaming-success",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
vec![
MockStreamFrame::NullContent,
MockStreamFrame::EmptyContent,
MockStreamFrame::Text("Hello"),
MockStreamFrame::Text(" world"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(HEADER_PROXY_E2EE).unwrap(),
"verified"
);
let body = json_body(response).await;
assert_eq!(body["object"], "chat.completion");
assert_eq!(body["id"], "chatcmpl-upstream-test");
assert_eq!(body["created"], 1_717_171_717);
assert_eq!(body["model"], "e2ee-test");
assert_eq!(body["choices"][0]["index"], 0);
assert_eq!(body["choices"][0]["message"]["role"], "assistant");
assert_eq!(body["choices"][0]["message"]["content"], "Hello world");
assert_eq!(body["choices"][0]["finish_reason"], "stop");
assert!(body["usage"].is_null());
}
#[tokio::test]
async fn chat_route_returns_buffered_reasoning_content() {
let response = chat_response(
"chat-route-reasoning-non-streaming",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false,"reasoning_effort":"medium"}"#,
vec![
MockStreamFrame::Reasoning("Think "),
MockStreamFrame::Reasoning("first."),
MockStreamFrame::Text("Answer"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = json_body(response).await;
assert_eq!(
body["choices"][0]["message"]["reasoning_content"],
"Think first."
);
assert_eq!(body["choices"][0]["message"]["content"], "Answer");
}
#[tokio::test]
async fn chat_route_treats_omitted_stream_as_buffered_non_streaming() {
let response = chat_response(
"chat-route-omitted-stream",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}]}"#,
vec![MockStreamFrame::Text("Hello"), MockStreamFrame::Done],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = json_body(response).await;
assert_eq!(body["object"], "chat.completion");
assert_eq!(body["choices"][0]["message"]["content"], "Hello");
assert_eq!(body["choices"][0]["finish_reason"], "stop");
}
#[tokio::test]
async fn chat_route_streams_incremental_tool_call_chunks() {
let response = streaming_chat_response(
"chat-route-tool-stream",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
vec![
MockStreamFrame::Text("<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}\n</tool_call>"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let chunks = sse_json_chunks(&body);
assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
let tool_calls = streamed_tool_call_deltas(&chunks);
assert!(!tool_calls.is_empty());
let first = tool_calls[0];
assert_eq!(first["index"], 0);
assert!(first["id"].as_str().unwrap().starts_with("call_"));
assert_eq!(first["type"], "function");
assert_eq!(first["function"]["name"], "search_web");
for later in &tool_calls[1..] {
assert!(later.get("id").is_none());
assert!(later.get("type").is_none());
assert!(later["function"].get("name").is_none());
}
assert_eq!(
streamed_tool_call_arguments(&chunks, 0),
r#"{"query":"example"}"#
);
let final_chunk = chunks.last().expect("stream should have chunks");
assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
}
#[tokio::test]
async fn chat_route_streams_text_then_incremental_tool_call() {
let response = streaming_chat_response(
"chat-route-tool-stream-mixed-text",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
vec![
MockStreamFrame::NullContent,
MockStreamFrame::EmptyContent,
MockStreamFrame::Text("I'll check that. "),
MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}"),
MockStreamFrame::Text("</tool_call>"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let chunks = sse_json_chunks(&body);
assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
assert_eq!(streamed_content(&chunks), "I'll check that. ");
let tool_calls = streamed_tool_call_deltas(&chunks);
assert!(!tool_calls.is_empty());
assert_eq!(tool_calls[0]["function"]["name"], "search_web");
assert_eq!(
streamed_tool_call_arguments(&chunks, 0),
r#"{"query":"example"}"#
);
let final_chunk = chunks.last().expect("stream should have chunks");
assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
}
#[tokio::test]
async fn chat_route_fails_closed_on_unterminated_streamed_tool_call() {
let response = streaming_chat_response(
"chat-route-tool-stream-missing-close",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
vec![
MockStreamFrame::Text("I'll check that. "),
MockStreamFrame::Text("<tool_call>{\"name\":"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_stream_body_fails(response).await;
}
#[tokio::test]
async fn chat_route_streams_hermes_format_tool_call_from_glm_model() {
let response = streaming_chat_response(
"chat-route-tool-stream-glm-hermes",
r#"{"model":"e2ee-glm-5-1","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
vec![
MockStreamFrame::Text("<tool_call>\n{\"name\":\"search_web\",\"arguments\":"),
MockStreamFrame::Text("{\"query\":\"example\"}}\n</tool_call>"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let chunks = sse_json_chunks(&body);
let tool_calls = streamed_tool_call_deltas(&chunks);
assert!(!tool_calls.is_empty());
assert_eq!(tool_calls[0]["function"]["name"], "search_web");
assert_eq!(
streamed_tool_call_arguments(&chunks, 0),
r#"{"query":"example"}"#
);
let final_chunk = chunks.last().expect("stream should have chunks");
assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
}
#[tokio::test]
async fn chat_route_recovers_streamed_tool_call_with_truncated_closing_marker() {
let response = streaming_chat_response(
"chat-route-tool-stream-truncated-close",
r#"{"model":"e2ee-glm-4-7-flash-p","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
vec![
MockStreamFrame::Text("<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}\n"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let chunks = sse_json_chunks(&body);
let tool_calls = streamed_tool_call_deltas(&chunks);
assert!(!tool_calls.is_empty());
assert_eq!(tool_calls[0]["function"]["name"], "search_web");
assert_eq!(
streamed_tool_call_arguments(&chunks, 0),
r#"{"query":"example"}"#
);
let final_chunk = chunks.last().expect("stream should have chunks");
assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
}
#[tokio::test]
async fn chat_route_streams_multiple_tool_calls_split_across_chunks() {
let response = streaming_chat_response(
"chat-route-tool-stream-multiple-calls",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
vec![
MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"first\"}}"),
MockStreamFrame::Text("</tool_call><tool_call>{\"name\":\"search_web\",\"arguments\":"),
MockStreamFrame::Text("{\"query\":\"second\"}}</tool_call>"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let chunks = sse_json_chunks(&body);
assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
let tool_calls = streamed_tool_call_deltas(&chunks);
let first = tool_calls
.iter()
.find(|tool_call| tool_call["index"] == 0 && tool_call.get("id").is_some())
.expect("first call should have an id-bearing fragment");
let second = tool_calls
.iter()
.find(|tool_call| tool_call["index"] == 1 && tool_call.get("id").is_some())
.expect("second call should have an id-bearing fragment");
assert_eq!(first["function"]["name"], "search_web");
assert_eq!(second["function"]["name"], "search_web");
assert_ne!(first["id"], second["id"]);
assert_eq!(
streamed_tool_call_arguments(&chunks, 0),
r#"{"query":"first"}"#
);
assert_eq!(
streamed_tool_call_arguments(&chunks, 1),
r#"{"query":"second"}"#
);
let final_chunk = chunks.last().expect("stream should have chunks");
assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
}
#[tokio::test]
async fn chat_route_tool_stream_passes_through_usage_chunk_when_requested() {
let response = streaming_chat_response(
"chat-route-tool-stream-usage",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"stream_options":{"include_usage":true},"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
vec![
MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Usage,
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let chunks = sse_json_chunks(&body);
let usage_chunk = chunks.last().expect("stream should have chunks");
assert_eq!(usage_chunk["choices"], json!([]));
assert_eq!(usage_chunk["usage"]["total_tokens"], 3);
let finish_chunk = &chunks[chunks.len() - 2];
assert_eq!(finish_chunk["choices"][0]["finish_reason"], "tool_calls");
}
#[tokio::test]
async fn chat_route_fails_closed_when_streamed_tool_call_exceeds_max_bytes() {
let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
let base_url = spawn_streaming_venice_server(
model_public_key,
true,
vec![
MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"this argument body is much longer than the configured cap\"}}</tool_call>"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
let mut config = chat_config_with_basic_test_attestation();
config.tools.tool_call_max_bytes = 16;
let response = request_chat_with_config(
config,
"chat-route-tool-stream-max-bytes",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
base_url,
)
.await;
assert_stream_body_fails(response).await;
}
#[tokio::test]
async fn chat_route_streams_all_tool_calls_when_parallel_tool_calls_false() {
let response = streaming_chat_response(
"chat-route-tool-stream-parallel-false",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"parallel_tool_calls":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
vec![
MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"first\"}}</tool_call>"),
MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"second\"}}</tool_call>"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let chunks = sse_json_chunks(&body);
assert_eq!(
streamed_tool_call_arguments(&chunks, 0),
r#"{"query":"first"}"#
);
assert_eq!(
streamed_tool_call_arguments(&chunks, 1),
r#"{"query":"second"}"#
);
let final_chunk = chunks.last().expect("stream should have chunks");
assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
}
#[tokio::test]
async fn chat_route_returns_non_streaming_tool_call_body_from_mixed_text() {
let response = chat_response(
"chat-route-tool-non-stream-mixed-text",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
vec![
MockStreamFrame::Text("I'll check that. <tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = json_body(response).await;
assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
let tool_call = &body["choices"][0]["message"]["tool_calls"][0];
assert_eq!(tool_call["function"]["name"], "search_web");
assert_eq!(tool_call["function"]["arguments"], r#"{"query":"example"}"#);
}
#[tokio::test]
async fn chat_route_returns_non_streaming_tool_call_body() {
let response = chat_response(
"chat-route-tool-non-stream",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
vec![
MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = json_body(response).await;
assert_eq!(body["object"], "chat.completion");
assert!(body["choices"][0]["message"]["content"].is_null());
assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
let tool_call = &body["choices"][0]["message"]["tool_calls"][0];
assert!(tool_call["id"].as_str().unwrap().starts_with("call_"));
assert_eq!(tool_call["type"], "function");
assert_eq!(tool_call["function"]["name"], "search_web");
assert_eq!(tool_call["function"]["arguments"], r#"{"query":"example"}"#);
}
#[tokio::test]
async fn chat_route_returns_non_streaming_multiple_tool_calls() {
let response = chat_response(
"chat-route-tool-non-stream-multiple-calls",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
vec![
MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"first\"}}</tool_call>\n<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"second\"}}</tool_call>"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = json_body(response).await;
assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
assert!(body["choices"][0]["message"]["content"].is_null());
let tool_calls = body["choices"][0]["message"]["tool_calls"]
.as_array()
.expect("tool_calls should be an array");
assert_eq!(tool_calls.len(), 2);
assert_eq!(tool_calls[0]["function"]["name"], "search_web");
assert_eq!(
tool_calls[0]["function"]["arguments"],
r#"{"query":"first"}"#
);
assert_eq!(
tool_calls[1]["function"]["arguments"],
r#"{"query":"second"}"#
);
assert_ne!(tool_calls[0]["id"], tool_calls[1]["id"]);
}
#[tokio::test]
async fn chat_route_tool_mode_leaves_normal_text_unaffected() {
let response = streaming_chat_response(
"chat-route-tool-normal-text",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object"}}}]}"#,
vec![
MockStreamFrame::Text("Hello without tools"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let chunks = sse_json_chunks(&body);
assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
assert_eq!(streamed_content(&chunks), "Hello without tools");
assert!(streamed_tool_call_deltas(&chunks).is_empty());
}
#[tokio::test]
async fn chat_route_treats_marker_like_non_protocol_text_as_normal_text() {
let response = streaming_chat_response(
"chat-route-tool-marker-like-text",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object"}}}]}"#,
vec![
MockStreamFrame::Text("<tool_cal>{not actually a marker}"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let chunks = sse_json_chunks(&body);
assert_eq!(
streamed_content(&chunks),
"<tool_cal>{not actually a marker}"
);
assert!(streamed_tool_call_deltas(&chunks).is_empty());
}
#[tokio::test]
async fn chat_route_retries_invalid_tool_call_and_returns_success() {
let response = chat_response_sequence(
"chat-route-tool-retry-success",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
vec![
vec![
MockStreamFrame::Text("<tool_call>{\"name\":\"unknown\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
MockStreamFrame::Done,
],
vec![
MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
MockStreamFrame::Done,
],
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(HEADER_PROXY_TOOL_RETRIES).unwrap(),
"1"
);
let body = json_body(response).await;
assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
assert_eq!(
body["choices"][0]["message"]["tool_calls"][0]["function"]["name"],
"search_web"
);
}
#[tokio::test]
async fn chat_route_returns_retry_failure_error_shape() {
let response = chat_response(
"chat-route-tool-retry-failure",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
vec![
MockStreamFrame::Text("<tool_call>{\"name\":\"unknown\",\"arguments\":{}}</tool_call>"),
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
assert_eq!(
response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
"invalid_tool_call"
);
let body = json_body(response).await;
assert_eq!(body["error"]["type"], "proxy_tool_call_error");
assert_eq!(body["error"]["code"], "invalid_tool_call");
assert_eq!(body["error"]["details"]["max_retries"], 2);
assert!(
body["error"]["details"]["last_validation_error"]
.as_str()
.unwrap()
.contains("unknown tool name")
);
}
#[tokio::test]
async fn chat_route_non_streaming_fails_closed_on_upstream_error_response() {
let response = chat_response_with_upstream_status(
"chat-route-non-streaming-upstream-error",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
StatusCode::INTERNAL_SERVER_ERROR,
)
.await;
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
assert_eq!(
response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
"upstream_status_error"
);
let body = error_body(response).await;
assert_eq!(body.error.kind, "proxy_upstream_error");
assert_eq!(body.error.code, "upstream_status_error");
}
#[tokio::test]
async fn chat_route_non_streaming_fails_closed_on_malformed_upstream_payload() {
let response = chat_response(
"chat-route-non-streaming-malformed",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
vec![MockStreamFrame::Raw("data: {\"choices\":\"bad\"}\n\n")],
)
.await;
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
assert_eq!(
response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
"upstream_malformed_response"
);
let body = error_body(response).await;
assert_eq!(body.error.kind, "proxy_upstream_error");
assert_eq!(body.error.code, "upstream_malformed_response");
}
#[tokio::test]
async fn chat_route_non_streaming_fails_closed_on_missing_encrypted_content() {
let response = chat_response(
"chat-route-non-streaming-missing-content",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
vec![MockStreamFrame::Finish("stop"), MockStreamFrame::Done],
)
.await;
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
assert_eq!(
response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
"e2ee_response_decryption_failed"
);
let body = error_body(response).await;
assert_eq!(body.error.kind, "proxy_e2ee_error");
assert_eq!(body.error.code, "e2ee_response_decryption_failed");
}
#[tokio::test]
async fn chat_route_non_streaming_fails_closed_on_decryption_failure() {
let response = chat_response(
"chat-route-non-streaming-decryption-failure",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
vec![MockStreamFrame::TextForWrongRecipient(" secret"), MockStreamFrame::Done],
)
.await;
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
assert_eq!(
response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
"e2ee_response_decryption_failed"
);
let body = error_body(response).await;
assert_eq!(body.error.kind, "proxy_e2ee_error");
assert_eq!(body.error.code, "e2ee_response_decryption_failed");
}
#[tokio::test]
async fn chat_route_non_streaming_passes_through_usage_when_available() {
let response = chat_response(
"chat-route-non-streaming-usage",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
vec![
MockStreamFrame::Text("Hello"),
MockStreamFrame::Finish("stop"),
MockStreamFrame::Usage,
MockStreamFrame::Done,
],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = json_body(response).await;
assert_eq!(body["choices"][0]["message"]["content"], "Hello");
assert_eq!(body["usage"]["prompt_tokens"], 1);
assert_eq!(body["usage"]["completion_tokens"], 2);
assert_eq!(body["usage"]["total_tokens"], 3);
}
#[tokio::test]
async fn chat_route_fails_closed_on_upstream_stream_error_event() {
let response = streaming_chat_response(
"chat-route-upstream-error",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
vec![MockStreamFrame::Error("model failed")],
)
.await;
assert_stream_body_fails(response).await;
}
#[tokio::test]
async fn chat_route_fails_closed_on_malformed_upstream_event() {
let response = streaming_chat_response(
"chat-route-malformed-event",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
vec![MockStreamFrame::Raw("data: {\"choices\":\n\n")],
)
.await;
assert_stream_body_fails(response).await;
}
#[tokio::test]
async fn chat_route_fails_closed_on_decryption_failure_mid_stream() {
let response = streaming_chat_response(
"chat-route-decryption-failure",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
vec![
MockStreamFrame::Text("Hello"),
MockStreamFrame::TextForWrongRecipient(" secret"),
MockStreamFrame::Done,
],
)
.await;
assert_stream_body_fails(response).await;
}
#[tokio::test]
async fn chat_route_synthesizes_final_finish_chunk_before_done_when_needed() {
let response = streaming_chat_response(
"chat-route-final-done",
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
vec![MockStreamFrame::Text("Hello"), MockStreamFrame::Done],
)
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response_body(response).await;
let data = sse_data(&body);
assert_eq!(data.len(), 3);
let final_chunk: Value = serde_json::from_str(data[1]).expect("final chunk should be JSON");
assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
assert_eq!(final_chunk["choices"][0]["finish_reason"], "stop");
assert_eq!(data[2], "[DONE]");
}
#[tokio::test]
async fn chat_route_attestation_failure_prevents_request_construction() {
let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
let base_url = spawn_attestation_server(model_public_key, false).await;
let app = router_with_venice_client(
chat_config_with_basic_test_attestation(),
test_venice_client_for_base_url(base_url),
);
let response = app
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header(HEADER_PROXY_SESSION_ID, "chat-route-attestation-failure")
.body(Body::from(
r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
))
.expect("request should build"),
)
.await
.expect("request should complete");
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
assert_eq!(
response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
"attestation_upstream_not_verified"
);
let body = error_body(response).await;
assert_eq!(body.error.kind, "proxy_attestation_error");
assert_eq!(body.error.code, "attestation_upstream_not_verified");
}
#[tokio::test]
async fn unknown_route_returns_openai_style_not_found() {
let response = test_app()
.oneshot(
Request::builder()
.uri("/v1/unknown")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("request should complete");
assert_eq!(response.status(), StatusCode::NOT_FOUND);
assert_eq!(
response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
"not_found"
);
let body = error_body(response).await;
assert_eq!(body.error.kind, "invalid_request_error");
assert_eq!(body.error.code, "not_found");
}
#[tokio::test]
async fn unsupported_method_returns_openai_style_method_error() {
let response = test_app()
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/v1/models")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("request should complete");
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(
response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
"method_not_allowed"
);
let body = error_body(response).await;
assert_eq!(body.error.kind, "invalid_request_error");
assert_eq!(body.error.code, "method_not_allowed");
}
#[tokio::test]
async fn malformed_chat_json_uses_axum_extractor_rejection() {
let response = test_app()
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from("{"))
.expect("request should build"),
)
.await
.expect("request should complete");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert!(response.headers().get(HEADER_PROXY_ERROR_CODE).is_none());
}
#[tokio::test]
async fn non_object_chat_json_returns_structured_invalid_request() {
let response = test_app()
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from("[]"))
.expect("request should build"),
)
.await
.expect("request should complete");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert_eq!(
response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
"invalid_request"
);
let body = error_body(response).await;
assert_eq!(body.error.kind, "invalid_request_error");
assert_eq!(body.error.code, "invalid_request");
}
#[derive(Debug, Clone)]
enum MockStreamFrame {
Role,
NullContent,
EmptyContent,
Text(&'static str),
Reasoning(&'static str),
TextForWrongRecipient(&'static str),
Finish(&'static str),
Usage,
Done,
Error(&'static str),
Raw(&'static str),
}
async fn streaming_chat_response(
session_id: &'static str,
request_body: &'static str,
frames: Vec<MockStreamFrame>,
) -> Response {
chat_response(session_id, request_body, frames).await
}
async fn chat_response(
session_id: &'static str,
request_body: &'static str,
frames: Vec<MockStreamFrame>,
) -> Response {
let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
let base_url = spawn_streaming_venice_server(model_public_key, true, frames).await;
request_chat(session_id, request_body, base_url).await
}
async fn chat_response_sequence(
session_id: &'static str,
request_body: &'static str,
attempts: Vec<Vec<MockStreamFrame>>,
) -> Response {
let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
let base_url =
spawn_streaming_venice_server_sequence(model_public_key, true, attempts).await;
request_chat(session_id, request_body, base_url).await
}
async fn chat_response_with_upstream_status(
session_id: &'static str,
request_body: &'static str,
upstream_status: StatusCode,
) -> Response {
let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
let base_url =
spawn_venice_server_with_chat_status(model_public_key, upstream_status).await;
request_chat(session_id, request_body, base_url).await
}
async fn request_chat(
session_id: &'static str,
request_body: &'static str,
base_url: String,
) -> Response {
request_chat_with_config(
chat_config_with_basic_test_attestation(),
session_id,
request_body,
base_url,
)
.await
}
async fn request_chat_with_config(
config: ProxyConfig,
session_id: &'static str,
request_body: &'static str,
base_url: String,
) -> Response {
let app = router_with_venice_client(config, test_venice_client_for_base_url(base_url));
app.oneshot(
Request::builder()
.method(Method::POST)
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header(HEADER_PROXY_SESSION_ID, session_id)
.body(Body::from(request_body))
.expect("request should build"),
)
.await
.expect("request should complete")
}
async fn json_body(response: Response) -> Value {
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body should buffer");
serde_json::from_slice(&bytes).expect("response should be JSON")
}
async fn response_body(response: Response) -> String {
let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body should buffer");
String::from_utf8(bytes.to_vec()).expect("response body should be UTF-8")
}
async fn assert_stream_body_fails(response: Response) {
assert_eq!(response.status(), StatusCode::OK);
let result = axum::body::to_bytes(response.into_body(), usize::MAX).await;
assert!(
result.is_err(),
"stream body should fail closed instead of completing successfully"
);
}
fn sse_data(body: &str) -> Vec<&str> {
body.lines()
.filter_map(|line| line.strip_prefix("data: "))
.collect()
}
fn sse_json_chunks(body: &str) -> Vec<Value> {
let data = sse_data(body);
assert_eq!(data.last().copied(), Some("[DONE]"));
data[..data.len() - 1]
.iter()
.map(|chunk| serde_json::from_str(chunk).expect("SSE chunk should be JSON"))
.collect()
}
fn streamed_content(chunks: &[Value]) -> String {
chunks
.iter()
.filter_map(|chunk| chunk["choices"][0]["delta"]["content"].as_str())
.collect()
}
fn streamed_tool_call_deltas(chunks: &[Value]) -> Vec<&Value> {
chunks
.iter()
.filter_map(|chunk| chunk["choices"][0]["delta"]["tool_calls"].as_array())
.flatten()
.collect()
}
fn streamed_tool_call_arguments(chunks: &[Value], index: u64) -> String {
streamed_tool_call_deltas(chunks)
.iter()
.filter(|tool_call| tool_call["index"] == json!(index))
.filter_map(|tool_call| tool_call["function"]["arguments"].as_str())
.collect()
}
async fn spawn_streaming_venice_server(
model_public_key: String,
verified: bool,
frames: Vec<MockStreamFrame>,
) -> String {
spawn_streaming_venice_server_sequence(model_public_key, verified, vec![frames]).await
}
async fn spawn_streaming_venice_server_sequence(
model_public_key: String,
verified: bool,
attempts: Vec<Vec<MockStreamFrame>>,
) -> String {
let chat_attempts = Arc::new(Mutex::new(VecDeque::from(attempts)));
let attestation_key = model_public_key.clone();
let app = Router::new()
.route(
"/api/v1/tee/attestation",
get(move |Query(query): Query<HashMap<String, String>>| {
let model_public_key = attestation_key.clone();
async move {
Json(json!({
"attestation": {
"verified": verified,
"nonce": query.get("nonce").cloned().unwrap_or_default(),
"model": query.get("model").cloned().unwrap_or_default(),
"tee_provider": "tdx",
"signing_key": model_public_key,
}
}))
}
}),
)
.route(
"/api/v1/chat/completions",
post(move |headers: HeaderMap, Json(body): Json<Value>| {
let chat_attempts = chat_attempts.clone();
async move {
let Some(client_public_key) = headers
.get(crate::venice::HEADER_VENICE_TEE_CLIENT_PUB_KEY)
.and_then(|value| value.to_str().ok())
else {
return (
StatusCode::BAD_REQUEST,
[("content-type", "text/plain")],
"missing client key".to_owned(),
);
};
if body.get("stream").and_then(Value::as_bool) != Some(true) {
return (
StatusCode::BAD_REQUEST,
[("content-type", "text/plain")],
"upstream request must stream".to_owned(),
);
}
let messages = body.get("messages").and_then(Value::as_array);
if messages.is_none_or(|messages| {
messages.is_empty()
|| !messages.iter().all(|message| {
message.get("role").and_then(Value::as_str).is_some()
&& message
.get("content")
.and_then(Value::as_str)
.is_some_and(|content| {
!content.is_empty()
&& content
.chars()
.all(|ch| ch.is_ascii_hexdigit())
})
})
}) {
return (
StatusCode::BAD_REQUEST,
[("content-type", "text/plain")],
"messages must be encrypted message objects".to_owned(),
);
}
let frames = {
let mut attempts = chat_attempts
.lock()
.expect("mock chat attempts mutex should not be poisoned");
if attempts.len() > 1 {
attempts.pop_front().expect("attempts length checked above")
} else {
attempts.front().cloned().unwrap_or_default()
}
};
(
StatusCode::OK,
[("content-type", "text/event-stream")],
render_mock_sse(&frames, client_public_key),
)
}
}),
);
let listener = TcpListener::bind(("127.0.0.1", 0))
.await
.expect("mock Venice listener should bind");
let addr = listener
.local_addr()
.expect("mock Venice listener should have local address");
tokio::spawn(async move {
axum::serve(listener, app)
.await
.expect("mock Venice server should run");
});
format!("http://{addr}/api/v1")
}
async fn spawn_venice_server_with_chat_status(
model_public_key: String,
upstream_status: StatusCode,
) -> String {
let attestation_key = model_public_key.clone();
let app = Router::new()
.route(
"/api/v1/tee/attestation",
get(move |Query(query): Query<HashMap<String, String>>| {
let model_public_key = attestation_key.clone();
async move {
Json(json!({
"attestation": {
"verified": true,
"nonce": query.get("nonce").cloned().unwrap_or_default(),
"model": query.get("model").cloned().unwrap_or_default(),
"tee_provider": "tdx",
"signing_key": model_public_key,
}
}))
}
}),
)
.route(
"/api/v1/chat/completions",
post(move || async move { upstream_status }),
);
let listener = TcpListener::bind(("127.0.0.1", 0))
.await
.expect("mock Venice listener should bind");
let addr = listener
.local_addr()
.expect("mock Venice listener should have local address");
tokio::spawn(async move {
axum::serve(listener, app)
.await
.expect("mock Venice server should run");
});
format!("http://{addr}/api/v1")
}
fn render_mock_sse(frames: &[MockStreamFrame], client_public_key: &str) -> String {
let codec = E2eeCodec::default();
let mut output = String::new();
for frame in frames {
match frame {
MockStreamFrame::Role => {
output.push_str(&format!("data: {}\n\n", upstream_role_chunk()));
}
MockStreamFrame::NullContent => {
output.push_str(&format!("data: {}\n\n", upstream_null_content_chunk()));
}
MockStreamFrame::EmptyContent => {
output.push_str(&format!(
"data: {}\n\n",
upstream_content_chunk(String::new())
));
}
MockStreamFrame::Text(content) => {
let encrypted = codec
.encrypt_content(content, client_public_key)
.expect("mock content should encrypt")
.into_hex();
output.push_str(&format!("data: {}\n\n", upstream_content_chunk(encrypted)));
}
MockStreamFrame::Reasoning(content) => {
let encrypted = codec
.encrypt_content(content, client_public_key)
.expect("mock reasoning content should encrypt")
.into_hex();
output.push_str(&format!(
"data: {}\n\n",
upstream_reasoning_content_chunk(encrypted)
));
}
MockStreamFrame::TextForWrongRecipient(content) => {
let wrong_key = ProxyInstanceKey::generate();
let encrypted = codec
.encrypt_content(content, wrong_key.public_key_hex())
.expect("mock content should encrypt")
.into_hex();
output.push_str(&format!("data: {}\n\n", upstream_content_chunk(encrypted)));
}
MockStreamFrame::Finish(reason) => {
output.push_str(&format!("data: {}\n\n", upstream_finish_chunk(reason)));
}
MockStreamFrame::Usage => {
output.push_str(&format!("data: {}\n\n", upstream_usage_chunk()));
}
MockStreamFrame::Done => output.push_str("data: [DONE]\n\n"),
MockStreamFrame::Error(message) => {
output.push_str(&format!(
"event: error\ndata: {}\n\n",
json!({ "message": message })
));
}
MockStreamFrame::Raw(raw) => output.push_str(raw),
}
}
output
}
fn upstream_role_chunk() -> Value {
json!({
"id": "chatcmpl-upstream-test",
"object": "chat.completion.chunk",
"created": 1_717_171_717,
"model": "e2ee-test",
"choices": [{
"index": 0,
"delta": { "role": "assistant" },
"finish_reason": null,
}],
})
}
fn upstream_content_chunk(encrypted_content: String) -> Value {
json!({
"id": "chatcmpl-upstream-test",
"object": "chat.completion.chunk",
"created": 1_717_171_717,
"model": "e2ee-test",
"choices": [{
"index": 0,
"delta": { "content": encrypted_content },
"finish_reason": null,
}],
})
}
fn upstream_reasoning_content_chunk(encrypted_content: String) -> Value {
json!({
"id": "chatcmpl-upstream-test",
"object": "chat.completion.chunk",
"created": 1_717_171_717,
"model": "e2ee-test",
"choices": [{
"index": 0,
"delta": { "reasoning_content": encrypted_content },
"finish_reason": null,
}],
})
}
fn upstream_null_content_chunk() -> Value {
json!({
"id": "chatcmpl-upstream-test",
"object": "chat.completion.chunk",
"created": 1_717_171_717,
"model": "e2ee-test",
"choices": [{
"index": 0,
"delta": { "content": Value::Null },
"finish_reason": null,
}],
})
}
fn upstream_finish_chunk(reason: &str) -> Value {
json!({
"id": "chatcmpl-upstream-test",
"object": "chat.completion.chunk",
"created": 1_717_171_717,
"model": "e2ee-test",
"choices": [{
"index": 0,
"delta": {},
"finish_reason": reason,
}],
})
}
fn upstream_usage_chunk() -> Value {
json!({
"id": "chatcmpl-upstream-test",
"object": "chat.completion.chunk",
"created": 1_717_171_717,
"model": "e2ee-test",
"choices": [],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
},
})
}
async fn spawn_attestation_server(model_public_key: String, verified: bool) -> String {
let app = Router::new().route(
"/api/v1/tee/attestation",
get(move |Query(query): Query<HashMap<String, String>>| {
let model_public_key = model_public_key.clone();
async move {
Json(json!({
"attestation": {
"verified": verified,
"nonce": query.get("nonce").cloned().unwrap_or_default(),
"model": query.get("model").cloned().unwrap_or_default(),
"signing_key": model_public_key,
}
}))
}
}),
);
let listener = TcpListener::bind(("127.0.0.1", 0))
.await
.expect("mock attestation listener should bind");
let addr = listener
.local_addr()
.expect("mock attestation listener should have local address");
tokio::spawn(async move {
axum::serve(listener, app)
.await
.expect("mock attestation server should run");
});
format!("http://{addr}/api/v1")
}
#[test]
fn metadata_header_helper_only_emits_safe_config_headers_by_default() {
let config = ProxyConfig::default();
let metadata = ProxyMetadataHeaders::from_config(&config);
let mut headers = HeaderMap::new();
metadata.apply(&mut headers);
assert_eq!(
headers.get(HEADER_PROXY_ATTESTATION_MODE).unwrap(),
"independent"
);
assert_eq!(headers.get(HEADER_PROXY_TOOL_MODE).unwrap(), "emulated");
assert!(headers.get(HEADER_PROXY_E2EE).is_none());
assert!(headers.get(HEADER_PROXY_KEY_BINDING).is_none());
}
}