Skip to main content

agent_io/llm/openai/
mod.rs

1//! OpenAI Chat Model implementation
2
3mod request;
4mod response;
5mod types;
6
7use async_trait::async_trait;
8use derive_builder::Builder;
9use futures::StreamExt;
10use reqwest::Client;
11use std::time::Duration;
12
13use crate::llm::{
14    BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
15};
16
17use types::*;
18
19const OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
20const CHAT_COMPLETIONS_PATH: &str = "/chat/completions";
21
22/// OpenAI Chat Model
23#[derive(Builder, Clone)]
24#[builder(pattern = "owned", build_fn(skip))]
25pub struct ChatOpenAI {
26    /// Model identifier
27    #[builder(setter(into))]
28    pub(super) model: String,
29    /// API key
30    pub(super) api_key: String,
31    /// Base URL (for proxies)
32    #[builder(setter(into, strip_option), default = "None")]
33    pub(super) base_url: Option<String>,
34    /// Temperature for sampling
35    #[builder(default = "0.2")]
36    pub(super) temperature: f32,
37    /// Maximum completion tokens
38    #[builder(default = "Some(4096)")]
39    pub(super) max_completion_tokens: Option<u64>,
40    /// HTTP client
41    #[builder(setter(skip))]
42    pub(super) client: Client,
43    /// Context window size
44    #[builder(setter(skip))]
45    pub(super) context_window: u64,
46}
47
48impl ChatOpenAI {
49    /// Create a new OpenAI chat model
50    pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
51        let api_key = std::env::var("OPENAI_API_KEY")
52            .map_err(|_| LlmError::Config("OPENAI_API_KEY not set".into()))?;
53        let base_url = std::env::var("OPENAI_BASE_URL").ok();
54
55        let mut builder = Self::builder().model(model).api_key(api_key);
56        if let Some(url) = base_url {
57            builder = builder.base_url(url);
58        }
59        builder.build()
60    }
61
62    /// Create a builder for configuration
63    pub fn builder() -> ChatOpenAIBuilder {
64        ChatOpenAIBuilder::default()
65    }
66
67    /// Check if this is a reasoning model (o1, o3, o4, gpt-5)
68    fn is_reasoning_model(&self) -> bool {
69        let model_lower = self.model.to_lowercase();
70        model_lower.starts_with("o1")
71            || model_lower.starts_with("o3")
72            || model_lower.starts_with("o4")
73            || model_lower.starts_with("gpt-5")
74    }
75
76    /// Get the API URL
77    fn api_url(&self) -> String {
78        let base = self.base_url.as_deref().unwrap_or(OPENAI_BASE_URL);
79        format!("{}{}", base.trim_end_matches('/'), CHAT_COMPLETIONS_PATH)
80    }
81
82    /// Build the HTTP client
83    fn build_client() -> Client {
84        Client::builder()
85            .timeout(Duration::from_secs(120))
86            .build()
87            .expect("Failed to create HTTP client")
88    }
89
90    /// Get context window for model
91    fn get_context_window(model: &str) -> u64 {
92        let model_lower = model.to_lowercase();
93
94        // GPT-4o family
95        if model_lower.contains("gpt-4o") || model_lower.contains("gpt-4-turbo") {
96            128_000
97        }
98        // GPT-4
99        else if model_lower.starts_with("gpt-4") {
100            8_192
101        }
102        // GPT-3.5
103        else if model_lower.starts_with("gpt-3.5") {
104            16_385
105        }
106        // O1/O3/O4 reasoning models
107        else if model_lower.starts_with("o1")
108            || model_lower.starts_with("o3")
109            || model_lower.starts_with("o4")
110        {
111            200_000
112        }
113        // Default
114        else {
115            128_000
116        }
117    }
118}
119
120impl ChatOpenAIBuilder {
121    pub fn build(&self) -> Result<ChatOpenAI, LlmError> {
122        let model = self
123            .model
124            .clone()
125            .ok_or_else(|| LlmError::Config("model is required".into()))?;
126        let api_key = self
127            .api_key
128            .clone()
129            .ok_or_else(|| LlmError::Config("api_key is required".into()))?;
130
131        Ok(ChatOpenAI {
132            context_window: ChatOpenAI::get_context_window(&model),
133            client: ChatOpenAI::build_client(),
134            model,
135            api_key,
136            base_url: self.base_url.clone().flatten(),
137            temperature: self.temperature.unwrap_or(0.2),
138            max_completion_tokens: self.max_completion_tokens.flatten(),
139        })
140    }
141}
142
143#[async_trait]
144impl BaseChatModel for ChatOpenAI {
145    fn model(&self) -> &str {
146        &self.model
147    }
148
149    fn provider(&self) -> &str {
150        "openai"
151    }
152
153    fn context_window(&self) -> Option<u64> {
154        Some(self.context_window)
155    }
156
157    async fn invoke(
158        &self,
159        messages: Vec<Message>,
160        tools: Option<Vec<ToolDefinition>>,
161        tool_choice: Option<ToolChoice>,
162    ) -> Result<ChatCompletion, LlmError> {
163        let request = self.build_request(messages, tools, tool_choice, false)?;
164
165        let response = self
166            .client
167            .post(self.api_url())
168            .header("Authorization", format!("Bearer {}", self.api_key))
169            .header("Content-Type", "application/json")
170            .json(&request)
171            .send()
172            .await?;
173
174        if !response.status().is_success() {
175            let status = response.status();
176            if status.as_u16() == 429 {
177                return Err(LlmError::RateLimit);
178            }
179            let body = response.text().await.unwrap_or_default();
180            return Err(LlmError::Api(format!(
181                "OpenAI API error ({}): {}",
182                status, body
183            )));
184        }
185        let body = response.text().await?;
186        tracing::debug!("OpenAI raw response: {}", body);
187
188        // Some proxies always return SSE format regardless of stream=false.
189        // Detect by checking if the body starts with "data:"
190        if body.trim_start().starts_with("data:") {
191            return self.parse_sse_as_completion(&body);
192        }
193
194        let completion: OpenAIResponse = serde_json::from_str(&body).map_err(|e| {
195            LlmError::Api(format!(
196                "Failed to parse response: {}\nBody: {}",
197                e,
198                &body[..body.len().min(500)]
199            ))
200        })?;
201        Ok(self.parse_response(completion))
202    }
203
204    async fn invoke_stream(
205        &self,
206        messages: Vec<Message>,
207        tools: Option<Vec<ToolDefinition>>,
208        tool_choice: Option<ToolChoice>,
209    ) -> Result<ChatStream, LlmError> {
210        let request = self.build_request(messages, tools, tool_choice, true)?;
211
212        let response = self
213            .client
214            .post(self.api_url())
215            .header("Authorization", format!("Bearer {}", self.api_key))
216            .header("Content-Type", "application/json")
217            .json(&request)
218            .send()
219            .await?;
220
221        if !response.status().is_success() {
222            let status = response.status();
223            if status.as_u16() == 429 {
224                return Err(LlmError::RateLimit);
225            }
226            let body = response.text().await.unwrap_or_default();
227            return Err(LlmError::Api(format!(
228                "OpenAI API error ({}): {}",
229                status, body
230            )));
231        }
232
233        let stream = response.bytes_stream().filter_map(|result| async move {
234            match result {
235                Ok(bytes) => {
236                    let text = String::from_utf8_lossy(&bytes);
237                    Self::parse_stream_chunk(&text)
238                }
239                Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
240            }
241        });
242
243        Ok(Box::pin(stream))
244    }
245
246    fn supports_vision(&self) -> bool {
247        let model_lower = self.model.to_lowercase();
248        model_lower.contains("gpt-4o")
249            || model_lower.contains("gpt-4-turbo")
250            || model_lower.contains("gpt-4-vision")
251            || model_lower.contains("gpt-4.1")
252    }
253}