Skip to main content

agent_io/llm/openai/
mod.rs

1//! OpenAI Chat Model implementation
2
3mod request;
4mod response;
5mod types;
6
7use async_trait::async_trait;
8use derive_builder::Builder;
9use futures::StreamExt;
10use reqwest::Client;
11use std::time::Duration;
12
13use crate::llm::{
14    BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
15};
16
17use types::*;
18
19const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
20
21/// OpenAI Chat Model
22#[derive(Builder, Clone)]
23#[builder(pattern = "owned", build_fn(skip))]
24pub struct ChatOpenAI {
25    /// Model identifier
26    #[builder(setter(into))]
27    pub(super) model: String,
28    /// API key
29    pub(super) api_key: String,
30    /// Base URL (for proxies)
31    #[builder(setter(into, strip_option), default = "None")]
32    pub(super) base_url: Option<String>,
33    /// Temperature for sampling
34    #[builder(default = "0.2")]
35    pub(super) temperature: f32,
36    /// Maximum completion tokens
37    #[builder(default = "Some(4096)")]
38    pub(super) max_completion_tokens: Option<u64>,
39    /// Reasoning effort for o1+ models
40    #[builder(default = "ReasoningEffort::Low")]
41    pub(super) reasoning_effort: ReasoningEffort,
42    /// HTTP client
43    #[builder(setter(skip))]
44    pub(super) client: Client,
45    /// Context window size
46    #[builder(setter(skip))]
47    pub(super) context_window: u64,
48}
49
50impl ChatOpenAI {
51    /// Create a new OpenAI chat model
52    pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
53        let api_key = std::env::var("OPENAI_API_KEY")
54            .map_err(|_| LlmError::Config("OPENAI_API_KEY not set".into()))?;
55
56        Self::builder().model(model).api_key(api_key).build()
57    }
58
59    /// Create a builder for configuration
60    pub fn builder() -> ChatOpenAIBuilder {
61        ChatOpenAIBuilder::default()
62    }
63
64    /// Check if this is a reasoning model (o1, o3, o4, gpt-5)
65    fn is_reasoning_model(&self) -> bool {
66        let model_lower = self.model.to_lowercase();
67        model_lower.starts_with("o1")
68            || model_lower.starts_with("o3")
69            || model_lower.starts_with("o4")
70            || model_lower.starts_with("gpt-5")
71    }
72
73    /// Get the API URL
74    fn api_url(&self) -> &str {
75        self.base_url.as_deref().unwrap_or(OPENAI_API_URL)
76    }
77
78    /// Build the HTTP client
79    fn build_client() -> Client {
80        Client::builder()
81            .timeout(Duration::from_secs(120))
82            .build()
83            .expect("Failed to create HTTP client")
84    }
85
86    /// Get context window for model
87    fn get_context_window(model: &str) -> u64 {
88        let model_lower = model.to_lowercase();
89
90        // GPT-4o family
91        if model_lower.contains("gpt-4o") || model_lower.contains("gpt-4-turbo") {
92            128_000
93        }
94        // GPT-4
95        else if model_lower.starts_with("gpt-4") {
96            8_192
97        }
98        // GPT-3.5
99        else if model_lower.starts_with("gpt-3.5") {
100            16_385
101        }
102        // O1/O3/O4 reasoning models
103        else if model_lower.starts_with("o1")
104            || model_lower.starts_with("o3")
105            || model_lower.starts_with("o4")
106        {
107            200_000
108        }
109        // Default
110        else {
111            128_000
112        }
113    }
114}
115
116impl ChatOpenAIBuilder {
117    pub fn build(&self) -> Result<ChatOpenAI, LlmError> {
118        let model = self
119            .model
120            .clone()
121            .ok_or_else(|| LlmError::Config("model is required".into()))?;
122        let api_key = self
123            .api_key
124            .clone()
125            .ok_or_else(|| LlmError::Config("api_key is required".into()))?;
126
127        Ok(ChatOpenAI {
128            context_window: ChatOpenAI::get_context_window(&model),
129            client: ChatOpenAI::build_client(),
130            model,
131            api_key,
132            base_url: self.base_url.clone().flatten(),
133            temperature: self.temperature.unwrap_or(0.2),
134            max_completion_tokens: self.max_completion_tokens.flatten(),
135            reasoning_effort: self.reasoning_effort.clone().unwrap_or_default(),
136        })
137    }
138}
139
140#[async_trait]
141impl BaseChatModel for ChatOpenAI {
142    fn model(&self) -> &str {
143        &self.model
144    }
145
146    fn provider(&self) -> &str {
147        "openai"
148    }
149
150    fn context_window(&self) -> Option<u64> {
151        Some(self.context_window)
152    }
153
154    async fn invoke(
155        &self,
156        messages: Vec<Message>,
157        tools: Option<Vec<ToolDefinition>>,
158        tool_choice: Option<ToolChoice>,
159    ) -> Result<ChatCompletion, LlmError> {
160        let request = self.build_request(messages, tools, tool_choice, false)?;
161
162        let response = self
163            .client
164            .post(self.api_url())
165            .header("Authorization", format!("Bearer {}", self.api_key))
166            .header("Content-Type", "application/json")
167            .json(&request)
168            .send()
169            .await?;
170
171        if !response.status().is_success() {
172            let status = response.status();
173            let body = response.text().await.unwrap_or_default();
174            return Err(LlmError::Api(format!(
175                "OpenAI API error ({}): {}",
176                status, body
177            )));
178        }
179
180        let completion: OpenAIResponse = response.json().await?;
181        Ok(self.parse_response(completion))
182    }
183
184    async fn invoke_stream(
185        &self,
186        messages: Vec<Message>,
187        tools: Option<Vec<ToolDefinition>>,
188        tool_choice: Option<ToolChoice>,
189    ) -> Result<ChatStream, LlmError> {
190        let request = self.build_request(messages, tools, tool_choice, true)?;
191
192        let response = self
193            .client
194            .post(self.api_url())
195            .header("Authorization", format!("Bearer {}", self.api_key))
196            .header("Content-Type", "application/json")
197            .json(&request)
198            .send()
199            .await?;
200
201        if !response.status().is_success() {
202            let status = response.status();
203            let body = response.text().await.unwrap_or_default();
204            return Err(LlmError::Api(format!(
205                "OpenAI API error ({}): {}",
206                status, body
207            )));
208        }
209
210        let stream = response.bytes_stream().filter_map(|result| async move {
211            match result {
212                Ok(bytes) => {
213                    let text = String::from_utf8_lossy(&bytes);
214                    Self::parse_stream_chunk(&text)
215                }
216                Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
217            }
218        });
219
220        Ok(Box::pin(stream))
221    }
222
223    fn supports_vision(&self) -> bool {
224        let model_lower = self.model.to_lowercase();
225        model_lower.contains("gpt-4o")
226            || model_lower.contains("gpt-4-turbo")
227            || model_lower.contains("gpt-4-vision")
228            || model_lower.contains("gpt-4.1")
229    }
230}
231
232// Re-export ReasoningEffort for public API
233pub use types::ReasoningEffort;