agent_io/llm/openai/
mod.rs1mod request;
4mod response;
5mod types;
6
7use async_trait::async_trait;
8use derive_builder::Builder;
9use futures::StreamExt;
10use reqwest::Client;
11use std::time::Duration;
12
13use crate::llm::{
14 BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
15};
16
17use types::*;
18
19const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
20
21#[derive(Builder, Clone)]
23#[builder(pattern = "owned", build_fn(skip))]
24pub struct ChatOpenAI {
25 #[builder(setter(into))]
27 pub(super) model: String,
28 pub(super) api_key: String,
30 #[builder(setter(into, strip_option), default = "None")]
32 pub(super) base_url: Option<String>,
33 #[builder(default = "0.2")]
35 pub(super) temperature: f32,
36 #[builder(default = "Some(4096)")]
38 pub(super) max_completion_tokens: Option<u64>,
39 #[builder(default = "ReasoningEffort::Low")]
41 pub(super) reasoning_effort: ReasoningEffort,
42 #[builder(setter(skip))]
44 pub(super) client: Client,
45 #[builder(setter(skip))]
47 pub(super) context_window: u64,
48}
49
50impl ChatOpenAI {
51 pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
53 let api_key = std::env::var("OPENAI_API_KEY")
54 .map_err(|_| LlmError::Config("OPENAI_API_KEY not set".into()))?;
55
56 Self::builder().model(model).api_key(api_key).build()
57 }
58
59 pub fn builder() -> ChatOpenAIBuilder {
61 ChatOpenAIBuilder::default()
62 }
63
64 fn is_reasoning_model(&self) -> bool {
66 let model_lower = self.model.to_lowercase();
67 model_lower.starts_with("o1")
68 || model_lower.starts_with("o3")
69 || model_lower.starts_with("o4")
70 || model_lower.starts_with("gpt-5")
71 }
72
73 fn api_url(&self) -> &str {
75 self.base_url.as_deref().unwrap_or(OPENAI_API_URL)
76 }
77
78 fn build_client() -> Client {
80 Client::builder()
81 .timeout(Duration::from_secs(120))
82 .build()
83 .expect("Failed to create HTTP client")
84 }
85
86 fn get_context_window(model: &str) -> u64 {
88 let model_lower = model.to_lowercase();
89
90 if model_lower.contains("gpt-4o") || model_lower.contains("gpt-4-turbo") {
92 128_000
93 }
94 else if model_lower.starts_with("gpt-4") {
96 8_192
97 }
98 else if model_lower.starts_with("gpt-3.5") {
100 16_385
101 }
102 else if model_lower.starts_with("o1")
104 || model_lower.starts_with("o3")
105 || model_lower.starts_with("o4")
106 {
107 200_000
108 }
109 else {
111 128_000
112 }
113 }
114}
115
116impl ChatOpenAIBuilder {
117 pub fn build(&self) -> Result<ChatOpenAI, LlmError> {
118 let model = self
119 .model
120 .clone()
121 .ok_or_else(|| LlmError::Config("model is required".into()))?;
122 let api_key = self
123 .api_key
124 .clone()
125 .ok_or_else(|| LlmError::Config("api_key is required".into()))?;
126
127 Ok(ChatOpenAI {
128 context_window: ChatOpenAI::get_context_window(&model),
129 client: ChatOpenAI::build_client(),
130 model,
131 api_key,
132 base_url: self.base_url.clone().flatten(),
133 temperature: self.temperature.unwrap_or(0.2),
134 max_completion_tokens: self.max_completion_tokens.flatten(),
135 reasoning_effort: self.reasoning_effort.clone().unwrap_or_default(),
136 })
137 }
138}
139
140#[async_trait]
141impl BaseChatModel for ChatOpenAI {
142 fn model(&self) -> &str {
143 &self.model
144 }
145
146 fn provider(&self) -> &str {
147 "openai"
148 }
149
150 fn context_window(&self) -> Option<u64> {
151 Some(self.context_window)
152 }
153
154 async fn invoke(
155 &self,
156 messages: Vec<Message>,
157 tools: Option<Vec<ToolDefinition>>,
158 tool_choice: Option<ToolChoice>,
159 ) -> Result<ChatCompletion, LlmError> {
160 let request = self.build_request(messages, tools, tool_choice, false)?;
161
162 let response = self
163 .client
164 .post(self.api_url())
165 .header("Authorization", format!("Bearer {}", self.api_key))
166 .header("Content-Type", "application/json")
167 .json(&request)
168 .send()
169 .await?;
170
171 if !response.status().is_success() {
172 let status = response.status();
173 let body = response.text().await.unwrap_or_default();
174 return Err(LlmError::Api(format!(
175 "OpenAI API error ({}): {}",
176 status, body
177 )));
178 }
179
180 let completion: OpenAIResponse = response.json().await?;
181 Ok(self.parse_response(completion))
182 }
183
184 async fn invoke_stream(
185 &self,
186 messages: Vec<Message>,
187 tools: Option<Vec<ToolDefinition>>,
188 tool_choice: Option<ToolChoice>,
189 ) -> Result<ChatStream, LlmError> {
190 let request = self.build_request(messages, tools, tool_choice, true)?;
191
192 let response = self
193 .client
194 .post(self.api_url())
195 .header("Authorization", format!("Bearer {}", self.api_key))
196 .header("Content-Type", "application/json")
197 .json(&request)
198 .send()
199 .await?;
200
201 if !response.status().is_success() {
202 let status = response.status();
203 let body = response.text().await.unwrap_or_default();
204 return Err(LlmError::Api(format!(
205 "OpenAI API error ({}): {}",
206 status, body
207 )));
208 }
209
210 let stream = response.bytes_stream().filter_map(|result| async move {
211 match result {
212 Ok(bytes) => {
213 let text = String::from_utf8_lossy(&bytes);
214 Self::parse_stream_chunk(&text)
215 }
216 Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
217 }
218 });
219
220 Ok(Box::pin(stream))
221 }
222
223 fn supports_vision(&self) -> bool {
224 let model_lower = self.model.to_lowercase();
225 model_lower.contains("gpt-4o")
226 || model_lower.contains("gpt-4-turbo")
227 || model_lower.contains("gpt-4-vision")
228 || model_lower.contains("gpt-4.1")
229 }
230}
231
232pub use types::ReasoningEffort;