khive_runtime/embedder_registry.rs
1//! EmbedderRegistry — pack-extensible embedding provider surface.
2//!
3//! Packs implement [`EmbedderProvider`] and register custom models via
4//! [`crate::KhiveRuntime::register_embedder`]. Built-in lattice models are pre-registered
5//! during runtime construction and require no opt-in.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use lattice_embed::{
12 CachedEmbeddingService, EmbeddingModel, EmbeddingService, NativeEmbeddingService,
13};
14use tokio::sync::OnceCell;
15
16use crate::error::{RuntimeError, RuntimeResult};
17
18/// A source that can produce an [`EmbeddingService`] by name.
19///
20/// Packs implement this trait to register custom embedding backends.
21/// The runtime calls [`build`](EmbedderProvider::build) lazily — once per
22/// process per model — and caches the result. Subsequent calls to
23/// `KhiveRuntime::embedder(name)` are cheap.
24///
25/// Built-in lattice models are registered automatically via
26/// [`LatticeEmbedderProvider`]; packs need not re-register them.
27#[async_trait]
28pub trait EmbedderProvider: Send + Sync {
29 /// Stable, case-sensitive name for this embedder.
30 ///
31 /// Must be unique across all registered providers. The name is used as
32 /// the key in `KhiveRuntime::embedder(name)` lookups and as the storage
33 /// table suffix for vector indices. Use the model's canonical short form
34 /// (e.g. `"all-minilm-l6-v2"`, `"my-custom-encoder"`).
35 fn name(&self) -> &str;
36
37 /// Output vector dimension for this embedder.
38 ///
39 /// Must be consistent with what [`build`](Self::build) produces.
40 /// The runtime uses this to pre-register the vector store columns.
41 fn dimensions(&self) -> usize;
42
43 /// Construct the underlying [`EmbeddingService`].
44 ///
45 /// Called at most once per process. The result is cached in a
46 /// [`OnceCell`]; concurrent callers block on the first call and share
47 /// the result thereafter.
48 async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>>;
49}
50
51/// An entry in the [`EmbedderRegistry`] combining a provider with its
52/// lazy-initialized service.
53pub(crate) struct EmbedderEntry {
54 provider: Arc<dyn EmbedderProvider>,
55 cell: Arc<OnceCell<Arc<dyn EmbeddingService>>>,
56}
57
58impl Clone for EmbedderEntry {
59 fn clone(&self) -> Self {
60 Self {
61 provider: Arc::clone(&self.provider),
62 cell: Arc::clone(&self.cell),
63 }
64 }
65}
66
67/// Registry of named [`EmbedderProvider`] instances.
68///
69/// Built during `KhiveRuntime` construction and optionally extended by packs
70/// via [`crate::KhiveRuntime::register_embedder`]. The registry is internally
71/// reference-counted so `KhiveRuntime::clone()` shares the same providers
72/// and cached service instances.
73#[derive(Clone, Default)]
74pub struct EmbedderRegistry {
75 entries: HashMap<String, EmbedderEntry>,
76}
77
78impl EmbedderRegistry {
79 /// Create an empty registry.
80 pub fn new() -> Self {
81 Self {
82 entries: HashMap::new(),
83 }
84 }
85
86 /// Register a provider.
87 ///
88 /// If a provider with the same [`name`](EmbedderProvider::name) already
89 /// exists, the new provider **replaces** the existing one (last-writer wins).
90 /// The previously cached service instance (if any) is discarded — the
91 /// replacement provider's [`build`](EmbedderProvider::build) will be
92 /// called lazily on the next access.
93 ///
94 /// **Last-wins** is chosen over rejection because pack registration order
95 /// is not guaranteed and packs may legitimately override a default model
96 /// with a custom implementation of the same logical name. Operators who
97 /// need strict collision detection should inspect
98 /// [`names`](Self::names) before registering.
99 pub fn register<P: EmbedderProvider + 'static>(&mut self, provider: P) {
100 let name = provider.name().to_owned();
101 self.entries.insert(
102 name,
103 EmbedderEntry {
104 provider: Arc::new(provider),
105 cell: Arc::new(OnceCell::new()),
106 },
107 );
108 }
109
110 /// Look up a provider by name.
111 pub fn get_provider(&self, name: &str) -> Option<&dyn EmbedderProvider> {
112 self.entries.get(name).map(|e| e.provider.as_ref())
113 }
114
115 /// Returns `true` if a provider with this name is registered.
116 pub fn contains(&self, name: &str) -> bool {
117 self.entries.contains_key(name)
118 }
119
120 /// Names of all registered providers, in unspecified order.
121 pub fn names(&self) -> Vec<String> {
122 self.entries.keys().cloned().collect()
123 }
124
125 /// Return a cloned entry for `name` without holding any lock.
126 ///
127 /// The caller can then call [`EmbedderEntry::resolve`] without holding
128 /// a lock — this avoids holding a `RwLockGuard` across `await` points.
129 /// Returns `None` if `name` is not registered.
130 pub(crate) fn get_entry(&self, name: &str) -> Option<EmbedderEntry> {
131 self.entries.get(name).cloned()
132 }
133
134 /// Lazily resolve a registered provider to its live [`EmbeddingService`].
135 ///
136 /// Returns [`RuntimeError::UnknownModel`] if `name` is not registered.
137 /// The first call for a given name triggers [`EmbedderProvider::build`];
138 /// subsequent calls return the cached `Arc`.
139 ///
140 /// Prefer [`crate::KhiveRuntime::embedder`] over calling this directly from pack
141 /// handlers — the runtime method handles alias resolution and error mapping.
142 pub async fn get_service(&self, name: &str) -> RuntimeResult<Arc<dyn EmbeddingService>> {
143 let entry = self
144 .entries
145 .get(name)
146 .ok_or_else(|| RuntimeError::UnknownModel(name.to_string()))?
147 .clone();
148
149 entry.resolve().await
150 }
151}
152
153impl EmbedderEntry {
154 /// Lazily initialise and return the embedding service for this entry.
155 ///
156 /// The `OnceCell` guarantees that `build` is called at most once even
157 /// under concurrent access. Callers hold no external lock while awaiting.
158 ///
159 /// Returns `RuntimeError` if `build()` fails, rather than panicking.
160 pub(crate) async fn resolve(self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
161 // `OnceCell` does not expose a fallible init; we work around this by
162 // checking if the cell is already populated (cheap), and if not, calling
163 // `build()` ourselves, storing on success, and propagating on failure.
164 // Races are benign: at worst two callers both call `build()` and the
165 // loser's result is discarded — both outcomes are identical services.
166 if let Some(svc) = self.cell.get() {
167 return Ok(Arc::clone(svc));
168 }
169 let svc = self.provider.build().await.map_err(|e| {
170 crate::error::RuntimeError::Internal(format!(
171 "EmbedderProvider '{}' build() failed: {e}",
172 self.provider.name()
173 ))
174 })?;
175 // `set` may fail if another task raced us to initialise; that is fine —
176 // we just return our freshly-built instance (functionally identical).
177 let _ = self.cell.set(Arc::clone(&svc));
178 Ok(svc)
179 }
180}
181
182// ── LatticeEmbedderProvider ───────────────────────────────────────────────────
183
184/// Adapter that wraps a [`lattice_embed::EmbeddingModel`] as an
185/// [`EmbedderProvider`].
186///
187/// All built-in models (MiniLM, paraphrase-multilingual, BGE variants, etc.)
188/// are registered as `LatticeEmbedderProvider` instances during
189/// `KhiveRuntime` construction. External callers do not need to use this type
190/// unless they are constructing a custom registry from scratch.
191pub struct LatticeEmbedderProvider {
192 model: EmbeddingModel,
193 /// Cached `to_string()` result so `name()` can return `&str`.
194 name: String,
195}
196
197impl LatticeEmbedderProvider {
198 /// Create a new provider wrapping the given lattice model.
199 pub fn new(model: EmbeddingModel) -> Self {
200 let name = model.to_string();
201 Self { model, name }
202 }
203}
204
205#[async_trait]
206impl EmbedderProvider for LatticeEmbedderProvider {
207 fn name(&self) -> &str {
208 &self.name
209 }
210
211 fn dimensions(&self) -> usize {
212 self.model.dimensions()
213 }
214
215 async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
216 let native = Arc::new(NativeEmbeddingService::with_model(self.model));
217 let cached = CachedEmbeddingService::with_default_cache(native);
218 Ok(Arc::new(cached) as Arc<dyn EmbeddingService>)
219 }
220}
221
222// ── Unit tests ────────────────────────────────────────────────────────────────
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use std::sync::atomic::{AtomicUsize, Ordering};
228
229 // ---- minimal mock provider ----
230
231 struct ConstVecProvider {
232 name: String,
233 dims: usize,
234 build_calls: Arc<AtomicUsize>,
235 }
236
237 impl ConstVecProvider {
238 fn new(name: &str, dims: usize) -> Self {
239 Self {
240 name: name.to_owned(),
241 dims,
242 build_calls: Arc::new(AtomicUsize::new(0)),
243 }
244 }
245 }
246
247 /// A trivial embedding service that returns a constant vector of `1.0`s.
248 /// The `model` parameter is ignored — this service always returns the
249 /// same synthetic vector regardless of which model is requested.
250 struct ConstVecService {
251 dims: usize,
252 }
253
254 #[async_trait]
255 impl EmbeddingService for ConstVecService {
256 async fn embed(
257 &self,
258 texts: &[String],
259 _model: EmbeddingModel,
260 ) -> std::result::Result<Vec<Vec<f32>>, lattice_embed::EmbedError> {
261 Ok(texts.iter().map(|_| vec![1.0_f32; self.dims]).collect())
262 }
263
264 fn supports_model(&self, _model: EmbeddingModel) -> bool {
265 true
266 }
267
268 fn name(&self) -> &'static str {
269 "const-vec-service"
270 }
271 }
272
273 #[async_trait]
274 impl EmbedderProvider for ConstVecProvider {
275 fn name(&self) -> &str {
276 &self.name
277 }
278
279 fn dimensions(&self) -> usize {
280 self.dims
281 }
282
283 async fn build(&self) -> RuntimeResult<Arc<dyn EmbeddingService>> {
284 self.build_calls.fetch_add(1, Ordering::SeqCst);
285 Ok(Arc::new(ConstVecService { dims: self.dims }))
286 }
287 }
288
289 // ---- test: register + get round-trip ----
290
291 #[test]
292 fn register_and_get_provider_round_trip() {
293 let mut reg = EmbedderRegistry::new();
294 reg.register(ConstVecProvider::new("mock-384", 384));
295
296 assert!(reg.contains("mock-384"), "registered name must be present");
297 let provider = reg.get_provider("mock-384").expect("provider must exist");
298 assert_eq!(provider.name(), "mock-384");
299 assert_eq!(provider.dimensions(), 384);
300 }
301
302 // ---- test: duplicate name is last-wins (not an error) ----
303
304 #[test]
305 fn duplicate_name_last_wins() {
306 let mut reg = EmbedderRegistry::new();
307 reg.register(ConstVecProvider::new("shared", 128));
308 reg.register(ConstVecProvider::new("shared", 256));
309
310 let provider = reg.get_provider("shared").expect("provider must exist");
311 assert_eq!(
312 provider.dimensions(),
313 256,
314 "last registration must win; expected dims=256"
315 );
316 }
317
318 // ---- test: names() returns all registered names ----
319
320 #[test]
321 fn names_returns_all_registered() {
322 let mut reg = EmbedderRegistry::new();
323 reg.register(ConstVecProvider::new("model-a", 64));
324 reg.register(ConstVecProvider::new("model-b", 128));
325 reg.register(ConstVecProvider::new("model-c", 256));
326
327 let mut names = reg.names();
328 names.sort();
329 assert_eq!(names, vec!["model-a", "model-b", "model-c"]);
330 }
331
332 // ---- test: get_service returns UnknownModel for unregistered name ----
333
334 #[tokio::test]
335 async fn get_service_unknown_name_returns_error() {
336 let reg = EmbedderRegistry::new();
337 let result = reg.get_service("does-not-exist").await;
338 let err = result.err().expect("expected Err for unknown name, got Ok");
339 assert!(
340 matches!(err, RuntimeError::UnknownModel(ref n) if n == "does-not-exist"),
341 "expected UnknownModel, got {err:?}"
342 );
343 }
344
345 // ---- test: get_service calls build once (lazy, cached) ----
346
347 #[tokio::test]
348 async fn get_service_calls_build_once() {
349 let counter = Arc::new(AtomicUsize::new(0));
350 let provider = ConstVecProvider {
351 name: "cached-model".to_owned(),
352 dims: 32,
353 build_calls: Arc::clone(&counter),
354 };
355 let mut reg = EmbedderRegistry::new();
356 reg.register(provider);
357
358 let _ = reg.get_service("cached-model").await.unwrap();
359 let _ = reg.get_service("cached-model").await.unwrap();
360 let _ = reg.get_service("cached-model").await.unwrap();
361
362 assert_eq!(
363 counter.load(Ordering::SeqCst),
364 1,
365 "build must be called exactly once regardless of get_service call count"
366 );
367 }
368}