Skip to main content

cersei_provider/
lib.rs

1//! cersei-provider: Provider trait and built-in LLM providers.
2//!
3//! Providers abstract over different LLM backends (Anthropic, OpenAI, local models).
4//! Each provider implements streaming completion, token counting, and capability discovery.
5
6pub mod anthropic;
7pub mod gemini;
8pub mod openai;
9pub mod registry;
10pub mod router;
11mod stream;
12
13use async_trait::async_trait;
14use cersei_types::*;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use tokio::sync::mpsc;
18
19// Re-exports
20pub use anthropic::Anthropic;
21pub use gemini::Gemini;
22pub use openai::OpenAi;
23pub use router::from_model_string;
24pub use stream::StreamAccumulator;
25
26// ─── Provider trait ──────────────────────────────────────────────────────────
27
28#[async_trait]
29pub trait Provider: Send + Sync {
30    /// Human-readable provider name (e.g., "anthropic", "openai").
31    fn name(&self) -> &str;
32
33    /// Context window size for the given model.
34    fn context_window(&self, model: &str) -> u64;
35
36    /// Capabilities supported by the given model.
37    fn capabilities(&self, model: &str) -> ProviderCapabilities;
38
39    /// Send a streaming completion request.
40    async fn complete(&self, request: CompletionRequest) -> Result<CompletionStream>;
41
42    /// Send a blocking (non-streaming) completion request.
43    async fn complete_blocking(&self, request: CompletionRequest) -> Result<CompletionResponse> {
44        self.complete(request).await?.collect().await
45    }
46
47    /// Count tokens for a message list. Returns an estimate if exact counting is unavailable.
48    async fn count_tokens(&self, messages: &[Message], _model: &str) -> Result<u64> {
49        // Default: rough estimate based on character count
50        let chars: usize = messages.iter().map(|m| m.get_all_text().len()).sum();
51        Ok((chars as u64) / 4) // ~4 chars per token
52    }
53}
54
55// Blanket impl: Box<dyn Provider> is itself a Provider.
56#[async_trait]
57impl Provider for Box<dyn Provider> {
58    fn name(&self) -> &str {
59        (**self).name()
60    }
61    fn context_window(&self, model: &str) -> u64 {
62        (**self).context_window(model)
63    }
64    fn capabilities(&self, model: &str) -> ProviderCapabilities {
65        (**self).capabilities(model)
66    }
67    async fn complete(&self, request: CompletionRequest) -> Result<CompletionStream> {
68        (**self).complete(request).await
69    }
70    async fn complete_blocking(&self, request: CompletionRequest) -> Result<CompletionResponse> {
71        (**self).complete_blocking(request).await
72    }
73    async fn count_tokens(&self, messages: &[Message], model: &str) -> Result<u64> {
74        (**self).count_tokens(messages, model).await
75    }
76}
77
78// ─── Authentication ──────────────────────────────────────────────────────────
79
80#[derive(Debug, Clone)]
81pub enum Auth {
82    /// API key sent as `x-api-key` header (Anthropic Console) or `Authorization: Bearer` (OpenAI).
83    ApiKey(String),
84    /// Bearer token sent as `Authorization: Bearer <token>`.
85    Bearer(String),
86    /// OAuth flow with client ID and token.
87    OAuth {
88        client_id: String,
89        token: OAuthToken,
90    },
91    /// Custom auth provider for non-standard flows.
92    Custom(std::sync::Arc<dyn AuthProvider>),
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct OAuthToken {
97    pub access_token: String,
98    pub refresh_token: Option<String>,
99    pub expires_at_ms: Option<i64>,
100    pub scopes: Vec<String>,
101}
102
103impl OAuthToken {
104    pub fn is_expired(&self) -> bool {
105        if let Some(exp) = self.expires_at_ms {
106            chrono::Utc::now().timestamp_millis() >= exp
107        } else {
108            false
109        }
110    }
111}
112
113#[async_trait]
114pub trait AuthProvider: Send + Sync + std::fmt::Debug {
115    /// Returns (header_name, header_value) for the request.
116    async fn get_credentials(&self) -> Result<(String, String)>;
117
118    /// Refresh credentials if they have expired.
119    async fn refresh(&self) -> Result<()>;
120}
121
122// ─── Completion request/response ─────────────────────────────────────────────
123
124#[derive(Debug, Clone)]
125pub struct CompletionRequest {
126    pub model: String,
127    pub messages: Vec<Message>,
128    pub system: Option<String>,
129    pub tools: Vec<ToolDefinition>,
130    pub max_tokens: u32,
131    pub temperature: Option<f32>,
132    pub stop_sequences: Vec<String>,
133    /// Provider-specific options (thinking budget, top_p, etc.)
134    pub options: ProviderOptions,
135}
136
137impl CompletionRequest {
138    pub fn new(model: impl Into<String>) -> Self {
139        Self {
140            model: model.into(),
141            messages: Vec::new(),
142            system: None,
143            tools: Vec::new(),
144            max_tokens: 16384,
145            temperature: None,
146            stop_sequences: Vec::new(),
147            options: ProviderOptions::default(),
148        }
149    }
150}
151
152#[derive(Debug, Clone, Default)]
153pub struct ProviderOptions {
154    entries: HashMap<String, serde_json::Value>,
155}
156
157impl ProviderOptions {
158    pub fn set(&mut self, key: impl Into<String>, value: impl Serialize) {
159        if let Ok(v) = serde_json::to_value(value) {
160            self.entries.insert(key.into(), v);
161        }
162    }
163
164    pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
165        self.entries
166            .get(key)
167            .and_then(|v| serde_json::from_value(v.clone()).ok())
168    }
169
170    pub fn has(&self, key: &str) -> bool {
171        self.entries.contains_key(key)
172    }
173}
174
175#[derive(Debug, Clone)]
176pub struct CompletionResponse {
177    pub message: Message,
178    pub usage: Usage,
179    pub stop_reason: StopReason,
180}
181
182#[derive(Debug, Clone, Default)]
183pub struct ProviderCapabilities {
184    pub streaming: bool,
185    pub tool_use: bool,
186    pub vision: bool,
187    pub thinking: bool,
188    pub system_prompt: bool,
189    pub caching: bool,
190}
191
192// ─── Completion stream ───────────────────────────────────────────────────────
193
194/// A streaming response from a provider. Wraps a channel of StreamEvents.
195pub struct CompletionStream {
196    rx: mpsc::Receiver<StreamEvent>,
197}
198
199impl CompletionStream {
200    pub fn new(rx: mpsc::Receiver<StreamEvent>) -> Self {
201        Self { rx }
202    }
203
204    /// Consume the stream and collect into a complete response.
205    pub async fn collect(mut self) -> Result<CompletionResponse> {
206        let mut acc = StreamAccumulator::new();
207        while let Some(event) = self.rx.recv().await {
208            if let StreamEvent::Error { message } = &event {
209                return Err(CerseiError::Provider(message.clone()));
210            }
211            acc.process_event(event);
212        }
213        acc.into_response()
214    }
215
216    /// Access the underlying receiver for real-time event processing.
217    pub fn into_receiver(self) -> mpsc::Receiver<StreamEvent> {
218        self.rx
219    }
220}