use super::{ChatStreamChunk, LlmProvider, Message, MessagePart, ToolCallChunk};
use anyhow::Result;
use async_openai::{
config::OpenAIConfig,
types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
ChatCompletionRequestMessageContentPart, ChatCompletionRequestMessageContentPartImageArgs,
ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
ChatCompletionRequestUserMessageContent, ChatCompletionTool,
CreateChatCompletionRequestArgs, ImageUrlArgs,
},
Client,
};
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use std::pin::Pin;
use std::time::Duration;
pub struct OpenAIProvider {
client: Client<OpenAIConfig>,
}
impl OpenAIProvider {
pub fn new() -> Self {
let raw_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
let api_key = raw_key.trim().to_string();
let raw_base = std::env::var("OPENAI_API_BASE")
.unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
let api_base = raw_base.trim().trim_end_matches('/').to_string();
let config = OpenAIConfig::new()
.with_api_key(api_key)
.with_api_base(api_base);
let http_client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(300))
.build()
.unwrap_or_default();
Self {
client: Client::with_config(config).with_http_client(http_client),
}
}
fn build_user_content(
&self,
parts_json: &serde_json::Value,
) -> ChatCompletionRequestUserMessageContent {
let mut content_parts = Vec::new();
if let Ok(parts) = serde_json::from_value::<Vec<MessagePart>>(parts_json.clone()) {
for part in parts {
if let Some(text) = part.content {
content_parts.push(ChatCompletionRequestMessageContentPart::Text(
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(text)
.build()
.unwrap(),
));
}
}
}
if content_parts.is_empty() {
ChatCompletionRequestUserMessageContent::Text("".to_string())
} else {
ChatCompletionRequestUserMessageContent::Array(content_parts)
}
}
}
#[async_trait]
impl LlmProvider for OpenAIProvider {
async fn stream_chat(
&self,
model: &str,
system_prompt: Option<String>, history: Vec<Message>,
tools: Option<Vec<serde_json::Value>>,
) -> Result<Pin<Box<dyn Stream<Item = Result<ChatStreamChunk>> + Send>>> {
let mut request_messages: Vec<ChatCompletionRequestMessage> = Vec::new();
if let Some(sp) = system_prompt {
if !sp.is_empty() {
request_messages.push(
ChatCompletionRequestSystemMessageArgs::default()
.content(sp)
.build()?
.into(),
);
}
}
for msg in history {
match msg.role.as_str() {
"system" => {
let text = if let Ok(parts) =
serde_json::from_value::<Vec<MessagePart>>(msg.parts.clone())
{
parts
.first()
.and_then(|p| p.content.clone())
.unwrap_or_default()
} else {
String::new()
};
request_messages.push(
ChatCompletionRequestSystemMessageArgs::default()
.content(text)
.build()?
.into(),
);
}
"user" => {
let content = self.build_user_content(&msg.parts);
request_messages.push(
ChatCompletionRequestUserMessageArgs::default()
.content(content)
.build()?
.into(),
);
}
"assistant" => {
let text = if let Ok(parts) =
serde_json::from_value::<Vec<MessagePart>>(msg.parts.clone())
{
parts
.first()
.and_then(|p| p.content.clone())
.unwrap_or_default()
} else {
String::new()
};
let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
let mut has_content = false;
if !text.is_empty() {
builder.content(text);
has_content = true;
}
if let Some(tc_json) = &msg.tool_calls {
if let Ok(tc_vec) = serde_json::from_value::<
Vec<async_openai::types::ChatCompletionMessageToolCall>,
>(tc_json.clone())
{
if !tc_vec.is_empty() {
builder.tool_calls(tc_vec);
has_content = true; }
}
}
if has_content {
request_messages.push(builder.build()?.into());
}
}
"tool" => {
let tool_call_id = msg
.tool_call_id
.clone()
.ok_or_else(|| anyhow::anyhow!("Tool message missing ID"))?;
let text = if let Ok(parts) =
serde_json::from_value::<Vec<MessagePart>>(msg.parts.clone())
{
parts
.first()
.and_then(|p| p.content.clone())
.unwrap_or_default()
} else {
String::new()
};
request_messages.push(
ChatCompletionRequestToolMessageArgs::default()
.content(text)
.tool_call_id(tool_call_id)
.build()?
.into(),
);
}
_ => {}
}
}
let mut request_tools: Option<Vec<ChatCompletionTool>> = None;
if let Some(t) = tools {
let mut converted_tools = Vec::new();
for tool_json in t {
if let Ok(tool) = serde_json::from_value::<ChatCompletionTool>(tool_json) {
converted_tools.push(tool);
}
}
if !converted_tools.is_empty() {
request_tools = Some(converted_tools);
}
}
let mut args = CreateChatCompletionRequestArgs::default();
args.model(model).messages(request_messages).stream(true);
if let Some(t) = request_tools {
args.tools(t);
}
let request = args.build()?;
let stream = self.client.chat().create_stream(request).await?;
let mapped_stream = stream.map(|item| match item {
Ok(resp) => {
let choice = resp.choices.first();
let content = choice.and_then(|c| c.delta.content.clone());
let mut tool_chunks = Vec::new();
if let Some(c) = choice {
if let Some(tool_calls) = &c.delta.tool_calls {
for tc in tool_calls {
tool_chunks.push(ToolCallChunk {
index: tc.index,
id: tc.id.clone(),
name: tc.function.as_ref().and_then(|f| f.name.clone()),
arguments: tc.function.as_ref().and_then(|f| f.arguments.clone()),
signature: None,
});
}
}
}
Ok(ChatStreamChunk {
content,
tool_calls: tool_chunks,
usage: None,
finish_reason: None,
})
}
Err(e) => Err(anyhow::anyhow!("ChatGPT Provider Error: {}", e)),
});
Ok(Box::pin(mapped_stream))
}
}