use crate::llm::retry;
use crate::llm::traits::AiProvider;
use crate::llm::types::{
ChatCompletionParams, Message, ProviderExchange, ProviderResponse, ThinkingBlock, TokenUsage,
ToolCall,
};
use crate::llm::utils::{normalize_model_name, starts_with_ignore_ascii_case};
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
const PRICING: &[(&str, f64, f64)] = &[
("gpt-4.1", 2.00, 8.00),
("gpt-4.1-2025-04-14", 2.00, 8.00),
("gpt-4.1-mini", 0.40, 1.60),
("gpt-4.1-mini-2025-04-14", 0.40, 1.60),
("gpt-4.1-nano", 0.10, 0.40),
("gpt-4.1-nano-2025-04-14", 0.10, 0.40),
("gpt-4.5-preview", 75.00, 150.00),
("gpt-4.5-preview-2025-02-27", 75.00, 150.00),
("gpt-5", 1.25, 10.00),
("gpt-5-2025-08-07", 1.25, 10.00),
("gpt-5-chat-latest", 1.25, 10.00),
("gpt-5-codex", 1.25, 10.00),
("gpt-5-mini", 0.25, 2.0),
("gpt-5-mini-2025-08-07", 0.25, 2.0),
("gpt-5-nano", 0.05, 0.40),
("gpt-5-nano-2025-08-07", 0.05, 0.40),
("gpt-5-pro", 15.00, 120.00),
("gpt-5.1", 1.25, 10.00),
("gpt-5.1-2025-11-20", 1.25, 10.00),
("gpt-5.1-chat-latest", 1.25, 10.00),
("gpt-5.1-codex", 1.25, 10.00),
("gpt-5.1-codex-max", 1.25, 10.00),
("gpt-5.2", 1.75, 14.00),
("gpt-5.2-chat-latest", 1.75, 14.00),
("gpt-5.2-pro", 21.00, 168.00),
("gpt-4o", 2.50, 10.00),
("gpt-4o-2024-08-06", 2.50, 10.00),
("gpt-4o-2024-05-13", 5.00, 15.00),
("gpt-4o-realtime-preview", 5.00, 20.00),
("gpt-4o-realtime-preview-2025-06-03", 5.00, 20.00),
("gpt-4o-mini", 0.15, 0.60),
("gpt-4o-mini-2024-07-18", 0.15, 0.60),
("gpt-4o-mini-realtime-preview", 0.60, 2.40),
("gpt-4o-mini-realtime-preview-2024-12-17", 0.60, 2.40),
("gpt-4o-mini-search-preview", 0.15, 0.60),
("gpt-4o-mini-search-preview-2025-03-11", 0.15, 0.60),
("gpt-4o-search-preview", 2.50, 10.00),
("gpt-4o-search-preview-2025-03-11", 2.50, 10.00),
("o1", 15.00, 60.00),
("o1-2024-12-17", 15.00, 60.00),
("o1-pro", 150.00, 600.00),
("o1-pro-2025-03-19", 150.00, 600.00),
("o1-mini", 1.10, 4.40),
("o1-mini-2024-09-12", 1.10, 4.40),
("o3", 2.00, 8.00),
("o3-2025-04-16", 2.00, 8.00),
("o3-pro", 20.00, 80.00),
("o3-pro-2025-06-10", 20.00, 80.00),
("o3-mini", 1.10, 4.40),
("o3-mini-2025-01-31", 1.10, 4.40),
("o3-deep-research", 10.00, 40.00),
("o3-deep-research-2025-06-26", 10.00, 40.00),
("o4-mini", 1.10, 4.40),
("o4-mini-2025-04-16", 1.10, 4.40),
("o4-mini-deep-research", 2.00, 8.00),
("o4-mini-deep-research-2025-06-26", 2.00, 8.00),
("gpt-4-turbo", 10.00, 30.00),
("gpt-4-turbo-2024-04-09", 10.00, 30.00),
("gpt-4", 30.00, 60.00),
("gpt-4-0613", 30.00, 60.00),
("gpt-4-32k", 60.00, 120.00),
("gpt-3.5-turbo", 0.50, 1.50),
("gpt-3.5-turbo-0125", 0.50, 1.50),
("gpt-3.5-turbo-instruct", 1.50, 2.00),
("gpt-3.5-turbo-16k-0613", 3.00, 4.00),
];
fn calculate_cost(model: &str, prompt_tokens: u64, completion_tokens: u64) -> Option<f64> {
for (pricing_model, input_price, output_price) in PRICING {
if model.contains(pricing_model) {
let input_cost = (prompt_tokens as f64 / 1_000_000.0) * input_price;
let output_cost = (completion_tokens as f64 / 1_000_000.0) * output_price;
return Some(input_cost + output_cost);
}
}
None
}
fn supports_temperature(model: &str) -> bool {
!model.starts_with("o1")
&& !model.starts_with("o2")
&& !model.starts_with("o3")
&& !model.starts_with("o4")
&& !model.starts_with("gpt-5")
}
fn uses_max_completion_tokens(model: &str) -> bool {
model.starts_with("gpt-5")
}
fn get_cache_multiplier(model: &str) -> f64 {
if model.starts_with("gpt-5") {
0.1 } else {
0.25 }
}
fn calculate_cost_with_cache(
model: &str,
regular_input_tokens: u64,
cache_read_tokens: u64,
completion_tokens: u64,
) -> Option<f64> {
for (pricing_model, input_price, output_price) in PRICING {
if model.contains(pricing_model) {
let regular_input_cost = (regular_input_tokens as f64 / 1_000_000.0) * input_price;
let cache_multiplier = get_cache_multiplier(model);
let cache_read_cost =
(cache_read_tokens as f64 / 1_000_000.0) * input_price * cache_multiplier;
let output_cost = (completion_tokens as f64 / 1_000_000.0) * output_price;
return Some(regular_input_cost + cache_read_cost + output_cost);
}
}
None
}
#[derive(Debug, Clone)]
pub struct OpenAiProvider;
impl Default for OpenAiProvider {
fn default() -> Self {
Self::new()
}
}
impl OpenAiProvider {
pub fn new() -> Self {
Self
}
}
const OPENAI_API_KEY_ENV: &str = "OPENAI_API_KEY";
const OPENAI_OAUTH_ACCESS_TOKEN_ENV: &str = "OPENAI_OAUTH_ACCESS_TOKEN";
const OPENAI_OAUTH_ACCOUNT_ID_ENV: &str = "OPENAI_OAUTH_ACCOUNT_ID";
const OPENAI_API_URL_ENV: &str = "OPENAI_API_URL";
const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
#[async_trait::async_trait]
impl AiProvider for OpenAiProvider {
fn name(&self) -> &str {
"openai"
}
fn supports_model(&self, model: &str) -> bool {
starts_with_ignore_ascii_case(model, "gpt-5")
|| starts_with_ignore_ascii_case(model, "gpt-4o")
|| starts_with_ignore_ascii_case(model, "gpt-4.5")
|| starts_with_ignore_ascii_case(model, "gpt-4.1")
|| starts_with_ignore_ascii_case(model, "gpt-4")
|| starts_with_ignore_ascii_case(model, "gpt-3.5")
|| starts_with_ignore_ascii_case(model, "o1")
|| starts_with_ignore_ascii_case(model, "o3")
|| starts_with_ignore_ascii_case(model, "o4")
|| model.eq_ignore_ascii_case("chatgpt-4o-latest")
}
fn get_api_key(&self) -> Result<String> {
if env::var(OPENAI_OAUTH_ACCESS_TOKEN_ENV).is_ok() {
return Err(anyhow::anyhow!(
"Using OAuth authentication. API key not available when {} is set.",
OPENAI_OAUTH_ACCESS_TOKEN_ENV
));
}
match env::var(OPENAI_API_KEY_ENV) {
Ok(key) => Ok(key),
Err(_) => Err(anyhow::anyhow!(
"OpenAI API key not found in environment variable: {}. Set either {} for API key auth or {} + {} for OAuth.",
OPENAI_API_KEY_ENV,
OPENAI_API_KEY_ENV,
OPENAI_OAUTH_ACCESS_TOKEN_ENV,
OPENAI_OAUTH_ACCOUNT_ID_ENV
)),
}
}
fn supports_caching(&self, model: &str) -> bool {
let model_lower = normalize_model_name(model);
model_lower.contains("gpt-4o")
|| model_lower.contains("gpt-4.1")
|| model_lower.contains("gpt-5")
|| model_lower.contains("o1-preview")
|| model_lower.contains("o1-mini")
|| model_lower.contains("o1")
|| model_lower.contains("o3")
|| model_lower.contains("o4")
}
fn supports_vision(&self, model: &str) -> bool {
let normalized = normalize_model_name(model);
normalized.starts_with("gpt-4o")
|| normalized.starts_with("gpt-4.1")
|| normalized.starts_with("gpt-4-turbo")
|| normalized.starts_with("gpt-4-vision-preview")
|| normalized.starts_with("gpt-4o-")
|| normalized.starts_with("gpt-5-")
}
fn get_max_input_tokens(&self, model: &str) -> usize {
let normalized = normalize_model_name(model);
if normalized.starts_with("gpt-5.1") {
return 400_000;
}
if normalized.starts_with("gpt-5") {
return 128_000;
}
if normalized.starts_with("gpt-4o") {
return 128_000;
}
if normalized.starts_with("gpt-4-turbo")
|| normalized.starts_with("gpt-4.5")
|| normalized.starts_with("gpt-4.1")
{
return 128_000;
}
if normalized.starts_with("gpt-4") && !normalized.starts_with("gpt-4o") {
return 8_192; }
if normalized.starts_with("o1")
|| normalized.starts_with("o2")
|| normalized.starts_with("o3")
{
return 128_000;
}
if normalized.starts_with("gpt-3.5") {
return 16_384;
}
8_192
}
fn supports_structured_output(&self, _model: &str) -> bool {
true }
async fn chat_completion(&self, params: ChatCompletionParams) -> Result<ProviderResponse> {
let (use_oauth, oauth_account_id) = if let (Ok(access_token), Ok(account_id)) = (
env::var(OPENAI_OAUTH_ACCESS_TOKEN_ENV),
env::var(OPENAI_OAUTH_ACCOUNT_ID_ENV),
) {
(true, Some((access_token, account_id)))
} else {
(false, None)
};
let auth_token = if use_oauth {
oauth_account_id.as_ref().unwrap().0.clone()
} else {
self.get_api_key()?
};
let openai_messages = convert_messages(¶ms.messages);
let mut request_body = serde_json::json!({
"model": params.model,
"messages": openai_messages,
});
if supports_temperature(¶ms.model) {
request_body["temperature"] = serde_json::json!(params.temperature);
request_body["top_p"] = serde_json::json!(params.top_p);
}
if params.max_tokens > 0 {
if uses_max_completion_tokens(¶ms.model) {
request_body["max_completion_tokens"] = serde_json::json!(params.max_tokens);
} else {
request_body["max_tokens"] = serde_json::json!(params.max_tokens);
}
}
if let Some(tools) = ¶ms.tools {
if !tools.is_empty() {
let mut sorted_tools = tools.clone();
sorted_tools.sort_by(|a, b| a.name.cmp(&b.name));
let openai_tools = sorted_tools
.iter()
.map(|f| {
serde_json::json!({
"type": "function",
"function": {
"name": f.name,
"description": f.description,
"parameters": f.parameters
}
})
})
.collect::<Vec<_>>();
request_body["tools"] = serde_json::json!(openai_tools);
request_body["tool_choice"] = serde_json::json!("auto");
}
}
if let Some(response_format) = ¶ms.response_format {
match &response_format.format {
crate::llm::types::OutputFormat::Json => {
request_body["response_format"] = serde_json::json!({
"type": "json_object"
});
}
crate::llm::types::OutputFormat::JsonSchema => {
if let Some(schema) = &response_format.schema {
let mut format_obj = serde_json::json!({
"type": "json_schema",
"json_schema": {
"schema": schema
}
});
if matches!(
response_format.mode,
crate::llm::types::ResponseMode::Strict
) {
format_obj["json_schema"]["strict"] = serde_json::json!(true);
}
request_body["response_format"] = format_obj;
}
}
}
}
let account_id_header = oauth_account_id.as_ref().map(|(_, id)| id.clone());
let api_url = env::var(OPENAI_API_URL_ENV).unwrap_or_else(|_| OPENAI_API_URL.to_string());
let response = execute_openai_request(
auth_token,
account_id_header,
api_url,
request_body,
params.max_retries,
params.retry_timeout,
params.cancellation_token.as_ref(),
)
.await?;
Ok(response)
}
}
#[derive(Serialize, Deserialize, Debug)]
struct OpenAiMessage {
role: String,
content: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<serde_json::Value>, }
#[derive(Deserialize, Debug)]
struct OpenAiResponse {
choices: Vec<OpenAiChoice>,
usage: OpenAiUsage,
}
#[derive(Deserialize, Debug)]
struct OpenAiChoice {
message: OpenAiResponseMessage,
finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
struct OpenAiResponseMessage {
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
tool_calls: Option<Vec<OpenAiToolCall>>,
}
#[derive(Deserialize, Debug)]
struct OpenAiToolCall {
id: String,
#[serde(rename = "type")]
tool_type: String,
function: OpenAiFunction,
}
#[derive(Deserialize, Debug)]
struct OpenAiFunction {
name: String,
arguments: String,
}
#[derive(Deserialize, Debug)]
struct OpenAiUsage {
prompt_tokens: u64,
completion_tokens: u64,
total_tokens: u64,
#[serde(default)]
input_tokens_details: Option<OpenAiInputTokensDetails>,
#[serde(default)]
completion_tokens_details: Option<OpenAiCompletionTokensDetails>,
}
#[derive(Deserialize, Debug)]
struct OpenAiInputTokensDetails {
#[serde(default)]
cached_tokens: u64,
}
#[derive(Deserialize, Debug)]
struct OpenAiCompletionTokensDetails {
#[serde(default)]
reasoning_tokens: Option<u64>,
}
fn convert_messages(messages: &[Message]) -> Vec<OpenAiMessage> {
let mut result = Vec::new();
for message in messages {
match message.role.as_str() {
"tool" => {
let tool_call_id = message.tool_call_id.clone();
let name = message.name.clone();
let content = if message.cached {
let mut text_content = serde_json::json!({
"type": "text",
"text": message.content
});
text_content["cache_control"] = serde_json::json!({
"type": "ephemeral"
});
serde_json::json!([text_content])
} else {
serde_json::json!(message.content)
};
result.push(OpenAiMessage {
role: message.role.clone(),
content,
tool_call_id,
name,
tool_calls: None,
});
}
"assistant" if message.tool_calls.is_some() => {
let mut content_parts = Vec::new();
if !message.content.trim().is_empty() {
let mut text_content = serde_json::json!({
"type": "text",
"text": message.content
});
if message.cached {
text_content["cache_control"] = serde_json::json!({
"type": "ephemeral"
});
}
content_parts.push(text_content);
}
let content = if content_parts.len() == 1 && !message.cached {
content_parts[0]["text"].clone()
} else if content_parts.is_empty() {
serde_json::Value::Null
} else {
serde_json::json!(content_parts)
};
let tool_calls = if let Ok(generic_calls) =
serde_json::from_value::<Vec<crate::llm::tool_calls::GenericToolCall>>(
message.tool_calls.clone().unwrap(),
) {
let openai_calls: Vec<serde_json::Value> = generic_calls
.into_iter()
.map(|call| {
serde_json::json!({
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": serde_json::to_string(&call.arguments).unwrap_or_default()
}
})
})
.collect();
Some(serde_json::Value::Array(openai_calls))
} else {
panic!("Invalid tool_calls format - must be Vec<GenericToolCall>");
};
result.push(OpenAiMessage {
role: message.role.clone(),
content,
tool_call_id: None,
name: None,
tool_calls,
});
}
_ => {
let mut content_parts = vec![{
let mut text_content = serde_json::json!({
"type": "text",
"text": message.content
});
if message.cached {
text_content["cache_control"] = serde_json::json!({
"type": "ephemeral"
});
}
text_content
}];
if let Some(images) = &message.images {
for image in images {
if let crate::llm::types::ImageData::Base64(data) = &image.data {
content_parts.push(serde_json::json!({
"type": "image_url",
"image_url": {
"url": format!("data:{};base64,{}", image.media_type, data)
}
}));
}
}
}
let content = if content_parts.len() == 1 && !message.cached {
content_parts[0]["text"].clone()
} else {
serde_json::json!(content_parts)
};
result.push(OpenAiMessage {
role: message.role.clone(),
content,
tool_call_id: None,
name: None,
tool_calls: None,
});
}
}
}
result
}
async fn execute_openai_request(
auth_token: String,
account_id: Option<String>,
api_url: String,
request_body: serde_json::Value,
max_retries: u32,
base_timeout: std::time::Duration,
cancellation_token: Option<&tokio::sync::watch::Receiver<bool>>,
) -> Result<ProviderResponse> {
let client = Client::new();
let start_time = std::time::Instant::now();
let response = retry::retry_with_exponential_backoff(
|| {
let client = client.clone();
let auth_token = auth_token.clone();
let account_id = account_id.clone();
let api_url = api_url.clone();
let request_body = request_body.clone();
Box::pin(async move {
let mut req = client
.post(&api_url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", auth_token));
if let Some(id) = account_id {
req = req.header("ChatGPT-Account-ID", id);
}
req.json(&request_body).send().await
})
},
max_retries,
base_timeout,
cancellation_token,
)
.await?;
let request_time_ms = start_time.elapsed().as_millis() as u64;
let mut rate_limit_headers = std::collections::HashMap::new();
let headers = response.headers();
let cache_creation_input_tokens = headers
.get("x-cache-creation-input-tokens")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u32>().ok())
.unwrap_or(0);
let cache_read_input_tokens = headers
.get("x-cache-read-input-tokens")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u32>().ok())
.unwrap_or(0);
if let Some(requests_limit) = headers
.get("x-ratelimit-limit-requests")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert("requests_limit".to_string(), requests_limit.to_string());
}
if let Some(requests_remaining) = headers
.get("x-ratelimit-remaining-requests")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert(
"requests_remaining".to_string(),
requests_remaining.to_string(),
);
}
if let Some(tokens_limit) = headers
.get("x-ratelimit-limit-tokens")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert("tokens_limit".to_string(), tokens_limit.to_string());
}
if let Some(tokens_remaining) = headers
.get("x-ratelimit-remaining-tokens")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert("tokens_remaining".to_string(), tokens_remaining.to_string());
}
if let Some(request_reset) = headers
.get("x-ratelimit-reset-requests")
.and_then(|h| h.to_str().ok())
{
rate_limit_headers.insert("request_reset".to_string(), request_reset.to_string());
}
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"OpenAI API error {}: {}",
status,
error_text
));
}
let response_text = response.text().await?;
let openai_response: OpenAiResponse = serde_json::from_str(&response_text)?;
let choice = openai_response
.choices
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No choices in OpenAI response"))?;
let content = choice.message.content.unwrap_or_default();
let reasoning_tokens = openai_response
.usage
.completion_tokens_details
.as_ref()
.and_then(|details| details.reasoning_tokens)
.unwrap_or(0);
let reasoning_content = choice.message.reasoning_content;
let thinking = reasoning_content.as_ref().map(|rc| ThinkingBlock {
content: rc.clone(),
tokens: reasoning_tokens, });
let tool_calls: Option<Vec<ToolCall>> = choice.message.tool_calls.map(|calls| {
calls
.into_iter()
.filter_map(|call| {
if call.tool_type != "function" {
eprintln!(
"Warning: Unexpected tool type '{}' from OpenAI API",
call.tool_type
);
return None;
}
let arguments: serde_json::Value =
serde_json::from_str(&call.function.arguments).unwrap_or(serde_json::json!({}));
Some(ToolCall {
id: call.id,
name: call.function.name,
arguments,
})
})
.collect()
});
let cost = request_body
.get("model")
.and_then(|m| m.as_str())
.and_then(|model| {
let cached_tokens_from_response = openai_response
.usage
.input_tokens_details
.as_ref()
.map(|details| details.cached_tokens)
.unwrap_or(0);
let effective_cached_tokens = if cached_tokens_from_response > 0 {
cached_tokens_from_response
} else {
cache_read_input_tokens as u64
};
if effective_cached_tokens > 0 || cache_creation_input_tokens > 0 {
let regular_input_tokens = openai_response
.usage
.prompt_tokens
.saturating_sub(effective_cached_tokens);
calculate_cost_with_cache(
model,
regular_input_tokens,
effective_cached_tokens,
openai_response.usage.completion_tokens,
)
} else {
calculate_cost(
model,
openai_response.usage.prompt_tokens,
openai_response.usage.completion_tokens,
)
}
});
let usage = TokenUsage {
prompt_tokens: openai_response.usage.prompt_tokens,
output_tokens: openai_response.usage.completion_tokens,
reasoning_tokens, total_tokens: openai_response.usage.total_tokens + reasoning_tokens,
cached_tokens: openai_response
.usage
.input_tokens_details
.as_ref()
.map(|details| details.cached_tokens)
.unwrap_or(cache_read_input_tokens as u64),
cost,
request_time_ms: Some(request_time_ms),
};
let mut response_json: serde_json::Value = serde_json::from_str(&response_text)?;
if let Some(ref tc) = tool_calls {
let generic_calls: Vec<crate::llm::tool_calls::GenericToolCall> = tc
.iter()
.map(|call| crate::llm::tool_calls::GenericToolCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
meta: None, })
.collect();
response_json["tool_calls"] = serde_json::to_value(&generic_calls).unwrap_or_default();
}
let exchange = if rate_limit_headers.is_empty() {
ProviderExchange::new(request_body, response_json, Some(usage), "openai")
} else {
ProviderExchange::with_rate_limit_headers(
request_body,
response_json,
Some(usage),
"openai",
rate_limit_headers,
)
};
let structured_output = if content.trim().starts_with('{') || content.trim().starts_with('[') {
serde_json::from_str(&content).ok()
} else {
None
};
Ok(ProviderResponse {
content,
thinking, exchange,
tool_calls,
finish_reason: choice.finish_reason,
structured_output,
})
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
fn test_get_cache_multiplier() {
assert_eq!(get_cache_multiplier("gpt-5"), 0.1);
assert_eq!(get_cache_multiplier("gpt-5-2025-08-07"), 0.1);
assert_eq!(get_cache_multiplier("gpt-5-mini"), 0.1);
assert_eq!(get_cache_multiplier("gpt-5-mini-2025-08-07"), 0.1);
assert_eq!(get_cache_multiplier("gpt-5-nano"), 0.1);
assert_eq!(get_cache_multiplier("gpt-5-nano-2025-08-07"), 0.1);
assert_eq!(get_cache_multiplier("gpt-4o"), 0.25);
assert_eq!(get_cache_multiplier("gpt-4o-mini"), 0.25);
assert_eq!(get_cache_multiplier("gpt-4.1"), 0.25);
assert_eq!(get_cache_multiplier("gpt-4"), 0.25);
assert_eq!(get_cache_multiplier("gpt-3.5-turbo"), 0.25);
assert_eq!(get_cache_multiplier("o1"), 0.25);
assert_eq!(get_cache_multiplier("o3"), 0.25);
}
#[test]
fn test_calculate_cost_with_cache() {
let cost = calculate_cost_with_cache("gpt-5", 1000, 500, 200);
assert!(cost.is_some());
let cost_value = cost.unwrap();
assert!((cost_value - 0.0033125).abs() < 0.0000001);
let cost = calculate_cost_with_cache("gpt-4o", 1000, 500, 200);
assert!(cost.is_some());
let cost_value = cost.unwrap();
assert!((cost_value - 0.0048125).abs() < 0.0000001);
let cost = calculate_cost_with_cache("unknown-model", 1000, 500, 200);
assert!(cost.is_none());
}
#[test]
fn test_supports_temperature() {
assert!(supports_temperature("gpt-4"));
assert!(supports_temperature("gpt-4o"));
assert!(supports_temperature("gpt-4o-mini"));
assert!(supports_temperature("gpt-3.5-turbo"));
assert!(supports_temperature("chatgpt-4o-latest"));
assert!(!supports_temperature("o1"));
assert!(!supports_temperature("o1-preview"));
assert!(!supports_temperature("o1-mini"));
assert!(!supports_temperature("o2"));
assert!(!supports_temperature("o3"));
assert!(!supports_temperature("o3-mini"));
assert!(!supports_temperature("o4"));
assert!(!supports_temperature("gpt-5"));
assert!(!supports_temperature("gpt-5-mini"));
assert!(!supports_temperature("gpt-5-nano"));
}
#[test]
fn test_uses_max_completion_tokens() {
assert!(uses_max_completion_tokens("gpt-5"));
assert!(uses_max_completion_tokens("gpt-5-2025-08-07"));
assert!(uses_max_completion_tokens("gpt-5-mini"));
assert!(uses_max_completion_tokens("gpt-5-mini-2025-08-07"));
assert!(uses_max_completion_tokens("gpt-5-nano"));
assert!(uses_max_completion_tokens("gpt-5-nano-2025-08-07"));
assert!(!uses_max_completion_tokens("gpt-4o"));
assert!(!uses_max_completion_tokens("gpt-4o-mini"));
assert!(!uses_max_completion_tokens("gpt-4.1"));
assert!(!uses_max_completion_tokens("gpt-4"));
assert!(!uses_max_completion_tokens("gpt-3.5-turbo"));
assert!(!uses_max_completion_tokens("o1"));
assert!(!uses_max_completion_tokens("o3"));
}
#[test]
fn test_supports_model_gpt5() {
let provider = OpenAiProvider::new();
assert!(provider.supports_model("gpt-5"));
assert!(provider.supports_model("gpt-5-2025-08-07"));
assert!(provider.supports_model("gpt-5-mini"));
assert!(provider.supports_model("gpt-5-mini-2025-08-07"));
assert!(provider.supports_model("gpt-5-nano"));
assert!(provider.supports_model("gpt-5-nano-2025-08-07"));
assert!(provider.supports_model("gpt-4o"));
assert!(provider.supports_model("gpt-4"));
assert!(provider.supports_model("gpt-3.5-turbo"));
assert!(provider.supports_model("o1"));
assert!(!provider.supports_model("claude-3"));
assert!(!provider.supports_model("llama-2"));
}
#[test]
fn test_supports_model_case_insensitive() {
let provider = OpenAiProvider::new();
assert!(provider.supports_model("GPT-5"));
assert!(provider.supports_model("GPT-4O"));
assert!(provider.supports_model("GPT-4"));
assert!(provider.supports_model("Gpt-5"));
assert!(provider.supports_model("gPT-4o"));
assert!(provider.supports_model("O1"));
assert!(provider.supports_model("o3-mini"));
}
#[test]
fn test_get_max_input_tokens_gpt5() {
let provider = OpenAiProvider::new();
assert_eq!(provider.get_max_input_tokens("gpt-5"), 128_000);
assert_eq!(provider.get_max_input_tokens("gpt-5-2025-08-07"), 128_000);
assert_eq!(provider.get_max_input_tokens("gpt-5-mini"), 128_000);
assert_eq!(provider.get_max_input_tokens("gpt-5-nano"), 128_000);
assert_eq!(provider.get_max_input_tokens("gpt-4o"), 128_000);
assert_eq!(provider.get_max_input_tokens("gpt-4"), 8_192);
assert_eq!(provider.get_max_input_tokens("gpt-3.5-turbo"), 16_384);
}
#[test]
fn test_supports_vision() {
let provider = OpenAiProvider::new();
assert!(provider.supports_vision("gpt-4o"));
assert!(provider.supports_vision("gpt-4o-mini"));
assert!(provider.supports_vision("gpt-4o-2024-05-13"));
assert!(provider.supports_vision("gpt-4-turbo"));
assert!(provider.supports_vision("gpt-4-vision-preview"));
assert!(provider.supports_vision("gpt-4.1"));
assert!(provider.supports_vision("gpt-5-mini"));
assert!(!provider.supports_vision("gpt-3.5-turbo"));
assert!(!provider.supports_vision("gpt-4"));
assert!(!provider.supports_vision("o1-preview"));
assert!(!provider.supports_vision("o1-mini"));
assert!(!provider.supports_vision("text-davinci-003"));
}
#[test]
#[serial]
fn test_oauth_token_priority() {
let provider = OpenAiProvider::new();
env::set_var(OPENAI_OAUTH_ACCESS_TOKEN_ENV, "test-oauth-token");
env::set_var(OPENAI_OAUTH_ACCOUNT_ID_ENV, "test-account-id");
let result = provider.get_api_key();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("OAuth authentication"));
env::remove_var(OPENAI_OAUTH_ACCESS_TOKEN_ENV);
env::remove_var(OPENAI_OAUTH_ACCOUNT_ID_ENV);
}
#[test]
#[serial]
fn test_api_key_fallback() {
let provider = OpenAiProvider::new();
env::remove_var(OPENAI_OAUTH_ACCESS_TOKEN_ENV);
env::remove_var(OPENAI_OAUTH_ACCOUNT_ID_ENV);
env::set_var(OPENAI_API_KEY_ENV, "test-api-key");
let result = provider.get_api_key();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "test-api-key");
env::remove_var(OPENAI_API_KEY_ENV);
}
#[test]
#[serial]
fn test_no_auth_error() {
let provider = OpenAiProvider::new();
env::remove_var(OPENAI_OAUTH_ACCESS_TOKEN_ENV);
env::remove_var(OPENAI_OAUTH_ACCOUNT_ID_ENV);
env::remove_var(OPENAI_API_KEY_ENV);
let result = provider.get_api_key();
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("OPENAI_OAUTH"));
}
}