Skip to main content

reddb_server/runtime/ai/
local_embedding.rs

1//! Local embedding routing (#680).
2//!
3//! Wires the `local` AI provider into HTTP and gRPC embedding surfaces.
4//! Resolves a registered+installed local model (registry from #678,
5//! cache from #679), routes through a swappable
6//! [`LocalEmbeddingBackend`], and returns deterministic, provider-tagged
7//! embeddings.
8//!
9//! The default backend is the in-process [`DeterministicFakeBackend`],
10//! which produces stable f32 vectors from `SHA-256(model || \0 || text)`.
11//! It exists so the end-to-end contract (registry lookup → backend
12//! dispatch → response shape) can be exercised without downloading a
13//! real model. Real candle/onnx engines slot in by calling
14//! [`install_local_embedding_backend`] at server boot.
15//!
16//! Errors are intentionally distinct so callers can disambiguate:
17//!
18//! * `FeatureNotEnabled` — `local-models` is off and no backend is
19//!   installed (routes through HTTP 501 / gRPC `feature_not_enabled`).
20//! * `NotFound` — the named model is not in the registry, or it is
21//!   registered but not installed in the cache.
22//! * `Query` — the registered task is not `embedding`, or the backend
23//!   produced a shape that disagrees with the registered dimensions.
24
25use std::sync::{Arc, OnceLock, RwLock};
26
27use crate::crypto::sha256::Sha256;
28use crate::json::{parse_json, Value as JsonValue};
29use crate::runtime::RedDBRuntime;
30use crate::storage::schema::Value;
31use crate::storage::unified::RedDB;
32use crate::{RedDBError, RedDBResult};
33
34const RED_CONFIG_COLLECTION: &str = "red_config";
35const AI_MODEL_KEY_PREFIX: &str = "red.config.ai.models.";
36const STATUS_INSTALLED: &str = "installed";
37const TASK_EMBEDDING: &str = "embedding";
38const PROVIDER_LOCAL: &str = "local";
39
40/// Canonical pull-policy names mirrored from the model-registry contract
41/// (`crate::server::handlers_ai`). The embed path is read-side and does
42/// not depend on the handler module, so these constants are duplicated
43/// deliberately to keep the runtime crate free of HTTP-layer coupling.
44const PULL_POLICY_NEVER: &str = "never";
45const PULL_POLICY_IF_MISSING: &str = "if_missing";
46const PULL_POLICY_ALWAYS: &str = "always";
47
48/// Normalise a stored `pull_policy` value to its canonical form. Old
49/// registry entries written before the rename still carry
50/// `manual`/`on_demand`/`eager`; those continue to resolve to the
51/// matching canonical name so existing installs keep working.
52fn normalize_stored_pull_policy(raw: &str) -> &'static str {
53    match raw.trim().to_ascii_lowercase().as_str() {
54        "never" | "manual" => PULL_POLICY_NEVER,
55        "always" | "eager" => PULL_POLICY_ALWAYS,
56        // Default — anything else, including the legacy `on_demand`, is
57        // treated as `if_missing` (the safest default for query-time
58        // routing: never auto-acquire silently, but allow operator
59        // pulls).
60        _ => PULL_POLICY_IF_MISSING,
61    }
62}
63
64const LOCAL_MODELS_DISABLED_MESSAGE: &str =
65    "local embeddings require the `local-models` feature flag at engine build time. \
66     Build with: cargo build --features local-models. Alternatively, install a backend \
67     via runtime::ai::local_embedding::install_local_embedding_backend, or use the \
68     'ollama' provider with a local Ollama server.";
69
70/// Materialised view of a single embedding request handed to a backend.
71#[derive(Debug, Clone)]
72pub struct LocalEmbeddingRequest {
73    /// Registered model name (registry key under `red.config.ai.models.{name}`).
74    pub name: String,
75    /// HuggingFace repo id or other source identifier (from the registry).
76    pub source: String,
77    /// Pinned git revision/tag from the registry.
78    pub revision: String,
79    /// Engine identifier from the registry (e.g. `candle`).
80    pub engine: String,
81    /// Output dimensionality declared at registration time.
82    pub dimensions: u32,
83    /// Texts to embed. Empty entries are rejected by the caller.
84    pub inputs: Vec<String>,
85}
86
87/// Backend abstraction so HTTP/gRPC routing does not depend on a
88/// specific model engine. A future candle/onnx backend implements this
89/// trait and is installed via [`install_local_embedding_backend`].
90pub trait LocalEmbeddingBackend: Send + Sync {
91    fn embed(&self, request: &LocalEmbeddingRequest) -> RedDBResult<Vec<Vec<f32>>>;
92}
93
94/// Resolved local-embedding response. Carries provider/model metadata
95/// the wire encoders surface to HTTP and gRPC clients.
96#[derive(Debug, Clone)]
97pub struct LocalEmbeddingResponse {
98    pub provider: &'static str,
99    pub name: String,
100    pub source: String,
101    pub revision: String,
102    pub engine: String,
103    pub dimensions: u32,
104    pub embeddings: Vec<Vec<f32>>,
105}
106
107/// Deterministic, dependency-free backend used to prove the wire
108/// contract end-to-end. The output of `embed(model, text, dim)` is a
109/// pure function of `(model, text, dim)` — no I/O, no clocks, no RNGs
110/// — so tests get byte-identical embeddings across runs.
111#[derive(Debug, Default, Clone, Copy)]
112pub struct DeterministicFakeBackend;
113
114impl LocalEmbeddingBackend for DeterministicFakeBackend {
115    fn embed(&self, request: &LocalEmbeddingRequest) -> RedDBResult<Vec<Vec<f32>>> {
116        let dim = request.dimensions as usize;
117        let mut out = Vec::with_capacity(request.inputs.len());
118        for text in &request.inputs {
119            out.push(deterministic_embedding(&request.name, text, dim));
120        }
121        Ok(out)
122    }
123}
124
125fn deterministic_embedding(model: &str, text: &str, dim: usize) -> Vec<f32> {
126    let mut out = Vec::with_capacity(dim);
127    let mut counter: u32 = 0;
128    while out.len() < dim {
129        let mut hasher = Sha256::new();
130        hasher.update(model.as_bytes());
131        hasher.update(&[0u8]);
132        hasher.update(text.as_bytes());
133        hasher.update(&[0u8]);
134        hasher.update(&counter.to_le_bytes());
135        let digest = hasher.finalize();
136        for chunk in digest.chunks(4) {
137            if out.len() >= dim {
138                break;
139            }
140            let mut bytes = [0u8; 4];
141            bytes.copy_from_slice(chunk);
142            let raw = u32::from_le_bytes(bytes) as f32 / u32::MAX as f32;
143            // Map [0, 1] → [-1, 1) so the fake produces sign-mixed
144            // vectors (the property tests look for both signs).
145            out.push(raw * 2.0 - 1.0);
146        }
147        counter = counter.wrapping_add(1);
148    }
149    out
150}
151
152type BackendSlot = Arc<dyn LocalEmbeddingBackend>;
153
154fn backend_slot() -> &'static RwLock<Option<BackendSlot>> {
155    static SLOT: OnceLock<RwLock<Option<BackendSlot>>> = OnceLock::new();
156    SLOT.get_or_init(|| RwLock::new(None))
157}
158
159/// Install (or replace) the process-global local embedding backend.
160///
161/// Production servers built with `--features local-models` should call
162/// this once at boot with their real engine. Tests use it to swap in
163/// a deterministic stub. Safe to call from any thread; the most recent
164/// install wins.
165pub fn install_local_embedding_backend(backend: Arc<dyn LocalEmbeddingBackend>) {
166    let mut guard = backend_slot().write().expect("backend slot poisoned");
167    *guard = Some(backend);
168}
169
170/// Test-only: clear the installed backend so a subsequent call exercises
171/// the `FeatureNotEnabled` path again.
172#[doc(hidden)]
173pub fn clear_local_embedding_backend_for_tests() {
174    let mut guard = backend_slot().write().expect("backend slot poisoned");
175    *guard = None;
176}
177
178fn current_backend() -> Option<BackendSlot> {
179    backend_slot()
180        .read()
181        .expect("backend slot poisoned")
182        .as_ref()
183        .map(Arc::clone)
184}
185
186/// Return the deterministic feature-disabled error before callers do
187/// request-shape validation. Tests may install a backend without the
188/// Cargo feature, so a present backend still means local embeddings are
189/// available for that process.
190pub fn ensure_local_embedding_available() -> RedDBResult<()> {
191    if current_backend().is_none() && !cfg!(feature = "local-models") {
192        return Err(RedDBError::FeatureNotEnabled(
193            LOCAL_MODELS_DISABLED_MESSAGE.to_string(),
194        ));
195    }
196    Ok(())
197}
198
199/// Resolve and run a local embedding request end-to-end.
200///
201/// Performs, in order:
202/// 1. Backend availability gate (or feature-off error).
203/// 2. Registry lookup for `name` in `red_config`.
204/// 3. Task / status / engine validation.
205/// 4. Backend dispatch.
206/// 5. Shape validation against the registered dimensions.
207pub fn embed_local(
208    runtime: &RedDBRuntime,
209    name: &str,
210    inputs: Vec<String>,
211) -> RedDBResult<LocalEmbeddingResponse> {
212    embed_local_with_db(&runtime.db(), name, inputs)
213}
214
215/// Validate that a local embedding request for `name` would resolve a
216/// registered+installed model and an available backend, without sending
217/// any inputs.
218///
219/// Used by write paths (e.g. INSERT ... WITH AUTO EMBED) that need a
220/// deterministic pre-flight to fail the statement before any side
221/// effect on the target collection, satisfying the
222/// "embedding failures leave the target collection unchanged" contract
223/// for the failure modes the local provider owns: feature disabled,
224/// missing model, uninstalled artifacts, unsupported task, wrong
225/// provider tag, missing dimensions, corrupted registry entry.
226///
227/// Returns the resolved descriptor's `dimensions` so callers can pin
228/// the expected output shape before any backend round-trip.
229pub fn preflight_local_embedding(db: &RedDB, name: &str) -> RedDBResult<u32> {
230    let name = name.trim();
231    if name.is_empty() {
232        return Err(RedDBError::Query(
233            "local embedding 'model' field cannot be empty; pass the registered local model name"
234                .to_string(),
235        ));
236    }
237
238    // Mirror the backend-availability gate from `embed_local_with_db`
239    // so a feature-off build fails before the write phase rather than
240    // after we have already inserted rows.
241    ensure_local_embedding_available()?;
242
243    let descriptor = read_model_descriptor(db, name)?;
244    if descriptor.provider != PROVIDER_LOCAL {
245        return Err(RedDBError::Query(format!(
246            "model '{name}' has provider '{}'; only '{PROVIDER_LOCAL}' is supported by local embedding routing",
247            descriptor.provider
248        )));
249    }
250    if descriptor.task != TASK_EMBEDDING {
251        return Err(RedDBError::Query(format!(
252            "model '{name}' has task '{}'; only '{TASK_EMBEDDING}' is supported by the local provider \
253             (prompt/generation are out of scope)",
254            descriptor.task
255        )));
256    }
257    if descriptor.status != STATUS_INSTALLED {
258        let message = match descriptor.pull_policy {
259            PULL_POLICY_NEVER => format!(
260                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
261                 pull_policy='never' forbids runtime acquisition. An operator must explicitly install \
262                 the model via `POST /ai/models/{name}/pull`.",
263                descriptor.status
264            ),
265            PULL_POLICY_ALWAYS => format!(
266                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
267                 pull_policy='always' is configured but query-time auto-pull is not implemented in this slice. \
268                 Trigger a refresh via `POST /ai/models/{name}/pull` before requesting embeddings.",
269                descriptor.status
270            ),
271            _ => format!(
272                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
273                 pull_policy='if_missing' permits acquisition only via the explicit pull endpoint \
274                 (query-time auto-pull is not implemented). Run `POST /ai/models/{name}/pull` to install.",
275                descriptor.status
276            ),
277        };
278        return Err(RedDBError::NotFound(message));
279    }
280    if descriptor.dimensions == 0 {
281        return Err(RedDBError::Query(format!(
282            "model '{name}' registry entry has dimensions=0; re-register with the model's true output width"
283        )));
284    }
285    Ok(descriptor.dimensions)
286}
287
288/// Variant of [`embed_local`] that operates against a `RedDB` handle
289/// directly. The runtime query executor only carries `&RedDB`, so the
290/// text-vector-search routing path calls this rather than the runtime
291/// wrapper above.
292pub fn embed_local_with_db(
293    db: &RedDB,
294    name: &str,
295    inputs: Vec<String>,
296) -> RedDBResult<LocalEmbeddingResponse> {
297    if inputs.is_empty() {
298        return Err(RedDBError::Query(
299            "at least one input is required for local embeddings".to_string(),
300        ));
301    }
302    let name = name.trim();
303    if name.is_empty() {
304        return Err(RedDBError::Query(
305            "local embedding 'model' field cannot be empty; pass the registered local model name"
306                .to_string(),
307        ));
308    }
309
310    let backend = match current_backend() {
311        Some(b) => b,
312        None => {
313            if cfg!(feature = "local-models") {
314                // Feature is on but no engine was installed by the
315                // server boot path — fall back to the deterministic
316                // fake so the surface stays usable in dev builds.
317                let fake: Arc<dyn LocalEmbeddingBackend> = Arc::new(DeterministicFakeBackend);
318                install_local_embedding_backend(Arc::clone(&fake));
319                fake
320            } else {
321                return Err(RedDBError::FeatureNotEnabled(
322                    LOCAL_MODELS_DISABLED_MESSAGE.to_string(),
323                ));
324            }
325        }
326    };
327
328    let descriptor = read_model_descriptor(db, name)?;
329
330    if descriptor.provider != PROVIDER_LOCAL {
331        return Err(RedDBError::Query(format!(
332            "model '{name}' has provider '{}'; only '{PROVIDER_LOCAL}' is supported by local embedding routing",
333            descriptor.provider
334        )));
335    }
336    if descriptor.task != TASK_EMBEDDING {
337        return Err(RedDBError::Query(format!(
338            "model '{name}' has task '{}'; only '{TASK_EMBEDDING}' is supported by the local provider \
339             (prompt/generation are out of scope)",
340            descriptor.task
341        )));
342    }
343    if descriptor.status != STATUS_INSTALLED {
344        // Operator-safe contract: query-time routing never silently
345        // acquires artifacts and never falls back to a remote provider.
346        // Each policy surfaces a clear, distinct error so the operator
347        // knows which knob to turn.
348        let message = match descriptor.pull_policy {
349            PULL_POLICY_NEVER => format!(
350                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
351                 pull_policy='never' forbids runtime acquisition. An operator must explicitly install \
352                 the model via `POST /ai/models/{name}/pull`.",
353                descriptor.status
354            ),
355            PULL_POLICY_ALWAYS => format!(
356                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
357                 pull_policy='always' is configured but query-time auto-pull is not implemented in this slice. \
358                 Trigger a refresh via `POST /ai/models/{name}/pull` before requesting embeddings.",
359                descriptor.status
360            ),
361            // PULL_POLICY_IF_MISSING (default)
362            _ => format!(
363                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
364                 pull_policy='if_missing' permits acquisition only via the explicit pull endpoint \
365                 (query-time auto-pull is not implemented). Run `POST /ai/models/{name}/pull` to install.",
366                descriptor.status
367            ),
368        };
369        return Err(RedDBError::NotFound(message));
370    }
371
372    let request = LocalEmbeddingRequest {
373        name: descriptor.name.clone(),
374        source: descriptor.source.clone(),
375        revision: descriptor.revision.clone(),
376        engine: descriptor.engine.clone(),
377        dimensions: descriptor.dimensions,
378        inputs,
379    };
380    let embeddings = backend.embed(&request)?;
381
382    if embeddings.len() != request.inputs.len() {
383        return Err(RedDBError::Query(format!(
384            "local backend returned {} embeddings for {} inputs",
385            embeddings.len(),
386            request.inputs.len()
387        )));
388    }
389    for (idx, row) in embeddings.iter().enumerate() {
390        if row.len() != descriptor.dimensions as usize {
391            return Err(RedDBError::Query(format!(
392                "local backend returned embedding[{idx}] of length {} but model '{name}' \
393                 was registered with dimensions={}",
394                row.len(),
395                descriptor.dimensions
396            )));
397        }
398    }
399
400    Ok(LocalEmbeddingResponse {
401        provider: PROVIDER_LOCAL,
402        name: descriptor.name,
403        source: descriptor.source,
404        revision: descriptor.revision,
405        engine: descriptor.engine,
406        dimensions: descriptor.dimensions,
407        embeddings,
408    })
409}
410
411#[derive(Debug, Clone)]
412struct ModelDescriptor {
413    name: String,
414    provider: String,
415    source: String,
416    revision: String,
417    engine: String,
418    task: String,
419    status: String,
420    dimensions: u32,
421    /// Canonical pull policy (`never` / `if_missing` / `always`),
422    /// normalised at read time so the gate logic does not need to know
423    /// about legacy alias spellings.
424    pull_policy: &'static str,
425}
426
427fn read_model_descriptor(db: &RedDB, name: &str) -> RedDBResult<ModelDescriptor> {
428    let key = format!("{AI_MODEL_KEY_PREFIX}{name}");
429    let raw = match db.get_kv(RED_CONFIG_COLLECTION, &key) {
430        Some((Value::Text(text), _)) => text.to_string(),
431        Some(_) => {
432            return Err(RedDBError::Query(format!(
433                "local model registry entry for '{name}' is not a JSON text payload"
434            )));
435        }
436        None => {
437            return Err(RedDBError::NotFound(format!(
438                "local model '{name}' is not registered; POST /ai/models to register it first"
439            )));
440        }
441    };
442    let parsed = parse_json(&raw).map_err(|err| {
443        RedDBError::Query(format!(
444            "local model registry entry for '{name}' is corrupted: {err}"
445        ))
446    })?;
447    let value = JsonValue::from(parsed);
448    let object = value
449        .as_object()
450        .ok_or_else(|| RedDBError::Query(format!("model entry for '{name}' is not an object")))?;
451
452    let pick = |key: &str| -> Option<String> {
453        object
454            .get(key)
455            .and_then(JsonValue::as_str)
456            .map(str::to_string)
457    };
458
459    let provider = pick("provider").unwrap_or_else(|| PROVIDER_LOCAL.to_string());
460    let source = pick("source").unwrap_or_default();
461    let revision = pick("revision").unwrap_or_default();
462    let engine = pick("engine").unwrap_or_default();
463    let task = pick("task").unwrap_or_default();
464    let status = pick("status").unwrap_or_default();
465    let dimensions = object
466        .get("dimensions")
467        .and_then(JsonValue::as_u64)
468        .ok_or_else(|| {
469            RedDBError::Query(format!("model entry for '{name}' is missing 'dimensions'"))
470        })? as u32;
471    let pull_policy = normalize_stored_pull_policy(
472        pick("pull_policy")
473            .as_deref()
474            .unwrap_or(PULL_POLICY_IF_MISSING),
475    );
476
477    Ok(ModelDescriptor {
478        name: pick("name").unwrap_or_else(|| name.to_string()),
479        provider,
480        source,
481        revision,
482        engine,
483        task,
484        status,
485        dimensions,
486        pull_policy,
487    })
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493
494    #[test]
495    fn deterministic_fake_is_pure_and_correct_length() {
496        let backend = DeterministicFakeBackend;
497        let req = LocalEmbeddingRequest {
498            name: "mini".to_string(),
499            source: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
500            revision: "main".to_string(),
501            engine: "candle".to_string(),
502            dimensions: 16,
503            inputs: vec!["hello".to_string(), "world".to_string()],
504        };
505        let a = backend.embed(&req).expect("embed");
506        let b = backend.embed(&req).expect("embed twice");
507        assert_eq!(a, b, "deterministic backend must be pure");
508        assert_eq!(a.len(), 2);
509        assert!(a.iter().all(|v| v.len() == 16));
510        assert_ne!(
511            a[0], a[1],
512            "different inputs must produce different vectors"
513        );
514    }
515
516    #[test]
517    fn deterministic_fake_changes_with_model_name() {
518        let backend = DeterministicFakeBackend;
519        let mk = |name: &str| LocalEmbeddingRequest {
520            name: name.to_string(),
521            source: String::new(),
522            revision: String::new(),
523            engine: String::new(),
524            dimensions: 8,
525            inputs: vec!["x".to_string()],
526        };
527        let a = backend.embed(&mk("alpha")).unwrap();
528        let b = backend.embed(&mk("beta")).unwrap();
529        assert_ne!(a[0], b[0]);
530    }
531}