Skip to main content

lellm_provider/
router.rs

1//! ModelRouter — 任务分级路由。
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use lellm_core::LlmError;
7
8use crate::LlmProvider;
9
10/// 任务级别。
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub enum TaskLevel {
13    /// 快速/便宜,如简单问答、格式转换
14    Flash,
15    /// 默认,如一般对话
16    Standard,
17    /// 复杂推理,如代码生成、深度分析
18    Pro,
19}
20
21/// 路由条目 — Provider + Model 的组合
22#[derive(Debug, Clone)]
23pub struct RouteEntry {
24    pub provider_id: String,
25    pub model: String,
26}
27
28/// 解析后的模型 — 从 Registry 中解析 RouteEntry 得到
29#[derive(Clone)]
30pub struct ResolvedModel {
31    pub provider: Arc<dyn LlmProvider>,
32    pub model: String,
33}
34
35/// 模型路由器 — 根据任务级别选择路由。
36pub struct ModelRouter {
37    routes: HashMap<TaskLevel, RouteEntry>,
38}
39
40impl ModelRouter {
41    pub fn new() -> Self {
42        Self {
43            routes: HashMap::new(),
44        }
45    }
46
47    pub fn add_route(&mut self, level: TaskLevel, entry: RouteEntry) {
48        self.routes.insert(level, entry);
49    }
50
51    /// 根据任务级别解析路由
52    pub fn resolve(&self, level: TaskLevel) -> Option<&RouteEntry> {
53        self.routes.get(&level)
54    }
55}
56
57impl Default for ModelRouter {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63/// Provider 注册表 — 持有所有 Provider 实例。
64pub struct ProviderRegistry {
65    providers: HashMap<String, Arc<dyn LlmProvider>>,
66}
67
68impl ProviderRegistry {
69    pub fn new() -> Self {
70        Self {
71            providers: HashMap::new(),
72        }
73    }
74
75    pub fn register(&mut self, id: &str, provider: Arc<dyn LlmProvider>) {
76        self.providers.insert(id.to_string(), provider);
77    }
78
79    pub fn get(&self, id: &str) -> Option<Arc<dyn LlmProvider>> {
80        self.providers.get(id).cloned()
81    }
82
83    /// 从 RouteEntry 解析为 ResolvedModel
84    pub fn resolve(&self, route: &RouteEntry) -> Result<ResolvedModel, LlmError> {
85        let provider = self
86            .get(&route.provider_id)
87            .ok_or_else(|| LlmError::ApiError {
88                provider: route.provider_id.clone(),
89                status: 0,
90                code: None,
91                message: format!("provider not registered: {}", route.provider_id),
92            })?;
93        Ok(ResolvedModel {
94            provider,
95            model: route.model.clone(),
96        })
97    }
98}
99
100impl Default for ProviderRegistry {
101    fn default() -> Self {
102        Self::new()
103    }
104}