use crate::{Error, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::env;
use super::{
Message, GenerateOptions, GenerateResult, StreamChunk, FinishReason,
Usage, ToolDefinition, MessageRole, MessageContent
};
#[async_trait]
pub trait Provider: Send + Sync {
fn id(&self) -> &str;
fn name(&self) -> &str;
fn base_url(&self) -> &str;
fn api_version(&self) -> &str;
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
async fn get_model(&self, model_id: &str) -> Result<Arc<dyn Model>>;
async fn health_check(&self) -> Result<ProviderHealth>;
fn get_config(&self) -> &ProviderConfig;
async fn update_config(&mut self, config: ProviderConfig) -> Result<()>;
async fn get_rate_limits(&self) -> Result<RateLimitInfo>;
async fn get_usage(&self) -> Result<UsageStats>;
}
#[async_trait]
pub trait Model: Send + Sync {
fn id(&self) -> &str;
fn name(&self) -> &str;
fn provider_id(&self) -> &str;
fn capabilities(&self) -> &ModelCapabilities;
fn config(&self) -> &ModelConfig;
async fn generate(
&self,
messages: Vec<Message>,
options: GenerateOptions,
) -> Result<GenerateResult>;
async fn stream(
&self,
messages: Vec<Message>,
options: GenerateOptions,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>>;
async fn count_tokens(&self, messages: &[Message]) -> Result<u32>;
async fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> Result<f64>;
fn metadata(&self) -> &ModelMetadata;
}
use futures::Stream;
use std::pin::Pin;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub description: Option<String>,
pub capabilities: ModelCapabilities,
pub limits: ModelLimits,
pub pricing: ModelPricing,
pub release_date: Option<chrono::DateTime<chrono::Utc>>,
pub status: ModelStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCapabilities {
pub text_generation: bool,
pub tool_calling: bool,
pub vision: bool,
pub streaming: bool,
pub caching: bool,
pub json_mode: bool,
pub reasoning: bool,
pub code_generation: bool,
pub multilingual: bool,
pub custom: HashMap<String, serde_json::Value>,
}
impl Default for ModelCapabilities {
fn default() -> Self {
Self {
text_generation: true,
tool_calling: false,
vision: false,
streaming: true,
caching: false,
json_mode: false,
reasoning: false,
code_generation: false,
multilingual: false,
custom: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelLimits {
pub max_context_tokens: u32,
pub max_output_tokens: u32,
pub max_image_size_bytes: Option<u64>,
pub max_images_per_request: Option<u32>,
pub max_tool_calls: Option<u32>,
pub rate_limits: RateLimitInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub input_cost_per_1k: f64,
pub output_cost_per_1k: f64,
pub cache_read_cost_per_1k: Option<f64>,
pub cache_write_cost_per_1k: Option<f64>,
pub currency: String,
}
impl Default for ModelPricing {
fn default() -> Self {
Self {
input_cost_per_1k: 0.0,
output_cost_per_1k: 0.0,
cache_read_cost_per_1k: None,
cache_write_cost_per_1k: None,
currency: "USD".to_string(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelStatus {
Active,
Deprecated,
Beta,
Unavailable,
Discontinued,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitInfo {
pub requests_per_minute: Option<u32>,
pub tokens_per_minute: Option<u32>,
pub tokens_per_day: Option<u32>,
pub concurrent_requests: Option<u32>,
pub current_usage: Option<CurrentUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CurrentUsage {
pub requests_this_minute: u32,
pub tokens_this_minute: u32,
pub tokens_today: u32,
pub active_requests: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderHealth {
pub available: bool,
pub latency_ms: Option<u64>,
pub error: Option<String>,
pub last_check: chrono::DateTime<chrono::Utc>,
pub details: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub provider_id: String,
pub api_key: Option<String>,
pub base_url_override: Option<String>,
pub api_version_override: Option<String>,
pub timeout_seconds: u64,
pub max_retries: u32,
pub retry_delay_ms: u64,
pub custom_headers: HashMap<String, String>,
pub organization_id: Option<String>,
pub project_id: Option<String>,
pub extra: HashMap<String, serde_json::Value>,
}
impl Default for ProviderConfig {
fn default() -> Self {
Self {
provider_id: String::new(),
api_key: None,
base_url_override: None,
api_version_override: None,
timeout_seconds: 60,
max_retries: 3,
retry_delay_ms: 1000,
custom_headers: HashMap::new(),
organization_id: None,
project_id: None,
extra: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub model_id: String,
pub default_temperature: Option<f32>,
pub default_max_tokens: Option<u32>,
pub default_top_p: Option<f32>,
pub default_stop_sequences: Vec<String>,
pub use_caching: bool,
pub options: HashMap<String, serde_json::Value>,
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
model_id: String::new(),
default_temperature: None,
default_max_tokens: None,
default_top_p: None,
default_stop_sequences: Vec::new(),
use_caching: false,
options: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub family: String,
pub parameters: Option<String>,
pub training_cutoff: Option<chrono::DateTime<chrono::Utc>>,
pub version: Option<String>,
pub extra: HashMap<String, serde_json::Value>,
}
impl Default for ModelMetadata {
fn default() -> Self {
Self {
family: String::new(),
parameters: None,
training_cutoff: None,
version: None,
extra: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageStats {
pub total_requests: u64,
pub total_tokens: u64,
pub total_cost: f64,
pub currency: String,
pub by_model: HashMap<String, ModelUsage>,
pub by_period: HashMap<String, PeriodUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelUsage {
pub requests: u64,
pub input_tokens: u64,
pub output_tokens: u64,
pub cache_hits: u64,
pub cost: f64,
pub avg_latency_ms: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PeriodUsage {
pub start: chrono::DateTime<chrono::Utc>,
pub end: chrono::DateTime<chrono::Utc>,
pub requests: u64,
pub tokens: u64,
pub cost: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Cost {
pub input_per_1k: f64,
pub output_per_1k: f64,
pub currency: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Limits {
pub max_context_tokens: u32,
pub max_output_tokens: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ProviderSource {
Official,
Community,
Custom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ProviderStatus {
Active,
Beta,
Deprecated,
Unavailable,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub multiplier: f32,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay_ms: 1000,
max_delay_ms: 10000,
multiplier: 2.0,
}
}
}
pub async fn retry_with_backoff<F, T, E>(
config: &RetryConfig,
operation: F,
) -> Result<T>
where
F: Fn() -> futures::future::BoxFuture<'static, Result<T>>,
{
use tokio::time::{sleep, Duration};
let mut attempts = 0;
let mut delay = config.initial_delay_ms;
loop {
match operation().await {
Ok(result) => return Ok(result),
Err(e) if attempts < config.max_retries => {
attempts += 1;
sleep(Duration::from_millis(delay)).await;
delay = (delay as f32 * config.multiplier) as u64;
delay = delay.min(config.max_delay_ms);
}
Err(e) => return Err(e),
}
}
}
pub struct ProviderRegistry {
providers: HashMap<String, Arc<dyn Provider>>,
models: HashMap<String, Arc<dyn Model>>,
default_provider: Option<String>,
storage: Arc<dyn crate::auth::AuthStorage>,
}
impl ProviderRegistry {
pub fn new(storage: Arc<dyn crate::auth::AuthStorage>) -> Self {
Self {
providers: HashMap::new(),
models: HashMap::new(),
default_provider: None,
storage,
}
}
pub fn register_provider(&mut self, provider: Arc<dyn Provider>) -> Result<()> {
let provider_id = provider.id().to_string();
if self.providers.contains_key(&provider_id) {
return Err(Error::Other(anyhow::anyhow!(
"Provider {} is already registered",
provider_id
)));
}
self.providers.insert(provider_id, provider);
Ok(())
}
pub fn get_provider(&self, provider_id: &str) -> Result<Arc<dyn Provider>> {
self.providers
.get(provider_id)
.cloned()
.ok_or_else(|| Error::Other(anyhow::anyhow!("Provider {} not found", provider_id)))
}
pub fn list_providers(&self) -> Vec<String> {
self.providers.keys().cloned().collect()
}
pub async fn get_model(&mut self, provider_id: &str, model_id: &str) -> Result<Arc<dyn Model>> {
let key = format!("{}/{}", provider_id, model_id);
if let Some(model) = self.models.get(&key) {
return Ok(model.clone());
}
let provider = self.get_provider(provider_id)?;
let model = provider.get_model(model_id).await?;
self.models.insert(key, model.clone());
Ok(model)
}
pub fn parse_model_string(&self, model_string: &str) -> Result<(String, String)> {
if let Some((provider, model)) = model_string.split_once('/') {
Ok((provider.to_string(), model.to_string()))
} else if let Some((provider, model)) = model_string.split_once(':') {
Ok((provider.to_string(), model.to_string()))
} else {
if let Some(default_provider) = &self.default_provider {
Ok((default_provider.clone(), model_string.to_string()))
} else {
Err(Error::Other(anyhow::anyhow!(
"Invalid model string format: {}. Expected 'provider/model' or 'provider:model'",
model_string
)))
}
}
}
pub fn set_default_provider(&mut self, provider_id: &str) -> Result<()> {
if !self.providers.contains_key(provider_id) {
return Err(Error::Other(anyhow::anyhow!(
"Provider {} is not registered",
provider_id
)));
}
self.default_provider = Some(provider_id.to_string());
Ok(())
}
pub fn get_default_provider(&self) -> Option<&str> {
self.default_provider.as_deref()
}
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
let mut all_models = Vec::new();
for provider in self.providers.values() {
match provider.list_models().await {
Ok(models) => all_models.extend(models),
Err(e) => {
tracing::warn!("Failed to list models for provider {}: {}", provider.id(), e);
}
}
}
Ok(all_models)
}
pub async fn get_all_provider_health(&self) -> HashMap<String, ProviderHealth> {
let mut health_status = HashMap::new();
for (id, provider) in &self.providers {
match provider.health_check().await {
Ok(health) => {
health_status.insert(id.clone(), health);
}
Err(e) => {
health_status.insert(
id.clone(),
ProviderHealth {
available: false,
latency_ms: None,
error: Some(e.to_string()),
last_check: chrono::Utc::now(),
details: HashMap::new(),
},
);
}
}
}
health_status
}
pub fn clear_model_cache(&mut self) {
self.models.clear();
}
pub fn remove_provider(&mut self, provider_id: &str) -> Result<()> {
if !self.providers.contains_key(provider_id) {
return Err(Error::Other(anyhow::anyhow!(
"Provider {} is not registered",
provider_id
)));
}
self.providers.remove(provider_id);
self.models.retain(|key, _| !key.starts_with(&format!("{}/", provider_id)));
if self.default_provider.as_deref() == Some(provider_id) {
self.default_provider = None;
}
Ok(())
}
pub async fn discover_from_env(&mut self) -> Result<()> {
if env::var("ANTHROPIC_API_KEY").is_ok() {
if let Ok(provider) = self.create_anthropic_provider().await {
self.register_provider(provider)?;
}
}
if env::var("OPENAI_API_KEY").is_ok() {
if let Ok(provider) = self.create_openai_provider().await {
self.register_provider(provider)?;
}
}
if env::var("GITHUB_TOKEN").is_ok() || env::var("GITHUB_COPILOT_TOKEN").is_ok() {
if let Ok(provider) = self.create_github_copilot_provider().await {
self.register_provider(provider)?;
}
}
Ok(())
}
pub async fn discover_from_storage(&mut self) -> Result<()> {
if let Ok(Some(_)) = self.storage.get("anthropic").await {
if let Ok(provider) = self.create_anthropic_provider().await {
self.register_provider(provider)?;
}
}
if let Ok(Some(_)) = self.storage.get("openai").await {
if let Ok(provider) = self.create_openai_provider().await {
self.register_provider(provider)?;
}
}
if let Ok(Some(_)) = self.storage.get("github-copilot").await {
if let Ok(provider) = self.create_github_copilot_provider().await {
self.register_provider(provider)?;
}
}
Ok(())
}
pub async fn initialize_all(&mut self) -> Result<()> {
let provider_ids: Vec<String> = self.providers.keys().cloned().collect();
for provider_id in provider_ids {
match self.providers.get(&provider_id) {
Some(provider) => {
if let Err(e) = provider.health_check().await {
tracing::warn!("Failed to initialize provider {}: {}", provider_id, e);
}
}
None => continue,
}
}
Ok(())
}
pub async fn load_models_dev(&mut self) -> Result<()> {
tracing::info!("Loading models from models.dev (using built-in configs for now)");
Ok(())
}
pub async fn load_configs(&mut self, path: &str) -> Result<()> {
use std::path::Path;
use tokio::fs;
let path = Path::new(path);
if !path.exists() {
return Err(Error::Other(anyhow::anyhow!(
"Configuration file not found: {}",
path.display()
)));
}
let contents = fs::read_to_string(path).await?;
let configs: HashMap<String, ProviderConfig> = serde_json::from_str(&contents)?;
for (provider_id, config) in configs {
if self.providers.contains_key(&provider_id) {
tracing::warn!("Cannot update config for provider {} - providers are immutable through Arc", provider_id);
}
}
Ok(())
}
pub async fn get(&self, provider_id: &str) -> Option<Arc<dyn Provider>> {
self.providers.get(provider_id).cloned()
}
pub fn parse_model(model_str: &str) -> (String, String) {
if let Some((provider, model)) = model_str.split_once('/') {
(provider.to_string(), model.to_string())
} else if let Some((provider, model)) = model_str.split_once(':') {
(provider.to_string(), model.to_string())
} else {
("anthropic".to_string(), model_str.to_string())
}
}
pub async fn get_default_model(&self, provider_id: &str) -> Result<Arc<dyn Model>> {
let provider = self.get_provider(provider_id)?;
let models = provider.list_models().await?;
if let Some(default_model) = models.iter().find(|m| m.status == ModelStatus::Active) {
provider.get_model(&default_model.id).await
} else if let Some(first_model) = models.first() {
provider.get_model(&first_model.id).await
} else {
Err(Error::Other(anyhow::anyhow!(
"Provider {} has no available models",
provider_id
)))
}
}
pub async fn available(&self) -> Vec<String> {
let mut available = Vec::new();
for (id, provider) in &self.providers {
if let Ok(health) = provider.health_check().await {
if health.available {
available.push(id.clone());
}
}
}
available
}
pub async fn list(&self) -> Vec<String> {
self.providers.keys().cloned().collect()
}
pub async fn register(&mut self, provider: Arc<dyn Provider>) {
let provider_id = provider.id().to_string();
self.providers.insert(provider_id, provider);
}
async fn create_anthropic_provider(&self) -> Result<Arc<dyn Provider>> {
Err(Error::Other(anyhow::anyhow!("Anthropic provider creation not implemented in this context")))
}
async fn create_openai_provider(&self) -> Result<Arc<dyn Provider>> {
Err(Error::Other(anyhow::anyhow!("OpenAI provider creation not implemented in this context")))
}
async fn create_github_copilot_provider(&self) -> Result<Arc<dyn Provider>> {
Err(Error::Other(anyhow::anyhow!("GitHub Copilot provider creation not implemented in this context")))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_model_string() {
let (provider, model) = ProviderRegistry::parse_model("anthropic/claude-3-opus");
assert_eq!(provider, "anthropic");
assert_eq!(model, "claude-3-opus");
let (provider, model) = ProviderRegistry::parse_model("openai:gpt-4");
assert_eq!(provider, "openai");
assert_eq!(model, "gpt-4");
let (provider, model) = ProviderRegistry::parse_model("claude-3-opus");
assert_eq!(provider, "anthropic");
assert_eq!(model, "claude-3-opus");
}
#[test]
fn test_model_capabilities_default() {
let caps = ModelCapabilities::default();
assert!(caps.text_generation);
assert!(!caps.tool_calling);
assert!(!caps.vision);
assert!(caps.streaming);
}
#[test]
fn test_provider_config_default() {
let config = ProviderConfig::default();
assert_eq!(config.timeout_seconds, 60);
assert_eq!(config.max_retries, 3);
assert_eq!(config.retry_delay_ms, 1000);
}
}