use async_openai::{
types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent,
CreateChatCompletionRequestArgs, ImageUrl,
},
Client,
};
use async_trait::async_trait;
use futures::stream::{BoxStream, StreamExt};
use crate::error::{AgentError, Result};
use crate::models::LLM;
use crate::types::{File, GenerationChunk, GenerationResponse, Message, Role};
pub struct OpenAILLM {
client: Client<async_openai::config::OpenAIConfig>,
model: String,
}
impl OpenAILLM {
pub fn new(model: impl Into<String>) -> Result<Self> {
let _ = std::env::var("OPENAI_API_KEY").map_err(|_| {
AgentError::ConfigError("OPENAI_API_KEY environment variable not set".to_string())
})?;
Ok(Self {
client: Client::new(),
model: model.into(),
})
}
pub fn with_api_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key);
Self {
client: Client::with_config(config),
model: model.into(),
}
}
}
impl OpenAILLM {
fn prepare_messages(
messages: Vec<Message>,
files: Option<Vec<File>>,
) -> Result<Vec<async_openai::types::ChatCompletionRequestMessage>> {
let mut chat_messages = Vec::new();
for msg in messages {
match msg.role {
Role::System => {
chat_messages.push(
ChatCompletionRequestSystemMessageArgs::default()
.content(msg.content)
.build()
.map_err(|e| {
AgentError::ModelError(format!(
"Failed to build system message: {}",
e
))
})?
.into(),
);
}
Role::User => {
chat_messages.push(
ChatCompletionRequestUserMessageArgs::default()
.content(msg.content)
.build()
.map_err(|e| {
AgentError::ModelError(format!(
"Failed to build user message: {}",
e
))
})?
.into(),
);
}
Role::Assistant => {
chat_messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
.content(msg.content)
.build()
.map_err(|e| {
AgentError::ModelError(format!(
"Failed to build assistant message: {}",
e
))
})?
.into(),
);
}
Role::Tool => {
chat_messages.push(
ChatCompletionRequestUserMessageArgs::default()
.content(format!("Tool output: {}", msg.content))
.build()
.map_err(|e| {
AgentError::ModelError(format!(
"Failed to build tool message: {}",
e
))
})?
.into(),
);
}
}
}
if let Some(files) = files {
if let Some(last_msg) = chat_messages.last_mut() {
if let async_openai::types::ChatCompletionRequestMessage::User(user_msg) = last_msg
{
let mut content_parts = Vec::new();
if let Some(content) = &user_msg.content {
match content {
ChatCompletionRequestUserMessageContent::Text(text) => {
content_parts.push(async_openai::types::ChatCompletionRequestMessageContentPart::Text(
async_openai::types::ChatCompletionRequestMessageContentPartTextArgs::default()
.text(text)
.build()
.unwrap()
));
}
ChatCompletionRequestUserMessageContent::Array(parts) => {
content_parts.extend(parts.clone());
}
}
}
for file in files {
if file.mime_type.starts_with("image/") {
let base64_image =
base64::engine::general_purpose::STANDARD.encode(&file.data);
let data_url =
format!("data:{};base64,{}", file.mime_type, base64_image);
content_parts.push(async_openai::types::ChatCompletionRequestMessageContentPart::ImageUrl(
async_openai::types::ChatCompletionRequestMessageContentPartImageArgs::default()
.image_url(
ImageUrl::default()
.url(data_url)
.detail(async_openai::types::ImageDetail::Auto)
)
.build()
.unwrap()
));
}
}
*user_msg = ChatCompletionRequestUserMessageArgs::default()
.content(content_parts)
.build()
.map_err(|e| {
AgentError::ModelError(format!(
"Failed to rebuild user message with images: {}",
e
))
})?;
}
}
}
Ok(chat_messages)
}
}
#[async_trait]
impl LLM for OpenAILLM {
async fn generate(
&self,
messages: Vec<Message>,
files: Option<Vec<File>>,
) -> Result<GenerationResponse> {
let chat_messages = Self::prepare_messages(messages, files)?;
let request = CreateChatCompletionRequestArgs::default()
.model(&self.model)
.messages(chat_messages)
.build()
.map_err(|e| AgentError::ModelError(format!("Failed to build request: {}", e)))?;
let response = self
.client
.chat()
.create(request)
.await
.map_err(|e| AgentError::ModelError(format!("OpenAI API error: {}", e)))?;
let content = response
.choices
.first()
.and_then(|c| c.message.content.clone())
.unwrap_or_default();
Ok(GenerationResponse {
content,
metadata: None,
})
}
async fn stream_generate(
&self,
messages: Vec<Message>,
files: Option<Vec<File>>,
) -> Result<BoxStream<'static, Result<GenerationChunk>>> {
let chat_messages = Self::prepare_messages(messages, files)?;
let request = CreateChatCompletionRequestArgs::default()
.model(&self.model)
.messages(chat_messages)
.stream(true)
.build()
.map_err(|e| AgentError::ModelError(format!("Failed to build request: {}", e)))?;
let stream = self
.client
.chat()
.create_stream(request)
.await
.map_err(|e| AgentError::ModelError(format!("OpenAI API error: {}", e)))?;
let mapped = stream.map(|chunk_res| match chunk_res {
Ok(chunk) => {
let content = chunk
.choices
.first()
.and_then(|c| c.delta.content.clone())
.unwrap_or_default();
Ok(GenerationChunk {
content,
metadata: None,
})
}
Err(e) => Err(AgentError::ModelError(format!("OpenAI stream error: {}", e))),
});
Ok(Box::pin(mapped))
}
fn model_name(&self) -> &str {
&self.model
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore] async fn test_openai_generate() {
let llm = OpenAILLM::new("gpt-3.5-turbo").unwrap();
let messages = vec![Message {
role: Role::User,
content: "Say 'Hello' and nothing else.".to_string(),
metadata: None,
}];
let response = llm.generate(messages, None).await.unwrap();
assert!(response.content.contains("Hello"));
}
}