Skip to main content

khive_runtime/
embedder_registry.rs

1//! EmbedderRegistry — pack-extensible embedding provider surface.
2//!
3//! Packs implement [`EmbedderProvider`] and register custom models via
4//! [`crate::KhiveRuntime::register_embedder`]. Built-in lattice models are pre-registered
5//! during runtime construction and require no opt-in.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use lattice_embed::{
12    CachedEmbeddingService, EmbeddingModel, EmbeddingService, NativeEmbeddingService,
13};
14use tokio::sync::OnceCell;
15
16use crate::error::{RuntimeError, RuntimeResult};
17
18/// A source that can produce an [`EmbeddingService`] by name.
19///
20/// Packs implement this trait to register custom embedding backends.
21/// The runtime calls [`build`](EmbedderProvider::build) lazily — once per
22/// process per model — and caches the result. Subsequent calls to
23/// `KhiveRuntime::embedder(name)` are cheap.
24///
25/// Built-in lattice models are registered automatically via
26/// [`LatticeEmbedderProvider`]; packs need not re-register them.
27#[async_trait]
28pub trait EmbedderProvider: Send + Sync {
29    /// Stable, case-sensitive name for this embedder.
30    ///
31    /// Must be unique across all registered providers. The name is used as
32    /// the key in `KhiveRuntime::embedder(name)` lookups and as the storage
33    /// table suffix for vector indices. Use the model's canonical short form
34    /// (e.g. `"all-minilm-l6-v2"`, `"my-custom-encoder"`).
35    fn name(&self) -> &str;
36
37    /// Output vector dimension for this embedder.
38    ///
39    /// Must be consistent with what [`build`](Self::build) produces.
40    /// The runtime uses this to pre-register the vector store columns.
41    fn dimensions(&self) -> usize;
42
43    /// Construct the underlying [`EmbeddingService`].
44    ///
45    /// Called at most once per process. The result is cached in a
46    /// [`OnceCell`]; concurrent callers block on the first call and share
47    /// the result thereafter.
48    async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>>;
49}
50
51/// An entry in the [`EmbedderRegistry`] combining a provider with its
52/// lazy-initialized service.
53pub(crate) struct EmbedderEntry {
54    provider: Arc<dyn EmbedderProvider>,
55    cell: Arc<OnceCell<Arc<dyn EmbeddingService>>>,
56}
57
58impl Clone for EmbedderEntry {
59    fn clone(&self) -> Self {
60        Self {
61            provider: Arc::clone(&self.provider),
62            cell: Arc::clone(&self.cell),
63        }
64    }
65}
66
67/// Registry of named [`EmbedderProvider`] instances.
68///
69/// Built during `KhiveRuntime` construction and optionally extended by packs
70/// via [`crate::KhiveRuntime::register_embedder`]. The registry is internally
71/// reference-counted so `KhiveRuntime::clone()` shares the same providers
72/// and cached service instances.
73#[derive(Clone, Default)]
74pub struct EmbedderRegistry {
75    entries: HashMap<String, EmbedderEntry>,
76}
77
78impl EmbedderRegistry {
79    /// Create an empty registry.
80    pub fn new() -> Self {
81        Self {
82            entries: HashMap::new(),
83        }
84    }
85
86    /// Register a provider.
87    ///
88    /// If a provider with the same [`name`](EmbedderProvider::name) already
89    /// exists, the new provider **replaces** the existing one (last-writer wins).
90    /// The previously cached service instance (if any) is discarded — the
91    /// replacement provider's [`build`](EmbedderProvider::build) will be
92    /// called lazily on the next access.
93    ///
94    /// **Last-wins** is chosen over rejection because pack registration order
95    /// is not guaranteed and packs may legitimately override a default model
96    /// with a custom implementation of the same logical name. Operators who
97    /// need strict collision detection should inspect
98    /// [`names`](Self::names) before registering.
99    pub fn register<P: EmbedderProvider + 'static>(&mut self, provider: P) {
100        let name = provider.name().to_owned();
101        self.entries.insert(
102            name,
103            EmbedderEntry {
104                provider: Arc::new(provider),
105                cell: Arc::new(OnceCell::new()),
106            },
107        );
108    }
109
110    /// Look up a provider by name.
111    pub fn get_provider(&self, name: &str) -> Option<&dyn EmbedderProvider> {
112        self.entries.get(name).map(|e| e.provider.as_ref())
113    }
114
115    /// Returns `true` if a provider with this name is registered.
116    pub fn contains(&self, name: &str) -> bool {
117        self.entries.contains_key(name)
118    }
119
120    /// Names of all registered providers, in unspecified order.
121    pub fn names(&self) -> Vec<String> {
122        self.entries.keys().cloned().collect()
123    }
124
125    /// Return a cloned entry for `name` without holding any lock.
126    ///
127    /// The caller can then call [`EmbedderEntry::resolve`] without holding
128    /// a lock — this avoids holding a `RwLockGuard` across `await` points.
129    /// Returns `None` if `name` is not registered.
130    pub(crate) fn get_entry(&self, name: &str) -> Option<EmbedderEntry> {
131        self.entries.get(name).cloned()
132    }
133
134    /// Lazily resolve a registered provider to its live [`EmbeddingService`].
135    ///
136    /// Returns [`RuntimeError::UnknownModel`] if `name` is not registered.
137    /// The first call for a given name triggers [`EmbedderProvider::build`];
138    /// subsequent calls return the cached `Arc`.
139    ///
140    /// Prefer [`crate::KhiveRuntime::embedder`] over calling this directly from pack
141    /// handlers — the runtime method handles alias resolution and error mapping.
142    pub async fn get_service(&self, name: &str) -> RuntimeResult<Arc<dyn EmbeddingService>> {
143        let entry = self
144            .entries
145            .get(name)
146            .ok_or_else(|| RuntimeError::UnknownModel(name.to_string()))?
147            .clone();
148
149        entry.resolve().await
150    }
151}
152
153impl EmbedderEntry {
154    /// Lazily initialise and return the embedding service for this entry.
155    ///
156    /// The `OnceCell` guarantees that `build` is called at most once even
157    /// under concurrent access. Callers hold no external lock while awaiting.
158    ///
159    /// Returns `RuntimeError` if `build()` fails, rather than panicking.
160    pub(crate) async fn resolve(self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
161        // `OnceCell` does not expose a fallible init; we work around this by
162        // checking if the cell is already populated (cheap), and if not, calling
163        // `build()` ourselves, storing on success, and propagating on failure.
164        // Races are benign: at worst two callers both call `build()` and the
165        // loser's result is discarded — both outcomes are identical services.
166        if let Some(svc) = self.cell.get() {
167            return Ok(Arc::clone(svc));
168        }
169        let svc = self.provider.build().await.map_err(|e| {
170            crate::error::RuntimeError::Internal(format!(
171                "EmbedderProvider '{}' build() failed: {e}",
172                self.provider.name()
173            ))
174        })?;
175        // `set` may fail if another task raced us to initialise; that is fine —
176        // we just return our freshly-built instance (functionally identical).
177        let _ = self.cell.set(Arc::clone(&svc));
178        Ok(svc)
179    }
180}
181
182// ── LatticeEmbedderProvider ───────────────────────────────────────────────────
183
184/// Adapter that wraps a [`lattice_embed::EmbeddingModel`] as an
185/// [`EmbedderProvider`].
186///
187/// All built-in models (MiniLM, paraphrase-multilingual, BGE variants, etc.)
188/// are registered as `LatticeEmbedderProvider` instances during
189/// `KhiveRuntime` construction. External callers do not need to use this type
190/// unless they are constructing a custom registry from scratch.
191pub struct LatticeEmbedderProvider {
192    model: EmbeddingModel,
193    /// Cached `to_string()` result so `name()` can return `&str`.
194    name: String,
195}
196
197impl LatticeEmbedderProvider {
198    /// Create a new provider wrapping the given lattice model.
199    pub fn new(model: EmbeddingModel) -> Self {
200        let name = model.to_string();
201        Self { model, name }
202    }
203}
204
205#[async_trait]
206impl EmbedderProvider for LatticeEmbedderProvider {
207    fn name(&self) -> &str {
208        &self.name
209    }
210
211    fn dimensions(&self) -> usize {
212        self.model.dimensions()
213    }
214
215    async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
216        let native = Arc::new(NativeEmbeddingService::with_model(self.model));
217        let cached = CachedEmbeddingService::with_default_cache(native);
218        Ok(Arc::new(cached) as Arc<dyn EmbeddingService>)
219    }
220}
221
222// ── Unit tests ────────────────────────────────────────────────────────────────
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use std::sync::atomic::{AtomicUsize, Ordering};
228
229    // ---- minimal mock provider ----
230
231    struct ConstVecProvider {
232        name: String,
233        dims: usize,
234        build_calls: Arc<AtomicUsize>,
235    }
236
237    impl ConstVecProvider {
238        fn new(name: &str, dims: usize) -> Self {
239            Self {
240                name: name.to_owned(),
241                dims,
242                build_calls: Arc::new(AtomicUsize::new(0)),
243            }
244        }
245    }
246
247    /// A trivial embedding service that returns a constant vector of `1.0`s.
248    /// The `model` parameter is ignored — this service always returns the
249    /// same synthetic vector regardless of which model is requested.
250    struct ConstVecService {
251        dims: usize,
252    }
253
254    #[async_trait]
255    impl EmbeddingService for ConstVecService {
256        async fn embed(
257            &self,
258            texts: &[String],
259            _model: EmbeddingModel,
260        ) -> std::result::Result<Vec<Vec<f32>>, lattice_embed::EmbedError> {
261            Ok(texts.iter().map(|_| vec![1.0_f32; self.dims]).collect())
262        }
263
264        fn supports_model(&self, _model: EmbeddingModel) -> bool {
265            true
266        }
267
268        fn name(&self) -> &'static str {
269            "const-vec-service"
270        }
271    }
272
273    #[async_trait]
274    impl EmbedderProvider for ConstVecProvider {
275        fn name(&self) -> &str {
276            &self.name
277        }
278
279        fn dimensions(&self) -> usize {
280            self.dims
281        }
282
283        async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
284            self.build_calls.fetch_add(1, Ordering::SeqCst);
285            Ok(Arc::new(ConstVecService { dims: self.dims }))
286        }
287    }
288
289    // ---- test: register + get round-trip ----
290
291    #[test]
292    fn register_and_get_provider_round_trip() {
293        let mut reg = EmbedderRegistry::new();
294        reg.register(ConstVecProvider::new("mock-384", 384));
295
296        assert!(reg.contains("mock-384"), "registered name must be present");
297        let provider = reg.get_provider("mock-384").expect("provider must exist");
298        assert_eq!(provider.name(), "mock-384");
299        assert_eq!(provider.dimensions(), 384);
300    }
301
302    // ---- test: duplicate name is last-wins (not an error) ----
303
304    #[test]
305    fn duplicate_name_last_wins() {
306        let mut reg = EmbedderRegistry::new();
307        reg.register(ConstVecProvider::new("shared", 128));
308        reg.register(ConstVecProvider::new("shared", 256));
309
310        let provider = reg.get_provider("shared").expect("provider must exist");
311        assert_eq!(
312            provider.dimensions(),
313            256,
314            "last registration must win; expected dims=256"
315        );
316    }
317
318    // ---- test: names() returns all registered names ----
319
320    #[test]
321    fn names_returns_all_registered() {
322        let mut reg = EmbedderRegistry::new();
323        reg.register(ConstVecProvider::new("model-a", 64));
324        reg.register(ConstVecProvider::new("model-b", 128));
325        reg.register(ConstVecProvider::new("model-c", 256));
326
327        let mut names = reg.names();
328        names.sort();
329        assert_eq!(names, vec!["model-a", "model-b", "model-c"]);
330    }
331
332    // ---- test: get_service returns UnknownModel for unregistered name ----
333
334    #[tokio::test]
335    async fn get_service_unknown_name_returns_error() {
336        let reg = EmbedderRegistry::new();
337        let result = reg.get_service("does-not-exist").await;
338        let err = result.err().expect("expected Err for unknown name, got Ok");
339        assert!(
340            matches!(err, RuntimeError::UnknownModel(ref n) if n == "does-not-exist"),
341            "expected UnknownModel, got {err:?}"
342        );
343    }
344
345    // ---- test: get_service calls build once (lazy, cached) ----
346
347    #[tokio::test]
348    async fn get_service_calls_build_once() {
349        let counter = Arc::new(AtomicUsize::new(0));
350        let provider = ConstVecProvider {
351            name: "cached-model".to_owned(),
352            dims: 32,
353            build_calls: Arc::clone(&counter),
354        };
355        let mut reg = EmbedderRegistry::new();
356        reg.register(provider);
357
358        let _ = reg.get_service("cached-model").await.unwrap();
359        let _ = reg.get_service("cached-model").await.unwrap();
360        let _ = reg.get_service("cached-model").await.unwrap();
361
362        assert_eq!(
363            counter.load(Ordering::SeqCst),
364            1,
365            "build must be called exactly once regardless of get_service call count"
366        );
367    }
368}