use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use super::model::{ModelInfo, ProviderInfo};
use super::provider::LlmProvider;
use crate::session::SessionCapabilitiesConfig;
#[derive(Clone)]
pub struct ProviderEntry {
provider: Arc<dyn LlmProvider>,
models: Vec<ModelInfo>,
capabilities: SessionCapabilitiesConfig,
}
impl std::fmt::Debug for ProviderEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProviderEntry")
.field("provider", &self.provider.info())
.field("models", &self.models)
.field("capabilities", &self.capabilities)
.finish()
}
}
impl ProviderEntry {
#[must_use]
pub fn new(
provider: Arc<dyn LlmProvider>,
models: Vec<ModelInfo>,
capabilities: SessionCapabilitiesConfig,
) -> Self {
Self {
provider,
models,
capabilities,
}
}
#[must_use]
pub fn provider(&self) -> &Arc<dyn LlmProvider> {
&self.provider
}
#[must_use]
pub fn models(&self) -> &[ModelInfo] {
&self.models
}
#[must_use]
pub fn capabilities(&self) -> SessionCapabilitiesConfig {
self.capabilities
}
}
#[derive(Debug, thiserror::Error)]
pub enum ProviderRegistryError {
#[error("provider registry requires at least one entry")]
Empty,
#[error(
"duplicate model id `{model}` declared twice by provider `{provider}`; \
the same (provider, model) pair must be unique within a build"
)]
DuplicateSelection { provider: String, model: String },
#[error(
"default model `{model}` is not declared by provider `{provider}`; \
add it under that provider, or point `default.provider` at the one that has it"
)]
UnknownDefaultModel { provider: String, model: String },
}
#[derive(Debug)]
pub struct ProviderRegistry {
entries: Vec<ProviderEntry>,
model_index: HashMap<(String, String), usize>,
default: (usize, usize),
}
impl ProviderRegistry {
#[must_use]
pub fn single(provider: Arc<dyn LlmProvider>, default_model: ModelInfo) -> Arc<Self> {
let vendor = provider.info().vendor;
let model_id = default_model.id.clone();
let entries = vec![ProviderEntry::new(
provider,
vec![default_model],
SessionCapabilitiesConfig::default(),
)];
Arc::new(
Self::new(entries, &vendor, &model_id)
.expect("single-entry registry with matching default model is always valid"),
)
}
pub fn new(
entries: Vec<ProviderEntry>,
default_provider: &str,
default_model: &str,
) -> Result<Self, ProviderRegistryError> {
if entries.is_empty() {
return Err(ProviderRegistryError::Empty);
}
let mut model_index = HashMap::new();
let mut default_pos = None;
for (entry_idx, entry) in entries.iter().enumerate() {
let provider_vendor = entry.provider.info().vendor;
let mut seen_in_entry = HashSet::new();
for (model_idx, model) in entry.models.iter().enumerate() {
if !seen_in_entry.insert(model.id.clone()) {
continue;
}
let key = (provider_vendor.clone(), model.id.clone());
if model_index.insert(key, entry_idx).is_some() {
return Err(ProviderRegistryError::DuplicateSelection {
provider: provider_vendor,
model: model.id.clone(),
});
}
if provider_vendor == default_provider
&& model.id == default_model
&& default_pos.is_none()
{
default_pos = Some((entry_idx, model_idx));
}
}
}
let default = default_pos.ok_or_else(|| ProviderRegistryError::UnknownDefaultModel {
provider: default_provider.to_string(),
model: default_model.to_string(),
})?;
Ok(Self {
entries,
model_index,
default,
})
}
#[must_use]
pub fn default_entry(&self) -> &ProviderEntry {
let (entry_idx, _) = self.default;
self.entries
.get(entry_idx)
.expect("default index validated in `new`")
}
#[must_use]
pub fn default_model(&self) -> &str {
let (entry_idx, model_idx) = self.default;
let entry = self
.entries
.get(entry_idx)
.expect("default index validated in `new`");
entry
.models
.get(model_idx)
.map(|m| m.id.as_str())
.expect("default model index validated in `new`")
}
#[must_use]
pub fn entry_for(&self, vendor: &str, model_id: &str) -> Option<&ProviderEntry> {
self.model_index
.get(&(vendor.to_string(), model_id.to_string()))
.and_then(|idx| self.entries.get(*idx))
}
#[must_use]
pub fn first_entry_for_model(&self, model_id: &str) -> Option<&ProviderEntry> {
self.entries
.iter()
.find(|entry| entry.models.iter().any(|m| m.id == model_id))
}
#[must_use]
pub fn entries(&self) -> &[ProviderEntry] {
&self.entries
}
#[must_use]
pub fn list_candidates(&self) -> Vec<ModelCandidate> {
let mut out = Vec::new();
for entry in &self.entries {
let info = entry.provider.info();
for model in &entry.models {
out.push(ModelCandidate {
provider: info.clone(),
model: model.clone(),
});
}
}
out
}
#[must_use]
pub fn candidate_for(&self, vendor: &str, model_id: &str) -> Option<ModelCandidate> {
let entry = self.entry_for(vendor, model_id)?;
let model = entry.models.iter().find(|m| m.id == model_id)?.clone();
Some(ModelCandidate {
provider: entry.provider.info(),
model,
})
}
}
#[derive(Debug, Clone)]
pub struct ModelCandidate {
pub provider: ProviderInfo,
pub model: ModelInfo,
}
#[cfg(test)]
mod tests;