use crate::error::{AgentError, Result};
use crate::message::{Message, MessageRole};
use crate::provider::{ModelConfig, ModelProvider, ModelResponse, Usage};
use async_trait::async_trait;
use serde::Deserialize;
use serde_json::json;
pub struct OpenAIProvider {
api_key: String,
base_url: String,
client: reqwest::Client,
}
impl OpenAIProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: "https://api.openai.com/v1".to_string(),
client: reqwest::Client::new(),
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
fn convert_messages(&self, messages: Vec<Message>) -> Vec<serde_json::Value> {
messages
.into_iter()
.map(|msg| {
let role = match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "tool",
};
json!({
"role": role,
"content": msg.content
})
})
.collect()
}
}
#[derive(Debug, Deserialize)]
struct OpenAIResponse {
choices: Vec<Choice>,
usage: OpenAIUsage,
model: String,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: MessageContent,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct MessageContent {
content: String,
}
#[derive(Debug, Deserialize)]
struct OpenAIUsage {
prompt_tokens: usize,
completion_tokens: usize,
total_tokens: usize,
}
#[async_trait]
impl ModelProvider for OpenAIProvider {
fn name(&self) -> &str {
"openai"
}
async fn complete(&self, messages: Vec<Message>, config: &ModelConfig) -> Result<ModelResponse> {
let url = format!("{}/chat/completions", self.base_url);
let converted_messages = self.convert_messages(messages);
let mut body = json!({
"model": config.model,
"messages": converted_messages,
"temperature": config.temperature,
});
if let Some(max_tokens) = config.max_tokens {
body["max_tokens"] = json!(max_tokens);
}
if let Some(top_p) = config.top_p {
body["top_p"] = json!(top_p);
}
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| AgentError::ExecutionError(format!("OpenAI API request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AgentError::ExecutionError(format!(
"OpenAI API error: {}",
error_text
)));
}
let api_response: OpenAIResponse = response
.json()
.await
.map_err(|e| AgentError::ExecutionError(format!("Failed to parse OpenAI response: {}", e)))?;
let choice = api_response
.choices
.first()
.ok_or_else(|| AgentError::ExecutionError("No choices in OpenAI response".to_string()))?;
Ok(ModelResponse {
content: choice.message.content.clone(),
model: api_response.model,
usage: Some(Usage {
prompt_tokens: api_response.usage.prompt_tokens,
completion_tokens: api_response.usage.completion_tokens,
total_tokens: api_response.usage.total_tokens,
}),
finish_reason: choice.finish_reason.clone(),
})
}
async fn stream_complete(
&self,
_messages: Vec<Message>,
_config: &ModelConfig,
) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Unpin + Send>> {
Err(AgentError::ExecutionError(
"Streaming not yet implemented for OpenAI".to_string(),
))
}
}