use crate::spec_ai_core::agent::model::{
parse_thinking_tokens, GenerationConfig, ImageAttachment, ModelProvider, ModelResponse,
ProviderKind, ProviderMetadata, TokenUsage, ToolCall,
};
use anyhow::{anyhow, Result};
use async_openai::{
config::OpenAIConfig,
types::{
ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImageArgs,
ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent,
ChatCompletionRequestUserMessageContentPart, ChatCompletionTool,
CreateChatCompletionRequestArgs, ImageUrlArgs,
},
Client,
};
use async_stream::stream;
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use futures::Stream;
use std::pin::Pin;
#[derive(Debug, Clone)]
pub struct OpenAIProvider {
client: Client<OpenAIConfig>,
model: String,
system_message: Option<String>,
tools: Option<Vec<ChatCompletionTool>>,
}
impl OpenAIProvider {
pub fn new() -> Self {
Self {
client: Client::new(),
model: "gpt-4.1-mini".to_string(),
system_message: None,
tools: None,
}
}
pub fn with_api_key(api_key: impl Into<String>) -> Self {
let config = OpenAIConfig::new().with_api_key(api_key);
Self {
client: Client::with_config(config),
model: "gpt-4.1-mini".to_string(),
system_message: None,
tools: None,
}
}
pub fn with_config(config: OpenAIConfig) -> Self {
Self {
client: Client::with_config(config),
model: "gpt-4.1-mini".to_string(),
system_message: None,
tools: None,
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
self.system_message = Some(message.into());
self
}
pub fn with_tools(mut self, tools: Vec<ChatCompletionTool>) -> Self {
self.tools = if tools.is_empty() { None } else { Some(tools) };
self
}
fn build_messages(&self, prompt: &str) -> Result<Vec<ChatCompletionRequestMessage>> {
let mut messages = Vec::new();
if let Some(system_msg) = &self.system_message {
let system_message = ChatCompletionRequestSystemMessageArgs::default()
.content(system_msg.clone())
.build()
.map_err(|e| anyhow!("Failed to build system message: {}", e))?;
messages.push(ChatCompletionRequestMessage::System(system_message));
}
let user_message = ChatCompletionRequestUserMessageArgs::default()
.content(prompt)
.build()
.map_err(|e| anyhow!("Failed to build user message: {}", e))?;
messages.push(ChatCompletionRequestMessage::User(user_message));
Ok(messages)
}
fn build_messages_with_attachments(
&self,
prompt: &str,
attachments: &[ImageAttachment],
) -> Result<Vec<ChatCompletionRequestMessage>> {
let mut messages = Vec::new();
if let Some(system_msg) = &self.system_message {
let system_message = ChatCompletionRequestSystemMessageArgs::default()
.content(system_msg.clone())
.build()
.map_err(|e| anyhow!("Failed to build system message: {}", e))?;
messages.push(ChatCompletionRequestMessage::System(system_message));
}
let mut parts = Vec::new();
if !prompt.is_empty() {
let text_part = ChatCompletionRequestMessageContentPartTextArgs::default()
.text(prompt)
.build()
.map_err(|e| anyhow!("Failed to build text content part: {}", e))?;
parts.push(ChatCompletionRequestUserMessageContentPart::Text(text_part));
}
for attachment in attachments {
let encoded = general_purpose::STANDARD.encode(&attachment.data);
let data_url = format!("data:{};base64,{}", attachment.mime, encoded);
let image_url = ImageUrlArgs::default()
.url(data_url)
.build()
.map_err(|e| anyhow!("Failed to build image url: {}", e))?;
let image_part = ChatCompletionRequestMessageContentPartImageArgs::default()
.image_url(image_url)
.build()
.map_err(|e| anyhow!("Failed to build image content part: {}", e))?;
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
image_part,
));
}
let user_message = ChatCompletionRequestUserMessageArgs::default()
.content(ChatCompletionRequestUserMessageContent::Array(parts))
.build()
.map_err(|e| anyhow!("Failed to build user message: {}", e))?;
messages.push(ChatCompletionRequestMessage::User(user_message));
Ok(messages)
}
}
impl Default for OpenAIProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ModelProvider for OpenAIProvider {
async fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<ModelResponse> {
let messages = self.build_messages(prompt)?;
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder.model(&self.model).messages(messages);
if let Some(temp) = config.temperature {
request_builder.temperature(temp);
}
if let Some(max_tokens) = config.max_tokens {
request_builder.max_tokens(max_tokens);
}
if let Some(top_p) = config.top_p {
request_builder.top_p(top_p);
}
if let Some(freq_penalty) = config.frequency_penalty {
request_builder.frequency_penalty(freq_penalty);
}
if let Some(pres_penalty) = config.presence_penalty {
request_builder.presence_penalty(pres_penalty);
}
if let Some(stop) = &config.stop_sequences {
request_builder.stop(stop.clone());
}
if let Some(ref tools) = self.tools {
request_builder.tools(tools.clone());
}
let request = request_builder
.build()
.map_err(|e| anyhow!("Failed to build request: {}", e))?;
let response = self
.client
.chat()
.create(request)
.await
.map_err(|e| anyhow!("OpenAI API error: {}", e))?;
let choice = response
.choices
.first()
.ok_or_else(|| anyhow!("No response choices returned"))?;
let raw_content = choice.message.content.clone().unwrap_or_default();
let (reasoning, content) = parse_thinking_tokens(&raw_content);
let tool_calls = choice
.message
.tool_calls
.as_ref()
.map(|calls| {
calls
.iter()
.filter_map(|call| {
let arguments = serde_json::from_str(&call.function.arguments).ok()?;
Some(ToolCall {
id: call.id.clone(),
function_name: call.function.name.clone(),
arguments,
})
})
.collect::<Vec<_>>()
})
.filter(|calls| !calls.is_empty());
let usage = response.usage.map(|u| TokenUsage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
});
Ok(ModelResponse {
content,
model: response.model,
usage,
finish_reason: choice.finish_reason.as_ref().map(|r| format!("{:?}", r)),
tool_calls,
reasoning,
})
}
async fn stream(
&self,
prompt: &str,
config: &GenerationConfig,
) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
let messages = self.build_messages(prompt)?;
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder
.model(&self.model)
.messages(messages)
.stream(true);
if let Some(temp) = config.temperature {
request_builder.temperature(temp);
}
if let Some(max_tokens) = config.max_tokens {
request_builder.max_tokens(max_tokens);
}
if let Some(top_p) = config.top_p {
request_builder.top_p(top_p);
}
if let Some(freq_penalty) = config.frequency_penalty {
request_builder.frequency_penalty(freq_penalty);
}
if let Some(pres_penalty) = config.presence_penalty {
request_builder.presence_penalty(pres_penalty);
}
if let Some(stop) = &config.stop_sequences {
request_builder.stop(stop.clone());
}
let request = request_builder
.build()
.map_err(|e| anyhow!("Failed to build streaming request: {}", e))?;
let mut response_stream = self
.client
.chat()
.create_stream(request)
.await
.map_err(|e| anyhow!("OpenAI streaming API error: {}", e))?;
let stream = stream! {
use futures::StreamExt;
let mut buffer = String::new();
let mut in_think_block = false;
let mut think_ended = false;
while let Some(result) = response_stream.next().await {
match result {
Ok(response) => {
if let Some(choice) = response.choices.first() {
if let Some(content) = &choice.delta.content {
buffer.push_str(content);
if buffer.contains("<think>") && !in_think_block {
in_think_block = true;
}
if buffer.contains("</think>") && in_think_block {
in_think_block = false;
think_ended = true;
if let Some(idx) = buffer.find("</think>") {
buffer = buffer[idx + "</think>".len()..].to_string();
}
}
if !in_think_block && (think_ended || !buffer.contains("<think>")) {
let output = buffer.clone();
buffer.clear();
if !output.is_empty() {
yield Ok(output);
}
}
}
}
}
Err(e) => {
yield Err(anyhow!("Stream error: {}", e));
break;
}
}
}
if !buffer.is_empty() && !in_think_block {
yield Ok(buffer);
}
};
Ok(Box::pin(stream))
}
async fn generate_with_attachments(
&self,
prompt: &str,
attachments: &[ImageAttachment],
config: &GenerationConfig,
) -> Result<ModelResponse> {
if attachments.is_empty() {
return self.generate(prompt, config).await;
}
let messages = self.build_messages_with_attachments(prompt, attachments)?;
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder.model(&self.model).messages(messages);
if let Some(temp) = config.temperature {
request_builder.temperature(temp);
}
if let Some(max_tokens) = config.max_tokens {
request_builder.max_tokens(max_tokens);
}
if let Some(top_p) = config.top_p {
request_builder.top_p(top_p);
}
if let Some(freq_penalty) = config.frequency_penalty {
request_builder.frequency_penalty(freq_penalty);
}
if let Some(pres_penalty) = config.presence_penalty {
request_builder.presence_penalty(pres_penalty);
}
if let Some(stop) = &config.stop_sequences {
request_builder.stop(stop.clone());
}
if let Some(ref tools) = self.tools {
request_builder.tools(tools.clone());
}
let request = request_builder
.build()
.map_err(|e| anyhow!("Failed to build request: {}", e))?;
let response = self
.client
.chat()
.create(request)
.await
.map_err(|e| anyhow!("OpenAI API error: {}", e))?;
let choice = response
.choices
.first()
.ok_or_else(|| anyhow!("No response choices returned"))?;
let raw_content = choice.message.content.clone().unwrap_or_default();
let (reasoning, content) = parse_thinking_tokens(&raw_content);
let tool_calls = choice
.message
.tool_calls
.as_ref()
.map(|calls| {
calls
.iter()
.filter_map(|call| {
let arguments = serde_json::from_str(&call.function.arguments).ok()?;
Some(ToolCall {
id: call.id.clone(),
function_name: call.function.name.clone(),
arguments,
})
})
.collect::<Vec<_>>()
})
.filter(|calls| !calls.is_empty());
let usage = response.usage.map(|u| TokenUsage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
});
Ok(ModelResponse {
content,
model: response.model,
usage,
finish_reason: choice.finish_reason.as_ref().map(|r| format!("{:?}", r)),
tool_calls,
reasoning,
})
}
async fn stream_with_attachments(
&self,
prompt: &str,
attachments: &[ImageAttachment],
config: &GenerationConfig,
) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
if attachments.is_empty() {
return self.stream(prompt, config).await;
}
let messages = self.build_messages_with_attachments(prompt, attachments)?;
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder
.model(&self.model)
.messages(messages)
.stream(true);
if let Some(temp) = config.temperature {
request_builder.temperature(temp);
}
if let Some(max_tokens) = config.max_tokens {
request_builder.max_tokens(max_tokens);
}
if let Some(top_p) = config.top_p {
request_builder.top_p(top_p);
}
if let Some(freq_penalty) = config.frequency_penalty {
request_builder.frequency_penalty(freq_penalty);
}
if let Some(pres_penalty) = config.presence_penalty {
request_builder.presence_penalty(pres_penalty);
}
if let Some(stop) = &config.stop_sequences {
request_builder.stop(stop.clone());
}
if let Some(ref tools) = self.tools {
request_builder.tools(tools.clone());
}
let request = request_builder
.build()
.map_err(|e| anyhow!("Failed to build request: {}", e))?;
let mut response_stream = self
.client
.chat()
.create_stream(request)
.await
.map_err(|e| anyhow!("OpenAI API error: {}", e))?;
let stream = stream! {
use futures::StreamExt;
let mut buffer = String::new();
let mut in_think_block = false;
let mut think_ended = false;
while let Some(result) = response_stream.next().await {
match result {
Ok(response) => {
if let Some(choice) = response.choices.first() {
if let Some(content) = &choice.delta.content {
buffer.push_str(content);
if buffer.contains("<think>") && !in_think_block {
in_think_block = true;
}
if buffer.contains("</think>") && in_think_block {
in_think_block = false;
think_ended = true;
if let Some(idx) = buffer.find("</think>") {
buffer = buffer[idx + "</think>".len()..].to_string();
}
}
#[allow(clippy::if_same_then_else)]
if !in_think_block && think_ended {
yield Ok(buffer.clone());
buffer.clear();
} else if !in_think_block && !buffer.contains("<think>") {
yield Ok(buffer.clone());
buffer.clear();
}
}
}
}
Err(err) => {
yield Err(anyhow!("Stream error: {}", err));
break;
}
}
}
};
Ok(Box::pin(stream))
}
fn metadata(&self) -> ProviderMetadata {
ProviderMetadata {
name: "OpenAI".to_string(),
supported_models: vec![
"gpt-4.1".to_string(),
"gpt-4-turbo-preview".to_string(),
"gpt-4-32k".to_string(),
"gpt-4.1-mini".to_string(),
"gpt-4.1-mini-16k".to_string(),
],
supports_streaming: true,
}
}
fn kind(&self) -> ProviderKind {
ProviderKind::OpenAI
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_openai_provider_creation() {
let provider = OpenAIProvider::new();
assert_eq!(provider.model, "gpt-4.1-mini");
assert!(provider.system_message.is_none());
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_openai_provider_with_model() {
let provider = OpenAIProvider::new().with_model("gpt-4.1");
assert_eq!(provider.model, "gpt-4.1");
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_openai_provider_with_system_message() {
let provider = OpenAIProvider::new().with_system_message("You are a helpful assistant.");
assert_eq!(
provider.system_message,
Some("You are a helpful assistant.".to_string())
);
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_openai_provider_metadata() {
let provider = OpenAIProvider::new();
let metadata = provider.metadata();
assert_eq!(metadata.name, "OpenAI");
assert!(metadata.supports_streaming);
assert!(metadata.supported_models.contains(&"gpt-4.1".to_string()));
assert!(metadata
.supported_models
.contains(&"gpt-4.1-mini".to_string()));
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_openai_provider_kind() {
let provider = OpenAIProvider::new();
assert_eq!(provider.kind(), ProviderKind::OpenAI);
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_build_messages_without_system() {
let provider = OpenAIProvider::new();
let messages = provider.build_messages("Hello, world!").unwrap();
assert_eq!(messages.len(), 1);
}
#[test]
#[cfg_attr(
target_os = "macos",
ignore = "system proxy APIs unavailable in this environment"
)]
fn test_build_messages_with_system() {
let provider = OpenAIProvider::new().with_system_message("You are a helpful assistant.");
let messages = provider.build_messages("Hello, world!").unwrap();
assert_eq!(messages.len(), 2);
}
}