1use std::{
7 io,
8 sync::Arc,
9 time::{SystemTime, UNIX_EPOCH},
10};
11
12use axum::{
13 Json, Router,
14 extract::State,
15 http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri},
16 response::{
17 IntoResponse, Response,
18 sse::{Event, Sse},
19 },
20 routing::{get, post},
21};
22use serde_json::{Value, json};
23use thiserror::Error;
24use tokio::net::TcpListener;
25use tracing::{debug, error, info, warn};
26
27use crate::{
28 attestation::{AttestationError, AttestationVerifier},
29 config::{NvidiaRequirement, ProxyConfig},
30 e2ee::{E2eeCodec, E2eeCodecError},
31 keys::ProxyInstanceKey,
32 openai::{
33 ErrorResponse,
34 chat::{
35 ChatCompletionRequest, ChatConstructionError, ChatRequestError, NormalizedChatMessage,
36 },
37 },
38 sessions::{AttestedModelState, SessionContext, SessionError, SessionManager, SessionRequest},
39 tools::{ToolEmulationContext, ToolOutputClassification, ValidatedToolCall},
40 venice::{VeniceClient, VeniceClientError},
41};
42
43pub const HEADER_PROXY_E2EE: &str = "X-Venice-Proxy-E2EE";
44pub const HEADER_PROXY_ATTESTATION_MODE: &str = "X-Venice-Proxy-Attestation-Mode";
45pub const HEADER_PROXY_ATTESTED_MODEL: &str = "X-Venice-Proxy-Attested-Model";
46pub const HEADER_PROXY_TEE_PROVIDER: &str = "X-Venice-Proxy-TEE-Provider";
47pub const HEADER_PROXY_TDX_VERIFIED: &str = "X-Venice-Proxy-TDX-Verified";
48pub const HEADER_PROXY_TDX_DEBUG: &str = "X-Venice-Proxy-TDX-Debug";
49pub const HEADER_PROXY_NVIDIA_VERIFIED: &str = "X-Venice-Proxy-NVIDIA-Verified";
50pub const HEADER_PROXY_KEY_BINDING: &str = "X-Venice-Proxy-Key-Binding";
51pub const HEADER_PROXY_SESSION_ID: &str = "X-Venice-Proxy-Session-Id";
52pub const HEADER_PROXY_SESSION_SCOPE: &str = "X-Venice-Proxy-Session-Scope";
53pub const HEADER_PROXY_TOOL_MODE: &str = "X-Venice-Proxy-Tool-Mode";
54pub const HEADER_PROXY_TOOL_RETRIES: &str = "X-Venice-Proxy-Tool-Retries";
55pub const HEADER_PROXY_ERROR_CODE: &str = "X-Venice-Proxy-Error-Code";
56
57#[derive(Debug, Clone)]
59pub struct AppState {
60 config: Arc<ProxyConfig>,
61 venice_client: VeniceClient,
62 proxy_instance_key: Option<ProxyInstanceKey>,
63 session_manager: SessionManager,
64 attestation_verifier: AttestationVerifier,
65}
66
67impl AppState {
68 pub fn new(config: ProxyConfig) -> Result<Self, VeniceClientError> {
70 let venice_client = VeniceClient::from_config(&config)?;
71 Ok(Self::from_parts(config, venice_client))
72 }
73
74 pub fn from_parts(config: ProxyConfig, venice_client: VeniceClient) -> Self {
76 let proxy_instance_key = ProxyInstanceKey::generate_from_config(&config.keys);
77 let session_manager = SessionManager::new(config.session.clone());
78 let attestation_verifier = AttestationVerifier::from_config(&config, venice_client.clone());
79
80 Self {
81 config: Arc::new(config),
82 venice_client,
83 proxy_instance_key,
84 session_manager,
85 attestation_verifier,
86 }
87 }
88
89 pub fn config(&self) -> &ProxyConfig {
91 &self.config
92 }
93
94 pub fn venice_client(&self) -> &VeniceClient {
96 &self.venice_client
97 }
98
99 pub fn proxy_instance_key(&self) -> Option<&ProxyInstanceKey> {
101 self.proxy_instance_key.as_ref()
102 }
103
104 pub fn session_manager(&self) -> &SessionManager {
106 &self.session_manager
107 }
108
109 pub fn attestation_verifier(&self) -> &AttestationVerifier {
111 &self.attestation_verifier
112 }
113}
114
115pub fn router(config: ProxyConfig) -> Result<Router, VeniceClientError> {
118 Ok(router_from_state(AppState::new(config)?))
119}
120
121pub fn router_with_venice_client(config: ProxyConfig, venice_client: VeniceClient) -> Router {
126 router_from_state(AppState::from_parts(config, venice_client))
127}
128
129fn router_from_state(state: AppState) -> Router {
131 Router::new()
132 .route("/v1/models", get(list_models).fallback(method_not_allowed))
133 .route(
134 "/v1/chat/completions",
135 post(create_chat_completion).fallback(method_not_allowed),
136 )
137 .fallback(not_found)
138 .with_state(state)
139}
140
141pub async fn serve(listener: TcpListener, router: Router) -> io::Result<()> {
143 axum::serve(listener, router).await
144}
145
146async fn list_models(State(state): State<AppState>) -> Result<Response, ProxyError> {
148 info!(route = "/v1/models", "listing Venice models");
149 let models = state.venice_client().list_models().await?;
150 let mut response = Json(models).into_response();
151 ProxyMetadataHeaders::from_config(state.config()).apply(response.headers_mut());
152 info!(route = "/v1/models", "Venice models response proxied");
153 Ok(response)
154}
155
156async fn create_chat_completion(
158 State(state): State<AppState>,
159 headers: HeaderMap,
160 Json(body): Json<Value>,
161) -> Result<Response, ProxyError> {
162 let request = ChatCompletionRequest::parse(&body)?;
163 let proxy_instance_key = state
164 .proxy_instance_key()
165 .ok_or(ProxyError::ProxyInstanceKeyUnavailable)?;
166
167 let session_resolution = state
168 .session_manager()
169 .get_or_create(SessionRequest::new(&request.model, &headers).with_body(&body))?;
170 let session_created = session_resolution.created;
171 let session_replaced_expired = session_resolution.replaced_expired;
172 let session_scope = session_resolution.session.scope;
173 let session = ensure_attested_session(&state, session_resolution.session).await?;
174 let model_public_key = session
175 .attested_model_public_key
176 .as_deref()
177 .ok_or(ProxyError::MissingAttestedModelKey)?;
178
179 let codec =
180 E2eeCodec::from_config(&state.config().e2ee).map_err(ChatConstructionError::E2ee)?;
181 let tool_context = ToolEmulationContext::from_request(&state.config().tools, &request)?;
182 let metadata = ProxyMetadataHeaders::for_verified_chat(state.config(), &session);
183
184 info!(
185 route = "/v1/chat/completions",
186 model = %request.model,
187 stream = request.stream,
188 message_count = request.messages.len(),
189 tool_count = request.tools.len(),
190 tool_mode = tool_context.is_some(),
191 session_created,
192 session_replaced_expired = ?session_replaced_expired,
193 session_scope = %session_scope,
194 "chat completion request accepted"
195 );
196
197 if let Some(tool_context) = tool_context {
198 info!(model = %request.model, "using tool-emulated chat completion");
199 return openai_tool_emulated_chat_response(
200 &state,
201 &request,
202 &tool_context,
203 codec,
204 proxy_instance_key.clone(),
205 model_public_key,
206 metadata,
207 )
208 .await;
209 }
210
211 let prepared = request.to_venice_e2ee_request(&codec, model_public_key)?;
212 info!(
213 model = %request.model,
214 client_stream = prepared.client_stream,
215 "forwarding encrypted chat completion to Venice"
216 );
217
218 let upstream = state
219 .venice_client()
220 .create_chat_completion_stream(
221 &prepared.upstream,
222 proxy_instance_key.public_key_hex(),
223 model_public_key,
224 )
225 .await?;
226
227 if prepared.client_stream {
228 info!(model = %request.model, "streaming chat completion response to client");
229 let include_usage_requested = request.stream_options.include_usage.unwrap_or(false);
230 let transformer = OpenAiChatStreamTransformer::new(
231 codec,
232 proxy_instance_key.clone(),
233 request.model.clone(),
234 include_usage_requested,
235 );
236 Ok(chat_sse_response(
237 upstream,
238 transformer,
239 request.model,
240 include_usage_requested,
241 &CHAT_SSE_LOG,
242 metadata,
243 ))
244 } else {
245 info!(model = %request.model, "buffering chat completion response for client");
246 openai_chat_buffered_response(
247 upstream,
248 codec,
249 proxy_instance_key.clone(),
250 request.model,
251 metadata,
252 )
253 .await
254 }
255}
256
257async fn ensure_attested_session(
259 state: &AppState,
260 session: SessionContext,
261) -> Result<SessionContext, ProxyError> {
262 if session.attested_model_public_key.is_some() {
263 info!(model = %session.model_id, session_scope = %session.scope, "using cached model attestation");
264 return Ok(session);
265 }
266
267 info!(model = %session.model_id, session_scope = %session.scope, "fetching model attestation");
268 let attestation = state
269 .attestation_verifier()
270 .verify_model_attestation(&session.model_id)
271 .await?;
272
273 info!(
274 model = %attestation.model_id,
275 tee_provider = attestation.tee_provider.as_deref().unwrap_or("unknown"),
276 tdx_verified = attestation.tdx.verified,
277 nvidia_verified = attestation.nvidia.verified.as_header_value(),
278 "model attestation verified"
279 );
280
281 let state_update = AttestedModelState {
282 model_public_key: attestation.model_public_key,
283 attestation_report: attestation.attestation_report,
284 verified_at: attestation.verified_at,
285 };
286
287 Ok(state
288 .session_manager()
289 .set_attested_model_state(&session.session_key, state_update)?)
290}
291
292async fn openai_chat_buffered_response(
294 upstream: reqwest::Response,
295 codec: E2eeCodec,
296 proxy_instance_key: ProxyInstanceKey,
297 fallback_model: String,
298 metadata: ProxyMetadataHeaders,
299) -> Result<Response, ProxyError> {
300 let completion =
301 buffer_openai_chat_completion(upstream, codec, proxy_instance_key, fallback_model).await?;
302 let mut response = Json(completion).into_response();
303 metadata.apply(response.headers_mut());
304 Ok(response)
305}
306
307async fn openai_tool_emulated_chat_response(
309 state: &AppState,
310 request: &ChatCompletionRequest,
311 tool_context: &ToolEmulationContext,
312 codec: E2eeCodec,
313 proxy_instance_key: ProxyInstanceKey,
314 model_public_key: &str,
315 metadata: ProxyMetadataHeaders,
316) -> Result<Response, ProxyError> {
317 info!(
318 model = %request.model,
319 max_retries = tool_context.max_retries(),
320 "starting tool-emulated chat completion"
321 );
322 if request.stream {
323 let upstream = tool_emulated_upstream_stream(
324 state,
325 request,
326 tool_context,
327 &codec,
328 &proxy_instance_key,
329 model_public_key,
330 None,
331 )
332 .await?;
333
334 let include_usage_requested = request.stream_options.include_usage.unwrap_or(false);
335 let transformer = OpenAiToolEmulatedChatStreamTransformer::new(
336 tool_context,
337 codec,
338 proxy_instance_key,
339 request.model.clone(),
340 include_usage_requested,
341 )
342 .map_err(ProxyError::ChatStream)?;
343 return Ok(chat_sse_response(
344 upstream,
345 transformer,
346 request.model.clone(),
347 include_usage_requested,
348 &TOOL_EMULATED_CHAT_SSE_LOG,
349 metadata,
350 ));
351 }
352
353 let mut retries = 0;
354 let mut correction: Option<(String, String)> = None;
355
356 loop {
357 let upstream = tool_emulated_upstream_stream(
358 state,
359 request,
360 tool_context,
361 &codec,
362 &proxy_instance_key,
363 model_public_key,
364 correction.as_ref(),
365 )
366 .await?;
367
368 let completion = match tokio::time::timeout(
369 tool_context.marker_timeout(),
370 buffer_openai_chat_completion(
371 upstream,
372 codec.clone(),
373 proxy_instance_key.clone(),
374 request.model.clone(),
375 ),
376 )
377 .await
378 {
379 Ok(completion) => completion?,
380 Err(_) => {
381 let validation_error = format!(
382 "tool-emulated completion did not finish within {}",
383 humantime::format_duration(tool_context.config().tool_call_marker_timeout)
384 );
385 if retries >= tool_context.max_retries() {
386 return Err(ProxyError::ToolCallRetryExhausted {
387 max_retries: tool_context.max_retries(),
388 last_validation_error: validation_error,
389 });
390 }
391 warn!(
392 model = %request.model,
393 retry = retries + 1,
394 max_retries = tool_context.max_retries(),
395 "tool call marker timed out; retrying with correction"
396 );
397 retries += 1;
398 correction = Some((validation_error, String::new()));
399 continue;
400 }
401 };
402 let assistant_content = completion
403 .get("choices")
404 .and_then(Value::as_array)
405 .and_then(|choices| choices.first())
406 .and_then(|choice| choice.get("message"))
407 .and_then(|message| message.get("content"))
408 .and_then(Value::as_str)
409 .unwrap_or_default();
410
411 let mut metadata = metadata.clone();
412 if retries > 0 {
413 metadata.tool_retries = Some(retries);
414 }
415
416 match tool_context.classify_assistant_output(assistant_content) {
417 ToolOutputClassification::NormalText => {
418 info!(model = %request.model, retries, "tool emulation produced normal text");
419 let mut response = Json(completion).into_response();
420 metadata.apply(response.headers_mut());
421 return Ok(response);
422 }
423 ToolOutputClassification::ToolCalls(tool_calls) => {
424 info!(
425 model = %request.model,
426 tool_calls = tool_calls.len(),
427 retries,
428 "tool emulation produced tool calls"
429 );
430 let body = openai_tool_call_completion(completion, tool_calls);
431 let mut response = Json(body).into_response();
432 metadata.apply(response.headers_mut());
433 return Ok(response);
434 }
435 ToolOutputClassification::InvalidToolCall {
436 error,
437 invalid_output,
438 } => {
439 if retries >= tool_context.max_retries() {
440 warn!(
441 model = %request.model,
442 max_retries = tool_context.max_retries(),
443 validation_error = %error,
444 "tool call validation failed and retries were exhausted"
445 );
446 return Err(ProxyError::ToolCallRetryExhausted {
447 max_retries: tool_context.max_retries(),
448 last_validation_error: error.to_string(),
449 });
450 }
451 warn!(
452 model = %request.model,
453 retry = retries + 1,
454 max_retries = tool_context.max_retries(),
455 validation_error = %error,
456 "tool call validation failed; retrying with correction"
457 );
458 retries += 1;
459 correction = Some((error.to_string(), invalid_output));
460 }
461 }
462 }
463}
464
465async fn tool_emulated_upstream_stream(
468 state: &AppState,
469 request: &ChatCompletionRequest,
470 tool_context: &ToolEmulationContext,
471 codec: &E2eeCodec,
472 proxy_instance_key: &ProxyInstanceKey,
473 model_public_key: &str,
474 correction: Option<&(String, String)>,
475) -> Result<reqwest::Response, ProxyError> {
476 let messages = tool_emulated_messages(request, tool_context, correction);
477 let mut tool_request = request.clone();
478 tool_request.messages = messages;
479
480 let prepared = tool_request.to_venice_e2ee_request(codec, model_public_key)?;
481
482 Ok(state
483 .venice_client()
484 .create_chat_completion_stream(
485 &prepared.upstream,
486 proxy_instance_key.public_key_hex(),
487 model_public_key,
488 )
489 .await?)
490}
491
492fn tool_emulated_messages(
494 request: &ChatCompletionRequest,
495 tool_context: &ToolEmulationContext,
496 correction: Option<&(String, String)>,
497) -> Vec<NormalizedChatMessage> {
498 let mut messages = request.messages.clone();
499 let mut tool_system_content = tool_context.controller_message().content;
500
501 if let Some((validation_error, invalid_output)) = correction {
502 tool_system_content.push_str("\n\n");
503 tool_system_content.push_str(
504 &tool_context
505 .correction_message(validation_error, invalid_output)
506 .content,
507 );
508 }
509
510 append_to_system_message(&mut messages, tool_system_content);
511 messages
512}
513
514fn append_to_system_message(messages: &mut Vec<NormalizedChatMessage>, content: String) {
516 if let Some(system_message) = messages.iter_mut().find(|message| message.role == "system") {
517 system_message.content.push_str("\n\n");
518 system_message.content.push_str(&content);
519 } else {
520 messages.insert(0, NormalizedChatMessage::new("system", content));
521 }
522}
523
524fn openai_tool_call_completion(completion: Value, tool_calls: Vec<ValidatedToolCall>) -> Value {
526 let choice = completion
527 .get("choices")
528 .and_then(Value::as_array)
529 .and_then(|choices| choices.first())
530 .cloned()
531 .unwrap_or(Value::Null);
532 let index = choice.get("index").and_then(Value::as_u64).unwrap_or(0);
533 let tool_call_values: Vec<Value> = tool_calls
534 .iter()
535 .map(ValidatedToolCall::to_openai_value)
536 .collect();
537 let reasoning_content = choice
538 .get("message")
539 .and_then(|message| message.get("reasoning_content"))
540 .and_then(Value::as_str);
541 let mut message = serde_json::Map::new();
542 message.insert("role".to_owned(), Value::String("assistant".to_owned()));
543 message.insert("content".to_owned(), Value::Null);
544 if let Some(reasoning_content) = reasoning_content {
545 message.insert(
546 "reasoning_content".to_owned(),
547 Value::String(reasoning_content.to_owned()),
548 );
549 }
550 message.insert("tool_calls".to_owned(), Value::Array(tool_call_values));
551
552 json!({
553 "id": string_field(&completion, "id").unwrap_or("chatcmpl-local"),
554 "object": "chat.completion",
555 "created": integer_field(&completion, "created").unwrap_or_else(unix_timestamp_now),
556 "model": string_field(&completion, "model").unwrap_or("unknown"),
557 "choices": [{
558 "index": index,
559 "message": Value::Object(message),
560 "finish_reason": "tool_calls",
561 }],
562 "usage": completion.get("usage").cloned().unwrap_or(Value::Null),
563 })
564}
565
566async fn buffer_openai_chat_completion(
568 mut upstream: reqwest::Response,
569 codec: E2eeCodec,
570 proxy_instance_key: ProxyInstanceKey,
571 fallback_model: String,
572) -> Result<Value, ChatStreamError> {
573 info!(model = %fallback_model, "buffering upstream chat stream");
574 let mut parser = SseEventParser::default();
575 let mut transformer =
576 OpenAiChatCompletionBuffer::new(codec, proxy_instance_key, fallback_model.clone());
577 let mut upstream_done = false;
578 let mut chunk_count = 0_u64;
579 let mut event_count = 0_u64;
580
581 while let Some(chunk) = upstream
582 .chunk()
583 .await
584 .map_err(ChatStreamError::upstream_stream)?
585 {
586 chunk_count += 1;
587 let chunk = std::str::from_utf8(&chunk).map_err(ChatStreamError::invalid_utf8)?;
588 let events = parser.push(chunk)?;
589 event_count += events.len() as u64;
590 debug!(
591 model = %fallback_model,
592 chunk_count,
593 parsed_events = events.len(),
594 total_events = event_count,
595 "parsed buffered upstream SSE chunk"
596 );
597
598 for event in events {
599 if transformer.handle_event(event)? {
600 upstream_done = true;
601 break;
602 }
603 }
604
605 if upstream_done {
606 break;
607 }
608 }
609
610 if !upstream_done {
611 warn!(
612 model = %fallback_model,
613 chunk_count,
614 event_count,
615 "buffered upstream stream ended before DONE"
616 );
617 parser.finish()?;
618 return Err(ChatStreamError::malformed_event(
619 "upstream stream ended before data: [DONE]",
620 ));
621 }
622
623 let completion = transformer.into_response();
624 info!(
625 model = %fallback_model,
626 chunk_count,
627 event_count,
628 "buffered upstream chat stream transformed"
629 );
630 Ok(completion)
631}
632
633struct ChatSseLogMessages {
636 start: &'static str,
637 parsed_chunk: &'static str,
638 transformed_event: &'static str,
639 completed: &'static str,
640 ended_early: &'static str,
641}
642
643const CHAT_SSE_LOG: ChatSseLogMessages = ChatSseLogMessages {
644 start: "starting upstream chat SSE transformation",
645 parsed_chunk: "parsed streaming upstream SSE chunk",
646 transformed_event: "transformed streaming upstream SSE event",
647 completed: "completed upstream chat SSE transformation",
648 ended_early: "streaming upstream stream ended before DONE",
649};
650
651const TOOL_EMULATED_CHAT_SSE_LOG: ChatSseLogMessages = ChatSseLogMessages {
652 start: "starting tool-emulated upstream chat SSE transformation",
653 parsed_chunk: "parsed tool-emulated upstream SSE chunk",
654 transformed_event: "transformed tool-emulated upstream SSE event",
655 completed: "completed tool-emulated upstream chat SSE transformation",
656 ended_early: "tool-emulated upstream stream ended before DONE",
657};
658
659trait ChatSseTransformer {
661 fn handle_event(&mut self, event: RawSseEvent) -> Result<Vec<StreamOutput>, ChatStreamError>;
663}
664
665fn chat_sse_response<T>(
667 upstream: reqwest::Response,
668 transformer: T,
669 fallback_model: String,
670 include_usage_requested: bool,
671 log: &'static ChatSseLogMessages,
672 metadata: ProxyMetadataHeaders,
673) -> Response
674where
675 T: ChatSseTransformer + Send + 'static,
676{
677 let stream = chat_sse_event_stream(
678 upstream,
679 transformer,
680 fallback_model,
681 include_usage_requested,
682 log,
683 );
684 let mut response = Sse::new(stream).into_response();
685 metadata.apply(response.headers_mut());
686 response
687}
688
689fn chat_sse_event_stream<T>(
691 mut upstream: reqwest::Response,
692 mut transformer: T,
693 fallback_model: String,
694 include_usage_requested: bool,
695 log: &'static ChatSseLogMessages,
696) -> impl futures_core::Stream<Item = Result<Event, axum::BoxError>>
697where
698 T: ChatSseTransformer + Send + 'static,
699{
700 async_stream::try_stream! {
701 info!(
702 model = %fallback_model,
703 include_usage_requested,
704 "{}", log.start
705 );
706 let mut parser = SseEventParser::default();
707 let mut upstream_done = false;
708 let mut chunk_count = 0_u64;
709 let mut event_count = 0_u64;
710 let mut output_count = 0_u64;
711
712 while let Some(chunk) = upstream
713 .chunk()
714 .await
715 .map_err(ChatStreamError::upstream_stream)
716 .map_err(box_chat_stream_error)?
717 {
718 chunk_count += 1;
719 let chunk = std::str::from_utf8(&chunk)
720 .map_err(ChatStreamError::invalid_utf8)
721 .map_err(box_chat_stream_error)?;
722 let events = parser.push(chunk).map_err(box_chat_stream_error)?;
723 event_count += events.len() as u64;
724 debug!(
725 model = %fallback_model,
726 chunk_count,
727 parsed_events = events.len(),
728 total_events = event_count,
729 "{}", log.parsed_chunk
730 );
731
732 for event in events {
733 let outputs = transformer.handle_event(event).map_err(box_chat_stream_error)?;
734 output_count += outputs.len() as u64;
735 debug!(
736 model = %fallback_model,
737 emitted_outputs = outputs.len(),
738 total_outputs = output_count,
739 "{}", log.transformed_event
740 );
741
742 for output in outputs {
743 match output {
744 StreamOutput::Json(value) => yield Event::default().data(value.to_string()),
745 StreamOutput::Done => {
746 upstream_done = true;
747 info!(
748 model = %fallback_model,
749 chunk_count,
750 event_count,
751 output_count,
752 "{}", log.completed
753 );
754 yield Event::default().data("[DONE]");
755 break;
756 }
757 }
758 }
759
760 if upstream_done {
761 break;
762 }
763 }
764
765 if upstream_done {
766 break;
767 }
768 }
769
770 if !upstream_done {
771 warn!(
772 model = %fallback_model,
773 chunk_count,
774 event_count,
775 output_count,
776 "{}", log.ended_early
777 );
778 parser.finish().map_err(box_chat_stream_error)?;
779 Err::<(), axum::BoxError>(box_chat_stream_error(ChatStreamError::malformed_event(
780 "upstream stream ended before data: [DONE]",
781 )))?;
782 }
783 }
784}
785
786fn box_chat_stream_error(error: ChatStreamError) -> axum::BoxError {
788 error!(error = %error, "chat stream transformation failed");
789 Box::new(error)
790}
791
792#[derive(Debug, Default)]
794struct SseEventParser {
795 buffer: String,
796}
797
798impl SseEventParser {
799 fn push(&mut self, chunk: &str) -> Result<Vec<RawSseEvent>, ChatStreamError> {
801 self.buffer.push_str(chunk);
802 let mut events = Vec::new();
803
804 while let Some((boundary_start, boundary_len)) = sse_event_boundary(&self.buffer) {
805 let raw = self.buffer[..boundary_start].to_owned();
806 self.buffer.drain(..boundary_start + boundary_len);
807 if let Some(event) = parse_sse_event(&raw)? {
808 events.push(event);
809 }
810 }
811
812 debug!(
813 chunk_bytes = chunk.len(),
814 buffered_bytes = self.buffer.len(),
815 parsed_events = events.len(),
816 "SSE parser processed upstream chunk"
817 );
818 Ok(events)
819 }
820
821 fn finish(&self) -> Result<(), ChatStreamError> {
823 if self.buffer.trim().is_empty() {
824 Ok(())
825 } else {
826 warn!(
827 buffered_bytes = self.buffer.len(),
828 "upstream SSE stream ended with incomplete event"
829 );
830 Err(ChatStreamError::malformed_event(
831 "upstream stream ended with an incomplete SSE event",
832 ))
833 }
834 }
835}
836
837#[derive(Debug, Clone, PartialEq, Eq)]
839struct RawSseEvent {
840 event: Option<String>,
841 data: String,
842}
843
844struct UpstreamEventLogMessages {
848 event: &'static str,
849 sse_error: &'static str,
850 done: &'static str,
851 parsing: Option<&'static str>,
852 json_error: &'static str,
853 missing_choices: &'static str,
854 parsed: Option<&'static str>,
855 unexpected_choice_count: &'static str,
856}
857
858const BUFFERED_UPSTREAM_EVENT_LOG: UpstreamEventLogMessages = UpstreamEventLogMessages {
859 event: "buffering upstream SSE event",
860 sse_error: "upstream SSE error event while buffering response",
861 done: "received upstream DONE while buffering response",
862 parsing: Some("parsing buffered upstream chat JSON chunk"),
863 json_error: "upstream JSON error chunk while buffering response",
864 missing_choices: "buffered upstream chat chunk is missing choices array",
865 parsed: Some("parsed buffered upstream chat chunk"),
866 unexpected_choice_count: "unexpected buffered upstream choice count",
867};
868
869const STREAMING_UPSTREAM_EVENT_LOG: UpstreamEventLogMessages = UpstreamEventLogMessages {
870 event: "transforming streaming upstream SSE event",
871 sse_error: "upstream SSE error event while streaming response",
872 done: "received upstream DONE while streaming response",
873 parsing: Some("parsing streaming upstream chat JSON chunk"),
874 json_error: "upstream JSON error chunk while streaming response",
875 missing_choices: "streaming upstream chat chunk is missing choices array",
876 parsed: Some("parsed streaming upstream chat chunk"),
877 unexpected_choice_count: "unexpected streaming upstream choice count",
878};
879
880const TOOL_EMULATED_UPSTREAM_EVENT_LOG: UpstreamEventLogMessages = UpstreamEventLogMessages {
881 event: "transforming tool-emulated streaming upstream SSE event",
882 sse_error: "upstream SSE error event while streaming tool-emulated response",
883 done: "received upstream DONE while streaming tool-emulated response",
884 parsing: None,
885 json_error: "upstream JSON error chunk while streaming tool-emulated response",
886 missing_choices: "tool-emulated upstream chat chunk is missing choices array",
887 parsed: None,
888 unexpected_choice_count: "unexpected tool-emulated upstream choice count",
889};
890
891enum UpstreamEventKind {
893 Done,
895 Usage(Value),
897 Choice { value: Value, choice: Value },
899}
900
901fn classify_upstream_event(
905 event: RawSseEvent,
906 log: &UpstreamEventLogMessages,
907) -> Result<UpstreamEventKind, ChatStreamError> {
908 let event_type = event.event.as_deref().unwrap_or("message");
909 let is_done = event.data.trim() == "[DONE]";
910 debug!(event_type, is_done, "{}", log.event);
911
912 if event.event.as_deref() == Some("error") {
913 warn!("{}", log.sse_error);
914 return Err(ChatStreamError::upstream_event(event.data));
915 }
916
917 if is_done {
918 info!("{}", log.done);
919 return Ok(UpstreamEventKind::Done);
920 }
921
922 if let Some(parsing) = log.parsing {
923 debug!("{}", parsing);
924 }
925 let value: Value = serde_json::from_str(&event.data).map_err(ChatStreamError::json_event)?;
926 if let Some(error) = value.get("error") {
927 warn!("{}", log.json_error);
928 return Err(ChatStreamError::upstream_event(error.to_string()));
929 }
930
931 let Some(choices) = value.get("choices").and_then(Value::as_array) else {
932 warn!("{}", log.missing_choices);
933 return Err(ChatStreamError::malformed_event(
934 "upstream chat chunk is missing choices array",
935 ));
936 };
937 if let Some(parsed) = log.parsed {
938 debug!(choice_count = choices.len(), "{}", parsed);
939 }
940
941 if choices.is_empty() {
942 return Ok(UpstreamEventKind::Usage(value));
943 }
944 if choices.len() != 1 {
945 warn!(
946 choice_count = choices.len(),
947 "{}", log.unexpected_choice_count
948 );
949 return Err(ChatStreamError::malformed_event(format!(
950 "expected exactly one upstream choice, got {}",
951 choices.len(),
952 )));
953 }
954
955 let choice = choices[0].clone();
956 Ok(UpstreamEventKind::Choice { value, choice })
957}
958
959struct ChunkContext {
962 codec: E2eeCodec,
963 proxy_instance_key: ProxyInstanceKey,
964 fallback_id: String,
965 fallback_created: i64,
966 fallback_model: String,
967}
968
969impl ChunkContext {
970 fn new(codec: E2eeCodec, proxy_instance_key: ProxyInstanceKey, fallback_model: String) -> Self {
972 Self {
973 codec,
974 proxy_instance_key,
975 fallback_id: format!("chatcmpl-local-{}", uuid::Uuid::new_v4()),
976 fallback_created: unix_timestamp_now(),
977 fallback_model,
978 }
979 }
980
981 fn decrypt(&self, content: Option<&str>) -> Result<Option<String>, ChatStreamError> {
983 self.codec
984 .decrypt_response_content(content, self.proxy_instance_key.private_key())
985 .map_err(ChatStreamError::decryption)
986 }
987
988 fn chunk_with_choice(
990 &self,
991 upstream: &Value,
992 index: u64,
993 delta: Value,
994 finish_reason: Value,
995 ) -> Value {
996 json!({
997 "id": string_field(upstream, "id").unwrap_or(&self.fallback_id),
998 "object": string_field(upstream, "object").unwrap_or("chat.completion.chunk"),
999 "created": integer_field(upstream, "created").unwrap_or(self.fallback_created),
1000 "model": string_field(upstream, "model").unwrap_or(&self.fallback_model),
1001 "choices": [{
1002 "index": index,
1003 "delta": delta,
1004 "finish_reason": finish_reason,
1005 }],
1006 })
1007 }
1008
1009 fn usage_chunk(&self, upstream: &Value, usage: &Value) -> Value {
1011 json!({
1012 "id": string_field(upstream, "id").unwrap_or(&self.fallback_id),
1013 "object": string_field(upstream, "object").unwrap_or("chat.completion.chunk"),
1014 "created": integer_field(upstream, "created").unwrap_or(self.fallback_created),
1015 "model": string_field(upstream, "model").unwrap_or(&self.fallback_model),
1016 "choices": [],
1017 "usage": usage,
1018 })
1019 }
1020}
1021
1022struct OpenAiChatCompletionBuffer {
1024 ctx: ChunkContext,
1025 id: Option<String>,
1026 created: Option<i64>,
1027 model: Option<String>,
1028 choice_index: Option<u64>,
1029 saw_encrypted_response_field: bool,
1030 content: String,
1031 reasoning_content: String,
1032 finish_reason: Option<Value>,
1033 usage: Option<Value>,
1034}
1035
1036impl OpenAiChatCompletionBuffer {
1037 fn new(codec: E2eeCodec, proxy_instance_key: ProxyInstanceKey, fallback_model: String) -> Self {
1039 Self {
1040 ctx: ChunkContext::new(codec, proxy_instance_key, fallback_model),
1041 id: None,
1042 created: None,
1043 model: None,
1044 choice_index: None,
1045 saw_encrypted_response_field: false,
1046 content: String::new(),
1047 reasoning_content: String::new(),
1048 finish_reason: None,
1049 usage: None,
1050 }
1051 }
1052
1053 fn handle_event(&mut self, event: RawSseEvent) -> Result<bool, ChatStreamError> {
1055 match classify_upstream_event(event, &BUFFERED_UPSTREAM_EVENT_LOG)? {
1056 UpstreamEventKind::Done => {
1057 if !self.saw_encrypted_response_field {
1058 self.ctx.decrypt(None)?;
1059 }
1060 if self.finish_reason.is_none() {
1061 self.finish_reason = Some(Value::String("stop".to_owned()));
1062 }
1063 Ok(true)
1064 }
1065 UpstreamEventKind::Usage(value) => {
1066 self.record_metadata(&value);
1067 self.handle_usage_chunk(&value).map(|()| false)
1068 }
1069 UpstreamEventKind::Choice { value, choice } => {
1070 self.record_metadata(&value);
1071 self.handle_choice_chunk(&choice)?;
1072 Ok(false)
1073 }
1074 }
1075 }
1076
1077 fn handle_usage_chunk(&mut self, value: &Value) -> Result<(), ChatStreamError> {
1079 let Some(usage) = value.get("usage") else {
1080 warn!("buffered upstream chunk has no choices and no usage");
1081 return Err(ChatStreamError::malformed_event(
1082 "upstream chunk has no choices and no usage",
1083 ));
1084 };
1085
1086 info!("buffered upstream usage chunk");
1087 self.usage = Some(usage.clone());
1088 Ok(())
1089 }
1090
1091 fn handle_choice_chunk(&mut self, choice: &Value) -> Result<(), ChatStreamError> {
1093 let choice = choice.as_object().ok_or_else(|| {
1094 ChatStreamError::malformed_event("upstream choice must be a JSON object")
1095 })?;
1096 let index = normalized_choice_index(choice.get("index"))?;
1097 match self.choice_index {
1098 Some(existing) if existing != index => {
1099 return Err(ChatStreamError::malformed_event(
1100 "upstream choice index changed while buffering a completion",
1101 ));
1102 }
1103 None => self.choice_index = Some(index),
1104 Some(_) => {}
1105 }
1106
1107 let finish_reason = normalized_finish_reason(choice.get("finish_reason"))?;
1108 let delta = choice.get("delta").unwrap_or(&Value::Null);
1109 let content = encrypted_delta_content(delta)?;
1110 let reasoning_content = encrypted_delta_reasoning_content(delta)?;
1111 debug!(
1112 choice_index = index,
1113 has_encrypted_content = content.is_some(),
1114 has_encrypted_reasoning_content = reasoning_content.is_some(),
1115 has_finish_reason = !finish_reason.is_null(),
1116 "transforming buffered upstream choice chunk"
1117 );
1118
1119 if let Some(content) = content {
1120 let decrypted = self.ctx.decrypt(Some(content))?;
1121 self.saw_encrypted_response_field = true;
1122 debug!(
1123 choice_index = index,
1124 has_decrypted_content = decrypted.is_some(),
1125 "decrypted buffered upstream content chunk"
1126 );
1127 if let Some(content) = decrypted {
1128 self.content.push_str(&content);
1129 }
1130 }
1131
1132 if let Some(reasoning_content) = reasoning_content {
1133 let decrypted = self.ctx.decrypt(Some(reasoning_content))?;
1134 self.saw_encrypted_response_field = true;
1135 debug!(
1136 choice_index = index,
1137 has_decrypted_reasoning_content = decrypted.is_some(),
1138 "decrypted buffered upstream reasoning content chunk"
1139 );
1140 if let Some(reasoning_content) = decrypted {
1141 self.reasoning_content.push_str(&reasoning_content);
1142 }
1143 }
1144
1145 if !finish_reason.is_null() {
1146 self.finish_reason = Some(finish_reason);
1147 }
1148
1149 Ok(())
1150 }
1151
1152 fn record_metadata(&mut self, value: &Value) {
1154 if self.id.is_none()
1155 && let Some(id) = string_field(value, "id")
1156 {
1157 self.id = Some(id.to_owned());
1158 }
1159 if self.created.is_none()
1160 && let Some(created) = integer_field(value, "created")
1161 {
1162 self.created = Some(created);
1163 }
1164 if self.model.is_none()
1165 && let Some(model) = string_field(value, "model")
1166 {
1167 self.model = Some(model.to_owned());
1168 }
1169 }
1170
1171 fn into_response(self) -> Value {
1173 let mut message = serde_json::Map::new();
1174 message.insert("role".to_owned(), Value::String("assistant".to_owned()));
1175 if !self.reasoning_content.is_empty() {
1176 message.insert(
1177 "reasoning_content".to_owned(),
1178 Value::String(self.reasoning_content),
1179 );
1180 }
1181 message.insert("content".to_owned(), Value::String(self.content));
1182
1183 json!({
1184 "id": self.id.unwrap_or(self.ctx.fallback_id),
1185 "object": "chat.completion",
1186 "created": self.created.unwrap_or(self.ctx.fallback_created),
1187 "model": self.model.unwrap_or(self.ctx.fallback_model),
1188 "choices": [{
1189 "index": self.choice_index.unwrap_or(0),
1190 "message": Value::Object(message),
1191 "finish_reason": self.finish_reason.unwrap_or_else(|| Value::String("stop".to_owned())),
1192 }],
1193 "usage": self.usage.unwrap_or(Value::Null),
1194 })
1195 }
1196}
1197
1198fn sse_event_boundary(buffer: &str) -> Option<(usize, usize)> {
1200 ["\r\n\r\n", "\n\n", "\r\r"]
1201 .into_iter()
1202 .filter_map(|delimiter| buffer.find(delimiter).map(|index| (index, delimiter.len())))
1203 .min_by_key(|(index, _)| *index)
1204}
1205
1206fn parse_sse_event(raw: &str) -> Result<Option<RawSseEvent>, ChatStreamError> {
1208 let mut event = None;
1209 let mut data_lines = Vec::new();
1210 let mut saw_non_comment_field = false;
1211
1212 for line in raw.lines() {
1213 let line = line.strip_suffix('\r').unwrap_or(line);
1214 if line.is_empty() || line.starts_with(':') {
1215 continue;
1216 }
1217
1218 saw_non_comment_field = true;
1219 let (field, value) = line.split_once(':').unwrap_or((line, ""));
1220 let value = value.strip_prefix(' ').unwrap_or(value);
1221 match field {
1222 "event" => event = Some(value.to_owned()),
1223 "data" => data_lines.push(value.to_owned()),
1224 "id" | "retry" => {}
1225 other => {
1226 warn!(field = other, "unsupported upstream SSE field");
1227 return Err(ChatStreamError::malformed_event(format!(
1228 "unsupported upstream SSE field {other:?}",
1229 )));
1230 }
1231 }
1232 }
1233
1234 if data_lines.is_empty() {
1235 return if saw_non_comment_field {
1236 warn!("upstream SSE event did not contain a data field");
1237 Err(ChatStreamError::malformed_event(
1238 "upstream SSE event did not contain a data field",
1239 ))
1240 } else {
1241 debug!("ignored upstream SSE comment or heartbeat event");
1242 Ok(None)
1243 };
1244 }
1245
1246 debug!(
1247 event_type = event.as_deref().unwrap_or("message"),
1248 data_line_count = data_lines.len(),
1249 "parsed upstream SSE event"
1250 );
1251
1252 Ok(Some(RawSseEvent {
1253 event,
1254 data: data_lines.join("\n"),
1255 }))
1256}
1257
1258struct OpenAiChatStreamTransformer {
1260 ctx: ChunkContext,
1261 include_usage_requested: bool,
1262 sent_role: bool,
1263 sent_final_finish: bool,
1264}
1265
1266impl OpenAiChatStreamTransformer {
1267 fn new(
1269 codec: E2eeCodec,
1270 proxy_instance_key: ProxyInstanceKey,
1271 fallback_model: String,
1272 include_usage_requested: bool,
1273 ) -> Self {
1274 Self {
1275 ctx: ChunkContext::new(codec, proxy_instance_key, fallback_model),
1276 include_usage_requested,
1277 sent_role: false,
1278 sent_final_finish: false,
1279 }
1280 }
1281
1282 fn handle_choice_chunk(
1284 &mut self,
1285 value: &Value,
1286 choice: &Value,
1287 ) -> Result<Vec<StreamOutput>, ChatStreamError> {
1288 let choice = choice.as_object().ok_or_else(|| {
1289 ChatStreamError::malformed_event("upstream choice must be a JSON object")
1290 })?;
1291 let finish_reason = normalized_finish_reason(choice.get("finish_reason"))?;
1292 let delta = choice.get("delta").unwrap_or(&Value::Null);
1293 let content = encrypted_delta_content(delta)?;
1294 let reasoning_content = encrypted_delta_reasoning_content(delta)?;
1295 debug!(
1296 has_encrypted_content = content.is_some(),
1297 has_encrypted_reasoning_content = reasoning_content.is_some(),
1298 has_finish_reason = !finish_reason.is_null(),
1299 "transforming streaming upstream choice chunk"
1300 );
1301
1302 let mut output = Vec::new();
1303
1304 if content.is_none() && reasoning_content.is_none() {
1305 if !finish_reason.is_null() {
1306 output.push(StreamOutput::Json(self.chunk_with_choice(
1307 value,
1308 choice.get("index"),
1309 json!({}),
1310 finish_reason,
1311 )?));
1312 self.sent_final_finish = true;
1313 }
1314 return Ok(output);
1315 }
1316
1317 let decrypted_content = match content {
1318 Some(content) => self.ctx.decrypt(Some(content))?,
1319 None => None,
1320 };
1321 let decrypted_reasoning_content = match reasoning_content {
1322 Some(reasoning_content) => self.ctx.decrypt(Some(reasoning_content))?,
1323 None => None,
1324 };
1325 debug!(
1326 has_decrypted_content = decrypted_content.is_some(),
1327 has_decrypted_reasoning_content = decrypted_reasoning_content.is_some(),
1328 "decrypted streaming upstream content chunk"
1329 );
1330
1331 if decrypted_content.is_some() || decrypted_reasoning_content.is_some() {
1332 let mut delta = serde_json::Map::new();
1333
1334 if !self.sent_role {
1335 delta.insert("role".to_owned(), Value::String("assistant".to_owned()));
1336 self.sent_role = true;
1337 }
1338
1339 if let Some(reasoning_content) = decrypted_reasoning_content {
1340 delta.insert(
1341 "reasoning_content".to_owned(),
1342 Value::String(reasoning_content),
1343 );
1344 }
1345
1346 if let Some(content) = decrypted_content {
1347 delta.insert("content".to_owned(), Value::String(content));
1348 }
1349
1350 let final_finish = !finish_reason.is_null();
1351 let content_finish_reason = if final_finish {
1352 Value::Null
1353 } else {
1354 finish_reason.clone()
1355 };
1356 output.push(StreamOutput::Json(self.chunk_with_choice(
1357 value,
1358 choice.get("index"),
1359 Value::Object(delta),
1360 content_finish_reason,
1361 )?));
1362 if final_finish {
1363 output.push(StreamOutput::Json(self.chunk_with_choice(
1364 value,
1365 choice.get("index"),
1366 json!({}),
1367 finish_reason,
1368 )?));
1369 self.sent_final_finish = true;
1370 }
1371 return Ok(output);
1372 }
1373
1374 Ok(output)
1375 }
1376
1377 fn handle_usage_chunk(&self, value: &Value) -> Result<Vec<StreamOutput>, ChatStreamError> {
1379 let Some(usage) = value.get("usage") else {
1380 warn!("streaming upstream chunk has no choices and no usage");
1381 return Err(ChatStreamError::malformed_event(
1382 "upstream chunk has no choices and no usage",
1383 ));
1384 };
1385
1386 if !self.include_usage_requested {
1390 debug!("streaming upstream usage chunk ignored because client did not request usage");
1391 return Ok(Vec::new());
1392 }
1393
1394 info!("streaming upstream usage chunk forwarded");
1395 Ok(vec![StreamOutput::Json(self.ctx.usage_chunk(value, usage))])
1396 }
1397
1398 fn finish_chunk(&self) -> Value {
1400 self.ctx
1401 .chunk_with_choice(&Value::Null, 0, json!({}), Value::String("stop".to_owned()))
1402 }
1403
1404 fn chunk_with_choice(
1406 &self,
1407 upstream: &Value,
1408 index: Option<&Value>,
1409 delta: Value,
1410 finish_reason: Value,
1411 ) -> Result<Value, ChatStreamError> {
1412 let index = normalized_choice_index(index)?;
1413 Ok(self
1414 .ctx
1415 .chunk_with_choice(upstream, index, delta, finish_reason))
1416 }
1417}
1418
1419impl ChatSseTransformer for OpenAiChatStreamTransformer {
1420 fn handle_event(&mut self, event: RawSseEvent) -> Result<Vec<StreamOutput>, ChatStreamError> {
1422 match classify_upstream_event(event, &STREAMING_UPSTREAM_EVENT_LOG)? {
1423 UpstreamEventKind::Done => {
1424 let mut output = Vec::new();
1425 if !self.sent_final_finish {
1426 debug!("synthesizing final streaming finish chunk before DONE");
1427 output.push(StreamOutput::Json(self.finish_chunk()));
1428 self.sent_final_finish = true;
1429 }
1430 output.push(StreamOutput::Done);
1431 Ok(output)
1432 }
1433 UpstreamEventKind::Usage(value) => self.handle_usage_chunk(&value),
1434 UpstreamEventKind::Choice { value, choice } => {
1435 self.handle_choice_chunk(&value, &choice)
1436 }
1437 }
1438 }
1439}
1440
1441const TOOL_CALL_START_MARKER: &str = "<tool_call>";
1442
1443struct OpenAiToolEmulatedChatStreamTransformer {
1449 ctx: ChunkContext,
1450 tool_context: ToolEmulationContext,
1451 include_usage_requested: bool,
1452 sent_role: bool,
1453 sent_final_finish: bool,
1454 pending_text: String,
1455 tool_buffer: String,
1456 buffering_tool_call: bool,
1457 emitted_tool_calls: bool,
1458}
1459
1460impl OpenAiToolEmulatedChatStreamTransformer {
1461 fn new(
1463 tool_context: &ToolEmulationContext,
1464 codec: E2eeCodec,
1465 proxy_instance_key: ProxyInstanceKey,
1466 fallback_model: String,
1467 include_usage_requested: bool,
1468 ) -> Result<Self, ChatStreamError> {
1469 Ok(Self {
1470 ctx: ChunkContext::new(codec, proxy_instance_key, fallback_model),
1471 tool_context: tool_context.clone(),
1472 include_usage_requested,
1473 sent_role: false,
1474 sent_final_finish: false,
1475 pending_text: String::new(),
1476 tool_buffer: String::new(),
1477 buffering_tool_call: false,
1478 emitted_tool_calls: false,
1479 })
1480 }
1481
1482 fn handle_choice_chunk(
1484 &mut self,
1485 value: &Value,
1486 choice: &Value,
1487 ) -> Result<Vec<StreamOutput>, ChatStreamError> {
1488 let choice = choice.as_object().ok_or_else(|| {
1489 ChatStreamError::malformed_event("upstream choice must be a JSON object")
1490 })?;
1491 let index = normalized_choice_index(choice.get("index"))?;
1492 let finish_reason = normalized_finish_reason(choice.get("finish_reason"))?;
1493 let delta = choice.get("delta").unwrap_or(&Value::Null);
1494 let content = encrypted_delta_content(delta)?;
1495 let reasoning_content = encrypted_delta_reasoning_content(delta)?;
1496
1497 let mut output = Vec::new();
1498
1499 if let Some(reasoning_content) = reasoning_content
1500 && let Some(reasoning_content) = self.ctx.decrypt(Some(reasoning_content))?
1501 && !self.sent_final_finish
1502 {
1503 output.push(self.reasoning_chunk(value, index, reasoning_content));
1504 }
1505
1506 if let Some(content) = content
1507 && let Some(content) = self.ctx.decrypt(Some(content))?
1508 && !self.sent_final_finish
1509 {
1510 output.extend(self.push_decrypted_content(value, index, &content)?);
1511 }
1512
1513 if !finish_reason.is_null() && !self.sent_final_finish {
1514 output.extend(self.finish_buffered_content(value, index, finish_reason)?);
1515 }
1516
1517 Ok(output)
1518 }
1519
1520 fn push_decrypted_content(
1522 &mut self,
1523 upstream: &Value,
1524 index: u64,
1525 content: &str,
1526 ) -> Result<Vec<StreamOutput>, ChatStreamError> {
1527 if self.buffering_tool_call {
1528 self.tool_buffer.push_str(content);
1529 self.ensure_tool_buffer_within_limit()?;
1530 return Ok(Vec::new());
1531 }
1532
1533 self.pending_text.push_str(content);
1534 if let Some(marker_index) = self.pending_text.find(TOOL_CALL_START_MARKER) {
1535 let text = self.pending_text[..marker_index].to_owned();
1536 self.tool_buffer = self.pending_text[marker_index..].to_owned();
1537 self.pending_text.clear();
1538 self.buffering_tool_call = true;
1539 self.ensure_tool_buffer_within_limit()?;
1540 return Ok(self.text_chunk_if_not_empty(upstream, index, text));
1541 }
1542
1543 let streamable_len = streamable_pending_text_len(&self.pending_text);
1544 if streamable_len == 0 {
1545 return Ok(Vec::new());
1546 }
1547
1548 let text = self.pending_text[..streamable_len].to_owned();
1549 self.pending_text.drain(..streamable_len);
1550 Ok(vec![
1551 self.text_field_chunk(upstream, index, "content", text),
1552 ])
1553 }
1554
1555 fn finish_buffered_content(
1557 &mut self,
1558 upstream: &Value,
1559 index: u64,
1560 finish_reason: Value,
1561 ) -> Result<Vec<StreamOutput>, ChatStreamError> {
1562 let mut output = Vec::new();
1563
1564 if self.buffering_tool_call {
1565 output.extend(self.buffered_tool_call_chunks(upstream, index)?);
1566 } else if !self.pending_text.is_empty() {
1567 let text = std::mem::take(&mut self.pending_text);
1568 output.push(self.text_field_chunk(upstream, index, "content", text));
1569 }
1570
1571 let finish_reason = if self.emitted_tool_calls {
1572 Value::String("tool_calls".to_owned())
1573 } else {
1574 finish_reason
1575 };
1576 output.push(StreamOutput::Json(self.ctx.chunk_with_choice(
1577 upstream,
1578 index,
1579 json!({}),
1580 finish_reason,
1581 )));
1582 self.sent_final_finish = true;
1583 Ok(output)
1584 }
1585
1586 fn buffered_tool_call_chunks(
1588 &mut self,
1589 upstream: &Value,
1590 index: u64,
1591 ) -> Result<Vec<StreamOutput>, ChatStreamError> {
1592 self.ensure_tool_buffer_within_limit()?;
1593 match self
1594 .tool_context
1595 .classify_assistant_output(&self.tool_buffer)
1596 {
1597 ToolOutputClassification::ToolCalls(tool_calls) => {
1598 self.emitted_tool_calls = true;
1599 Ok(tool_calls
1600 .iter()
1601 .enumerate()
1602 .map(|(tool_index, tool_call)| {
1603 self.full_tool_call_chunk(upstream, index, tool_index, tool_call)
1604 })
1605 .collect())
1606 }
1607 ToolOutputClassification::NormalText => {
1608 let text = std::mem::take(&mut self.tool_buffer);
1609 self.buffering_tool_call = false;
1610 Ok(self.text_chunk_if_not_empty(upstream, index, text))
1611 }
1612 ToolOutputClassification::InvalidToolCall { error, .. } => {
1613 error!(
1614 validation_error = %error,
1615 payload_bytes = self.tool_buffer.len(),
1616 payload = %self.tool_buffer,
1617 "buffered streamed tool-call payload failed validation"
1618 );
1619 Err(ChatStreamError::malformed_event(format!(
1620 "tool call parsing failed: {error}"
1621 )))
1622 }
1623 }
1624 }
1625
1626 fn ensure_tool_buffer_within_limit(&self) -> Result<(), ChatStreamError> {
1628 if self.tool_buffer.len() > self.tool_context.config().tool_call_max_bytes {
1629 return Err(ChatStreamError::malformed_event(format!(
1630 "tool call output exceeded max size of {} bytes",
1631 self.tool_context.config().tool_call_max_bytes
1632 )));
1633 }
1634 Ok(())
1635 }
1636
1637 fn text_chunk_if_not_empty(
1639 &mut self,
1640 upstream: &Value,
1641 index: u64,
1642 text: String,
1643 ) -> Vec<StreamOutput> {
1644 if text.is_empty() {
1645 Vec::new()
1646 } else {
1647 vec![self.text_field_chunk(upstream, index, "content", text)]
1648 }
1649 }
1650
1651 fn reasoning_chunk(
1653 &mut self,
1654 upstream: &Value,
1655 index: u64,
1656 reasoning_content: String,
1657 ) -> StreamOutput {
1658 self.text_field_chunk(upstream, index, "reasoning_content", reasoning_content)
1659 }
1660
1661 fn text_field_chunk(
1663 &mut self,
1664 upstream: &Value,
1665 index: u64,
1666 field: &'static str,
1667 text: String,
1668 ) -> StreamOutput {
1669 let mut delta = serde_json::Map::new();
1670 self.insert_role_if_needed(&mut delta);
1671 delta.insert(field.to_owned(), Value::String(text));
1672
1673 StreamOutput::Json(self.ctx.chunk_with_choice(
1674 upstream,
1675 index,
1676 Value::Object(delta),
1677 Value::Null,
1678 ))
1679 }
1680
1681 fn insert_role_if_needed(&mut self, delta: &mut serde_json::Map<String, Value>) {
1683 if !self.sent_role {
1684 delta.insert("role".to_owned(), Value::String("assistant".to_owned()));
1685 self.sent_role = true;
1686 }
1687 }
1688
1689 fn full_tool_call_chunk(
1691 &mut self,
1692 upstream: &Value,
1693 index: u64,
1694 tool_index: usize,
1695 tool_call: &ValidatedToolCall,
1696 ) -> StreamOutput {
1697 let mut delta = serde_json::Map::new();
1698 self.insert_role_if_needed(&mut delta);
1699
1700 let mut tool_call_value = tool_call.to_openai_value();
1701
1702 if let Some(tool_call_object) = tool_call_value.as_object_mut() {
1703 tool_call_object.insert("index".to_owned(), json!(tool_index));
1704 }
1705 delta.insert("tool_calls".to_owned(), Value::Array(vec![tool_call_value]));
1706
1707 StreamOutput::Json(self.ctx.chunk_with_choice(
1708 upstream,
1709 index,
1710 Value::Object(delta),
1711 Value::Null,
1712 ))
1713 }
1714
1715 fn handle_usage_chunk(&self, value: &Value) -> Result<Vec<StreamOutput>, ChatStreamError> {
1717 let Some(usage) = value.get("usage") else {
1718 warn!("tool-emulated upstream chunk has no choices and no usage");
1719 return Err(ChatStreamError::malformed_event(
1720 "upstream chunk has no choices and no usage",
1721 ));
1722 };
1723
1724 if !self.include_usage_requested {
1726 return Ok(Vec::new());
1727 }
1728
1729 Ok(vec![StreamOutput::Json(self.ctx.usage_chunk(value, usage))])
1730 }
1731
1732 fn finish_stream(&mut self) -> Result<Vec<StreamOutput>, ChatStreamError> {
1734 let upstream = &Value::Null;
1735 let mut output = Vec::new();
1736
1737 if !self.sent_final_finish {
1738 output.extend(self.finish_buffered_content(
1739 upstream,
1740 0,
1741 Value::String("stop".to_owned()),
1742 )?);
1743 }
1744
1745 output.push(StreamOutput::Done);
1746 Ok(output)
1747 }
1748}
1749
1750fn streamable_pending_text_len(pending_text: &str) -> usize {
1752 let protected_suffix_len = TOOL_CALL_START_MARKER.len().saturating_sub(1);
1753 if pending_text.len() <= protected_suffix_len {
1754 return 0;
1755 }
1756
1757 let mut split_at = pending_text.len() - protected_suffix_len;
1758 while !pending_text.is_char_boundary(split_at) {
1759 split_at -= 1;
1760 }
1761 split_at
1762}
1763
1764impl ChatSseTransformer for OpenAiToolEmulatedChatStreamTransformer {
1765 fn handle_event(&mut self, event: RawSseEvent) -> Result<Vec<StreamOutput>, ChatStreamError> {
1767 match classify_upstream_event(event, &TOOL_EMULATED_UPSTREAM_EVENT_LOG)? {
1768 UpstreamEventKind::Done => self.finish_stream(),
1769 UpstreamEventKind::Usage(value) => self.handle_usage_chunk(&value),
1770 UpstreamEventKind::Choice { value, choice } => {
1771 self.handle_choice_chunk(&value, &choice)
1772 }
1773 }
1774 }
1775}
1776
1777#[derive(Debug, Clone, PartialEq, Eq)]
1779enum StreamOutput {
1780 Json(Value),
1781 Done,
1782}
1783
1784fn normalized_choice_index(index: Option<&Value>) -> Result<u64, ChatStreamError> {
1786 match index {
1787 Some(Value::Number(number)) => number.as_u64().ok_or_else(|| {
1788 ChatStreamError::malformed_event("upstream choice index must be a non-negative integer")
1789 }),
1790 Some(_) => Err(ChatStreamError::malformed_event(
1791 "upstream choice index must be a non-negative integer",
1792 )),
1793 None => Ok(0),
1794 }
1795}
1796
1797fn normalized_finish_reason(value: Option<&Value>) -> Result<Value, ChatStreamError> {
1799 match value {
1800 Some(Value::Null) | None => Ok(Value::Null),
1801 Some(Value::String(reason)) => Ok(Value::String(reason.clone())),
1802 Some(_) => Err(ChatStreamError::malformed_event(
1803 "upstream finish_reason must be a string or null",
1804 )),
1805 }
1806}
1807
1808fn encrypted_delta_content(delta: &Value) -> Result<Option<&str>, ChatStreamError> {
1810 encrypted_delta_text_field(delta, "content")
1811}
1812
1813fn encrypted_delta_reasoning_content(delta: &Value) -> Result<Option<&str>, ChatStreamError> {
1815 encrypted_delta_text_field(delta, "reasoning_content")
1816}
1817
1818fn encrypted_delta_text_field<'a>(
1820 delta: &'a Value,
1821 field: &'static str,
1822) -> Result<Option<&'a str>, ChatStreamError> {
1823 match delta.get(field) {
1824 Some(Value::Null) => {
1825 debug!(field, "ignoring null upstream delta text field");
1826 Ok(None)
1827 }
1828 Some(Value::String(content)) if content.is_empty() => {
1829 debug!(field, "ignoring empty upstream delta text field");
1830 Ok(None)
1831 }
1832 Some(Value::String(content)) => Ok(Some(content.as_str())),
1833 Some(_) => Err(ChatStreamError::malformed_event(format!(
1834 "upstream delta.{field} must be a string or null"
1835 ))),
1836 None => Ok(None),
1837 }
1838}
1839
1840fn string_field<'a>(value: &'a Value, field: &str) -> Option<&'a str> {
1842 value.get(field).and_then(Value::as_str)
1843}
1844
1845fn integer_field(value: &Value, field: &str) -> Option<i64> {
1847 value.get(field).and_then(Value::as_i64)
1848}
1849
1850fn unix_timestamp_now() -> i64 {
1852 SystemTime::now()
1853 .duration_since(UNIX_EPOCH)
1854 .map(|duration| duration.as_secs() as i64)
1855 .unwrap_or(0)
1856}
1857
1858async fn method_not_allowed(method: Method, uri: Uri) -> ProxyError {
1860 ProxyError::MethodNotAllowed { method, uri }
1861}
1862
1863async fn not_found(uri: Uri) -> ProxyError {
1865 ProxyError::NotFound { uri }
1866}
1867
1868#[derive(Debug, Error)]
1870pub enum ChatStreamError {
1871 #[error("Venice upstream stream failed: {message}")]
1872 UpstreamStream { message: String },
1873 #[error("Venice upstream stream emitted an error event: {message}")]
1874 UpstreamEvent { message: String },
1875 #[error("Venice upstream stream event is malformed: {message}")]
1876 MalformedEvent { message: String },
1877 #[error("failed to decrypt Venice E2EE response chunk: {source}")]
1878 Decryption { source: E2eeCodecError },
1879}
1880
1881impl ChatStreamError {
1882 fn upstream_stream(source: reqwest::Error) -> Self {
1884 Self::UpstreamStream {
1885 message: source.to_string(),
1886 }
1887 }
1888
1889 fn upstream_event(message: impl Into<String>) -> Self {
1891 Self::UpstreamEvent {
1892 message: message.into(),
1893 }
1894 }
1895
1896 fn malformed_event(message: impl Into<String>) -> Self {
1898 Self::MalformedEvent {
1899 message: message.into(),
1900 }
1901 }
1902
1903 fn invalid_utf8(source: std::str::Utf8Error) -> Self {
1905 Self::MalformedEvent {
1906 message: format!("upstream SSE bytes are not valid UTF-8: {source}"),
1907 }
1908 }
1909
1910 fn json_event(source: serde_json::Error) -> Self {
1912 Self::MalformedEvent {
1913 message: format!("upstream SSE data is not valid JSON: {source}"),
1914 }
1915 }
1916
1917 fn decryption(source: E2eeCodecError) -> Self {
1919 Self::Decryption { source }
1920 }
1921
1922 fn api_error_type(&self) -> &'static str {
1924 match self {
1925 Self::UpstreamStream { .. }
1926 | Self::UpstreamEvent { .. }
1927 | Self::MalformedEvent { .. } => "proxy_upstream_error",
1928 Self::Decryption { .. } => "proxy_e2ee_error",
1929 }
1930 }
1931
1932 fn api_error_code(&self) -> &'static str {
1934 match self {
1935 Self::UpstreamStream { .. } => "upstream_stream_error",
1936 Self::UpstreamEvent { .. } => "upstream_stream_error",
1937 Self::MalformedEvent { .. } => "upstream_malformed_response",
1938 Self::Decryption { .. } => "e2ee_response_decryption_failed",
1939 }
1940 }
1941}
1942
1943#[derive(Debug, Error)]
1945pub enum ProxyError {
1946 #[error(transparent)]
1947 Venice(#[from] VeniceClientError),
1948 #[error(transparent)]
1949 Attestation(#[from] AttestationError),
1950 #[error(transparent)]
1951 Session(#[from] SessionError),
1952 #[error(transparent)]
1953 ChatRequest(#[from] ChatRequestError),
1954 #[error(transparent)]
1955 ChatConstruction(#[from] ChatConstructionError),
1956 #[error(transparent)]
1957 ChatStream(#[from] ChatStreamError),
1958 #[error("The model failed to produce a valid tool call after correction attempts.")]
1959 ToolCallRetryExhausted {
1960 max_retries: u32,
1961 last_validation_error: String,
1962 },
1963 #[error(
1964 "proxy instance key is unavailable; keys.generate_proxy_instance_key_on_startup must be enabled for E2EE chat requests"
1965 )]
1966 ProxyInstanceKeyUnavailable,
1967 #[error("session does not contain an attested model public key after attestation verification")]
1968 MissingAttestedModelKey,
1969 #[error("method {method} is not supported for {uri}")]
1970 MethodNotAllowed { method: Method, uri: Uri },
1971 #[error("route {uri} was not found")]
1972 NotFound { uri: Uri },
1973}
1974
1975impl ProxyError {
1976 fn status(&self) -> StatusCode {
1978 match self {
1979 Self::Venice(_) => StatusCode::BAD_GATEWAY,
1980 Self::Attestation(error) if error.verifier_unavailable() => {
1981 StatusCode::SERVICE_UNAVAILABLE
1982 }
1983 Self::Attestation(_) => StatusCode::BAD_GATEWAY,
1984 Self::Session(
1985 SessionError::MissingSessionIdentifier | SessionError::InvalidHeaderValue { .. },
1986 ) => StatusCode::BAD_REQUEST,
1987 Self::Session(_) => StatusCode::INTERNAL_SERVER_ERROR,
1988 Self::ChatRequest(_) => StatusCode::BAD_REQUEST,
1989 Self::ChatConstruction(_)
1990 | Self::ChatStream(_)
1991 | Self::ToolCallRetryExhausted { .. } => StatusCode::BAD_GATEWAY,
1992 Self::ProxyInstanceKeyUnavailable | Self::MissingAttestedModelKey => {
1993 StatusCode::INTERNAL_SERVER_ERROR
1994 }
1995 Self::MethodNotAllowed { .. } => StatusCode::METHOD_NOT_ALLOWED,
1996 Self::NotFound { .. } => StatusCode::NOT_FOUND,
1997 }
1998 }
1999
2000 fn error_type(&self) -> &'static str {
2002 match self {
2003 Self::Venice(error) => error.api_error_type(),
2004 Self::Attestation(error) => error.api_error_type(),
2005 Self::Session(
2006 SessionError::MissingSessionIdentifier | SessionError::InvalidHeaderValue { .. },
2007 ) => "invalid_request_error",
2008 Self::Session(_) => "proxy_session_error",
2009 Self::ChatRequest(_) => "invalid_request_error",
2010 Self::ChatConstruction(_) => "proxy_e2ee_error",
2011 Self::ChatStream(error) => error.api_error_type(),
2012 Self::ToolCallRetryExhausted { .. } => "proxy_tool_call_error",
2013 Self::ProxyInstanceKeyUnavailable => "proxy_configuration_error",
2014 Self::MissingAttestedModelKey => "proxy_attestation_error",
2015 Self::MethodNotAllowed { .. } | Self::NotFound { .. } => "invalid_request_error",
2016 }
2017 }
2018
2019 fn code(&self) -> &'static str {
2021 match self {
2022 Self::Venice(error) => error.api_error_code(),
2023 Self::Attestation(error) => error.api_error_code(),
2024 Self::Session(SessionError::MissingSessionIdentifier) => "session_identifier_missing",
2025 Self::Session(SessionError::InvalidHeaderValue { .. }) => "invalid_session_header",
2026 Self::Session(_) => "session_error",
2027 Self::ChatRequest(error) => error.api_error_code(),
2028 Self::ChatConstruction(error) => error.api_error_code(),
2029 Self::ChatStream(error) => error.api_error_code(),
2030 Self::ToolCallRetryExhausted { .. } => "invalid_tool_call",
2031 Self::ProxyInstanceKeyUnavailable => "proxy_instance_key_unavailable",
2032 Self::MissingAttestedModelKey => "attestation_failed",
2033 Self::MethodNotAllowed { .. } => "method_not_allowed",
2034 Self::NotFound { .. } => "not_found",
2035 }
2036 }
2037}
2038
2039impl IntoResponse for ProxyError {
2040 fn into_response(self) -> Response {
2042 let status = self.status();
2043 let error_code = self.code();
2044 let error_type = self.error_type();
2045
2046 if status.is_server_error() {
2047 error!(
2048 status = status.as_u16(),
2049 error_code,
2050 error_type,
2051 error = %self,
2052 "proxy request failed"
2053 );
2054 } else {
2055 warn!(
2056 status = status.as_u16(),
2057 error_code,
2058 error_type,
2059 error = %self,
2060 "proxy request rejected"
2061 );
2062 }
2063
2064 let mut response = if let Self::ToolCallRetryExhausted {
2065 max_retries,
2066 last_validation_error,
2067 } = &self
2068 {
2069 let body = json!({
2070 "error": {
2071 "message": self.to_string(),
2072 "type": error_type,
2073 "code": error_code,
2074 "details": {
2075 "max_retries": max_retries,
2076 "last_validation_error": last_validation_error,
2077 },
2078 }
2079 });
2080 (status, Json(body)).into_response()
2081 } else {
2082 let body = ErrorResponse::new(self.to_string(), error_type, error_code);
2083 (status, Json(body)).into_response()
2084 };
2085
2086 apply_error_headers(response.headers_mut(), error_code);
2087 response
2088 }
2089}
2090
2091#[derive(Debug, Clone, Default, PartialEq, Eq)]
2096pub struct ProxyMetadataHeaders {
2097 pub e2ee: Option<String>,
2098 pub attestation_mode: Option<String>,
2099 pub attested_model: Option<String>,
2100 pub tee_provider: Option<String>,
2101 pub tdx_verified: Option<bool>,
2102 pub tdx_debug: Option<bool>,
2103 pub nvidia_verified: Option<String>,
2104 pub key_binding: Option<bool>,
2105 pub session_id: Option<String>,
2106 pub session_scope: Option<String>,
2107 pub tool_mode: Option<String>,
2108 pub tool_retries: Option<u32>,
2109}
2110
2111impl ProxyMetadataHeaders {
2112 pub fn from_config(config: &ProxyConfig) -> Self {
2115 Self {
2116 attestation_mode: Some(config.attestation.mode.as_str().to_owned()),
2117 tool_mode: Some(config.tools.mode.as_str().to_owned()),
2118 ..Self::default()
2119 }
2120 }
2121
2122 pub fn for_verified_chat(config: &ProxyConfig, session: &SessionContext) -> Self {
2124 let evidence = session
2125 .attestation_report
2126 .as_ref()
2127 .and_then(|report| report.get("attestation"))
2128 .and_then(Value::as_object);
2129 let tee_provider = evidence
2130 .and_then(|evidence| evidence.get("tee_provider"))
2131 .and_then(Value::as_str)
2132 .unwrap_or("unknown")
2133 .to_owned();
2134 let tdx_debug = evidence.and_then(|evidence| {
2135 evidence
2136 .get("debug")
2137 .or_else(|| evidence.get("tdx_debug"))
2138 .and_then(Value::as_bool)
2139 });
2140 let nvidia_payload_present = evidence
2141 .and_then(|evidence| evidence.get("nvidia_payload"))
2142 .is_some_and(|value| !value.is_null());
2143 let nvidia_verified = match (config.attestation.require_nvidia, nvidia_payload_present) {
2144 (_, false) => "not-present",
2145 (NvidiaRequirement::Never, true) => "ignored",
2146 (_, true) => "verified",
2147 }
2148 .to_owned();
2149
2150 Self {
2151 e2ee: Some("verified".to_owned()),
2152 attestation_mode: Some(config.attestation.mode.as_str().to_owned()),
2153 attested_model: Some(session.model_id.clone()),
2154 tee_provider: Some(tee_provider),
2155 tdx_verified: config.attestation.require_tdx.then_some(true),
2156 tdx_debug,
2157 nvidia_verified: Some(nvidia_verified),
2158 key_binding: Some(true),
2159 session_id: Some(session.agent_session_id.clone()),
2160 session_scope: Some(session.scope.as_str().to_owned()),
2161 tool_mode: Some(config.tools.mode.as_str().to_owned()),
2162 tool_retries: None,
2163 }
2164 }
2165
2166 pub fn apply(&self, headers: &mut HeaderMap) {
2168 insert_optional_header(headers, HEADER_PROXY_E2EE, self.e2ee.as_deref());
2169 insert_optional_header(
2170 headers,
2171 HEADER_PROXY_ATTESTATION_MODE,
2172 self.attestation_mode.as_deref(),
2173 );
2174 insert_optional_header(
2175 headers,
2176 HEADER_PROXY_ATTESTED_MODEL,
2177 self.attested_model.as_deref(),
2178 );
2179 insert_optional_header(
2180 headers,
2181 HEADER_PROXY_TEE_PROVIDER,
2182 self.tee_provider.as_deref(),
2183 );
2184 insert_optional_bool_header(headers, HEADER_PROXY_TDX_VERIFIED, self.tdx_verified);
2185 insert_optional_bool_header(headers, HEADER_PROXY_TDX_DEBUG, self.tdx_debug);
2186 insert_optional_header(
2187 headers,
2188 HEADER_PROXY_NVIDIA_VERIFIED,
2189 self.nvidia_verified.as_deref(),
2190 );
2191 insert_optional_bool_header(headers, HEADER_PROXY_KEY_BINDING, self.key_binding);
2192 insert_optional_header(headers, HEADER_PROXY_SESSION_ID, self.session_id.as_deref());
2193 insert_optional_header(
2194 headers,
2195 HEADER_PROXY_SESSION_SCOPE,
2196 self.session_scope.as_deref(),
2197 );
2198 insert_optional_header(headers, HEADER_PROXY_TOOL_MODE, self.tool_mode.as_deref());
2199 if let Some(tool_retries) = self.tool_retries {
2200 insert_header(
2201 headers,
2202 HEADER_PROXY_TOOL_RETRIES,
2203 &tool_retries.to_string(),
2204 );
2205 }
2206 }
2207}
2208
2209pub fn apply_error_headers(headers: &mut HeaderMap, error_code: &str) {
2211 insert_header(headers, HEADER_PROXY_ERROR_CODE, error_code);
2212}
2213
2214fn insert_optional_header(headers: &mut HeaderMap, name: &'static str, value: Option<&str>) {
2216 if let Some(value) = value {
2217 insert_header(headers, name, value);
2218 }
2219}
2220
2221fn insert_optional_bool_header(headers: &mut HeaderMap, name: &'static str, value: Option<bool>) {
2223 if let Some(value) = value {
2224 insert_header(headers, name, if value { "true" } else { "false" });
2225 }
2226}
2227
2228fn insert_header(headers: &mut HeaderMap, name: &'static str, value: &str) {
2230 let Ok(name) = HeaderName::from_bytes(name.as_bytes()) else {
2231 return;
2232 };
2233 let Ok(value) = HeaderValue::from_str(value) else {
2234 return;
2235 };
2236 headers.insert(name, value);
2237}
2238
2239#[cfg(test)]
2240mod tests {
2241 use super::*;
2242 use std::{
2243 collections::{HashMap, VecDeque},
2244 sync::{Arc, Mutex},
2245 time::Duration,
2246 };
2247
2248 use axum::{
2249 body::Body,
2250 extract::Query,
2251 http::Request,
2252 routing::{get, post},
2253 };
2254 use serde_json::json;
2255
2256 use crate::config::NvidiaRequirement;
2257 use tower::ServiceExt;
2258
2259 fn test_app() -> Router {
2260 router_with_venice_client(ProxyConfig::default(), test_venice_client())
2261 }
2262
2263 fn test_venice_client() -> VeniceClient {
2264 test_venice_client_for_base_url("http://127.0.0.1:1/api/v1")
2265 }
2266
2267 fn test_venice_client_for_base_url(base_url: impl AsRef<str>) -> VeniceClient {
2268 VeniceClient::new(base_url.as_ref(), "test-api-key", Duration::from_secs(1))
2269 .expect("test Venice client should build")
2270 }
2271
2272 fn chat_config_with_basic_test_attestation() -> ProxyConfig {
2273 let mut config = ProxyConfig::default();
2274 config.attestation.require_tdx = false;
2275 config.attestation.require_nvidia = NvidiaRequirement::Never;
2276 config
2277 }
2278
2279 #[test]
2280 fn app_state_initializes_key_and_session_managers_from_config() {
2281 let state = AppState::from_parts(ProxyConfig::default(), test_venice_client());
2282
2283 let key = state
2284 .proxy_instance_key()
2285 .expect("default config should generate startup key");
2286 assert_eq!(key.public_key_hex().len(), 130);
2287 assert!(state.session_manager().is_empty());
2288 assert_eq!(
2289 state.attestation_verifier().policy(),
2290 &ProxyConfig::default().attestation
2291 );
2292
2293 let mut config = ProxyConfig::default();
2294 config.keys.generate_proxy_instance_key_on_startup = false;
2295 let state = AppState::from_parts(config, test_venice_client());
2296 assert!(state.proxy_instance_key().is_none());
2297 }
2298
2299 async fn error_body(response: Response) -> ErrorResponse {
2300 let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
2301 .await
2302 .expect("response body should buffer");
2303 serde_json::from_slice(&bytes).expect("response should be OpenAI-style error JSON")
2304 }
2305
2306 #[tokio::test]
2307 async fn chat_route_ignores_upstream_role_only_chunk_before_encrypted_content() {
2308 let response = streaming_chat_response(
2309 "chat-route-role-only",
2310 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
2311 vec![
2312 MockStreamFrame::Role,
2313 MockStreamFrame::Text("Hello"),
2314 MockStreamFrame::Finish("stop"),
2315 MockStreamFrame::Done,
2316 ],
2317 )
2318 .await;
2319
2320 assert_eq!(response.status(), StatusCode::OK);
2321 let body = response_body(response).await;
2322 let data = sse_data(&body);
2323 assert_eq!(data.len(), 3);
2324 let first: Value = serde_json::from_str(data[0]).expect("first chunk should be JSON");
2325 assert_eq!(first["choices"][0]["delta"]["role"], "assistant");
2326 assert_eq!(first["choices"][0]["delta"]["content"], "Hello");
2327 assert_eq!(data[2], "[DONE]");
2328 }
2329
2330 #[tokio::test]
2331 async fn chat_route_streams_decrypted_normal_assistant_text() {
2332 let response = streaming_chat_response(
2333 "chat-route-test",
2334 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
2335 vec![
2336 MockStreamFrame::NullContent,
2337 MockStreamFrame::EmptyContent,
2338 MockStreamFrame::Text("Hello"),
2339 MockStreamFrame::Finish("stop"),
2340 MockStreamFrame::Done,
2341 ],
2342 )
2343 .await;
2344
2345 assert_eq!(response.status(), StatusCode::OK);
2346 assert_eq!(
2347 response.headers().get(HEADER_PROXY_E2EE).unwrap(),
2348 "verified"
2349 );
2350 assert_eq!(
2351 response.headers().get(HEADER_PROXY_ATTESTED_MODEL).unwrap(),
2352 "e2ee-test"
2353 );
2354
2355 let body = response_body(response).await;
2356 let data = sse_data(&body);
2357 assert_eq!(data.len(), 3);
2358
2359 let first: Value = serde_json::from_str(data[0]).expect("first chunk should be JSON");
2360 assert_eq!(first["object"], "chat.completion.chunk");
2361 assert_eq!(first["model"], "e2ee-test");
2362 assert_eq!(first["choices"][0]["delta"]["role"], "assistant");
2363 assert_eq!(first["choices"][0]["delta"]["content"], "Hello");
2364 assert!(first["choices"][0]["finish_reason"].is_null());
2365
2366 let final_chunk: Value = serde_json::from_str(data[1]).expect("final chunk should be JSON");
2367 assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
2368 assert_eq!(final_chunk["choices"][0]["finish_reason"], "stop");
2369 assert_eq!(data[2], "[DONE]");
2370 }
2371
2372 #[tokio::test]
2373 async fn chat_route_streams_decrypted_reasoning_content() {
2374 let response = streaming_chat_response(
2375 "chat-route-reasoning-stream",
2376 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"reasoning":{"effort":"high"}}"#,
2377 vec![
2378 MockStreamFrame::Reasoning("Thinking"),
2379 MockStreamFrame::Text("Answer"),
2380 MockStreamFrame::Finish("stop"),
2381 MockStreamFrame::Done,
2382 ],
2383 )
2384 .await;
2385
2386 assert_eq!(response.status(), StatusCode::OK);
2387 let body = response_body(response).await;
2388 let data = sse_data(&body);
2389 assert_eq!(data.len(), 4);
2390 let reasoning: Value =
2391 serde_json::from_str(data[0]).expect("reasoning chunk should be JSON");
2392 let answer: Value = serde_json::from_str(data[1]).expect("answer chunk should be JSON");
2393
2394 assert_eq!(reasoning["choices"][0]["delta"]["role"], "assistant");
2395 assert_eq!(
2396 reasoning["choices"][0]["delta"]["reasoning_content"],
2397 "Thinking"
2398 );
2399 assert!(answer["choices"][0]["delta"].get("role").is_none());
2400 assert_eq!(answer["choices"][0]["delta"]["content"], "Answer");
2401 assert_eq!(data.last().copied(), Some("[DONE]"));
2402 }
2403
2404 #[tokio::test]
2405 async fn chat_route_streams_multiple_decrypted_content_chunks() {
2406 let response = streaming_chat_response(
2407 "chat-route-multiple-chunks",
2408 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
2409 vec![
2410 MockStreamFrame::Text("Hello"),
2411 MockStreamFrame::Text(" world"),
2412 MockStreamFrame::Finish("stop"),
2413 MockStreamFrame::Done,
2414 ],
2415 )
2416 .await;
2417
2418 assert_eq!(response.status(), StatusCode::OK);
2419 let body = response_body(response).await;
2420 let data = sse_data(&body);
2421 let first: Value = serde_json::from_str(data[0]).expect("first chunk should be JSON");
2422 let second: Value = serde_json::from_str(data[1]).expect("second chunk should be JSON");
2423
2424 assert_eq!(first["choices"][0]["delta"]["role"], "assistant");
2425 assert_eq!(first["choices"][0]["delta"]["content"], "Hello");
2426 assert!(second["choices"][0]["delta"].get("role").is_none());
2427 assert_eq!(second["choices"][0]["delta"]["content"], " world");
2428 assert_eq!(data.last().copied(), Some("[DONE]"));
2429 }
2430
2431 #[tokio::test]
2432 async fn chat_route_passes_through_usage_chunk_when_requested_and_upstream_provides_it() {
2433 let response = streaming_chat_response(
2434 "chat-route-usage",
2435 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"stream_options":{"include_usage":true}}"#,
2436 vec![
2437 MockStreamFrame::Text("Hello"),
2438 MockStreamFrame::Finish("stop"),
2439 MockStreamFrame::Usage,
2440 MockStreamFrame::Done,
2441 ],
2442 )
2443 .await;
2444
2445 assert_eq!(response.status(), StatusCode::OK);
2446 let body = response_body(response).await;
2447 let data = sse_data(&body);
2448 assert_eq!(data.len(), 4);
2449 let usage_chunk: Value = serde_json::from_str(data[2]).expect("usage chunk should be JSON");
2450 assert_eq!(usage_chunk["choices"], json!([]));
2451 assert_eq!(usage_chunk["usage"]["total_tokens"], 3);
2452 assert_eq!(data[3], "[DONE]");
2453 }
2454
2455 #[tokio::test]
2456 async fn chat_route_returns_buffered_non_streaming_completion() {
2457 let response = chat_response(
2458 "chat-route-non-streaming-success",
2459 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
2460 vec![
2461 MockStreamFrame::NullContent,
2462 MockStreamFrame::EmptyContent,
2463 MockStreamFrame::Text("Hello"),
2464 MockStreamFrame::Text(" world"),
2465 MockStreamFrame::Finish("stop"),
2466 MockStreamFrame::Done,
2467 ],
2468 )
2469 .await;
2470
2471 assert_eq!(response.status(), StatusCode::OK);
2472 assert_eq!(
2473 response.headers().get(HEADER_PROXY_E2EE).unwrap(),
2474 "verified"
2475 );
2476 let body = json_body(response).await;
2477 assert_eq!(body["object"], "chat.completion");
2478 assert_eq!(body["id"], "chatcmpl-upstream-test");
2479 assert_eq!(body["created"], 1_717_171_717);
2480 assert_eq!(body["model"], "e2ee-test");
2481 assert_eq!(body["choices"][0]["index"], 0);
2482 assert_eq!(body["choices"][0]["message"]["role"], "assistant");
2483 assert_eq!(body["choices"][0]["message"]["content"], "Hello world");
2484 assert_eq!(body["choices"][0]["finish_reason"], "stop");
2485 assert!(body["usage"].is_null());
2486 }
2487
2488 #[tokio::test]
2489 async fn chat_route_returns_buffered_reasoning_content() {
2490 let response = chat_response(
2491 "chat-route-reasoning-non-streaming",
2492 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false,"reasoning_effort":"medium"}"#,
2493 vec![
2494 MockStreamFrame::Reasoning("Think "),
2495 MockStreamFrame::Reasoning("first."),
2496 MockStreamFrame::Text("Answer"),
2497 MockStreamFrame::Finish("stop"),
2498 MockStreamFrame::Done,
2499 ],
2500 )
2501 .await;
2502
2503 assert_eq!(response.status(), StatusCode::OK);
2504 let body = json_body(response).await;
2505 assert_eq!(
2506 body["choices"][0]["message"]["reasoning_content"],
2507 "Think first."
2508 );
2509 assert_eq!(body["choices"][0]["message"]["content"], "Answer");
2510 }
2511
2512 #[tokio::test]
2513 async fn chat_route_treats_omitted_stream_as_buffered_non_streaming() {
2514 let response = chat_response(
2515 "chat-route-omitted-stream",
2516 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}]}"#,
2517 vec![MockStreamFrame::Text("Hello"), MockStreamFrame::Done],
2518 )
2519 .await;
2520
2521 assert_eq!(response.status(), StatusCode::OK);
2522 let body = json_body(response).await;
2523 assert_eq!(body["object"], "chat.completion");
2524 assert_eq!(body["choices"][0]["message"]["content"], "Hello");
2525 assert_eq!(body["choices"][0]["finish_reason"], "stop");
2526 }
2527
2528 #[tokio::test]
2529 async fn chat_route_streams_incremental_tool_call_chunks() {
2530 let response = streaming_chat_response(
2531 "chat-route-tool-stream",
2532 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2533 vec![
2534 MockStreamFrame::Text("<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}\n</tool_call>"),
2535 MockStreamFrame::Finish("stop"),
2536 MockStreamFrame::Done,
2537 ],
2538 )
2539 .await;
2540
2541 assert_eq!(response.status(), StatusCode::OK);
2542 let body = response_body(response).await;
2543 let chunks = sse_json_chunks(&body);
2544
2545 assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
2546
2547 let tool_calls = streamed_tool_call_deltas(&chunks);
2548 assert!(!tool_calls.is_empty());
2549 let first = tool_calls[0];
2550 assert_eq!(first["index"], 0);
2551 assert!(first["id"].as_str().unwrap().starts_with("call_"));
2552 assert_eq!(first["type"], "function");
2553 assert_eq!(first["function"]["name"], "search_web");
2554 for later in &tool_calls[1..] {
2555 assert!(later.get("id").is_none());
2556 assert!(later.get("type").is_none());
2557 assert!(later["function"].get("name").is_none());
2558 }
2559 assert_eq!(
2560 streamed_tool_call_arguments(&chunks, 0),
2561 r#"{"query":"example"}"#
2562 );
2563
2564 let final_chunk = chunks.last().expect("stream should have chunks");
2565 assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
2566 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2567 }
2568
2569 #[tokio::test]
2570 async fn chat_route_streams_text_then_incremental_tool_call() {
2571 let response = streaming_chat_response(
2572 "chat-route-tool-stream-mixed-text",
2573 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2574 vec![
2575 MockStreamFrame::NullContent,
2576 MockStreamFrame::EmptyContent,
2577 MockStreamFrame::Text("I'll check that. "),
2578 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}"),
2579 MockStreamFrame::Text("</tool_call>"),
2580 MockStreamFrame::Finish("stop"),
2581 MockStreamFrame::Done,
2582 ],
2583 )
2584 .await;
2585
2586 assert_eq!(response.status(), StatusCode::OK);
2587 let body = response_body(response).await;
2588 let chunks = sse_json_chunks(&body);
2589
2590 assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
2591 assert_eq!(streamed_content(&chunks), "I'll check that. ");
2592
2593 let tool_calls = streamed_tool_call_deltas(&chunks);
2594 assert!(!tool_calls.is_empty());
2595 assert_eq!(tool_calls[0]["function"]["name"], "search_web");
2596 assert_eq!(
2597 streamed_tool_call_arguments(&chunks, 0),
2598 r#"{"query":"example"}"#
2599 );
2600
2601 let final_chunk = chunks.last().expect("stream should have chunks");
2602 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2603 }
2604
2605 #[tokio::test]
2606 async fn chat_route_fails_closed_on_unterminated_streamed_tool_call() {
2607 let response = streaming_chat_response(
2610 "chat-route-tool-stream-missing-close",
2611 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2612 vec![
2613 MockStreamFrame::Text("I'll check that. "),
2614 MockStreamFrame::Text("<tool_call>{\"name\":"),
2615 MockStreamFrame::Finish("stop"),
2616 MockStreamFrame::Done,
2617 ],
2618 )
2619 .await;
2620
2621 assert_stream_body_fails(response).await;
2622 }
2623
2624 #[tokio::test]
2625 async fn chat_route_streams_hermes_format_tool_call_from_glm_model() {
2626 let response = streaming_chat_response(
2629 "chat-route-tool-stream-glm-hermes",
2630 r#"{"model":"e2ee-glm-5-1","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2631 vec![
2632 MockStreamFrame::Text("<tool_call>\n{\"name\":\"search_web\",\"arguments\":"),
2633 MockStreamFrame::Text("{\"query\":\"example\"}}\n</tool_call>"),
2634 MockStreamFrame::Finish("stop"),
2635 MockStreamFrame::Done,
2636 ],
2637 )
2638 .await;
2639
2640 assert_eq!(response.status(), StatusCode::OK);
2641 let body = response_body(response).await;
2642 let chunks = sse_json_chunks(&body);
2643
2644 let tool_calls = streamed_tool_call_deltas(&chunks);
2645 assert!(!tool_calls.is_empty());
2646 assert_eq!(tool_calls[0]["function"]["name"], "search_web");
2647 assert_eq!(
2648 streamed_tool_call_arguments(&chunks, 0),
2649 r#"{"query":"example"}"#
2650 );
2651
2652 let final_chunk = chunks.last().expect("stream should have chunks");
2653 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2654 }
2655
2656 #[tokio::test]
2657 async fn chat_route_recovers_streamed_tool_call_with_truncated_closing_marker() {
2658 let response = streaming_chat_response(
2661 "chat-route-tool-stream-truncated-close",
2662 r#"{"model":"e2ee-glm-4-7-flash-p","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2663 vec![
2664 MockStreamFrame::Text("<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}\n"),
2665 MockStreamFrame::Finish("stop"),
2666 MockStreamFrame::Done,
2667 ],
2668 )
2669 .await;
2670
2671 assert_eq!(response.status(), StatusCode::OK);
2672 let body = response_body(response).await;
2673 let chunks = sse_json_chunks(&body);
2674
2675 let tool_calls = streamed_tool_call_deltas(&chunks);
2676 assert!(!tool_calls.is_empty());
2677 assert_eq!(tool_calls[0]["function"]["name"], "search_web");
2678 assert_eq!(
2679 streamed_tool_call_arguments(&chunks, 0),
2680 r#"{"query":"example"}"#
2681 );
2682
2683 let final_chunk = chunks.last().expect("stream should have chunks");
2684 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2685 }
2686
2687 #[tokio::test]
2688 async fn chat_route_streams_multiple_tool_calls_split_across_chunks() {
2689 let response = streaming_chat_response(
2690 "chat-route-tool-stream-multiple-calls",
2691 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2692 vec![
2693 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"first\"}}"),
2694 MockStreamFrame::Text("</tool_call><tool_call>{\"name\":\"search_web\",\"arguments\":"),
2695 MockStreamFrame::Text("{\"query\":\"second\"}}</tool_call>"),
2696 MockStreamFrame::Finish("stop"),
2697 MockStreamFrame::Done,
2698 ],
2699 )
2700 .await;
2701
2702 assert_eq!(response.status(), StatusCode::OK);
2703 let body = response_body(response).await;
2704 let chunks = sse_json_chunks(&body);
2705
2706 assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
2707 let tool_calls = streamed_tool_call_deltas(&chunks);
2708 let first = tool_calls
2709 .iter()
2710 .find(|tool_call| tool_call["index"] == 0 && tool_call.get("id").is_some())
2711 .expect("first call should have an id-bearing fragment");
2712 let second = tool_calls
2713 .iter()
2714 .find(|tool_call| tool_call["index"] == 1 && tool_call.get("id").is_some())
2715 .expect("second call should have an id-bearing fragment");
2716 assert_eq!(first["function"]["name"], "search_web");
2717 assert_eq!(second["function"]["name"], "search_web");
2718 assert_ne!(first["id"], second["id"]);
2719 assert_eq!(
2720 streamed_tool_call_arguments(&chunks, 0),
2721 r#"{"query":"first"}"#
2722 );
2723 assert_eq!(
2724 streamed_tool_call_arguments(&chunks, 1),
2725 r#"{"query":"second"}"#
2726 );
2727
2728 let final_chunk = chunks.last().expect("stream should have chunks");
2729 assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
2730 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2731 }
2732
2733 #[tokio::test]
2734 async fn chat_route_tool_stream_passes_through_usage_chunk_when_requested() {
2735 let response = streaming_chat_response(
2736 "chat-route-tool-stream-usage",
2737 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"stream_options":{"include_usage":true},"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2738 vec![
2739 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
2740 MockStreamFrame::Finish("stop"),
2741 MockStreamFrame::Usage,
2742 MockStreamFrame::Done,
2743 ],
2744 )
2745 .await;
2746
2747 assert_eq!(response.status(), StatusCode::OK);
2748 let body = response_body(response).await;
2749 let chunks = sse_json_chunks(&body);
2750
2751 let usage_chunk = chunks.last().expect("stream should have chunks");
2753 assert_eq!(usage_chunk["choices"], json!([]));
2754 assert_eq!(usage_chunk["usage"]["total_tokens"], 3);
2755 let finish_chunk = &chunks[chunks.len() - 2];
2756 assert_eq!(finish_chunk["choices"][0]["finish_reason"], "tool_calls");
2757 }
2758
2759 #[tokio::test]
2760 async fn chat_route_fails_closed_when_streamed_tool_call_exceeds_max_bytes() {
2761 let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
2762 let base_url = spawn_streaming_venice_server(
2763 model_public_key,
2764 true,
2765 vec![
2766 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"this argument body is much longer than the configured cap\"}}</tool_call>"),
2767 MockStreamFrame::Finish("stop"),
2768 MockStreamFrame::Done,
2769 ],
2770 )
2771 .await;
2772 let mut config = chat_config_with_basic_test_attestation();
2773 config.tools.tool_call_max_bytes = 16;
2774
2775 let response = request_chat_with_config(
2776 config,
2777 "chat-route-tool-stream-max-bytes",
2778 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2779 base_url,
2780 )
2781 .await;
2782
2783 assert_stream_body_fails(response).await;
2784 }
2785
2786 #[tokio::test]
2787 async fn chat_route_streams_all_tool_calls_when_parallel_tool_calls_false() {
2788 let response = streaming_chat_response(
2791 "chat-route-tool-stream-parallel-false",
2792 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"parallel_tool_calls":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2793 vec![
2794 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"first\"}}</tool_call>"),
2795 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"second\"}}</tool_call>"),
2796 MockStreamFrame::Finish("stop"),
2797 MockStreamFrame::Done,
2798 ],
2799 )
2800 .await;
2801
2802 assert_eq!(response.status(), StatusCode::OK);
2803 let body = response_body(response).await;
2804 let chunks = sse_json_chunks(&body);
2805
2806 assert_eq!(
2807 streamed_tool_call_arguments(&chunks, 0),
2808 r#"{"query":"first"}"#
2809 );
2810 assert_eq!(
2811 streamed_tool_call_arguments(&chunks, 1),
2812 r#"{"query":"second"}"#
2813 );
2814
2815 let final_chunk = chunks.last().expect("stream should have chunks");
2816 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2817 }
2818
2819 #[tokio::test]
2820 async fn chat_route_returns_non_streaming_tool_call_body_from_mixed_text() {
2821 let response = chat_response(
2822 "chat-route-tool-non-stream-mixed-text",
2823 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
2824 vec![
2825 MockStreamFrame::Text("I'll check that. <tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
2826 MockStreamFrame::Done,
2827 ],
2828 )
2829 .await;
2830
2831 assert_eq!(response.status(), StatusCode::OK);
2832 let body = json_body(response).await;
2833 assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
2834 let tool_call = &body["choices"][0]["message"]["tool_calls"][0];
2835 assert_eq!(tool_call["function"]["name"], "search_web");
2836 assert_eq!(tool_call["function"]["arguments"], r#"{"query":"example"}"#);
2837 }
2838
2839 #[tokio::test]
2840 async fn chat_route_returns_non_streaming_tool_call_body() {
2841 let response = chat_response(
2842 "chat-route-tool-non-stream",
2843 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
2844 vec![
2845 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
2846 MockStreamFrame::Done,
2847 ],
2848 )
2849 .await;
2850
2851 assert_eq!(response.status(), StatusCode::OK);
2852 let body = json_body(response).await;
2853 assert_eq!(body["object"], "chat.completion");
2854 assert!(body["choices"][0]["message"]["content"].is_null());
2855 assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
2856 let tool_call = &body["choices"][0]["message"]["tool_calls"][0];
2857 assert!(tool_call["id"].as_str().unwrap().starts_with("call_"));
2858 assert_eq!(tool_call["type"], "function");
2859 assert_eq!(tool_call["function"]["name"], "search_web");
2860 assert_eq!(tool_call["function"]["arguments"], r#"{"query":"example"}"#);
2861 }
2862
2863 #[tokio::test]
2864 async fn chat_route_returns_non_streaming_multiple_tool_calls() {
2865 let response = chat_response(
2866 "chat-route-tool-non-stream-multiple-calls",
2867 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
2868 vec![
2869 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"first\"}}</tool_call>\n<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"second\"}}</tool_call>"),
2870 MockStreamFrame::Done,
2871 ],
2872 )
2873 .await;
2874
2875 assert_eq!(response.status(), StatusCode::OK);
2876 let body = json_body(response).await;
2877 assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
2878 assert!(body["choices"][0]["message"]["content"].is_null());
2879 let tool_calls = body["choices"][0]["message"]["tool_calls"]
2880 .as_array()
2881 .expect("tool_calls should be an array");
2882 assert_eq!(tool_calls.len(), 2);
2883 assert_eq!(tool_calls[0]["function"]["name"], "search_web");
2884 assert_eq!(
2885 tool_calls[0]["function"]["arguments"],
2886 r#"{"query":"first"}"#
2887 );
2888 assert_eq!(
2889 tool_calls[1]["function"]["arguments"],
2890 r#"{"query":"second"}"#
2891 );
2892 assert_ne!(tool_calls[0]["id"], tool_calls[1]["id"]);
2893 }
2894
2895 #[tokio::test]
2896 async fn chat_route_tool_mode_leaves_normal_text_unaffected() {
2897 let response = streaming_chat_response(
2898 "chat-route-tool-normal-text",
2899 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object"}}}]}"#,
2900 vec![
2901 MockStreamFrame::Text("Hello without tools"),
2902 MockStreamFrame::Finish("stop"),
2903 MockStreamFrame::Done,
2904 ],
2905 )
2906 .await;
2907
2908 assert_eq!(response.status(), StatusCode::OK);
2909 let body = response_body(response).await;
2910 let chunks = sse_json_chunks(&body);
2911 assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
2912 assert_eq!(streamed_content(&chunks), "Hello without tools");
2913 assert!(streamed_tool_call_deltas(&chunks).is_empty());
2914 }
2915
2916 #[tokio::test]
2917 async fn chat_route_treats_marker_like_non_protocol_text_as_normal_text() {
2918 let response = streaming_chat_response(
2919 "chat-route-tool-marker-like-text",
2920 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object"}}}]}"#,
2921 vec![
2922 MockStreamFrame::Text("<tool_cal>{not actually a marker}"),
2923 MockStreamFrame::Finish("stop"),
2924 MockStreamFrame::Done,
2925 ],
2926 )
2927 .await;
2928
2929 assert_eq!(response.status(), StatusCode::OK);
2930 let body = response_body(response).await;
2931 let chunks = sse_json_chunks(&body);
2932 assert_eq!(
2933 streamed_content(&chunks),
2934 "<tool_cal>{not actually a marker}"
2935 );
2936 assert!(streamed_tool_call_deltas(&chunks).is_empty());
2937 }
2938
2939 #[tokio::test]
2940 async fn chat_route_retries_invalid_tool_call_and_returns_success() {
2941 let response = chat_response_sequence(
2942 "chat-route-tool-retry-success",
2943 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
2944 vec![
2945 vec![
2946 MockStreamFrame::Text("<tool_call>{\"name\":\"unknown\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
2947 MockStreamFrame::Done,
2948 ],
2949 vec![
2950 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
2951 MockStreamFrame::Done,
2952 ],
2953 ],
2954 )
2955 .await;
2956
2957 assert_eq!(response.status(), StatusCode::OK);
2958 assert_eq!(
2959 response.headers().get(HEADER_PROXY_TOOL_RETRIES).unwrap(),
2960 "1"
2961 );
2962 let body = json_body(response).await;
2963 assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
2964 assert_eq!(
2965 body["choices"][0]["message"]["tool_calls"][0]["function"]["name"],
2966 "search_web"
2967 );
2968 }
2969
2970 #[tokio::test]
2971 async fn chat_route_returns_retry_failure_error_shape() {
2972 let response = chat_response(
2973 "chat-route-tool-retry-failure",
2974 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
2975 vec![
2976 MockStreamFrame::Text("<tool_call>{\"name\":\"unknown\",\"arguments\":{}}</tool_call>"),
2977 MockStreamFrame::Done,
2978 ],
2979 )
2980 .await;
2981
2982 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
2983 assert_eq!(
2984 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
2985 "invalid_tool_call"
2986 );
2987 let body = json_body(response).await;
2988 assert_eq!(body["error"]["type"], "proxy_tool_call_error");
2989 assert_eq!(body["error"]["code"], "invalid_tool_call");
2990 assert_eq!(body["error"]["details"]["max_retries"], 2);
2991 assert!(
2992 body["error"]["details"]["last_validation_error"]
2993 .as_str()
2994 .unwrap()
2995 .contains("unknown tool name")
2996 );
2997 }
2998
2999 #[tokio::test]
3000 async fn chat_route_non_streaming_fails_closed_on_upstream_error_response() {
3001 let response = chat_response_with_upstream_status(
3002 "chat-route-non-streaming-upstream-error",
3003 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3004 StatusCode::INTERNAL_SERVER_ERROR,
3005 )
3006 .await;
3007
3008 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
3009 assert_eq!(
3010 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3011 "upstream_status_error"
3012 );
3013 let body = error_body(response).await;
3014 assert_eq!(body.error.kind, "proxy_upstream_error");
3015 assert_eq!(body.error.code, "upstream_status_error");
3016 }
3017
3018 #[tokio::test]
3019 async fn chat_route_non_streaming_fails_closed_on_malformed_upstream_payload() {
3020 let response = chat_response(
3021 "chat-route-non-streaming-malformed",
3022 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3023 vec![MockStreamFrame::Raw("data: {\"choices\":\"bad\"}\n\n")],
3024 )
3025 .await;
3026
3027 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
3028 assert_eq!(
3029 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3030 "upstream_malformed_response"
3031 );
3032 let body = error_body(response).await;
3033 assert_eq!(body.error.kind, "proxy_upstream_error");
3034 assert_eq!(body.error.code, "upstream_malformed_response");
3035 }
3036
3037 #[tokio::test]
3038 async fn chat_route_non_streaming_fails_closed_on_missing_encrypted_content() {
3039 let response = chat_response(
3040 "chat-route-non-streaming-missing-content",
3041 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3042 vec![MockStreamFrame::Finish("stop"), MockStreamFrame::Done],
3043 )
3044 .await;
3045
3046 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
3047 assert_eq!(
3048 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3049 "e2ee_response_decryption_failed"
3050 );
3051 let body = error_body(response).await;
3052 assert_eq!(body.error.kind, "proxy_e2ee_error");
3053 assert_eq!(body.error.code, "e2ee_response_decryption_failed");
3054 }
3055
3056 #[tokio::test]
3057 async fn chat_route_non_streaming_fails_closed_on_decryption_failure() {
3058 let response = chat_response(
3059 "chat-route-non-streaming-decryption-failure",
3060 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3061 vec![MockStreamFrame::TextForWrongRecipient(" secret"), MockStreamFrame::Done],
3062 )
3063 .await;
3064
3065 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
3066 assert_eq!(
3067 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3068 "e2ee_response_decryption_failed"
3069 );
3070 let body = error_body(response).await;
3071 assert_eq!(body.error.kind, "proxy_e2ee_error");
3072 assert_eq!(body.error.code, "e2ee_response_decryption_failed");
3073 }
3074
3075 #[tokio::test]
3076 async fn chat_route_non_streaming_passes_through_usage_when_available() {
3077 let response = chat_response(
3078 "chat-route-non-streaming-usage",
3079 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3080 vec![
3081 MockStreamFrame::Text("Hello"),
3082 MockStreamFrame::Finish("stop"),
3083 MockStreamFrame::Usage,
3084 MockStreamFrame::Done,
3085 ],
3086 )
3087 .await;
3088
3089 assert_eq!(response.status(), StatusCode::OK);
3090 let body = json_body(response).await;
3091 assert_eq!(body["choices"][0]["message"]["content"], "Hello");
3092 assert_eq!(body["usage"]["prompt_tokens"], 1);
3093 assert_eq!(body["usage"]["completion_tokens"], 2);
3094 assert_eq!(body["usage"]["total_tokens"], 3);
3095 }
3096
3097 #[tokio::test]
3098 async fn chat_route_fails_closed_on_upstream_stream_error_event() {
3099 let response = streaming_chat_response(
3100 "chat-route-upstream-error",
3101 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
3102 vec![MockStreamFrame::Error("model failed")],
3103 )
3104 .await;
3105
3106 assert_stream_body_fails(response).await;
3107 }
3108
3109 #[tokio::test]
3110 async fn chat_route_fails_closed_on_malformed_upstream_event() {
3111 let response = streaming_chat_response(
3112 "chat-route-malformed-event",
3113 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
3114 vec![MockStreamFrame::Raw("data: {\"choices\":\n\n")],
3115 )
3116 .await;
3117
3118 assert_stream_body_fails(response).await;
3119 }
3120
3121 #[tokio::test]
3122 async fn chat_route_fails_closed_on_decryption_failure_mid_stream() {
3123 let response = streaming_chat_response(
3124 "chat-route-decryption-failure",
3125 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
3126 vec![
3127 MockStreamFrame::Text("Hello"),
3128 MockStreamFrame::TextForWrongRecipient(" secret"),
3129 MockStreamFrame::Done,
3130 ],
3131 )
3132 .await;
3133
3134 assert_stream_body_fails(response).await;
3135 }
3136
3137 #[tokio::test]
3138 async fn chat_route_synthesizes_final_finish_chunk_before_done_when_needed() {
3139 let response = streaming_chat_response(
3140 "chat-route-final-done",
3141 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
3142 vec![MockStreamFrame::Text("Hello"), MockStreamFrame::Done],
3143 )
3144 .await;
3145
3146 assert_eq!(response.status(), StatusCode::OK);
3147 let body = response_body(response).await;
3148 let data = sse_data(&body);
3149 assert_eq!(data.len(), 3);
3150 let final_chunk: Value = serde_json::from_str(data[1]).expect("final chunk should be JSON");
3151 assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
3152 assert_eq!(final_chunk["choices"][0]["finish_reason"], "stop");
3153 assert_eq!(data[2], "[DONE]");
3154 }
3155
3156 #[tokio::test]
3157 async fn chat_route_attestation_failure_prevents_request_construction() {
3158 let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
3159 let base_url = spawn_attestation_server(model_public_key, false).await;
3160 let app = router_with_venice_client(
3161 chat_config_with_basic_test_attestation(),
3162 test_venice_client_for_base_url(base_url),
3163 );
3164
3165 let response = app
3166 .oneshot(
3167 Request::builder()
3168 .method(Method::POST)
3169 .uri("/v1/chat/completions")
3170 .header("content-type", "application/json")
3171 .header(HEADER_PROXY_SESSION_ID, "chat-route-attestation-failure")
3172 .body(Body::from(
3173 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3174 ))
3175 .expect("request should build"),
3176 )
3177 .await
3178 .expect("request should complete");
3179
3180 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
3181 assert_eq!(
3182 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3183 "attestation_upstream_not_verified"
3184 );
3185 let body = error_body(response).await;
3186 assert_eq!(body.error.kind, "proxy_attestation_error");
3187 assert_eq!(body.error.code, "attestation_upstream_not_verified");
3188 }
3189
3190 #[tokio::test]
3191 async fn unknown_route_returns_openai_style_not_found() {
3192 let response = test_app()
3193 .oneshot(
3194 Request::builder()
3195 .uri("/v1/unknown")
3196 .body(Body::empty())
3197 .expect("request should build"),
3198 )
3199 .await
3200 .expect("request should complete");
3201
3202 assert_eq!(response.status(), StatusCode::NOT_FOUND);
3203 assert_eq!(
3204 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3205 "not_found"
3206 );
3207 let body = error_body(response).await;
3208 assert_eq!(body.error.kind, "invalid_request_error");
3209 assert_eq!(body.error.code, "not_found");
3210 }
3211
3212 #[tokio::test]
3213 async fn unsupported_method_returns_openai_style_method_error() {
3214 let response = test_app()
3215 .oneshot(
3216 Request::builder()
3217 .method(Method::POST)
3218 .uri("/v1/models")
3219 .body(Body::empty())
3220 .expect("request should build"),
3221 )
3222 .await
3223 .expect("request should complete");
3224
3225 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
3226 assert_eq!(
3227 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3228 "method_not_allowed"
3229 );
3230 let body = error_body(response).await;
3231 assert_eq!(body.error.kind, "invalid_request_error");
3232 assert_eq!(body.error.code, "method_not_allowed");
3233 }
3234
3235 #[tokio::test]
3236 async fn malformed_chat_json_uses_axum_extractor_rejection() {
3237 let response = test_app()
3238 .oneshot(
3239 Request::builder()
3240 .method(Method::POST)
3241 .uri("/v1/chat/completions")
3242 .header("content-type", "application/json")
3243 .body(Body::from("{"))
3244 .expect("request should build"),
3245 )
3246 .await
3247 .expect("request should complete");
3248
3249 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
3250 assert!(response.headers().get(HEADER_PROXY_ERROR_CODE).is_none());
3251 }
3252
3253 #[tokio::test]
3254 async fn non_object_chat_json_returns_structured_invalid_request() {
3255 let response = test_app()
3256 .oneshot(
3257 Request::builder()
3258 .method(Method::POST)
3259 .uri("/v1/chat/completions")
3260 .header("content-type", "application/json")
3261 .body(Body::from("[]"))
3262 .expect("request should build"),
3263 )
3264 .await
3265 .expect("request should complete");
3266
3267 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
3268 assert_eq!(
3269 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3270 "invalid_request"
3271 );
3272 let body = error_body(response).await;
3273 assert_eq!(body.error.kind, "invalid_request_error");
3274 assert_eq!(body.error.code, "invalid_request");
3275 }
3276
3277 #[derive(Debug, Clone)]
3278 enum MockStreamFrame {
3279 Role,
3280 NullContent,
3281 EmptyContent,
3282 Text(&'static str),
3283 Reasoning(&'static str),
3284 TextForWrongRecipient(&'static str),
3285 Finish(&'static str),
3286 Usage,
3287 Done,
3288 Error(&'static str),
3289 Raw(&'static str),
3290 }
3291
3292 async fn streaming_chat_response(
3293 session_id: &'static str,
3294 request_body: &'static str,
3295 frames: Vec<MockStreamFrame>,
3296 ) -> Response {
3297 chat_response(session_id, request_body, frames).await
3298 }
3299
3300 async fn chat_response(
3301 session_id: &'static str,
3302 request_body: &'static str,
3303 frames: Vec<MockStreamFrame>,
3304 ) -> Response {
3305 let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
3306 let base_url = spawn_streaming_venice_server(model_public_key, true, frames).await;
3307 request_chat(session_id, request_body, base_url).await
3308 }
3309
3310 async fn chat_response_sequence(
3311 session_id: &'static str,
3312 request_body: &'static str,
3313 attempts: Vec<Vec<MockStreamFrame>>,
3314 ) -> Response {
3315 let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
3316 let base_url =
3317 spawn_streaming_venice_server_sequence(model_public_key, true, attempts).await;
3318 request_chat(session_id, request_body, base_url).await
3319 }
3320
3321 async fn chat_response_with_upstream_status(
3322 session_id: &'static str,
3323 request_body: &'static str,
3324 upstream_status: StatusCode,
3325 ) -> Response {
3326 let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
3327 let base_url =
3328 spawn_venice_server_with_chat_status(model_public_key, upstream_status).await;
3329 request_chat(session_id, request_body, base_url).await
3330 }
3331
3332 async fn request_chat(
3333 session_id: &'static str,
3334 request_body: &'static str,
3335 base_url: String,
3336 ) -> Response {
3337 request_chat_with_config(
3338 chat_config_with_basic_test_attestation(),
3339 session_id,
3340 request_body,
3341 base_url,
3342 )
3343 .await
3344 }
3345
3346 async fn request_chat_with_config(
3347 config: ProxyConfig,
3348 session_id: &'static str,
3349 request_body: &'static str,
3350 base_url: String,
3351 ) -> Response {
3352 let app = router_with_venice_client(config, test_venice_client_for_base_url(base_url));
3353
3354 app.oneshot(
3355 Request::builder()
3356 .method(Method::POST)
3357 .uri("/v1/chat/completions")
3358 .header("content-type", "application/json")
3359 .header(HEADER_PROXY_SESSION_ID, session_id)
3360 .body(Body::from(request_body))
3361 .expect("request should build"),
3362 )
3363 .await
3364 .expect("request should complete")
3365 }
3366
3367 async fn json_body(response: Response) -> Value {
3368 let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
3369 .await
3370 .expect("response body should buffer");
3371 serde_json::from_slice(&bytes).expect("response should be JSON")
3372 }
3373
3374 async fn response_body(response: Response) -> String {
3375 let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
3376 .await
3377 .expect("response body should buffer");
3378 String::from_utf8(bytes.to_vec()).expect("response body should be UTF-8")
3379 }
3380
3381 async fn assert_stream_body_fails(response: Response) {
3382 assert_eq!(response.status(), StatusCode::OK);
3383 let result = axum::body::to_bytes(response.into_body(), usize::MAX).await;
3384 assert!(
3385 result.is_err(),
3386 "stream body should fail closed instead of completing successfully"
3387 );
3388 }
3389
3390 fn sse_data(body: &str) -> Vec<&str> {
3391 body.lines()
3392 .filter_map(|line| line.strip_prefix("data: "))
3393 .collect()
3394 }
3395
3396 fn sse_json_chunks(body: &str) -> Vec<Value> {
3398 let data = sse_data(body);
3399 assert_eq!(data.last().copied(), Some("[DONE]"));
3400 data[..data.len() - 1]
3401 .iter()
3402 .map(|chunk| serde_json::from_str(chunk).expect("SSE chunk should be JSON"))
3403 .collect()
3404 }
3405
3406 fn streamed_content(chunks: &[Value]) -> String {
3408 chunks
3409 .iter()
3410 .filter_map(|chunk| chunk["choices"][0]["delta"]["content"].as_str())
3411 .collect()
3412 }
3413
3414 fn streamed_tool_call_deltas(chunks: &[Value]) -> Vec<&Value> {
3416 chunks
3417 .iter()
3418 .filter_map(|chunk| chunk["choices"][0]["delta"]["tool_calls"].as_array())
3419 .flatten()
3420 .collect()
3421 }
3422
3423 fn streamed_tool_call_arguments(chunks: &[Value], index: u64) -> String {
3425 streamed_tool_call_deltas(chunks)
3426 .iter()
3427 .filter(|tool_call| tool_call["index"] == json!(index))
3428 .filter_map(|tool_call| tool_call["function"]["arguments"].as_str())
3429 .collect()
3430 }
3431
3432 async fn spawn_streaming_venice_server(
3433 model_public_key: String,
3434 verified: bool,
3435 frames: Vec<MockStreamFrame>,
3436 ) -> String {
3437 spawn_streaming_venice_server_sequence(model_public_key, verified, vec![frames]).await
3438 }
3439
3440 async fn spawn_streaming_venice_server_sequence(
3441 model_public_key: String,
3442 verified: bool,
3443 attempts: Vec<Vec<MockStreamFrame>>,
3444 ) -> String {
3445 let chat_attempts = Arc::new(Mutex::new(VecDeque::from(attempts)));
3446 let attestation_key = model_public_key.clone();
3447 let app = Router::new()
3448 .route(
3449 "/api/v1/tee/attestation",
3450 get(move |Query(query): Query<HashMap<String, String>>| {
3451 let model_public_key = attestation_key.clone();
3452 async move {
3453 Json(json!({
3454 "attestation": {
3455 "verified": verified,
3456 "nonce": query.get("nonce").cloned().unwrap_or_default(),
3457 "model": query.get("model").cloned().unwrap_or_default(),
3458 "tee_provider": "tdx",
3459 "signing_key": model_public_key,
3460 }
3461 }))
3462 }
3463 }),
3464 )
3465 .route(
3466 "/api/v1/chat/completions",
3467 post(move |headers: HeaderMap, Json(body): Json<Value>| {
3468 let chat_attempts = chat_attempts.clone();
3469 async move {
3470 let Some(client_public_key) = headers
3471 .get(crate::venice::HEADER_VENICE_TEE_CLIENT_PUB_KEY)
3472 .and_then(|value| value.to_str().ok())
3473 else {
3474 return (
3475 StatusCode::BAD_REQUEST,
3476 [("content-type", "text/plain")],
3477 "missing client key".to_owned(),
3478 );
3479 };
3480 if body.get("stream").and_then(Value::as_bool) != Some(true) {
3481 return (
3482 StatusCode::BAD_REQUEST,
3483 [("content-type", "text/plain")],
3484 "upstream request must stream".to_owned(),
3485 );
3486 }
3487 let messages = body.get("messages").and_then(Value::as_array);
3488 if messages.is_none_or(|messages| {
3489 messages.is_empty()
3490 || !messages.iter().all(|message| {
3491 message.get("role").and_then(Value::as_str).is_some()
3492 && message
3493 .get("content")
3494 .and_then(Value::as_str)
3495 .is_some_and(|content| {
3496 !content.is_empty()
3497 && content
3498 .chars()
3499 .all(|ch| ch.is_ascii_hexdigit())
3500 })
3501 })
3502 }) {
3503 return (
3504 StatusCode::BAD_REQUEST,
3505 [("content-type", "text/plain")],
3506 "messages must be encrypted message objects".to_owned(),
3507 );
3508 }
3509
3510 let frames = {
3511 let mut attempts = chat_attempts
3512 .lock()
3513 .expect("mock chat attempts mutex should not be poisoned");
3514 if attempts.len() > 1 {
3515 attempts.pop_front().expect("attempts length checked above")
3516 } else {
3517 attempts.front().cloned().unwrap_or_default()
3518 }
3519 };
3520
3521 (
3522 StatusCode::OK,
3523 [("content-type", "text/event-stream")],
3524 render_mock_sse(&frames, client_public_key),
3525 )
3526 }
3527 }),
3528 );
3529 let listener = TcpListener::bind(("127.0.0.1", 0))
3530 .await
3531 .expect("mock Venice listener should bind");
3532 let addr = listener
3533 .local_addr()
3534 .expect("mock Venice listener should have local address");
3535
3536 tokio::spawn(async move {
3537 axum::serve(listener, app)
3538 .await
3539 .expect("mock Venice server should run");
3540 });
3541
3542 format!("http://{addr}/api/v1")
3543 }
3544
3545 async fn spawn_venice_server_with_chat_status(
3546 model_public_key: String,
3547 upstream_status: StatusCode,
3548 ) -> String {
3549 let attestation_key = model_public_key.clone();
3550 let app = Router::new()
3551 .route(
3552 "/api/v1/tee/attestation",
3553 get(move |Query(query): Query<HashMap<String, String>>| {
3554 let model_public_key = attestation_key.clone();
3555 async move {
3556 Json(json!({
3557 "attestation": {
3558 "verified": true,
3559 "nonce": query.get("nonce").cloned().unwrap_or_default(),
3560 "model": query.get("model").cloned().unwrap_or_default(),
3561 "tee_provider": "tdx",
3562 "signing_key": model_public_key,
3563 }
3564 }))
3565 }
3566 }),
3567 )
3568 .route(
3569 "/api/v1/chat/completions",
3570 post(move || async move { upstream_status }),
3571 );
3572 let listener = TcpListener::bind(("127.0.0.1", 0))
3573 .await
3574 .expect("mock Venice listener should bind");
3575 let addr = listener
3576 .local_addr()
3577 .expect("mock Venice listener should have local address");
3578
3579 tokio::spawn(async move {
3580 axum::serve(listener, app)
3581 .await
3582 .expect("mock Venice server should run");
3583 });
3584
3585 format!("http://{addr}/api/v1")
3586 }
3587
3588 fn render_mock_sse(frames: &[MockStreamFrame], client_public_key: &str) -> String {
3589 let codec = E2eeCodec::default();
3590 let mut output = String::new();
3591 for frame in frames {
3592 match frame {
3593 MockStreamFrame::Role => {
3594 output.push_str(&format!("data: {}\n\n", upstream_role_chunk()));
3595 }
3596 MockStreamFrame::NullContent => {
3597 output.push_str(&format!("data: {}\n\n", upstream_null_content_chunk()));
3598 }
3599 MockStreamFrame::EmptyContent => {
3600 output.push_str(&format!(
3601 "data: {}\n\n",
3602 upstream_content_chunk(String::new())
3603 ));
3604 }
3605 MockStreamFrame::Text(content) => {
3606 let encrypted = codec
3607 .encrypt_content(content, client_public_key)
3608 .expect("mock content should encrypt")
3609 .into_hex();
3610 output.push_str(&format!("data: {}\n\n", upstream_content_chunk(encrypted)));
3611 }
3612 MockStreamFrame::Reasoning(content) => {
3613 let encrypted = codec
3614 .encrypt_content(content, client_public_key)
3615 .expect("mock reasoning content should encrypt")
3616 .into_hex();
3617 output.push_str(&format!(
3618 "data: {}\n\n",
3619 upstream_reasoning_content_chunk(encrypted)
3620 ));
3621 }
3622 MockStreamFrame::TextForWrongRecipient(content) => {
3623 let wrong_key = ProxyInstanceKey::generate();
3624 let encrypted = codec
3625 .encrypt_content(content, wrong_key.public_key_hex())
3626 .expect("mock content should encrypt")
3627 .into_hex();
3628 output.push_str(&format!("data: {}\n\n", upstream_content_chunk(encrypted)));
3629 }
3630 MockStreamFrame::Finish(reason) => {
3631 output.push_str(&format!("data: {}\n\n", upstream_finish_chunk(reason)));
3632 }
3633 MockStreamFrame::Usage => {
3634 output.push_str(&format!("data: {}\n\n", upstream_usage_chunk()));
3635 }
3636 MockStreamFrame::Done => output.push_str("data: [DONE]\n\n"),
3637 MockStreamFrame::Error(message) => {
3638 output.push_str(&format!(
3639 "event: error\ndata: {}\n\n",
3640 json!({ "message": message })
3641 ));
3642 }
3643 MockStreamFrame::Raw(raw) => output.push_str(raw),
3644 }
3645 }
3646 output
3647 }
3648
3649 fn upstream_role_chunk() -> Value {
3650 json!({
3651 "id": "chatcmpl-upstream-test",
3652 "object": "chat.completion.chunk",
3653 "created": 1_717_171_717,
3654 "model": "e2ee-test",
3655 "choices": [{
3656 "index": 0,
3657 "delta": { "role": "assistant" },
3658 "finish_reason": null,
3659 }],
3660 })
3661 }
3662
3663 fn upstream_content_chunk(encrypted_content: String) -> Value {
3664 json!({
3665 "id": "chatcmpl-upstream-test",
3666 "object": "chat.completion.chunk",
3667 "created": 1_717_171_717,
3668 "model": "e2ee-test",
3669 "choices": [{
3670 "index": 0,
3671 "delta": { "content": encrypted_content },
3672 "finish_reason": null,
3673 }],
3674 })
3675 }
3676
3677 fn upstream_reasoning_content_chunk(encrypted_content: String) -> Value {
3678 json!({
3679 "id": "chatcmpl-upstream-test",
3680 "object": "chat.completion.chunk",
3681 "created": 1_717_171_717,
3682 "model": "e2ee-test",
3683 "choices": [{
3684 "index": 0,
3685 "delta": { "reasoning_content": encrypted_content },
3686 "finish_reason": null,
3687 }],
3688 })
3689 }
3690
3691 fn upstream_null_content_chunk() -> Value {
3692 json!({
3693 "id": "chatcmpl-upstream-test",
3694 "object": "chat.completion.chunk",
3695 "created": 1_717_171_717,
3696 "model": "e2ee-test",
3697 "choices": [{
3698 "index": 0,
3699 "delta": { "content": Value::Null },
3700 "finish_reason": null,
3701 }],
3702 })
3703 }
3704
3705 fn upstream_finish_chunk(reason: &str) -> Value {
3706 json!({
3707 "id": "chatcmpl-upstream-test",
3708 "object": "chat.completion.chunk",
3709 "created": 1_717_171_717,
3710 "model": "e2ee-test",
3711 "choices": [{
3712 "index": 0,
3713 "delta": {},
3714 "finish_reason": reason,
3715 }],
3716 })
3717 }
3718
3719 fn upstream_usage_chunk() -> Value {
3720 json!({
3721 "id": "chatcmpl-upstream-test",
3722 "object": "chat.completion.chunk",
3723 "created": 1_717_171_717,
3724 "model": "e2ee-test",
3725 "choices": [],
3726 "usage": {
3727 "prompt_tokens": 1,
3728 "completion_tokens": 2,
3729 "total_tokens": 3,
3730 },
3731 })
3732 }
3733
3734 async fn spawn_attestation_server(model_public_key: String, verified: bool) -> String {
3735 let app = Router::new().route(
3736 "/api/v1/tee/attestation",
3737 get(move |Query(query): Query<HashMap<String, String>>| {
3738 let model_public_key = model_public_key.clone();
3739 async move {
3740 Json(json!({
3741 "attestation": {
3742 "verified": verified,
3743 "nonce": query.get("nonce").cloned().unwrap_or_default(),
3744 "model": query.get("model").cloned().unwrap_or_default(),
3745 "signing_key": model_public_key,
3746 }
3747 }))
3748 }
3749 }),
3750 );
3751 let listener = TcpListener::bind(("127.0.0.1", 0))
3752 .await
3753 .expect("mock attestation listener should bind");
3754 let addr = listener
3755 .local_addr()
3756 .expect("mock attestation listener should have local address");
3757
3758 tokio::spawn(async move {
3759 axum::serve(listener, app)
3760 .await
3761 .expect("mock attestation server should run");
3762 });
3763
3764 format!("http://{addr}/api/v1")
3765 }
3766
3767 #[test]
3768 fn metadata_header_helper_only_emits_safe_config_headers_by_default() {
3769 let config = ProxyConfig::default();
3770 let metadata = ProxyMetadataHeaders::from_config(&config);
3771 let mut headers = HeaderMap::new();
3772
3773 metadata.apply(&mut headers);
3774
3775 assert_eq!(
3776 headers.get(HEADER_PROXY_ATTESTATION_MODE).unwrap(),
3777 "independent"
3778 );
3779 assert_eq!(headers.get(HEADER_PROXY_TOOL_MODE).unwrap(), "emulated");
3780 assert!(headers.get(HEADER_PROXY_E2EE).is_none());
3781 assert!(headers.get(HEADER_PROXY_KEY_BINDING).is_none());
3782 }
3783}