Skip to main content

agent_base/llm/
openai.rs

1use async_trait::async_trait;
2use eventsource_stream::Eventsource;
3use futures_core::Stream;
4use futures_util::StreamExt;
5use reqwest::Client;
6use serde_json::{json, Value};
7use std::pin::Pin;
8
9use crate::types::{AgentResult, AgentError, ChatMessage, ImageAttachment, ImageDetail, ResponseFormat, ToolCallMessage};
10use super::{LlmCapabilities, LlmClient, StreamChunk, UsageInfo};
11
12
13pub struct OpenAiClient {
14    api_key: String,
15    model: String,
16    base_url: String,
17    client: Client,
18}
19
20impl OpenAiClient {
21    pub fn new(api_key: String, model: String, base_url: Option<String>) -> Self {
22        Self {
23            api_key,
24            model,
25            base_url: base_url
26                .unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
27            client: Client::new(),
28        }
29    }
30
31    fn chat_message_to_json(msg: &ChatMessage) -> Value {
32        match msg {
33            ChatMessage::System { content } => json!({
34                "role": "system",
35                "content": content,
36            }),
37            ChatMessage::User { content, images } => {
38                if images.is_empty() {
39                    json!({
40                        "role": "user",
41                        "content": content,
42                    })
43                } else {
44                    let mut content_parts: Vec<Value> = Vec::new();
45                    content_parts.push(json!({"type": "text", "text": content}));
46                    for img in images {
47                        content_parts.push(Self::image_to_json(img));
48                    }
49                    json!({
50                        "role": "user",
51                        "content": content_parts,
52                    })
53                }
54            }
55            ChatMessage::Assistant { content, reasoning_content, tool_calls } => {
56                let mut obj = serde_json::Map::new();
57                obj.insert("role".to_string(), json!("assistant"));
58                obj.insert("content".to_string(), json!(content));
59                if let Some(reasoning) = reasoning_content {
60                    obj.insert("reasoning_content".to_string(), json!(reasoning));
61                }
62                if let Some(tc) = tool_calls {
63                    let tool_calls_json: Vec<Value> = tc
64                        .iter()
65                        .map(|t| Self::tool_call_to_json(t))
66                        .collect();
67                    obj.insert("tool_calls".to_string(), json!(tool_calls_json));
68                }
69                Value::Object(obj)
70            }
71            ChatMessage::Tool { tool_call_id, content } => json!({
72                "role": "tool",
73                "tool_call_id": tool_call_id,
74                "content": content,
75            }),
76        }
77    }
78
79    fn tool_call_to_json(tc: &ToolCallMessage) -> Value {
80        json!({
81            "id": tc.id,
82            "type": "function",
83            "function": {
84                "name": tc.name,
85                "arguments": tc.arguments,
86            }
87        })
88    }
89
90    fn image_to_json(img: &ImageAttachment) -> Value {
91        match img {
92            ImageAttachment::Url { url, detail } => {
93                let mut obj = serde_json::Map::new();
94                obj.insert("url".to_string(), json!(url));
95                if let Some(d) = detail {
96                    let detail_str = match d {
97                        ImageDetail::Low => "low",
98                        ImageDetail::High => "high",
99                        ImageDetail::Auto => "auto",
100                    };
101                    obj.insert("detail".to_string(), json!(detail_str));
102                }
103                json!({
104                    "type": "image_url",
105                    "image_url": Value::Object(obj),
106                })
107            }
108            ImageAttachment::Base64 { data, media_type, detail } => {
109                let mime = media_type.as_deref().unwrap_or("image/jpeg");
110                let data_url = format!("data:{mime};base64,{data}");
111                let mut obj = serde_json::Map::new();
112                obj.insert("url".to_string(), json!(data_url));
113                if let Some(d) = detail {
114                    let detail_str = match d {
115                        ImageDetail::Low => "low",
116                        ImageDetail::High => "high",
117                        ImageDetail::Auto => "auto",
118                    };
119                    obj.insert("detail".to_string(), json!(detail_str));
120                }
121                json!({
122                    "type": "image_url",
123                    "image_url": Value::Object(obj),
124                })
125            }
126        }
127    }
128
129    fn messages_to_json(messages: &[ChatMessage]) -> Vec<Value> {
130        messages.iter().map(Self::chat_message_to_json).collect()
131    }
132}
133
134#[async_trait]
135impl LlmClient for OpenAiClient {
136    async fn chat(
137        &self,
138        messages: &[ChatMessage],
139        tools: &[Value],
140        enable_thinking: Option<bool>,
141        response_format: Option<&ResponseFormat>,
142    ) -> AgentResult<Value> {
143        let url = format!("{}/chat/completions", self.base_url);
144        let raw_messages = Self::messages_to_json(messages);
145        let mut request_body = json!({
146            "model": self.model,
147            "messages": raw_messages,
148            "tools": tools,
149        });
150
151        if let Some(thinking) = enable_thinking {
152            if let Some(obj) = request_body.as_object_mut() {
153                obj.insert("enable_thinking".to_string(), json!(thinking));
154            }
155        }
156
157        if let Some(rf) = response_format {
158            if let Some(obj) = request_body.as_object_mut() {
159                obj.insert("response_format".to_string(), rf.to_api_value());
160            }
161        }
162
163        let response = self
164            .client
165            .post(&url)
166            .header("Authorization", format!("Bearer {}", self.api_key))
167            .header("Content-Type", "application/json")
168            .json(&request_body)
169            .send()
170            .await
171            .map_err(|e| AgentError::llm(format!("HTTP request failed: {e}")))?;
172
173        let res_json: Value = response.json().await
174            .map_err(|e| AgentError::json(format!("Response JSON parse failed: {e}")))?;
175
176        if let Some(error) = res_json.get("error") {
177            return Err(AgentError::LlmApi {
178                message: format!("{error:#?}"),
179            });
180        }
181
182        Ok(res_json)
183    }
184
185    async fn chat_stream(
186        &self,
187        messages: &[ChatMessage],
188        tools: &[Value],
189        enable_thinking: Option<bool>,
190        response_format: Option<&ResponseFormat>,
191    ) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamChunk>> + Send>>> {
192        let url = format!("{}/chat/completions", self.base_url);
193        let raw_messages = Self::messages_to_json(messages);
194        let mut request_body = json!({
195            "model": self.model,
196            "messages": raw_messages,
197            "tools": tools,
198            "stream": true,
199            "stream_options": { "include_usage": true },
200        });
201
202        if let Some(thinking) = enable_thinking {
203            if let Some(obj) = request_body.as_object_mut() {
204                obj.insert("enable_thinking".to_string(), json!(thinking));
205            }
206        }
207
208        if let Some(rf) = response_format {
209            if let Some(obj) = request_body.as_object_mut() {
210                obj.insert("response_format".to_string(), rf.to_api_value());
211            }
212        }
213
214        let response = self
215            .client
216            .post(&url)
217            .header("Authorization", format!("Bearer {}", self.api_key))
218            .header("Content-Type", "application/json")
219            .json(&request_body)
220            .send()
221            .await
222            .map_err(|e| AgentError::llm(format!("HTTP request failed: {e}")))?;
223
224        if !response.status().is_success() {
225            let err_text = response.text().await
226                .map_err(|e| AgentError::llm(format!("Failed to read error response: {e}")))?;
227            return Err(AgentError::LlmApi { message: err_text });
228        }
229
230        let stream = response.bytes_stream().eventsource().map(|event| match event {
231            Ok(event) => {
232                if event.data == "[DONE]" {
233                    return Ok(StreamChunk::Stop);
234                }
235
236                let data: Value = serde_json::from_str(&event.data)
237                    .map_err(|e| AgentError::json(format!("JSON Parse error: {e}")))?;
238
239                let choices = data.get("choices").and_then(Value::as_array);
240
241                if choices.is_none() || choices.map_or(true, |c| c.is_empty()) {
242                    if let Some(usage) = data.get("usage") {
243                        return Ok(StreamChunk::Usage(UsageInfo {
244                            prompt_tokens: usage.get("prompt_tokens").and_then(Value::as_u64).map(|v| v as u32),
245                            completion_tokens: usage.get("completion_tokens").and_then(Value::as_u64).map(|v| v as u32),
246                            total_tokens: usage.get("total_tokens").and_then(Value::as_u64).map(|v| v as u32),
247                        }));
248                    }
249                    return Ok(StreamChunk::Text(String::new()));
250                }
251
252                let choice = &choices.unwrap()[0];
253                let delta = &choice["delta"];
254                let finish_reason = choice["finish_reason"].as_str().unwrap_or("");
255
256                if finish_reason == "tool_calls" || delta.get("tool_calls").is_some() {
257                    return Ok(StreamChunk::ToolCall(choice.clone()));
258                }
259
260                if let Some(reasoning) = delta.get("reasoning_content") {
261                    if let Some(text) = reasoning.as_str() {
262                        return Ok(StreamChunk::Thought(text.to_string()));
263                    }
264                }
265
266                if let Some(content) = delta.get("content") {
267                    if let Some(text) = content.as_str() {
268                        return Ok(StreamChunk::Text(text.to_string()));
269                    }
270                }
271
272                if finish_reason == "stop" {
273                    return Ok(StreamChunk::Stop);
274                }
275
276                Ok(StreamChunk::Text(String::new()))
277            }
278            Err(e) => Err(AgentError::LlmStream(format!("SSE Stream error: {e}"))),
279        });
280
281        Ok(Box::pin(stream))
282    }
283
284    fn capabilities(&self) -> LlmCapabilities {
285        LlmCapabilities {
286            supports_streaming: true,
287            supports_tools: true,
288            supports_vision: true,
289            supports_thinking: false,
290            max_context_tokens: Some(128_000),
291            max_output_tokens: Some(16_384),
292        }
293    }
294}