use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures_core::Stream;
use futures_util::StreamExt;
use reqwest::Client;
use serde_json::{json, Value};
use std::pin::Pin;
use crate::types::{AgentResult, AgentError, ChatMessage, ImageAttachment, ImageDetail, ResponseFormat, ToolCallMessage};
use super::{LlmCapabilities, LlmClient, StreamChunk, UsageInfo};
pub struct OpenAiClient {
api_key: String,
model: String,
base_url: String,
client: Client,
}
impl OpenAiClient {
pub fn new(api_key: String, model: String, base_url: Option<String>) -> Self {
Self {
api_key,
model,
base_url: base_url
.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
client: Client::new(),
}
}
fn chat_message_to_json(msg: &ChatMessage) -> Value {
match msg {
ChatMessage::System { content } => json!({
"role": "system",
"content": content,
}),
ChatMessage::User { content, images } => {
if images.is_empty() {
json!({
"role": "user",
"content": content,
})
} else {
let mut content_parts: Vec<Value> = Vec::new();
content_parts.push(json!({"type": "text", "text": content}));
for img in images {
content_parts.push(Self::image_to_json(img));
}
json!({
"role": "user",
"content": content_parts,
})
}
}
ChatMessage::Assistant { content, reasoning_content, tool_calls } => {
let mut obj = serde_json::Map::new();
obj.insert("role".to_string(), json!("assistant"));
obj.insert("content".to_string(), json!(content));
if let Some(reasoning) = reasoning_content {
obj.insert("reasoning_content".to_string(), json!(reasoning));
}
if let Some(tc) = tool_calls {
let tool_calls_json: Vec<Value> = tc
.iter()
.map(|t| Self::tool_call_to_json(t))
.collect();
obj.insert("tool_calls".to_string(), json!(tool_calls_json));
}
Value::Object(obj)
}
ChatMessage::Tool { tool_call_id, content } => json!({
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}),
}
}
fn tool_call_to_json(tc: &ToolCallMessage) -> Value {
json!({
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": tc.arguments,
}
})
}
fn image_to_json(img: &ImageAttachment) -> Value {
match img {
ImageAttachment::Url { url, detail } => {
let mut obj = serde_json::Map::new();
obj.insert("url".to_string(), json!(url));
if let Some(d) = detail {
let detail_str = match d {
ImageDetail::Low => "low",
ImageDetail::High => "high",
ImageDetail::Auto => "auto",
};
obj.insert("detail".to_string(), json!(detail_str));
}
json!({
"type": "image_url",
"image_url": Value::Object(obj),
})
}
ImageAttachment::Base64 { data, media_type, detail } => {
let mime = media_type.as_deref().unwrap_or("image/jpeg");
let data_url = format!("data:{mime};base64,{data}");
let mut obj = serde_json::Map::new();
obj.insert("url".to_string(), json!(data_url));
if let Some(d) = detail {
let detail_str = match d {
ImageDetail::Low => "low",
ImageDetail::High => "high",
ImageDetail::Auto => "auto",
};
obj.insert("detail".to_string(), json!(detail_str));
}
json!({
"type": "image_url",
"image_url": Value::Object(obj),
})
}
}
}
fn messages_to_json(messages: &[ChatMessage]) -> Vec<Value> {
messages.iter().map(Self::chat_message_to_json).collect()
}
}
#[async_trait]
impl LlmClient for OpenAiClient {
async fn chat(
&self,
messages: &[ChatMessage],
tools: &[Value],
enable_thinking: Option<bool>,
response_format: Option<&ResponseFormat>,
) -> AgentResult<Value> {
let url = format!("{}/chat/completions", self.base_url);
let raw_messages = Self::messages_to_json(messages);
let mut request_body = json!({
"model": self.model,
"messages": raw_messages,
"tools": tools,
});
if let Some(thinking) = enable_thinking {
if let Some(obj) = request_body.as_object_mut() {
obj.insert("enable_thinking".to_string(), json!(thinking));
}
}
if let Some(rf) = response_format {
if let Some(obj) = request_body.as_object_mut() {
obj.insert("response_format".to_string(), rf.to_api_value());
}
}
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| AgentError::llm(format!("HTTP request failed: {e}")))?;
let res_json: Value = response.json().await
.map_err(|e| AgentError::json(format!("Response JSON parse failed: {e}")))?;
if let Some(error) = res_json.get("error") {
return Err(AgentError::LlmApi {
message: format!("{error:#?}"),
});
}
Ok(res_json)
}
async fn chat_stream(
&self,
messages: &[ChatMessage],
tools: &[Value],
enable_thinking: Option<bool>,
response_format: Option<&ResponseFormat>,
) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamChunk>> + Send>>> {
let url = format!("{}/chat/completions", self.base_url);
let raw_messages = Self::messages_to_json(messages);
let mut request_body = json!({
"model": self.model,
"messages": raw_messages,
"tools": tools,
"stream": true,
"stream_options": { "include_usage": true },
});
if let Some(thinking) = enable_thinking {
if let Some(obj) = request_body.as_object_mut() {
obj.insert("enable_thinking".to_string(), json!(thinking));
}
}
if let Some(rf) = response_format {
if let Some(obj) = request_body.as_object_mut() {
obj.insert("response_format".to_string(), rf.to_api_value());
}
}
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| AgentError::llm(format!("HTTP request failed: {e}")))?;
if !response.status().is_success() {
let err_text = response.text().await
.map_err(|e| AgentError::llm(format!("Failed to read error response: {e}")))?;
return Err(AgentError::LlmApi { message: err_text });
}
let stream = response.bytes_stream().eventsource().map(|event| match event {
Ok(event) => {
if event.data == "[DONE]" {
return Ok(StreamChunk::Stop);
}
let data: Value = serde_json::from_str(&event.data)
.map_err(|e| AgentError::json(format!("JSON Parse error: {e}")))?;
let choices = data.get("choices").and_then(Value::as_array);
if choices.is_none() || choices.map_or(true, |c| c.is_empty()) {
if let Some(usage) = data.get("usage") {
return Ok(StreamChunk::Usage(UsageInfo {
prompt_tokens: usage.get("prompt_tokens").and_then(Value::as_u64).map(|v| v as u32),
completion_tokens: usage.get("completion_tokens").and_then(Value::as_u64).map(|v| v as u32),
total_tokens: usage.get("total_tokens").and_then(Value::as_u64).map(|v| v as u32),
}));
}
return Ok(StreamChunk::Text(String::new()));
}
let choice = &choices.unwrap()[0];
let delta = &choice["delta"];
let finish_reason = choice["finish_reason"].as_str().unwrap_or("");
if finish_reason == "tool_calls" || delta.get("tool_calls").is_some() {
return Ok(StreamChunk::ToolCall(choice.clone()));
}
if let Some(reasoning) = delta.get("reasoning_content") {
if let Some(text) = reasoning.as_str() {
return Ok(StreamChunk::Thought(text.to_string()));
}
}
if let Some(content) = delta.get("content") {
if let Some(text) = content.as_str() {
return Ok(StreamChunk::Text(text.to_string()));
}
}
if finish_reason == "stop" {
return Ok(StreamChunk::Stop);
}
Ok(StreamChunk::Text(String::new()))
}
Err(e) => Err(AgentError::LlmStream(format!("SSE Stream error: {e}"))),
});
Ok(Box::pin(stream))
}
fn capabilities(&self) -> LlmCapabilities {
LlmCapabilities {
supports_streaming: true,
supports_tools: true,
supports_vision: true,
supports_thinking: false,
max_context_tokens: Some(128_000),
max_output_tokens: Some(16_384),
}
}
}