1use std::env;
41use std::pin::Pin;
42
43use async_trait::async_trait;
44use futures::Stream;
45
46use super::error::BackendError;
47use super::openai_compat::{OpenAICompatConfig, OpenAICompatibleBackend};
48use super::{Backend, Capability, ChatRequest, ChatResponse, ChatStream};
49
50const API_KEY_ENV: &str = "OPENAI_API_KEY";
51
52pub struct OpenAIBackend {
56 inner: OpenAICompatibleBackend,
57}
58
59impl OpenAIBackend {
60 pub fn from_env() -> Self {
63 let api_key = env::var(API_KEY_ENV).ok();
64 Self::with_api_key(api_key)
65 }
66
67 pub fn with_api_key(api_key: Option<String>) -> Self {
69 Self {
70 inner: OpenAICompatibleBackend::new(OpenAICompatConfig::openai(), api_key),
71 }
72 }
73
74 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
78 self.inner = self.inner.with_base_url(base_url);
79 self
80 }
81
82 pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
84 self.inner = self.inner.with_default_model(model);
85 self
86 }
87
88 pub fn inner(&self) -> &OpenAICompatibleBackend {
91 &self.inner
92 }
93}
94
95impl Default for OpenAIBackend {
96 fn default() -> Self {
97 Self::from_env()
98 }
99}
100
101#[async_trait]
102impl Backend for OpenAIBackend {
103 fn name(&self) -> &str {
104 self.inner.name()
105 }
106
107 fn default_model(&self) -> &str {
108 self.inner.default_model()
109 }
110
111 async fn complete(&self, request: ChatRequest) -> Result<ChatResponse, BackendError> {
112 self.inner.complete(request).await
113 }
114
115 async fn stream(&self, request: ChatRequest) -> Result<ChatStream, BackendError> {
116 self.inner.stream(request).await
117 }
118
119 fn count_tokens(&self, model: &str, text: &str) -> usize {
120 self.inner.count_tokens(model, text)
121 }
122
123 fn supports(&self, capability: Capability, model: &str) -> bool {
124 match capability {
125 Capability::Vision => model.to_lowercase().starts_with("gpt-4o"),
130 other => self.inner.supports(other, model),
134 }
135 }
136}
137
138pub fn from_env() -> OpenAIBackend {
147 OpenAIBackend::from_env()
148}
149
150pub fn with_api_key(api_key: Option<String>) -> OpenAIBackend {
152 OpenAIBackend::with_api_key(api_key)
153}
154
155#[allow(dead_code)]
156type OpenAIChatStream =
157 Pin<Box<dyn Stream<Item = Result<crate::backends::ChatChunk, BackendError>> + Send>>;
158
159#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::backends::Message;
167
168 #[test]
171 fn from_env_constructs_openai_backend() {
172 let b = OpenAIBackend::from_env();
173 assert_eq!(b.name(), "openai");
174 assert_eq!(b.default_model(), "gpt-4o-mini");
175 }
176
177 #[test]
178 fn module_factory_from_env_works() {
179 let b = from_env();
180 assert_eq!(b.name(), "openai");
181 }
182
183 #[test]
184 fn module_factory_with_api_key_explicit() {
185 let b = with_api_key(Some("sk-test".into()));
186 assert_eq!(b.name(), "openai");
187 }
188
189 #[test]
190 fn with_base_url_overrides() {
191 let b = OpenAIBackend::with_api_key(Some("k".into()))
192 .with_base_url("http://localhost:1234");
193 let _ = b;
197 }
198
199 #[test]
200 fn with_default_model_overrides() {
201 let b = OpenAIBackend::with_api_key(Some("k".into()))
202 .with_default_model("o1-mini");
203 assert_eq!(b.default_model(), "o1-mini");
204 }
205
206 #[test]
209 fn supports_vision_for_gpt_4o_family() {
210 let b = OpenAIBackend::with_api_key(Some("k".into()));
211 assert!(b.supports(Capability::Vision, "gpt-4o"));
212 assert!(b.supports(Capability::Vision, "gpt-4o-mini"));
213 assert!(b.supports(Capability::Vision, "gpt-4o-2024-08-06"));
214 }
215
216 #[test]
217 fn does_not_support_vision_for_older_models() {
218 let b = OpenAIBackend::with_api_key(Some("k".into()));
219 assert!(!b.supports(Capability::Vision, "gpt-3.5-turbo"));
220 assert!(!b.supports(Capability::Vision, "gpt-4"));
221 assert!(!b.supports(Capability::Vision, "gpt-4-turbo"));
222 }
223
224 #[test]
225 fn does_not_support_vision_for_reasoning_models() {
226 let b = OpenAIBackend::with_api_key(Some("k".into()));
228 assert!(!b.supports(Capability::Vision, "o1"));
229 assert!(!b.supports(Capability::Vision, "o1-mini"));
230 assert!(!b.supports(Capability::Vision, "o3-mini"));
231 }
232
233 #[test]
234 fn vision_is_case_insensitive() {
235 let b = OpenAIBackend::with_api_key(Some("k".into()));
236 assert!(b.supports(Capability::Vision, "GPT-4o-mini"));
237 }
238
239 #[test]
242 fn supports_lockedparams_for_o1_o3() {
243 let b = OpenAIBackend::with_api_key(Some("k".into()));
244 assert!(b.supports(Capability::LockedParams, "o1"));
245 assert!(b.supports(Capability::LockedParams, "o1-mini"));
246 assert!(b.supports(Capability::LockedParams, "o1-preview"));
247 assert!(b.supports(Capability::LockedParams, "o3"));
248 assert!(b.supports(Capability::LockedParams, "o3-mini"));
249 }
250
251 #[test]
252 fn does_not_support_lockedparams_for_chat_models() {
253 let b = OpenAIBackend::with_api_key(Some("k".into()));
254 assert!(!b.supports(Capability::LockedParams, "gpt-4o-mini"));
255 assert!(!b.supports(Capability::LockedParams, "gpt-3.5-turbo"));
256 assert!(!b.supports(Capability::LockedParams, "gpt-4"));
257 }
258
259 #[test]
262 fn supports_streaming_tooluse_structured_via_base() {
263 let b = OpenAIBackend::with_api_key(Some("k".into()));
264 assert!(b.supports(Capability::Streaming, "gpt-4o-mini"));
265 assert!(b.supports(Capability::ToolUse, "gpt-4o-mini"));
266 assert!(b.supports(Capability::StructuredOutput, "gpt-4o-mini"));
267 }
268
269 #[test]
270 fn does_not_support_anthropic_or_gemini_only_caps() {
271 let b = OpenAIBackend::with_api_key(Some("k".into()));
272 assert!(!b.supports(Capability::PromptCaching, "gpt-4o-mini"));
273 assert!(!b.supports(Capability::SafetySettings, "gpt-4o-mini"));
274 }
275
276 #[test]
279 fn count_tokens_uses_o200k_for_gpt_4o() {
280 let b = OpenAIBackend::with_api_key(Some("k".into()));
281 let n = b.count_tokens("gpt-4o-mini", "hello world");
282 assert!(n > 0);
284 assert!(n <= 5);
285 }
286
287 #[test]
288 fn count_tokens_uses_o200k_for_o1() {
289 let b = OpenAIBackend::with_api_key(Some("k".into()));
290 let n = b.count_tokens("o1-mini", "hello world");
291 assert!(n > 0);
292 }
293
294 #[tokio::test]
297 async fn complete_without_api_key_returns_auth_error() {
298 let b = OpenAIBackend::with_api_key(None).with_base_url("http://127.0.0.1:0");
299 let err = b
300 .complete(ChatRequest {
301 messages: vec![Message::user("hi")],
302 ..Default::default()
303 })
304 .await
305 .unwrap_err();
306 match err {
307 BackendError::Auth { api_key_env, .. } => {
308 assert_eq!(api_key_env.as_deref(), Some(API_KEY_ENV));
309 }
310 other => panic!("expected Auth, got {other:?}"),
311 }
312 }
313
314 #[tokio::test]
317 async fn stream_delegates_to_base_real_sse_implementation() {
318 let b = OpenAIBackend::with_api_key(Some("k".into()))
322 .with_base_url("http://127.0.0.1:1");
323 match b.stream(ChatRequest::default()).await {
324 Err(BackendError::Generic { ref message, .. }) => {
325 assert!(
328 message.contains("streaming transport failure")
329 || message.contains("transport"),
330 "unexpected message: {message}"
331 );
332 }
333 Err(other) => panic!("expected Generic, got {other:?}"),
334 Ok(_) => panic!("expected error, got Ok"),
335 }
336 }
337
338 #[test]
341 fn inner_accessor_returns_compat_backend() {
342 let b = OpenAIBackend::with_api_key(Some("k".into()));
343 assert_eq!(b.inner().name(), "openai");
344 }
345}