llm_stack_openai/
provider.rs1use std::collections::HashSet;
4
5use llm_stack::ChatResponse;
6use llm_stack::error::LlmError;
7use llm_stack::provider::{Capability, ChatParams, Provider, ProviderMetadata};
8use llm_stack::stream::ChatStream;
9use reqwest::header::{HeaderMap, HeaderValue};
10use tracing::instrument;
11
12use crate::config::OpenAiConfig;
13use crate::convert;
14
15#[derive(Debug)]
40pub struct OpenAiProvider {
41 config: OpenAiConfig,
42 client: reqwest::Client,
43}
44
45impl OpenAiProvider {
46 pub fn new(config: OpenAiConfig) -> Self {
51 let client = config.client.clone().unwrap_or_else(|| {
52 let mut builder = reqwest::Client::builder();
53 if let Some(timeout) = config.timeout {
54 builder = builder.timeout(timeout);
55 }
56 builder.build().expect("failed to build HTTP client")
57 });
58 Self { config, client }
59 }
60
61 fn default_headers(&self) -> Result<HeaderMap, LlmError> {
63 let mut headers = HeaderMap::new();
64
65 let auth_value = format!("Bearer {}", self.config.api_key);
66 headers.insert(
67 "authorization",
68 HeaderValue::from_str(&auth_value)
69 .map_err(|_| LlmError::Auth("API key contains invalid header characters".into()))?,
70 );
71 headers.insert("content-type", HeaderValue::from_static("application/json"));
72
73 if let Some(org) = &self.config.organization {
74 headers.insert(
75 "openai-organization",
76 HeaderValue::from_str(org).map_err(|_| {
77 LlmError::InvalidRequest(
78 "Organization ID contains invalid header characters".into(),
79 )
80 })?,
81 );
82 }
83
84 Ok(headers)
85 }
86
87 fn completions_url(&self) -> String {
89 let base = self.config.base_url.trim_end_matches('/');
90 format!("{base}/chat/completions")
91 }
92
93 async fn send_request(
95 &self,
96 params: &ChatParams,
97 stream: bool,
98 ) -> Result<reqwest::Response, LlmError> {
99 let request_body = convert::build_request(params, &self.config, stream)?;
100
101 let mut headers = self.default_headers()?;
102 if let Some(extra) = ¶ms.extra_headers {
103 headers.extend(extra.iter().map(|(k, v)| (k.clone(), v.clone())));
104 }
105
106 let mut req = self
107 .client
108 .post(self.completions_url())
109 .headers(headers)
110 .json(&request_body);
111
112 if let Some(timeout) = params.timeout {
113 req = req.timeout(timeout);
114 }
115
116 let response = req.send().await.map_err(|e| {
117 if e.is_timeout() {
118 LlmError::Timeout {
119 elapsed_ms: params
120 .timeout
121 .or(self.config.timeout)
122 .map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX)),
123 }
124 } else {
125 LlmError::Http {
126 status: e.status().map(|s| {
127 http::StatusCode::from_u16(s.as_u16())
128 .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR)
129 }),
130 message: e.to_string(),
131 retryable: e.is_connect() || e.is_timeout(),
132 }
133 }
134 })?;
135
136 let status = response.status();
137 if !status.is_success() {
138 let body = response.text().await.unwrap_or_default();
139 let http_status = http::StatusCode::from_u16(status.as_u16())
140 .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR);
141 return Err(convert::convert_error(http_status, &body));
142 }
143
144 Ok(response)
145 }
146}
147
148impl Provider for OpenAiProvider {
149 #[instrument(skip_all, fields(model = %self.config.model))]
150 async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
151 let response = self.send_request(params, false).await?;
152
153 let body = response
154 .text()
155 .await
156 .map_err(|e| LlmError::ResponseFormat {
157 message: format!("Failed to read OpenAI response body: {e}"),
158 raw: String::new(),
159 })?;
160
161 let api_response: crate::types::Response =
162 serde_json::from_str(&body).map_err(|e| LlmError::ResponseFormat {
163 message: format!("Failed to parse OpenAI response: {e}"),
164 raw: body,
165 })?;
166
167 Ok(convert::convert_response(api_response))
168 }
169
170 #[instrument(skip_all, fields(model = %self.config.model))]
171 async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
172 let response = self.send_request(params, true).await?;
173 Ok(crate::stream::into_stream(response))
174 }
175
176 fn metadata(&self) -> ProviderMetadata {
177 let mut capabilities = HashSet::new();
178 capabilities.insert(Capability::Tools);
179 capabilities.insert(Capability::Vision);
180 capabilities.insert(Capability::StructuredOutput);
181
182 if self.config.model.starts_with("o1")
184 || self.config.model.starts_with("o3")
185 || self.config.model.starts_with("o4")
186 {
187 capabilities.insert(Capability::Reasoning);
188 }
189
190 ProviderMetadata {
191 name: "openai".into(),
192 model: self.config.model.clone(),
193 context_window: context_window_for_model(&self.config.model),
194 capabilities,
195 }
196 }
197}
198
199fn context_window_for_model(model: &str) -> u64 {
201 if model.starts_with("gpt-4o") || model.starts_with("gpt-4.1") {
202 128_000
203 } else if model.starts_with("o1") || model.starts_with("o3") || model.starts_with("o4") {
204 200_000
205 } else if model.starts_with("gpt-4") {
206 128_000
207 } else if model.starts_with("gpt-3.5") {
208 16_385
209 } else {
210 128_000
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use std::time::Duration;
217
218 use super::*;
219
220 #[test]
221 fn test_metadata() {
222 let provider = OpenAiProvider::new(OpenAiConfig {
223 model: "gpt-4o".into(),
224 ..Default::default()
225 });
226 let meta = provider.metadata();
227
228 assert_eq!(meta.name, "openai");
229 assert_eq!(meta.model, "gpt-4o");
230 assert_eq!(meta.context_window, 128_000);
231 assert!(meta.capabilities.contains(&Capability::Tools));
232 assert!(meta.capabilities.contains(&Capability::Vision));
233 assert!(meta.capabilities.contains(&Capability::StructuredOutput));
234 assert!(!meta.capabilities.contains(&Capability::Reasoning));
235 }
236
237 #[test]
238 fn test_metadata_reasoning_model() {
239 let provider = OpenAiProvider::new(OpenAiConfig {
240 model: "o1-mini".into(),
241 ..Default::default()
242 });
243 let meta = provider.metadata();
244
245 assert!(meta.capabilities.contains(&Capability::Reasoning));
246 assert_eq!(meta.context_window, 200_000);
247 }
248
249 #[test]
250 fn test_context_window_gpt4o() {
251 assert_eq!(context_window_for_model("gpt-4o"), 128_000);
252 assert_eq!(context_window_for_model("gpt-4o-mini"), 128_000);
253 }
254
255 #[test]
256 fn test_context_window_gpt35() {
257 assert_eq!(context_window_for_model("gpt-3.5-turbo"), 16_385);
258 }
259
260 #[test]
261 fn test_context_window_unknown() {
262 assert_eq!(context_window_for_model("some-future-model"), 128_000);
263 }
264
265 #[test]
266 fn test_completions_url() {
267 let provider = OpenAiProvider::new(OpenAiConfig {
268 base_url: "https://api.openai.com/v1".into(),
269 ..Default::default()
270 });
271 assert_eq!(
272 provider.completions_url(),
273 "https://api.openai.com/v1/chat/completions"
274 );
275 }
276
277 #[test]
278 fn test_completions_url_trailing_slash() {
279 let provider = OpenAiProvider::new(OpenAiConfig {
280 base_url: "https://proxy.example.com/v1/".into(),
281 ..Default::default()
282 });
283 assert_eq!(
284 provider.completions_url(),
285 "https://proxy.example.com/v1/chat/completions"
286 );
287 }
288
289 #[test]
290 fn test_default_headers() {
291 let provider = OpenAiProvider::new(OpenAiConfig {
292 api_key: "sk-test123".into(),
293 ..Default::default()
294 });
295 let headers = provider.default_headers().unwrap();
296
297 assert_eq!(headers.get("authorization").unwrap(), "Bearer sk-test123");
298 assert_eq!(headers.get("content-type").unwrap(), "application/json");
299 }
300
301 #[test]
302 fn test_default_headers_with_org() {
303 let provider = OpenAiProvider::new(OpenAiConfig {
304 api_key: "sk-test123".into(),
305 organization: Some("org-abc".into()),
306 ..Default::default()
307 });
308 let headers = provider.default_headers().unwrap();
309
310 assert_eq!(headers.get("openai-organization").unwrap(), "org-abc");
311 }
312
313 #[test]
314 fn test_default_headers_invalid_key() {
315 let provider = OpenAiProvider::new(OpenAiConfig {
316 api_key: "invalid\nkey".into(),
317 ..Default::default()
318 });
319 let err = provider.default_headers().unwrap_err();
320 assert!(matches!(err, LlmError::Auth(_)));
321 }
322
323 #[test]
324 fn test_new_with_custom_client() {
325 let custom_client = reqwest::Client::builder()
326 .timeout(Duration::from_secs(10))
327 .build()
328 .unwrap();
329
330 let provider = OpenAiProvider::new(OpenAiConfig {
331 client: Some(custom_client),
332 ..Default::default()
333 });
334 assert_eq!(provider.metadata().name, "openai");
335 }
336
337 #[test]
338 fn test_new_with_timeout() {
339 let provider = OpenAiProvider::new(OpenAiConfig {
340 timeout: Some(Duration::from_secs(30)),
341 ..Default::default()
342 });
343 assert_eq!(provider.metadata().name, "openai");
344 }
345}