use crate::core::{
EmbeddingModel, ImageModel, LanguageModel, Result,
};
use crate::core::error::ProviderError;
use std::collections::HashMap;
pub trait Provider: Send + Sync {
fn language_model(&self, model_id: &str) -> Option<Box<dyn LanguageModel>>;
fn embedding_model(&self, _model_id: &str) -> Option<Box<dyn EmbeddingModel>> {
None
}
fn image_model(&self, _model_id: &str) -> Option<Box<dyn ImageModel>> {
None
}
}
pub struct ProviderRegistry {
providers: HashMap<String, Box<dyn Provider>>,
separator: char,
}
impl ProviderRegistry {
#[must_use]
pub fn new() -> Self {
Self {
providers: HashMap::new(),
separator: ':',
}
}
#[must_use]
pub fn with_separator(separator: char) -> Self {
Self {
providers: HashMap::new(),
separator,
}
}
#[must_use]
pub fn register(mut self, name: impl Into<String>, provider: impl Provider + 'static) -> Self {
self.providers.insert(name.into(), Box::new(provider));
self
}
fn split_id(&self, id: &str) -> Result<(String, String)> {
let sep_pos = id.find(self.separator).ok_or_else(|| {
ProviderError::InvalidResponse(format!(
"Invalid model ID '{id}': expected format 'provider{sep}model'",
sep = self.separator
))
})?;
let provider_id = id[..sep_pos].to_string();
let model_id = id[sep_pos + 1..].to_string();
Ok((provider_id, model_id))
}
pub fn language_model(&self, id: &str) -> Result<Box<dyn LanguageModel>> {
let (provider_id, model_id) = self.split_id(id)?;
let provider = self.providers.get(&provider_id).ok_or_else(|| {
ProviderError::NotSupported(format!(
"No provider registered with name '{provider_id}'. Available: {:?}",
self.providers.keys().collect::<Vec<_>>()
))
})?;
provider.language_model(&model_id).ok_or_else(|| {
ProviderError::NotSupported(format!(
"Provider '{provider_id}' does not support language model '{model_id}'"
))
})
}
pub fn embedding_model(&self, id: &str) -> Result<Box<dyn EmbeddingModel>> {
let (provider_id, model_id) = self.split_id(id)?;
let provider = self.providers.get(&provider_id).ok_or_else(|| {
ProviderError::NotSupported(format!(
"No provider registered with name '{provider_id}'"
))
})?;
provider.embedding_model(&model_id).ok_or_else(|| {
ProviderError::NotSupported(format!(
"Provider '{provider_id}' does not support embedding model '{model_id}'"
))
})
}
pub fn image_model(&self, id: &str) -> Result<Box<dyn ImageModel>> {
let (provider_id, model_id) = self.split_id(id)?;
let provider = self.providers.get(&provider_id).ok_or_else(|| {
ProviderError::NotSupported(format!(
"No provider registered with name '{provider_id}'"
))
})?;
provider.image_model(&model_id).ok_or_else(|| {
ProviderError::NotSupported(format!(
"Provider '{provider_id}' does not support image model '{model_id}'"
))
})
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}