use crate::context::{Message, MessageRole};
use crate::error::AgentError;
use crate::provider::{LLMProvider, ProviderConfig, StreamChunk};
use async_openai::{
config::OpenAIConfig,
types::{
ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage,
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs,
},
Client,
};
use async_trait::async_trait;
use futures::Stream;
use futures::StreamExt;
use std::pin::Pin;
use tracing::{debug, info, trace, warn};
pub struct OpenAIProvider {
client: Client<OpenAIConfig>,
config: ProviderConfig,
}
impl OpenAIProvider {
pub fn new(api_key: impl Into<String>) -> Self {
let openai_config = OpenAIConfig::new().with_api_key(api_key);
let client = Client::with_config(openai_config);
Self {
client,
config: ProviderConfig::new("gpt-4"),
}
}
pub fn from_env() -> Result<Self, AgentError> {
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
AgentError::Configuration("OPENAI_API_KEY environment variable not set".to_string())
})?;
Ok(Self::new(api_key))
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.config.model = model.into();
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.config = self.config.with_temperature(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.config = self.config.with_max_tokens(max_tokens);
self
}
fn convert_messages(&self, messages: Vec<Message>) -> Vec<ChatCompletionRequestMessage> {
messages
.into_iter()
.map(|m| match m.role {
MessageRole::System => {
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content:
async_openai::types::ChatCompletionRequestSystemMessageContent::Text(
m.content,
),
name: None,
})
}
MessageRole::User => {
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
m.content,
),
name: None,
})
}
MessageRole::Assistant => {
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
content: Some(
async_openai::types::ChatCompletionRequestAssistantMessageContent::Text(
m.content,
),
),
name: None,
tool_calls: None,
refusal: None,
#[allow(deprecated)]
function_call: None,
})
}
MessageRole::Tool => {
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content:
async_openai::types::ChatCompletionRequestSystemMessageContent::Text(
format!("Tool result: {}", m.content),
),
name: None,
})
}
})
.collect()
}
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
async fn complete(&self, messages: Vec<Message>) -> std::result::Result<String, AgentError> {
info!(
model = %self.config.model,
message_count = messages.len(),
"Requesting OpenAI completion"
);
let openai_messages = self.convert_messages(messages);
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder
.model(&self.config.model)
.messages(openai_messages)
.temperature(self.config.temperature);
if let Some(max_tokens) = self.config.max_tokens {
request_builder.max_tokens(max_tokens);
}
if let Some(top_p) = self.config.top_p {
request_builder.top_p(top_p);
}
if let Some(frequency_penalty) = self.config.frequency_penalty {
request_builder.frequency_penalty(frequency_penalty);
}
if let Some(presence_penalty) = self.config.presence_penalty {
request_builder.presence_penalty(presence_penalty);
}
let request = request_builder
.build()
.map_err(|e| AgentError::ProviderError(format!("Failed to build request: {}", e)))?;
trace!("Sending request to OpenAI");
let response = self.client.chat().create(request).await.map_err(|e| {
warn!(error = %e, "OpenAI API error");
AgentError::ProviderError(format!("OpenAI API error: {}", e))
})?;
let message = response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.ok_or_else(|| {
warn!("No content in OpenAI response");
AgentError::ProviderError("No content in OpenAI response".to_string())
})?;
debug!(
response_length = message.len(),
"OpenAI completion successful"
);
Ok(message)
}
async fn stream(
&self,
messages: Vec<Message>,
) -> std::result::Result<Pin<Box<dyn Stream<Item = StreamChunk> + Send>>, AgentError> {
info!(
model = %self.config.model,
message_count = messages.len(),
"Requesting OpenAI streaming completion"
);
let openai_messages = self.convert_messages(messages);
let mut request_builder = CreateChatCompletionRequestArgs::default();
request_builder
.model(&self.config.model)
.messages(openai_messages)
.temperature(self.config.temperature);
if let Some(max_tokens) = self.config.max_tokens {
request_builder.max_tokens(max_tokens);
}
if let Some(top_p) = self.config.top_p {
request_builder.top_p(top_p);
}
let request = request_builder
.build()
.map_err(|e| AgentError::ProviderError(format!("Failed to build request: {}", e)))?;
trace!("Sending streaming request to OpenAI");
let stream = self
.client
.chat()
.create_stream(request)
.await
.map_err(|e| {
warn!(error = %e, "OpenAI streaming error");
AgentError::ProviderError(format!("OpenAI streaming error: {}", e))
})?;
let mapped_stream = stream.map(|result| {
result
.map_err(|e| AgentError::ProviderError(format!("Stream error: {}", e)))
.and_then(|response| {
response
.choices
.first()
.and_then(|choice| choice.delta.content.clone())
.ok_or_else(|| {
AgentError::ProviderError("No content in stream chunk".to_string())
})
})
});
Ok(Box::pin(mapped_stream))
}
fn name(&self) -> &str {
"OpenAI"
}
fn config(&self) -> &ProviderConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_openai_provider_creation() {
let provider = OpenAIProvider::new("test-api-key");
assert_eq!(provider.name(), "OpenAI");
assert_eq!(provider.config().model, "gpt-4");
assert_eq!(provider.config().temperature, 0.7);
}
#[test]
fn test_openai_provider_with_model() {
let provider = OpenAIProvider::new("test-api-key").with_model("gpt-3.5-turbo");
assert_eq!(provider.config().model, "gpt-3.5-turbo");
}
#[test]
fn test_openai_provider_with_temperature() {
let provider = OpenAIProvider::new("test-api-key").with_temperature(0.5);
assert_eq!(provider.config().temperature, 0.5);
}
#[test]
fn test_openai_provider_with_max_tokens() {
let provider = OpenAIProvider::new("test-api-key").with_max_tokens(1000);
assert_eq!(provider.config().max_tokens, Some(1000));
}
#[test]
fn test_message_conversion() {
let provider = OpenAIProvider::new("test-api-key");
let messages = vec![
Message::system("You are a helpful assistant"),
Message::user("Hello"),
Message::assistant("Hi there!"),
];
let converted = provider.convert_messages(messages);
assert_eq!(converted.len(), 3);
}
}