1use async_trait::async_trait;
4use derive_builder::Builder;
5use futures::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9
10use crate::llm::{
11 BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, StopReason, ToolChoice,
12 ToolDefinition, Usage,
13};
14
15const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
16
17#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19#[serde(rename_all = "lowercase")]
20pub enum ReasoningEffort {
21 Low,
22 Medium,
23 High,
24 #[default]
25 Minimal,
26}
27
28#[derive(Builder, Clone)]
30#[builder(pattern = "owned", build_fn(skip))]
31pub struct ChatOpenAI {
32 #[builder(setter(into))]
34 model: String,
35 api_key: String,
37 #[builder(setter(into, strip_option), default = "None")]
39 base_url: Option<String>,
40 #[builder(default = "0.2")]
42 temperature: f32,
43 #[builder(default = "Some(4096)")]
45 max_completion_tokens: Option<u64>,
46 #[builder(default = "ReasoningEffort::Low")]
48 reasoning_effort: ReasoningEffort,
49 #[builder(setter(skip))]
51 client: Client,
52 #[builder(setter(skip))]
54 context_window: u64,
55}
56
57impl ChatOpenAI {
58 pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
60 let api_key = std::env::var("OPENAI_API_KEY")
61 .map_err(|_| LlmError::Config("OPENAI_API_KEY not set".into()))?;
62
63 Self::builder().model(model).api_key(api_key).build()
64 }
65
66 pub fn builder() -> ChatOpenAIBuilder {
68 ChatOpenAIBuilder::default()
69 }
70
71 fn is_reasoning_model(&self) -> bool {
73 let model_lower = self.model.to_lowercase();
74 model_lower.starts_with("o1")
75 || model_lower.starts_with("o3")
76 || model_lower.starts_with("o4")
77 || model_lower.starts_with("gpt-5")
78 }
79
80 fn api_url(&self) -> &str {
82 self.base_url.as_deref().unwrap_or(OPENAI_API_URL)
83 }
84
85 fn build_client() -> Client {
87 Client::builder()
88 .timeout(Duration::from_secs(120))
89 .build()
90 .expect("Failed to create HTTP client")
91 }
92
93 fn get_context_window(model: &str) -> u64 {
95 let model_lower = model.to_lowercase();
96
97 if model_lower.contains("gpt-4o") || model_lower.contains("gpt-4-turbo") {
99 128_000
100 }
101 else if model_lower.starts_with("gpt-4") {
103 8_192
104 }
105 else if model_lower.starts_with("gpt-3.5") {
107 16_385
108 }
109 else if model_lower.starts_with("o1")
111 || model_lower.starts_with("o3")
112 || model_lower.starts_with("o4")
113 {
114 200_000
115 }
116 else {
118 128_000
119 }
120 }
121}
122
123impl ChatOpenAIBuilder {
124 pub fn build(&self) -> Result<ChatOpenAI, LlmError> {
125 let model = self
126 .model
127 .clone()
128 .ok_or_else(|| LlmError::Config("model is required".into()))?;
129 let api_key = self
130 .api_key
131 .clone()
132 .ok_or_else(|| LlmError::Config("api_key is required".into()))?;
133
134 Ok(ChatOpenAI {
135 context_window: ChatOpenAI::get_context_window(&model),
136 client: ChatOpenAI::build_client(),
137 model,
138 api_key,
139 base_url: self.base_url.clone().flatten(),
140 temperature: self.temperature.unwrap_or(0.2),
141 max_completion_tokens: self.max_completion_tokens.flatten(),
142 reasoning_effort: self.reasoning_effort.clone().unwrap_or_default(),
143 })
144 }
145}
146
147#[async_trait]
148impl BaseChatModel for ChatOpenAI {
149 fn model(&self) -> &str {
150 &self.model
151 }
152
153 fn provider(&self) -> &str {
154 "openai"
155 }
156
157 fn context_window(&self) -> Option<u64> {
158 Some(self.context_window)
159 }
160
161 async fn invoke(
162 &self,
163 messages: Vec<Message>,
164 tools: Option<Vec<ToolDefinition>>,
165 tool_choice: Option<ToolChoice>,
166 ) -> Result<ChatCompletion, LlmError> {
167 let request = self.build_request(messages, tools, tool_choice, false)?;
168
169 let response = self
170 .client
171 .post(self.api_url())
172 .header("Authorization", format!("Bearer {}", self.api_key))
173 .header("Content-Type", "application/json")
174 .json(&request)
175 .send()
176 .await?;
177
178 if !response.status().is_success() {
179 let status = response.status();
180 let body = response.text().await.unwrap_or_default();
181 return Err(LlmError::Api(format!(
182 "OpenAI API error ({}): {}",
183 status, body
184 )));
185 }
186
187 let completion: OpenAIResponse = response.json().await?;
188 Ok(self.parse_response(completion))
189 }
190
191 async fn invoke_stream(
192 &self,
193 messages: Vec<Message>,
194 tools: Option<Vec<ToolDefinition>>,
195 tool_choice: Option<ToolChoice>,
196 ) -> Result<ChatStream, LlmError> {
197 let request = self.build_request(messages, tools, tool_choice, true)?;
198
199 let response = self
200 .client
201 .post(self.api_url())
202 .header("Authorization", format!("Bearer {}", self.api_key))
203 .header("Content-Type", "application/json")
204 .json(&request)
205 .send()
206 .await?;
207
208 if !response.status().is_success() {
209 let status = response.status();
210 let body = response.text().await.unwrap_or_default();
211 return Err(LlmError::Api(format!(
212 "OpenAI API error ({}): {}",
213 status, body
214 )));
215 }
216
217 let stream = response.bytes_stream().filter_map(|result| async move {
218 match result {
219 Ok(bytes) => {
220 let text = String::from_utf8_lossy(&bytes);
221 Self::parse_stream_chunk(&text)
222 }
223 Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
224 }
225 });
226
227 Ok(Box::pin(stream))
228 }
229
230 fn supports_vision(&self) -> bool {
231 let model_lower = self.model.to_lowercase();
232 model_lower.contains("gpt-4o")
233 || model_lower.contains("gpt-4-turbo")
234 || model_lower.contains("gpt-4-vision")
235 || model_lower.contains("gpt-4.1")
236 }
237}
238
239#[derive(Serialize)]
244struct OpenAIRequest {
245 model: String,
246 messages: Vec<OpenAIMessage>,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 tools: Option<Vec<OpenAITool>>,
249 #[serde(skip_serializing_if = "Option::is_none")]
250 tool_choice: Option<serde_json::Value>,
251 #[serde(skip_serializing_if = "Option::is_none")]
252 temperature: Option<f32>,
253 #[serde(skip_serializing_if = "Option::is_none")]
254 max_completion_tokens: Option<u64>,
255 #[serde(skip_serializing_if = "Option::is_none")]
256 reasoning_effort: Option<ReasoningEffort>,
257 #[serde(skip_serializing_if = "Option::is_none")]
258 stream: Option<bool>,
259}
260
261#[derive(Serialize)]
262struct OpenAIMessage {
263 role: String,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 content: Option<serde_json::Value>,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 name: Option<String>,
268 #[serde(skip_serializing_if = "Option::is_none")]
269 tool_calls: Option<Vec<crate::llm::ToolCall>>,
270 #[serde(skip_serializing_if = "Option::is_none")]
271 tool_call_id: Option<String>,
272}
273
274#[derive(Serialize)]
275struct OpenAITool {
276 #[serde(rename = "type")]
277 tool_type: String,
278 function: OpenAIFunction,
279}
280
281#[derive(Serialize)]
282struct OpenAIFunction {
283 name: String,
284 description: String,
285 parameters: serde_json::Map<String, serde_json::Value>,
286 strict: bool,
287}
288
289#[derive(Deserialize)]
290struct OpenAIResponse {
291 choices: Vec<OpenAIChoice>,
292 usage: Option<OpenAIUsage>,
293}
294
295#[derive(Deserialize)]
296struct OpenAIChoice {
297 message: OpenAIMessageResponse,
298 finish_reason: Option<String>,
299}
300
301#[derive(Deserialize)]
302struct OpenAIMessageResponse {
303 content: Option<String>,
304 tool_calls: Option<Vec<crate::llm::ToolCall>>,
305 reasoning_content: Option<String>,
306}
307
308#[derive(Deserialize)]
309struct OpenAIUsage {
310 prompt_tokens: u64,
311 completion_tokens: u64,
312 total_tokens: u64,
313 #[serde(default)]
314 prompt_tokens_details: Option<OpenAIPromptTokenDetails>,
315}
316
317#[derive(Deserialize, Default)]
318struct OpenAIPromptTokenDetails {
319 cached_tokens: u64,
320}
321
322impl ChatOpenAI {
323 fn build_request(
324 &self,
325 messages: Vec<Message>,
326 tools: Option<Vec<ToolDefinition>>,
327 tool_choice: Option<ToolChoice>,
328 stream: bool,
329 ) -> Result<OpenAIRequest, LlmError> {
330 let openai_messages: Vec<OpenAIMessage> =
331 messages.into_iter().map(Self::convert_message).collect();
332
333 let openai_tools = tools.map(|ts| {
334 ts.into_iter()
335 .map(|t| OpenAITool {
336 tool_type: "function".to_string(),
337 function: OpenAIFunction {
338 name: t.name,
339 description: t.description,
340 parameters: t.parameters,
341 strict: t.strict,
342 },
343 })
344 .collect()
345 });
346
347 let tool_choice_value = tool_choice.map(|tc| match tc {
348 ToolChoice::Auto => serde_json::json!("auto"),
349 ToolChoice::Required => serde_json::json!("required"),
350 ToolChoice::None => serde_json::json!("none"),
351 ToolChoice::Named(name) => {
352 serde_json::json!({"type": "function", "function": {"name": name}})
353 }
354 });
355
356 let temperature = if self.is_reasoning_model() {
358 None
359 } else {
360 Some(self.temperature)
361 };
362
363 let reasoning_effort = if self.is_reasoning_model() {
365 Some(self.reasoning_effort.clone())
366 } else {
367 None
368 };
369
370 Ok(OpenAIRequest {
371 model: self.model.clone(),
372 messages: openai_messages,
373 tools: openai_tools,
374 tool_choice: tool_choice_value,
375 temperature,
376 max_completion_tokens: self.max_completion_tokens,
377 reasoning_effort,
378 stream: if stream { Some(true) } else { None },
379 })
380 }
381
382 fn convert_message(message: Message) -> OpenAIMessage {
383 match message {
384 Message::User(u) => {
385 let content = if u.content.len() == 1 && u.content[0].is_text() {
386 serde_json::json!(u.content[0].as_text().unwrap())
387 } else {
388 serde_json::json!(u.content)
389 };
390 OpenAIMessage {
391 role: "user".to_string(),
392 content: Some(content),
393 name: u.name,
394 tool_calls: None,
395 tool_call_id: None,
396 }
397 }
398 Message::Assistant(a) => OpenAIMessage {
399 role: "assistant".to_string(),
400 content: a.content.map(|c| serde_json::json!(c)),
401 name: None,
402 tool_calls: if a.tool_calls.is_empty() {
403 None
404 } else {
405 Some(a.tool_calls)
406 },
407 tool_call_id: None,
408 },
409 Message::System(s) => OpenAIMessage {
410 role: "system".to_string(),
411 content: Some(serde_json::json!(s.content)),
412 name: None,
413 tool_calls: None,
414 tool_call_id: None,
415 },
416 Message::Developer(d) => OpenAIMessage {
417 role: "developer".to_string(),
418 content: Some(serde_json::json!(d.content)),
419 name: None,
420 tool_calls: None,
421 tool_call_id: None,
422 },
423 Message::Tool(t) => OpenAIMessage {
424 role: "tool".to_string(),
425 content: Some(serde_json::json!(t.content)),
426 name: None,
427 tool_calls: None,
428 tool_call_id: Some(t.tool_call_id),
429 },
430 }
431 }
432
433 fn parse_response(&self, response: OpenAIResponse) -> ChatCompletion {
434 let stop_reason = response
435 .choices
436 .first()
437 .and_then(|c| c.finish_reason.as_ref())
438 .and_then(|r| match r.as_str() {
439 "stop" => Some(StopReason::EndTurn),
440 "tool_calls" => Some(StopReason::ToolUse),
441 "length" => Some(StopReason::MaxTokens),
442 _ => None,
443 });
444
445 let choice = response.choices.into_iter().next();
446
447 let (content, tool_calls) = choice
448 .map(|c| {
449 let reasoning = c.message.reasoning_content;
450 let content = c.message.content.or(reasoning);
451 (content, c.message.tool_calls.unwrap_or_default())
452 })
453 .unwrap_or((None, Vec::new()));
454
455 let usage = response.usage.map(|u| Usage {
456 prompt_tokens: u.prompt_tokens,
457 completion_tokens: u.completion_tokens,
458 total_tokens: u.total_tokens,
459 prompt_cached_tokens: u.prompt_tokens_details.map(|d| d.cached_tokens),
460 ..Default::default()
461 });
462
463 ChatCompletion {
464 content,
465 thinking: None,
466 redacted_thinking: None,
467 tool_calls,
468 usage,
469 stop_reason,
470 }
471 }
472
473 fn parse_stream_chunk(text: &str) -> Option<Result<ChatCompletion, LlmError>> {
474 for line in text.lines() {
475 let line = line.trim();
476 if line.is_empty() || !line.starts_with("data:") {
477 continue;
478 }
479
480 let data = line.strip_prefix("data:").unwrap().trim();
481 if data == "[DONE]" {
482 return None;
483 }
484
485 let chunk: serde_json::Value = match serde_json::from_str(data) {
486 Ok(v) => v,
487 Err(_) => continue,
488 };
489
490 let delta = chunk
491 .get("choices")
492 .and_then(|c| c.as_array())
493 .and_then(|a| a.first())
494 .and_then(|c| c.get("delta"));
495
496 if let Some(delta) = delta {
497 let content = delta
498 .get("content")
499 .and_then(|c| c.as_str())
500 .map(|s| s.to_string());
501
502 let tool_calls: Vec<crate::llm::ToolCall> = delta
503 .get("tool_calls")
504 .and_then(|tc| tc.as_array())
505 .map(|arr| {
506 arr.iter()
507 .filter_map(|tc| {
508 let id = tc.get("id")?.as_str()?.to_string();
509 let func = tc.get("function")?;
510 let name = func.get("name")?.as_str()?.to_string();
511 let arguments = func.get("arguments")?.as_str()?.to_string();
512 Some(crate::llm::ToolCall::new(id, name, arguments))
513 })
514 .collect()
515 })
516 .unwrap_or_default();
517
518 if content.is_some() || !tool_calls.is_empty() {
519 return Some(Ok(ChatCompletion {
520 content,
521 thinking: None,
522 redacted_thinking: None,
523 tool_calls,
524 usage: None,
525 stop_reason: None,
526 }));
527 }
528 }
529 }
530
531 None
532 }
533}