use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use sha2::{Digest, Sha256};
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use crate::error::{AgentLoopError, LlmErrorKind, Result};
use crate::llm_driver_registry::{
ChatDriver, LlmCallConfig, LlmCompletionMetadata, LlmContentPart, LlmMessage,
LlmMessageContent, LlmMessageRole, LlmResponseStream, LlmStreamEvent,
OpenRouterProviderRouting,
};
use crate::llm_retry::{
LlmRetryConfig, RateLimitInfo, RetryMetadata, is_rate_limit_status, is_transient_error,
};
use crate::openai_protocol::{
apply_openai_api_auth, is_openai_model_not_found, is_openai_request_too_large,
};
use crate::openresponses_types::{self as types, StreamingEvent};
use crate::provider::DriverId;
use crate::tool_types::{ToolCall, ToolDefinition};
use crate::user_facing_error::is_provider_quota_message;
const DEFAULT_API_URL: &str = "https://api.openai.com/v1/responses";
const OPENAI_PROMPT_CACHE_KEY_MAX_LEN: usize = 64;
const PROMPT_CACHE_KEY_PREFIX: &str = "everruns:";
#[derive(Clone)]
pub struct OpenResponsesProtocolChatDriver {
client: Client,
api_key: String,
api_url: String,
provider_type: DriverId,
retry_config: LlmRetryConfig,
}
impl OpenResponsesProtocolChatDriver {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
client: Client::new(),
api_key: api_key.into(),
api_url: DEFAULT_API_URL.to_string(),
provider_type: DriverId::OpenAI,
retry_config: LlmRetryConfig::default(),
}
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| AgentLoopError::llm("OPENAI_API_KEY environment variable not set"))?;
Ok(Self::new(api_key))
}
pub fn with_base_url(api_key: impl Into<String>, api_url: impl Into<String>) -> Self {
Self {
client: Client::new(),
api_key: api_key.into(),
api_url: api_url.into(),
provider_type: DriverId::OpenAI,
retry_config: LlmRetryConfig::default(),
}
}
pub fn with_provider_type(mut self, provider_type: DriverId) -> Self {
self.provider_type = provider_type;
self
}
pub fn with_retry_config(mut self, config: LlmRetryConfig) -> Self {
self.retry_config = config;
self
}
pub fn api_url(&self) -> &str {
&self.api_url
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn client(&self) -> &Client {
&self.client
}
pub fn provider_type(&self) -> &DriverId {
&self.provider_type
}
fn convert_role(role: &LlmMessageRole) -> &'static str {
match role {
LlmMessageRole::System => "developer", LlmMessageRole::User => "user",
LlmMessageRole::Assistant => "assistant",
LlmMessageRole::Tool => "tool",
}
}
fn convert_message(msg: &LlmMessage, supports_phases: bool) -> ResponsesInputItem {
if msg.role == LlmMessageRole::Tool
&& let Some(tool_call_id) = &msg.tool_call_id
{
let mut has_images = false;
let output = match &msg.content {
LlmMessageContent::Text(text) => text.clone(),
LlmMessageContent::Parts(parts) => {
has_images = parts
.iter()
.any(|p| matches!(p, LlmContentPart::Image { .. }));
parts
.iter()
.filter_map(|p| match p {
LlmContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
};
if has_images {
tracing::warn!(
tool_call_id = %tool_call_id,
"OpenResponses API does not support images in tool results; images dropped"
);
}
return ResponsesInputItem::FunctionCallOutput {
r#type: "function_call_output".to_string(),
call_id: tool_call_id.clone(),
output,
};
}
let content = match &msg.content {
LlmMessageContent::Text(text) => ResponsesContent::Text(text.clone()),
LlmMessageContent::Parts(parts) => {
let responses_parts: Vec<ResponsesContentPart> = parts
.iter()
.map(|part| match part {
LlmContentPart::Text { text } => ResponsesContentPart::InputText {
r#type: "input_text".to_string(),
text: text.clone(),
},
LlmContentPart::Image { url } => ResponsesContentPart::InputImage {
r#type: "input_image".to_string(),
image_url: url.clone(),
},
LlmContentPart::Audio { url } => ResponsesContentPart::InputAudio {
r#type: "input_audio".to_string(),
input_audio: ResponsesInputAudio {
data: url.clone(),
format: "wav".to_string(),
},
},
})
.collect();
ResponsesContent::Parts(responses_parts)
}
};
let phase = if supports_phases && msg.role == LlmMessageRole::Assistant {
msg.phase.map(|p| p.as_provider_str().to_string())
} else {
None
};
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: Self::convert_role(&msg.role).to_string(),
content,
phase,
}
}
fn sanitize_parameters(params: &Value) -> Value {
let mut p = params.clone();
if let Some(obj) = p.as_object_mut()
&& obj.get("type").and_then(|v| v.as_str()) == Some("object")
&& !obj.contains_key("properties")
{
obj.insert(
"properties".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
}
p
}
fn convert_tools(tools: &[ToolDefinition]) -> Vec<ResponsesTool> {
tools
.iter()
.map(|tool| ResponsesTool::Function {
r#type: "function".to_string(),
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: Self::sanitize_parameters(tool.parameters()),
defer_loading: None,
})
.collect()
}
fn convert_tools_with_search(tools: &[ToolDefinition], threshold: usize) -> Vec<ResponsesTool> {
use crate::tool_types::DeferrablePolicy;
use std::collections::HashMap;
if tools.len() < threshold {
return Self::convert_tools(tools);
}
let mut namespaces: HashMap<String, Vec<ResponsesTool>> = HashMap::new();
let mut ungrouped = vec![];
let mut never_defer = vec![];
for tool in tools {
let should_defer = match tool.deferrable() {
DeferrablePolicy::Never => false,
DeferrablePolicy::Automatic | DeferrablePolicy::Always => true,
};
let func = ResponsesTool::Function {
r#type: "function".to_string(),
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: Self::sanitize_parameters(tool.parameters()),
defer_loading: if should_defer { Some(true) } else { None },
};
if !should_defer {
never_defer.push(func);
} else {
match tool.category() {
Some(cat) => {
namespaces.entry(cat.to_string()).or_default().push(func);
}
None => ungrouped.push(func),
}
}
}
let mut result: Vec<ResponsesTool> = Vec::new();
result.extend(never_defer);
for (name, tools) in namespaces {
let description = format!("Tools for {name}");
result.push(ResponsesTool::Namespace {
r#type: "namespace".to_string(),
name,
description,
tools,
});
}
result.extend(ungrouped);
result.push(ResponsesTool::ToolSearch {
r#type: "tool_search".to_string(),
});
result
}
fn build_prompt_cache_key(
config: &LlmCallConfig,
_input_items: &[ResponsesInputItem],
instructions: &Option<String>,
tools: &Option<Vec<ResponsesTool>>,
) -> Option<String> {
let prompt_cache = config.prompt_cache.as_ref().filter(|cfg| cfg.enabled)?;
let cache_family = config
.metadata
.get("session_id")
.or_else(|| config.metadata.get("agent_id"))
.or_else(|| config.metadata.get("harness_id"))
.or_else(|| config.metadata.get("org_id"));
let fingerprint = json!({
"strategy": prompt_cache.strategy,
"model": config.model,
"cache_family": cache_family,
"instructions": instructions,
"tools": tools,
});
let payload = serde_json::to_vec(&fingerprint).ok()?;
let digest = format!("{:x}", Sha256::digest(payload));
let digest_len = OPENAI_PROMPT_CACHE_KEY_MAX_LEN - PROMPT_CACHE_KEY_PREFIX.len();
Some(format!(
"{PROMPT_CACHE_KEY_PREFIX}{}",
&digest[..digest_len]
))
}
pub async fn compact(&self, request: CompactRequest) -> Result<CompactResponse> {
let compact_url = if self.api_url.ends_with("/responses") {
format!("{}/compact", self.api_url)
} else if self.api_url.ends_with("/responses/") {
format!("{}compact", self.api_url)
} else {
format!("{}/compact", self.api_url.trim_end_matches('/'))
};
let mut retry_metadata = RetryMetadata::default();
let mut last_error: Option<String> = None;
let response = loop {
let response =
apply_openai_api_auth(self.client.post(&compact_url), &compact_url, &self.api_key)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| {
AgentLoopError::llm(format!("Failed to send compact request: {}", e))
})?;
let status = response.status();
if status.is_success() {
break response;
}
if is_transient_error(status) && retry_metadata.attempts < self.retry_config.max_retries
{
let rate_limit_info = if is_rate_limit_status(status) {
Some(RateLimitInfo::from_openai_headers(response.headers()))
} else {
None
};
let error_text = response.text().await.unwrap_or_default();
let wait_duration = rate_limit_info
.as_ref()
.map(|info| info.recommended_wait(&self.retry_config, retry_metadata.attempts))
.unwrap_or_else(|| {
self.retry_config.calculate_backoff(retry_metadata.attempts)
});
tracing::warn!(
status = %status,
attempt = retry_metadata.attempts + 1,
max_retries = self.retry_config.max_retries,
wait_secs = wait_duration.as_secs_f64(),
"OpenResponsesDriver: compact rate limit or transient error, retrying"
);
retry_metadata.record_retry(wait_duration, rate_limit_info);
last_error = Some(error_text);
tokio::time::sleep(wait_duration).await;
continue;
}
let error_text = response.text().await.unwrap_or_default();
if is_openai_model_not_found(status, &error_text) {
return Err(AgentLoopError::model_not_available(request.model.clone()));
}
if is_openai_request_too_large(status, &error_text) {
return Err(AgentLoopError::request_too_large(format!(
"OpenAI Responses compact API ({}): {}",
status, error_text
)));
}
let error_msg = format!(
"OpenAI Responses compact API error ({}): {}",
status, error_text
);
if retry_metadata.attempts > 0 {
return Err(AgentLoopError::llm(format!(
"{} (after {} retries, last error: {})",
error_msg,
retry_metadata.attempts,
last_error.unwrap_or_default()
)));
}
return Err(AgentLoopError::llm(error_msg));
};
if retry_metadata.had_retries() {
tracing::info!(
attempts = retry_metadata.attempts,
total_wait_secs = retry_metadata.total_retry_wait.as_secs_f64(),
"OpenResponsesDriver: compact request succeeded after retries"
);
}
let compact_response: CompactResponse = response
.json()
.await
.map_err(|e| AgentLoopError::llm(format!("Failed to parse compact response: {}", e)))?;
Ok(compact_response)
}
pub fn supports_compact(&self) -> bool {
self.api_url.starts_with("https://api.openai.com/")
}
fn build_input(
messages: &[LlmMessage],
supports_phases: bool,
) -> (Option<String>, Vec<ResponsesInputItem>) {
let mut instructions: Option<String> = None;
let mut input_items = Vec::new();
let mut reasoning_counter = 0u32;
for msg in messages {
if msg.role == LlmMessageRole::System {
instructions = Some(match &msg.content {
LlmMessageContent::Text(text) => text.clone(),
LlmMessageContent::Parts(parts) => parts
.iter()
.filter_map(|p| match p {
LlmContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join(""),
});
} else if msg.role == LlmMessageRole::Assistant {
if let Some(encrypted_content) = &msg.thinking_signature {
reasoning_counter += 1;
input_items.push(ResponsesInputItem::Reasoning {
r#type: "reasoning".to_string(),
id: format!("rs_{:08x}", reasoning_counter),
encrypted_content: encrypted_content.clone(),
});
tracing::debug!(
encrypted_len = encrypted_content.len(),
"OpenResponses: including reasoning item in request"
);
}
if msg.tool_calls.as_ref().is_some_and(|tc| !tc.is_empty()) {
let has_content = match &msg.content {
LlmMessageContent::Text(text) => !text.is_empty(),
LlmMessageContent::Parts(parts) => !parts.is_empty(),
};
if has_content {
input_items.push(Self::convert_message(msg, supports_phases));
}
if let Some(tool_calls) = &msg.tool_calls {
for tc in tool_calls {
input_items.push(ResponsesInputItem::FunctionCall {
r#type: "function_call".to_string(),
call_id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.to_string(),
});
}
}
} else {
input_items.push(Self::convert_message(msg, supports_phases));
}
} else {
input_items.push(Self::convert_message(msg, supports_phases));
}
}
(instructions, input_items)
}
}
fn compute_delta_input_items(items: Vec<ResponsesInputItem>) -> Vec<ResponsesInputItem> {
let last_assistant_turn_idx = items
.iter()
.enumerate()
.rev()
.find_map(|(i, item)| match item {
ResponsesInputItem::Message { role, .. } if role == "assistant" => Some(i),
ResponsesInputItem::Reasoning { .. } => Some(i),
ResponsesInputItem::FunctionCall { .. } => Some(i),
_ => None,
});
match last_assistant_turn_idx {
Some(idx) => items.into_iter().skip(idx + 1).collect(),
None => items,
}
}
fn finalize_input_for_request(
input_items: Vec<ResponsesInputItem>,
previous_response_id: &Option<String>,
) -> Vec<ResponsesInputItem> {
if previous_response_id.is_some() {
compute_delta_input_items(input_items)
} else {
drop_locally_orphaned_function_call_outputs(input_items)
}
}
fn drop_locally_orphaned_function_call_outputs(
input_items: Vec<ResponsesInputItem>,
) -> Vec<ResponsesInputItem> {
let visible_call_ids: HashSet<String> = input_items
.iter()
.filter_map(|item| match item {
ResponsesInputItem::FunctionCall { call_id, .. } => Some(call_id.clone()),
_ => None,
})
.collect();
if visible_call_ids.is_empty() {
return input_items
.into_iter()
.filter(|item| !matches!(item, ResponsesInputItem::FunctionCallOutput { .. }))
.collect();
}
input_items
.into_iter()
.filter(|item| match item {
ResponsesInputItem::FunctionCallOutput { call_id, .. } => {
visible_call_ids.contains(call_id.as_str())
}
_ => true,
})
.collect()
}
fn endpoint_persists_responses(api_url: &str) -> bool {
crate::openai_protocol::is_openai_api_url(api_url)
|| crate::openai_protocol::is_azure_openai_api_url(api_url)
}
#[async_trait]
impl ChatDriver for OpenResponsesProtocolChatDriver {
async fn chat_completion_stream(
&self,
messages: Vec<LlmMessage>,
config: &LlmCallConfig,
) -> Result<LlmResponseStream> {
let model_profile =
crate::model_profiles::get_model_profile(&self.provider_type, &config.model);
let supports_phases = model_profile
.as_ref()
.is_some_and(|profile| profile.supports_phases);
let supports_tool_search = model_profile
.as_ref()
.is_some_and(|profile| profile.tool_search);
let (instructions, input_items) = Self::build_input(&messages, supports_phases);
let previous_response_id = if endpoint_persists_responses(&self.api_url) {
config.previous_response_id.clone()
} else {
None
};
let input_items = finalize_input_for_request(input_items, &previous_response_id);
let tools = if config.tools.is_empty() {
None
} else if let Some(ref ts_config) = config.tool_search {
if ts_config.enabled && supports_tool_search {
Some(Self::convert_tools_with_search(
&config.tools,
ts_config.threshold,
))
} else {
Some(Self::convert_tools(&config.tools))
}
} else {
Some(Self::convert_tools(&config.tools))
};
let reasoning = config
.reasoning_effort
.as_ref()
.filter(|e| !e.eq_ignore_ascii_case("none"))
.map(|effort| ResponsesReasoning {
effort: effort.clone(),
summary: "detailed".to_string(),
});
let metadata = if config.metadata.is_empty() {
None
} else {
Some(config.metadata.clone())
};
let prompt_cache_key =
Self::build_prompt_cache_key(config, &input_items, &instructions, &tools);
let openrouter_routing = if self.provider_type == DriverId::OpenRouter {
config.openrouter_routing.as_ref()
} else {
None
};
if let Some(routing) = openrouter_routing {
routing
.validate_for_primary_model(&config.model)
.map_err(AgentLoopError::llm)?;
}
let preset_applied_owned: Option<crate::llm_driver_registry::OpenRouterRoutingConfig>;
let after_presets: Option<&crate::llm_driver_registry::OpenRouterRoutingConfig> =
match openrouter_routing {
None => None,
Some(r) if r.presets.is_empty() => Some(r),
Some(r) => {
preset_applied_owned = Some(r.apply_presets().map_err(AgentLoopError::llm)?);
preset_applied_owned.as_ref()
}
};
let effective_routing_cow: Option<
std::borrow::Cow<'_, crate::llm_driver_registry::OpenRouterRoutingConfig>,
> = match after_presets {
None => None,
Some(r) => match r.capacity_strategy {
None
| Some(crate::llm_driver_registry::OpenRouterCapacityStrategy::SharedCapacity) => {
Some(std::borrow::Cow::Borrowed(r))
}
_ => Some(std::borrow::Cow::Owned(
r.apply_capacity_strategy().map_err(AgentLoopError::llm)?,
)),
},
};
let effective_routing = effective_routing_cow.as_deref();
let openrouter_provider = effective_routing.and_then(|routing| {
routing
.provider
.as_ref()
.filter(|provider| !provider.is_empty())
.cloned()
});
let openrouter_plugins = effective_routing.and_then(|routing| {
routing
.plugins
.as_ref()
.filter(|p| !p.is_empty())
.and_then(plugins_to_wire)
});
let request = ResponsesRequest {
model: config.model.clone(),
models: effective_routing
.and_then(|routing| (!routing.models.is_empty()).then_some(routing.models.clone())),
route: effective_routing.and_then(|routing| routing.route),
provider: openrouter_provider,
input: input_items,
instructions,
previous_response_id,
temperature: config.temperature,
max_output_tokens: config.max_tokens,
stream: true,
tools,
reasoning,
metadata,
prompt_cache_key,
plugins: openrouter_plugins,
};
{
let tool_count = request.tools.as_ref().map_or(0, |t| t.len());
let input_count = request.input.len();
let has_instructions = request.instructions.is_some();
let has_reasoning = request.reasoning.is_some();
let has_previous_response = request.previous_response_id.is_some();
tracing::debug!(
model = %request.model,
input_items = input_count,
tool_count = tool_count,
has_instructions = has_instructions,
has_reasoning = has_reasoning,
has_previous_response = has_previous_response,
api_url = %self.api_url,
"OpenResponsesDriver: sending request"
);
}
let mut retry_metadata = RetryMetadata::default();
let mut last_error: Option<String> = None;
let response = loop {
let response = apply_openai_api_auth(
self.client.post(&self.api_url),
&self.api_url,
&self.api_key,
)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| AgentLoopError::llm(format!("Failed to send request: {}", e)))?;
let status = response.status();
if status.is_success() {
break response;
}
if is_transient_error(status) && retry_metadata.attempts < self.retry_config.max_retries
{
let rate_limit_info = if is_rate_limit_status(status) {
Some(RateLimitInfo::from_openai_headers(response.headers()))
} else {
None
};
let error_text = response.text().await.unwrap_or_default();
if is_provider_quota_message(&error_text) {
return Err(AgentLoopError::llm_kind(
LlmErrorKind::QuotaExhausted,
format!("OpenAI Responses API error ({}): {}", status, error_text),
));
}
let wait_duration = rate_limit_info
.as_ref()
.map(|info| info.recommended_wait(&self.retry_config, retry_metadata.attempts))
.unwrap_or_else(|| {
self.retry_config.calculate_backoff(retry_metadata.attempts)
});
tracing::warn!(
status = %status,
attempt = retry_metadata.attempts + 1,
max_retries = self.retry_config.max_retries,
wait_secs = wait_duration.as_secs_f64(),
retry_after = ?rate_limit_info.as_ref().and_then(|i| i.retry_after_secs),
"OpenResponsesDriver: rate limit or transient error, retrying"
);
retry_metadata.record_retry(wait_duration, rate_limit_info);
last_error = Some(error_text);
tokio::time::sleep(wait_duration).await;
continue;
}
let error_text = response.text().await.unwrap_or_default();
if is_openai_model_not_found(status, &error_text) {
return Err(AgentLoopError::model_not_available(config.model.clone()));
}
if is_openai_request_too_large(status, &error_text) {
return Err(AgentLoopError::request_too_large(format!(
"OpenAI Responses API ({}): {}",
status, error_text
)));
}
let error_msg = format!("OpenAI Responses API error ({}): {}", status, error_text);
let kind = LlmErrorKind::from_provider_status(status.as_u16(), &error_text);
if retry_metadata.attempts > 0 {
return Err(AgentLoopError::llm_kind(
kind,
format!(
"{} (after {} retries, last error: {})",
error_msg,
retry_metadata.attempts,
last_error.unwrap_or_default()
),
));
}
return Err(AgentLoopError::llm_kind(kind, error_msg));
};
if retry_metadata.had_retries() {
tracing::info!(
attempts = retry_metadata.attempts,
total_wait_secs = retry_metadata.total_retry_wait.as_secs_f64(),
"OpenResponsesDriver: request succeeded after retries"
);
}
let byte_stream = response.bytes_stream();
let event_stream = byte_stream.eventsource();
let model = config.model.clone();
let input_tokens = Arc::new(Mutex::new(0u32));
let output_tokens = Arc::new(Mutex::new(0u32));
let cache_read_tokens = Arc::new(Mutex::new(Option::<u32>::None));
let accumulated_tool_calls = Arc::new(Mutex::new(Vec::<ToolCallAccumulator>::new()));
let finish_reason = Arc::new(Mutex::new(Option::<String>::None));
let shared_retry_metadata = if retry_metadata.had_retries() {
Some(Arc::new(retry_metadata))
} else {
None
};
let converted_stream: LlmResponseStream = Box::pin(event_stream.then(move |result| {
let model = model.clone();
let input_tokens = Arc::clone(&input_tokens);
let output_tokens = Arc::clone(&output_tokens);
let cache_read_tokens = Arc::clone(&cache_read_tokens);
let accumulated_tool_calls = Arc::clone(&accumulated_tool_calls);
let finish_reason = Arc::clone(&finish_reason);
let retry_metadata_for_done = shared_retry_metadata.clone();
async move {
match result {
Ok(event) => {
let event_data = &event.data;
if let Ok(streaming_event) =
serde_json::from_str::<StreamingEvent>(event_data)
{
return Ok(handle_streaming_event(
streaming_event,
&input_tokens,
&output_tokens,
&cache_read_tokens,
&accumulated_tool_calls,
&finish_reason,
model,
retry_metadata_for_done,
));
}
let parsed: std::result::Result<Value, _> =
serde_json::from_str(event_data);
match parsed {
Ok(json) => {
let event_type = json.get("type").and_then(|t| t.as_str());
match event_type {
Some("response.output_text.delta") => {
if let Some(delta) =
json.get("delta").and_then(|d| d.as_str())
{
Ok(LlmStreamEvent::TextDelta(delta.to_string()))
} else {
Ok(LlmStreamEvent::TextDelta(String::new()))
}
}
Some("response.function_call_arguments.delta") => {
if let (Some(item_id), Some(delta)) = (
json.get("item_id").and_then(|c| c.as_str()),
json.get("delta").and_then(|d| d.as_str()),
) {
let mut acc = accumulated_tool_calls.lock().unwrap();
if let Some(tc) =
acc.iter_mut().find(|t| t.id == item_id)
{
tc.arguments.push_str(delta);
} else {
acc.push(ToolCallAccumulator {
id: item_id.to_string(),
call_id: String::new(),
name: String::new(),
arguments: delta.to_string(),
});
}
}
Ok(LlmStreamEvent::TextDelta(String::new()))
}
Some("response.output_item.added") => {
if let Some(item) = json.get("item")
&& item.get("type").and_then(|t| t.as_str())
== Some("function_call")
{
let id = item
.get("id")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string();
let call_id = item
.get("call_id")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string();
let name = item
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("")
.to_string();
let mut acc = accumulated_tool_calls.lock().unwrap();
if let Some(tc) = acc.iter_mut().find(|t| t.id == id) {
tc.name = name;
tc.call_id = call_id;
} else {
acc.push(ToolCallAccumulator {
id,
call_id,
name,
arguments: String::new(),
});
}
}
Ok(LlmStreamEvent::TextDelta(String::new()))
}
Some("response.output_item.done") => {
if let Some(item) = json.get("item")
&& item.get("type").and_then(|t| t.as_str())
== Some("function_call")
{
let acc = accumulated_tool_calls.lock().unwrap();
if !acc.is_empty() {
let tool_calls: Vec<ToolCall> = acc
.iter()
.filter(|tc| !tc.name.is_empty())
.map(|tc| {
let arguments: Value =
serde_json::from_str(&tc.arguments)
.unwrap_or(json!({}));
ToolCall {
id: tc.call_id.clone(),
name: tc.name.clone(),
arguments,
}
})
.collect();
if !tool_calls.is_empty() {
*finish_reason.lock().unwrap() =
Some("tool_calls".to_string());
return Ok(LlmStreamEvent::ToolCalls(
tool_calls,
));
}
}
}
Ok(LlmStreamEvent::TextDelta(String::new()))
}
Some("response.completed") | Some("response.done") => {
let response_obj = json.get("response").unwrap_or(&json);
let mut provider_cost_usd: Option<f64> = None;
if let Some(usage) = response_obj.get("usage") {
if let Some(input) =
usage.get("input_tokens").and_then(|t| t.as_u64())
{
*input_tokens.lock().unwrap() = input as u32;
}
if let Some(output) =
usage.get("output_tokens").and_then(|t| t.as_u64())
{
*output_tokens.lock().unwrap() = output as u32;
}
if let Some(details) = usage.get("input_tokens_details")
&& let Some(cached) = details
.get("cached_tokens")
.and_then(|t| t.as_u64())
{
*cache_read_tokens.lock().unwrap() =
Some(cached as u32);
}
provider_cost_usd =
usage.get("cost").and_then(|c| c.as_f64());
}
let status = response_obj
.get("status")
.and_then(|s| s.as_str())
.unwrap_or("completed");
let reason = match status {
"completed" => {
let existing_reason =
finish_reason.lock().unwrap().clone();
existing_reason
.unwrap_or_else(|| "stop".to_string())
}
"failed" => {
let error_detail = response_obj
.get("error")
.map(|e| e.to_string())
.unwrap_or_else(|| "no error detail".into());
tracing::warn!(
response_error = %error_detail,
"OpenResponsesDriver: response completed with 'failed' status (fallback parser)"
);
"error".to_string()
}
"cancelled" => "stop".to_string(),
_ => "stop".to_string(),
};
let phase = response_obj
.get("output")
.and_then(|o| o.as_array())
.and_then(|items| {
items.iter().rev().find_map(|item| {
if item.get("type")?.as_str()? == "message"
&& item.get("role")?.as_str()?
== "assistant"
{
item.get("phase")?
.as_str()
.map(String::from)
} else {
None
}
})
});
let input = *input_tokens.lock().unwrap();
let output = *output_tokens.lock().unwrap();
let cached = *cache_read_tokens.lock().unwrap();
Ok(LlmStreamEvent::Done(Box::new(LlmCompletionMetadata {
total_tokens: Some(input + output),
prompt_tokens: Some(input),
completion_tokens: Some(output),
cache_read_tokens: cached,
cache_creation_tokens: None,
provider_cost_usd,
model: Some(model),
finish_reason: Some(reason),
retry_metadata: retry_metadata_for_done
.map(|arc| (*arc).clone()),
response_id: None,
phase,
})))
}
Some("error") => {
let error_code = json
.get("error")
.and_then(|e| e.get("code"))
.and_then(|c| c.as_str())
.unwrap_or("unknown");
let error_msg = json
.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
tracing::warn!(
error_code = error_code,
error_message = error_msg,
raw_error = %json.get("error").unwrap_or(&json),
"OpenResponsesDriver: received streaming error event (fallback parser)"
);
let formatted = if error_code != "unknown" {
format!("{}: {}", error_code, error_msg)
} else {
error_msg.to_string()
};
Ok(LlmStreamEvent::Error(formatted))
}
_ => {
Ok(LlmStreamEvent::TextDelta(String::new()))
}
}
}
Err(e) => Ok(LlmStreamEvent::Error(format!(
"Failed to parse event: {}",
e
))),
}
}
Err(e) => Ok(LlmStreamEvent::Error(format!("Stream error: {}", e))),
}
}
}));
Ok(converted_stream)
}
fn supports_compact(&self) -> bool {
OpenResponsesProtocolChatDriver::supports_compact(self)
}
async fn compact(
&self,
request: crate::openresponses_protocol::CompactRequest,
) -> Result<Option<crate::openresponses_protocol::CompactResponse>> {
Ok(Some(
OpenResponsesProtocolChatDriver::compact(self, request).await?,
))
}
}
impl std::fmt::Debug for OpenResponsesProtocolChatDriver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenResponsesProtocolChatDriver")
.field("api_url", &self.api_url)
.field("provider_type", &self.provider_type)
.field("api_key", &"[REDACTED]")
.finish()
}
}
#[derive(Clone, Default)]
struct ToolCallAccumulator {
id: String,
call_id: String,
name: String,
arguments: String,
}
#[allow(clippy::too_many_arguments)]
fn handle_streaming_event(
event: StreamingEvent,
input_tokens: &Mutex<u32>,
output_tokens: &Mutex<u32>,
cache_read_tokens: &Mutex<Option<u32>>,
accumulated_tool_calls: &Mutex<Vec<ToolCallAccumulator>>,
finish_reason: &Mutex<Option<String>>,
model: String,
retry_metadata: Option<Arc<RetryMetadata>>,
) -> LlmStreamEvent {
match event {
StreamingEvent::OutputTextDelta { delta, .. } => LlmStreamEvent::TextDelta(delta),
StreamingEvent::ReasoningDelta { delta, .. } => LlmStreamEvent::ThinkingDelta(delta),
StreamingEvent::ReasoningTextDelta { delta, .. } => LlmStreamEvent::ThinkingDelta(delta),
StreamingEvent::ReasoningSummaryDelta { delta, .. } => LlmStreamEvent::ThinkingDelta(delta),
StreamingEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => {
let mut acc = accumulated_tool_calls.lock().unwrap();
if let Some(tc) = acc.iter_mut().find(|t| t.id == item_id) {
tc.arguments.push_str(&delta);
} else {
acc.push(ToolCallAccumulator {
id: item_id,
call_id: String::new(),
name: String::new(),
arguments: delta,
});
}
LlmStreamEvent::TextDelta(String::new())
}
StreamingEvent::OutputItemAdded { item, .. } => {
if let Some(types::OutputItem::FunctionCall {
id, call_id, name, ..
}) = item
{
let mut acc = accumulated_tool_calls.lock().unwrap();
if let Some(tc) = acc.iter_mut().find(|t| t.id == id) {
tc.name = name;
tc.call_id = call_id;
} else {
acc.push(ToolCallAccumulator {
id,
call_id,
name,
arguments: String::new(),
});
}
}
LlmStreamEvent::TextDelta(String::new())
}
StreamingEvent::OutputItemDone { item, .. } => {
match item {
Some(types::OutputItem::FunctionCall { .. }) => {
let acc = accumulated_tool_calls.lock().unwrap();
if !acc.is_empty() {
let tool_calls: Vec<ToolCall> = acc
.iter()
.filter(|tc| !tc.name.is_empty())
.map(|tc| {
let arguments: Value =
serde_json::from_str(&tc.arguments).unwrap_or(json!({}));
ToolCall {
id: tc.call_id.clone(),
name: tc.name.clone(),
arguments,
}
})
.collect();
if !tool_calls.is_empty() {
*finish_reason.lock().unwrap() = Some("tool_calls".to_string());
return LlmStreamEvent::ToolCalls(tool_calls);
}
}
LlmStreamEvent::TextDelta(String::new())
}
Some(types::OutputItem::Reasoning {
id,
summary,
content: _, encrypted_content,
}) => {
let safe_summary: Vec<String> = summary
.into_iter()
.filter_map(|part| match part {
types::ContentPart::SummaryText { text } => Some(text),
_ => None,
})
.collect();
tracing::debug!(
encrypted_len = encrypted_content.as_ref().map(|s| s.len()).unwrap_or(0),
summary_segments = safe_summary.len(),
"OpenResponses: received reasoning item"
);
LlmStreamEvent::ReasonItem {
provider: "openai".to_string(),
model: Some(model.clone()),
item_id: id,
encrypted_content,
summary: safe_summary,
token_count: None,
}
}
_ => LlmStreamEvent::TextDelta(String::new()),
}
}
StreamingEvent::ResponseCompleted { response, .. } => {
if let Some(usage) = &response.usage {
*input_tokens.lock().unwrap() = usage.input_tokens;
*output_tokens.lock().unwrap() = usage.output_tokens;
if let Some(details) = &usage.input_tokens_details {
*cache_read_tokens.lock().unwrap() = Some(details.cached_tokens);
}
}
let reason = match response.status {
types::ResponseStatus::Completed => {
let existing = finish_reason.lock().unwrap().clone();
existing.unwrap_or_else(|| "stop".to_string())
}
types::ResponseStatus::Failed => {
tracing::warn!(
response_id = %response.id,
error = ?response.error,
"OpenResponsesDriver: response completed with 'failed' status"
);
"error".to_string()
}
types::ResponseStatus::Cancelled => "stop".to_string(),
_ => "stop".to_string(),
};
let phase = response.output.iter().rev().find_map(|item| {
if let types::OutputItem::Message { phase, .. } = item {
phase.clone()
} else {
None
}
});
let input = *input_tokens.lock().unwrap();
let output = *output_tokens.lock().unwrap();
let cached = *cache_read_tokens.lock().unwrap();
let provider_cost_usd = response.usage.as_ref().and_then(|u| u.cost);
LlmStreamEvent::Done(Box::new(LlmCompletionMetadata {
total_tokens: Some(input + output),
prompt_tokens: Some(input),
completion_tokens: Some(output),
cache_read_tokens: cached,
cache_creation_tokens: None,
provider_cost_usd,
model: Some(model),
finish_reason: Some(reason),
retry_metadata: retry_metadata.map(|arc| (*arc).clone()),
response_id: Some(response.id),
phase,
}))
}
StreamingEvent::Error { error, .. } => {
let msg = if let Some(code) = &error.code {
format!("{}: {}", code, error.message)
} else {
error.message.clone()
};
tracing::warn!(
error_code = error.code.as_deref().unwrap_or("none"),
error_message = %error.message,
"OpenResponsesDriver: received streaming error event from provider"
);
LlmStreamEvent::Error(msg)
}
StreamingEvent::RefusalDelta { delta, .. } => {
LlmStreamEvent::Error(format!("Model refused: {}", delta))
}
_ => LlmStreamEvent::TextDelta(String::new()),
}
}
#[derive(Debug, Clone, Serialize)]
pub struct CompactRequest {
pub model: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub input: Vec<CompactInputItem>,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum CompactInputItem {
#[serde(rename = "message")]
Message {
role: String,
content: CompactContent,
},
#[serde(rename = "function_call")]
FunctionCall {
call_id: String,
name: String,
arguments: String,
},
#[serde(rename = "function_call_output")]
FunctionCallOutput { call_id: String, output: String },
#[serde(rename = "compaction")]
Compaction { encrypted_content: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CompactContent {
Text(String),
Parts(Vec<CompactContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum CompactContentPart {
#[serde(rename = "input_text")]
InputText { text: String },
#[serde(rename = "input_image")]
InputImage { image_url: String },
}
#[derive(Debug, Clone, Deserialize)]
pub struct CompactResponse {
pub output: Vec<CompactOutputItem>,
pub usage: Option<CompactUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum CompactOutputItem {
#[serde(rename = "message")]
Message {
role: String,
content: CompactContent,
},
#[serde(rename = "compaction")]
Compaction {
encrypted_content: String,
},
}
#[derive(Debug, Clone, Deserialize)]
pub struct CompactUsage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
pub total_tokens: Option<u32>,
}
impl CompactInputItem {
pub fn from_llm_message(msg: &LlmMessage) -> Vec<Self> {
let mut items = Vec::new();
let role = match msg.role {
LlmMessageRole::System => "developer",
LlmMessageRole::User => "user",
LlmMessageRole::Assistant => "assistant",
LlmMessageRole::Tool => "tool",
};
if msg.role == LlmMessageRole::Tool
&& let Some(tool_call_id) = &msg.tool_call_id
{
let output = match &msg.content {
LlmMessageContent::Text(text) => text.clone(),
LlmMessageContent::Parts(parts) => parts
.iter()
.filter_map(|p| match p {
LlmContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join(""),
};
items.push(CompactInputItem::FunctionCallOutput {
call_id: tool_call_id.clone(),
output,
});
return items;
}
let content = Self::content_from_llm_message(msg);
let has_content = match &content {
CompactContent::Text(t) => !t.is_empty(),
CompactContent::Parts(p) => !p.is_empty(),
};
if has_content || msg.tool_calls.is_none() {
items.push(CompactInputItem::Message {
role: role.to_string(),
content,
});
}
if msg.role == LlmMessageRole::Assistant
&& let Some(tool_calls) = &msg.tool_calls
{
for tc in tool_calls {
items.push(CompactInputItem::FunctionCall {
call_id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.to_string(),
});
}
}
items
}
fn content_from_llm_message(msg: &LlmMessage) -> CompactContent {
match &msg.content {
LlmMessageContent::Text(text) => CompactContent::Text(text.clone()),
LlmMessageContent::Parts(parts) => {
let compact_parts: Vec<CompactContentPart> = parts
.iter()
.filter_map(|part| match part {
LlmContentPart::Text { text } => {
Some(CompactContentPart::InputText { text: text.clone() })
}
LlmContentPart::Image { url } => {
Some(CompactContentPart::InputImage {
image_url: url.clone(),
})
}
LlmContentPart::Audio { .. } => None, })
.collect();
if compact_parts.len() == 1
&& let CompactContentPart::InputText { text } = &compact_parts[0]
{
return CompactContent::Text(text.clone());
}
CompactContent::Parts(compact_parts)
}
}
}
}
impl CompactOutputItem {
pub fn to_llm_message(&self) -> Option<LlmMessage> {
match self {
CompactOutputItem::Message { role, content } => {
let llm_role = match role.as_str() {
"user" => LlmMessageRole::User,
"assistant" => LlmMessageRole::Assistant,
"developer" | "system" => LlmMessageRole::System,
"tool" => LlmMessageRole::Tool,
_ => LlmMessageRole::User, };
let llm_content = match content {
CompactContent::Text(text) => LlmMessageContent::Text(text.clone()),
CompactContent::Parts(parts) => {
let llm_parts: Vec<LlmContentPart> = parts
.iter()
.map(|p| match p {
CompactContentPart::InputText { text } => {
LlmContentPart::Text { text: text.clone() }
}
CompactContentPart::InputImage { image_url } => {
LlmContentPart::Image {
url: image_url.clone(),
}
}
})
.collect();
LlmMessageContent::Parts(llm_parts)
}
};
Some(LlmMessage {
role: llm_role,
content: llm_content,
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
})
}
CompactOutputItem::Compaction { .. } => {
None
}
}
}
}
pub fn messages_to_compact_input(messages: &[LlmMessage]) -> Vec<CompactInputItem> {
messages
.iter()
.flat_map(CompactInputItem::from_llm_message)
.collect()
}
pub fn compact_output_to_messages(
output: &[CompactOutputItem],
) -> (Vec<LlmMessage>, Vec<CompactInputItem>) {
let mut messages = Vec::new();
let mut compaction_items = Vec::new();
for item in output {
match item {
CompactOutputItem::Message { role, content } => {
if let Some(msg) = item.to_llm_message() {
messages.push(msg);
} else {
compaction_items.push(CompactInputItem::Message {
role: role.clone(),
content: content.clone(),
});
}
}
CompactOutputItem::Compaction { encrypted_content } => {
compaction_items.push(CompactInputItem::Compaction {
encrypted_content: encrypted_content.clone(),
});
}
}
}
(messages, compaction_items)
}
fn plugins_to_wire(
config: &crate::llm_driver_registry::OpenRouterPluginConfig,
) -> Option<Vec<Value>> {
let mut items: Vec<Value> = Vec::new();
if let Some(web) = &config.web {
let mut obj = serde_json::Map::new();
obj.insert("id".to_string(), json!("web"));
if let Some(max_results) = web.max_results {
obj.insert("max_results".to_string(), json!(max_results));
}
if let Some(ref prompt) = web.search_prompt {
obj.insert("search_prompt".to_string(), json!(prompt));
}
items.push(Value::Object(obj));
}
if config.file.is_some() {
items.push(json!({"id": "file"}));
}
if items.is_empty() { None } else { Some(items) }
}
#[derive(Debug, Serialize)]
struct ResponsesRequest {
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
models: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
route: Option<crate::llm_driver_registry::OpenRouterRoute>,
#[serde(skip_serializing_if = "Option::is_none")]
provider: Option<OpenRouterProviderRouting>,
input: Vec<ResponsesInputItem>,
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
previous_response_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u32>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ResponsesTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<ResponsesReasoning>,
#[serde(skip_serializing_if = "Option::is_none")]
metadata: Option<std::collections::HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
plugins: Option<Vec<serde_json::Value>>,
}
#[derive(Debug, Serialize)]
struct ResponsesReasoning {
effort: String,
summary: String,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum ResponsesInputItem {
Message {
r#type: String,
role: String,
content: ResponsesContent,
#[serde(skip_serializing_if = "Option::is_none")]
phase: Option<String>,
},
FunctionCall {
r#type: String,
call_id: String,
name: String,
arguments: String,
},
FunctionCallOutput {
r#type: String,
call_id: String,
output: String,
},
Reasoning {
r#type: String,
id: String,
encrypted_content: String,
},
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
enum ResponsesContent {
Text(String),
Parts(Vec<ResponsesContentPart>),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
#[allow(clippy::enum_variant_names)]
enum ResponsesContentPart {
InputText {
r#type: String,
text: String,
},
InputImage {
r#type: String,
image_url: String,
},
InputAudio {
r#type: String,
input_audio: ResponsesInputAudio,
},
}
#[derive(Debug, Serialize, Deserialize)]
struct ResponsesInputAudio {
data: String,
format: String,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum ResponsesTool {
Function {
r#type: String,
name: String,
description: String,
parameters: Value,
#[serde(skip_serializing_if = "Option::is_none")]
defer_loading: Option<bool>,
},
Namespace {
r#type: String,
name: String,
description: String,
tools: Vec<ResponsesTool>,
},
ToolSearch { r#type: String },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_driver_with_api_key() {
let driver = OpenResponsesProtocolChatDriver::new("test-key");
assert!(format!("{:?}", driver).contains("OpenResponsesProtocolChatDriver"));
}
#[test]
fn test_driver_with_base_url() {
let driver = OpenResponsesProtocolChatDriver::with_base_url(
"test-key",
"https://custom.api.com/v1/responses",
);
assert!(format!("{:?}", driver).contains("OpenResponsesProtocolChatDriver"));
assert_eq!(driver.api_url(), "https://custom.api.com/v1/responses");
}
#[test]
fn test_request_serialization() {
let request = ResponsesRequest {
model: "gpt-4o".to_string(),
models: None,
route: None,
provider: None,
input: vec![ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("Hello".to_string()),
phase: None,
}],
instructions: Some("You are helpful".to_string()),
previous_response_id: None,
temperature: None,
max_output_tokens: None,
stream: true,
tools: None,
reasoning: None,
metadata: None,
prompt_cache_key: None,
plugins: None,
};
let json = serde_json::to_value(&request).unwrap();
assert_eq!(json["model"], "gpt-4o");
assert_eq!(json["stream"], true);
assert_eq!(json["instructions"], "You are helpful");
assert!(json["input"].is_array());
}
#[test]
fn test_request_with_reasoning() {
let request = ResponsesRequest {
model: "o3".to_string(),
models: None,
route: None,
provider: None,
input: vec![ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("Think about this".to_string()),
phase: None,
}],
instructions: None,
previous_response_id: None,
temperature: None,
max_output_tokens: None,
stream: true,
tools: None,
reasoning: Some(ResponsesReasoning {
effort: "high".to_string(),
summary: "detailed".to_string(),
}),
metadata: None,
prompt_cache_key: None,
plugins: None,
};
let json = serde_json::to_value(&request).unwrap();
assert_eq!(json["reasoning"]["effort"], "high");
assert_eq!(json["reasoning"]["summary"], "detailed");
}
#[test]
fn test_request_with_metadata() {
let mut metadata = std::collections::HashMap::new();
metadata.insert("session_id".to_string(), "session_abc123".to_string());
metadata.insert("agent_id".to_string(), "agent_xyz789".to_string());
let request = ResponsesRequest {
model: "gpt-4o".to_string(),
models: None,
route: None,
provider: None,
input: vec![ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("Hello".to_string()),
phase: None,
}],
instructions: None,
previous_response_id: None,
temperature: None,
max_output_tokens: None,
stream: true,
tools: None,
reasoning: None,
metadata: Some(metadata),
prompt_cache_key: None,
plugins: None,
};
let json = serde_json::to_value(&request).unwrap();
assert_eq!(json["metadata"]["session_id"], "session_abc123");
assert_eq!(json["metadata"]["agent_id"], "agent_xyz789");
}
#[test]
fn test_build_prompt_cache_key_when_enabled() {
let mut metadata = std::collections::HashMap::new();
metadata.insert("session_id".to_string(), "session_abc123".to_string());
let config = LlmCallConfig {
model: "gpt-5.4".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata,
previous_response_id: None,
tool_search: None,
prompt_cache: Some(crate::llm_driver_registry::PromptCacheConfig {
enabled: true,
strategy: crate::llm_driver_registry::PromptCacheStrategy::Auto,
gemini_cached_content: None,
}),
openrouter_routing: None,
};
let input = vec![ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("Hello".to_string()),
phase: None,
}];
let key = OpenResponsesProtocolChatDriver::build_prompt_cache_key(
&config,
&input,
&Some("You are helpful".to_string()),
&None,
);
assert!(key.is_some());
assert!(key.unwrap().starts_with("everruns:"));
}
#[test]
fn test_build_prompt_cache_key_ignores_changing_input() {
let mut metadata = std::collections::HashMap::new();
metadata.insert("session_id".to_string(), "session_abc123".to_string());
let config = LlmCallConfig {
model: "gpt-5.4".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata,
previous_response_id: None,
tool_search: None,
prompt_cache: Some(crate::llm_driver_registry::PromptCacheConfig {
enabled: true,
strategy: crate::llm_driver_registry::PromptCacheStrategy::Auto,
gemini_cached_content: None,
}),
openrouter_routing: None,
};
let first_input = vec![ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("first turn".to_string()),
phase: None,
}];
let second_input = vec![ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("second turn with different text".to_string()),
phase: None,
}];
let first = OpenResponsesProtocolChatDriver::build_prompt_cache_key(
&config,
&first_input,
&Some("You are helpful".to_string()),
&None,
);
let second = OpenResponsesProtocolChatDriver::build_prompt_cache_key(
&config,
&second_input,
&Some("You are helpful".to_string()),
&None,
);
assert_eq!(first, second);
}
#[test]
fn test_build_prompt_cache_key_changes_with_cache_family() {
let mut first_metadata = std::collections::HashMap::new();
first_metadata.insert("session_id".to_string(), "session_abc123".to_string());
let mut second_metadata = std::collections::HashMap::new();
second_metadata.insert("session_id".to_string(), "session_xyz789".to_string());
let make_config = |metadata| LlmCallConfig {
model: "gpt-5.4".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata,
previous_response_id: None,
tool_search: None,
prompt_cache: Some(crate::llm_driver_registry::PromptCacheConfig {
enabled: true,
strategy: crate::llm_driver_registry::PromptCacheStrategy::Auto,
gemini_cached_content: None,
}),
openrouter_routing: None,
};
let input = vec![ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("same turn".to_string()),
phase: None,
}];
let first = OpenResponsesProtocolChatDriver::build_prompt_cache_key(
&make_config(first_metadata),
&input,
&Some("You are helpful".to_string()),
&None,
);
let second = OpenResponsesProtocolChatDriver::build_prompt_cache_key(
&make_config(second_metadata),
&input,
&Some("You are helpful".to_string()),
&None,
);
assert_ne!(first, second);
}
#[test]
fn test_build_prompt_cache_key_stays_within_openai_limit() {
let config = LlmCallConfig {
model: "gpt-5.5".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: None,
prompt_cache: Some(crate::llm_driver_registry::PromptCacheConfig {
enabled: true,
strategy: crate::llm_driver_registry::PromptCacheStrategy::Auto,
gemini_cached_content: None,
}),
openrouter_routing: None,
};
let input = vec![ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("fetch chalyi.name for me".to_string()),
phase: None,
}];
let key = OpenResponsesProtocolChatDriver::build_prompt_cache_key(
&config,
&input,
&Some("You are helpful".to_string()),
&None,
)
.unwrap();
assert!(
key.len() <= 64,
"OpenAI prompt_cache_key limit is 64 characters, got {}",
key.len()
);
}
#[test]
fn test_function_call_output_serialization() {
let item = ResponsesInputItem::FunctionCallOutput {
r#type: "function_call_output".to_string(),
call_id: "call_123".to_string(),
output: r#"{"result": 42}"#.to_string(),
};
let json = serde_json::to_value(&item).unwrap();
assert_eq!(json["type"], "function_call_output");
assert_eq!(json["call_id"], "call_123");
assert_eq!(json["output"], r#"{"result": 42}"#);
}
#[test]
fn test_multipart_content_serialization() {
let content = ResponsesContent::Parts(vec![
ResponsesContentPart::InputText {
r#type: "input_text".to_string(),
text: "Look at this image".to_string(),
},
ResponsesContentPart::InputImage {
r#type: "input_image".to_string(),
image_url: "data:image/png;base64,abc123".to_string(),
},
]);
let json = serde_json::to_value(&content).unwrap();
assert!(json.is_array());
assert_eq!(json[0]["type"], "input_text");
assert_eq!(json[1]["type"], "input_image");
}
#[test]
fn test_tool_serialization() {
let tool = ResponsesTool::Function {
r#type: "function".to_string(),
name: "get_weather".to_string(),
description: "Get weather for a location".to_string(),
parameters: json!({
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}),
defer_loading: None,
};
let json = serde_json::to_value(&tool).unwrap();
assert_eq!(json["type"], "function");
assert_eq!(json["name"], "get_weather");
assert!(json["parameters"]["properties"]["location"].is_object());
}
#[test]
fn test_build_input_extracts_system_as_instructions() {
let messages = vec![
LlmMessage::text(LlmMessageRole::System, "You are a helpful assistant"),
LlmMessage::text(LlmMessageRole::User, "Hello"),
];
let (instructions, input) = OpenResponsesProtocolChatDriver::build_input(&messages, false);
assert_eq!(
instructions,
Some("You are a helpful assistant".to_string())
);
assert_eq!(input.len(), 1); }
#[test]
fn test_convert_role() {
assert_eq!(
OpenResponsesProtocolChatDriver::convert_role(&LlmMessageRole::System),
"developer"
);
assert_eq!(
OpenResponsesProtocolChatDriver::convert_role(&LlmMessageRole::User),
"user"
);
assert_eq!(
OpenResponsesProtocolChatDriver::convert_role(&LlmMessageRole::Assistant),
"assistant"
);
assert_eq!(
OpenResponsesProtocolChatDriver::convert_role(&LlmMessageRole::Tool),
"tool"
);
}
#[test]
fn test_function_call_serialization() {
let item = ResponsesInputItem::FunctionCall {
r#type: "function_call".to_string(),
call_id: "call_abc123".to_string(),
name: "get_current_time".to_string(),
arguments: r#"{"timezone":"UTC"}"#.to_string(),
};
let json = serde_json::to_value(&item).unwrap();
assert_eq!(json["type"], "function_call");
assert_eq!(json["call_id"], "call_abc123");
assert_eq!(json["name"], "get_current_time");
assert_eq!(json["arguments"], r#"{"timezone":"UTC"}"#);
}
#[test]
fn test_build_input_with_tool_calls() {
use crate::tool_types::ToolCall;
let messages = vec![
LlmMessage::text(LlmMessageRole::System, "You are helpful"),
LlmMessage::text(LlmMessageRole::User, "What time is it?"),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text(String::new()),
tool_calls: Some(vec![ToolCall {
id: "call_xyz789".to_string(),
name: "get_current_time".to_string(),
arguments: json!({"timezone": "UTC"}),
}]),
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
},
LlmMessage {
role: LlmMessageRole::Tool,
content: LlmMessageContent::Text("2025-01-19T10:30:00Z".to_string()),
tool_calls: None,
tool_call_id: Some("call_xyz789".to_string()),
phase: None,
thinking: None,
thinking_signature: None,
},
];
let (instructions, input) = OpenResponsesProtocolChatDriver::build_input(&messages, false);
assert_eq!(instructions, Some("You are helpful".to_string()));
assert_eq!(input.len(), 3);
let json = serde_json::to_value(&input[1]).unwrap();
assert_eq!(json["type"], "function_call");
assert_eq!(json["call_id"], "call_xyz789");
assert_eq!(json["name"], "get_current_time");
let json = serde_json::to_value(&input[2]).unwrap();
assert_eq!(json["type"], "function_call_output");
assert_eq!(json["call_id"], "call_xyz789");
assert_eq!(json["output"], "2025-01-19T10:30:00Z");
}
#[test]
fn test_build_input_with_tool_calls_and_text() {
use crate::tool_types::ToolCall;
let messages = vec![
LlmMessage::text(LlmMessageRole::User, "What time is it?"),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text("Let me check the time for you.".to_string()),
tool_calls: Some(vec![ToolCall {
id: "call_abc".to_string(),
name: "get_time".to_string(),
arguments: json!({}),
}]),
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
},
];
let (_, input) = OpenResponsesProtocolChatDriver::build_input(&messages, false);
assert_eq!(input.len(), 3);
let json = serde_json::to_value(&input[0]).unwrap();
assert_eq!(json["role"], "user");
let json = serde_json::to_value(&input[1]).unwrap();
assert_eq!(json["role"], "assistant");
let json = serde_json::to_value(&input[2]).unwrap();
assert_eq!(json["type"], "function_call");
assert_eq!(json["call_id"], "call_abc");
}
#[test]
fn openresponses_requests_should_not_mix_previous_response_id_with_full_transcript() {
use crate::tool_types::ToolCall;
let messages = vec![
LlmMessage::text(LlmMessageRole::System, "You are helpful"),
LlmMessage::text(LlmMessageRole::User, "What time is it?"),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text("Let me check.".to_string()),
tool_calls: Some(vec![ToolCall {
id: "call_xyz789".to_string(),
name: "get_current_time".to_string(),
arguments: json!({"timezone": "UTC"}),
}]),
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
},
LlmMessage {
role: LlmMessageRole::Tool,
content: LlmMessageContent::Text("2025-01-19T10:30:00Z".to_string()),
tool_calls: None,
tool_call_id: Some("call_xyz789".to_string()),
phase: None,
thinking: None,
thinking_signature: None,
},
];
let (instructions, full_input) =
OpenResponsesProtocolChatDriver::build_input(&messages, false);
assert!(
full_input.len() > 1,
"sanity: full transcript has multi items"
);
let delta = compute_delta_input_items(full_input);
assert_eq!(
delta.len(),
1,
"stateful continuation must only send delta items; got {} items",
delta.len()
);
let json = serde_json::to_value(&delta[0]).unwrap();
assert_eq!(json["type"], "function_call_output");
assert_eq!(json["call_id"], "call_xyz789");
assert_eq!(json["output"], "2025-01-19T10:30:00Z");
assert_eq!(instructions, Some("You are helpful".to_string()));
}
#[test]
fn compute_delta_keeps_tail_after_assistant_message() {
let items = vec![
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("hi".to_string()),
phase: None,
},
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "assistant".to_string(),
content: ResponsesContent::Text("hello".to_string()),
phase: None,
},
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("follow up".to_string()),
phase: None,
},
];
let trimmed = compute_delta_input_items(items);
assert_eq!(trimmed.len(), 1);
let json = serde_json::to_value(&trimmed[0]).unwrap();
assert_eq!(json["role"], "user");
assert_eq!(
json["content"], "follow up",
"trim keeps the fresh user message that arrived after the assistant turn"
);
}
#[test]
fn compute_delta_keeps_tool_results_after_last_assistant_turn() {
let items = vec![
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("do two things".to_string()),
phase: None,
},
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "assistant".to_string(),
content: ResponsesContent::Text("ok".to_string()),
phase: None,
},
ResponsesInputItem::FunctionCall {
r#type: "function_call".to_string(),
call_id: "call_a".to_string(),
name: "tool_a".to_string(),
arguments: "{}".to_string(),
},
ResponsesInputItem::FunctionCall {
r#type: "function_call".to_string(),
call_id: "call_b".to_string(),
name: "tool_b".to_string(),
arguments: "{}".to_string(),
},
ResponsesInputItem::FunctionCallOutput {
r#type: "function_call_output".to_string(),
call_id: "call_a".to_string(),
output: "a result".to_string(),
},
ResponsesInputItem::FunctionCallOutput {
r#type: "function_call_output".to_string(),
call_id: "call_b".to_string(),
output: "b result".to_string(),
},
];
let trimmed = compute_delta_input_items(items);
assert_eq!(trimmed.len(), 2);
for item in &trimmed {
let json = serde_json::to_value(item).unwrap();
assert_eq!(json["type"], "function_call_output");
}
}
#[test]
fn compute_delta_allows_empty_input_for_stateful_continuation() {
let trimmed = compute_delta_input_items(vec![]);
assert!(trimmed.is_empty());
}
#[test]
fn compute_delta_keeps_all_items_when_no_assistant_turn_present() {
let items = vec![
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("one".to_string()),
phase: None,
},
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("two".to_string()),
phase: None,
},
];
let trimmed = compute_delta_input_items(items);
assert_eq!(trimmed.len(), 2);
}
#[test]
fn compute_delta_drops_prior_reasoning_items() {
let items = vec![
ResponsesInputItem::Reasoning {
r#type: "reasoning".to_string(),
id: "rs_00000001".to_string(),
encrypted_content: "encrypted-blob".to_string(),
},
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "assistant".to_string(),
content: ResponsesContent::Text("prior".to_string()),
phase: None,
},
ResponsesInputItem::FunctionCallOutput {
r#type: "function_call_output".to_string(),
call_id: "call_z".to_string(),
output: "result".to_string(),
},
];
let trimmed = compute_delta_input_items(items);
assert_eq!(trimmed.len(), 1);
let json = serde_json::to_value(&trimmed[0]).unwrap();
assert_eq!(json["type"], "function_call_output");
}
fn sample_full_transcript_items() -> Vec<ResponsesInputItem> {
vec![
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("first request".to_string()),
phase: None,
},
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "assistant".to_string(),
content: ResponsesContent::Text("first reply".to_string()),
phase: None,
},
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("follow-up".to_string()),
phase: None,
},
]
}
#[test]
fn finalize_input_skips_trim_when_previous_response_id_is_none() {
let items = sample_full_transcript_items();
let original_len = items.len();
let out = finalize_input_for_request(items, &None);
assert_eq!(
out.len(),
original_len,
"stateless mode keeps the full transcript so the model has context"
);
}
#[test]
fn finalize_input_drops_locally_orphaned_tool_output_without_previous_response_id() {
let items = vec![
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("fresh".to_string()),
phase: None,
},
ResponsesInputItem::FunctionCallOutput {
r#type: "function_call_output".to_string(),
call_id: "call_trimmed".to_string(),
output: "result".to_string(),
},
];
let out = finalize_input_for_request(items, &None);
assert_eq!(out.len(), 1);
let json = serde_json::to_value(&out[0]).unwrap();
assert_eq!(json["type"], "message");
}
#[test]
fn finalize_input_keeps_tool_output_with_previous_response_id_even_without_local_call() {
let items = vec![
ResponsesInputItem::FunctionCallOutput {
r#type: "function_call_output".to_string(),
call_id: "call_server_side".to_string(),
output: "stateful result".to_string(),
},
ResponsesInputItem::Message {
r#type: "message".to_string(),
role: "user".to_string(),
content: ResponsesContent::Text("follow-up".to_string()),
phase: None,
},
];
let out = finalize_input_for_request(items, &Some("resp_prev_42".to_string()));
assert_eq!(out.len(), 2);
let json = serde_json::to_value(&out[0]).unwrap();
assert_eq!(json["type"], "function_call_output");
assert_eq!(json["call_id"], "call_server_side");
}
#[test]
fn finalize_input_trims_when_previous_response_id_is_set() {
let items = sample_full_transcript_items();
let out = finalize_input_for_request(items, &Some("resp_prev_42".to_string()));
assert_eq!(
out.len(),
1,
"stateful continuation must drop everything up to and including the prior assistant message"
);
let json = serde_json::to_value(&out[0]).unwrap();
assert_eq!(json["type"], "message");
assert_eq!(json["role"], "user");
let txt = json["content"].as_str().unwrap_or("");
assert_eq!(txt, "follow-up");
}
#[test]
fn finalize_input_allows_empty_input_with_previous_response_id() {
let out = finalize_input_for_request(vec![], &Some("resp_anything".to_string()));
assert!(
out.is_empty(),
"empty delta is valid — the provider can resume purely from the response id"
);
}
#[test]
fn endpoint_persists_responses_for_openai_and_azure() {
assert!(endpoint_persists_responses(
"https://api.openai.com/v1/responses"
));
assert!(endpoint_persists_responses(
"https://api.openai.com:443/v1/responses"
));
assert!(endpoint_persists_responses(
"https://my-resource.openai.azure.com/openai/v1/responses"
));
assert!(endpoint_persists_responses(
"https://my-resource.services.ai.azure.com/openai/v1/responses"
));
}
#[test]
fn endpoint_does_not_persist_for_stateless_gateways() {
assert!(!endpoint_persists_responses(
"https://openrouter.ai/api/v1/responses"
));
assert!(!endpoint_persists_responses(
"https://generativelanguage.googleapis.com/v1beta/openai/responses"
));
assert!(!endpoint_persists_responses(
"https://api.openai.example.com/v1/responses"
));
}
#[test]
fn stateless_gateway_replays_full_transcript_despite_previous_response_id() {
let api_url = "https://openrouter.ai/api/v1/responses";
let prev_id: Option<String> = Some("gen-turn-1".to_string());
let effective_prev_id = if endpoint_persists_responses(api_url) {
prev_id.clone()
} else {
None
};
assert!(
effective_prev_id.is_none(),
"stateless gateway must not chain via previous_response_id"
);
let items = sample_full_transcript_items();
let original_len = items.len();
let out = finalize_input_for_request(items, &effective_prev_id);
assert_eq!(
out.len(),
original_len,
"stateless gateway must replay the full transcript so the model keeps context"
);
}
#[test]
fn stateful_endpoint_still_trims_and_chains() {
let api_url = "https://api.openai.com/v1/responses";
let prev_id: Option<String> = Some("resp_turn_1".to_string());
let effective_prev_id = if endpoint_persists_responses(api_url) {
prev_id.clone()
} else {
None
};
assert_eq!(
effective_prev_id, prev_id,
"stateful endpoint keeps the continuation handle"
);
let out = finalize_input_for_request(sample_full_transcript_items(), &effective_prev_id);
assert_eq!(out.len(), 1, "stateful endpoint trims to the delta window");
}
#[tokio::test]
async fn stateless_gateway_request_replays_full_transcript_on_the_wire() {
use crate::tool_types::ToolCall;
use serde_json::json;
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let api_url = format!("{}/v1/responses", server.uri());
let driver = OpenResponsesProtocolChatDriver::with_base_url("test-key", api_url);
let messages = vec![
LlmMessage::text(LlmMessageRole::System, "You are helpful"),
LlmMessage::text(LlmMessageRole::User, "upgrade dependencies"),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text("Let me look.".to_string()),
tool_calls: Some(vec![ToolCall {
id: "call_1".to_string(),
name: "read_file".to_string(),
arguments: json!({"path": "Cargo.toml"}),
}]),
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
},
LlmMessage {
role: LlmMessageRole::Tool,
content: LlmMessageContent::Text("[package]…".to_string()),
tool_calls: None,
tool_call_id: Some("call_1".to_string()),
phase: None,
thinking: None,
thinking_signature: None,
},
];
let config = LlmCallConfig {
model: "some/model".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata: std::collections::HashMap::new(),
previous_response_id: Some("gen-turn-1".to_string()),
tool_search: None,
prompt_cache: None,
openrouter_routing: None,
};
let _ = driver.chat_completion_stream(messages, &config).await;
let requests = server
.received_requests()
.await
.expect("mock server recorded requests");
assert_eq!(requests.len(), 1, "exactly one request should be sent");
let body: serde_json::Value = requests[0].body_json().expect("request body is JSON");
assert!(
body.get("previous_response_id").is_none(),
"stateless gateway request must omit previous_response_id; body: {body}"
);
let input = body["input"].as_array().expect("input is an array");
assert_eq!(
input.len(),
4,
"full transcript must be replayed on a stateless gateway; got {input:?}"
);
assert_eq!(body["instructions"], "You are helpful");
let has_user_task = input
.iter()
.any(|item| item["type"] == "message" && item["role"] == "user");
assert!(
has_user_task,
"the original user task must be replayed; got {input:?}"
);
let has_tool_output = input
.iter()
.any(|item| item["type"] == "function_call_output");
assert!(
has_tool_output,
"the latest tool result must still be present; got {input:?}"
);
}
#[tokio::test]
async fn openrouter_provider_does_not_send_hosted_tool_search() {
use crate::tool_types::DeferrablePolicy;
use serde_json::json;
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let api_url = format!("{}/v1/responses", server.uri());
let driver = OpenResponsesProtocolChatDriver::with_base_url("test-key", api_url)
.with_provider_type(DriverId::OpenRouter);
let tools: Vec<ToolDefinition> = (0..16)
.map(|i| {
make_tool(
&format!("tool_{i}"),
Some("General"),
DeferrablePolicy::Automatic,
)
})
.collect();
let config = LlmCallConfig {
model: "gpt-5.4".to_string(),
temperature: None,
max_tokens: None,
tools,
reasoning_effort: None,
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: Some(crate::llm_driver_registry::ToolSearchConfig {
enabled: true,
threshold: 15,
}),
prompt_cache: None,
openrouter_routing: None,
};
let messages = vec![LlmMessage::text(LlmMessageRole::User, "hello")];
let _ = driver.chat_completion_stream(messages, &config).await;
let requests = server
.received_requests()
.await
.expect("mock server recorded requests");
assert_eq!(requests.len(), 1, "exactly one request should be sent");
let body: serde_json::Value = requests[0].body_json().expect("request body is JSON");
let tools = body["tools"].as_array().expect("tools is an array");
assert!(
tools.iter().all(|tool| tool["type"] == "function"),
"OpenRouter should receive regular function tools, not hosted tool_search payloads: {tools:?}"
);
assert!(
tools.iter().all(|tool| tool.get("defer_loading").is_none()),
"OpenRouter tool schemas should not be deferred by hosted tool_search: {tools:?}"
);
assert_eq!(
body["input"],
json!([{"type": "message", "role": "user", "content": "hello"}])
);
}
#[tokio::test]
async fn openrouter_provider_sends_routing_controls() {
use crate::llm_driver_registry::{
OpenRouterDataCollection, OpenRouterMaxPrice, OpenRouterProviderSort,
OpenRouterProviderSortBy, OpenRouterProviderSortOptions, OpenRouterRoute,
OpenRouterRoutingConfig, OpenRouterSortPartition,
};
use serde_json::json;
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let api_url = format!("{}/v1/responses", server.uri());
let driver = OpenResponsesProtocolChatDriver::with_base_url("test-key", api_url)
.with_provider_type(DriverId::OpenRouter);
let config = LlmCallConfig {
model: "openai/gpt-5-mini".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: None,
prompt_cache: None,
openrouter_routing: Some(OpenRouterRoutingConfig {
models: vec![
"openai/gpt-5-mini".to_string(),
"anthropic/claude-sonnet-4.5".to_string(),
],
route: Some(OpenRouterRoute::Fallback),
provider: Some(OpenRouterProviderRouting {
order: vec!["openai".to_string()],
allow_fallbacks: Some(false),
require_parameters: Some(true),
data_collection: Some(OpenRouterDataCollection::Deny),
zdr: Some(true),
sort: Some(OpenRouterProviderSort::Advanced(
OpenRouterProviderSortOptions {
by: OpenRouterProviderSortBy::Latency,
partition: Some(OpenRouterSortPartition::None),
},
)),
max_price: Some(OpenRouterMaxPrice {
prompt: Some(1.0),
completion: Some(2.0),
..Default::default()
}),
..Default::default()
}),
..Default::default()
}),
};
let messages = vec![LlmMessage::text(LlmMessageRole::User, "hello")];
let _ = driver.chat_completion_stream(messages, &config).await;
let requests = server
.received_requests()
.await
.expect("mock server recorded requests");
assert_eq!(requests.len(), 1, "exactly one request should be sent");
let body: serde_json::Value = requests[0].body_json().expect("request body is JSON");
assert_eq!(
body["models"],
json!(["openai/gpt-5-mini", "anthropic/claude-sonnet-4.5"])
);
assert_eq!(body["route"], "fallback");
assert_eq!(
body["provider"],
json!({
"order": ["openai"],
"allow_fallbacks": false,
"require_parameters": true,
"data_collection": "deny",
"zdr": true,
"sort": {
"by": "latency",
"partition": "none"
},
"max_price": {
"prompt": 1.0,
"completion": 2.0
}
})
);
}
#[tokio::test]
async fn openai_provider_omits_openrouter_routing_controls() {
use crate::llm_driver_registry::{OpenRouterRoute, OpenRouterRoutingConfig};
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let api_url = format!("{}/v1/responses", server.uri());
let driver = OpenResponsesProtocolChatDriver::with_base_url("test-key", api_url);
let config = LlmCallConfig {
model: "gpt-5-mini".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: None,
prompt_cache: None,
openrouter_routing: Some(OpenRouterRoutingConfig {
models: vec!["openai/gpt-5-mini".to_string()],
route: Some(OpenRouterRoute::Fallback),
provider: None,
..Default::default()
}),
};
let messages = vec![LlmMessage::text(LlmMessageRole::User, "hello")];
let _ = driver.chat_completion_stream(messages, &config).await;
let requests = server
.received_requests()
.await
.expect("mock server recorded requests");
assert_eq!(requests.len(), 1, "exactly one request should be sent");
let body: serde_json::Value = requests[0].body_json().expect("request body is JSON");
assert!(body.get("models").is_none(), "body: {body}");
assert!(body.get("route").is_none(), "body: {body}");
assert!(body.get("provider").is_none(), "body: {body}");
}
#[tokio::test]
async fn openrouter_provider_rejects_invalid_routing_controls() {
use crate::llm_driver_registry::{OpenRouterRoute, OpenRouterRoutingConfig};
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let api_url = format!("{}/v1/responses", server.uri());
let driver = OpenResponsesProtocolChatDriver::with_base_url("test-key", api_url)
.with_provider_type(DriverId::OpenRouter);
let mismatch_config = LlmCallConfig {
model: "openai/gpt-5-mini".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: None,
prompt_cache: None,
openrouter_routing: Some(OpenRouterRoutingConfig {
models: vec!["anthropic/claude-sonnet-4.5".to_string()],
route: Some(OpenRouterRoute::Fallback),
provider: None,
..Default::default()
}),
};
let err = match driver
.chat_completion_stream(
vec![LlmMessage::text(LlmMessageRole::User, "hello")],
&mismatch_config,
)
.await
{
Ok(_) => panic!("invalid OpenRouter routing should fail before dispatch"),
Err(err) => err,
};
assert!(err.to_string().contains("models[0]"));
let empty_fallback_config = LlmCallConfig {
model: "openai/gpt-5-mini".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: None,
prompt_cache: None,
openrouter_routing: Some(OpenRouterRoutingConfig {
models: vec![],
route: Some(OpenRouterRoute::Fallback),
provider: None,
..Default::default()
}),
};
let err = match driver
.chat_completion_stream(
vec![LlmMessage::text(LlmMessageRole::User, "hello")],
&empty_fallback_config,
)
.await
{
Ok(_) => panic!("empty OpenRouter fallback routing should fail before dispatch"),
Err(err) => err,
};
assert!(err.to_string().contains("requires at least one model"));
let requests = server
.received_requests()
.await
.expect("mock server recorded requests");
assert!(
requests.is_empty(),
"invalid routing must be rejected before request dispatch"
);
}
#[test]
fn test_compact_request_serialization() {
let request = CompactRequest {
model: "gpt-4o".to_string(),
input: vec![
CompactInputItem::Message {
role: "user".to_string(),
content: CompactContent::Text("Hello!".to_string()),
},
CompactInputItem::Message {
role: "assistant".to_string(),
content: CompactContent::Text("Hi there!".to_string()),
},
],
previous_response_id: None,
instructions: Some("Be helpful".to_string()),
};
let json = serde_json::to_value(&request).unwrap();
assert_eq!(json["model"], "gpt-4o");
assert_eq!(json["instructions"], "Be helpful");
assert!(json["input"].is_array());
assert_eq!(json["input"].as_array().unwrap().len(), 2);
}
#[test]
fn test_compact_input_item_message_serialization() {
let item = CompactInputItem::Message {
role: "user".to_string(),
content: CompactContent::Text("Test message".to_string()),
};
let json = serde_json::to_value(&item).unwrap();
assert_eq!(json["type"], "message");
assert_eq!(json["role"], "user");
assert_eq!(json["content"], "Test message");
}
#[test]
fn test_compact_input_item_function_call_serialization() {
let item = CompactInputItem::FunctionCall {
call_id: "call_123".to_string(),
name: "get_weather".to_string(),
arguments: r#"{"city":"NYC"}"#.to_string(),
};
let json = serde_json::to_value(&item).unwrap();
assert_eq!(json["type"], "function_call");
assert_eq!(json["call_id"], "call_123");
assert_eq!(json["name"], "get_weather");
assert_eq!(json["arguments"], r#"{"city":"NYC"}"#);
}
#[test]
fn test_compact_input_item_compaction_serialization() {
let item = CompactInputItem::Compaction {
encrypted_content: "encrypted_data_here".to_string(),
};
let json = serde_json::to_value(&item).unwrap();
assert_eq!(json["type"], "compaction");
assert_eq!(json["encrypted_content"], "encrypted_data_here");
}
#[test]
fn test_compact_output_item_deserialization() {
let json = r#"{
"type": "message",
"role": "user",
"content": "Hello"
}"#;
let item: CompactOutputItem = serde_json::from_str(json).unwrap();
match item {
CompactOutputItem::Message { role, content } => {
assert_eq!(role, "user");
match content {
CompactContent::Text(text) => assert_eq!(text, "Hello"),
_ => panic!("Expected text content"),
}
}
_ => panic!("Expected Message item"),
}
}
#[test]
fn test_compact_output_compaction_deserialization() {
let json = r#"{
"type": "compaction",
"encrypted_content": "abc123encrypted"
}"#;
let item: CompactOutputItem = serde_json::from_str(json).unwrap();
match item {
CompactOutputItem::Compaction { encrypted_content } => {
assert_eq!(encrypted_content, "abc123encrypted");
}
_ => panic!("Expected Compaction item"),
}
}
#[test]
fn test_compact_response_deserialization() {
let json = r#"{
"output": [
{"type": "message", "role": "user", "content": "Hello"},
{"type": "compaction", "encrypted_content": "xyz789"}
],
"usage": {
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150
}
}"#;
let response: CompactResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.output.len(), 2);
assert!(response.usage.is_some());
let usage = response.usage.unwrap();
assert_eq!(usage.input_tokens, Some(100));
assert_eq!(usage.output_tokens, Some(50));
assert_eq!(usage.total_tokens, Some(150));
}
#[test]
fn test_compact_content_parts_serialization() {
let content = CompactContent::Parts(vec![
CompactContentPart::InputText {
text: "Check this image".to_string(),
},
CompactContentPart::InputImage {
image_url: "data:image/png;base64,abc".to_string(),
},
]);
let json = serde_json::to_value(&content).unwrap();
assert!(json.is_array());
assert_eq!(json[0]["type"], "input_text");
assert_eq!(json[0]["text"], "Check this image");
assert_eq!(json[1]["type"], "input_image");
}
#[test]
fn test_supports_compact_default_url() {
let driver = OpenResponsesProtocolChatDriver::new("test-key");
assert!(driver.supports_compact());
}
#[test]
fn test_supports_compact_custom_url() {
let driver = OpenResponsesProtocolChatDriver::with_base_url(
"test-key",
"https://custom.api.com/v1/responses",
);
assert!(!driver.supports_compact());
}
#[test]
fn test_reasoning_input_item_serialization() {
let item = ResponsesInputItem::Reasoning {
r#type: "reasoning".to_string(),
id: "rs_00000001".to_string(),
encrypted_content: "encrypted_reasoning_context_here".to_string(),
};
let json = serde_json::to_value(&item).unwrap();
assert_eq!(json["type"], "reasoning");
assert_eq!(json["id"], "rs_00000001");
assert_eq!(
json["encrypted_content"],
"encrypted_reasoning_context_here"
);
}
#[test]
fn test_build_input_with_thinking_signature() {
let messages = vec![
LlmMessage::text(LlmMessageRole::User, "Think about this deeply"),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text("I have thought about this.".to_string()),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: Some("This is my chain of thought reasoning...".to_string()),
thinking_signature: Some("encrypted_reasoning_token_123".to_string()),
},
LlmMessage::text(LlmMessageRole::User, "What else?"),
];
let (_, input) = OpenResponsesProtocolChatDriver::build_input(&messages, false);
assert_eq!(input.len(), 4);
let json = serde_json::to_value(&input[0]).unwrap();
assert_eq!(json["role"], "user");
assert_eq!(json["content"], "Think about this deeply");
let json = serde_json::to_value(&input[1]).unwrap();
assert_eq!(json["type"], "reasoning");
assert_eq!(json["encrypted_content"], "encrypted_reasoning_token_123");
let json = serde_json::to_value(&input[2]).unwrap();
assert_eq!(json["role"], "assistant");
assert_eq!(json["content"], "I have thought about this.");
let json = serde_json::to_value(&input[3]).unwrap();
assert_eq!(json["role"], "user");
}
#[test]
fn test_build_input_with_thinking_signature_and_tool_calls() {
use crate::tool_types::ToolCall;
let messages = vec![
LlmMessage::text(LlmMessageRole::User, "What time is it? Think carefully."),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text("Let me check.".to_string()),
tool_calls: Some(vec![ToolCall {
id: "call_123".to_string(),
name: "get_time".to_string(),
arguments: json!({}),
}]),
tool_call_id: None,
phase: None,
thinking: Some("I need to call the get_time tool...".to_string()),
thinking_signature: Some("encrypted_token_xyz".to_string()),
},
LlmMessage {
role: LlmMessageRole::Tool,
content: LlmMessageContent::Text("10:30 AM".to_string()),
tool_calls: None,
tool_call_id: Some("call_123".to_string()),
phase: None,
thinking: None,
thinking_signature: None,
},
];
let (_, input) = OpenResponsesProtocolChatDriver::build_input(&messages, false);
assert_eq!(input.len(), 5);
let json = serde_json::to_value(&input[1]).unwrap();
assert_eq!(json["type"], "reasoning");
assert_eq!(json["encrypted_content"], "encrypted_token_xyz");
let json = serde_json::to_value(&input[2]).unwrap();
assert_eq!(json["role"], "assistant");
let json = serde_json::to_value(&input[3]).unwrap();
assert_eq!(json["type"], "function_call");
assert_eq!(json["call_id"], "call_123");
let json = serde_json::to_value(&input[4]).unwrap();
assert_eq!(json["type"], "function_call_output");
}
#[test]
fn test_build_input_without_thinking_signature() {
let messages = vec![
LlmMessage::text(LlmMessageRole::User, "Hello"),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text("Hi there!".to_string()),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: Some("Some thinking...".to_string()),
thinking_signature: None, },
];
let (_, input) = OpenResponsesProtocolChatDriver::build_input(&messages, false);
assert_eq!(input.len(), 2);
let json = serde_json::to_value(&input[0]).unwrap();
assert_eq!(json["role"], "user");
let json = serde_json::to_value(&input[1]).unwrap();
assert_eq!(json["role"], "assistant");
}
#[test]
fn test_handle_streaming_event_reasoning_encrypted_content() {
use std::sync::Mutex;
let input_tokens = Mutex::new(0u32);
let output_tokens = Mutex::new(0u32);
let cache_read_tokens = Mutex::new(None);
let accumulated_tool_calls = Mutex::new(Vec::new());
let finish_reason = Mutex::new(None);
let event = StreamingEvent::OutputItemDone {
sequence_number: 5,
output_index: 0,
item: Some(types::OutputItem::Reasoning {
id: "rs_001".to_string(),
summary: vec![],
content: None,
encrypted_content: Some("encrypted_reasoning_data".to_string()),
}),
};
let result = handle_streaming_event(
event,
&input_tokens,
&output_tokens,
&cache_read_tokens,
&accumulated_tool_calls,
&finish_reason,
"gpt-5".to_string(),
None,
);
match result {
LlmStreamEvent::ReasonItem {
provider,
model,
item_id,
encrypted_content,
summary,
token_count,
} => {
assert_eq!(provider, "openai");
assert_eq!(model.as_deref(), Some("gpt-5"));
assert_eq!(item_id, "rs_001");
assert_eq!(
encrypted_content.as_deref(),
Some("encrypted_reasoning_data")
);
assert!(summary.is_empty());
assert!(token_count.is_none());
}
other => panic!("Expected ReasonItem event, got {:?}", other),
}
}
#[test]
fn test_handle_streaming_event_reasoning_without_encrypted_content() {
use std::sync::Mutex;
let input_tokens = Mutex::new(0u32);
let output_tokens = Mutex::new(0u32);
let cache_read_tokens = Mutex::new(None);
let accumulated_tool_calls = Mutex::new(Vec::new());
let finish_reason = Mutex::new(None);
let event = StreamingEvent::OutputItemDone {
sequence_number: 5,
output_index: 0,
item: Some(types::OutputItem::Reasoning {
id: "rs_001".to_string(),
summary: vec![types::ContentPart::SummaryText {
text: "Some summary".to_string(),
}],
content: None,
encrypted_content: None, }),
};
let result = handle_streaming_event(
event,
&input_tokens,
&output_tokens,
&cache_read_tokens,
&accumulated_tool_calls,
&finish_reason,
"gpt-5".to_string(),
None,
);
match result {
LlmStreamEvent::ReasonItem {
provider,
item_id,
encrypted_content,
summary,
..
} => {
assert_eq!(provider, "openai");
assert_eq!(item_id, "rs_001");
assert!(encrypted_content.is_none());
assert_eq!(summary, vec!["Some summary".to_string()]);
}
other => panic!("Expected ReasonItem event, got {:?}", other),
}
}
#[test]
fn test_handle_streaming_event_reasoning_drops_plaintext_content() {
use std::sync::Mutex;
let input_tokens = Mutex::new(0u32);
let output_tokens = Mutex::new(0u32);
let cache_read_tokens = Mutex::new(None);
let accumulated_tool_calls = Mutex::new(Vec::new());
let finish_reason = Mutex::new(None);
let event = StreamingEvent::OutputItemDone {
sequence_number: 5,
output_index: 0,
item: Some(types::OutputItem::Reasoning {
id: "rs_002".to_string(),
summary: vec![
types::ContentPart::SummaryText {
text: "safe summary".to_string(),
},
types::ContentPart::ReasoningText {
text: "SECRET hidden reasoning".to_string(),
},
],
content: Some(vec![types::ContentPart::ReasoningText {
text: "SECRET hidden reasoning".to_string(),
}]),
encrypted_content: Some("opaque".to_string()),
}),
};
let result = handle_streaming_event(
event,
&input_tokens,
&output_tokens,
&cache_read_tokens,
&accumulated_tool_calls,
&finish_reason,
"gpt-5".to_string(),
None,
);
match result {
LlmStreamEvent::ReasonItem {
summary,
encrypted_content,
..
} => {
assert_eq!(summary, vec!["safe summary".to_string()]);
assert_eq!(encrypted_content.as_deref(), Some("opaque"));
}
other => panic!("Expected ReasonItem event, got {:?}", other),
}
}
#[test]
fn test_handle_streaming_event_reasoning_delta() {
use std::sync::Mutex;
let input_tokens = Mutex::new(0u32);
let output_tokens = Mutex::new(0u32);
let cache_read_tokens = Mutex::new(None);
let accumulated_tool_calls = Mutex::new(Vec::new());
let finish_reason = Mutex::new(None);
let event = StreamingEvent::ReasoningDelta {
sequence_number: 3,
item_id: "rs_001".to_string(),
output_index: 0,
content_index: 0,
delta: "Let me reason about this...".to_string(),
obfuscation: None,
};
let result = handle_streaming_event(
event,
&input_tokens,
&output_tokens,
&cache_read_tokens,
&accumulated_tool_calls,
&finish_reason,
"o3".to_string(),
None,
);
match result {
LlmStreamEvent::ThinkingDelta(text) => {
assert_eq!(text, "Let me reason about this...");
}
_ => panic!("Expected ThinkingDelta, got {:?}", result),
}
}
#[test]
fn test_handle_streaming_event_reasoning_summary_delta() {
use std::sync::Mutex;
let input_tokens = Mutex::new(0u32);
let output_tokens = Mutex::new(0u32);
let cache_read_tokens = Mutex::new(None);
let accumulated_tool_calls = Mutex::new(Vec::new());
let finish_reason = Mutex::new(None);
let event = StreamingEvent::ReasoningSummaryDelta {
sequence_number: 4,
item_id: "rs_002".to_string(),
output_index: 0,
summary_index: 0,
delta: "Breaking down the problem...".to_string(),
obfuscation: None,
};
let result = handle_streaming_event(
event,
&input_tokens,
&output_tokens,
&cache_read_tokens,
&accumulated_tool_calls,
&finish_reason,
"gpt-5.2".to_string(),
None,
);
match result {
LlmStreamEvent::ThinkingDelta(text) => {
assert_eq!(text, "Breaking down the problem...");
}
_ => panic!("Expected ThinkingDelta, got {:?}", result),
}
}
#[test]
fn test_request_reasoning_none_is_omitted() {
let config = LlmCallConfig {
model: "gpt-5.2".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: Some("none".to_string()),
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: None,
prompt_cache: None,
openrouter_routing: None,
};
let reasoning = config
.reasoning_effort
.as_ref()
.filter(|e| !e.eq_ignore_ascii_case("none"))
.map(|effort| ResponsesReasoning {
effort: effort.clone(),
summary: "detailed".to_string(),
});
assert!(
reasoning.is_none(),
"reasoning should be None for effort=none"
);
}
#[test]
fn test_request_reasoning_high_is_included() {
let config = LlmCallConfig {
model: "gpt-5.2".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: Some("high".to_string()),
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: None,
prompt_cache: None,
openrouter_routing: None,
};
let reasoning = config
.reasoning_effort
.as_ref()
.filter(|e| !e.eq_ignore_ascii_case("none"))
.map(|effort| ResponsesReasoning {
effort: effort.clone(),
summary: "detailed".to_string(),
});
assert!(
reasoning.is_some(),
"reasoning should be present for effort=high"
);
let r = reasoning.unwrap();
assert_eq!(r.effort, "high");
assert_eq!(r.summary, "detailed");
}
#[test]
fn test_request_reasoning_none_case_insensitive() {
for effort in &["none", "None", "NONE"] {
let reasoning = Some(effort.to_string())
.as_ref()
.filter(|e| !e.eq_ignore_ascii_case("none"))
.cloned();
assert!(
reasoning.is_none(),
"effort={effort:?} should be filtered out"
);
}
}
#[test]
fn test_build_input_assistant_without_thinking_or_tools() {
let messages = vec![
LlmMessage::text(LlmMessageRole::User, "Hello"),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text("Hi there!".to_string()),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
},
];
let (_, input) = OpenResponsesProtocolChatDriver::build_input(&messages, false);
assert_eq!(input.len(), 2);
let json = serde_json::to_value(&input[1]).unwrap();
assert_eq!(json["role"], "assistant");
assert!(json.get("type").is_none() || json["type"] == "message");
}
#[test]
fn test_build_input_multiple_reasoning_items_get_unique_ids() {
let messages = vec![
LlmMessage::text(LlmMessageRole::User, "First question"),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text("First answer.".to_string()),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: Some("thinking 1".to_string()),
thinking_signature: Some("encrypted_1".to_string()),
},
LlmMessage::text(LlmMessageRole::User, "Second question"),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text("Second answer.".to_string()),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: Some("thinking 2".to_string()),
thinking_signature: Some("encrypted_2".to_string()),
},
];
let (_, input) = OpenResponsesProtocolChatDriver::build_input(&messages, false);
assert_eq!(input.len(), 6);
let r1 = serde_json::to_value(&input[1]).unwrap();
let r2 = serde_json::to_value(&input[4]).unwrap();
assert_eq!(r1["type"], "reasoning");
assert_eq!(r2["type"], "reasoning");
assert_ne!(r1["id"], r2["id"], "Reasoning items should have unique IDs");
assert_eq!(r1["encrypted_content"], "encrypted_1");
assert_eq!(r2["encrypted_content"], "encrypted_2");
}
#[test]
fn test_build_input_with_phases_enabled() {
use crate::message::ExecutionPhase;
let messages = vec![
LlmMessage::text(LlmMessageRole::System, "You are helpful"),
LlmMessage::text(LlmMessageRole::User, "Hello"),
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text("Working on it...".to_string()),
tool_calls: Some(vec![crate::tool_types::ToolCall {
id: "call_1".to_string(),
name: "search".to_string(),
arguments: json!({}),
}]),
tool_call_id: None,
phase: Some(ExecutionPhase::Commentary),
thinking: None,
thinking_signature: None,
},
LlmMessage {
role: LlmMessageRole::Tool,
content: LlmMessageContent::Text("result".to_string()),
tool_calls: None,
tool_call_id: Some("call_1".to_string()),
phase: None,
thinking: None,
thinking_signature: None,
},
];
let (_, input) = OpenResponsesProtocolChatDriver::build_input(&messages, true);
let assistant_json = serde_json::to_value(&input[1]).unwrap();
assert_eq!(assistant_json["phase"], "commentary");
let (_, input_no_phases) = OpenResponsesProtocolChatDriver::build_input(&messages, false);
let assistant_json_no = serde_json::to_value(&input_no_phases[1]).unwrap();
assert!(assistant_json_no.get("phase").is_none() || assistant_json_no["phase"].is_null());
}
fn make_tool(
name: &str,
category: Option<&str>,
deferrable: crate::tool_types::DeferrablePolicy,
) -> ToolDefinition {
ToolDefinition::Builtin(crate::tool_types::BuiltinTool {
name: name.to_string(),
display_name: None,
description: format!("{} description", name),
parameters: json!({"type": "object", "properties": {}}),
policy: crate::tool_types::ToolPolicy::Auto,
category: category.map(|s| s.to_string()),
deferrable,
hints: crate::tool_types::ToolHints::default(),
full_parameters: None,
})
}
#[test]
fn test_convert_tools_with_search_below_threshold_falls_back() {
use crate::tool_types::DeferrablePolicy;
let tools: Vec<ToolDefinition> = (0..5)
.map(|i| {
make_tool(
&format!("tool_{i}"),
Some("cat"),
DeferrablePolicy::Automatic,
)
})
.collect();
let result = OpenResponsesProtocolChatDriver::convert_tools_with_search(&tools, 15);
assert_eq!(result.len(), 5);
let json = serde_json::to_value(&result).unwrap();
for item in json.as_array().unwrap() {
assert_eq!(item["type"], "function");
assert!(item.get("defer_loading").is_none() || item["defer_loading"].is_null());
}
}
#[test]
fn test_convert_tools_with_search_groups_by_category() {
use crate::tool_types::DeferrablePolicy;
let mut tools = vec![];
for i in 0..10 {
tools.push(make_tool(
&format!("fs_tool_{i}"),
Some("FileSystem"),
DeferrablePolicy::Automatic,
));
}
for i in 0..6 {
tools.push(make_tool(
&format!("weather_tool_{i}"),
Some("Weather"),
DeferrablePolicy::Automatic,
));
}
let result = OpenResponsesProtocolChatDriver::convert_tools_with_search(&tools, 15);
let json = serde_json::to_value(&result).unwrap();
let arr = json.as_array().unwrap();
assert_eq!(arr.len(), 3);
assert_eq!(arr.last().unwrap()["type"], "tool_search");
let ns: Vec<&Value> = arr.iter().filter(|v| v["type"] == "namespace").collect();
assert_eq!(ns.len(), 2);
let ns_names: Vec<&str> = ns.iter().map(|v| v["name"].as_str().unwrap()).collect();
assert!(ns_names.contains(&"FileSystem"));
assert!(ns_names.contains(&"Weather"));
for n in &ns {
let inner_tools = n["tools"].as_array().unwrap();
match n["name"].as_str().unwrap() {
"FileSystem" => assert_eq!(inner_tools.len(), 10),
"Weather" => assert_eq!(inner_tools.len(), 6),
other => panic!("Unexpected namespace: {other}"),
}
for t in inner_tools {
assert_eq!(t["defer_loading"], true);
}
}
}
#[test]
fn test_convert_tools_with_search_never_defer_stays_top_level() {
use crate::tool_types::DeferrablePolicy;
let mut tools = vec![];
tools.push(make_tool(
"write_todos",
Some("Productivity"),
DeferrablePolicy::Never,
));
tools.push(make_tool(
"get_session_info",
Some("Session"),
DeferrablePolicy::Never,
));
for i in 0..14 {
tools.push(make_tool(
&format!("fs_tool_{i}"),
Some("FileSystem"),
DeferrablePolicy::Automatic,
));
}
let result = OpenResponsesProtocolChatDriver::convert_tools_with_search(&tools, 15);
let json = serde_json::to_value(&result).unwrap();
let arr = json.as_array().unwrap();
assert_eq!(arr.len(), 4);
let funcs: Vec<&Value> = arr.iter().filter(|v| v["type"] == "function").collect();
assert_eq!(funcs.len(), 2);
for f in &funcs {
assert!(f.get("defer_loading").is_none() || f["defer_loading"].is_null());
}
let ns: Vec<&Value> = arr.iter().filter(|v| v["type"] == "namespace").collect();
assert_eq!(ns.len(), 1);
assert_eq!(ns[0]["name"], "FileSystem");
assert_eq!(ns[0]["tools"].as_array().unwrap().len(), 14);
}
#[test]
fn test_convert_tools_with_search_ungrouped_tools() {
use crate::tool_types::DeferrablePolicy;
let mut tools = vec![];
for i in 0..10 {
tools.push(make_tool(
&format!("cat_tool_{i}"),
Some("Cat"),
DeferrablePolicy::Automatic,
));
}
for i in 0..6 {
tools.push(make_tool(
&format!("misc_tool_{i}"),
None,
DeferrablePolicy::Automatic,
));
}
let result = OpenResponsesProtocolChatDriver::convert_tools_with_search(&tools, 15);
let json = serde_json::to_value(&result).unwrap();
let arr = json.as_array().unwrap();
assert_eq!(arr.len(), 8);
let ns: Vec<&Value> = arr.iter().filter(|v| v["type"] == "namespace").collect();
assert_eq!(ns.len(), 1);
assert_eq!(ns[0]["tools"].as_array().unwrap().len(), 10);
let funcs: Vec<&Value> = arr.iter().filter(|v| v["type"] == "function").collect();
assert_eq!(funcs.len(), 6);
for f in &funcs {
assert_eq!(f["defer_loading"], true);
}
assert_eq!(arr.last().unwrap()["type"], "tool_search");
}
#[test]
fn test_convert_tools_with_search_always_policy() {
use crate::tool_types::DeferrablePolicy;
let mut tools = vec![];
for i in 0..14 {
tools.push(make_tool(
&format!("tool_{i}"),
Some("General"),
DeferrablePolicy::Automatic,
));
}
tools.push(make_tool(
"always_tool",
Some("General"),
DeferrablePolicy::Always,
));
let result = OpenResponsesProtocolChatDriver::convert_tools_with_search(&tools, 15);
let json = serde_json::to_value(&result).unwrap();
let arr = json.as_array().unwrap();
assert_eq!(arr.len(), 2);
let ns = &arr[0];
assert_eq!(ns["type"], "namespace");
let inner = ns["tools"].as_array().unwrap();
assert_eq!(inner.len(), 15);
for t in inner {
assert_eq!(t["defer_loading"], true);
}
}
#[test]
fn test_tool_search_serialization_format() {
let ts = ResponsesTool::ToolSearch {
r#type: "tool_search".to_string(),
};
let json = serde_json::to_value(&ts).unwrap();
assert_eq!(json, json!({"type": "tool_search"}));
}
#[test]
fn test_namespace_serialization_format() {
let ns = ResponsesTool::Namespace {
r#type: "namespace".to_string(),
name: "FileSystem".to_string(),
description: "Tools for FileSystem".to_string(),
tools: vec![ResponsesTool::Function {
r#type: "function".to_string(),
name: "read_file".to_string(),
description: "Read a file".to_string(),
parameters: json!({}),
defer_loading: Some(true),
}],
};
let json = serde_json::to_value(&ns).unwrap();
assert_eq!(json["type"], "namespace");
assert_eq!(json["name"], "FileSystem");
assert_eq!(json["tools"][0]["name"], "read_file");
assert_eq!(json["tools"][0]["defer_loading"], true);
}
#[test]
fn test_hosted_tool_search_completed_event_preserves_response_id() {
let event_json = r#"{
"type": "response.completed",
"sequence_number": 8,
"response": {
"id": "resp_tool_search",
"object": "response",
"created_at": 1780000000,
"status": "completed",
"model": "gpt-5.5",
"output": [
{
"type": "tool_search_call",
"execution": "server",
"call_id": null,
"status": "completed",
"arguments": { "paths": ["Math"] }
},
{
"type": "tool_search_output",
"execution": "server",
"call_id": null,
"status": "completed",
"tools": [
{
"type": "namespace",
"name": "Math",
"description": "Tools for Math",
"tools": [
{
"type": "function",
"name": "add",
"description": "Add numbers.",
"defer_loading": true,
"parameters": {
"type": "object",
"properties": {
"a": { "type": "number" },
"b": { "type": "number" }
},
"required": ["a", "b"],
"additionalProperties": false
}
}
]
}
]
},
{
"type": "function_call",
"id": "fc_123",
"call_id": "call_123",
"name": "add",
"namespace": "Math",
"arguments": "{\"a\":7,\"b\":3}",
"status": "completed"
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15
}
}
}"#;
let event: StreamingEvent = serde_json::from_str(event_json).unwrap();
let stream_event = handle_streaming_event(
event,
&Mutex::new(0),
&Mutex::new(0),
&Mutex::new(None),
&Mutex::new(Vec::new()),
&Mutex::new(Some("tool_calls".to_string())),
"gpt-5.5".to_string(),
None,
);
match stream_event {
LlmStreamEvent::Done(metadata) => {
assert_eq!(metadata.response_id.as_deref(), Some("resp_tool_search"));
assert_eq!(metadata.finish_reason.as_deref(), Some("tool_calls"));
}
other => panic!("expected Done event, got {other:?}"),
}
}
#[test]
fn test_sanitize_parameters_adds_missing_properties() {
let params = json!({"type": "object", "additionalProperties": false});
let sanitized = OpenResponsesProtocolChatDriver::sanitize_parameters(¶ms);
assert_eq!(
sanitized,
json!({"type": "object", "properties": {}, "additionalProperties": false})
);
}
#[test]
fn test_sanitize_parameters_preserves_existing_properties() {
let params = json!({"type": "object", "properties": {"x": {"type": "string"}}, "additionalProperties": false});
let sanitized = OpenResponsesProtocolChatDriver::sanitize_parameters(¶ms);
assert_eq!(sanitized, params);
}
#[test]
fn test_sanitize_parameters_ignores_non_object_types() {
let params = json!({"type": "string"});
let sanitized = OpenResponsesProtocolChatDriver::sanitize_parameters(¶ms);
assert_eq!(sanitized, params);
}
#[test]
fn test_plugins_to_wire_empty_is_none() {
use crate::llm_driver_registry::OpenRouterPluginConfig;
let cfg = OpenRouterPluginConfig::default();
assert!(plugins_to_wire(&cfg).is_none());
}
#[test]
fn test_plugins_to_wire_web_search_basic() {
use crate::llm_driver_registry::{OpenRouterPluginConfig, OpenRouterWebSearchPlugin};
let cfg = OpenRouterPluginConfig {
web: Some(OpenRouterWebSearchPlugin {
max_results: Some(5),
search_prompt: Some("find recent news".to_string()),
}),
file: None,
};
let wire = plugins_to_wire(&cfg).expect("should produce wire entries");
assert_eq!(wire.len(), 1);
assert_eq!(wire[0]["id"], "web");
assert_eq!(wire[0]["max_results"], 5);
assert_eq!(wire[0]["search_prompt"], "find recent news");
}
#[test]
fn test_plugins_to_wire_web_search_no_options() {
use crate::llm_driver_registry::{OpenRouterPluginConfig, OpenRouterWebSearchPlugin};
let cfg = OpenRouterPluginConfig {
web: Some(OpenRouterWebSearchPlugin::default()),
file: None,
};
let wire = plugins_to_wire(&cfg).expect("should produce wire entries");
assert_eq!(wire.len(), 1);
assert_eq!(wire[0]["id"], "web");
assert!(wire[0].get("max_results").is_none() || wire[0]["max_results"].is_null());
assert!(wire[0].get("search_prompt").is_none() || wire[0]["search_prompt"].is_null());
}
#[test]
fn test_plugins_to_wire_file_plugin() {
use crate::llm_driver_registry::{OpenRouterFilePlugin, OpenRouterPluginConfig};
let cfg = OpenRouterPluginConfig {
web: None,
file: Some(OpenRouterFilePlugin {}),
};
let wire = plugins_to_wire(&cfg).expect("should produce wire entries");
assert_eq!(wire.len(), 1);
assert_eq!(wire[0]["id"], "file");
}
#[test]
fn test_plugins_to_wire_both_plugins() {
use crate::llm_driver_registry::{
OpenRouterFilePlugin, OpenRouterPluginConfig, OpenRouterWebSearchPlugin,
};
let cfg = OpenRouterPluginConfig {
web: Some(OpenRouterWebSearchPlugin {
max_results: Some(3),
search_prompt: None,
}),
file: Some(OpenRouterFilePlugin {}),
};
let wire = plugins_to_wire(&cfg).expect("should produce wire entries");
assert_eq!(wire.len(), 2);
assert_eq!(wire[0]["id"], "web");
assert_eq!(wire[0]["max_results"], 3);
assert_eq!(wire[1]["id"], "file");
}
#[tokio::test]
async fn openrouter_provider_includes_plugins_in_request() {
use crate::llm_driver_registry::{
OpenRouterPluginConfig, OpenRouterRoutingConfig, OpenRouterWebSearchPlugin,
};
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let api_url = format!("{}/v1/responses", server.uri());
let driver = OpenResponsesProtocolChatDriver::with_base_url("test-key", api_url)
.with_provider_type(DriverId::OpenRouter);
let config = LlmCallConfig {
model: "openai/gpt-5-mini".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: None,
prompt_cache: None,
openrouter_routing: Some(OpenRouterRoutingConfig {
plugins: Some(OpenRouterPluginConfig {
web: Some(OpenRouterWebSearchPlugin {
max_results: Some(5),
search_prompt: None,
}),
file: None,
}),
..Default::default()
}),
};
let messages = vec![LlmMessage::text(LlmMessageRole::User, "search the web")];
let _ = driver.chat_completion_stream(messages, &config).await;
let requests = server
.received_requests()
.await
.expect("mock server recorded requests");
assert_eq!(requests.len(), 1);
let body: serde_json::Value = requests[0].body_json().expect("request body is JSON");
assert!(
body.get("plugins").is_some(),
"plugins field should be present: {body}"
);
let plugins = body["plugins"].as_array().unwrap();
assert_eq!(plugins.len(), 1);
assert_eq!(plugins[0]["id"], "web");
assert_eq!(plugins[0]["max_results"], 5);
}
#[tokio::test]
async fn non_openrouter_provider_omits_plugins() {
use crate::llm_driver_registry::{
OpenRouterPluginConfig, OpenRouterRoutingConfig, OpenRouterWebSearchPlugin,
};
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let api_url = format!("{}/v1/responses", server.uri());
let driver = OpenResponsesProtocolChatDriver::with_base_url("test-key", api_url);
let config = LlmCallConfig {
model: "gpt-5-mini".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: None,
prompt_cache: None,
openrouter_routing: Some(OpenRouterRoutingConfig {
plugins: Some(OpenRouterPluginConfig {
web: Some(OpenRouterWebSearchPlugin::default()),
file: None,
}),
..Default::default()
}),
};
let messages = vec![LlmMessage::text(LlmMessageRole::User, "search the web")];
let _ = driver.chat_completion_stream(messages, &config).await;
let requests = server
.received_requests()
.await
.expect("mock server recorded requests");
assert_eq!(requests.len(), 1);
let body: serde_json::Value = requests[0].body_json().expect("request body is JSON");
assert!(
body.get("plugins").is_none(),
"plugins must not be forwarded to non-OpenRouter providers: {body}"
);
}
}