Skip to main content

khive_runtime/
embedder_registry.rs

1//! EmbedderRegistry — pack-extensible embedding provider surface (ADR-031 extension).
2//!
3//! Packs can contribute custom embedding models at registration time by
4//! implementing [`EmbedderProvider`] and calling
5//! [`KhiveRuntime::register_embedder`]. The built-in lattice models
6//! (MiniLM, paraphrase, etc.) are pre-registered as [`LatticeEmbedderProvider`]
7//! instances during runtime construction and require no opt-in.
8//!
9//! # Extension contract
10//!
11//! ```ignore
12//! use khive_runtime::embedder_registry::{EmbedderProvider, EmbedderRegistry};
13//! use khive_runtime::{KhiveRuntime, RuntimeError};
14//! use lattice_embed::EmbeddingService;
15//! use std::sync::Arc;
16//! use async_trait::async_trait;
17//!
18//! struct MyProvider;
19//!
20//! #[async_trait]
21//! impl EmbedderProvider for MyProvider {
22//!     fn name(&self) -> &str { "my-custom-encoder" }
23//!     fn dimensions(&self) -> usize { 384 }
24//!     async fn build(&self) -> Result<Arc<dyn EmbeddingService>, RuntimeError> {
25//!         // construct and return your EmbeddingService impl
26//!         todo!()
27//!     }
28//! }
29//! ```
30//!
31//! Register via `PackRuntime::register_embedders` (called by the transport after
32//! the pack is instantiated) or directly via `KhiveRuntime::register_embedder`.
33
34use std::collections::HashMap;
35use std::sync::Arc;
36
37use async_trait::async_trait;
38use lattice_embed::{
39    CachedEmbeddingService, EmbeddingModel, EmbeddingService, NativeEmbeddingService,
40};
41use tokio::sync::OnceCell;
42
43use crate::error::{RuntimeError, RuntimeResult};
44
45/// A source that can produce an [`EmbeddingService`] by name.
46///
47/// Packs implement this trait to register custom embedding backends.
48/// The runtime calls [`build`](EmbedderProvider::build) lazily — once per
49/// process per model — and caches the result. Subsequent calls to
50/// `KhiveRuntime::embedder(name)` are cheap.
51///
52/// Built-in lattice models are registered automatically via
53/// [`LatticeEmbedderProvider`]; packs need not re-register them.
54#[async_trait]
55pub trait EmbedderProvider: Send + Sync {
56    /// Stable, case-sensitive name for this embedder.
57    ///
58    /// Must be unique across all registered providers. The name is used as
59    /// the key in `KhiveRuntime::embedder(name)` lookups and as the storage
60    /// table suffix for vector indices. Use the model's canonical short form
61    /// (e.g. `"all-minilm-l6-v2"`, `"my-custom-encoder"`).
62    fn name(&self) -> &str;
63
64    /// Output vector dimension for this embedder.
65    ///
66    /// Must be consistent with what [`build`](Self::build) produces.
67    /// The runtime uses this to pre-register the vector store columns.
68    fn dimensions(&self) -> usize;
69
70    /// Construct the underlying [`EmbeddingService`].
71    ///
72    /// Called at most once per process. The result is cached in a
73    /// [`OnceCell`]; concurrent callers block on the first call and share
74    /// the result thereafter.
75    async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>>;
76}
77
78/// An entry in the [`EmbedderRegistry`] combining a provider with its
79/// lazy-initialized service.
80pub(crate) struct EmbedderEntry {
81    provider: Arc<dyn EmbedderProvider>,
82    cell: Arc<OnceCell<Arc<dyn EmbeddingService>>>,
83}
84
85impl Clone for EmbedderEntry {
86    fn clone(&self) -> Self {
87        Self {
88            provider: Arc::clone(&self.provider),
89            cell: Arc::clone(&self.cell),
90        }
91    }
92}
93
94/// Registry of named [`EmbedderProvider`] instances.
95///
96/// Built during `KhiveRuntime` construction and optionally extended by packs
97/// via [`KhiveRuntime::register_embedder`]. The registry is internally
98/// reference-counted so `KhiveRuntime::clone()` shares the same providers
99/// and cached service instances.
100#[derive(Clone, Default)]
101pub struct EmbedderRegistry {
102    entries: HashMap<String, EmbedderEntry>,
103}
104
105impl EmbedderRegistry {
106    /// Create an empty registry.
107    pub fn new() -> Self {
108        Self {
109            entries: HashMap::new(),
110        }
111    }
112
113    /// Register a provider.
114    ///
115    /// If a provider with the same [`name`](EmbedderProvider::name) already
116    /// exists, the new provider **replaces** the existing one (last-writer wins).
117    /// The previously cached service instance (if any) is discarded — the
118    /// replacement provider's [`build`](EmbedderProvider::build) will be
119    /// called lazily on the next access.
120    ///
121    /// **Last-wins** is chosen over rejection because pack registration order
122    /// is not guaranteed and packs may legitimately override a default model
123    /// with a custom implementation of the same logical name. Operators who
124    /// need strict collision detection should inspect
125    /// [`names`](Self::names) before registering.
126    pub fn register<P: EmbedderProvider + 'static>(&mut self, provider: P) {
127        let name = provider.name().to_owned();
128        self.entries.insert(
129            name,
130            EmbedderEntry {
131                provider: Arc::new(provider),
132                cell: Arc::new(OnceCell::new()),
133            },
134        );
135    }
136
137    /// Look up a provider by name.
138    pub fn get_provider(&self, name: &str) -> Option<&dyn EmbedderProvider> {
139        self.entries.get(name).map(|e| e.provider.as_ref())
140    }
141
142    /// Returns `true` if a provider with this name is registered.
143    pub fn contains(&self, name: &str) -> bool {
144        self.entries.contains_key(name)
145    }
146
147    /// Names of all registered providers, in unspecified order.
148    pub fn names(&self) -> Vec<String> {
149        self.entries.keys().cloned().collect()
150    }
151
152    /// Return a cloned entry for `name` without holding any lock.
153    ///
154    /// The caller can then call [`EmbedderEntry::resolve`] without holding
155    /// a lock — this avoids holding a `RwLockGuard` across `await` points.
156    /// Returns `None` if `name` is not registered.
157    pub(crate) fn get_entry(&self, name: &str) -> Option<EmbedderEntry> {
158        self.entries.get(name).cloned()
159    }
160
161    /// Lazily resolve a registered provider to its live [`EmbeddingService`].
162    ///
163    /// Returns [`RuntimeError::UnknownModel`] if `name` is not registered.
164    /// The first call for a given name triggers [`EmbedderProvider::build`];
165    /// subsequent calls return the cached `Arc`.
166    ///
167    /// Prefer [`KhiveRuntime::embedder`] over calling this directly from pack
168    /// handlers — the runtime method handles alias resolution and error mapping.
169    pub async fn get_service(&self, name: &str) -> RuntimeResult<Arc<dyn EmbeddingService>> {
170        let entry = self
171            .entries
172            .get(name)
173            .ok_or_else(|| RuntimeError::UnknownModel(name.to_string()))?
174            .clone();
175
176        entry.resolve().await
177    }
178}
179
180impl EmbedderEntry {
181    /// Lazily initialise and return the embedding service for this entry.
182    ///
183    /// The `OnceCell` guarantees that `build` is called at most once even
184    /// under concurrent access. Callers hold no external lock while awaiting.
185    ///
186    /// Returns `RuntimeError` if `build()` fails, rather than panicking.
187    pub(crate) async fn resolve(self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
188        // `OnceCell` does not expose a fallible init; we work around this by
189        // checking if the cell is already populated (cheap), and if not, calling
190        // `build()` ourselves, storing on success, and propagating on failure.
191        // Races are benign: at worst two callers both call `build()` and the
192        // loser's result is discarded — both outcomes are identical services.
193        if let Some(svc) = self.cell.get() {
194            return Ok(Arc::clone(svc));
195        }
196        let svc = self.provider.build().await.map_err(|e| {
197            crate::error::RuntimeError::Internal(format!(
198                "EmbedderProvider '{}' build() failed: {e}",
199                self.provider.name()
200            ))
201        })?;
202        // `set` may fail if another task raced us to initialise; that is fine —
203        // we just return our freshly-built instance (functionally identical).
204        let _ = self.cell.set(Arc::clone(&svc));
205        Ok(svc)
206    }
207}
208
209// ── LatticeEmbedderProvider ───────────────────────────────────────────────────
210
211/// Adapter that wraps a [`lattice_embed::EmbeddingModel`] as an
212/// [`EmbedderProvider`].
213///
214/// All built-in models (MiniLM, paraphrase-multilingual, BGE variants, etc.)
215/// are registered as `LatticeEmbedderProvider` instances during
216/// `KhiveRuntime` construction. External callers do not need to use this type
217/// unless they are constructing a custom registry from scratch.
218pub struct LatticeEmbedderProvider {
219    model: EmbeddingModel,
220    /// Cached `to_string()` result so `name()` can return `&str`.
221    name: String,
222}
223
224impl LatticeEmbedderProvider {
225    /// Create a new provider wrapping the given lattice model.
226    pub fn new(model: EmbeddingModel) -> Self {
227        let name = model.to_string();
228        Self { model, name }
229    }
230}
231
232#[async_trait]
233impl EmbedderProvider for LatticeEmbedderProvider {
234    fn name(&self) -> &str {
235        &self.name
236    }
237
238    fn dimensions(&self) -> usize {
239        self.model.dimensions()
240    }
241
242    async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
243        let native = Arc::new(NativeEmbeddingService::with_model(self.model));
244        let cached = CachedEmbeddingService::with_default_cache(native);
245        Ok(Arc::new(cached) as Arc<dyn EmbeddingService>)
246    }
247}
248
249// ── Unit tests ────────────────────────────────────────────────────────────────
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use std::sync::atomic::{AtomicUsize, Ordering};
255
256    // ---- minimal mock provider ----
257
258    struct ConstVecProvider {
259        name: String,
260        dims: usize,
261        build_calls: Arc<AtomicUsize>,
262    }
263
264    impl ConstVecProvider {
265        fn new(name: &str, dims: usize) -> Self {
266            Self {
267                name: name.to_owned(),
268                dims,
269                build_calls: Arc::new(AtomicUsize::new(0)),
270            }
271        }
272    }
273
274    /// A trivial embedding service that returns a constant vector of `1.0`s.
275    /// The `model` parameter is ignored — this service always returns the
276    /// same synthetic vector regardless of which model is requested.
277    struct ConstVecService {
278        dims: usize,
279    }
280
281    #[async_trait]
282    impl EmbeddingService for ConstVecService {
283        async fn embed(
284            &self,
285            texts: &[String],
286            _model: EmbeddingModel,
287        ) -> std::result::Result<Vec<Vec<f32>>, lattice_embed::EmbedError> {
288            Ok(texts.iter().map(|_| vec![1.0_f32; self.dims]).collect())
289        }
290
291        fn supports_model(&self, _model: EmbeddingModel) -> bool {
292            true
293        }
294
295        fn name(&self) -> &'static str {
296            "const-vec-service"
297        }
298    }
299
300    #[async_trait]
301    impl EmbedderProvider for ConstVecProvider {
302        fn name(&self) -> &str {
303            &self.name
304        }
305
306        fn dimensions(&self) -> usize {
307            self.dims
308        }
309
310        async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
311            self.build_calls.fetch_add(1, Ordering::SeqCst);
312            Ok(Arc::new(ConstVecService { dims: self.dims }))
313        }
314    }
315
316    // ---- test: register + get round-trip ----
317
318    #[test]
319    fn register_and_get_provider_round_trip() {
320        let mut reg = EmbedderRegistry::new();
321        reg.register(ConstVecProvider::new("mock-384", 384));
322
323        assert!(reg.contains("mock-384"), "registered name must be present");
324        let provider = reg.get_provider("mock-384").expect("provider must exist");
325        assert_eq!(provider.name(), "mock-384");
326        assert_eq!(provider.dimensions(), 384);
327    }
328
329    // ---- test: duplicate name is last-wins (not an error) ----
330
331    #[test]
332    fn duplicate_name_last_wins() {
333        let mut reg = EmbedderRegistry::new();
334        reg.register(ConstVecProvider::new("shared", 128));
335        reg.register(ConstVecProvider::new("shared", 256));
336
337        let provider = reg.get_provider("shared").expect("provider must exist");
338        assert_eq!(
339            provider.dimensions(),
340            256,
341            "last registration must win; expected dims=256"
342        );
343    }
344
345    // ---- test: names() returns all registered names ----
346
347    #[test]
348    fn names_returns_all_registered() {
349        let mut reg = EmbedderRegistry::new();
350        reg.register(ConstVecProvider::new("model-a", 64));
351        reg.register(ConstVecProvider::new("model-b", 128));
352        reg.register(ConstVecProvider::new("model-c", 256));
353
354        let mut names = reg.names();
355        names.sort();
356        assert_eq!(names, vec!["model-a", "model-b", "model-c"]);
357    }
358
359    // ---- test: get_service returns UnknownModel for unregistered name ----
360
361    #[tokio::test]
362    async fn get_service_unknown_name_returns_error() {
363        let reg = EmbedderRegistry::new();
364        let result = reg.get_service("does-not-exist").await;
365        let err = result.err().expect("expected Err for unknown name, got Ok");
366        assert!(
367            matches!(err, RuntimeError::UnknownModel(ref n) if n == "does-not-exist"),
368            "expected UnknownModel, got {err:?}"
369        );
370    }
371
372    // ---- test: get_service calls build once (lazy, cached) ----
373
374    #[tokio::test]
375    async fn get_service_calls_build_once() {
376        let counter = Arc::new(AtomicUsize::new(0));
377        let provider = ConstVecProvider {
378            name: "cached-model".to_owned(),
379            dims: 32,
380            build_calls: Arc::clone(&counter),
381        };
382        let mut reg = EmbedderRegistry::new();
383        reg.register(provider);
384
385        let _ = reg.get_service("cached-model").await.unwrap();
386        let _ = reg.get_service("cached-model").await.unwrap();
387        let _ = reg.get_service("cached-model").await.unwrap();
388
389        assert_eq!(
390            counter.load(Ordering::SeqCst),
391            1,
392            "build must be called exactly once regardless of get_service call count"
393        );
394    }
395}