1pub mod anthropic;
7pub mod gemini;
8pub mod openai;
9pub mod registry;
10pub mod router;
11mod stream;
12
13use async_trait::async_trait;
14use cersei_types::*;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use tokio::sync::mpsc;
18
19pub use anthropic::Anthropic;
21pub use gemini::Gemini;
22pub use openai::OpenAi;
23pub use router::from_model_string;
24pub use stream::StreamAccumulator;
25
26#[async_trait]
29pub trait Provider: Send + Sync {
30 fn name(&self) -> &str;
32
33 fn context_window(&self, model: &str) -> u64;
35
36 fn capabilities(&self, model: &str) -> ProviderCapabilities;
38
39 async fn complete(&self, request: CompletionRequest) -> Result<CompletionStream>;
41
42 async fn complete_blocking(&self, request: CompletionRequest) -> Result<CompletionResponse> {
44 self.complete(request).await?.collect().await
45 }
46
47 async fn count_tokens(&self, messages: &[Message], _model: &str) -> Result<u64> {
49 let chars: usize = messages.iter().map(|m| m.get_all_text().len()).sum();
51 Ok((chars as u64) / 4) }
53}
54
55#[async_trait]
57impl Provider for Box<dyn Provider> {
58 fn name(&self) -> &str {
59 (**self).name()
60 }
61 fn context_window(&self, model: &str) -> u64 {
62 (**self).context_window(model)
63 }
64 fn capabilities(&self, model: &str) -> ProviderCapabilities {
65 (**self).capabilities(model)
66 }
67 async fn complete(&self, request: CompletionRequest) -> Result<CompletionStream> {
68 (**self).complete(request).await
69 }
70 async fn complete_blocking(&self, request: CompletionRequest) -> Result<CompletionResponse> {
71 (**self).complete_blocking(request).await
72 }
73 async fn count_tokens(&self, messages: &[Message], model: &str) -> Result<u64> {
74 (**self).count_tokens(messages, model).await
75 }
76}
77
78#[derive(Debug, Clone)]
81pub enum Auth {
82 ApiKey(String),
84 Bearer(String),
86 OAuth {
88 client_id: String,
89 token: OAuthToken,
90 },
91 Custom(std::sync::Arc<dyn AuthProvider>),
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct OAuthToken {
97 pub access_token: String,
98 pub refresh_token: Option<String>,
99 pub expires_at_ms: Option<i64>,
100 pub scopes: Vec<String>,
101}
102
103impl OAuthToken {
104 pub fn is_expired(&self) -> bool {
105 if let Some(exp) = self.expires_at_ms {
106 chrono::Utc::now().timestamp_millis() >= exp
107 } else {
108 false
109 }
110 }
111}
112
113#[async_trait]
114pub trait AuthProvider: Send + Sync + std::fmt::Debug {
115 async fn get_credentials(&self) -> Result<(String, String)>;
117
118 async fn refresh(&self) -> Result<()>;
120}
121
122#[derive(Debug, Clone)]
125pub struct CompletionRequest {
126 pub model: String,
127 pub messages: Vec<Message>,
128 pub system: Option<String>,
129 pub tools: Vec<ToolDefinition>,
130 pub max_tokens: u32,
131 pub temperature: Option<f32>,
132 pub stop_sequences: Vec<String>,
133 pub options: ProviderOptions,
135}
136
137impl CompletionRequest {
138 pub fn new(model: impl Into<String>) -> Self {
139 Self {
140 model: model.into(),
141 messages: Vec::new(),
142 system: None,
143 tools: Vec::new(),
144 max_tokens: 16384,
145 temperature: None,
146 stop_sequences: Vec::new(),
147 options: ProviderOptions::default(),
148 }
149 }
150}
151
152#[derive(Debug, Clone, Default)]
153pub struct ProviderOptions {
154 entries: HashMap<String, serde_json::Value>,
155}
156
157impl ProviderOptions {
158 pub fn set(&mut self, key: impl Into<String>, value: impl Serialize) {
159 if let Ok(v) = serde_json::to_value(value) {
160 self.entries.insert(key.into(), v);
161 }
162 }
163
164 pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
165 self.entries
166 .get(key)
167 .and_then(|v| serde_json::from_value(v.clone()).ok())
168 }
169
170 pub fn has(&self, key: &str) -> bool {
171 self.entries.contains_key(key)
172 }
173}
174
175#[derive(Debug, Clone)]
176pub struct CompletionResponse {
177 pub message: Message,
178 pub usage: Usage,
179 pub stop_reason: StopReason,
180}
181
182#[derive(Debug, Clone, Default)]
183pub struct ProviderCapabilities {
184 pub streaming: bool,
185 pub tool_use: bool,
186 pub vision: bool,
187 pub thinking: bool,
188 pub system_prompt: bool,
189 pub caching: bool,
190}
191
192pub struct CompletionStream {
196 rx: mpsc::Receiver<StreamEvent>,
197}
198
199impl CompletionStream {
200 pub fn new(rx: mpsc::Receiver<StreamEvent>) -> Self {
201 Self { rx }
202 }
203
204 pub async fn collect(mut self) -> Result<CompletionResponse> {
206 let mut acc = StreamAccumulator::new();
207 while let Some(event) = self.rx.recv().await {
208 if let StreamEvent::Error { message } = &event {
209 return Err(CerseiError::Provider(message.clone()));
210 }
211 acc.process_event(event);
212 }
213 acc.into_response()
214 }
215
216 pub fn into_receiver(self) -> mpsc::Receiver<StreamEvent> {
218 self.rx
219 }
220}