use std::collections::HashMap;
use std::sync::Arc;
use lellm_core::LlmError;
use crate::LlmProvider;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TaskLevel {
Flash,
Standard,
Pro,
}
#[derive(Debug, Clone)]
pub struct RouteEntry {
pub provider_id: String,
pub model: String,
}
#[derive(Clone)]
pub struct ResolvedModel {
pub provider: Arc<dyn LlmProvider>,
pub model: String,
}
pub struct ModelRouter {
routes: HashMap<TaskLevel, RouteEntry>,
}
impl ModelRouter {
pub fn new() -> Self {
Self {
routes: HashMap::new(),
}
}
pub fn add_route(&mut self, level: TaskLevel, entry: RouteEntry) {
self.routes.insert(level, entry);
}
pub fn resolve(&self, level: TaskLevel) -> Option<&RouteEntry> {
self.routes.get(&level)
}
}
impl Default for ModelRouter {
fn default() -> Self {
Self::new()
}
}
pub struct ProviderRegistry {
providers: HashMap<String, Arc<dyn LlmProvider>>,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
providers: HashMap::new(),
}
}
pub fn register(&mut self, id: &str, provider: Arc<dyn LlmProvider>) {
self.providers.insert(id.to_string(), provider);
}
pub fn get(&self, id: &str) -> Option<Arc<dyn LlmProvider>> {
self.providers.get(id).cloned()
}
pub fn resolve(&self, route: &RouteEntry) -> Result<ResolvedModel, LlmError> {
let provider = self
.get(&route.provider_id)
.ok_or_else(|| LlmError::ApiError {
provider: route.provider_id.clone(),
status: 0,
code: None,
message: format!("provider not registered: {}", route.provider_id),
})?;
Ok(ResolvedModel {
provider,
model: route.model.clone(),
})
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}