Skip to main content

reddb_server/runtime/ai/
local_embedding.rs

1//! Local embedding routing (#680).
2//!
3//! Wires the `local` AI provider into HTTP and gRPC embedding surfaces.
4//! Resolves a registered+installed local model (registry from #678,
5//! cache from #679), routes through a swappable
6//! [`LocalEmbeddingBackend`], and returns deterministic, provider-tagged
7//! embeddings.
8//!
9//! The default backend is the in-process [`DeterministicFakeBackend`],
10//! which produces stable f32 vectors from `SHA-256(model || \0 || text)`.
11//! It exists so the end-to-end contract (registry lookup → backend
12//! dispatch → response shape) can be exercised without downloading a
13//! real model. Real candle/onnx engines slot in by calling
14//! [`install_local_embedding_backend`] at server boot.
15//!
16//! Errors are intentionally distinct so callers can disambiguate:
17//!
18//! * `FeatureNotEnabled` — `local-models` is off and no backend is
19//!   installed (routes through HTTP 501 / gRPC `feature_not_enabled`).
20//! * `NotFound` — the named model is not in the registry, or it is
21//!   registered but not installed in the cache.
22//! * `Query` — the registered task is not `embedding`, or the backend
23//!   produced a shape that disagrees with the registered dimensions.
24
25use std::sync::{Arc, OnceLock, RwLock};
26
27use crate::crypto::sha256::Sha256;
28use crate::json::{parse_json, Value as JsonValue};
29use crate::runtime::RedDBRuntime;
30use crate::storage::schema::Value;
31use crate::storage::unified::RedDB;
32use crate::{RedDBError, RedDBResult};
33
34const RED_CONFIG_COLLECTION: &str = "red_config";
35const AI_MODEL_KEY_PREFIX: &str = "red.config.ai.models.";
36const STATUS_INSTALLED: &str = "installed";
37const TASK_EMBEDDING: &str = "embedding";
38const PROVIDER_LOCAL: &str = "local";
39
40/// Canonical pull-policy names mirrored from the model-registry contract
41/// (`crate::server::handlers_ai`). The embed path is read-side and does
42/// not depend on the handler module, so these constants are duplicated
43/// deliberately to keep the runtime crate free of HTTP-layer coupling.
44const PULL_POLICY_NEVER: &str = "never";
45const PULL_POLICY_IF_MISSING: &str = "if_missing";
46const PULL_POLICY_ALWAYS: &str = "always";
47
48/// Normalise a stored `pull_policy` value to its canonical form. Old
49/// registry entries written before the rename still carry
50/// `manual`/`on_demand`/`eager`; those continue to resolve to the
51/// matching canonical name so existing installs keep working.
52fn normalize_stored_pull_policy(raw: &str) -> &'static str {
53    match raw.trim().to_ascii_lowercase().as_str() {
54        "never" | "manual" => PULL_POLICY_NEVER,
55        "always" | "eager" => PULL_POLICY_ALWAYS,
56        // Default — anything else, including the legacy `on_demand`, is
57        // treated as `if_missing` (the safest default for query-time
58        // routing: never auto-acquire silently, but allow operator
59        // pulls).
60        _ => PULL_POLICY_IF_MISSING,
61    }
62}
63
64const LOCAL_MODELS_DISABLED_MESSAGE: &str =
65    "local embeddings require the `local-models` feature flag at engine build time. \
66     Build with: cargo build --features local-models. Alternatively, install a backend \
67     via runtime::ai::local_embedding::install_local_embedding_backend, or use the \
68     'ollama' provider with a local Ollama server.";
69
70/// Materialised view of a single embedding request handed to a backend.
71#[derive(Debug, Clone)]
72pub struct LocalEmbeddingRequest {
73    /// Registered model name (registry key under `red.config.ai.models.{name}`).
74    pub name: String,
75    /// HuggingFace repo id or other source identifier (from the registry).
76    pub source: String,
77    /// Pinned git revision/tag from the registry.
78    pub revision: String,
79    /// Engine identifier from the registry (e.g. `candle`).
80    pub engine: String,
81    /// Output dimensionality declared at registration time.
82    pub dimensions: u32,
83    /// Texts to embed. Empty entries are rejected by the caller.
84    pub inputs: Vec<String>,
85}
86
87/// Backend abstraction so HTTP/gRPC routing does not depend on a
88/// specific model engine. A future candle/onnx backend implements this
89/// trait and is installed via [`install_local_embedding_backend`].
90pub trait LocalEmbeddingBackend: Send + Sync {
91    fn embed(&self, request: &LocalEmbeddingRequest) -> RedDBResult<Vec<Vec<f32>>>;
92}
93
94/// Resolved local-embedding response. Carries provider/model metadata
95/// the wire encoders surface to HTTP and gRPC clients.
96#[derive(Debug, Clone)]
97pub struct LocalEmbeddingResponse {
98    pub provider: &'static str,
99    pub name: String,
100    pub source: String,
101    pub revision: String,
102    pub engine: String,
103    pub dimensions: u32,
104    pub embeddings: Vec<Vec<f32>>,
105}
106
107/// Deterministic, dependency-free backend used to prove the wire
108/// contract end-to-end. The output of `embed(model, text, dim)` is a
109/// pure function of `(model, text, dim)` — no I/O, no clocks, no RNGs
110/// — so tests get byte-identical embeddings across runs.
111#[derive(Debug, Default, Clone, Copy)]
112pub struct DeterministicFakeBackend;
113
114impl LocalEmbeddingBackend for DeterministicFakeBackend {
115    fn embed(&self, request: &LocalEmbeddingRequest) -> RedDBResult<Vec<Vec<f32>>> {
116        let dim = request.dimensions as usize;
117        let mut out = Vec::with_capacity(request.inputs.len());
118        for text in &request.inputs {
119            out.push(deterministic_embedding(&request.name, text, dim));
120        }
121        Ok(out)
122    }
123}
124
125fn deterministic_embedding(model: &str, text: &str, dim: usize) -> Vec<f32> {
126    let mut out = Vec::with_capacity(dim);
127    let mut counter: u32 = 0;
128    while out.len() < dim {
129        let mut hasher = Sha256::new();
130        hasher.update(model.as_bytes());
131        hasher.update(&[0u8]);
132        hasher.update(text.as_bytes());
133        hasher.update(&[0u8]);
134        hasher.update(&counter.to_le_bytes());
135        let digest = hasher.finalize();
136        for chunk in digest.chunks(4) {
137            if out.len() >= dim {
138                break;
139            }
140            let mut bytes = [0u8; 4];
141            bytes.copy_from_slice(chunk);
142            let raw = u32::from_le_bytes(bytes) as f32 / u32::MAX as f32;
143            // Map [0, 1] → [-1, 1) so the fake produces sign-mixed
144            // vectors (the property tests look for both signs).
145            out.push(raw * 2.0 - 1.0);
146        }
147        counter = counter.wrapping_add(1);
148    }
149    out
150}
151
152type BackendSlot = Arc<dyn LocalEmbeddingBackend>;
153
154fn backend_slot() -> &'static RwLock<Option<BackendSlot>> {
155    static SLOT: OnceLock<RwLock<Option<BackendSlot>>> = OnceLock::new();
156    SLOT.get_or_init(|| RwLock::new(None))
157}
158
159/// Install (or replace) the process-global local embedding backend.
160///
161/// Production servers built with `--features local-models` should call
162/// this once at boot with their real engine. Tests use it to swap in
163/// a deterministic stub. Safe to call from any thread; the most recent
164/// install wins.
165pub fn install_local_embedding_backend(backend: Arc<dyn LocalEmbeddingBackend>) {
166    let mut guard = backend_slot().write().expect("backend slot poisoned");
167    *guard = Some(backend);
168}
169
170/// Test-only: clear the installed backend so a subsequent call exercises
171/// the `FeatureNotEnabled` path again.
172#[doc(hidden)]
173pub fn clear_local_embedding_backend_for_tests() {
174    let mut guard = backend_slot().write().expect("backend slot poisoned");
175    *guard = None;
176}
177
178fn current_backend() -> Option<BackendSlot> {
179    backend_slot()
180        .read()
181        .expect("backend slot poisoned")
182        .as_ref()
183        .map(Arc::clone)
184}
185
186/// Resolve and run a local embedding request end-to-end.
187///
188/// Performs, in order:
189/// 1. Backend availability gate (or feature-off error).
190/// 2. Registry lookup for `name` in `red_config`.
191/// 3. Task / status / engine validation.
192/// 4. Backend dispatch.
193/// 5. Shape validation against the registered dimensions.
194pub fn embed_local(
195    runtime: &RedDBRuntime,
196    name: &str,
197    inputs: Vec<String>,
198) -> RedDBResult<LocalEmbeddingResponse> {
199    embed_local_with_db(&runtime.db(), name, inputs)
200}
201
202/// Validate that a local embedding request for `name` would resolve a
203/// registered+installed model and an available backend, without sending
204/// any inputs.
205///
206/// Used by write paths (e.g. INSERT ... WITH AUTO EMBED) that need a
207/// deterministic pre-flight to fail the statement before any side
208/// effect on the target collection, satisfying the
209/// "embedding failures leave the target collection unchanged" contract
210/// for the failure modes the local provider owns: feature disabled,
211/// missing model, uninstalled artifacts, unsupported task, wrong
212/// provider tag, missing dimensions, corrupted registry entry.
213///
214/// Returns the resolved descriptor's `dimensions` so callers can pin
215/// the expected output shape before any backend round-trip.
216pub fn preflight_local_embedding(db: &RedDB, name: &str) -> RedDBResult<u32> {
217    let name = name.trim();
218    if name.is_empty() {
219        return Err(RedDBError::Query(
220            "local embedding 'model' field cannot be empty; pass the registered local model name"
221                .to_string(),
222        ));
223    }
224
225    // Mirror the backend-availability gate from `embed_local_with_db`
226    // so a feature-off build fails before the write phase rather than
227    // after we have already inserted rows.
228    if current_backend().is_none() && !cfg!(feature = "local-models") {
229        return Err(RedDBError::FeatureNotEnabled(
230            LOCAL_MODELS_DISABLED_MESSAGE.to_string(),
231        ));
232    }
233
234    let descriptor = read_model_descriptor(db, name)?;
235    if descriptor.provider != PROVIDER_LOCAL {
236        return Err(RedDBError::Query(format!(
237            "model '{name}' has provider '{}'; only '{PROVIDER_LOCAL}' is supported by local embedding routing",
238            descriptor.provider
239        )));
240    }
241    if descriptor.task != TASK_EMBEDDING {
242        return Err(RedDBError::Query(format!(
243            "model '{name}' has task '{}'; only '{TASK_EMBEDDING}' is supported by the local provider \
244             (prompt/generation are out of scope)",
245            descriptor.task
246        )));
247    }
248    if descriptor.status != STATUS_INSTALLED {
249        let message = match descriptor.pull_policy {
250            PULL_POLICY_NEVER => format!(
251                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
252                 pull_policy='never' forbids runtime acquisition. An operator must explicitly install \
253                 the model via `POST /ai/models/{name}/pull`.",
254                descriptor.status
255            ),
256            PULL_POLICY_ALWAYS => format!(
257                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
258                 pull_policy='always' is configured but query-time auto-pull is not implemented in this slice. \
259                 Trigger a refresh via `POST /ai/models/{name}/pull` before requesting embeddings.",
260                descriptor.status
261            ),
262            _ => format!(
263                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
264                 pull_policy='if_missing' permits acquisition only via the explicit pull endpoint \
265                 (query-time auto-pull is not implemented). Run `POST /ai/models/{name}/pull` to install.",
266                descriptor.status
267            ),
268        };
269        return Err(RedDBError::NotFound(message));
270    }
271    if descriptor.dimensions == 0 {
272        return Err(RedDBError::Query(format!(
273            "model '{name}' registry entry has dimensions=0; re-register with the model's true output width"
274        )));
275    }
276    Ok(descriptor.dimensions)
277}
278
279/// Variant of [`embed_local`] that operates against a `RedDB` handle
280/// directly. The runtime query executor only carries `&RedDB`, so the
281/// text-vector-search routing path calls this rather than the runtime
282/// wrapper above.
283pub fn embed_local_with_db(
284    db: &RedDB,
285    name: &str,
286    inputs: Vec<String>,
287) -> RedDBResult<LocalEmbeddingResponse> {
288    if inputs.is_empty() {
289        return Err(RedDBError::Query(
290            "at least one input is required for local embeddings".to_string(),
291        ));
292    }
293    let name = name.trim();
294    if name.is_empty() {
295        return Err(RedDBError::Query(
296            "local embedding 'model' field cannot be empty; pass the registered local model name"
297                .to_string(),
298        ));
299    }
300
301    let backend = match current_backend() {
302        Some(b) => b,
303        None => {
304            if cfg!(feature = "local-models") {
305                // Feature is on but no engine was installed by the
306                // server boot path — fall back to the deterministic
307                // fake so the surface stays usable in dev builds.
308                let fake: Arc<dyn LocalEmbeddingBackend> = Arc::new(DeterministicFakeBackend);
309                install_local_embedding_backend(Arc::clone(&fake));
310                fake
311            } else {
312                return Err(RedDBError::FeatureNotEnabled(
313                    LOCAL_MODELS_DISABLED_MESSAGE.to_string(),
314                ));
315            }
316        }
317    };
318
319    let descriptor = read_model_descriptor(db, name)?;
320
321    if descriptor.provider != PROVIDER_LOCAL {
322        return Err(RedDBError::Query(format!(
323            "model '{name}' has provider '{}'; only '{PROVIDER_LOCAL}' is supported by local embedding routing",
324            descriptor.provider
325        )));
326    }
327    if descriptor.task != TASK_EMBEDDING {
328        return Err(RedDBError::Query(format!(
329            "model '{name}' has task '{}'; only '{TASK_EMBEDDING}' is supported by the local provider \
330             (prompt/generation are out of scope)",
331            descriptor.task
332        )));
333    }
334    if descriptor.status != STATUS_INSTALLED {
335        // Operator-safe contract: query-time routing never silently
336        // acquires artifacts and never falls back to a remote provider.
337        // Each policy surfaces a clear, distinct error so the operator
338        // knows which knob to turn.
339        let message = match descriptor.pull_policy {
340            PULL_POLICY_NEVER => format!(
341                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
342                 pull_policy='never' forbids runtime acquisition. An operator must explicitly install \
343                 the model via `POST /ai/models/{name}/pull`.",
344                descriptor.status
345            ),
346            PULL_POLICY_ALWAYS => format!(
347                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
348                 pull_policy='always' is configured but query-time auto-pull is not implemented in this slice. \
349                 Trigger a refresh via `POST /ai/models/{name}/pull` before requesting embeddings.",
350                descriptor.status
351            ),
352            // PULL_POLICY_IF_MISSING (default)
353            _ => format!(
354                "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
355                 pull_policy='if_missing' permits acquisition only via the explicit pull endpoint \
356                 (query-time auto-pull is not implemented). Run `POST /ai/models/{name}/pull` to install.",
357                descriptor.status
358            ),
359        };
360        return Err(RedDBError::NotFound(message));
361    }
362
363    let request = LocalEmbeddingRequest {
364        name: descriptor.name.clone(),
365        source: descriptor.source.clone(),
366        revision: descriptor.revision.clone(),
367        engine: descriptor.engine.clone(),
368        dimensions: descriptor.dimensions,
369        inputs,
370    };
371    let embeddings = backend.embed(&request)?;
372
373    if embeddings.len() != request.inputs.len() {
374        return Err(RedDBError::Query(format!(
375            "local backend returned {} embeddings for {} inputs",
376            embeddings.len(),
377            request.inputs.len()
378        )));
379    }
380    for (idx, row) in embeddings.iter().enumerate() {
381        if row.len() != descriptor.dimensions as usize {
382            return Err(RedDBError::Query(format!(
383                "local backend returned embedding[{idx}] of length {} but model '{name}' \
384                 was registered with dimensions={}",
385                row.len(),
386                descriptor.dimensions
387            )));
388        }
389    }
390
391    Ok(LocalEmbeddingResponse {
392        provider: PROVIDER_LOCAL,
393        name: descriptor.name,
394        source: descriptor.source,
395        revision: descriptor.revision,
396        engine: descriptor.engine,
397        dimensions: descriptor.dimensions,
398        embeddings,
399    })
400}
401
402#[derive(Debug, Clone)]
403struct ModelDescriptor {
404    name: String,
405    provider: String,
406    source: String,
407    revision: String,
408    engine: String,
409    task: String,
410    status: String,
411    dimensions: u32,
412    /// Canonical pull policy (`never` / `if_missing` / `always`),
413    /// normalised at read time so the gate logic does not need to know
414    /// about legacy alias spellings.
415    pull_policy: &'static str,
416}
417
418fn read_model_descriptor(db: &RedDB, name: &str) -> RedDBResult<ModelDescriptor> {
419    let key = format!("{AI_MODEL_KEY_PREFIX}{name}");
420    let raw = match db.get_kv(RED_CONFIG_COLLECTION, &key) {
421        Some((Value::Text(text), _)) => text.to_string(),
422        Some(_) => {
423            return Err(RedDBError::Query(format!(
424                "local model registry entry for '{name}' is not a JSON text payload"
425            )));
426        }
427        None => {
428            return Err(RedDBError::NotFound(format!(
429                "local model '{name}' is not registered; POST /ai/models to register it first"
430            )));
431        }
432    };
433    let parsed = parse_json(&raw).map_err(|err| {
434        RedDBError::Query(format!(
435            "local model registry entry for '{name}' is corrupted: {err}"
436        ))
437    })?;
438    let value = JsonValue::from(parsed);
439    let object = value
440        .as_object()
441        .ok_or_else(|| RedDBError::Query(format!("model entry for '{name}' is not an object")))?;
442
443    let pick = |key: &str| -> Option<String> {
444        object
445            .get(key)
446            .and_then(JsonValue::as_str)
447            .map(str::to_string)
448    };
449
450    let provider = pick("provider").unwrap_or_else(|| PROVIDER_LOCAL.to_string());
451    let source = pick("source").unwrap_or_default();
452    let revision = pick("revision").unwrap_or_default();
453    let engine = pick("engine").unwrap_or_default();
454    let task = pick("task").unwrap_or_default();
455    let status = pick("status").unwrap_or_default();
456    let dimensions = object
457        .get("dimensions")
458        .and_then(JsonValue::as_u64)
459        .ok_or_else(|| {
460            RedDBError::Query(format!("model entry for '{name}' is missing 'dimensions'"))
461        })? as u32;
462    let pull_policy = normalize_stored_pull_policy(
463        pick("pull_policy")
464            .as_deref()
465            .unwrap_or(PULL_POLICY_IF_MISSING),
466    );
467
468    Ok(ModelDescriptor {
469        name: pick("name").unwrap_or_else(|| name.to_string()),
470        provider,
471        source,
472        revision,
473        engine,
474        task,
475        status,
476        dimensions,
477        pull_policy,
478    })
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    #[test]
486    fn deterministic_fake_is_pure_and_correct_length() {
487        let backend = DeterministicFakeBackend;
488        let req = LocalEmbeddingRequest {
489            name: "mini".to_string(),
490            source: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
491            revision: "main".to_string(),
492            engine: "candle".to_string(),
493            dimensions: 16,
494            inputs: vec!["hello".to_string(), "world".to_string()],
495        };
496        let a = backend.embed(&req).expect("embed");
497        let b = backend.embed(&req).expect("embed twice");
498        assert_eq!(a, b, "deterministic backend must be pure");
499        assert_eq!(a.len(), 2);
500        assert!(a.iter().all(|v| v.len() == 16));
501        assert_ne!(
502            a[0], a[1],
503            "different inputs must produce different vectors"
504        );
505    }
506
507    #[test]
508    fn deterministic_fake_changes_with_model_name() {
509        let backend = DeterministicFakeBackend;
510        let mk = |name: &str| LocalEmbeddingRequest {
511            name: name.to_string(),
512            source: String::new(),
513            revision: String::new(),
514            engine: String::new(),
515            dimensions: 8,
516            inputs: vec!["x".to_string()],
517        };
518        let a = backend.embed(&mk("alpha")).unwrap();
519        let b = backend.embed(&mk("beta")).unwrap();
520        assert_ne!(a[0], b[0]);
521    }
522}