use std::env;
use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use super::error::BackendError;
use super::openai_compat::{OpenAICompatConfig, OpenAICompatibleBackend};
use super::{Backend, Capability, ChatRequest, ChatResponse, ChatStream};
const API_KEY_ENV: &str = "OPENAI_API_KEY";
pub struct OpenAIBackend {
inner: OpenAICompatibleBackend,
}
impl OpenAIBackend {
pub fn from_env() -> Self {
let api_key = env::var(API_KEY_ENV).ok();
Self::with_api_key(api_key)
}
pub fn with_api_key(api_key: Option<String>) -> Self {
Self {
inner: OpenAICompatibleBackend::new(OpenAICompatConfig::openai(), api_key),
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.inner = self.inner.with_base_url(base_url);
self
}
pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
self.inner = self.inner.with_default_model(model);
self
}
pub fn inner(&self) -> &OpenAICompatibleBackend {
&self.inner
}
}
impl Default for OpenAIBackend {
fn default() -> Self {
Self::from_env()
}
}
#[async_trait]
impl Backend for OpenAIBackend {
fn name(&self) -> &str {
self.inner.name()
}
fn default_model(&self) -> &str {
self.inner.default_model()
}
async fn complete(&self, request: ChatRequest) -> Result<ChatResponse, BackendError> {
self.inner.complete(request).await
}
async fn stream(&self, request: ChatRequest) -> Result<ChatStream, BackendError> {
self.inner.stream(request).await
}
fn count_tokens(&self, model: &str, text: &str) -> usize {
self.inner.count_tokens(model, text)
}
fn supports(&self, capability: Capability, model: &str) -> bool {
match capability {
Capability::Vision => model.to_lowercase().starts_with("gpt-4o"),
other => self.inner.supports(other, model),
}
}
}
pub fn from_env() -> OpenAIBackend {
OpenAIBackend::from_env()
}
pub fn with_api_key(api_key: Option<String>) -> OpenAIBackend {
OpenAIBackend::with_api_key(api_key)
}
#[allow(dead_code)]
type OpenAIChatStream =
Pin<Box<dyn Stream<Item = Result<crate::backends::ChatChunk, BackendError>> + Send>>;
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::Message;
#[test]
fn from_env_constructs_openai_backend() {
let b = OpenAIBackend::from_env();
assert_eq!(b.name(), "openai");
assert_eq!(b.default_model(), "gpt-4o-mini");
}
#[test]
fn module_factory_from_env_works() {
let b = from_env();
assert_eq!(b.name(), "openai");
}
#[test]
fn module_factory_with_api_key_explicit() {
let b = with_api_key(Some("sk-test".into()));
assert_eq!(b.name(), "openai");
}
#[test]
fn with_base_url_overrides() {
let b = OpenAIBackend::with_api_key(Some("k".into()))
.with_base_url("http://localhost:1234");
let _ = b;
}
#[test]
fn with_default_model_overrides() {
let b = OpenAIBackend::with_api_key(Some("k".into()))
.with_default_model("o1-mini");
assert_eq!(b.default_model(), "o1-mini");
}
#[test]
fn supports_vision_for_gpt_4o_family() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
assert!(b.supports(Capability::Vision, "gpt-4o"));
assert!(b.supports(Capability::Vision, "gpt-4o-mini"));
assert!(b.supports(Capability::Vision, "gpt-4o-2024-08-06"));
}
#[test]
fn does_not_support_vision_for_older_models() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
assert!(!b.supports(Capability::Vision, "gpt-3.5-turbo"));
assert!(!b.supports(Capability::Vision, "gpt-4"));
assert!(!b.supports(Capability::Vision, "gpt-4-turbo"));
}
#[test]
fn does_not_support_vision_for_reasoning_models() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
assert!(!b.supports(Capability::Vision, "o1"));
assert!(!b.supports(Capability::Vision, "o1-mini"));
assert!(!b.supports(Capability::Vision, "o3-mini"));
}
#[test]
fn vision_is_case_insensitive() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
assert!(b.supports(Capability::Vision, "GPT-4o-mini"));
}
#[test]
fn supports_lockedparams_for_o1_o3() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
assert!(b.supports(Capability::LockedParams, "o1"));
assert!(b.supports(Capability::LockedParams, "o1-mini"));
assert!(b.supports(Capability::LockedParams, "o1-preview"));
assert!(b.supports(Capability::LockedParams, "o3"));
assert!(b.supports(Capability::LockedParams, "o3-mini"));
}
#[test]
fn does_not_support_lockedparams_for_chat_models() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
assert!(!b.supports(Capability::LockedParams, "gpt-4o-mini"));
assert!(!b.supports(Capability::LockedParams, "gpt-3.5-turbo"));
assert!(!b.supports(Capability::LockedParams, "gpt-4"));
}
#[test]
fn supports_streaming_tooluse_structured_via_base() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
assert!(b.supports(Capability::Streaming, "gpt-4o-mini"));
assert!(b.supports(Capability::ToolUse, "gpt-4o-mini"));
assert!(b.supports(Capability::StructuredOutput, "gpt-4o-mini"));
}
#[test]
fn does_not_support_anthropic_or_gemini_only_caps() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
assert!(!b.supports(Capability::PromptCaching, "gpt-4o-mini"));
assert!(!b.supports(Capability::SafetySettings, "gpt-4o-mini"));
}
#[test]
fn count_tokens_uses_o200k_for_gpt_4o() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
let n = b.count_tokens("gpt-4o-mini", "hello world");
assert!(n > 0);
assert!(n <= 5);
}
#[test]
fn count_tokens_uses_o200k_for_o1() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
let n = b.count_tokens("o1-mini", "hello world");
assert!(n > 0);
}
#[tokio::test]
async fn complete_without_api_key_returns_auth_error() {
let b = OpenAIBackend::with_api_key(None).with_base_url("http://127.0.0.1:0");
let err = b
.complete(ChatRequest {
messages: vec![Message::user("hi")],
..Default::default()
})
.await
.unwrap_err();
match err {
BackendError::Auth { api_key_env, .. } => {
assert_eq!(api_key_env.as_deref(), Some(API_KEY_ENV));
}
other => panic!("expected Auth, got {other:?}"),
}
}
#[tokio::test]
async fn stream_delegates_to_base_not_implemented_path() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
match b.stream(ChatRequest::default()).await {
Err(BackendError::Generic { ref message, .. }) => {
assert!(message.contains("streaming not yet implemented"));
}
Err(other) => panic!("expected Generic, got {other:?}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[test]
fn inner_accessor_returns_compat_backend() {
let b = OpenAIBackend::with_api_key(Some("k".into()));
assert_eq!(b.inner().name(), "openai");
}
}