use std::collections::HashMap;
use crate::provider::{ModelInfo, Provider};
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum RegistryError {
#[error("invalid model spec: {0}")]
InvalidSpec(String),
#[error("unknown provider: {0}")]
UnknownProvider(String),
#[error("unknown model '{model}' for provider '{provider}'")]
UnknownModel { provider: String, model: String },
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum RegistrationError {
#[error("provider id cannot be empty")]
EmptyProviderId,
#[error("model id cannot be empty for provider '{provider}'")]
EmptyModelId { provider: String },
#[error("model '{model}' already registered for provider '{provider}'")]
DuplicateModel { provider: String, model: String },
}
#[derive(Debug, Clone, Copy)]
pub struct ModelCapabilities {
pub context_window: u64,
pub max_output_tokens: u64,
pub supports_images: bool,
pub supports_streaming: bool,
pub supports_thinking: bool,
}
pub struct ProviderRegistry {
providers: Vec<Box<dyn Provider>>,
model_overrides: HashMap<(String, String), ModelInfo>,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
providers: Vec::new(),
model_overrides: HashMap::new(),
}
}
pub fn register_provider(
&mut self,
provider: Box<dyn Provider>,
) -> Result<(), RegistrationError> {
if provider.id().is_empty() {
return Err(RegistrationError::EmptyProviderId);
}
let id = provider.id().to_owned();
self.providers.retain(|p| p.id() != id);
self.providers.push(provider);
Ok(())
}
pub fn register(&mut self, provider: Box<dyn Provider>) {
let id = provider.id().to_owned();
self.providers.retain(|p| p.id() != id);
self.providers.push(provider);
}
pub fn register_model(
&mut self,
provider_id: &str,
model: ModelInfo,
) -> Result<(), RegistrationError> {
if model.id.is_empty() {
return Err(RegistrationError::EmptyModelId {
provider: provider_id.to_owned(),
});
}
let key = (provider_id.to_owned(), model.id.clone());
if self.model_overrides.contains_key(&key) {
return Err(RegistrationError::DuplicateModel {
provider: provider_id.to_owned(),
model: model.id,
});
}
self.model_overrides.insert(key, model);
Ok(())
}
pub fn provider_ids(&self) -> Vec<&str> {
let mut ids: Vec<&str> = self.providers.iter().map(|p| p.id()).collect();
ids.sort();
ids
}
pub fn resolve(&self, spec: &str) -> Result<(&dyn Provider, &ModelInfo), RegistryError> {
let (provider_id, model_id) = split_spec(spec)?;
let provider = self
.providers
.iter()
.find(|p| p.id() == provider_id)
.ok_or_else(|| RegistryError::UnknownProvider(provider_id.to_owned()))?;
let key = (provider_id.to_owned(), model_id.to_owned());
if let Some(model) = self.model_overrides.get(&key) {
return Ok((provider.as_ref(), model));
}
let model = provider
.models()
.iter()
.find(|m| m.id == model_id)
.ok_or_else(|| RegistryError::UnknownModel {
provider: provider_id.to_owned(),
model: model_id.to_owned(),
})?;
Ok((provider.as_ref(), model))
}
pub fn capabilities(&self, spec: &str) -> Result<ModelCapabilities, RegistryError> {
let (_, model) = self.resolve(spec)?;
Ok(ModelCapabilities {
context_window: model.context_window,
max_output_tokens: model.max_output_tokens,
supports_images: model.supports_images,
supports_streaming: model.supports_streaming,
supports_thinking: model.supports_thinking,
})
}
pub fn get_provider(&self, id: &str) -> Option<&dyn Provider> {
self.providers
.iter()
.find(|p| p.id() == id)
.map(|p| p.as_ref())
}
pub fn all_models(&self) -> Vec<(&str, &ModelInfo)> {
let mut result = Vec::new();
for provider in &self.providers {
for model in provider.models() {
let key = (provider.id().to_owned(), model.id.clone());
if self.model_overrides.contains_key(&key) {
continue; }
result.push((provider.id(), model));
}
}
let mut overrides = self.model_overrides.iter().collect::<Vec<_>>();
overrides.sort_by(|((provider_a, model_a), _), ((provider_b, model_b), _)| {
provider_a
.cmp(provider_b)
.then_with(|| model_a.cmp(model_b))
});
for ((provider_id, _model_id), model) in overrides {
result.push((provider_id.as_str(), model));
}
result
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
fn split_spec(spec: &str) -> Result<(&str, &str), RegistryError> {
let Some((provider, model)) = spec.split_once(':') else {
return Err(RegistryError::InvalidSpec(format!(
"spec must be 'provider:model', got: {spec:?}"
)));
};
if provider.is_empty() || model.is_empty() {
return Err(RegistryError::InvalidSpec(format!(
"spec must be 'provider:model', got: {spec:?}"
)));
}
Ok((provider, model))
}