agent_io/llm/openai/
mod.rs1mod request;
4mod response;
5mod types;
6
7use async_trait::async_trait;
8use derive_builder::Builder;
9use futures::StreamExt;
10use reqwest::Client;
11use std::time::Duration;
12
13use crate::llm::{
14 BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
15};
16
17use types::*;
18
19const OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
20const CHAT_COMPLETIONS_PATH: &str = "/chat/completions";
21
22#[derive(Builder, Clone)]
24#[builder(pattern = "owned", build_fn(skip))]
25pub struct ChatOpenAI {
26 #[builder(setter(into))]
28 pub(super) model: String,
29 pub(super) api_key: String,
31 #[builder(setter(into, strip_option), default = "None")]
33 pub(super) base_url: Option<String>,
34 #[builder(default = "0.2")]
36 pub(super) temperature: f32,
37 #[builder(default = "Some(4096)")]
39 pub(super) max_completion_tokens: Option<u64>,
40 #[builder(setter(skip))]
42 pub(super) client: Client,
43 #[builder(setter(skip))]
45 pub(super) context_window: u64,
46}
47
48impl ChatOpenAI {
49 pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
51 let api_key = std::env::var("OPENAI_API_KEY")
52 .map_err(|_| LlmError::Config("OPENAI_API_KEY not set".into()))?;
53 let base_url = std::env::var("OPENAI_BASE_URL").ok();
54
55 let mut builder = Self::builder().model(model).api_key(api_key);
56 if let Some(url) = base_url {
57 builder = builder.base_url(url);
58 }
59 builder.build()
60 }
61
62 pub fn builder() -> ChatOpenAIBuilder {
64 ChatOpenAIBuilder::default()
65 }
66
67 fn is_reasoning_model(&self) -> bool {
69 let model_lower = self.model.to_lowercase();
70 model_lower.starts_with("o1")
71 || model_lower.starts_with("o3")
72 || model_lower.starts_with("o4")
73 || model_lower.starts_with("gpt-5")
74 }
75
76 fn api_url(&self) -> String {
78 let base = self.base_url.as_deref().unwrap_or(OPENAI_BASE_URL);
79 format!("{}{}", base.trim_end_matches('/'), CHAT_COMPLETIONS_PATH)
80 }
81
82 fn build_client() -> Client {
84 Client::builder()
85 .timeout(Duration::from_secs(120))
86 .build()
87 .expect("Failed to create HTTP client")
88 }
89
90 fn get_context_window(model: &str) -> u64 {
92 let model_lower = model.to_lowercase();
93
94 if model_lower.contains("gpt-4o") || model_lower.contains("gpt-4-turbo") {
96 128_000
97 }
98 else if model_lower.starts_with("gpt-4") {
100 8_192
101 }
102 else if model_lower.starts_with("gpt-3.5") {
104 16_385
105 }
106 else if model_lower.starts_with("o1")
108 || model_lower.starts_with("o3")
109 || model_lower.starts_with("o4")
110 {
111 200_000
112 }
113 else {
115 128_000
116 }
117 }
118}
119
120impl ChatOpenAIBuilder {
121 pub fn build(&self) -> Result<ChatOpenAI, LlmError> {
122 let model = self
123 .model
124 .clone()
125 .ok_or_else(|| LlmError::Config("model is required".into()))?;
126 let api_key = self
127 .api_key
128 .clone()
129 .ok_or_else(|| LlmError::Config("api_key is required".into()))?;
130
131 Ok(ChatOpenAI {
132 context_window: ChatOpenAI::get_context_window(&model),
133 client: ChatOpenAI::build_client(),
134 model,
135 api_key,
136 base_url: self.base_url.clone().flatten(),
137 temperature: self.temperature.unwrap_or(0.2),
138 max_completion_tokens: self.max_completion_tokens.flatten(),
139 })
140 }
141}
142
143#[async_trait]
144impl BaseChatModel for ChatOpenAI {
145 fn model(&self) -> &str {
146 &self.model
147 }
148
149 fn provider(&self) -> &str {
150 "openai"
151 }
152
153 fn context_window(&self) -> Option<u64> {
154 Some(self.context_window)
155 }
156
157 async fn invoke(
158 &self,
159 messages: Vec<Message>,
160 tools: Option<Vec<ToolDefinition>>,
161 tool_choice: Option<ToolChoice>,
162 ) -> Result<ChatCompletion, LlmError> {
163 let request = self.build_request(messages, tools, tool_choice, false)?;
164
165 let response = self
166 .client
167 .post(self.api_url())
168 .header("Authorization", format!("Bearer {}", self.api_key))
169 .header("Content-Type", "application/json")
170 .json(&request)
171 .send()
172 .await?;
173
174 if !response.status().is_success() {
175 let status = response.status();
176 if status.as_u16() == 429 {
177 return Err(LlmError::RateLimit);
178 }
179 let body = response.text().await.unwrap_or_default();
180 return Err(LlmError::Api(format!(
181 "OpenAI API error ({}): {}",
182 status, body
183 )));
184 }
185 let body = response.text().await?;
186 tracing::debug!("OpenAI raw response: {}", body);
187
188 if body.trim_start().starts_with("data:") {
191 return self.parse_sse_as_completion(&body);
192 }
193
194 let completion: OpenAIResponse = serde_json::from_str(&body).map_err(|e| {
195 LlmError::Api(format!(
196 "Failed to parse response: {}\nBody: {}",
197 e,
198 &body[..body.len().min(500)]
199 ))
200 })?;
201 Ok(self.parse_response(completion))
202 }
203
204 async fn invoke_stream(
205 &self,
206 messages: Vec<Message>,
207 tools: Option<Vec<ToolDefinition>>,
208 tool_choice: Option<ToolChoice>,
209 ) -> Result<ChatStream, LlmError> {
210 let request = self.build_request(messages, tools, tool_choice, true)?;
211
212 let response = self
213 .client
214 .post(self.api_url())
215 .header("Authorization", format!("Bearer {}", self.api_key))
216 .header("Content-Type", "application/json")
217 .json(&request)
218 .send()
219 .await?;
220
221 if !response.status().is_success() {
222 let status = response.status();
223 if status.as_u16() == 429 {
224 return Err(LlmError::RateLimit);
225 }
226 let body = response.text().await.unwrap_or_default();
227 return Err(LlmError::Api(format!(
228 "OpenAI API error ({}): {}",
229 status, body
230 )));
231 }
232
233 let stream = response.bytes_stream().filter_map(|result| async move {
234 match result {
235 Ok(bytes) => {
236 let text = String::from_utf8_lossy(&bytes);
237 Self::parse_stream_chunk(&text)
238 }
239 Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
240 }
241 });
242
243 Ok(Box::pin(stream))
244 }
245
246 fn supports_vision(&self) -> bool {
247 let model_lower = self.model.to_lowercase();
248 model_lower.contains("gpt-4o")
249 || model_lower.contains("gpt-4-turbo")
250 || model_lower.contains("gpt-4-vision")
251 || model_lower.contains("gpt-4.1")
252 }
253}