use crate::config::Config;
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::sync::RwLock;
#[derive(thiserror::Error, Debug)]
#[error("Registry lock error")]
pub struct LockError;
macro_rules! read_lock {
($lock:expr, $field:ident) => {
$lock.read().map_err(|_| {
tracing::error!("{} lock is poisoned", stringify!($field));
LockError
})
};
}
macro_rules! write_lock {
($lock:expr, $field:ident) => {
$lock.write().map_err(|_| {
tracing::error!("{} lock is poisoned", stringify!($field));
LockError
})
};
}
pub trait ProviderBuilder: Send + Sync {
fn name(&self) -> &'static str;
fn aliases(&self) -> Vec<&'static str> {
vec![]
}
fn category(&self) -> ProviderCategory {
ProviderCategory::Standard
}
fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>>;
fn requires_api_key(&self) -> bool {
true
}
fn default_model(&self) -> Option<&'static str> {
None
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub enum ProviderCategory {
Standard,
OpenAICompatible,
Local,
Cloud,
}
#[derive(Clone)]
pub struct ProviderEntry {
pub name: &'static str,
pub aliases: Vec<&'static str>,
pub category: ProviderCategory,
#[allow(dead_code)]
pub requires_api_key: bool,
#[allow(dead_code)]
pub default_model: Option<&'static str>,
}
impl ProviderEntry {
pub fn from_builder(builder: &dyn ProviderBuilder) -> Self {
Self {
name: builder.name(),
aliases: builder.aliases(),
category: builder.category(),
requires_api_key: builder.requires_api_key(),
default_model: builder.default_model(),
}
}
#[allow(dead_code)]
pub fn matches(&self, provider: &str) -> bool {
let lower = provider.to_lowercase();
self.name.eq_ignore_ascii_case(&lower)
|| self.aliases.iter().any(|&a| a.eq_ignore_ascii_case(&lower))
}
}
pub struct ProviderRegistry {
entries: RwLock<HashMap<&'static str, ProviderEntry>>,
builders: RwLock<HashMap<&'static str, Box<dyn ProviderBuilder>>>,
by_alias: RwLock<HashMap<&'static str, &'static str>>,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
entries: RwLock::new(HashMap::new()),
builders: RwLock::new(HashMap::new()),
by_alias: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, builder: Box<dyn ProviderBuilder>) -> Result<()> {
let name = builder.name();
let entry = ProviderEntry::from_builder(&*builder);
write_lock!(self.entries, entries)?.insert(name, entry.clone());
write_lock!(self.builders, builders)?.insert(name, builder);
for &alias in &entry.aliases {
write_lock!(self.by_alias, by_alias)?.insert(alias, name);
}
Ok(())
}
#[allow(dead_code)]
pub fn get(&self, provider: &str) -> Option<ProviderEntry> {
let lower = provider.to_lowercase();
let entries = read_lock!(self.entries, entries).ok()?;
if let Some(entry) = entries.get(lower.as_str()) {
return Some(entry.clone());
}
let by_alias = read_lock!(self.by_alias, by_alias).ok()?;
if let Some(&primary) = by_alias.get(lower.as_str()) {
return entries.get(primary).cloned();
}
None
}
pub fn all(&self) -> Option<Vec<ProviderEntry>> {
let entries = read_lock!(self.entries, entries).ok()?;
Some(entries.values().cloned().collect())
}
pub fn by_category(&self, category: ProviderCategory) -> Option<Vec<ProviderEntry>> {
let entries = read_lock!(self.entries, entries).ok()?;
Some(
entries
.values()
.filter(|e| e.category == category)
.cloned()
.collect(),
)
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
match read_lock!(self.entries, entries) {
Ok(entries) => entries.is_empty(),
Err(_) => true,
}
}
#[allow(dead_code)]
pub fn len(&self) -> usize {
match read_lock!(self.entries, entries) {
Ok(entries) => entries.len(),
Err(_) => 0,
}
}
pub fn create(
&self,
name: &str,
config: &Config,
) -> Result<Option<Box<dyn super::AIProvider>>> {
let lower = name.to_lowercase();
let builders = read_lock!(self.builders, builders).context("Failed to read builders")?;
let by_alias = read_lock!(self.by_alias, by_alias).context("Failed to read aliases")?;
if let Some(builder) = builders.get(lower.as_str()) {
return Ok(Some(builder.create(config)?));
}
if let Some(&primary) = by_alias.get(lower.as_str()) {
if let Some(builder) = builders.get(primary) {
return Ok(Some(builder.create(config)?));
}
}
Ok(None)
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}