use crate::llm::attachments::{request_has_attachments, validate_request_attachments};
use crate::llm::{
ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, Effort, LlmProvider, StopReason,
StreamBox, StreamDelta, ThinkingConfig, ThinkingMode, Usage,
};
use anyhow::Result;
use async_trait::async_trait;
use futures::StreamExt;
use reqwest::StatusCode;
use serde::de::Error as _;
use serde::{Deserialize, Serialize};
use super::openai_responses::OpenAIResponsesProvider;
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
fn requires_responses_api(model: &str) -> bool {
model == MODEL_GPT52_CODEX
}
fn is_official_openai_base_url(base_url: &str) -> bool {
base_url == DEFAULT_BASE_URL || base_url.contains("api.openai.com")
}
fn request_is_agentic(request: &ChatRequest) -> bool {
request
.tools
.as_ref()
.is_some_and(|tools| !tools.is_empty()) || request.messages.iter().any(|message| {
matches!(
&message.content,
Content::Blocks(blocks)
if blocks.iter().any(|block| {
matches!(block, ContentBlock::ToolUse { .. } | ContentBlock::ToolResult { .. })
})
)
})
}
fn should_use_responses_api(base_url: &str, model: &str, request: &ChatRequest) -> bool {
requires_responses_api(model)
|| request_has_attachments(request)
|| (is_official_openai_base_url(base_url) && request_is_agentic(request))
}
pub const MODEL_GPT54: &str = "gpt-5.4";
pub const MODEL_GPT53_CODEX: &str = "gpt-5.3-codex";
pub const MODEL_GPT52_INSTANT: &str = "gpt-5.2-instant";
pub const MODEL_GPT52_THINKING: &str = "gpt-5.2-thinking";
pub const MODEL_GPT52_PRO: &str = "gpt-5.2-pro";
pub const MODEL_GPT52_CODEX: &str = "gpt-5.2-codex";
pub const MODEL_GPT5: &str = "gpt-5";
pub const MODEL_GPT5_MINI: &str = "gpt-5-mini";
pub const MODEL_GPT5_NANO: &str = "gpt-5-nano";
pub const MODEL_O3: &str = "o3";
pub const MODEL_O3_MINI: &str = "o3-mini";
pub const MODEL_O4_MINI: &str = "o4-mini";
pub const MODEL_O1: &str = "o1";
pub const MODEL_O1_MINI: &str = "o1-mini";
pub const MODEL_GPT41: &str = "gpt-4.1";
pub const MODEL_GPT41_MINI: &str = "gpt-4.1-mini";
pub const MODEL_GPT41_NANO: &str = "gpt-4.1-nano";
pub const MODEL_GPT4O: &str = "gpt-4o";
pub const MODEL_GPT4O_MINI: &str = "gpt-4o-mini";
pub const BASE_URL_KIMI: &str = "https://api.moonshot.ai/v1";
pub const BASE_URL_ZAI: &str = "https://api.z.ai/api/paas/v4";
pub const BASE_URL_MINIMAX: &str = "https://api.minimax.io/v1";
pub const MODEL_KIMI_K2_5: &str = "kimi-k2.5";
pub const MODEL_KIMI_K2_THINKING: &str = "kimi-k2-thinking";
pub const MODEL_ZAI_GLM5: &str = "glm-5";
pub const MODEL_MINIMAX_M2_5: &str = "MiniMax-M2.5";
#[derive(Clone)]
pub struct OpenAIProvider {
client: reqwest::Client,
api_key: String,
model: String,
base_url: String,
thinking: Option<ThinkingConfig>,
extra_headers: Vec<(String, String)>,
}
impl OpenAIProvider {
#[must_use]
pub fn new(api_key: String, model: String) -> Self {
Self {
client: reqwest::Client::new(),
api_key,
model,
base_url: DEFAULT_BASE_URL.to_owned(),
thinking: None,
extra_headers: Vec::new(),
}
}
#[must_use]
pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
Self {
client: reqwest::Client::new(),
api_key,
model,
base_url,
thinking: None,
extra_headers: Vec::new(),
}
}
#[must_use]
pub fn kimi(api_key: String, model: String) -> Self {
Self::with_base_url(api_key, model, BASE_URL_KIMI.to_owned())
}
#[must_use]
pub fn kimi_k2_5(api_key: String) -> Self {
Self::kimi(api_key, MODEL_KIMI_K2_5.to_owned())
}
#[must_use]
pub fn kimi_k2_thinking(api_key: String) -> Self {
Self::kimi(api_key, MODEL_KIMI_K2_THINKING.to_owned())
}
#[must_use]
pub fn zai(api_key: String, model: String) -> Self {
Self::with_base_url(api_key, model, BASE_URL_ZAI.to_owned())
}
#[must_use]
pub fn zai_glm5(api_key: String) -> Self {
Self::zai(api_key, MODEL_ZAI_GLM5.to_owned())
}
#[must_use]
pub fn minimax(api_key: String, model: String) -> Self {
Self::with_base_url(api_key, model, BASE_URL_MINIMAX.to_owned())
}
#[must_use]
pub fn minimax_m2_5(api_key: String) -> Self {
Self::minimax(api_key, MODEL_MINIMAX_M2_5.to_owned())
}
#[must_use]
pub fn gpt52_instant(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT52_INSTANT.to_owned())
}
#[must_use]
pub fn gpt54(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT54.to_owned())
}
#[must_use]
pub fn gpt53_codex(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT53_CODEX.to_owned())
}
#[must_use]
pub fn gpt52_thinking(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT52_THINKING.to_owned())
}
#[must_use]
pub fn gpt52_pro(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT52_PRO.to_owned())
}
#[must_use]
pub fn codex(api_key: String) -> Self {
Self::gpt53_codex(api_key)
}
#[must_use]
pub fn gpt5(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT5.to_owned())
}
#[must_use]
pub fn gpt5_mini(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT5_MINI.to_owned())
}
#[must_use]
pub fn gpt5_nano(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT5_NANO.to_owned())
}
#[must_use]
pub fn o3(api_key: String) -> Self {
Self::new(api_key, MODEL_O3.to_owned())
}
#[must_use]
pub fn o3_mini(api_key: String) -> Self {
Self::new(api_key, MODEL_O3_MINI.to_owned())
}
#[must_use]
pub fn o4_mini(api_key: String) -> Self {
Self::new(api_key, MODEL_O4_MINI.to_owned())
}
#[must_use]
pub fn o1(api_key: String) -> Self {
Self::new(api_key, MODEL_O1.to_owned())
}
#[must_use]
pub fn o1_mini(api_key: String) -> Self {
Self::new(api_key, MODEL_O1_MINI.to_owned())
}
#[must_use]
pub fn gpt41(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT41.to_owned())
}
#[must_use]
pub fn gpt41_mini(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT41_MINI.to_owned())
}
#[must_use]
pub fn gpt4o(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT4O.to_owned())
}
#[must_use]
pub fn gpt4o_mini(api_key: String) -> Self {
Self::new(api_key, MODEL_GPT4O_MINI.to_owned())
}
#[must_use]
pub const fn with_thinking(mut self, thinking: ThinkingConfig) -> Self {
self.thinking = Some(thinking);
self
}
#[must_use]
pub fn with_extra_headers(mut self, headers: Vec<(String, String)>) -> Self {
self.extra_headers = headers;
self
}
fn apply_headers(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let builder = if self.api_key.is_empty() {
builder
} else {
builder.header("Authorization", format!("Bearer {}", self.api_key))
};
self.extra_headers
.iter()
.fold(builder, |b, (k, v)| b.header(k.as_str(), v.as_str()))
}
}
#[async_trait]
impl LlmProvider for OpenAIProvider {
async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
if should_use_responses_api(&self.base_url, &self.model, &request) {
let mut responses_provider = OpenAIResponsesProvider::with_base_url(
self.api_key.clone(),
self.model.clone(),
self.base_url.clone(),
);
if let Some(thinking) = self.thinking.clone() {
responses_provider = responses_provider.with_thinking(thinking);
}
return responses_provider.chat(request).await;
}
let thinking_config = match self.resolve_thinking_config(request.thinking.as_ref()) {
Ok(thinking) => thinking,
Err(error) => return Ok(ChatOutcome::InvalidRequest(error.to_string())),
};
if let Err(error) = validate_request_attachments(self.provider(), self.model(), &request) {
return Ok(ChatOutcome::InvalidRequest(error.to_string()));
}
let reasoning = build_api_reasoning(thinking_config.as_ref());
let messages = build_api_messages(&request);
let tools: Option<Vec<ApiTool>> = request
.tools
.map(|ts| ts.into_iter().map(convert_tool).collect());
let api_request = build_api_chat_request(
&self.model,
&messages,
request.max_tokens,
tools.as_deref(),
reasoning,
use_max_tokens_alias(&self.base_url),
);
log::debug!(
"OpenAI LLM request model={} max_tokens={}",
self.model,
request.max_tokens
);
let builder = self
.client
.post(format!("{}/chat/completions", self.base_url))
.header("Content-Type", "application/json");
let response = self
.apply_headers(builder)
.json(&api_request)
.send()
.await
.map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
let status = response.status();
let bytes = response
.bytes()
.await
.map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
log::debug!(
"OpenAI LLM response status={} body_len={}",
status,
bytes.len()
);
if status == StatusCode::TOO_MANY_REQUESTS {
return Ok(ChatOutcome::RateLimited);
}
if status.is_server_error() {
let body = String::from_utf8_lossy(&bytes);
log::error!("OpenAI server error status={status} body={body}");
return Ok(ChatOutcome::ServerError(body.into_owned()));
}
if status.is_client_error() {
let body = String::from_utf8_lossy(&bytes);
log::warn!("OpenAI client error status={status} body={body}");
return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
}
let api_response: ApiChatResponse = serde_json::from_slice(&bytes)
.map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
let choice = api_response
.choices
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("no choices in response"))?;
let content = build_content_blocks(&choice.message);
let stop_reason = choice.finish_reason.as_deref().map(map_finish_reason);
Ok(ChatOutcome::Success(ChatResponse {
id: api_response.id,
content,
model: api_response.model,
stop_reason,
usage: Usage {
input_tokens: api_response.usage.prompt_tokens,
output_tokens: api_response.usage.completion_tokens,
cached_input_tokens: api_response
.usage
.prompt_tokens_details
.as_ref()
.map_or(0, |details| details.cached_tokens),
},
}))
}
#[allow(clippy::too_many_lines)]
fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
if should_use_responses_api(&self.base_url, &self.model, &request) {
let api_key = self.api_key.clone();
let model = self.model.clone();
let base_url = self.base_url.clone();
let thinking = self.thinking.clone();
return Box::pin(async_stream::stream! {
let mut responses_provider =
OpenAIResponsesProvider::with_base_url(api_key, model, base_url);
if let Some(thinking) = thinking {
responses_provider = responses_provider.with_thinking(thinking);
}
let mut stream = std::pin::pin!(responses_provider.chat_stream(request));
while let Some(item) = futures::StreamExt::next(&mut stream).await {
yield item;
}
});
}
Box::pin(async_stream::stream! {
let thinking_config = match self.resolve_thinking_config(request.thinking.as_ref()) {
Ok(thinking) => thinking,
Err(error) => {
yield Ok(StreamDelta::Error {
message: error.to_string(),
recoverable: false,
});
return;
}
};
if let Err(error) = validate_request_attachments(self.provider(), self.model(), &request) {
yield Ok(StreamDelta::Error {
message: error.to_string(),
recoverable: false,
});
return;
}
let reasoning = build_api_reasoning(thinking_config.as_ref());
let messages = build_api_messages(&request);
let tools: Option<Vec<ApiTool>> = request
.tools
.map(|ts| ts.into_iter().map(convert_tool).collect());
let api_request = build_api_chat_request_streaming(
&self.model,
&messages,
request.max_tokens,
tools.as_deref(),
reasoning,
use_max_tokens_alias(&self.base_url),
use_stream_usage_options(&self.base_url),
);
log::debug!("OpenAI streaming LLM request model={} max_tokens={}", self.model, request.max_tokens);
let stream_builder = self.client
.post(format!("{}/chat/completions", self.base_url))
.header("Content-Type", "application/json");
let Ok(response) = self
.apply_headers(stream_builder)
.json(&api_request)
.send()
.await
else {
yield Err(anyhow::anyhow!("request failed"));
return;
};
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
let (recoverable, level) = if status == StatusCode::TOO_MANY_REQUESTS {
(true, "rate_limit")
} else if status.is_server_error() {
(true, "server_error")
} else {
(false, "client_error")
};
log::warn!("OpenAI error status={status} body={body} kind={level}");
yield Ok(StreamDelta::Error { message: body, recoverable });
return;
}
let mut tool_calls: std::collections::HashMap<usize, ToolCallAccumulator> =
std::collections::HashMap::new();
let mut usage: Option<Usage> = None;
let mut buffer = String::new();
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let Ok(chunk) = chunk_result else {
yield Err(anyhow::anyhow!("stream error: {}", chunk_result.unwrap_err()));
return;
};
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(pos) = buffer.find('\n') {
let line = buffer[..pos].trim().to_string();
buffer = buffer[pos + 1..].to_string();
if line.is_empty() { continue; }
let Some(data) = line.strip_prefix("data: ") else { continue; };
for result in process_sse_data(data) {
match result {
SseProcessResult::TextDelta(c) => yield Ok(StreamDelta::TextDelta { delta: c, block_index: 0 }),
SseProcessResult::ToolCallUpdate { index, id, name, arguments } => apply_tool_call_update(&mut tool_calls, index, id, name, arguments),
SseProcessResult::Usage(u) => usage = Some(u),
SseProcessResult::Done(sr) => {
for d in build_stream_end_deltas(&tool_calls, usage.take(), sr) { yield Ok(d); }
return;
}
SseProcessResult::Sentinel => {
let sr = if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
for d in build_stream_end_deltas(&tool_calls, usage.take(), sr) { yield Ok(d); }
return;
}
}
}
}
}
for delta in build_stream_end_deltas(&tool_calls, usage, StopReason::EndTurn) {
yield Ok(delta);
}
})
}
fn model(&self) -> &str {
&self.model
}
fn provider(&self) -> &'static str {
"openai"
}
fn configured_thinking(&self) -> Option<&ThinkingConfig> {
self.thinking.as_ref()
}
}
fn apply_tool_call_update(
tool_calls: &mut std::collections::HashMap<usize, ToolCallAccumulator>,
index: usize,
id: Option<String>,
name: Option<String>,
arguments: Option<String>,
) {
let entry = tool_calls
.entry(index)
.or_insert_with(|| ToolCallAccumulator {
id: String::new(),
name: String::new(),
arguments: String::new(),
});
if let Some(id) = id {
entry.id = id;
}
if let Some(name) = name {
entry.name = name;
}
if let Some(args) = arguments {
entry.arguments.push_str(&args);
}
}
fn build_stream_end_deltas(
tool_calls: &std::collections::HashMap<usize, ToolCallAccumulator>,
usage: Option<Usage>,
stop_reason: StopReason,
) -> Vec<StreamDelta> {
let mut deltas = Vec::new();
for (idx, tool) in tool_calls {
deltas.push(StreamDelta::ToolUseStart {
id: tool.id.clone(),
name: tool.name.clone(),
block_index: *idx + 1,
thought_signature: None,
});
deltas.push(StreamDelta::ToolInputDelta {
id: tool.id.clone(),
delta: tool.arguments.clone(),
block_index: *idx + 1,
});
}
if let Some(u) = usage {
deltas.push(StreamDelta::Usage(u));
}
deltas.push(StreamDelta::Done {
stop_reason: Some(stop_reason),
});
deltas
}
enum SseProcessResult {
TextDelta(String),
ToolCallUpdate {
index: usize,
id: Option<String>,
name: Option<String>,
arguments: Option<String>,
},
Usage(Usage),
Done(StopReason),
Sentinel,
}
fn process_sse_data(data: &str) -> Vec<SseProcessResult> {
if data == "[DONE]" {
return vec![SseProcessResult::Sentinel];
}
let Ok(chunk) = serde_json::from_str::<SseChunk>(data) else {
return vec![];
};
let mut results = Vec::new();
if let Some(u) = chunk.usage {
results.push(SseProcessResult::Usage(Usage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
cached_input_tokens: u
.prompt_tokens_details
.as_ref()
.map_or(0, |details| details.cached_tokens),
}));
}
if let Some(choice) = chunk.choices.into_iter().next() {
if let Some(content) = choice.delta.content
&& !content.is_empty()
{
results.push(SseProcessResult::TextDelta(content));
}
if let Some(tc_deltas) = choice.delta.tool_calls {
for tc in tc_deltas {
results.push(SseProcessResult::ToolCallUpdate {
index: tc.index,
id: tc.id,
name: tc.function.as_ref().and_then(|f| f.name.clone()),
arguments: tc.function.as_ref().and_then(|f| f.arguments.clone()),
});
}
}
if let Some(finish_reason) = choice.finish_reason {
results.push(SseProcessResult::Done(map_finish_reason(&finish_reason)));
}
}
results
}
fn use_max_tokens_alias(base_url: &str) -> bool {
base_url.contains("moonshot.ai")
|| base_url.contains("api.z.ai")
|| base_url.contains("minimax.io")
}
fn use_stream_usage_options(base_url: &str) -> bool {
base_url == DEFAULT_BASE_URL || base_url.contains("api.openai.com")
}
fn map_finish_reason(finish_reason: &str) -> StopReason {
match finish_reason {
"stop" => StopReason::EndTurn,
"tool_calls" => StopReason::ToolUse,
"length" => StopReason::MaxTokens,
"content_filter" | "network_error" => StopReason::StopSequence,
"sensitive" => StopReason::Refusal,
unknown => {
log::debug!("Unknown finish_reason from OpenAI-compatible API: {unknown}");
StopReason::StopSequence
}
}
}
fn build_api_chat_request<'a>(
model: &'a str,
messages: &'a [ApiMessage],
max_tokens: u32,
tools: Option<&'a [ApiTool]>,
reasoning: Option<ApiReasoning>,
include_max_tokens_alias: bool,
) -> ApiChatRequest<'a> {
ApiChatRequest {
model,
messages,
max_completion_tokens: Some(max_tokens),
max_tokens: include_max_tokens_alias.then_some(max_tokens),
tools,
reasoning,
}
}
fn build_api_chat_request_streaming<'a>(
model: &'a str,
messages: &'a [ApiMessage],
max_tokens: u32,
tools: Option<&'a [ApiTool]>,
reasoning: Option<ApiReasoning>,
include_max_tokens_alias: bool,
include_stream_usage: bool,
) -> ApiChatRequestStreaming<'a> {
ApiChatRequestStreaming {
model,
messages,
max_completion_tokens: Some(max_tokens),
max_tokens: include_max_tokens_alias.then_some(max_tokens),
tools,
reasoning,
stream_options: include_stream_usage.then_some(ApiStreamOptions {
include_usage: true,
}),
stream: true,
}
}
fn build_api_reasoning(thinking: Option<&ThinkingConfig>) -> Option<ApiReasoning> {
thinking
.and_then(resolve_reasoning_effort)
.map(|effort| ApiReasoning { effort })
}
const fn resolve_reasoning_effort(config: &ThinkingConfig) -> Option<ReasoningEffort> {
if let Some(effort) = config.effort {
return Some(map_effort(effort));
}
match &config.mode {
ThinkingMode::Adaptive => None,
ThinkingMode::Enabled { budget_tokens } => Some(map_budget_to_reasoning(*budget_tokens)),
}
}
const fn map_effort(effort: Effort) -> ReasoningEffort {
match effort {
Effort::Low => ReasoningEffort::Low,
Effort::Medium => ReasoningEffort::Medium,
Effort::High => ReasoningEffort::High,
Effort::Max => ReasoningEffort::XHigh,
}
}
const fn map_budget_to_reasoning(budget_tokens: u32) -> ReasoningEffort {
if budget_tokens <= 4_096 {
ReasoningEffort::Low
} else if budget_tokens <= 16_384 {
ReasoningEffort::Medium
} else if budget_tokens <= 32_768 {
ReasoningEffort::High
} else {
ReasoningEffort::XHigh
}
}
fn build_api_messages(request: &ChatRequest) -> Vec<ApiMessage> {
let mut messages = Vec::new();
if !request.system.is_empty() {
messages.push(ApiMessage {
role: ApiRole::System,
content: Some(request.system.clone()),
tool_calls: None,
tool_call_id: None,
});
}
for msg in &request.messages {
match &msg.content {
Content::Text(text) => {
messages.push(ApiMessage {
role: match msg.role {
crate::llm::Role::User => ApiRole::User,
crate::llm::Role::Assistant => ApiRole::Assistant,
},
content: Some(text.clone()),
tool_calls: None,
tool_call_id: None,
});
}
Content::Blocks(blocks) => {
let mut text_parts = Vec::new();
let mut tool_calls = Vec::new();
for block in blocks {
match block {
ContentBlock::Text { text } => {
text_parts.push(text.clone());
}
ContentBlock::Thinking { .. }
| ContentBlock::RedactedThinking { .. }
| ContentBlock::Image { .. }
| ContentBlock::Document { .. } => {
}
ContentBlock::ToolUse {
id, name, input, ..
} => {
tool_calls.push(ApiToolCall {
id: id.clone(),
r#type: "function".to_owned(),
function: ApiFunctionCall {
name: name.clone(),
arguments: serde_json::to_string(input)
.unwrap_or_else(|_| "{}".to_owned()),
},
});
}
ContentBlock::ToolResult {
tool_use_id,
content,
..
} => {
messages.push(ApiMessage {
role: ApiRole::Tool,
content: Some(content.clone()),
tool_calls: None,
tool_call_id: Some(tool_use_id.clone()),
});
}
}
}
if !text_parts.is_empty() || !tool_calls.is_empty() {
let role = match msg.role {
crate::llm::Role::User => ApiRole::User,
crate::llm::Role::Assistant => ApiRole::Assistant,
};
if role == ApiRole::Assistant || !text_parts.is_empty() {
messages.push(ApiMessage {
role,
content: if text_parts.is_empty() {
None
} else {
Some(text_parts.join("\n"))
},
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None,
});
}
}
}
}
}
messages
}
fn convert_tool(t: crate::llm::Tool) -> ApiTool {
ApiTool {
r#type: "function".to_owned(),
function: ApiFunction {
name: t.name,
description: t.description,
parameters: t.input_schema,
},
}
}
fn build_content_blocks(message: &ApiResponseMessage) -> Vec<ContentBlock> {
let mut blocks = Vec::new();
if let Some(content) = &message.content
&& !content.is_empty()
{
blocks.push(ContentBlock::Text {
text: content.clone(),
});
}
if let Some(tool_calls) = &message.tool_calls {
for tc in tool_calls {
let input: serde_json::Value = serde_json::from_str(&tc.function.arguments)
.unwrap_or_else(|_| serde_json::json!({}));
blocks.push(ContentBlock::ToolUse {
id: tc.id.clone(),
name: tc.function.name.clone(),
input,
thought_signature: None,
});
}
}
blocks
}
#[derive(Serialize)]
struct ApiChatRequest<'a> {
model: &'a str,
messages: &'a [ApiMessage],
#[serde(skip_serializing_if = "Option::is_none")]
max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<&'a [ApiTool]>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<ApiReasoning>,
}
#[derive(Serialize)]
struct ApiChatRequestStreaming<'a> {
model: &'a str,
messages: &'a [ApiMessage],
#[serde(skip_serializing_if = "Option::is_none")]
max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<&'a [ApiTool]>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<ApiReasoning>,
#[serde(skip_serializing_if = "Option::is_none")]
stream_options: Option<ApiStreamOptions>,
stream: bool,
}
#[derive(Clone, Copy, Serialize)]
struct ApiStreamOptions {
include_usage: bool,
}
#[derive(Clone, Copy, Serialize)]
#[serde(rename_all = "lowercase")]
enum ReasoningEffort {
Low,
Medium,
High,
#[serde(rename = "xhigh")]
XHigh,
}
#[derive(Serialize)]
struct ApiReasoning {
effort: ReasoningEffort,
}
#[derive(Serialize)]
struct ApiMessage {
role: ApiRole,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ApiToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
}
#[derive(Debug, Serialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
enum ApiRole {
System,
User,
Assistant,
Tool,
}
#[derive(Serialize)]
struct ApiToolCall {
id: String,
r#type: String,
function: ApiFunctionCall,
}
#[derive(Serialize)]
struct ApiFunctionCall {
name: String,
arguments: String,
}
#[derive(Serialize)]
struct ApiTool {
r#type: String,
function: ApiFunction,
}
#[derive(Serialize)]
struct ApiFunction {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Deserialize)]
struct ApiChatResponse {
id: String,
choices: Vec<ApiChoice>,
model: String,
usage: ApiUsage,
}
#[derive(Deserialize)]
struct ApiChoice {
message: ApiResponseMessage,
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct ApiResponseMessage {
content: Option<String>,
tool_calls: Option<Vec<ApiResponseToolCall>>,
}
#[derive(Deserialize)]
struct ApiResponseToolCall {
id: String,
function: ApiResponseFunctionCall,
}
#[derive(Deserialize)]
struct ApiResponseFunctionCall {
name: String,
arguments: String,
}
#[derive(Deserialize)]
struct ApiUsage {
#[serde(deserialize_with = "deserialize_u32_from_number")]
prompt_tokens: u32,
#[serde(deserialize_with = "deserialize_u32_from_number")]
completion_tokens: u32,
#[serde(default)]
prompt_tokens_details: Option<ApiPromptTokensDetails>,
}
#[derive(Deserialize)]
struct ApiPromptTokensDetails {
#[serde(default, deserialize_with = "deserialize_u32_from_number")]
cached_tokens: u32,
}
struct ToolCallAccumulator {
id: String,
name: String,
arguments: String,
}
#[derive(Deserialize)]
struct SseChunk {
choices: Vec<SseChoice>,
#[serde(default)]
usage: Option<SseUsage>,
}
#[derive(Deserialize)]
struct SseChoice {
delta: SseDelta,
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct SseDelta {
content: Option<String>,
tool_calls: Option<Vec<SseToolCallDelta>>,
}
#[derive(Deserialize)]
struct SseToolCallDelta {
index: usize,
id: Option<String>,
function: Option<SseFunctionDelta>,
}
#[derive(Deserialize)]
struct SseFunctionDelta {
name: Option<String>,
arguments: Option<String>,
}
#[derive(Deserialize)]
struct SseUsage {
#[serde(deserialize_with = "deserialize_u32_from_number")]
prompt_tokens: u32,
#[serde(deserialize_with = "deserialize_u32_from_number")]
completion_tokens: u32,
#[serde(default)]
prompt_tokens_details: Option<ApiPromptTokensDetails>,
}
fn deserialize_u32_from_number<'de, D>(deserializer: D) -> std::result::Result<u32, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum NumberLike {
U64(u64),
F64(f64),
}
match NumberLike::deserialize(deserializer)? {
NumberLike::U64(v) => u32::try_from(v)
.map_err(|_| D::Error::custom(format!("token count out of range for u32: {v}"))),
NumberLike::F64(v) => {
if v.is_finite() && v >= 0.0 && v.fract() == 0.0 && v <= f64::from(u32::MAX) {
v.to_string().parse::<u32>().map_err(|e| {
D::Error::custom(format!(
"failed to convert integer-compatible token count {v} to u32: {e}"
))
})
} else {
Err(D::Error::custom(format!(
"token count must be a non-negative integer-compatible number, got {v}"
)))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_creates_provider_with_custom_model() {
let provider = OpenAIProvider::new("test-api-key".to_string(), "custom-model".to_string());
assert_eq!(provider.model(), "custom-model");
assert_eq!(provider.provider(), "openai");
assert_eq!(provider.base_url, DEFAULT_BASE_URL);
}
#[test]
fn test_with_base_url_creates_provider_with_custom_url() {
let provider = OpenAIProvider::with_base_url(
"test-api-key".to_string(),
"llama3".to_string(),
"http://localhost:11434/v1".to_string(),
);
assert_eq!(provider.model(), "llama3");
assert_eq!(provider.base_url, "http://localhost:11434/v1");
}
#[test]
fn test_gpt4o_factory_creates_gpt4o_provider() {
let provider = OpenAIProvider::gpt4o("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_GPT4O);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_gpt4o_mini_factory_creates_gpt4o_mini_provider() {
let provider = OpenAIProvider::gpt4o_mini("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_GPT4O_MINI);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_gpt52_thinking_factory_creates_provider() {
let provider = OpenAIProvider::gpt52_thinking("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_GPT52_THINKING);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_gpt54_factory_creates_provider() {
let provider = OpenAIProvider::gpt54("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_GPT54);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_gpt53_codex_factory_creates_provider() {
let provider = OpenAIProvider::gpt53_codex("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_GPT53_CODEX);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_codex_factory_points_to_latest_codex_model() {
let provider = OpenAIProvider::codex("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_GPT53_CODEX);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_gpt5_factory_creates_gpt5_provider() {
let provider = OpenAIProvider::gpt5("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_GPT5);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_gpt5_mini_factory_creates_provider() {
let provider = OpenAIProvider::gpt5_mini("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_GPT5_MINI);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_o3_factory_creates_o3_provider() {
let provider = OpenAIProvider::o3("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_O3);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_o4_mini_factory_creates_o4_mini_provider() {
let provider = OpenAIProvider::o4_mini("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_O4_MINI);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_o1_factory_creates_o1_provider() {
let provider = OpenAIProvider::o1("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_O1);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_gpt41_factory_creates_gpt41_provider() {
let provider = OpenAIProvider::gpt41("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_GPT41);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_kimi_factory_creates_provider_with_kimi_base_url() {
let provider = OpenAIProvider::kimi("test-api-key".to_string(), "kimi-custom".to_string());
assert_eq!(provider.model(), "kimi-custom");
assert_eq!(provider.base_url, BASE_URL_KIMI);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_kimi_k2_5_factory_creates_provider() {
let provider = OpenAIProvider::kimi_k2_5("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_KIMI_K2_5);
assert_eq!(provider.base_url, BASE_URL_KIMI);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_kimi_k2_thinking_factory_creates_provider() {
let provider = OpenAIProvider::kimi_k2_thinking("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_KIMI_K2_THINKING);
assert_eq!(provider.base_url, BASE_URL_KIMI);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_zai_factory_creates_provider_with_zai_base_url() {
let provider = OpenAIProvider::zai("test-api-key".to_string(), "glm-custom".to_string());
assert_eq!(provider.model(), "glm-custom");
assert_eq!(provider.base_url, BASE_URL_ZAI);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_zai_glm5_factory_creates_provider() {
let provider = OpenAIProvider::zai_glm5("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_ZAI_GLM5);
assert_eq!(provider.base_url, BASE_URL_ZAI);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_minimax_factory_creates_provider_with_minimax_base_url() {
let provider =
OpenAIProvider::minimax("test-api-key".to_string(), "minimax-custom".to_string());
assert_eq!(provider.model(), "minimax-custom");
assert_eq!(provider.base_url, BASE_URL_MINIMAX);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_minimax_m2_5_factory_creates_provider() {
let provider = OpenAIProvider::minimax_m2_5("test-api-key".to_string());
assert_eq!(provider.model(), MODEL_MINIMAX_M2_5);
assert_eq!(provider.base_url, BASE_URL_MINIMAX);
assert_eq!(provider.provider(), "openai");
}
#[test]
fn test_model_constants_have_expected_values() {
assert_eq!(MODEL_GPT54, "gpt-5.4");
assert_eq!(MODEL_GPT53_CODEX, "gpt-5.3-codex");
assert_eq!(MODEL_GPT52_INSTANT, "gpt-5.2-instant");
assert_eq!(MODEL_GPT52_THINKING, "gpt-5.2-thinking");
assert_eq!(MODEL_GPT52_PRO, "gpt-5.2-pro");
assert_eq!(MODEL_GPT52_CODEX, "gpt-5.2-codex");
assert_eq!(MODEL_GPT5, "gpt-5");
assert_eq!(MODEL_GPT5_MINI, "gpt-5-mini");
assert_eq!(MODEL_GPT5_NANO, "gpt-5-nano");
assert_eq!(MODEL_O3, "o3");
assert_eq!(MODEL_O3_MINI, "o3-mini");
assert_eq!(MODEL_O4_MINI, "o4-mini");
assert_eq!(MODEL_O1, "o1");
assert_eq!(MODEL_O1_MINI, "o1-mini");
assert_eq!(MODEL_GPT41, "gpt-4.1");
assert_eq!(MODEL_GPT41_MINI, "gpt-4.1-mini");
assert_eq!(MODEL_GPT41_NANO, "gpt-4.1-nano");
assert_eq!(MODEL_GPT4O, "gpt-4o");
assert_eq!(MODEL_GPT4O_MINI, "gpt-4o-mini");
assert_eq!(MODEL_KIMI_K2_5, "kimi-k2.5");
assert_eq!(MODEL_KIMI_K2_THINKING, "kimi-k2-thinking");
assert_eq!(MODEL_ZAI_GLM5, "glm-5");
assert_eq!(MODEL_MINIMAX_M2_5, "MiniMax-M2.5");
assert_eq!(BASE_URL_KIMI, "https://api.moonshot.ai/v1");
assert_eq!(BASE_URL_ZAI, "https://api.z.ai/api/paas/v4");
assert_eq!(BASE_URL_MINIMAX, "https://api.minimax.io/v1");
}
#[test]
fn test_provider_is_cloneable() {
let provider = OpenAIProvider::new("test-api-key".to_string(), "test-model".to_string());
let cloned = provider.clone();
assert_eq!(provider.model(), cloned.model());
assert_eq!(provider.provider(), cloned.provider());
assert_eq!(provider.base_url, cloned.base_url);
}
#[test]
fn test_api_role_serialization() {
let system_role = ApiRole::System;
let user_role = ApiRole::User;
let assistant_role = ApiRole::Assistant;
let tool_role = ApiRole::Tool;
assert_eq!(serde_json::to_string(&system_role).unwrap(), "\"system\"");
assert_eq!(serde_json::to_string(&user_role).unwrap(), "\"user\"");
assert_eq!(
serde_json::to_string(&assistant_role).unwrap(),
"\"assistant\""
);
assert_eq!(serde_json::to_string(&tool_role).unwrap(), "\"tool\"");
}
#[test]
fn test_api_message_serialization_simple() {
let message = ApiMessage {
role: ApiRole::User,
content: Some("Hello, world!".to_string()),
tool_calls: None,
tool_call_id: None,
};
let json = serde_json::to_string(&message).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"content\":\"Hello, world!\""));
assert!(!json.contains("tool_calls"));
assert!(!json.contains("tool_call_id"));
}
#[test]
fn test_api_message_serialization_with_tool_calls() {
let message = ApiMessage {
role: ApiRole::Assistant,
content: Some("Let me help.".to_string()),
tool_calls: Some(vec![ApiToolCall {
id: "call_123".to_string(),
r#type: "function".to_string(),
function: ApiFunctionCall {
name: "read_file".to_string(),
arguments: "{\"path\": \"/test.txt\"}".to_string(),
},
}]),
tool_call_id: None,
};
let json = serde_json::to_string(&message).unwrap();
assert!(json.contains("\"role\":\"assistant\""));
assert!(json.contains("\"tool_calls\""));
assert!(json.contains("\"id\":\"call_123\""));
assert!(json.contains("\"type\":\"function\""));
assert!(json.contains("\"name\":\"read_file\""));
}
#[test]
fn test_api_tool_message_serialization() {
let message = ApiMessage {
role: ApiRole::Tool,
content: Some("File contents here".to_string()),
tool_calls: None,
tool_call_id: Some("call_123".to_string()),
};
let json = serde_json::to_string(&message).unwrap();
assert!(json.contains("\"role\":\"tool\""));
assert!(json.contains("\"tool_call_id\":\"call_123\""));
assert!(json.contains("\"content\":\"File contents here\""));
}
#[test]
fn test_api_tool_serialization() {
let tool = ApiTool {
r#type: "function".to_string(),
function: ApiFunction {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"arg": {"type": "string"}
}
}),
},
};
let json = serde_json::to_string(&tool).unwrap();
assert!(json.contains("\"type\":\"function\""));
assert!(json.contains("\"name\":\"test_tool\""));
assert!(json.contains("\"description\":\"A test tool\""));
assert!(json.contains("\"parameters\""));
}
#[test]
fn test_api_response_deserialization() {
let json = r#"{
"id": "chatcmpl-123",
"choices": [
{
"message": {
"content": "Hello!"
},
"finish_reason": "stop"
}
],
"model": "gpt-4o",
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50
}
}"#;
let response: ApiChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.id, "chatcmpl-123");
assert_eq!(response.model, "gpt-4o");
assert_eq!(response.usage.prompt_tokens, 100);
assert_eq!(response.usage.completion_tokens, 50);
assert_eq!(response.choices.len(), 1);
assert_eq!(
response.choices[0].message.content,
Some("Hello!".to_string())
);
}
#[test]
fn test_api_response_with_tool_calls_deserialization() {
let json = r#"{
"id": "chatcmpl-456",
"choices": [
{
"message": {
"content": null,
"tool_calls": [
{
"id": "call_abc",
"type": "function",
"function": {
"name": "read_file",
"arguments": "{\"path\": \"test.txt\"}"
}
}
]
},
"finish_reason": "tool_calls"
}
],
"model": "gpt-4o",
"usage": {
"prompt_tokens": 150,
"completion_tokens": 30
}
}"#;
let response: ApiChatResponse = serde_json::from_str(json).unwrap();
let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "call_abc");
assert_eq!(tool_calls[0].function.name, "read_file");
}
#[test]
fn test_api_response_with_unknown_finish_reason_deserialization() {
let json = r#"{
"id": "chatcmpl-789",
"choices": [
{
"message": {
"content": "ok"
},
"finish_reason": "vendor_custom_reason"
}
],
"model": "glm-5",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5
}
}"#;
let response: ApiChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(
response.choices[0].finish_reason.as_deref(),
Some("vendor_custom_reason")
);
assert_eq!(
map_finish_reason(response.choices[0].finish_reason.as_deref().unwrap()),
StopReason::StopSequence
);
}
#[test]
fn test_map_finish_reason_covers_vendor_specific_values() {
assert_eq!(map_finish_reason("stop"), StopReason::EndTurn);
assert_eq!(map_finish_reason("tool_calls"), StopReason::ToolUse);
assert_eq!(map_finish_reason("length"), StopReason::MaxTokens);
assert_eq!(
map_finish_reason("content_filter"),
StopReason::StopSequence
);
assert_eq!(map_finish_reason("sensitive"), StopReason::Refusal);
assert_eq!(map_finish_reason("network_error"), StopReason::StopSequence);
assert_eq!(
map_finish_reason("some_new_reason"),
StopReason::StopSequence
);
}
#[test]
fn test_build_api_messages_with_system() {
let request = ChatRequest {
system: "You are helpful.".to_string(),
messages: vec![crate::llm::Message::user("Hello")],
tools: None,
max_tokens: 1024,
max_tokens_explicit: true,
session_id: None,
cached_content: None,
thinking: None,
};
let api_messages = build_api_messages(&request);
assert_eq!(api_messages.len(), 2);
assert_eq!(api_messages[0].role, ApiRole::System);
assert_eq!(
api_messages[0].content,
Some("You are helpful.".to_string())
);
assert_eq!(api_messages[1].role, ApiRole::User);
assert_eq!(api_messages[1].content, Some("Hello".to_string()));
}
#[test]
fn test_build_api_messages_empty_system() {
let request = ChatRequest {
system: String::new(),
messages: vec![crate::llm::Message::user("Hello")],
tools: None,
max_tokens: 1024,
max_tokens_explicit: true,
session_id: None,
cached_content: None,
thinking: None,
};
let api_messages = build_api_messages(&request);
assert_eq!(api_messages.len(), 1);
assert_eq!(api_messages[0].role, ApiRole::User);
}
#[test]
fn test_convert_tool() {
let tool = crate::llm::Tool {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
input_schema: serde_json::json!({"type": "object"}),
};
let api_tool = convert_tool(tool);
assert_eq!(api_tool.r#type, "function");
assert_eq!(api_tool.function.name, "test_tool");
assert_eq!(api_tool.function.description, "A test tool");
}
#[test]
fn test_build_content_blocks_text_only() {
let message = ApiResponseMessage {
content: Some("Hello!".to_string()),
tool_calls: None,
};
let blocks = build_content_blocks(&message);
assert_eq!(blocks.len(), 1);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello!"));
}
#[test]
fn test_build_content_blocks_with_tool_calls() {
let message = ApiResponseMessage {
content: Some("Let me help.".to_string()),
tool_calls: Some(vec![ApiResponseToolCall {
id: "call_123".to_string(),
function: ApiResponseFunctionCall {
name: "read_file".to_string(),
arguments: "{\"path\": \"test.txt\"}".to_string(),
},
}]),
};
let blocks = build_content_blocks(&message);
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Let me help."));
assert!(
matches!(&blocks[1], ContentBlock::ToolUse { id, name, .. } if id == "call_123" && name == "read_file")
);
}
#[test]
fn test_sse_chunk_text_delta_deserialization() {
let json = r#"{
"choices": [{
"delta": {
"content": "Hello"
},
"finish_reason": null
}]
}"#;
let chunk: SseChunk = serde_json::from_str(json).unwrap();
assert_eq!(chunk.choices.len(), 1);
assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
assert!(chunk.choices[0].finish_reason.is_none());
}
#[test]
fn test_sse_chunk_tool_call_delta_deserialization() {
let json = r#"{
"choices": [{
"delta": {
"tool_calls": [{
"index": 0,
"id": "call_abc",
"function": {
"name": "read_file",
"arguments": ""
}
}]
},
"finish_reason": null
}]
}"#;
let chunk: SseChunk = serde_json::from_str(json).unwrap();
let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].index, 0);
assert_eq!(tool_calls[0].id, Some("call_abc".to_string()));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name,
Some("read_file".to_string())
);
}
#[test]
fn test_sse_chunk_tool_call_arguments_delta_deserialization() {
let json = r#"{
"choices": [{
"delta": {
"tool_calls": [{
"index": 0,
"function": {
"arguments": "{\"path\":"
}
}]
},
"finish_reason": null
}]
}"#;
let chunk: SseChunk = serde_json::from_str(json).unwrap();
let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls[0].id, None);
assert_eq!(
tool_calls[0].function.as_ref().unwrap().arguments,
Some("{\"path\":".to_string())
);
}
#[test]
fn test_sse_chunk_with_finish_reason_deserialization() {
let json = r#"{
"choices": [{
"delta": {},
"finish_reason": "stop"
}]
}"#;
let chunk: SseChunk = serde_json::from_str(json).unwrap();
assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("stop"));
}
#[test]
fn test_sse_chunk_with_usage_deserialization() {
let json = r#"{
"choices": [{
"delta": {},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50
}
}"#;
let chunk: SseChunk = serde_json::from_str(json).unwrap();
let usage = chunk.usage.unwrap();
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
}
#[test]
fn test_sse_chunk_with_float_usage_deserialization() {
let json = r#"{
"choices": [{
"delta": {},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 100.0,
"completion_tokens": 50.0
}
}"#;
let chunk: SseChunk = serde_json::from_str(json).unwrap();
let usage = chunk.usage.unwrap();
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
}
#[test]
fn test_api_usage_deserializes_integer_compatible_numbers() {
let json = r#"{
"prompt_tokens": 42.0,
"completion_tokens": 7
}"#;
let usage: ApiUsage = serde_json::from_str(json).unwrap();
assert_eq!(usage.prompt_tokens, 42);
assert_eq!(usage.completion_tokens, 7);
}
#[test]
fn test_api_usage_deserializes_cached_tokens() {
let json = r#"{
"prompt_tokens": 42,
"completion_tokens": 7,
"prompt_tokens_details": {
"cached_tokens": 10
}
}"#;
let usage: ApiUsage = serde_json::from_str(json).unwrap();
assert_eq!(usage.prompt_tokens, 42);
assert_eq!(usage.completion_tokens, 7);
assert_eq!(usage.prompt_tokens_details.unwrap().cached_tokens, 10);
}
#[test]
fn test_api_usage_rejects_fractional_numbers() {
let json = r#"{
"prompt_tokens": 42.5,
"completion_tokens": 7
}"#;
let usage: std::result::Result<ApiUsage, _> = serde_json::from_str(json);
assert!(usage.is_err());
}
#[test]
fn test_use_max_tokens_alias_for_vendor_urls() {
assert!(!use_max_tokens_alias(DEFAULT_BASE_URL));
assert!(use_max_tokens_alias(BASE_URL_KIMI));
assert!(use_max_tokens_alias(BASE_URL_ZAI));
assert!(use_max_tokens_alias(BASE_URL_MINIMAX));
}
#[test]
fn test_requires_responses_api_only_for_legacy_codex_model() {
assert!(requires_responses_api(MODEL_GPT52_CODEX));
assert!(!requires_responses_api(MODEL_GPT53_CODEX));
assert!(!requires_responses_api(MODEL_GPT54));
}
#[test]
fn test_should_use_responses_api_for_official_agentic_requests() {
let request = ChatRequest {
system: String::new(),
messages: vec![crate::llm::Message::user("Hello")],
tools: Some(vec![crate::llm::Tool {
name: "read_file".to_string(),
description: "Read a file".to_string(),
input_schema: serde_json::json!({"type": "object"}),
}]),
max_tokens: 1024,
max_tokens_explicit: true,
session_id: Some("thread-1".to_string()),
cached_content: None,
thinking: None,
};
assert!(should_use_responses_api(
DEFAULT_BASE_URL,
MODEL_GPT54,
&request
));
assert!(!should_use_responses_api(
BASE_URL_KIMI,
MODEL_GPT54,
&request
));
}
#[test]
fn test_build_api_reasoning_maps_enabled_budget_to_effort() {
let reasoning = build_api_reasoning(Some(&ThinkingConfig::new(40_000))).unwrap();
assert!(matches!(reasoning.effort, ReasoningEffort::XHigh));
}
#[test]
fn test_build_api_reasoning_uses_explicit_effort() {
let reasoning =
build_api_reasoning(Some(&ThinkingConfig::adaptive_with_effort(Effort::High))).unwrap();
assert!(matches!(reasoning.effort, ReasoningEffort::High));
}
#[test]
fn test_build_api_reasoning_omits_adaptive_without_effort() {
assert!(build_api_reasoning(Some(&ThinkingConfig::adaptive())).is_none());
}
#[test]
fn test_openai_rejects_adaptive_thinking() {
let provider = OpenAIProvider::gpt54("test-key".to_string());
let error = provider
.validate_thinking_config(Some(&ThinkingConfig::adaptive()))
.unwrap_err();
assert!(
error
.to_string()
.contains("adaptive thinking is not supported")
);
}
#[test]
fn test_openai_non_reasoning_models_reject_thinking() {
let provider = OpenAIProvider::gpt4o("test-key".to_string());
let error = provider
.validate_thinking_config(Some(&ThinkingConfig::new(10_000)))
.unwrap_err();
assert!(error.to_string().contains("thinking is not supported"));
}
#[test]
fn test_request_serialization_openai_uses_max_completion_tokens_only() {
let messages = vec![ApiMessage {
role: ApiRole::User,
content: Some("Hello".to_string()),
tool_calls: None,
tool_call_id: None,
}];
let request = ApiChatRequest {
model: "gpt-4o",
messages: &messages,
max_completion_tokens: Some(1024),
max_tokens: None,
tools: None,
reasoning: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"max_completion_tokens\":1024"));
assert!(!json.contains("\"max_tokens\""));
}
#[test]
fn test_request_serialization_with_max_tokens_alias() {
let messages = vec![ApiMessage {
role: ApiRole::User,
content: Some("Hello".to_string()),
tool_calls: None,
tool_call_id: None,
}];
let request = ApiChatRequest {
model: "glm-5",
messages: &messages,
max_completion_tokens: Some(1024),
max_tokens: Some(1024),
tools: None,
reasoning: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"max_completion_tokens\":1024"));
assert!(json.contains("\"max_tokens\":1024"));
}
#[test]
fn test_streaming_request_serialization_openai_default() {
let messages = vec![ApiMessage {
role: ApiRole::User,
content: Some("Hello".to_string()),
tool_calls: None,
tool_call_id: None,
}];
let request = ApiChatRequestStreaming {
model: "gpt-4o",
messages: &messages,
max_completion_tokens: Some(1024),
max_tokens: None,
tools: None,
reasoning: None,
stream_options: Some(ApiStreamOptions {
include_usage: true,
}),
stream: true,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"stream\":true"));
assert!(json.contains("\"model\":\"gpt-4o\""));
assert!(json.contains("\"max_completion_tokens\":1024"));
assert!(json.contains("\"stream_options\":{\"include_usage\":true}"));
assert!(!json.contains("\"max_tokens\""));
}
#[test]
fn test_streaming_request_serialization_with_max_tokens_alias() {
let messages = vec![ApiMessage {
role: ApiRole::User,
content: Some("Hello".to_string()),
tool_calls: None,
tool_call_id: None,
}];
let request = ApiChatRequestStreaming {
model: "kimi-k2-thinking",
messages: &messages,
max_completion_tokens: Some(1024),
max_tokens: Some(1024),
tools: None,
reasoning: None,
stream_options: None,
stream: true,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"max_completion_tokens\":1024"));
assert!(json.contains("\"max_tokens\":1024"));
assert!(!json.contains("\"stream_options\""));
}
#[test]
fn test_request_serialization_includes_reasoning_when_present() {
let messages = vec![ApiMessage {
role: ApiRole::User,
content: Some("Hello".to_string()),
tool_calls: None,
tool_call_id: None,
}];
let request = ApiChatRequest {
model: MODEL_GPT54,
messages: &messages,
max_completion_tokens: Some(1024),
max_tokens: None,
tools: None,
reasoning: Some(ApiReasoning {
effort: ReasoningEffort::High,
}),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"reasoning\":{\"effort\":\"high\"}"));
}
}