khive_runtime/embedder_registry.rs
1//! EmbedderRegistry — pack-extensible embedding provider surface (ADR-031 extension).
2//!
3//! Packs can contribute custom embedding models at registration time by
4//! implementing [`EmbedderProvider`] and calling
5//! [`KhiveRuntime::register_embedder`]. The built-in lattice models
6//! (MiniLM, paraphrase, etc.) are pre-registered as [`LatticeEmbedderProvider`]
7//! instances during runtime construction and require no opt-in.
8//!
9//! # Extension contract
10//!
11//! ```ignore
12//! use khive_runtime::embedder_registry::{EmbedderProvider, EmbedderRegistry};
13//! use khive_runtime::{KhiveRuntime, RuntimeError};
14//! use lattice_embed::EmbeddingService;
15//! use std::sync::Arc;
16//! use async_trait::async_trait;
17//!
18//! struct MyProvider;
19//!
20//! #[async_trait]
21//! impl EmbedderProvider for MyProvider {
22//! fn name(&self) -> &str { "my-custom-encoder" }
23//! fn dimensions(&self) -> usize { 384 }
24//! async fn build(&self) -> Result<Arc<dyn EmbeddingService>, RuntimeError> {
25//! // construct and return your EmbeddingService impl
26//! todo!()
27//! }
28//! }
29//! ```
30//!
31//! Register via `PackRuntime::register_embedders` (called by the transport after
32//! the pack is instantiated) or directly via `KhiveRuntime::register_embedder`.
33
34use std::collections::HashMap;
35use std::sync::Arc;
36
37use async_trait::async_trait;
38use lattice_embed::{
39 CachedEmbeddingService, EmbeddingModel, EmbeddingService, NativeEmbeddingService,
40};
41use tokio::sync::OnceCell;
42
43use crate::error::{RuntimeError, RuntimeResult};
44
45/// A source that can produce an [`EmbeddingService`] by name.
46///
47/// Packs implement this trait to register custom embedding backends.
48/// The runtime calls [`build`](EmbedderProvider::build) lazily — once per
49/// process per model — and caches the result. Subsequent calls to
50/// `KhiveRuntime::embedder(name)` are cheap.
51///
52/// Built-in lattice models are registered automatically via
53/// [`LatticeEmbedderProvider`]; packs need not re-register them.
54#[async_trait]
55pub trait EmbedderProvider: Send + Sync {
56 /// Stable, case-sensitive name for this embedder.
57 ///
58 /// Must be unique across all registered providers. The name is used as
59 /// the key in `KhiveRuntime::embedder(name)` lookups and as the storage
60 /// table suffix for vector indices. Use the model's canonical short form
61 /// (e.g. `"all-minilm-l6-v2"`, `"my-custom-encoder"`).
62 fn name(&self) -> &str;
63
64 /// Output vector dimension for this embedder.
65 ///
66 /// Must be consistent with what [`build`](Self::build) produces.
67 /// The runtime uses this to pre-register the vector store columns.
68 fn dimensions(&self) -> usize;
69
70 /// Construct the underlying [`EmbeddingService`].
71 ///
72 /// Called at most once per process. The result is cached in a
73 /// [`OnceCell`]; concurrent callers block on the first call and share
74 /// the result thereafter.
75 async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>>;
76}
77
78/// An entry in the [`EmbedderRegistry`] combining a provider with its
79/// lazy-initialized service.
80pub(crate) struct EmbedderEntry {
81 provider: Arc<dyn EmbedderProvider>,
82 cell: Arc<OnceCell<Arc<dyn EmbeddingService>>>,
83}
84
85impl Clone for EmbedderEntry {
86 fn clone(&self) -> Self {
87 Self {
88 provider: Arc::clone(&self.provider),
89 cell: Arc::clone(&self.cell),
90 }
91 }
92}
93
94/// Registry of named [`EmbedderProvider`] instances.
95///
96/// Built during `KhiveRuntime` construction and optionally extended by packs
97/// via [`KhiveRuntime::register_embedder`]. The registry is internally
98/// reference-counted so `KhiveRuntime::clone()` shares the same providers
99/// and cached service instances.
100#[derive(Clone, Default)]
101pub struct EmbedderRegistry {
102 entries: HashMap<String, EmbedderEntry>,
103}
104
105impl EmbedderRegistry {
106 /// Create an empty registry.
107 pub fn new() -> Self {
108 Self {
109 entries: HashMap::new(),
110 }
111 }
112
113 /// Register a provider.
114 ///
115 /// If a provider with the same [`name`](EmbedderProvider::name) already
116 /// exists, the new provider **replaces** the existing one (last-writer wins).
117 /// The previously cached service instance (if any) is discarded — the
118 /// replacement provider's [`build`](EmbedderProvider::build) will be
119 /// called lazily on the next access.
120 ///
121 /// **Last-wins** is chosen over rejection because pack registration order
122 /// is not guaranteed and packs may legitimately override a default model
123 /// with a custom implementation of the same logical name. Operators who
124 /// need strict collision detection should inspect
125 /// [`names`](Self::names) before registering.
126 pub fn register<P: EmbedderProvider + 'static>(&mut self, provider: P) {
127 let name = provider.name().to_owned();
128 self.entries.insert(
129 name,
130 EmbedderEntry {
131 provider: Arc::new(provider),
132 cell: Arc::new(OnceCell::new()),
133 },
134 );
135 }
136
137 /// Look up a provider by name.
138 pub fn get_provider(&self, name: &str) -> Option<&dyn EmbedderProvider> {
139 self.entries.get(name).map(|e| e.provider.as_ref())
140 }
141
142 /// Returns `true` if a provider with this name is registered.
143 pub fn contains(&self, name: &str) -> bool {
144 self.entries.contains_key(name)
145 }
146
147 /// Names of all registered providers, in unspecified order.
148 pub fn names(&self) -> Vec<String> {
149 self.entries.keys().cloned().collect()
150 }
151
152 /// Return a cloned entry for `name` without holding any lock.
153 ///
154 /// The caller can then call [`EmbedderEntry::resolve`] without holding
155 /// a lock — this avoids holding a `RwLockGuard` across `await` points.
156 /// Returns `None` if `name` is not registered.
157 pub(crate) fn get_entry(&self, name: &str) -> Option<EmbedderEntry> {
158 self.entries.get(name).cloned()
159 }
160
161 /// Lazily resolve a registered provider to its live [`EmbeddingService`].
162 ///
163 /// Returns [`RuntimeError::UnknownModel`] if `name` is not registered.
164 /// The first call for a given name triggers [`EmbedderProvider::build`];
165 /// subsequent calls return the cached `Arc`.
166 ///
167 /// Prefer [`KhiveRuntime::embedder`] over calling this directly from pack
168 /// handlers — the runtime method handles alias resolution and error mapping.
169 pub async fn get_service(&self, name: &str) -> RuntimeResult<Arc<dyn EmbeddingService>> {
170 let entry = self
171 .entries
172 .get(name)
173 .ok_or_else(|| RuntimeError::UnknownModel(name.to_string()))?
174 .clone();
175
176 entry.resolve().await
177 }
178}
179
180impl EmbedderEntry {
181 /// Lazily initialise and return the embedding service for this entry.
182 ///
183 /// The `OnceCell` guarantees that `build` is called at most once even
184 /// under concurrent access. Callers hold no external lock while awaiting.
185 ///
186 /// Returns `RuntimeError` if `build()` fails, rather than panicking.
187 pub(crate) async fn resolve(self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
188 // `OnceCell` does not expose a fallible init; we work around this by
189 // checking if the cell is already populated (cheap), and if not, calling
190 // `build()` ourselves, storing on success, and propagating on failure.
191 // Races are benign: at worst two callers both call `build()` and the
192 // loser's result is discarded — both outcomes are identical services.
193 if let Some(svc) = self.cell.get() {
194 return Ok(Arc::clone(svc));
195 }
196 let svc = self.provider.build().await.map_err(|e| {
197 crate::error::RuntimeError::Internal(format!(
198 "EmbedderProvider '{}' build() failed: {e}",
199 self.provider.name()
200 ))
201 })?;
202 // `set` may fail if another task raced us to initialise; that is fine —
203 // we just return our freshly-built instance (functionally identical).
204 let _ = self.cell.set(Arc::clone(&svc));
205 Ok(svc)
206 }
207}
208
209// ── LatticeEmbedderProvider ───────────────────────────────────────────────────
210
211/// Adapter that wraps a [`lattice_embed::EmbeddingModel`] as an
212/// [`EmbedderProvider`].
213///
214/// All built-in models (MiniLM, paraphrase-multilingual, BGE variants, etc.)
215/// are registered as `LatticeEmbedderProvider` instances during
216/// `KhiveRuntime` construction. External callers do not need to use this type
217/// unless they are constructing a custom registry from scratch.
218pub struct LatticeEmbedderProvider {
219 model: EmbeddingModel,
220 /// Cached `to_string()` result so `name()` can return `&str`.
221 name: String,
222}
223
224impl LatticeEmbedderProvider {
225 /// Create a new provider wrapping the given lattice model.
226 pub fn new(model: EmbeddingModel) -> Self {
227 let name = model.to_string();
228 Self { model, name }
229 }
230}
231
232#[async_trait]
233impl EmbedderProvider for LatticeEmbedderProvider {
234 fn name(&self) -> &str {
235 &self.name
236 }
237
238 fn dimensions(&self) -> usize {
239 self.model.dimensions()
240 }
241
242 async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
243 let native = Arc::new(NativeEmbeddingService::with_model(self.model));
244 let cached = CachedEmbeddingService::with_default_cache(native);
245 Ok(Arc::new(cached) as Arc<dyn EmbeddingService>)
246 }
247}
248
249// ── Unit tests ────────────────────────────────────────────────────────────────
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use std::sync::atomic::{AtomicUsize, Ordering};
255
256 // ---- minimal mock provider ----
257
258 struct ConstVecProvider {
259 name: String,
260 dims: usize,
261 build_calls: Arc<AtomicUsize>,
262 }
263
264 impl ConstVecProvider {
265 fn new(name: &str, dims: usize) -> Self {
266 Self {
267 name: name.to_owned(),
268 dims,
269 build_calls: Arc::new(AtomicUsize::new(0)),
270 }
271 }
272 }
273
274 /// A trivial embedding service that returns a constant vector of `1.0`s.
275 /// The `model` parameter is ignored — this service always returns the
276 /// same synthetic vector regardless of which model is requested.
277 struct ConstVecService {
278 dims: usize,
279 }
280
281 #[async_trait]
282 impl EmbeddingService for ConstVecService {
283 async fn embed(
284 &self,
285 texts: &[String],
286 _model: EmbeddingModel,
287 ) -> std::result::Result<Vec<Vec<f32>>, lattice_embed::EmbedError> {
288 Ok(texts.iter().map(|_| vec![1.0_f32; self.dims]).collect())
289 }
290
291 fn supports_model(&self, _model: EmbeddingModel) -> bool {
292 true
293 }
294
295 fn name(&self) -> &'static str {
296 "const-vec-service"
297 }
298 }
299
300 #[async_trait]
301 impl EmbedderProvider for ConstVecProvider {
302 fn name(&self) -> &str {
303 &self.name
304 }
305
306 fn dimensions(&self) -> usize {
307 self.dims
308 }
309
310 async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
311 self.build_calls.fetch_add(1, Ordering::SeqCst);
312 Ok(Arc::new(ConstVecService { dims: self.dims }))
313 }
314 }
315
316 // ---- test: register + get round-trip ----
317
318 #[test]
319 fn register_and_get_provider_round_trip() {
320 let mut reg = EmbedderRegistry::new();
321 reg.register(ConstVecProvider::new("mock-384", 384));
322
323 assert!(reg.contains("mock-384"), "registered name must be present");
324 let provider = reg.get_provider("mock-384").expect("provider must exist");
325 assert_eq!(provider.name(), "mock-384");
326 assert_eq!(provider.dimensions(), 384);
327 }
328
329 // ---- test: duplicate name is last-wins (not an error) ----
330
331 #[test]
332 fn duplicate_name_last_wins() {
333 let mut reg = EmbedderRegistry::new();
334 reg.register(ConstVecProvider::new("shared", 128));
335 reg.register(ConstVecProvider::new("shared", 256));
336
337 let provider = reg.get_provider("shared").expect("provider must exist");
338 assert_eq!(
339 provider.dimensions(),
340 256,
341 "last registration must win; expected dims=256"
342 );
343 }
344
345 // ---- test: names() returns all registered names ----
346
347 #[test]
348 fn names_returns_all_registered() {
349 let mut reg = EmbedderRegistry::new();
350 reg.register(ConstVecProvider::new("model-a", 64));
351 reg.register(ConstVecProvider::new("model-b", 128));
352 reg.register(ConstVecProvider::new("model-c", 256));
353
354 let mut names = reg.names();
355 names.sort();
356 assert_eq!(names, vec!["model-a", "model-b", "model-c"]);
357 }
358
359 // ---- test: get_service returns UnknownModel for unregistered name ----
360
361 #[tokio::test]
362 async fn get_service_unknown_name_returns_error() {
363 let reg = EmbedderRegistry::new();
364 let result = reg.get_service("does-not-exist").await;
365 let err = result.err().expect("expected Err for unknown name, got Ok");
366 assert!(
367 matches!(err, RuntimeError::UnknownModel(ref n) if n == "does-not-exist"),
368 "expected UnknownModel, got {err:?}"
369 );
370 }
371
372 // ---- test: get_service calls build once (lazy, cached) ----
373
374 #[tokio::test]
375 async fn get_service_calls_build_once() {
376 let counter = Arc::new(AtomicUsize::new(0));
377 let provider = ConstVecProvider {
378 name: "cached-model".to_owned(),
379 dims: 32,
380 build_calls: Arc::clone(&counter),
381 };
382 let mut reg = EmbedderRegistry::new();
383 reg.register(provider);
384
385 let _ = reg.get_service("cached-model").await.unwrap();
386 let _ = reg.get_service("cached-model").await.unwrap();
387 let _ = reg.get_service("cached-model").await.unwrap();
388
389 assert_eq!(
390 counter.load(Ordering::SeqCst),
391 1,
392 "build must be called exactly once regardless of get_service call count"
393 );
394 }
395}