Skip to main content

axon/backends/
openai.rs

1//! OpenAI Chat Completions backend — Fase 24.d.
2//!
3//! Thin factory + capability override on top of [`OpenAICompatibleBackend`].
4//! The OpenAI provider is the canonical OpenAI-compat shape; everything
5//! about the wire (Bearer auth, `/v1/chat/completions`, OpenAI tool
6//! envelope) lives in the shared base.
7//!
8//! What this module adds on top of the shared base:
9//!
10//!   * [`from_env`] / [`with_api_key`] factories that pin
11//!     [`OpenAICompatConfig::openai`] (base URL, default model, env var).
12//!   * Vision support discovery — `gpt-4o*` models accept image content
13//!     blocks; older models don't. The shared base conservatively
14//!     reports `Capability::Vision = false`; OpenAI's adapter overrides
15//!     to `true` for the gpt-4o family.
16//!   * o1 / o3 reasoning models work transparently — the locked-model
17//!     dispatch in the body builder strips `temperature` / `top_p` /
18//!     `presence_penalty` / `frequency_penalty` / `logprobs` /
19//!     `logit_bias` for those families, so adopters can pass any
20//!     sampling params they like and they're silently filtered out.
21//!     `Capability::LockedParams` returns `true` for the resolved model
22//!     when this filtering would fire.
23//!
24//! # Example
25//!
26//! ```ignore
27//! use axon::backends::{openai, Backend, ChatRequest, Message};
28//!
29//! let backend = openai::from_env();
30//! let request = ChatRequest {
31//!     model: "gpt-4o-mini".into(),
32//!     messages: vec![Message::user("Hello!")],
33//!     temperature: Some(0.7),
34//!     ..Default::default()
35//! };
36//! let response = backend.complete(request).await?;
37//! println!("{}", response.content);
38//! ```
39
40use std::env;
41use std::pin::Pin;
42
43use async_trait::async_trait;
44use futures::Stream;
45
46use super::error::BackendError;
47use super::openai_compat::{OpenAICompatConfig, OpenAICompatibleBackend};
48use super::{Backend, Capability, ChatRequest, ChatResponse, ChatStream};
49
50const API_KEY_ENV: &str = "OPENAI_API_KEY";
51
52/// OpenAI Chat Completions backend. Composes [`OpenAICompatibleBackend`]
53/// with the OpenAI preset + a capability override for `Vision` on the
54/// gpt-4o family.
55pub struct OpenAIBackend {
56    inner: OpenAICompatibleBackend,
57}
58
59impl OpenAIBackend {
60    /// Construct from env. `OPENAI_API_KEY` is read at construction time;
61    /// `None` is permitted (auth check fires at first call).
62    pub fn from_env() -> Self {
63        let api_key = env::var(API_KEY_ENV).ok();
64        Self::with_api_key(api_key)
65    }
66
67    /// Construct with an explicit API key (or `None`).
68    pub fn with_api_key(api_key: Option<String>) -> Self {
69        Self {
70            inner: OpenAICompatibleBackend::new(OpenAICompatConfig::openai(), api_key),
71        }
72    }
73
74    /// Override the base URL (test fixtures, mock servers, Azure
75    /// OpenAI-compatible deployments). Returns `self` for builder
76    /// chaining.
77    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
78        self.inner = self.inner.with_base_url(base_url);
79        self
80    }
81
82    /// Override the default model.
83    pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
84        self.inner = self.inner.with_default_model(model);
85        self
86    }
87
88    /// Borrow the underlying [`OpenAICompatibleBackend`] (for testing
89    /// fixtures that need access to the composed inner state).
90    pub fn inner(&self) -> &OpenAICompatibleBackend {
91        &self.inner
92    }
93}
94
95impl Default for OpenAIBackend {
96    fn default() -> Self {
97        Self::from_env()
98    }
99}
100
101#[async_trait]
102impl Backend for OpenAIBackend {
103    fn name(&self) -> &str {
104        self.inner.name()
105    }
106
107    fn default_model(&self) -> &str {
108        self.inner.default_model()
109    }
110
111    async fn complete(&self, request: ChatRequest) -> Result<ChatResponse, BackendError> {
112        self.inner.complete(request).await
113    }
114
115    async fn stream(&self, request: ChatRequest) -> Result<ChatStream, BackendError> {
116        self.inner.stream(request).await
117    }
118
119    fn count_tokens(&self, model: &str, text: &str) -> usize {
120        self.inner.count_tokens(model, text)
121    }
122
123    fn supports(&self, capability: Capability, model: &str) -> bool {
124        match capability {
125            // OpenAI gpt-4o family supports image content blocks.
126            // Older models (gpt-3.5, gpt-4 turbo) and reasoning models
127            // (o1*, o3*) do not. Conservative match: only gpt-4o* gets
128            // a true here.
129            Capability::Vision => model.to_lowercase().starts_with("gpt-4o"),
130            // Everything else delegates to the shared base — Streaming,
131            // ToolUse, StructuredOutput, LockedParams (for o1/o3) all
132            // return whatever the base reports.
133            other => self.inner.supports(other, model),
134        }
135    }
136}
137
138// ────────────────────────────────────────────────────────────────────
139//  Module-level factories — adopter-friendly entry points
140// ────────────────────────────────────────────────────────────────────
141
142/// Construct an OpenAI backend using the `OPENAI_API_KEY` env var.
143///
144/// Convenience over `OpenAIBackend::from_env()` — adopter writes
145/// `let b = backends::openai::from_env();`.
146pub fn from_env() -> OpenAIBackend {
147    OpenAIBackend::from_env()
148}
149
150/// Construct an OpenAI backend with an explicit API key (or `None`).
151pub fn with_api_key(api_key: Option<String>) -> OpenAIBackend {
152    OpenAIBackend::with_api_key(api_key)
153}
154
155#[allow(dead_code)]
156type OpenAIChatStream =
157    Pin<Box<dyn Stream<Item = Result<crate::backends::ChatChunk, BackendError>> + Send>>;
158
159// ────────────────────────────────────────────────────────────────────
160//  Tests
161// ────────────────────────────────────────────────────────────────────
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::backends::Message;
167
168    // ── Construction ────────────────────────────────────────────────
169
170    #[test]
171    fn from_env_constructs_openai_backend() {
172        let b = OpenAIBackend::from_env();
173        assert_eq!(b.name(), "openai");
174        assert_eq!(b.default_model(), "gpt-4o-mini");
175    }
176
177    #[test]
178    fn module_factory_from_env_works() {
179        let b = from_env();
180        assert_eq!(b.name(), "openai");
181    }
182
183    #[test]
184    fn module_factory_with_api_key_explicit() {
185        let b = with_api_key(Some("sk-test".into()));
186        assert_eq!(b.name(), "openai");
187    }
188
189    #[test]
190    fn with_base_url_overrides() {
191        let b = OpenAIBackend::with_api_key(Some("k".into()))
192            .with_base_url("http://localhost:1234");
193        // Verify by going through inner accessor.
194        // (No public getter for base_url — exercise via complete()
195        // failure path in dedicated test below.)
196        let _ = b;
197    }
198
199    #[test]
200    fn with_default_model_overrides() {
201        let b = OpenAIBackend::with_api_key(Some("k".into()))
202            .with_default_model("o1-mini");
203        assert_eq!(b.default_model(), "o1-mini");
204    }
205
206    // ── Capability discovery — OpenAI-specific overrides ────────────
207
208    #[test]
209    fn supports_vision_for_gpt_4o_family() {
210        let b = OpenAIBackend::with_api_key(Some("k".into()));
211        assert!(b.supports(Capability::Vision, "gpt-4o"));
212        assert!(b.supports(Capability::Vision, "gpt-4o-mini"));
213        assert!(b.supports(Capability::Vision, "gpt-4o-2024-08-06"));
214    }
215
216    #[test]
217    fn does_not_support_vision_for_older_models() {
218        let b = OpenAIBackend::with_api_key(Some("k".into()));
219        assert!(!b.supports(Capability::Vision, "gpt-3.5-turbo"));
220        assert!(!b.supports(Capability::Vision, "gpt-4"));
221        assert!(!b.supports(Capability::Vision, "gpt-4-turbo"));
222    }
223
224    #[test]
225    fn does_not_support_vision_for_reasoning_models() {
226        // o1 / o3 are text-only reasoning models.
227        let b = OpenAIBackend::with_api_key(Some("k".into()));
228        assert!(!b.supports(Capability::Vision, "o1"));
229        assert!(!b.supports(Capability::Vision, "o1-mini"));
230        assert!(!b.supports(Capability::Vision, "o3-mini"));
231    }
232
233    #[test]
234    fn vision_is_case_insensitive() {
235        let b = OpenAIBackend::with_api_key(Some("k".into()));
236        assert!(b.supports(Capability::Vision, "GPT-4o-mini"));
237    }
238
239    // ── Locked-params reaches o1/o3 via shared base ─────────────────
240
241    #[test]
242    fn supports_lockedparams_for_o1_o3() {
243        let b = OpenAIBackend::with_api_key(Some("k".into()));
244        assert!(b.supports(Capability::LockedParams, "o1"));
245        assert!(b.supports(Capability::LockedParams, "o1-mini"));
246        assert!(b.supports(Capability::LockedParams, "o1-preview"));
247        assert!(b.supports(Capability::LockedParams, "o3"));
248        assert!(b.supports(Capability::LockedParams, "o3-mini"));
249    }
250
251    #[test]
252    fn does_not_support_lockedparams_for_chat_models() {
253        let b = OpenAIBackend::with_api_key(Some("k".into()));
254        assert!(!b.supports(Capability::LockedParams, "gpt-4o-mini"));
255        assert!(!b.supports(Capability::LockedParams, "gpt-3.5-turbo"));
256        assert!(!b.supports(Capability::LockedParams, "gpt-4"));
257    }
258
259    // ── Capabilities passed through to base ─────────────────────────
260
261    #[test]
262    fn supports_streaming_tooluse_structured_via_base() {
263        let b = OpenAIBackend::with_api_key(Some("k".into()));
264        assert!(b.supports(Capability::Streaming, "gpt-4o-mini"));
265        assert!(b.supports(Capability::ToolUse, "gpt-4o-mini"));
266        assert!(b.supports(Capability::StructuredOutput, "gpt-4o-mini"));
267    }
268
269    #[test]
270    fn does_not_support_anthropic_or_gemini_only_caps() {
271        let b = OpenAIBackend::with_api_key(Some("k".into()));
272        assert!(!b.supports(Capability::PromptCaching, "gpt-4o-mini"));
273        assert!(!b.supports(Capability::SafetySettings, "gpt-4o-mini"));
274    }
275
276    // ── count_tokens delegates to unified dispatch ──────────────────
277
278    #[test]
279    fn count_tokens_uses_o200k_for_gpt_4o() {
280        let b = OpenAIBackend::with_api_key(Some("k".into()));
281        let n = b.count_tokens("gpt-4o-mini", "hello world");
282        // Exact tokenizer reports a small nonzero count.
283        assert!(n > 0);
284        assert!(n <= 5);
285    }
286
287    #[test]
288    fn count_tokens_uses_o200k_for_o1() {
289        let b = OpenAIBackend::with_api_key(Some("k".into()));
290        let n = b.count_tokens("o1-mini", "hello world");
291        assert!(n > 0);
292    }
293
294    // ── complete() — early failure paths ────────────────────────────
295
296    #[tokio::test]
297    async fn complete_without_api_key_returns_auth_error() {
298        let b = OpenAIBackend::with_api_key(None).with_base_url("http://127.0.0.1:0");
299        let err = b
300            .complete(ChatRequest {
301                messages: vec![Message::user("hi")],
302                ..Default::default()
303            })
304            .await
305            .unwrap_err();
306        match err {
307            BackendError::Auth { api_key_env, .. } => {
308                assert_eq!(api_key_env.as_deref(), Some(API_KEY_ENV));
309            }
310            other => panic!("expected Auth, got {other:?}"),
311        }
312    }
313
314    // ── Streaming surface ───────────────────────────────────────────
315
316    #[tokio::test]
317    async fn stream_delegates_to_base_real_sse_implementation() {
318        // §Fase 33.d — OpenAI-compat now implements SSE streaming
319        // natively. Without a reachable server this test exercises the
320        // transport-error path (early failure before any chunk).
321        let b = OpenAIBackend::with_api_key(Some("k".into()))
322            .with_base_url("http://127.0.0.1:1");
323        match b.stream(ChatRequest::default()).await {
324            Err(BackendError::Generic { ref message, .. }) => {
325                // Transport failure on unreachable port; message
326                // contains the connect-error string from reqwest.
327                assert!(
328                    message.contains("streaming transport failure")
329                        || message.contains("transport"),
330                    "unexpected message: {message}"
331                );
332            }
333            Err(other) => panic!("expected Generic, got {other:?}"),
334            Ok(_) => panic!("expected error, got Ok"),
335        }
336    }
337
338    // ── Inner accessor exists for test fixtures ─────────────────────
339
340    #[test]
341    fn inner_accessor_returns_compat_backend() {
342        let b = OpenAIBackend::with_api_key(Some("k".into()));
343        assert_eq!(b.inner().name(), "openai");
344    }
345}