Skip to main content

agent_io/llm/
openai.rs

1//! OpenAI Chat Model implementation
2
3use async_trait::async_trait;
4use derive_builder::Builder;
5use futures::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9
10use crate::llm::{
11    BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, StopReason, ToolChoice,
12    ToolDefinition, Usage,
13};
14
15const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
16
17/// Reasoning effort levels for o1+ models
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19#[serde(rename_all = "lowercase")]
20pub enum ReasoningEffort {
21    Low,
22    Medium,
23    High,
24    #[default]
25    Minimal,
26}
27
28/// OpenAI Chat Model
29#[derive(Builder, Clone)]
30#[builder(pattern = "owned", build_fn(skip))]
31pub struct ChatOpenAI {
32    /// Model identifier
33    #[builder(setter(into))]
34    model: String,
35    /// API key
36    api_key: String,
37    /// Base URL (for proxies)
38    #[builder(setter(into, strip_option), default = "None")]
39    base_url: Option<String>,
40    /// Temperature for sampling
41    #[builder(default = "0.2")]
42    temperature: f32,
43    /// Maximum completion tokens
44    #[builder(default = "Some(4096)")]
45    max_completion_tokens: Option<u64>,
46    /// Reasoning effort for o1+ models
47    #[builder(default = "ReasoningEffort::Low")]
48    reasoning_effort: ReasoningEffort,
49    /// HTTP client
50    #[builder(setter(skip))]
51    client: Client,
52    /// Context window size
53    #[builder(setter(skip))]
54    context_window: u64,
55}
56
57impl ChatOpenAI {
58    /// Create a new OpenAI chat model
59    pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
60        let api_key = std::env::var("OPENAI_API_KEY")
61            .map_err(|_| LlmError::Config("OPENAI_API_KEY not set".into()))?;
62
63        Self::builder().model(model).api_key(api_key).build()
64    }
65
66    /// Create a builder for configuration
67    pub fn builder() -> ChatOpenAIBuilder {
68        ChatOpenAIBuilder::default()
69    }
70
71    /// Check if this is a reasoning model (o1, o3, o4, gpt-5)
72    fn is_reasoning_model(&self) -> bool {
73        let model_lower = self.model.to_lowercase();
74        model_lower.starts_with("o1")
75            || model_lower.starts_with("o3")
76            || model_lower.starts_with("o4")
77            || model_lower.starts_with("gpt-5")
78    }
79
80    /// Get the API URL
81    fn api_url(&self) -> &str {
82        self.base_url.as_deref().unwrap_or(OPENAI_API_URL)
83    }
84
85    /// Build the HTTP client
86    fn build_client() -> Client {
87        Client::builder()
88            .timeout(Duration::from_secs(120))
89            .build()
90            .expect("Failed to create HTTP client")
91    }
92
93    /// Get context window for model
94    fn get_context_window(model: &str) -> u64 {
95        let model_lower = model.to_lowercase();
96
97        // GPT-4o family
98        if model_lower.contains("gpt-4o") || model_lower.contains("gpt-4-turbo") {
99            128_000
100        }
101        // GPT-4
102        else if model_lower.starts_with("gpt-4") {
103            8_192
104        }
105        // GPT-3.5
106        else if model_lower.starts_with("gpt-3.5") {
107            16_385
108        }
109        // O1/O3/O4 reasoning models
110        else if model_lower.starts_with("o1")
111            || model_lower.starts_with("o3")
112            || model_lower.starts_with("o4")
113        {
114            200_000
115        }
116        // Default
117        else {
118            128_000
119        }
120    }
121}
122
123impl ChatOpenAIBuilder {
124    pub fn build(&self) -> Result<ChatOpenAI, LlmError> {
125        let model = self
126            .model
127            .clone()
128            .ok_or_else(|| LlmError::Config("model is required".into()))?;
129        let api_key = self
130            .api_key
131            .clone()
132            .ok_or_else(|| LlmError::Config("api_key is required".into()))?;
133
134        Ok(ChatOpenAI {
135            context_window: ChatOpenAI::get_context_window(&model),
136            client: ChatOpenAI::build_client(),
137            model,
138            api_key,
139            base_url: self.base_url.clone().flatten(),
140            temperature: self.temperature.unwrap_or(0.2),
141            max_completion_tokens: self.max_completion_tokens.flatten(),
142            reasoning_effort: self.reasoning_effort.clone().unwrap_or_default(),
143        })
144    }
145}
146
147#[async_trait]
148impl BaseChatModel for ChatOpenAI {
149    fn model(&self) -> &str {
150        &self.model
151    }
152
153    fn provider(&self) -> &str {
154        "openai"
155    }
156
157    fn context_window(&self) -> Option<u64> {
158        Some(self.context_window)
159    }
160
161    async fn invoke(
162        &self,
163        messages: Vec<Message>,
164        tools: Option<Vec<ToolDefinition>>,
165        tool_choice: Option<ToolChoice>,
166    ) -> Result<ChatCompletion, LlmError> {
167        let request = self.build_request(messages, tools, tool_choice, false)?;
168
169        let response = self
170            .client
171            .post(self.api_url())
172            .header("Authorization", format!("Bearer {}", self.api_key))
173            .header("Content-Type", "application/json")
174            .json(&request)
175            .send()
176            .await?;
177
178        if !response.status().is_success() {
179            let status = response.status();
180            let body = response.text().await.unwrap_or_default();
181            return Err(LlmError::Api(format!(
182                "OpenAI API error ({}): {}",
183                status, body
184            )));
185        }
186
187        let completion: OpenAIResponse = response.json().await?;
188        Ok(self.parse_response(completion))
189    }
190
191    async fn invoke_stream(
192        &self,
193        messages: Vec<Message>,
194        tools: Option<Vec<ToolDefinition>>,
195        tool_choice: Option<ToolChoice>,
196    ) -> Result<ChatStream, LlmError> {
197        let request = self.build_request(messages, tools, tool_choice, true)?;
198
199        let response = self
200            .client
201            .post(self.api_url())
202            .header("Authorization", format!("Bearer {}", self.api_key))
203            .header("Content-Type", "application/json")
204            .json(&request)
205            .send()
206            .await?;
207
208        if !response.status().is_success() {
209            let status = response.status();
210            let body = response.text().await.unwrap_or_default();
211            return Err(LlmError::Api(format!(
212                "OpenAI API error ({}): {}",
213                status, body
214            )));
215        }
216
217        let stream = response.bytes_stream().filter_map(|result| async move {
218            match result {
219                Ok(bytes) => {
220                    let text = String::from_utf8_lossy(&bytes);
221                    Self::parse_stream_chunk(&text)
222                }
223                Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
224            }
225        });
226
227        Ok(Box::pin(stream))
228    }
229
230    fn supports_vision(&self) -> bool {
231        let model_lower = self.model.to_lowercase();
232        model_lower.contains("gpt-4o")
233            || model_lower.contains("gpt-4-turbo")
234            || model_lower.contains("gpt-4-vision")
235            || model_lower.contains("gpt-4.1")
236    }
237}
238
239// =============================================================================
240// Request/Response Types
241// =============================================================================
242
243#[derive(Serialize)]
244struct OpenAIRequest {
245    model: String,
246    messages: Vec<OpenAIMessage>,
247    #[serde(skip_serializing_if = "Option::is_none")]
248    tools: Option<Vec<OpenAITool>>,
249    #[serde(skip_serializing_if = "Option::is_none")]
250    tool_choice: Option<serde_json::Value>,
251    #[serde(skip_serializing_if = "Option::is_none")]
252    temperature: Option<f32>,
253    #[serde(skip_serializing_if = "Option::is_none")]
254    max_completion_tokens: Option<u64>,
255    #[serde(skip_serializing_if = "Option::is_none")]
256    reasoning_effort: Option<ReasoningEffort>,
257    #[serde(skip_serializing_if = "Option::is_none")]
258    stream: Option<bool>,
259}
260
261#[derive(Serialize)]
262struct OpenAIMessage {
263    role: String,
264    #[serde(skip_serializing_if = "Option::is_none")]
265    content: Option<serde_json::Value>,
266    #[serde(skip_serializing_if = "Option::is_none")]
267    name: Option<String>,
268    #[serde(skip_serializing_if = "Option::is_none")]
269    tool_calls: Option<Vec<crate::llm::ToolCall>>,
270    #[serde(skip_serializing_if = "Option::is_none")]
271    tool_call_id: Option<String>,
272}
273
274#[derive(Serialize)]
275struct OpenAITool {
276    #[serde(rename = "type")]
277    tool_type: String,
278    function: OpenAIFunction,
279}
280
281#[derive(Serialize)]
282struct OpenAIFunction {
283    name: String,
284    description: String,
285    parameters: serde_json::Map<String, serde_json::Value>,
286    strict: bool,
287}
288
289#[derive(Deserialize)]
290struct OpenAIResponse {
291    choices: Vec<OpenAIChoice>,
292    usage: Option<OpenAIUsage>,
293}
294
295#[derive(Deserialize)]
296struct OpenAIChoice {
297    message: OpenAIMessageResponse,
298    finish_reason: Option<String>,
299}
300
301#[derive(Deserialize)]
302struct OpenAIMessageResponse {
303    content: Option<String>,
304    tool_calls: Option<Vec<crate::llm::ToolCall>>,
305    reasoning_content: Option<String>,
306}
307
308#[derive(Deserialize)]
309struct OpenAIUsage {
310    prompt_tokens: u64,
311    completion_tokens: u64,
312    total_tokens: u64,
313    #[serde(default)]
314    prompt_tokens_details: Option<OpenAIPromptTokenDetails>,
315}
316
317#[derive(Deserialize, Default)]
318struct OpenAIPromptTokenDetails {
319    cached_tokens: u64,
320}
321
322impl ChatOpenAI {
323    fn build_request(
324        &self,
325        messages: Vec<Message>,
326        tools: Option<Vec<ToolDefinition>>,
327        tool_choice: Option<ToolChoice>,
328        stream: bool,
329    ) -> Result<OpenAIRequest, LlmError> {
330        let openai_messages: Vec<OpenAIMessage> =
331            messages.into_iter().map(Self::convert_message).collect();
332
333        let openai_tools = tools.map(|ts| {
334            ts.into_iter()
335                .map(|t| OpenAITool {
336                    tool_type: "function".to_string(),
337                    function: OpenAIFunction {
338                        name: t.name,
339                        description: t.description,
340                        parameters: t.parameters,
341                        strict: t.strict,
342                    },
343                })
344                .collect()
345        });
346
347        let tool_choice_value = tool_choice.map(|tc| match tc {
348            ToolChoice::Auto => serde_json::json!("auto"),
349            ToolChoice::Required => serde_json::json!("required"),
350            ToolChoice::None => serde_json::json!("none"),
351            ToolChoice::Named(name) => {
352                serde_json::json!({"type": "function", "function": {"name": name}})
353            }
354        });
355
356        // For reasoning models, omit temperature
357        let temperature = if self.is_reasoning_model() {
358            None
359        } else {
360            Some(self.temperature)
361        };
362
363        // For reasoning models, add reasoning_effort
364        let reasoning_effort = if self.is_reasoning_model() {
365            Some(self.reasoning_effort.clone())
366        } else {
367            None
368        };
369
370        Ok(OpenAIRequest {
371            model: self.model.clone(),
372            messages: openai_messages,
373            tools: openai_tools,
374            tool_choice: tool_choice_value,
375            temperature,
376            max_completion_tokens: self.max_completion_tokens,
377            reasoning_effort,
378            stream: if stream { Some(true) } else { None },
379        })
380    }
381
382    fn convert_message(message: Message) -> OpenAIMessage {
383        match message {
384            Message::User(u) => {
385                let content = if u.content.len() == 1 && u.content[0].is_text() {
386                    serde_json::json!(u.content[0].as_text().unwrap())
387                } else {
388                    serde_json::json!(u.content)
389                };
390                OpenAIMessage {
391                    role: "user".to_string(),
392                    content: Some(content),
393                    name: u.name,
394                    tool_calls: None,
395                    tool_call_id: None,
396                }
397            }
398            Message::Assistant(a) => OpenAIMessage {
399                role: "assistant".to_string(),
400                content: a.content.map(|c| serde_json::json!(c)),
401                name: None,
402                tool_calls: if a.tool_calls.is_empty() {
403                    None
404                } else {
405                    Some(a.tool_calls)
406                },
407                tool_call_id: None,
408            },
409            Message::System(s) => OpenAIMessage {
410                role: "system".to_string(),
411                content: Some(serde_json::json!(s.content)),
412                name: None,
413                tool_calls: None,
414                tool_call_id: None,
415            },
416            Message::Developer(d) => OpenAIMessage {
417                role: "developer".to_string(),
418                content: Some(serde_json::json!(d.content)),
419                name: None,
420                tool_calls: None,
421                tool_call_id: None,
422            },
423            Message::Tool(t) => OpenAIMessage {
424                role: "tool".to_string(),
425                content: Some(serde_json::json!(t.content)),
426                name: None,
427                tool_calls: None,
428                tool_call_id: Some(t.tool_call_id),
429            },
430        }
431    }
432
433    fn parse_response(&self, response: OpenAIResponse) -> ChatCompletion {
434        let stop_reason = response
435            .choices
436            .first()
437            .and_then(|c| c.finish_reason.as_ref())
438            .and_then(|r| match r.as_str() {
439                "stop" => Some(StopReason::EndTurn),
440                "tool_calls" => Some(StopReason::ToolUse),
441                "length" => Some(StopReason::MaxTokens),
442                _ => None,
443            });
444
445        let choice = response.choices.into_iter().next();
446
447        let (content, tool_calls) = choice
448            .map(|c| {
449                let reasoning = c.message.reasoning_content;
450                let content = c.message.content.or(reasoning);
451                (content, c.message.tool_calls.unwrap_or_default())
452            })
453            .unwrap_or((None, Vec::new()));
454
455        let usage = response.usage.map(|u| Usage {
456            prompt_tokens: u.prompt_tokens,
457            completion_tokens: u.completion_tokens,
458            total_tokens: u.total_tokens,
459            prompt_cached_tokens: u.prompt_tokens_details.map(|d| d.cached_tokens),
460            ..Default::default()
461        });
462
463        ChatCompletion {
464            content,
465            thinking: None,
466            redacted_thinking: None,
467            tool_calls,
468            usage,
469            stop_reason,
470        }
471    }
472
473    fn parse_stream_chunk(text: &str) -> Option<Result<ChatCompletion, LlmError>> {
474        for line in text.lines() {
475            let line = line.trim();
476            if line.is_empty() || !line.starts_with("data:") {
477                continue;
478            }
479
480            let data = line.strip_prefix("data:").unwrap().trim();
481            if data == "[DONE]" {
482                return None;
483            }
484
485            let chunk: serde_json::Value = match serde_json::from_str(data) {
486                Ok(v) => v,
487                Err(_) => continue,
488            };
489
490            let delta = chunk
491                .get("choices")
492                .and_then(|c| c.as_array())
493                .and_then(|a| a.first())
494                .and_then(|c| c.get("delta"));
495
496            if let Some(delta) = delta {
497                let content = delta
498                    .get("content")
499                    .and_then(|c| c.as_str())
500                    .map(|s| s.to_string());
501
502                let tool_calls: Vec<crate::llm::ToolCall> = delta
503                    .get("tool_calls")
504                    .and_then(|tc| tc.as_array())
505                    .map(|arr| {
506                        arr.iter()
507                            .filter_map(|tc| {
508                                let id = tc.get("id")?.as_str()?.to_string();
509                                let func = tc.get("function")?;
510                                let name = func.get("name")?.as_str()?.to_string();
511                                let arguments = func.get("arguments")?.as_str()?.to_string();
512                                Some(crate::llm::ToolCall::new(id, name, arguments))
513                            })
514                            .collect()
515                    })
516                    .unwrap_or_default();
517
518                if content.is_some() || !tool_calls.is_empty() {
519                    return Some(Ok(ChatCompletion {
520                        content,
521                        thinking: None,
522                        redacted_thinking: None,
523                        tool_calls,
524                        usage: None,
525                        stop_reason: None,
526                    }));
527                }
528            }
529        }
530
531        None
532    }
533}