Skip to main content

ai_agents_llm/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use ai_agents_core::{LLMError, LLMProvider};
5
6#[derive(Clone)]
7pub struct LLMRegistry {
8    providers: HashMap<String, Arc<dyn LLMProvider>>,
9    default_alias: String,
10    router_alias: Option<String>,
11}
12
13impl std::fmt::Debug for LLMRegistry {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        f.debug_struct("LLMRegistry")
16            .field("providers", &self.providers.keys().collect::<Vec<_>>())
17            .field("default_alias", &self.default_alias)
18            .field("router_alias", &self.router_alias)
19            .finish()
20    }
21}
22
23impl LLMRegistry {
24    pub fn new() -> Self {
25        Self {
26            providers: HashMap::new(),
27            default_alias: "default".to_string(),
28            router_alias: None,
29        }
30    }
31
32    pub fn register(&mut self, alias: impl Into<String>, provider: Arc<dyn LLMProvider>) {
33        self.providers.insert(alias.into(), provider);
34    }
35
36    pub fn set_default(&mut self, alias: impl Into<String>) {
37        self.default_alias = alias.into();
38    }
39
40    pub fn set_router(&mut self, alias: impl Into<String>) {
41        self.router_alias = Some(alias.into());
42    }
43
44    pub fn get(&self, alias: &str) -> Result<Arc<dyn LLMProvider>, LLMError> {
45        self.providers
46            .get(alias)
47            .cloned()
48            .ok_or_else(|| LLMError::Config(format!("LLM alias not found: {}", alias)))
49    }
50
51    pub fn default(&self) -> Result<Arc<dyn LLMProvider>, LLMError> {
52        self.get(&self.default_alias)
53    }
54
55    pub fn router(&self) -> Result<Arc<dyn LLMProvider>, LLMError> {
56        match &self.router_alias {
57            Some(alias) => self.get(alias),
58            None => self.default(),
59        }
60    }
61
62    pub fn has(&self, alias: &str) -> bool {
63        self.providers.contains_key(alias)
64    }
65
66    pub fn aliases(&self) -> Vec<String> {
67        self.providers.keys().cloned().collect()
68    }
69
70    pub fn len(&self) -> usize {
71        self.providers.len()
72    }
73
74    pub fn is_empty(&self) -> bool {
75        self.providers.is_empty()
76    }
77}
78
79impl Default for LLMRegistry {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use ai_agents_core::{ChatMessage, FinishReason, LLMChunk, LLMConfig, LLMFeature, LLMResponse};
89    use async_trait::async_trait;
90
91    struct MockProvider {
92        name: String,
93    }
94
95    #[async_trait]
96    impl LLMProvider for MockProvider {
97        async fn complete(
98            &self,
99            _messages: &[ChatMessage],
100            _config: Option<&LLMConfig>,
101        ) -> Result<LLMResponse, LLMError> {
102            Ok(LLMResponse::new(
103                format!("Response from {}", self.name),
104                FinishReason::Stop,
105            ))
106        }
107
108        async fn complete_stream(
109            &self,
110            _messages: &[ChatMessage],
111            _config: Option<&LLMConfig>,
112        ) -> Result<
113            Box<dyn futures::Stream<Item = Result<LLMChunk, LLMError>> + Unpin + Send>,
114            LLMError,
115        > {
116            Err(LLMError::Other("Not implemented".into()))
117        }
118
119        fn provider_name(&self) -> &str {
120            &self.name
121        }
122
123        fn supports(&self, _feature: LLMFeature) -> bool {
124            false
125        }
126    }
127
128    #[test]
129    fn test_registry_basic() {
130        let mut registry = LLMRegistry::new();
131        let provider = Arc::new(MockProvider {
132            name: "test".into(),
133        });
134
135        registry.register("default", provider);
136        assert!(registry.has("default"));
137        assert!(!registry.has("unknown"));
138        assert_eq!(registry.len(), 1);
139    }
140
141    #[test]
142    fn test_registry_default_and_router() {
143        let mut registry = LLMRegistry::new();
144        registry.register(
145            "main",
146            Arc::new(MockProvider {
147                name: "main".into(),
148            }),
149        );
150        registry.register(
151            "router",
152            Arc::new(MockProvider {
153                name: "router".into(),
154            }),
155        );
156
157        registry.set_default("main");
158        registry.set_router("router");
159
160        assert!(registry.default().is_ok());
161        assert!(registry.router().is_ok());
162        assert_eq!(registry.default().unwrap().provider_name(), "main");
163        assert_eq!(registry.router().unwrap().provider_name(), "router");
164    }
165
166    #[test]
167    fn test_registry_router_fallback() {
168        let mut registry = LLMRegistry::new();
169        registry.register(
170            "default",
171            Arc::new(MockProvider {
172                name: "default".into(),
173            }),
174        );
175
176        let router = registry.router().unwrap();
177        assert_eq!(router.provider_name(), "default");
178    }
179
180    #[test]
181    fn test_registry_missing_alias() {
182        let registry = LLMRegistry::new();
183        assert!(registry.get("nonexistent").is_err());
184    }
185}