ai_agents_llm/
registry.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use ai_agents_core::{LLMError, LLMProvider};
5
6#[derive(Clone)]
7pub struct LLMRegistry {
8 providers: HashMap<String, Arc<dyn LLMProvider>>,
9 default_alias: String,
10 router_alias: Option<String>,
11}
12
13impl std::fmt::Debug for LLMRegistry {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15 f.debug_struct("LLMRegistry")
16 .field("providers", &self.providers.keys().collect::<Vec<_>>())
17 .field("default_alias", &self.default_alias)
18 .field("router_alias", &self.router_alias)
19 .finish()
20 }
21}
22
23impl LLMRegistry {
24 pub fn new() -> Self {
25 Self {
26 providers: HashMap::new(),
27 default_alias: "default".to_string(),
28 router_alias: None,
29 }
30 }
31
32 pub fn register(&mut self, alias: impl Into<String>, provider: Arc<dyn LLMProvider>) {
33 self.providers.insert(alias.into(), provider);
34 }
35
36 pub fn set_default(&mut self, alias: impl Into<String>) {
37 self.default_alias = alias.into();
38 }
39
40 pub fn set_router(&mut self, alias: impl Into<String>) {
41 self.router_alias = Some(alias.into());
42 }
43
44 pub fn get(&self, alias: &str) -> Result<Arc<dyn LLMProvider>, LLMError> {
45 self.providers
46 .get(alias)
47 .cloned()
48 .ok_or_else(|| LLMError::Config(format!("LLM alias not found: {}", alias)))
49 }
50
51 pub fn default(&self) -> Result<Arc<dyn LLMProvider>, LLMError> {
52 self.get(&self.default_alias)
53 }
54
55 pub fn router(&self) -> Result<Arc<dyn LLMProvider>, LLMError> {
56 match &self.router_alias {
57 Some(alias) => self.get(alias),
58 None => self.default(),
59 }
60 }
61
62 pub fn has(&self, alias: &str) -> bool {
63 self.providers.contains_key(alias)
64 }
65
66 pub fn aliases(&self) -> Vec<String> {
67 self.providers.keys().cloned().collect()
68 }
69
70 pub fn len(&self) -> usize {
71 self.providers.len()
72 }
73
74 pub fn is_empty(&self) -> bool {
75 self.providers.is_empty()
76 }
77}
78
79impl Default for LLMRegistry {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use ai_agents_core::{ChatMessage, FinishReason, LLMChunk, LLMConfig, LLMFeature, LLMResponse};
89 use async_trait::async_trait;
90
91 struct MockProvider {
92 name: String,
93 }
94
95 #[async_trait]
96 impl LLMProvider for MockProvider {
97 async fn complete(
98 &self,
99 _messages: &[ChatMessage],
100 _config: Option<&LLMConfig>,
101 ) -> Result<LLMResponse, LLMError> {
102 Ok(LLMResponse::new(
103 format!("Response from {}", self.name),
104 FinishReason::Stop,
105 ))
106 }
107
108 async fn complete_stream(
109 &self,
110 _messages: &[ChatMessage],
111 _config: Option<&LLMConfig>,
112 ) -> Result<
113 Box<dyn futures::Stream<Item = Result<LLMChunk, LLMError>> + Unpin + Send>,
114 LLMError,
115 > {
116 Err(LLMError::Other("Not implemented".into()))
117 }
118
119 fn provider_name(&self) -> &str {
120 &self.name
121 }
122
123 fn supports(&self, _feature: LLMFeature) -> bool {
124 false
125 }
126 }
127
128 #[test]
129 fn test_registry_basic() {
130 let mut registry = LLMRegistry::new();
131 let provider = Arc::new(MockProvider {
132 name: "test".into(),
133 });
134
135 registry.register("default", provider);
136 assert!(registry.has("default"));
137 assert!(!registry.has("unknown"));
138 assert_eq!(registry.len(), 1);
139 }
140
141 #[test]
142 fn test_registry_default_and_router() {
143 let mut registry = LLMRegistry::new();
144 registry.register(
145 "main",
146 Arc::new(MockProvider {
147 name: "main".into(),
148 }),
149 );
150 registry.register(
151 "router",
152 Arc::new(MockProvider {
153 name: "router".into(),
154 }),
155 );
156
157 registry.set_default("main");
158 registry.set_router("router");
159
160 assert!(registry.default().is_ok());
161 assert!(registry.router().is_ok());
162 assert_eq!(registry.default().unwrap().provider_name(), "main");
163 assert_eq!(registry.router().unwrap().provider_name(), "router");
164 }
165
166 #[test]
167 fn test_registry_router_fallback() {
168 let mut registry = LLMRegistry::new();
169 registry.register(
170 "default",
171 Arc::new(MockProvider {
172 name: "default".into(),
173 }),
174 );
175
176 let router = registry.router().unwrap();
177 assert_eq!(router.provider_name(), "default");
178 }
179
180 #[test]
181 fn test_registry_missing_alias() {
182 let registry = LLMRegistry::new();
183 assert!(registry.get("nonexistent").is_err());
184 }
185}