use async_openai::{
config::OpenAIConfig,
types::chat::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
ChatCompletionNamedToolChoice, ChatCompletionRequestAssistantMessageArgs,
ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage,
ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
ChatCompletionStreamOptions, ChatCompletionTool, ChatCompletionToolChoiceOption,
ChatCompletionTools, CompletionUsage, CreateChatCompletionRequestArgs, FinishReason,
FunctionCall, FunctionName, FunctionObjectArgs, ImageDetail, ImageUrl, ToolChoiceOptions,
},
types::embeddings::{CreateEmbeddingRequestArgs, EmbeddingInput},
Client,
};
use async_trait::async_trait;
use futures::StreamExt;
use std::collections::HashMap;
use tracing::debug;
use crate::error::{LlmError, Result};
use crate::traits::FunctionCall as TraitFunctionCall;
use crate::traits::ImageData;
use crate::traits::ToolCall;
use crate::traits::{
ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, LLMProvider, LLMResponse,
StreamChunk, StreamUsage, ToolChoice, ToolDefinition,
};
pub struct OpenAIProvider {
client: Client<OpenAIConfig>,
model: String,
embedding_model: String,
max_context_length: usize,
embedding_dimension: usize,
}
impl OpenAIProvider {
pub fn new(api_key: impl Into<String>) -> Self {
let config = OpenAIConfig::new().with_api_key(api_key);
Self::with_config(config)
}
pub fn with_config(config: OpenAIConfig) -> Self {
Self {
client: Client::with_config(config),
model: "gpt-5-mini".to_string(), embedding_model: "text-embedding-3-small".to_string(),
max_context_length: 200000, embedding_dimension: 1536,
}
}
pub fn compatible(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
let config = OpenAIConfig::new()
.with_api_key(api_key)
.with_api_base(base_url);
Self::with_config(config)
}
pub fn from_env() -> crate::error::Result<Self> {
let _ = dotenvy::dotenv();
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| crate::error::LlmError::ConfigError("OPENAI_API_KEY not set".into()))?;
let mut config = OpenAIConfig::new().with_api_key(api_key);
if let Ok(base_url) = std::env::var("OPENAI_BASE_URL") {
config = config.with_api_base(base_url);
}
let mut provider = Self::with_config(config);
if let Ok(model) = std::env::var("OPENAI_MODEL") {
provider = provider.with_model(model);
}
Ok(provider)
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self.max_context_length = Self::context_length_for_model(&self.model);
self
}
pub fn with_embedding_model(mut self, model: impl Into<String>) -> Self {
self.embedding_model = model.into();
self.embedding_dimension = Self::dimension_for_model(&self.embedding_model);
self
}
fn context_length_for_model(model: &str) -> usize {
match model {
m if m.contains("gpt-5.2") || m.contains("gpt-5.1") => 200000,
m if m.contains("gpt-5-nano") => 128000,
m if m.contains("gpt-5-mini") || m.contains("gpt-5") => 200000,
m if m.contains("gpt-4.1") => 128000,
m if m.contains("o4") || m.contains("o3") => 200000,
m if m.contains("o1") => 200000,
m if m.contains("gpt-4o") => 128000,
m if m.contains("gpt-4-turbo") => 128000,
m if m.contains("gpt-4-32k") => 32768,
m if m.contains("gpt-4") => 8192,
m if m.contains("gpt-3.5-turbo-16k") => 16384,
m if m.contains("gpt-3.5") => 4096,
m if m.contains("codex") => 200000,
m if m.contains("gpt-realtime") || m.contains("gpt-audio") => 128000,
_ => 128000, }
}
fn dimension_for_model(model: &str) -> usize {
match model {
m if m.contains("text-embedding-3-large") => 3072,
m if m.contains("text-embedding-3-small") => 1536,
m if m.contains("text-embedding-ada") => 1536,
_ => 1536, }
}
fn extract_usage(
usage: Option<CompletionUsage>,
) -> (usize, usize, usize, Option<usize>, Option<usize>) {
let usage = usage.unwrap_or(CompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
});
let cache_hit_tokens = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens)
.map(|t| t as usize);
let thinking_tokens = usage
.completion_tokens_details
.as_ref()
.and_then(|d| d.reasoning_tokens)
.map(|t| t as usize);
(
usage.prompt_tokens as usize,
usage.completion_tokens as usize,
usage.total_tokens as usize,
cache_hit_tokens,
thinking_tokens,
)
}
fn extract_stream_usage(usage: Option<CompletionUsage>) -> Option<StreamUsage> {
let (prompt_tokens, completion_tokens, _total_tokens, cache_hit_tokens, thinking_tokens) =
Self::extract_usage(usage);
if prompt_tokens == 0
&& completion_tokens == 0
&& cache_hit_tokens.is_none()
&& thinking_tokens.is_none()
{
return None;
}
let mut usage = StreamUsage::new(prompt_tokens, completion_tokens);
if let Some(tokens) = cache_hit_tokens {
usage = usage.with_cache_hit_tokens(tokens);
}
if let Some(tokens) = thinking_tokens {
usage = usage.with_thinking_tokens(tokens);
}
Some(usage)
}
fn convert_messages(messages: &[ChatMessage]) -> Result<Vec<ChatCompletionRequestMessage>> {
messages
.iter()
.map(|msg| {
match msg.role {
ChatRole::System => ChatCompletionRequestSystemMessageArgs::default()
.content(msg.content.as_str())
.build()
.map(Into::into)
.map_err(|e| LlmError::InvalidRequest(e.to_string())),
ChatRole::User => {
let content = Self::build_user_content(msg);
ChatCompletionRequestUserMessageArgs::default()
.content(content)
.build()
.map(Into::into)
.map_err(|e| LlmError::InvalidRequest(e.to_string()))
}
ChatRole::Assistant => {
let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
if !msg.content.is_empty() {
builder.content(msg.content.clone());
}
if let Some(ref tool_calls) = msg.tool_calls {
let openai_calls: Vec<ChatCompletionMessageToolCalls> = tool_calls
.iter()
.map(|tc| {
ChatCompletionMessageToolCalls::Function(
ChatCompletionMessageToolCall {
id: tc.id.clone(),
function: FunctionCall {
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
},
},
)
})
.collect();
builder.tool_calls(openai_calls);
}
builder
.build()
.map(Into::into)
.map_err(|e| LlmError::InvalidRequest(e.to_string()))
}
ChatRole::Tool => {
let tool_call_id = msg.tool_call_id.clone().ok_or_else(|| {
LlmError::InvalidRequest(
"Tool message missing required tool_call_id".into(),
)
})?;
ChatCompletionRequestToolMessageArgs::default()
.content(msg.content.clone())
.tool_call_id(tool_call_id)
.build()
.map(Into::into)
.map_err(|e| LlmError::InvalidRequest(e.to_string()))
}
ChatRole::Function => {
ChatCompletionRequestUserMessageArgs::default()
.content(msg.content.as_str())
.build()
.map(Into::into)
.map_err(|e| LlmError::InvalidRequest(e.to_string()))
}
}
})
.collect()
}
fn build_user_content(msg: &ChatMessage) -> ChatCompletionRequestUserMessageContent {
if msg.has_images() {
let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
if !msg.content.is_empty() {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(
ChatCompletionRequestMessageContentPartText {
text: msg.content.clone(),
},
));
}
if let Some(ref images) = msg.images {
for img in images {
let detail = Self::parse_image_detail(img);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: img.to_api_url(),
detail,
},
},
));
}
}
ChatCompletionRequestUserMessageContent::Array(parts)
} else {
ChatCompletionRequestUserMessageContent::Text(msg.content.clone())
}
}
fn parse_image_detail(img: &ImageData) -> Option<ImageDetail> {
match img.detail.as_deref() {
Some("low") => Some(ImageDetail::Low),
Some("high") => Some(ImageDetail::High),
Some("auto") => Some(ImageDetail::Auto),
_ => None,
}
}
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
fn name(&self) -> &str {
"openai"
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
self.max_context_length
}
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.complete_with_options(prompt, &CompletionOptions::default())
.await
}
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
let mut messages = Vec::new();
if let Some(system) = &options.system_prompt {
messages.push(ChatMessage::system(system));
}
messages.push(ChatMessage::user(prompt));
self.chat(&messages, Some(options)).await
}
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let openai_messages = Self::convert_messages(messages)?;
let options = options.cloned().unwrap_or_default();
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder.model(&self.model).messages(openai_messages);
if let Some(max_tokens) = options.max_tokens {
request_builder.max_completion_tokens(max_tokens as u32);
}
if let Some(temp) = options.temperature {
if (temp - 1.0_f32).abs() > f32::EPSILON {
request_builder.temperature(temp);
}
}
if let Some(top_p) = options.top_p {
request_builder.top_p(top_p);
}
if let Some(stop) = options.stop {
request_builder.stop(stop);
}
if let Some(freq_penalty) = options.frequency_penalty {
request_builder.frequency_penalty(freq_penalty);
}
if let Some(pres_penalty) = options.presence_penalty {
request_builder.presence_penalty(pres_penalty);
}
let request = request_builder
.build()
.map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
let response = self.client.chat().create(request).await?;
debug!(
"OpenAI response - usage: {:?}, model: {}",
response.usage, response.model
);
let choice = response
.choices
.first()
.ok_or_else(|| LlmError::ApiError("No choices in response".to_string()))?;
if let Some(FinishReason::ContentFilter) = choice.finish_reason {
return Err(LlmError::ApiError(
"Response blocked by OpenAI content filter (finish_reason=content_filter)".into(),
));
}
let content = choice.message.content.clone().unwrap_or_default();
let (prompt_tokens, completion_tokens, total_tokens, cache_hit_tokens, thinking_tokens) =
Self::extract_usage(response.usage.clone());
debug!(
"OpenAI token usage - prompt: {}, completion: {}, total: {}, cached: {:?}, reasoning: {:?}",
prompt_tokens, completion_tokens, total_tokens,
cache_hit_tokens, thinking_tokens
);
let mut metadata = HashMap::new();
metadata.insert("response_id".to_string(), serde_json::json!(response.id));
Ok(LLMResponse {
content,
prompt_tokens,
completion_tokens,
total_tokens,
model: response.model,
finish_reason: choice.finish_reason.map(|r| format!("{:?}", r)),
tool_calls: Vec::new(),
metadata,
cache_hit_tokens,
thinking_tokens,
thinking_content: None,
})
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let openai_messages = Self::convert_messages(messages)?;
let opts = options.cloned().unwrap_or_default();
let openai_tools: Vec<ChatCompletionTools> = tools
.iter()
.map(|t| {
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name(&t.function.name)
.description(&t.function.description)
.parameters(t.function.parameters.clone())
.build()
.expect("Invalid tool definition"),
})
})
.collect();
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder
.model(&self.model)
.messages(openai_messages)
.tools(openai_tools);
if let Some(tc) = tool_choice {
match tc {
ToolChoice::Auto(_) => {
request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
ToolChoiceOptions::Auto,
));
}
ToolChoice::Required(_) => {
request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
ToolChoiceOptions::Required,
));
}
ToolChoice::Function { ref function, .. } => {
request_builder.tool_choice(ChatCompletionToolChoiceOption::Function(
ChatCompletionNamedToolChoice {
function: FunctionName {
name: function.name.clone(),
},
},
));
}
}
}
if let Some(max_tokens) = opts.max_tokens {
request_builder.max_completion_tokens(max_tokens as u32);
}
if let Some(temp) = opts.temperature {
if (temp - 1.0_f32).abs() > f32::EPSILON {
request_builder.temperature(temp);
}
}
let request = request_builder
.build()
.map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
let response = self.client.chat().create(request).await?;
debug!(
"OpenAI chat_with_tools response id={} model={}",
response.id, response.model
);
let choice = response
.choices
.first()
.ok_or_else(|| LlmError::ApiError("No choices in response".to_string()))?;
if let Some(FinishReason::ContentFilter) = choice.finish_reason {
return Err(LlmError::ApiError(
"Response blocked by OpenAI content filter (finish_reason=content_filter)".into(),
));
}
let tool_calls: Vec<ToolCall> = choice
.message
.tool_calls
.as_deref()
.unwrap_or_default()
.iter()
.filter_map(|tc| {
if let ChatCompletionMessageToolCalls::Function(f) = tc {
Some(ToolCall {
id: f.id.clone(),
call_type: "function".to_string(),
function: TraitFunctionCall {
name: f.function.name.clone(),
arguments: f.function.arguments.clone(),
},
thought_signature: None,
})
} else {
None
}
})
.collect();
let content = choice.message.content.clone().unwrap_or_default();
let (prompt_tokens, completion_tokens, total_tokens, cache_hit_tokens, thinking_tokens) =
Self::extract_usage(response.usage.clone());
let mut metadata = HashMap::new();
metadata.insert("response_id".to_string(), serde_json::json!(response.id));
Ok(LLMResponse {
content,
prompt_tokens,
completion_tokens,
total_tokens,
model: response.model,
finish_reason: choice.finish_reason.map(|r| format!("{:?}", r)),
tool_calls,
metadata,
cache_hit_tokens,
thinking_tokens,
thinking_content: None,
})
}
fn supports_function_calling(&self) -> bool {
true
}
async fn stream(
&self,
prompt: &str,
) -> Result<futures::stream::BoxStream<'static, Result<String>>> {
let request = ChatCompletionRequestUserMessageArgs::default()
.content(prompt)
.build()
.map(Into::into)
.map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
let request = CreateChatCompletionRequestArgs::default()
.model(&self.model)
.messages(vec![request])
.stream(true)
.build()
.map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
let stream = self.client.chat().create_stream(request).await?;
let mapped_stream = stream.map(|res| match res {
Ok(response) => {
let content = response
.choices
.first()
.and_then(|c| c.delta.content.clone())
.unwrap_or_default();
Ok(content)
}
Err(e) => Err(LlmError::from(e)),
});
Ok(mapped_stream.boxed())
}
fn supports_streaming(&self) -> bool {
true
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<futures::stream::BoxStream<'static, Result<StreamChunk>>> {
let openai_messages = Self::convert_messages(messages)?;
let options = options.cloned().unwrap_or_default();
let openai_tools: Vec<ChatCompletionTools> = tools
.iter()
.map(|tool| {
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name(&tool.function.name)
.description(&tool.function.description)
.parameters(tool.function.parameters.clone())
.build()
.expect("Invalid tool definition"),
})
})
.collect();
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder
.model(&self.model)
.messages(openai_messages)
.tools(openai_tools)
.stream(true)
.stream_options(ChatCompletionStreamOptions {
include_usage: Some(true),
include_obfuscation: None,
});
if let Some(tc) = tool_choice {
match tc {
ToolChoice::Auto(_) => {
request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
ToolChoiceOptions::Auto,
));
}
ToolChoice::Required(_) => {
request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
ToolChoiceOptions::Required,
));
}
ToolChoice::Function { ref function, .. } => {
request_builder.tool_choice(ChatCompletionToolChoiceOption::Function(
ChatCompletionNamedToolChoice {
function: FunctionName {
name: function.name.clone(),
},
},
));
}
}
}
if let Some(temp) = options.temperature {
if (temp - 1.0_f32).abs() > f32::EPSILON {
request_builder.temperature(temp);
}
}
if let Some(max_tokens) = options.max_tokens {
request_builder.max_completion_tokens(max_tokens as u32);
}
let request = request_builder
.build()
.map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
let stream = self.client.chat().create_stream(request).await?;
let mapped_stream = stream.map(|result| {
match result {
Ok(response) => {
let stream_usage = Self::extract_stream_usage(response.usage.clone());
let choice = response.choices.first();
if let Some(choice) = choice {
if let Some(content) = &choice.delta.content {
return Ok(StreamChunk::Content(content.clone()));
}
if let Some(tool_call_chunks) = &choice.delta.tool_calls {
if let Some(chunk) = tool_call_chunks.first() {
return Ok(StreamChunk::ToolCallDelta {
index: chunk.index as usize,
id: chunk.id.clone(),
function_name: chunk
.function
.as_ref()
.and_then(|f| f.name.clone()),
function_arguments: chunk
.function
.as_ref()
.and_then(|f| f.arguments.clone()),
thought_signature: None,
});
}
}
if let Some(finish_reason) = &choice.finish_reason {
let reason = match finish_reason {
FinishReason::Stop => "stop",
FinishReason::Length => "length",
FinishReason::ToolCalls => "tool_calls",
FinishReason::ContentFilter => "content_filter",
FinishReason::FunctionCall => "function_call",
};
return Ok(StreamChunk::Finished {
reason: reason.to_string(),
ttft_ms: None,
usage: stream_usage,
});
}
}
if stream_usage.is_some() {
return Ok(StreamChunk::Finished {
reason: "stop".to_string(),
ttft_ms: None,
usage: stream_usage,
});
}
Ok(StreamChunk::Content(String::new()))
}
Err(e) => Err(LlmError::from(e)),
}
});
Ok(mapped_stream.boxed())
}
fn supports_tool_streaming(&self) -> bool {
true
}
fn supports_json_mode(&self) -> bool {
let m = &self.model;
m.contains("gpt-4")
|| m.contains("gpt-3.5-turbo")
|| m.contains("gpt-5")
|| m.starts_with("o1")
|| m.starts_with("o3")
|| m.starts_with("o4")
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIProvider {
fn name(&self) -> &str {
"openai"
}
#[allow(clippy::misnamed_getters)]
fn model(&self) -> &str {
&self.embedding_model
}
fn dimension(&self) -> usize {
self.embedding_dimension
}
fn max_tokens(&self) -> usize {
8191 }
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let input = EmbeddingInput::StringArray(texts.to_vec());
let request = CreateEmbeddingRequestArgs::default()
.model(&self.embedding_model)
.input(input)
.build()
.map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
let response = self.client.embeddings().create(request).await?;
let embeddings: Vec<Vec<f32>> = response.data.into_iter().map(|e| e.embedding).collect();
Ok(embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_length_detection() {
assert_eq!(OpenAIProvider::context_length_for_model("gpt-4o"), 128000);
assert_eq!(OpenAIProvider::context_length_for_model("gpt-4"), 8192);
assert_eq!(
OpenAIProvider::context_length_for_model("gpt-3.5-turbo"),
4096
);
}
#[test]
fn test_embedding_dimension_detection() {
assert_eq!(
OpenAIProvider::dimension_for_model("text-embedding-3-large"),
3072
);
assert_eq!(
OpenAIProvider::dimension_for_model("text-embedding-3-small"),
1536
);
}
#[test]
fn test_provider_builder() {
let provider = OpenAIProvider::new("test-key")
.with_model("gpt-4")
.with_embedding_model("text-embedding-3-large");
assert_eq!(LLMProvider::model(&provider), "gpt-4");
assert_eq!(provider.dimension(), 3072);
}
#[test]
fn test_message_conversion() {
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
];
let converted = OpenAIProvider::convert_messages(&messages).unwrap();
assert_eq!(converted.len(), 3);
}
#[test]
fn test_context_length_gpt5_series() {
assert_eq!(
OpenAIProvider::context_length_for_model("gpt-5.2-turbo"),
200000
);
assert_eq!(
OpenAIProvider::context_length_for_model("gpt-5.1-preview"),
200000
);
assert_eq!(
OpenAIProvider::context_length_for_model("gpt-5-nano"),
128000
);
assert_eq!(
OpenAIProvider::context_length_for_model("gpt-5-mini"),
200000
);
assert_eq!(OpenAIProvider::context_length_for_model("gpt-5"), 200000);
}
#[test]
fn test_context_length_o_series() {
assert_eq!(OpenAIProvider::context_length_for_model("o4-mini"), 200000);
assert_eq!(
OpenAIProvider::context_length_for_model("o3-preview"),
200000
);
assert_eq!(
OpenAIProvider::context_length_for_model("o1-preview"),
200000
);
}
#[test]
fn test_context_length_gpt4_variants() {
assert_eq!(
OpenAIProvider::context_length_for_model("gpt-4-turbo-preview"),
128000
);
assert_eq!(
OpenAIProvider::context_length_for_model("gpt-4-32k-0613"),
32768
);
assert_eq!(OpenAIProvider::context_length_for_model("gpt-4-0613"), 8192);
}
#[test]
fn test_context_length_gpt35_variants() {
assert_eq!(
OpenAIProvider::context_length_for_model("gpt-3.5-turbo-16k"),
16384
);
assert_eq!(
OpenAIProvider::context_length_for_model("gpt-3.5-turbo-1106"),
4096
);
}
#[test]
fn test_context_length_unknown_defaults_high() {
assert_eq!(
OpenAIProvider::context_length_for_model("unknown-future-model"),
128000
);
}
#[test]
fn test_dimension_ada_model() {
assert_eq!(
OpenAIProvider::dimension_for_model("text-embedding-ada-002"),
1536
);
}
#[test]
fn test_dimension_unknown_defaults() {
assert_eq!(
OpenAIProvider::dimension_for_model("unknown-embedding"),
1536
);
}
#[test]
fn test_provider_name() {
let provider = OpenAIProvider::new("test-key");
assert_eq!(LLMProvider::name(&provider), "openai");
}
#[test]
fn test_provider_max_context_length() {
let provider = OpenAIProvider::new("test-key").with_model("gpt-4");
assert_eq!(provider.max_context_length(), 8192);
}
#[test]
fn test_provider_dimension() {
let provider =
OpenAIProvider::new("test-key").with_embedding_model("text-embedding-3-large");
assert_eq!(provider.dimension(), 3072);
}
#[test]
fn test_provider_embedding_model() {
let provider =
OpenAIProvider::new("test-key").with_embedding_model("text-embedding-3-small");
assert_eq!(
EmbeddingProvider::model(&provider),
"text-embedding-3-small"
);
}
#[test]
fn test_message_conversion_tool_role() {
let messages = vec![ChatMessage::tool_result("call_abc", "result data")];
let converted = OpenAIProvider::convert_messages(&messages).unwrap();
assert_eq!(converted.len(), 1);
match &converted[0] {
ChatCompletionRequestMessage::Tool(m) => {
assert_eq!(m.tool_call_id, "call_abc");
}
other => panic!("Expected Tool message, got {:?}", other),
}
}
#[test]
fn test_tool_message_missing_id_returns_err() {
let mut msg = ChatMessage::user("orphan");
msg.role = ChatRole::Tool;
msg.tool_call_id = None;
let r = OpenAIProvider::convert_messages(&[msg]);
assert!(
r.is_err(),
"Expected Err for tool message without tool_call_id"
);
}
#[test]
fn test_assistant_with_tool_calls_serialized() {
let calls = vec![ToolCall {
id: "call_xyz".to_string(),
call_type: "function".to_string(),
function: TraitFunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"city":"Paris"}"#.to_string(),
},
thought_signature: None,
}];
let msg = ChatMessage::assistant_with_tools("", calls);
let converted = OpenAIProvider::convert_messages(&[msg]).unwrap();
assert_eq!(converted.len(), 1);
match &converted[0] {
ChatCompletionRequestMessage::Assistant(m) => {
let tcs = m.tool_calls.as_ref().expect("tool_calls must be present");
assert_eq!(tcs.len(), 1);
if let ChatCompletionMessageToolCalls::Function(f) = &tcs[0] {
assert_eq!(f.id, "call_xyz");
assert_eq!(f.function.name, "get_weather");
} else {
panic!("Expected Function tool call");
}
}
other => panic!("Expected Assistant message, got {:?}", other),
}
}
#[test]
fn test_supports_streaming() {
let provider = OpenAIProvider::new("test-key");
assert!(provider.supports_streaming());
}
#[test]
fn test_supports_json_mode_gpt4() {
let provider = OpenAIProvider::new("test-key").with_model("gpt-4o");
assert!(provider.supports_json_mode());
}
#[test]
fn test_supports_json_mode_gpt35() {
let provider = OpenAIProvider::new("test-key").with_model("gpt-3.5-turbo");
assert!(provider.supports_json_mode());
}
#[test]
fn test_supports_json_mode_default_is_false() {
let provider = OpenAIProvider::new("test-key").with_model("davinci-002");
assert!(!provider.supports_json_mode());
}
#[test]
fn test_build_user_content_text_only() {
let msg = ChatMessage::user("Hello");
let content = OpenAIProvider::build_user_content(&msg);
match content {
ChatCompletionRequestUserMessageContent::Text(t) => assert_eq!(t, "Hello"),
_ => panic!("Expected text content"),
}
}
#[test]
fn test_build_user_content_with_image() {
use crate::traits::ImageData;
let img = ImageData::new("base64data", "image/png");
let msg = ChatMessage::user_with_images("Describe this", vec![img]);
let content = OpenAIProvider::build_user_content(&msg);
match content {
ChatCompletionRequestUserMessageContent::Array(parts) => {
assert_eq!(parts.len(), 2, "Should have text + image parts");
assert!(
matches!(
parts[0],
ChatCompletionRequestUserMessageContentPart::Text(_)
),
"First part should be text"
);
assert!(
matches!(
parts[1],
ChatCompletionRequestUserMessageContentPart::ImageUrl(_)
),
"Second part should be image_url"
);
}
_ => panic!("Expected array content for vision message"),
}
}
#[test]
fn test_build_user_content_image_data_uri() {
use crate::traits::ImageData;
let img = ImageData::new("abc123", "image/jpeg");
let msg = ChatMessage::user_with_images("What's here?", vec![img]);
let content = OpenAIProvider::build_user_content(&msg);
if let ChatCompletionRequestUserMessageContent::Array(parts) = content {
if let ChatCompletionRequestUserMessageContentPart::ImageUrl(img_part) = &parts[1] {
assert_eq!(
img_part.image_url.url, "data:image/jpeg;base64,abc123",
"Data URI should be correct"
);
} else {
panic!("Expected ImageUrl part");
}
} else {
panic!("Expected array content");
}
}
#[test]
fn test_build_user_content_image_with_detail() {
use crate::traits::ImageData;
let img = ImageData::new("data", "image/png").with_detail("high");
let _msg = ChatMessage::user_with_images("Analyze", vec![img]);
let detail = OpenAIProvider::parse_image_detail(
&ImageData::new("x", "image/png").with_detail("high"),
);
assert!(matches!(detail, Some(ImageDetail::High)));
}
#[test]
fn test_parse_image_detail_low() {
use crate::traits::ImageData;
let img = ImageData::new("x", "image/png").with_detail("low");
let d = OpenAIProvider::parse_image_detail(&img);
assert!(matches!(d, Some(ImageDetail::Low)));
}
#[test]
fn test_parse_image_detail_auto() {
use crate::traits::ImageData;
let img = ImageData::new("x", "image/png").with_detail("auto");
let d = OpenAIProvider::parse_image_detail(&img);
assert!(matches!(d, Some(ImageDetail::Auto)));
}
#[test]
fn test_parse_image_detail_none() {
use crate::traits::ImageData;
let img = ImageData::new("x", "image/png");
let d = OpenAIProvider::parse_image_detail(&img);
assert!(d.is_none());
}
#[test]
fn test_convert_messages_with_image_produces_array_content() {
use crate::traits::ImageData;
let img = ImageData::new("iVBORw0KGgo", "image/png");
let messages = vec![
ChatMessage::system("You are a vision assistant"),
ChatMessage::user_with_images("What is in this image?", vec![img]),
];
let converted = OpenAIProvider::convert_messages(&messages).unwrap();
assert_eq!(converted.len(), 2);
let json = serde_json::to_value(&converted[1]).unwrap();
let content = &json["content"];
assert!(
content.is_array(),
"Vision user message content must be a JSON array, got: {:?}",
content
);
let parts = content.as_array().unwrap();
assert_eq!(parts.len(), 2, "Should have text + image parts");
assert_eq!(parts[0]["type"], "text");
assert_eq!(parts[1]["type"], "image_url");
assert!(parts[1]["image_url"]["url"]
.as_str()
.unwrap()
.starts_with("data:image/png;base64,"));
}
#[test]
fn test_convert_messages_without_image_produces_text_content() {
let messages = vec![ChatMessage::user("Just text")];
let converted = OpenAIProvider::convert_messages(&messages).unwrap();
let json = serde_json::to_value(&converted[0]).unwrap();
let content = &json["content"];
assert!(
content.is_string(),
"Plain text user message content must be a JSON string"
);
assert_eq!(content.as_str().unwrap(), "Just text");
}
#[test]
fn test_chat_completion_tools_function_wrapping() {
use crate::traits::FunctionDefinition;
let tool_def = ToolDefinition {
tool_type: "function".to_string(),
function: FunctionDefinition {
name: "get_weather".to_string(),
description: "Get the current weather".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"location": { "type": "string" }
},
"required": ["location"]
}),
strict: None,
},
};
let openai_tool = ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObjectArgs::default()
.name(&tool_def.function.name)
.description(&tool_def.function.description)
.parameters(tool_def.function.parameters.clone())
.build()
.unwrap(),
});
let json = serde_json::to_value(&openai_tool).unwrap();
assert_eq!(json["type"], "function");
assert_eq!(json["function"]["name"], "get_weather");
assert_eq!(json["function"]["description"], "Get the current weather");
}
#[test]
fn test_tool_choice_auto_serialization() {
let choice = ChatCompletionToolChoiceOption::Mode(ToolChoiceOptions::Auto);
let json = serde_json::to_value(&choice).unwrap();
assert_eq!(json, "auto");
}
#[test]
fn test_tool_choice_required_serialization() {
let choice = ChatCompletionToolChoiceOption::Mode(ToolChoiceOptions::Required);
let json = serde_json::to_value(&choice).unwrap();
assert_eq!(json, "required");
}
#[test]
fn test_tool_choice_none_serialization() {
let choice = ChatCompletionToolChoiceOption::Mode(ToolChoiceOptions::None);
let json = serde_json::to_value(&choice).unwrap();
assert_eq!(json, "none");
}
#[test]
fn test_max_completion_tokens_in_request_serialization() {
let request = CreateChatCompletionRequestArgs::default()
.model("o3-mini")
.messages(vec![ChatCompletionRequestUserMessageArgs::default()
.content("Hello")
.build()
.unwrap()
.into()])
.max_completion_tokens(1024u32)
.build()
.unwrap();
let json = serde_json::to_value(&request).unwrap();
assert_eq!(
json["max_completion_tokens"], 1024,
"max_completion_tokens should be set in request"
);
assert!(
json["max_tokens"].is_null(),
"deprecated max_tokens should NOT be set"
);
}
#[test]
fn test_max_completion_tokens_works_for_all_models() {
for model in &[
"gpt-4o",
"gpt-3.5-turbo",
"o1-preview",
"o3-mini",
"gpt-4.1-nano",
] {
let request = CreateChatCompletionRequestArgs::default()
.model(*model)
.messages(vec![ChatCompletionRequestUserMessageArgs::default()
.content("Test")
.build()
.unwrap()
.into()])
.max_completion_tokens(512u32)
.build()
.unwrap();
let json = serde_json::to_value(&request).unwrap();
assert_eq!(
json["max_completion_tokens"], 512,
"max_completion_tokens should be set for model {}",
model
);
}
}
#[test]
fn test_cache_hit_token_extraction() {
use async_openai::types::chat::PromptTokensDetails;
let usage = CompletionUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
prompt_tokens_details: Some(PromptTokensDetails {
cached_tokens: Some(80),
audio_tokens: None,
}),
completion_tokens_details: None,
};
let cache_hit_tokens = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens)
.map(|t| t as usize);
assert_eq!(cache_hit_tokens, Some(80));
}
#[test]
fn test_reasoning_token_extraction() {
use async_openai::types::chat::CompletionTokensDetails;
let usage = CompletionUsage {
prompt_tokens: 50,
completion_tokens: 200,
total_tokens: 250,
prompt_tokens_details: None,
completion_tokens_details: Some(CompletionTokensDetails {
reasoning_tokens: Some(150),
audio_tokens: None,
accepted_prediction_tokens: None,
rejected_prediction_tokens: None,
}),
};
let thinking_tokens = usage
.completion_tokens_details
.as_ref()
.and_then(|d| d.reasoning_tokens)
.map(|t| t as usize);
assert_eq!(thinking_tokens, Some(150));
}
#[test]
fn test_token_details_none_is_safe() {
let usage = CompletionUsage {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
prompt_tokens_details: None,
completion_tokens_details: None,
};
let cache_hit = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens)
.map(|t| t as usize);
let reasoning = usage
.completion_tokens_details
.as_ref()
.and_then(|d| d.reasoning_tokens)
.map(|t| t as usize);
assert_eq!(cache_hit, None);
assert_eq!(reasoning, None);
}
#[test]
fn test_finish_reason_variants() {
let cases = vec![
(FinishReason::Stop, "Stop"),
(FinishReason::Length, "Length"),
(FinishReason::ToolCalls, "ToolCalls"),
(FinishReason::ContentFilter, "ContentFilter"),
(FinishReason::FunctionCall, "FunctionCall"),
];
for (reason, expected_debug) in cases {
let formatted = format!("{:?}", reason);
assert_eq!(
formatted, expected_debug,
"FinishReason::{} should format as {:?}",
expected_debug, expected_debug
);
}
}
#[test]
fn test_json_deserialize_error_conversion() {
use crate::error::LlmError;
let serde_err = serde_json::from_str::<serde_json::Value>("invalid json {{").unwrap_err();
let openai_err = async_openai::error::OpenAIError::JSONDeserialize(
serde_err,
"invalid json {{".to_string(),
);
let llm_err = LlmError::from(openai_err);
assert!(
matches!(llm_err, LlmError::SerializationError(_)),
"JSONDeserialize error should convert to SerializationError"
);
}
#[test]
fn test_chat_completion_tool_serialization() {
let tool = ChatCompletionTool {
function: FunctionObjectArgs::default()
.name("my_func")
.description("A test function")
.parameters(serde_json::json!({"type": "object"}))
.build()
.unwrap(),
};
let wrapped = ChatCompletionTools::Function(tool);
let json = serde_json::to_value(&wrapped).unwrap();
assert_eq!(json["type"], "function");
assert_eq!(json["function"]["name"], "my_func");
}
}