Skip to main content

engram/embedding/
provider.rs

1//! EmbeddingProvider trait and registry (RML-1231)
2//!
3//! Extends the [`Embedder`] trait with provider metadata and a runtime registry
4//! for discovering and selecting embedding backends.
5//!
6//! # Overview
7//!
8//! - [`EmbeddingProviderInfo`] — static metadata about a provider (id, model, dimensions, …)
9//! - [`EmbeddingProvider`] — supertrait of [`Embedder`] that exposes provider metadata
10//! - [`EmbeddingRegistry`] — a runtime map of named providers with default selection
11
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use crate::error::{EngramError, Result};
16
17use super::Embedder;
18
19// ── Provider metadata ─────────────────────────────────────────────────────────
20
21/// Static metadata describing an embedding provider.
22///
23/// Returned by [`EmbeddingProvider::provider_info`] and exposed through the
24/// [`EmbeddingRegistry`] without requiring a live embedding call.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct EmbeddingProviderInfo {
27    /// Unique, machine-readable identifier (e.g. `"tfidf"`, `"openai-3-small"`).
28    pub id: String,
29    /// Human-readable name (e.g. `"TF-IDF (local)"`, `"OpenAI text-embedding-3-small"`).
30    pub name: String,
31    /// Underlying model identifier (e.g. `"tfidf"`, `"text-embedding-3-small"`).
32    pub model: String,
33    /// Number of dimensions produced by this provider.
34    pub dimensions: usize,
35    /// Whether this provider requires an API key to operate.
36    pub requires_api_key: bool,
37    /// Whether this provider runs entirely on the local machine (no network calls).
38    pub is_local: bool,
39}
40
41// ── EmbeddingProvider trait ───────────────────────────────────────────────────
42
43/// An [`Embedder`] that also exposes self-describing metadata.
44///
45/// Implement both [`Embedder`] and this trait to participate in the
46/// [`EmbeddingRegistry`].
47pub trait EmbeddingProvider: Embedder {
48    /// Return static metadata for this provider.
49    fn provider_info(&self) -> EmbeddingProviderInfo;
50}
51
52// ── EmbeddingRegistry ─────────────────────────────────────────────────────────
53
54/// A runtime registry of named [`EmbeddingProvider`] implementations.
55///
56/// Providers are keyed by [`EmbeddingProviderInfo::id`]. An optional default
57/// can be set explicitly via [`EmbeddingRegistry::set_default`]; if none is
58/// set, [`EmbeddingRegistry::default_provider`] returns the first registered
59/// provider.
60pub struct EmbeddingRegistry {
61    providers: HashMap<String, Arc<dyn EmbeddingProvider>>,
62    /// Insertion order — used to determine the first-registered provider.
63    order: Vec<String>,
64    /// Explicit default id, if set.
65    default_id: Option<String>,
66}
67
68impl EmbeddingRegistry {
69    /// Create an empty registry.
70    pub fn new() -> Self {
71        Self {
72            providers: HashMap::new(),
73            order: Vec::new(),
74            default_id: None,
75        }
76    }
77
78    /// Register a provider.
79    ///
80    /// If a provider with the same id already exists it is replaced.
81    /// The insertion order of new ids is preserved for use by
82    /// [`EmbeddingRegistry::default_provider`].
83    pub fn register(&mut self, provider: Arc<dyn EmbeddingProvider>) {
84        let id = provider.provider_info().id.clone();
85        if !self.providers.contains_key(&id) {
86            self.order.push(id.clone());
87        }
88        self.providers.insert(id, provider);
89    }
90
91    /// Look up a provider by id.
92    ///
93    /// Returns `None` if no provider with that id has been registered.
94    pub fn get(&self, id: &str) -> Option<Arc<dyn EmbeddingProvider>> {
95        self.providers.get(id).cloned()
96    }
97
98    /// Return metadata for all registered providers, in registration order.
99    pub fn list(&self) -> Vec<EmbeddingProviderInfo> {
100        self.order
101            .iter()
102            .filter_map(|id| self.providers.get(id))
103            .map(|p| p.provider_info())
104            .collect()
105    }
106
107    /// Return the default provider.
108    ///
109    /// - If a default was set via [`EmbeddingRegistry::set_default`], that provider is returned.
110    /// - Otherwise the first registered provider is returned.
111    /// - Returns `None` if the registry is empty.
112    pub fn default_provider(&self) -> Option<Arc<dyn EmbeddingProvider>> {
113        if let Some(ref id) = self.default_id {
114            // Explicit default may have been de-registered; fall through if so.
115            if let Some(p) = self.providers.get(id.as_str()) {
116                return Some(p.clone());
117            }
118        }
119        // Fallback: first registered.
120        self.order
121            .first()
122            .and_then(|id| self.providers.get(id.as_str()))
123            .cloned()
124    }
125
126    /// Change the default provider.
127    ///
128    /// Returns [`EngramError::InvalidInput`] if `id` is not registered.
129    pub fn set_default(&mut self, id: &str) -> Result<()> {
130        if self.providers.contains_key(id) {
131            self.default_id = Some(id.to_string());
132            Ok(())
133        } else {
134            Err(EngramError::InvalidInput(format!(
135                "No embedding provider registered with id '{id}'"
136            )))
137        }
138    }
139
140    /// Return the number of registered providers.
141    pub fn count(&self) -> usize {
142        self.providers.len()
143    }
144}
145
146impl Default for EmbeddingRegistry {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152// ── Tests ─────────────────────────────────────────────────────────────────────
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::error::Result;
158
159    // ── Mock provider ─────────────────────────────────────────────────────────
160
161    struct MockProvider {
162        info: EmbeddingProviderInfo,
163    }
164
165    impl MockProvider {
166        fn new(id: &str, dimensions: usize) -> Self {
167            Self {
168                info: EmbeddingProviderInfo {
169                    id: id.to_string(),
170                    name: format!("Mock ({id})"),
171                    model: format!("mock-{id}"),
172                    dimensions,
173                    requires_api_key: false,
174                    is_local: true,
175                },
176            }
177        }
178    }
179
180    impl Embedder for MockProvider {
181        fn embed(&self, _text: &str) -> Result<Vec<f32>> {
182            Ok(vec![0.0_f32; self.info.dimensions])
183        }
184
185        fn dimensions(&self) -> usize {
186            self.info.dimensions
187        }
188
189        fn model_name(&self) -> &str {
190            &self.info.model
191        }
192    }
193
194    impl EmbeddingProvider for MockProvider {
195        fn provider_info(&self) -> EmbeddingProviderInfo {
196            self.info.clone()
197        }
198    }
199
200    fn make_provider(id: &str) -> Arc<dyn EmbeddingProvider> {
201        Arc::new(MockProvider::new(id, 64))
202    }
203
204    // ── Tests ─────────────────────────────────────────────────────────────────
205
206    #[test]
207    fn test_register_and_get_by_id() {
208        let mut registry = EmbeddingRegistry::new();
209        registry.register(make_provider("alpha"));
210
211        let provider = registry.get("alpha");
212        assert!(provider.is_some(), "registered provider should be found");
213        assert_eq!(provider.unwrap().provider_info().id, "alpha");
214    }
215
216    #[test]
217    fn test_get_unknown_returns_none() {
218        let registry = EmbeddingRegistry::new();
219        assert!(registry.get("nonexistent").is_none());
220    }
221
222    #[test]
223    fn test_list_returns_all_providers() {
224        let mut registry = EmbeddingRegistry::new();
225        registry.register(make_provider("alpha"));
226        registry.register(make_provider("beta"));
227        registry.register(make_provider("gamma"));
228
229        let list = registry.list();
230        assert_eq!(list.len(), 3);
231        let ids: Vec<&str> = list.iter().map(|i| i.id.as_str()).collect();
232        assert!(ids.contains(&"alpha"));
233        assert!(ids.contains(&"beta"));
234        assert!(ids.contains(&"gamma"));
235    }
236
237    #[test]
238    fn test_list_preserves_insertion_order() {
239        let mut registry = EmbeddingRegistry::new();
240        registry.register(make_provider("first"));
241        registry.register(make_provider("second"));
242        registry.register(make_provider("third"));
243
244        let ids: Vec<String> = registry.list().into_iter().map(|i| i.id).collect();
245        assert_eq!(ids, vec!["first", "second", "third"]);
246    }
247
248    #[test]
249    fn test_default_returns_first_registered() {
250        let mut registry = EmbeddingRegistry::new();
251        assert!(
252            registry.default_provider().is_none(),
253            "empty registry has no default"
254        );
255
256        registry.register(make_provider("first"));
257        registry.register(make_provider("second"));
258
259        let default = registry.default_provider().expect("should have a default");
260        assert_eq!(default.provider_info().id, "first");
261    }
262
263    #[test]
264    fn test_set_default_changes_default() {
265        let mut registry = EmbeddingRegistry::new();
266        registry.register(make_provider("alpha"));
267        registry.register(make_provider("beta"));
268
269        registry.set_default("beta").expect("beta is registered");
270
271        let default = registry.default_provider().expect("should have a default");
272        assert_eq!(default.provider_info().id, "beta");
273    }
274
275    #[test]
276    fn test_set_default_unknown_returns_error() {
277        let mut registry = EmbeddingRegistry::new();
278        let result = registry.set_default("does-not-exist");
279        assert!(result.is_err(), "unknown id should return an error");
280    }
281
282    #[test]
283    fn test_count() {
284        let mut registry = EmbeddingRegistry::new();
285        assert_eq!(registry.count(), 0);
286
287        registry.register(make_provider("a"));
288        assert_eq!(registry.count(), 1);
289
290        registry.register(make_provider("b"));
291        assert_eq!(registry.count(), 2);
292    }
293
294    #[test]
295    fn test_register_replaces_existing_id() {
296        let mut registry = EmbeddingRegistry::new();
297        registry.register(make_provider("a"));
298        // Register a different provider under the same id.
299        registry.register(Arc::new(MockProvider {
300            info: EmbeddingProviderInfo {
301                id: "a".to_string(),
302                name: "Updated A".to_string(),
303                model: "updated-model".to_string(),
304                dimensions: 128,
305                requires_api_key: true,
306                is_local: false,
307            },
308        }));
309
310        // Count must remain 1 (no duplicate id).
311        assert_eq!(registry.count(), 1);
312
313        let info = registry.get("a").unwrap().provider_info();
314        assert_eq!(info.name, "Updated A");
315        assert_eq!(info.dimensions, 128);
316    }
317
318    #[test]
319    fn test_provider_info_fields() {
320        let info = EmbeddingProviderInfo {
321            id: "test".to_string(),
322            name: "Test Provider".to_string(),
323            model: "test-model-v1".to_string(),
324            dimensions: 256,
325            requires_api_key: true,
326            is_local: false,
327        };
328        assert_eq!(info.id, "test");
329        assert_eq!(info.dimensions, 256);
330        assert!(info.requires_api_key);
331        assert!(!info.is_local);
332    }
333
334    #[test]
335    fn test_embed_via_registry_provider() {
336        let mut registry = EmbeddingRegistry::new();
337        registry.register(make_provider("mock"));
338
339        let provider = registry.get("mock").expect("mock is registered");
340        let embedding = provider.embed("hello world").expect("embed should succeed");
341        assert_eq!(embedding.len(), 64);
342    }
343}