1use async_trait::async_trait;
2use eventsource_stream::Eventsource;
3use futures_core::Stream;
4use futures_util::StreamExt;
5use reqwest::Client;
6use serde_json::{json, Value};
7use std::pin::Pin;
8
9use crate::types::{AgentResult, AgentError, ChatMessage, ImageAttachment, ImageDetail, ResponseFormat, ToolCallMessage};
10use super::{LlmCapabilities, LlmClient, StreamChunk, UsageInfo};
11
12
13pub struct OpenAiClient {
14 api_key: String,
15 model: String,
16 base_url: String,
17 client: Client,
18}
19
20impl OpenAiClient {
21 pub fn new(api_key: String, model: String, base_url: Option<String>) -> Self {
22 Self {
23 api_key,
24 model,
25 base_url: base_url
26 .unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
27 client: Client::new(),
28 }
29 }
30
31 fn chat_message_to_json(msg: &ChatMessage) -> Value {
32 match msg {
33 ChatMessage::System { content } => json!({
34 "role": "system",
35 "content": content,
36 }),
37 ChatMessage::User { content, images } => {
38 if images.is_empty() {
39 json!({
40 "role": "user",
41 "content": content,
42 })
43 } else {
44 let mut content_parts: Vec<Value> = Vec::new();
45 content_parts.push(json!({"type": "text", "text": content}));
46 for img in images {
47 content_parts.push(Self::image_to_json(img));
48 }
49 json!({
50 "role": "user",
51 "content": content_parts,
52 })
53 }
54 }
55 ChatMessage::Assistant { content, reasoning_content, tool_calls } => {
56 let mut obj = serde_json::Map::new();
57 obj.insert("role".to_string(), json!("assistant"));
58 obj.insert("content".to_string(), json!(content));
59 if let Some(reasoning) = reasoning_content {
60 obj.insert("reasoning_content".to_string(), json!(reasoning));
61 }
62 if let Some(tc) = tool_calls {
63 let tool_calls_json: Vec<Value> = tc
64 .iter()
65 .map(|t| Self::tool_call_to_json(t))
66 .collect();
67 obj.insert("tool_calls".to_string(), json!(tool_calls_json));
68 }
69 Value::Object(obj)
70 }
71 ChatMessage::Tool { tool_call_id, content } => json!({
72 "role": "tool",
73 "tool_call_id": tool_call_id,
74 "content": content,
75 }),
76 }
77 }
78
79 fn tool_call_to_json(tc: &ToolCallMessage) -> Value {
80 json!({
81 "id": tc.id,
82 "type": "function",
83 "function": {
84 "name": tc.name,
85 "arguments": tc.arguments,
86 }
87 })
88 }
89
90 fn image_to_json(img: &ImageAttachment) -> Value {
91 match img {
92 ImageAttachment::Url { url, detail } => {
93 let mut obj = serde_json::Map::new();
94 obj.insert("url".to_string(), json!(url));
95 if let Some(d) = detail {
96 let detail_str = match d {
97 ImageDetail::Low => "low",
98 ImageDetail::High => "high",
99 ImageDetail::Auto => "auto",
100 };
101 obj.insert("detail".to_string(), json!(detail_str));
102 }
103 json!({
104 "type": "image_url",
105 "image_url": Value::Object(obj),
106 })
107 }
108 ImageAttachment::Base64 { data, media_type, detail } => {
109 let mime = media_type.as_deref().unwrap_or("image/jpeg");
110 let data_url = format!("data:{mime};base64,{data}");
111 let mut obj = serde_json::Map::new();
112 obj.insert("url".to_string(), json!(data_url));
113 if let Some(d) = detail {
114 let detail_str = match d {
115 ImageDetail::Low => "low",
116 ImageDetail::High => "high",
117 ImageDetail::Auto => "auto",
118 };
119 obj.insert("detail".to_string(), json!(detail_str));
120 }
121 json!({
122 "type": "image_url",
123 "image_url": Value::Object(obj),
124 })
125 }
126 }
127 }
128
129 fn messages_to_json(messages: &[ChatMessage]) -> Vec<Value> {
130 messages.iter().map(Self::chat_message_to_json).collect()
131 }
132}
133
134#[async_trait]
135impl LlmClient for OpenAiClient {
136 async fn chat(
137 &self,
138 messages: &[ChatMessage],
139 tools: &[Value],
140 enable_thinking: Option<bool>,
141 response_format: Option<&ResponseFormat>,
142 ) -> AgentResult<Value> {
143 let url = format!("{}/chat/completions", self.base_url);
144 let raw_messages = Self::messages_to_json(messages);
145 let mut request_body = json!({
146 "model": self.model,
147 "messages": raw_messages,
148 "tools": tools,
149 });
150
151 if let Some(thinking) = enable_thinking {
152 if let Some(obj) = request_body.as_object_mut() {
153 obj.insert("enable_thinking".to_string(), json!(thinking));
154 }
155 }
156
157 if let Some(rf) = response_format {
158 if let Some(obj) = request_body.as_object_mut() {
159 obj.insert("response_format".to_string(), rf.to_api_value());
160 }
161 }
162
163 let response = self
164 .client
165 .post(&url)
166 .header("Authorization", format!("Bearer {}", self.api_key))
167 .header("Content-Type", "application/json")
168 .json(&request_body)
169 .send()
170 .await
171 .map_err(|e| AgentError::llm(format!("HTTP request failed: {e}")))?;
172
173 let res_json: Value = response.json().await
174 .map_err(|e| AgentError::json(format!("Response JSON parse failed: {e}")))?;
175
176 if let Some(error) = res_json.get("error") {
177 return Err(AgentError::LlmApi {
178 message: format!("{error:#?}"),
179 });
180 }
181
182 Ok(res_json)
183 }
184
185 async fn chat_stream(
186 &self,
187 messages: &[ChatMessage],
188 tools: &[Value],
189 enable_thinking: Option<bool>,
190 response_format: Option<&ResponseFormat>,
191 ) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamChunk>> + Send>>> {
192 let url = format!("{}/chat/completions", self.base_url);
193 let raw_messages = Self::messages_to_json(messages);
194 let mut request_body = json!({
195 "model": self.model,
196 "messages": raw_messages,
197 "tools": tools,
198 "stream": true,
199 "stream_options": { "include_usage": true },
200 });
201
202 if let Some(thinking) = enable_thinking {
203 if let Some(obj) = request_body.as_object_mut() {
204 obj.insert("enable_thinking".to_string(), json!(thinking));
205 }
206 }
207
208 if let Some(rf) = response_format {
209 if let Some(obj) = request_body.as_object_mut() {
210 obj.insert("response_format".to_string(), rf.to_api_value());
211 }
212 }
213
214 let response = self
215 .client
216 .post(&url)
217 .header("Authorization", format!("Bearer {}", self.api_key))
218 .header("Content-Type", "application/json")
219 .json(&request_body)
220 .send()
221 .await
222 .map_err(|e| AgentError::llm(format!("HTTP request failed: {e}")))?;
223
224 if !response.status().is_success() {
225 let err_text = response.text().await
226 .map_err(|e| AgentError::llm(format!("Failed to read error response: {e}")))?;
227 return Err(AgentError::LlmApi { message: err_text });
228 }
229
230 let stream = response.bytes_stream().eventsource().map(|event| match event {
231 Ok(event) => {
232 if event.data == "[DONE]" {
233 return Ok(StreamChunk::Stop);
234 }
235
236 let data: Value = serde_json::from_str(&event.data)
237 .map_err(|e| AgentError::json(format!("JSON Parse error: {e}")))?;
238
239 let choices = data.get("choices").and_then(Value::as_array);
240
241 if choices.is_none() || choices.map_or(true, |c| c.is_empty()) {
242 if let Some(usage) = data.get("usage") {
243 return Ok(StreamChunk::Usage(UsageInfo {
244 prompt_tokens: usage.get("prompt_tokens").and_then(Value::as_u64).map(|v| v as u32),
245 completion_tokens: usage.get("completion_tokens").and_then(Value::as_u64).map(|v| v as u32),
246 total_tokens: usage.get("total_tokens").and_then(Value::as_u64).map(|v| v as u32),
247 }));
248 }
249 return Ok(StreamChunk::Text(String::new()));
250 }
251
252 let choice = &choices.unwrap()[0];
253 let delta = &choice["delta"];
254 let finish_reason = choice["finish_reason"].as_str().unwrap_or("");
255
256 if finish_reason == "tool_calls" || delta.get("tool_calls").is_some() {
257 return Ok(StreamChunk::ToolCall(choice.clone()));
258 }
259
260 if let Some(reasoning) = delta.get("reasoning_content") {
261 if let Some(text) = reasoning.as_str() {
262 return Ok(StreamChunk::Thought(text.to_string()));
263 }
264 }
265
266 if let Some(content) = delta.get("content") {
267 if let Some(text) = content.as_str() {
268 return Ok(StreamChunk::Text(text.to_string()));
269 }
270 }
271
272 if finish_reason == "stop" {
273 return Ok(StreamChunk::Stop);
274 }
275
276 Ok(StreamChunk::Text(String::new()))
277 }
278 Err(e) => Err(AgentError::LlmStream(format!("SSE Stream error: {e}"))),
279 });
280
281 Ok(Box::pin(stream))
282 }
283
284 fn capabilities(&self) -> LlmCapabilities {
285 LlmCapabilities {
286 supports_streaming: true,
287 supports_tools: true,
288 supports_vision: true,
289 supports_thinking: false,
290 max_context_tokens: Some(128_000),
291 max_output_tokens: Some(16_384),
292 }
293 }
294}