engram/embedding/
provider.rs1use std::collections::HashMap;
13use std::sync::Arc;
14
15use crate::error::{EngramError, Result};
16
17use super::Embedder;
18
19#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct EmbeddingProviderInfo {
27 pub id: String,
29 pub name: String,
31 pub model: String,
33 pub dimensions: usize,
35 pub requires_api_key: bool,
37 pub is_local: bool,
39}
40
41pub trait EmbeddingProvider: Embedder {
48 fn provider_info(&self) -> EmbeddingProviderInfo;
50}
51
52pub struct EmbeddingRegistry {
61 providers: HashMap<String, Arc<dyn EmbeddingProvider>>,
62 order: Vec<String>,
64 default_id: Option<String>,
66}
67
68impl EmbeddingRegistry {
69 pub fn new() -> Self {
71 Self {
72 providers: HashMap::new(),
73 order: Vec::new(),
74 default_id: None,
75 }
76 }
77
78 pub fn register(&mut self, provider: Arc<dyn EmbeddingProvider>) {
84 let id = provider.provider_info().id.clone();
85 if !self.providers.contains_key(&id) {
86 self.order.push(id.clone());
87 }
88 self.providers.insert(id, provider);
89 }
90
91 pub fn get(&self, id: &str) -> Option<Arc<dyn EmbeddingProvider>> {
95 self.providers.get(id).cloned()
96 }
97
98 pub fn list(&self) -> Vec<EmbeddingProviderInfo> {
100 self.order
101 .iter()
102 .filter_map(|id| self.providers.get(id))
103 .map(|p| p.provider_info())
104 .collect()
105 }
106
107 pub fn default_provider(&self) -> Option<Arc<dyn EmbeddingProvider>> {
113 if let Some(ref id) = self.default_id {
114 if let Some(p) = self.providers.get(id.as_str()) {
116 return Some(p.clone());
117 }
118 }
119 self.order
121 .first()
122 .and_then(|id| self.providers.get(id.as_str()))
123 .cloned()
124 }
125
126 pub fn set_default(&mut self, id: &str) -> Result<()> {
130 if self.providers.contains_key(id) {
131 self.default_id = Some(id.to_string());
132 Ok(())
133 } else {
134 Err(EngramError::InvalidInput(format!(
135 "No embedding provider registered with id '{id}'"
136 )))
137 }
138 }
139
140 pub fn count(&self) -> usize {
142 self.providers.len()
143 }
144}
145
146impl Default for EmbeddingRegistry {
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::error::Result;
158
159 struct MockProvider {
162 info: EmbeddingProviderInfo,
163 }
164
165 impl MockProvider {
166 fn new(id: &str, dimensions: usize) -> Self {
167 Self {
168 info: EmbeddingProviderInfo {
169 id: id.to_string(),
170 name: format!("Mock ({id})"),
171 model: format!("mock-{id}"),
172 dimensions,
173 requires_api_key: false,
174 is_local: true,
175 },
176 }
177 }
178 }
179
180 impl Embedder for MockProvider {
181 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
182 Ok(vec![0.0_f32; self.info.dimensions])
183 }
184
185 fn dimensions(&self) -> usize {
186 self.info.dimensions
187 }
188
189 fn model_name(&self) -> &str {
190 &self.info.model
191 }
192 }
193
194 impl EmbeddingProvider for MockProvider {
195 fn provider_info(&self) -> EmbeddingProviderInfo {
196 self.info.clone()
197 }
198 }
199
200 fn make_provider(id: &str) -> Arc<dyn EmbeddingProvider> {
201 Arc::new(MockProvider::new(id, 64))
202 }
203
204 #[test]
207 fn test_register_and_get_by_id() {
208 let mut registry = EmbeddingRegistry::new();
209 registry.register(make_provider("alpha"));
210
211 let provider = registry.get("alpha");
212 assert!(provider.is_some(), "registered provider should be found");
213 assert_eq!(provider.unwrap().provider_info().id, "alpha");
214 }
215
216 #[test]
217 fn test_get_unknown_returns_none() {
218 let registry = EmbeddingRegistry::new();
219 assert!(registry.get("nonexistent").is_none());
220 }
221
222 #[test]
223 fn test_list_returns_all_providers() {
224 let mut registry = EmbeddingRegistry::new();
225 registry.register(make_provider("alpha"));
226 registry.register(make_provider("beta"));
227 registry.register(make_provider("gamma"));
228
229 let list = registry.list();
230 assert_eq!(list.len(), 3);
231 let ids: Vec<&str> = list.iter().map(|i| i.id.as_str()).collect();
232 assert!(ids.contains(&"alpha"));
233 assert!(ids.contains(&"beta"));
234 assert!(ids.contains(&"gamma"));
235 }
236
237 #[test]
238 fn test_list_preserves_insertion_order() {
239 let mut registry = EmbeddingRegistry::new();
240 registry.register(make_provider("first"));
241 registry.register(make_provider("second"));
242 registry.register(make_provider("third"));
243
244 let ids: Vec<String> = registry.list().into_iter().map(|i| i.id).collect();
245 assert_eq!(ids, vec!["first", "second", "third"]);
246 }
247
248 #[test]
249 fn test_default_returns_first_registered() {
250 let mut registry = EmbeddingRegistry::new();
251 assert!(
252 registry.default_provider().is_none(),
253 "empty registry has no default"
254 );
255
256 registry.register(make_provider("first"));
257 registry.register(make_provider("second"));
258
259 let default = registry.default_provider().expect("should have a default");
260 assert_eq!(default.provider_info().id, "first");
261 }
262
263 #[test]
264 fn test_set_default_changes_default() {
265 let mut registry = EmbeddingRegistry::new();
266 registry.register(make_provider("alpha"));
267 registry.register(make_provider("beta"));
268
269 registry.set_default("beta").expect("beta is registered");
270
271 let default = registry.default_provider().expect("should have a default");
272 assert_eq!(default.provider_info().id, "beta");
273 }
274
275 #[test]
276 fn test_set_default_unknown_returns_error() {
277 let mut registry = EmbeddingRegistry::new();
278 let result = registry.set_default("does-not-exist");
279 assert!(result.is_err(), "unknown id should return an error");
280 }
281
282 #[test]
283 fn test_count() {
284 let mut registry = EmbeddingRegistry::new();
285 assert_eq!(registry.count(), 0);
286
287 registry.register(make_provider("a"));
288 assert_eq!(registry.count(), 1);
289
290 registry.register(make_provider("b"));
291 assert_eq!(registry.count(), 2);
292 }
293
294 #[test]
295 fn test_register_replaces_existing_id() {
296 let mut registry = EmbeddingRegistry::new();
297 registry.register(make_provider("a"));
298 registry.register(Arc::new(MockProvider {
300 info: EmbeddingProviderInfo {
301 id: "a".to_string(),
302 name: "Updated A".to_string(),
303 model: "updated-model".to_string(),
304 dimensions: 128,
305 requires_api_key: true,
306 is_local: false,
307 },
308 }));
309
310 assert_eq!(registry.count(), 1);
312
313 let info = registry.get("a").unwrap().provider_info();
314 assert_eq!(info.name, "Updated A");
315 assert_eq!(info.dimensions, 128);
316 }
317
318 #[test]
319 fn test_provider_info_fields() {
320 let info = EmbeddingProviderInfo {
321 id: "test".to_string(),
322 name: "Test Provider".to_string(),
323 model: "test-model-v1".to_string(),
324 dimensions: 256,
325 requires_api_key: true,
326 is_local: false,
327 };
328 assert_eq!(info.id, "test");
329 assert_eq!(info.dimensions, 256);
330 assert!(info.requires_api_key);
331 assert!(!info.is_local);
332 }
333
334 #[test]
335 fn test_embed_via_registry_provider() {
336 let mut registry = EmbeddingRegistry::new();
337 registry.register(make_provider("mock"));
338
339 let provider = registry.get("mock").expect("mock is registered");
340 let embedding = provider.embed("hello world").expect("embed should succeed");
341 assert_eq!(embedding.len(), 64);
342 }
343}