use std::collections::HashMap;
use super::{LlmConfig, LlmError, RigLlmProvider};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::Display, strum::EnumString, strum::AsRefStr)]
#[strum(serialize_all = "lowercase")]
pub enum LlmRole {
Extraction,
Contradiction,
Relational,
}
#[derive(Debug, Clone, Default)]
pub struct LlmRegistry {
providers: HashMap<LlmRole, RigLlmProvider>,
}
impl LlmRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn get(&self, role: LlmRole) -> Option<&RigLlmProvider> {
self.providers.get(&role)
}
#[must_use]
pub fn get_with_fallback(
&self,
primary: LlmRole,
fallback: LlmRole,
) -> Option<&RigLlmProvider> {
self.providers.get(&primary).or_else(|| self.providers.get(&fallback))
}
pub fn insert(&mut self, role: LlmRole, provider: RigLlmProvider) {
self.providers.insert(role, provider);
}
pub fn install(&mut self, role: LlmRole, config: LlmConfig) -> Result<(), LlmError> {
let kind = config.kind();
let provider = RigLlmProvider::new(config)?;
self.insert(role, provider);
tracing::event!(
name: "memoir.llm.configured",
tracing::Level::INFO,
role = role.as_ref(),
provider = kind.as_ref(),
"configured {{provider}} provider for {{role}}",
);
Ok(())
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.providers.is_empty()
}
}