use std::sync::Arc;
use super::BaseLlm;
type LlmFactory = Box<dyn Fn(&str) -> Arc<dyn BaseLlm> + Send + Sync>;
pub struct LlmRegistry {
factories: Vec<(String, LlmFactory)>,
}
impl LlmRegistry {
pub fn new() -> Self {
Self {
factories: Vec::new(),
}
}
pub fn register(
&mut self,
pattern: impl Into<String>,
factory: impl Fn(&str) -> Arc<dyn BaseLlm> + Send + Sync + 'static,
) {
self.factories.push((pattern.into(), Box::new(factory)));
}
pub fn resolve(&self, model_name: &str) -> Option<Arc<dyn BaseLlm>> {
for (pattern, factory) in &self.factories {
if model_name.starts_with(pattern.as_str()) {
return Some(factory(model_name));
}
}
None
}
pub fn len(&self) -> usize {
self.factories.len()
}
pub fn is_empty(&self) -> bool {
self.factories.is_empty()
}
}
impl Default for LlmRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::{LlmError, LlmRequest, LlmResponse};
use async_trait::async_trait;
struct MockLlm {
model: String,
}
#[async_trait]
impl BaseLlm for MockLlm {
fn model_id(&self) -> &str {
&self.model
}
async fn generate(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
Err(LlmError::Other("mock".into()))
}
}
#[test]
fn register_and_resolve() {
let mut registry = LlmRegistry::new();
registry.register("gemini", |name: &str| {
Arc::new(MockLlm {
model: name.to_string(),
})
});
let llm = registry.resolve("gemini-2.5-flash").unwrap();
assert_eq!(llm.model_id(), "gemini-2.5-flash");
}
#[test]
fn resolve_unknown_returns_none() {
let registry = LlmRegistry::new();
assert!(registry.resolve("gpt-4").is_none());
}
#[test]
fn first_match_wins() {
let mut registry = LlmRegistry::new();
registry.register("gemini-2.5", |name: &str| {
Arc::new(MockLlm {
model: format!("v2.5:{name}"),
})
});
registry.register("gemini", |name: &str| {
Arc::new(MockLlm {
model: format!("generic:{name}"),
})
});
let llm = registry.resolve("gemini-2.5-flash").unwrap();
assert_eq!(llm.model_id(), "v2.5:gemini-2.5-flash");
let llm2 = registry.resolve("gemini-1.5-pro").unwrap();
assert_eq!(llm2.model_id(), "generic:gemini-1.5-pro");
}
#[test]
fn len_and_is_empty() {
let mut registry = LlmRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
registry.register("test", |name: &str| {
Arc::new(MockLlm {
model: name.to_string(),
})
});
assert!(!registry.is_empty());
assert_eq!(registry.len(), 1);
}
}