1use std::collections::HashMap;
4use std::sync::Arc;
5
6use lellm_core::LlmError;
7
8use crate::LlmProvider;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub enum TaskLevel {
13 Flash,
15 Standard,
17 Pro,
19}
20
21#[derive(Debug, Clone)]
23pub struct RouteEntry {
24 pub provider_id: String,
25 pub model: String,
26}
27
28#[derive(Clone)]
30pub struct ResolvedModel {
31 pub provider: Arc<dyn LlmProvider>,
32 pub model: String,
33}
34
35pub struct ModelRouter {
37 routes: HashMap<TaskLevel, RouteEntry>,
38}
39
40impl ModelRouter {
41 pub fn new() -> Self {
42 Self {
43 routes: HashMap::new(),
44 }
45 }
46
47 pub fn add_route(&mut self, level: TaskLevel, entry: RouteEntry) {
48 self.routes.insert(level, entry);
49 }
50
51 pub fn resolve(&self, level: TaskLevel) -> Option<&RouteEntry> {
53 self.routes.get(&level)
54 }
55}
56
57impl Default for ModelRouter {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63pub struct ProviderRegistry {
65 providers: HashMap<String, Arc<dyn LlmProvider>>,
66}
67
68impl ProviderRegistry {
69 pub fn new() -> Self {
70 Self {
71 providers: HashMap::new(),
72 }
73 }
74
75 pub fn register(&mut self, id: &str, provider: Arc<dyn LlmProvider>) {
76 self.providers.insert(id.to_string(), provider);
77 }
78
79 pub fn get(&self, id: &str) -> Option<Arc<dyn LlmProvider>> {
80 self.providers.get(id).cloned()
81 }
82
83 pub fn resolve(&self, route: &RouteEntry) -> Result<ResolvedModel, LlmError> {
85 let provider = self
86 .get(&route.provider_id)
87 .ok_or_else(|| LlmError::ApiError {
88 provider: route.provider_id.clone(),
89 status: 0,
90 code: None,
91 message: format!("provider not registered: {}", route.provider_id),
92 })?;
93 Ok(ResolvedModel {
94 provider,
95 model: route.model.clone(),
96 })
97 }
98}
99
100impl Default for ProviderRegistry {
101 fn default() -> Self {
102 Self::new()
103 }
104}