Skip to main content

caliban_provider_openai/
lib.rs

1//! `OpenAI` schema family for the caliban agent harness.
2//!
3//! Provides [`OpenAIProvider<T: Transport>`] generic over its transport.
4//! Direct API is supported by default; Azure `OpenAI` transport is gated
5//! behind the `azure` cargo feature.
6
7#![allow(clippy::missing_errors_doc)]
8// Transitive dependencies pull in multiple versions of some crates.
9#![allow(clippy::multiple_crate_versions)]
10
11pub mod config;
12pub mod error;
13pub mod ir_convert;
14pub mod models;
15pub mod schema;
16pub mod transport;
17
18mod stream_parse; // populated in Task 5
19
20use async_trait::async_trait;
21use caliban_provider::{
22    Capabilities, CompletionRequest, CompletionResponse, Error, MessageStream, ModelInfo, Provider,
23    Result, SystemPromptCapability,
24};
25
26use crate::config::DirectConfig;
27use crate::transport::Transport;
28use crate::transport::direct::DirectTransport;
29
30/// `OpenAI` Chat Completions provider, generic over its transport.
31pub struct OpenAIProvider<T: Transport> {
32    transport: T,
33}
34
35impl<T: Transport> std::fmt::Debug for OpenAIProvider<T> {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("OpenAIProvider").finish_non_exhaustive()
38    }
39}
40
41impl OpenAIProvider<DirectTransport> {
42    /// Construct an `OpenAIProvider` using the direct HTTPS transport.
43    ///
44    /// # Errors
45    ///
46    /// Returns `Err` if the underlying `reqwest` client cannot be built.
47    pub fn direct(cfg: DirectConfig) -> Result<Self> {
48        DirectTransport::new(cfg)
49            .map(|t| Self { transport: t })
50            .map_err(Error::adapter)
51    }
52}
53
54impl<T: Transport> OpenAIProvider<T> {
55    /// Construct an `OpenAIProvider` from an arbitrary `Transport`.
56    pub fn from_transport(transport: T) -> Self {
57        Self { transport }
58    }
59}
60
61#[cfg(feature = "azure")]
62impl OpenAIProvider<crate::transport::azure::AzureTransport> {
63    /// Construct an `OpenAIProvider` using the Azure `OpenAI` Service transport.
64    ///
65    /// # Errors
66    ///
67    /// Returns `Err` if the underlying `reqwest` client cannot be built.
68    pub fn azure(cfg: crate::config::AzureConfig) -> Result<Self> {
69        crate::transport::azure::AzureTransport::new(cfg)
70            .map(|t| Self { transport: t })
71            .map_err(Error::adapter)
72    }
73}
74
75#[async_trait]
76impl<T: Transport> Provider for OpenAIProvider<T> {
77    async fn complete(&self, req: CompletionRequest) -> Result<CompletionResponse> {
78        req.validate()?;
79        let canonical_model = req.model.clone();
80        let caps = self.capabilities(&canonical_model);
81        let system_role = match caps.system_prompt {
82            SystemPromptCapability::DeveloperRole => "developer",
83            _ => "system",
84        };
85        let mut native = ir_convert::ir_to_native_request(req, false, system_role)?;
86        native.model = self.transport.wire_model_id(&canonical_model);
87        self.transport.finalize_request(&mut native);
88        let native_resp = self.transport.send(native).await.map_err(Error::from)?;
89        ir_convert::native_response_to_ir(native_resp)
90    }
91
92    async fn stream(&self, req: CompletionRequest) -> Result<MessageStream> {
93        req.validate()?;
94        let canonical_model = req.model.clone();
95        let caps = self.capabilities(&canonical_model);
96        let system_role = match caps.system_prompt {
97            SystemPromptCapability::DeveloperRole => "developer",
98            _ => "system",
99        };
100        let mut native = ir_convert::ir_to_native_request(req, true, system_role)?;
101        native.model = self.transport.wire_model_id(&canonical_model);
102        // Opt into usage reporting on the final streaming chunk.
103        native.stream_options = Some(crate::schema::request::NativeStreamOptions {
104            include_usage: true,
105        });
106        self.transport.finalize_request(&mut native);
107        let bytes_stream = self
108            .transport
109            .stream(native)
110            .await
111            .map_err(caliban_provider::Error::from)?;
112        Ok(stream_parse::map_openai_sse_to_events(bytes_stream))
113    }
114
115    fn capabilities(&self, model: &str) -> Capabilities {
116        models::capabilities_for(model)
117    }
118
119    fn list_models(&self) -> Vec<ModelInfo> {
120        models::models()
121    }
122
123    fn name(&self) -> &'static str {
124        "openai"
125    }
126}