pub mod anthropic;
pub mod gemini;
pub mod openai;
pub mod registry;
pub mod router;
mod stream;
use async_trait::async_trait;
use cersei_types::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::sync::mpsc;
pub use anthropic::Anthropic;
pub use gemini::Gemini;
pub use openai::OpenAi;
pub use router::from_model_string;
pub use stream::StreamAccumulator;
#[async_trait]
pub trait Provider: Send + Sync {
fn name(&self) -> &str;
fn context_window(&self, model: &str) -> u64;
fn capabilities(&self, model: &str) -> ProviderCapabilities;
async fn complete(&self, request: CompletionRequest) -> Result<CompletionStream>;
async fn complete_blocking(&self, request: CompletionRequest) -> Result<CompletionResponse> {
self.complete(request).await?.collect().await
}
async fn count_tokens(&self, messages: &[Message], _model: &str) -> Result<u64> {
let chars: usize = messages.iter().map(|m| m.get_all_text().len()).sum();
Ok((chars as u64) / 4) }
}
#[async_trait]
impl Provider for Box<dyn Provider> {
fn name(&self) -> &str {
(**self).name()
}
fn context_window(&self, model: &str) -> u64 {
(**self).context_window(model)
}
fn capabilities(&self, model: &str) -> ProviderCapabilities {
(**self).capabilities(model)
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionStream> {
(**self).complete(request).await
}
async fn complete_blocking(&self, request: CompletionRequest) -> Result<CompletionResponse> {
(**self).complete_blocking(request).await
}
async fn count_tokens(&self, messages: &[Message], model: &str) -> Result<u64> {
(**self).count_tokens(messages, model).await
}
}
#[derive(Debug, Clone)]
pub enum Auth {
ApiKey(String),
Bearer(String),
OAuth {
client_id: String,
token: OAuthToken,
},
Custom(std::sync::Arc<dyn AuthProvider>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthToken {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at_ms: Option<i64>,
pub scopes: Vec<String>,
}
impl OAuthToken {
pub fn is_expired(&self) -> bool {
if let Some(exp) = self.expires_at_ms {
chrono::Utc::now().timestamp_millis() >= exp
} else {
false
}
}
}
#[async_trait]
pub trait AuthProvider: Send + Sync + std::fmt::Debug {
async fn get_credentials(&self) -> Result<(String, String)>;
async fn refresh(&self) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct CompletionRequest {
pub model: String,
pub messages: Vec<Message>,
pub system: Option<String>,
pub tools: Vec<ToolDefinition>,
pub max_tokens: u32,
pub temperature: Option<f32>,
pub stop_sequences: Vec<String>,
pub options: ProviderOptions,
}
impl CompletionRequest {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
messages: Vec::new(),
system: None,
tools: Vec::new(),
max_tokens: 16384,
temperature: None,
stop_sequences: Vec::new(),
options: ProviderOptions::default(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ProviderOptions {
entries: HashMap<String, serde_json::Value>,
}
impl ProviderOptions {
pub fn set(&mut self, key: impl Into<String>, value: impl Serialize) {
if let Ok(v) = serde_json::to_value(value) {
self.entries.insert(key.into(), v);
}
}
pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
self.entries
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn has(&self, key: &str) -> bool {
self.entries.contains_key(key)
}
}
#[derive(Debug, Clone)]
pub struct CompletionResponse {
pub message: Message,
pub usage: Usage,
pub stop_reason: StopReason,
}
#[derive(Debug, Clone, Default)]
pub struct ProviderCapabilities {
pub streaming: bool,
pub tool_use: bool,
pub vision: bool,
pub thinking: bool,
pub system_prompt: bool,
pub caching: bool,
}
pub struct CompletionStream {
rx: mpsc::Receiver<StreamEvent>,
}
impl CompletionStream {
pub fn new(rx: mpsc::Receiver<StreamEvent>) -> Self {
Self { rx }
}
pub async fn collect(mut self) -> Result<CompletionResponse> {
let mut acc = StreamAccumulator::new();
while let Some(event) = self.rx.recv().await {
if let StreamEvent::Error { message } = &event {
return Err(CerseiError::Provider(message.clone()));
}
acc.process_event(event);
}
acc.into_response()
}
pub fn into_receiver(self) -> mpsc::Receiver<StreamEvent> {
self.rx
}
}