Skip to main content

llm_stack_openai/
provider.rs

1//! `OpenAI` `Provider` implementation.
2
3use std::collections::HashSet;
4
5use llm_stack::ChatResponse;
6use llm_stack::error::LlmError;
7use llm_stack::provider::{Capability, ChatParams, Provider, ProviderMetadata};
8use llm_stack::stream::ChatStream;
9use reqwest::header::{HeaderMap, HeaderValue};
10use tracing::instrument;
11
12use crate::config::OpenAiConfig;
13use crate::convert;
14
15/// `OpenAI` provider implementing [`Provider`].
16///
17/// Supports the `OpenAI` Chat Completions API with tool calling,
18/// structured output, and streaming.
19///
20/// # Example
21///
22/// ```rust,no_run
23/// use llm_stack_openai::{OpenAiConfig, OpenAiProvider};
24/// use llm_stack::{ChatParams, ChatMessage, Provider};
25///
26/// # async fn example() -> Result<(), llm_stack::LlmError> {
27/// let provider = OpenAiProvider::new(OpenAiConfig {
28///     api_key: std::env::var("OPENAI_API_KEY").unwrap(),
29///     ..Default::default()
30/// });
31///
32/// let response = provider.generate(&ChatParams {
33///     messages: vec![ChatMessage::user("Hello!")],
34///     ..Default::default()
35/// }).await?;
36/// # Ok(())
37/// # }
38/// ```
39#[derive(Debug)]
40pub struct OpenAiProvider {
41    config: OpenAiConfig,
42    client: reqwest::Client,
43}
44
45impl OpenAiProvider {
46    /// Create a new `OpenAI` provider from configuration.
47    ///
48    /// If `config.client` is `Some`, that client is reused for connection
49    /// pooling. Otherwise a new client is built with the configured timeout.
50    pub fn new(config: OpenAiConfig) -> Self {
51        let client = config.client.clone().unwrap_or_else(|| {
52            let mut builder = reqwest::Client::builder();
53            if let Some(timeout) = config.timeout {
54                builder = builder.timeout(timeout);
55            }
56            builder.build().expect("failed to build HTTP client")
57        });
58        Self { config, client }
59    }
60
61    /// Build the default headers for `OpenAI` API requests.
62    fn default_headers(&self) -> Result<HeaderMap, LlmError> {
63        let mut headers = HeaderMap::new();
64
65        let auth_value = format!("Bearer {}", self.config.api_key);
66        headers.insert(
67            "authorization",
68            HeaderValue::from_str(&auth_value)
69                .map_err(|_| LlmError::Auth("API key contains invalid header characters".into()))?,
70        );
71        headers.insert("content-type", HeaderValue::from_static("application/json"));
72
73        if let Some(org) = &self.config.organization {
74            headers.insert(
75                "openai-organization",
76                HeaderValue::from_str(org).map_err(|_| {
77                    LlmError::InvalidRequest(
78                        "Organization ID contains invalid header characters".into(),
79                    )
80                })?,
81            );
82        }
83
84        Ok(headers)
85    }
86
87    /// Build the full URL for the chat completions endpoint.
88    fn completions_url(&self) -> String {
89        let base = self.config.base_url.trim_end_matches('/');
90        format!("{base}/chat/completions")
91    }
92
93    /// Send a request to the `OpenAI` API and return the raw response.
94    async fn send_request(
95        &self,
96        params: &ChatParams,
97        stream: bool,
98    ) -> Result<reqwest::Response, LlmError> {
99        let request_body = convert::build_request(params, &self.config, stream)?;
100
101        let mut headers = self.default_headers()?;
102        if let Some(extra) = &params.extra_headers {
103            headers.extend(extra.iter().map(|(k, v)| (k.clone(), v.clone())));
104        }
105
106        let mut req = self
107            .client
108            .post(self.completions_url())
109            .headers(headers)
110            .json(&request_body);
111
112        if let Some(timeout) = params.timeout {
113            req = req.timeout(timeout);
114        }
115
116        let response = req.send().await.map_err(|e| {
117            if e.is_timeout() {
118                LlmError::Timeout {
119                    elapsed_ms: params
120                        .timeout
121                        .or(self.config.timeout)
122                        .map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX)),
123                }
124            } else {
125                LlmError::Http {
126                    status: e.status().map(|s| {
127                        http::StatusCode::from_u16(s.as_u16())
128                            .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR)
129                    }),
130                    message: e.to_string(),
131                    retryable: e.is_connect() || e.is_timeout(),
132                }
133            }
134        })?;
135
136        let status = response.status();
137        if !status.is_success() {
138            let body = response.text().await.unwrap_or_default();
139            let http_status = http::StatusCode::from_u16(status.as_u16())
140                .unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR);
141            return Err(convert::convert_error(http_status, &body));
142        }
143
144        Ok(response)
145    }
146}
147
148impl Provider for OpenAiProvider {
149    #[instrument(skip_all, fields(model = %self.config.model))]
150    async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
151        let response = self.send_request(params, false).await?;
152
153        let body = response
154            .text()
155            .await
156            .map_err(|e| LlmError::ResponseFormat {
157                message: format!("Failed to read OpenAI response body: {e}"),
158                raw: String::new(),
159            })?;
160
161        let api_response: crate::types::Response =
162            serde_json::from_str(&body).map_err(|e| LlmError::ResponseFormat {
163                message: format!("Failed to parse OpenAI response: {e}"),
164                raw: body,
165            })?;
166
167        Ok(convert::convert_response(api_response))
168    }
169
170    #[instrument(skip_all, fields(model = %self.config.model))]
171    async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
172        let response = self.send_request(params, true).await?;
173        Ok(crate::stream::into_stream(response))
174    }
175
176    fn metadata(&self) -> ProviderMetadata {
177        let mut capabilities = HashSet::new();
178        capabilities.insert(Capability::Tools);
179        capabilities.insert(Capability::Vision);
180        capabilities.insert(Capability::StructuredOutput);
181
182        // o-series models support reasoning
183        if self.config.model.starts_with("o1")
184            || self.config.model.starts_with("o3")
185            || self.config.model.starts_with("o4")
186        {
187            capabilities.insert(Capability::Reasoning);
188        }
189
190        ProviderMetadata {
191            name: "openai".into(),
192            model: self.config.model.clone(),
193            context_window: context_window_for_model(&self.config.model),
194            capabilities,
195        }
196    }
197}
198
199/// Look up the context window size for known `OpenAI` models.
200fn context_window_for_model(model: &str) -> u64 {
201    if model.starts_with("gpt-4o") || model.starts_with("gpt-4.1") {
202        128_000
203    } else if model.starts_with("o1") || model.starts_with("o3") || model.starts_with("o4") {
204        200_000
205    } else if model.starts_with("gpt-4") {
206        128_000
207    } else if model.starts_with("gpt-3.5") {
208        16_385
209    } else {
210        128_000
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use std::time::Duration;
217
218    use super::*;
219
220    #[test]
221    fn test_metadata() {
222        let provider = OpenAiProvider::new(OpenAiConfig {
223            model: "gpt-4o".into(),
224            ..Default::default()
225        });
226        let meta = provider.metadata();
227
228        assert_eq!(meta.name, "openai");
229        assert_eq!(meta.model, "gpt-4o");
230        assert_eq!(meta.context_window, 128_000);
231        assert!(meta.capabilities.contains(&Capability::Tools));
232        assert!(meta.capabilities.contains(&Capability::Vision));
233        assert!(meta.capabilities.contains(&Capability::StructuredOutput));
234        assert!(!meta.capabilities.contains(&Capability::Reasoning));
235    }
236
237    #[test]
238    fn test_metadata_reasoning_model() {
239        let provider = OpenAiProvider::new(OpenAiConfig {
240            model: "o1-mini".into(),
241            ..Default::default()
242        });
243        let meta = provider.metadata();
244
245        assert!(meta.capabilities.contains(&Capability::Reasoning));
246        assert_eq!(meta.context_window, 200_000);
247    }
248
249    #[test]
250    fn test_context_window_gpt4o() {
251        assert_eq!(context_window_for_model("gpt-4o"), 128_000);
252        assert_eq!(context_window_for_model("gpt-4o-mini"), 128_000);
253    }
254
255    #[test]
256    fn test_context_window_gpt35() {
257        assert_eq!(context_window_for_model("gpt-3.5-turbo"), 16_385);
258    }
259
260    #[test]
261    fn test_context_window_unknown() {
262        assert_eq!(context_window_for_model("some-future-model"), 128_000);
263    }
264
265    #[test]
266    fn test_completions_url() {
267        let provider = OpenAiProvider::new(OpenAiConfig {
268            base_url: "https://api.openai.com/v1".into(),
269            ..Default::default()
270        });
271        assert_eq!(
272            provider.completions_url(),
273            "https://api.openai.com/v1/chat/completions"
274        );
275    }
276
277    #[test]
278    fn test_completions_url_trailing_slash() {
279        let provider = OpenAiProvider::new(OpenAiConfig {
280            base_url: "https://proxy.example.com/v1/".into(),
281            ..Default::default()
282        });
283        assert_eq!(
284            provider.completions_url(),
285            "https://proxy.example.com/v1/chat/completions"
286        );
287    }
288
289    #[test]
290    fn test_default_headers() {
291        let provider = OpenAiProvider::new(OpenAiConfig {
292            api_key: "sk-test123".into(),
293            ..Default::default()
294        });
295        let headers = provider.default_headers().unwrap();
296
297        assert_eq!(headers.get("authorization").unwrap(), "Bearer sk-test123");
298        assert_eq!(headers.get("content-type").unwrap(), "application/json");
299    }
300
301    #[test]
302    fn test_default_headers_with_org() {
303        let provider = OpenAiProvider::new(OpenAiConfig {
304            api_key: "sk-test123".into(),
305            organization: Some("org-abc".into()),
306            ..Default::default()
307        });
308        let headers = provider.default_headers().unwrap();
309
310        assert_eq!(headers.get("openai-organization").unwrap(), "org-abc");
311    }
312
313    #[test]
314    fn test_default_headers_invalid_key() {
315        let provider = OpenAiProvider::new(OpenAiConfig {
316            api_key: "invalid\nkey".into(),
317            ..Default::default()
318        });
319        let err = provider.default_headers().unwrap_err();
320        assert!(matches!(err, LlmError::Auth(_)));
321    }
322
323    #[test]
324    fn test_new_with_custom_client() {
325        let custom_client = reqwest::Client::builder()
326            .timeout(Duration::from_secs(10))
327            .build()
328            .unwrap();
329
330        let provider = OpenAiProvider::new(OpenAiConfig {
331            client: Some(custom_client),
332            ..Default::default()
333        });
334        assert_eq!(provider.metadata().name, "openai");
335    }
336
337    #[test]
338    fn test_new_with_timeout() {
339        let provider = OpenAiProvider::new(OpenAiConfig {
340            timeout: Some(Duration::from_secs(30)),
341            ..Default::default()
342        });
343        assert_eq!(provider.metadata().name, "openai");
344    }
345}