use crate::spec_ai_core::agent::model::{
GenerationConfig, ImageAttachment, ModelProvider, ModelResponse, ModelStreamItem, ProviderKind,
ProviderMetadata, TokenUsage, ToolCall, parse_thinking_tokens,
};
use anyhow::{Result, anyhow};
use async_openai::{
Client,
config::OpenAIConfig,
types::chat::{
ChatCompletionMessageToolCalls, ChatCompletionRequestMessage,
ChatCompletionRequestMessageContentPartImageArgs,
ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent,
ChatCompletionRequestUserMessageContentPart, ChatCompletionResponseStream,
ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs,
CreateChatCompletionResponse, ImageUrlArgs,
},
};
use async_stream::stream;
use async_trait::async_trait;
use base64::{Engine as _, engine::general_purpose};
use futures::Stream;
use std::pin::Pin;
#[derive(Debug, Clone)]
pub struct OpenAIProvider {
client: Client<OpenAIConfig>,
model: String,
system_message: Option<String>,
tools: Option<Vec<ChatCompletionTool>>,
}
impl OpenAIProvider {
pub fn new() -> Self {
Self {
client: Client::new(),
model: "gpt-5-mini".to_string(),
system_message: None,
tools: None,
}
}
pub fn with_api_key(api_key: impl Into<String>) -> Self {
let config = OpenAIConfig::new().with_api_key(api_key);
Self {
client: Client::with_config(config),
model: "gpt-5-mini".to_string(),
system_message: None,
tools: None,
}
}
pub fn with_config(config: OpenAIConfig) -> Self {
Self {
client: Client::with_config(config),
model: "gpt-5-mini".to_string(),
system_message: None,
tools: None,
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
self.system_message = Some(message.into());
self
}
pub fn with_tools(mut self, tools: Vec<ChatCompletionTool>) -> Self {
self.tools = if tools.is_empty() { None } else { Some(tools) };
self
}
fn model_requires_default_sampling(model: &str) -> bool {
model.to_ascii_lowercase().starts_with("gpt-5")
}
fn apply_generation_config_for_model(
model: &str,
request_builder: &mut CreateChatCompletionRequestArgs,
config: &GenerationConfig,
) {
if !Self::model_requires_default_sampling(model) {
if let Some(temp) = config.temperature {
request_builder.temperature(temp);
}
if let Some(top_p) = config.top_p {
request_builder.top_p(top_p);
}
}
if let Some(max_tokens) = config.max_tokens {
request_builder.max_completion_tokens(max_tokens);
}
if let Some(freq_penalty) = config.frequency_penalty {
request_builder.frequency_penalty(freq_penalty);
}
if let Some(pres_penalty) = config.presence_penalty {
request_builder.presence_penalty(pres_penalty);
}
if let Some(stop) = &config.stop_sequences {
request_builder.stop(stop.clone());
}
}
fn apply_generation_config(
&self,
request_builder: &mut CreateChatCompletionRequestArgs,
config: &GenerationConfig,
) {
Self::apply_generation_config_for_model(&self.model, request_builder, config);
}
fn build_messages(&self, prompt: &str) -> Result<Vec<ChatCompletionRequestMessage>> {
let mut messages = Vec::new();
if let Some(system_msg) = &self.system_message {
let system_message = ChatCompletionRequestSystemMessageArgs::default()
.content(system_msg.clone())
.build()
.map_err(|e| anyhow!("Failed to build system message: {}", e))?;
messages.push(ChatCompletionRequestMessage::System(system_message));
}
let user_message = ChatCompletionRequestUserMessageArgs::default()
.content(prompt)
.build()
.map_err(|e| anyhow!("Failed to build user message: {}", e))?;
messages.push(ChatCompletionRequestMessage::User(user_message));
Ok(messages)
}
fn build_messages_with_attachments(
&self,
prompt: &str,
attachments: &[ImageAttachment],
) -> Result<Vec<ChatCompletionRequestMessage>> {
let mut messages = Vec::new();
if let Some(system_msg) = &self.system_message {
let system_message = ChatCompletionRequestSystemMessageArgs::default()
.content(system_msg.clone())
.build()
.map_err(|e| anyhow!("Failed to build system message: {}", e))?;
messages.push(ChatCompletionRequestMessage::System(system_message));
}
let mut parts = Vec::new();
if !prompt.is_empty() {
let text_part = ChatCompletionRequestMessageContentPartTextArgs::default()
.text(prompt)
.build()
.map_err(|e| anyhow!("Failed to build text content part: {}", e))?;
parts.push(ChatCompletionRequestUserMessageContentPart::Text(text_part));
}
for attachment in attachments {
let encoded = general_purpose::STANDARD.encode(&attachment.data);
let data_url = format!("data:{};base64,{}", attachment.mime, encoded);
let image_url = ImageUrlArgs::default()
.url(data_url)
.build()
.map_err(|e| anyhow!("Failed to build image url: {}", e))?;
let image_part = ChatCompletionRequestMessageContentPartImageArgs::default()
.image_url(image_url)
.build()
.map_err(|e| anyhow!("Failed to build image content part: {}", e))?;
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
image_part,
));
}
let user_message = ChatCompletionRequestUserMessageArgs::default()
.content(ChatCompletionRequestUserMessageContent::Array(parts))
.build()
.map_err(|e| anyhow!("Failed to build user message: {}", e))?;
messages.push(ChatCompletionRequestMessage::User(user_message));
Ok(messages)
}
}
impl Default for OpenAIProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ModelProvider for OpenAIProvider {
async fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<ModelResponse> {
let messages = self.build_messages(prompt)?;
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder.model(&self.model).messages(messages);
self.apply_generation_config(&mut request_builder, config);
if let Some(ref tools) = self.tools {
let openai_tools: Vec<ChatCompletionTools> = tools
.iter()
.cloned()
.map(ChatCompletionTools::Function)
.collect();
request_builder.tools(openai_tools);
}
let request = request_builder
.build()
.map_err(|e| anyhow!("Failed to build request: {}", e))?;
let response: CreateChatCompletionResponse = self
.client
.chat()
.create(request)
.await
.map_err(|e| anyhow!("OpenAI API error: {}", e))?;
let choice = response
.choices
.first()
.ok_or_else(|| anyhow!("No response choices returned"))?;
let raw_content = choice.message.content.clone().unwrap_or_default();
let (reasoning, content) = parse_thinking_tokens(&raw_content);
let tool_calls = choice
.message
.tool_calls
.as_ref()
.map(|calls: &Vec<ChatCompletionMessageToolCalls>| {
calls
.iter()
.filter_map(|call| {
if let ChatCompletionMessageToolCalls::Function(f) = call {
let arguments = serde_json::from_str(&f.function.arguments).ok()?;
Some(ToolCall {
id: f.id.clone(),
function_name: f.function.name.clone(),
arguments,
})
} else {
None
}
})
.collect::<Vec<_>>()
})
.filter(|calls| !calls.is_empty());
let usage = response.usage.map(|u| TokenUsage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
});
Ok(ModelResponse {
content,
model: response.model,
usage,
finish_reason: choice.finish_reason.as_ref().map(|r| format!("{:?}", r)),
tool_calls,
reasoning,
})
}
async fn stream(
&self,
prompt: &str,
config: &GenerationConfig,
) -> Result<Pin<Box<dyn Stream<Item = Result<ModelStreamItem>> + Send>>> {
let messages = self.build_messages(prompt)?;
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder
.model(&self.model)
.messages(messages)
.stream(true);
if self.model.starts_with("gpt-4") || self.model.starts_with("gpt-5") {
use async_openai::types::chat::ChatCompletionStreamOptions;
let stream_options: ChatCompletionStreamOptions =
serde_json::from_value(serde_json::json!({
"include_usage": true
}))
.unwrap_or_else(|_| {
ChatCompletionStreamOptions {
include_usage: Some(true),
..unsafe { std::mem::zeroed() } }
});
request_builder.stream_options(stream_options);
}
self.apply_generation_config(&mut request_builder, config);
let request = request_builder
.build()
.map_err(|e| anyhow!("Failed to build streaming request: {}", e))?;
let mut response_stream: ChatCompletionResponseStream = self
.client
.chat()
.create_stream(request)
.await
.map_err(|e| anyhow!("OpenAI streaming API error: {}", e))?;
let stream = stream! {
use futures::StreamExt;
use crate::spec_ai_core::agent::model::ModelStreamItem;
let mut buffer = String::new();
let mut in_think_block = false;
let mut think_ended = false;
while let Some(result) = response_stream.next().await {
match result {
Ok(response) => {
if let Some(usage) = response.usage {
yield Ok(ModelStreamItem::Usage(TokenUsage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
}));
}
if let Some(choice) = response.choices.first() {
if let Some(reason) = &choice.finish_reason {
yield Ok(ModelStreamItem::FinishReason(format!("{:?}", reason)));
}
if let Some(content) = &choice.delta.content {
buffer.push_str(content);
if buffer.contains("<think>") && !in_think_block {
in_think_block = true;
}
if buffer.contains("</think>") && in_think_block {
in_think_block = false;
think_ended = true;
if let Some(idx) = buffer.find("</think>") {
buffer = buffer[idx + "</think>".len()..].to_string();
}
}
if !in_think_block && (think_ended || !buffer.contains("<think>")) {
let output = buffer.clone();
buffer.clear();
if !output.is_empty() {
yield Ok(ModelStreamItem::Content(output));
}
}
}
}
}
Err(e) => {
yield Err(anyhow!("Stream error: {}", e));
break;
}
}
}
if !buffer.is_empty() && !in_think_block {
yield Ok(ModelStreamItem::Content(buffer));
}
};
Ok(Box::pin(stream))
}
async fn generate_with_attachments(
&self,
prompt: &str,
attachments: &[ImageAttachment],
config: &GenerationConfig,
) -> Result<ModelResponse> {
if attachments.is_empty() {
return self.generate(prompt, config).await;
}
let messages = self.build_messages_with_attachments(prompt, attachments)?;
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder.model(&self.model).messages(messages);
self.apply_generation_config(&mut request_builder, config);
if let Some(ref tools) = self.tools {
let openai_tools: Vec<ChatCompletionTools> = tools
.iter()
.cloned()
.map(ChatCompletionTools::Function)
.collect();
request_builder.tools(openai_tools);
}
let request = request_builder
.build()
.map_err(|e| anyhow!("Failed to build request: {}", e))?;
let response: CreateChatCompletionResponse = self
.client
.chat()
.create(request)
.await
.map_err(|e| anyhow!("OpenAI API error: {}", e))?;
let choice = response
.choices
.first()
.ok_or_else(|| anyhow!("No response choices returned"))?;
let raw_content = choice.message.content.clone().unwrap_or_default();
let (reasoning, content) = parse_thinking_tokens(&raw_content);
let tool_calls = choice
.message
.tool_calls
.as_ref()
.map(|calls: &Vec<ChatCompletionMessageToolCalls>| {
calls
.iter()
.filter_map(|call| {
if let ChatCompletionMessageToolCalls::Function(f) = call {
let arguments = serde_json::from_str(&f.function.arguments).ok()?;
Some(ToolCall {
id: f.id.clone(),
function_name: f.function.name.clone(),
arguments,
})
} else {
None
}
})
.collect::<Vec<_>>()
})
.filter(|calls| !calls.is_empty());
let usage = response.usage.map(|u| TokenUsage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
});
Ok(ModelResponse {
content,
model: response.model,
usage,
finish_reason: choice.finish_reason.as_ref().map(|r| format!("{:?}", r)),
tool_calls,
reasoning,
})
}
async fn stream_with_attachments(
&self,
prompt: &str,
attachments: &[ImageAttachment],
config: &GenerationConfig,
) -> Result<Pin<Box<dyn Stream<Item = Result<ModelStreamItem>> + Send>>> {
if attachments.is_empty() {
return self.stream(prompt, config).await;
}
let messages = self.build_messages_with_attachments(prompt, attachments)?;
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder
.model(&self.model)
.messages(messages)
.stream(true);
if self.model.starts_with("gpt-4") || self.model.starts_with("gpt-5") {
use async_openai::types::chat::ChatCompletionStreamOptions;
let stream_options: ChatCompletionStreamOptions =
serde_json::from_value(serde_json::json!({
"include_usage": true
}))
.unwrap_or_else(|_| {
ChatCompletionStreamOptions {
include_usage: Some(true),
..unsafe { std::mem::zeroed() } }
});
request_builder.stream_options(stream_options);
}
self.apply_generation_config(&mut request_builder, config);
if let Some(ref tools) = self.tools {
let openai_tools: Vec<ChatCompletionTools> = tools
.iter()
.cloned()
.map(ChatCompletionTools::Function)
.collect();
request_builder.tools(openai_tools);
}
let request = request_builder
.build()
.map_err(|e| anyhow!("Failed to build request: {}", e))?;
let mut response_stream: ChatCompletionResponseStream = self
.client
.chat()
.create_stream(request)
.await
.map_err(|e| anyhow!("OpenAI API error: {}", e))?;
let stream = stream! {
use futures::StreamExt;
use crate::spec_ai_core::agent::model::ModelStreamItem;
let mut buffer = String::new();
let mut in_think_block = false;
let mut think_ended = false;
while let Some(result) = response_stream.next().await {
match result {
Ok(response) => {
if let Some(usage) = response.usage {
yield Ok(ModelStreamItem::Usage(TokenUsage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
}));
}
if let Some(choice) = response.choices.first() {
if let Some(reason) = &choice.finish_reason {
yield Ok(ModelStreamItem::FinishReason(format!("{:?}", reason)));
}
if let Some(content) = &choice.delta.content {
buffer.push_str(content);
if buffer.contains("<think>") && !in_think_block {
in_think_block = true;
}
if buffer.contains("</think>") && in_think_block {
in_think_block = false;
think_ended = true;
if let Some(idx) = buffer.find("</think>") {
buffer = buffer[idx + "</think>".len()..].to_string();
}
}
#[allow(clippy::if_same_then_else)]
if !in_think_block && think_ended {
yield Ok(ModelStreamItem::Content(buffer.clone()));
buffer.clear();
} else if !in_think_block && !buffer.contains("<think>") {
yield Ok(ModelStreamItem::Content(buffer.clone()));
buffer.clear();
}
}
}
}
Err(err) => {
yield Err(anyhow!("Stream error: {}", err));
break;
}
}
}
};
Ok(Box::pin(stream))
}
fn metadata(&self) -> ProviderMetadata {
ProviderMetadata {
name: "OpenAI".to_string(),
supported_models: vec![
"gpt-5".to_string(),
"gpt-5-mini".to_string(),
"gpt-5-nano".to_string(),
"gpt-5-pro".to_string(),
],
supports_streaming: true,
}
}
fn kind(&self) -> ProviderKind {
ProviderKind::OpenAI
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_openai_provider_creation() {
let provider = OpenAIProvider::new();
assert_eq!(provider.model, "gpt-5-mini");
assert!(provider.system_message.is_none());
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_openai_provider_with_model() {
let provider = OpenAIProvider::new().with_model("gpt-5");
assert_eq!(provider.model, "gpt-5");
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_openai_provider_with_system_message() {
let provider = OpenAIProvider::new().with_system_message("You are a helpful assistant.");
assert_eq!(
provider.system_message,
Some("You are a helpful assistant.".to_string())
);
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_openai_provider_metadata() {
let provider = OpenAIProvider::new();
let metadata = provider.metadata();
assert_eq!(metadata.name, "OpenAI");
assert!(metadata.supports_streaming);
assert!(metadata.supported_models.contains(&"gpt-5".to_string()));
assert!(
metadata
.supported_models
.contains(&"gpt-5-mini".to_string())
);
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_openai_provider_kind() {
let provider = OpenAIProvider::new();
assert_eq!(provider.kind(), ProviderKind::OpenAI);
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_build_messages_without_system() {
let provider = OpenAIProvider::new();
let messages = provider.build_messages("Hello, world!").unwrap();
assert_eq!(messages.len(), 1);
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_build_messages_with_system() {
let provider = OpenAIProvider::new().with_system_message("You are a helpful assistant.");
let messages = provider.build_messages("Hello, world!").unwrap();
assert_eq!(messages.len(), 2);
}
fn test_message() -> ChatCompletionRequestMessage {
ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessageArgs::default()
.content("Hello")
.build()
.unwrap(),
)
}
#[test]
fn test_gpt5_requests_omit_sampling_overrides() {
let config = GenerationConfig {
temperature: Some(0.7),
top_p: Some(0.9),
max_tokens: Some(256),
..GenerationConfig::default()
};
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder
.model("gpt-5")
.messages(vec![test_message()]);
OpenAIProvider::apply_generation_config_for_model("gpt-5", &mut request_builder, &config);
let request = request_builder.build().unwrap();
assert_eq!(request.temperature, None);
assert_eq!(request.top_p, None);
assert_eq!(request.max_completion_tokens, Some(256));
}
}