1use std::sync::{Arc, OnceLock, RwLock};
26
27use crate::crypto::sha256::Sha256;
28use crate::json::{parse_json, Value as JsonValue};
29use crate::runtime::RedDBRuntime;
30use crate::storage::schema::Value;
31use crate::storage::unified::RedDB;
32use crate::{RedDBError, RedDBResult};
33
34const RED_CONFIG_COLLECTION: &str = "red_config";
35const AI_MODEL_KEY_PREFIX: &str = "red.config.ai.models.";
36const STATUS_INSTALLED: &str = "installed";
37const TASK_EMBEDDING: &str = "embedding";
38const PROVIDER_LOCAL: &str = "local";
39
40const PULL_POLICY_NEVER: &str = "never";
45const PULL_POLICY_IF_MISSING: &str = "if_missing";
46const PULL_POLICY_ALWAYS: &str = "always";
47
48fn normalize_stored_pull_policy(raw: &str) -> &'static str {
53 match raw.trim().to_ascii_lowercase().as_str() {
54 "never" | "manual" => PULL_POLICY_NEVER,
55 "always" | "eager" => PULL_POLICY_ALWAYS,
56 _ => PULL_POLICY_IF_MISSING,
61 }
62}
63
64const LOCAL_MODELS_DISABLED_MESSAGE: &str =
65 "local embeddings require the `local-models` feature flag at engine build time. \
66 Build with: cargo build --features local-models. Alternatively, install a backend \
67 via runtime::ai::local_embedding::install_local_embedding_backend, or use the \
68 'ollama' provider with a local Ollama server.";
69
70#[derive(Debug, Clone)]
72pub struct LocalEmbeddingRequest {
73 pub name: String,
75 pub source: String,
77 pub revision: String,
79 pub engine: String,
81 pub dimensions: u32,
83 pub inputs: Vec<String>,
85}
86
87pub trait LocalEmbeddingBackend: Send + Sync {
91 fn embed(&self, request: &LocalEmbeddingRequest) -> RedDBResult<Vec<Vec<f32>>>;
92}
93
94#[derive(Debug, Clone)]
97pub struct LocalEmbeddingResponse {
98 pub provider: &'static str,
99 pub name: String,
100 pub source: String,
101 pub revision: String,
102 pub engine: String,
103 pub dimensions: u32,
104 pub embeddings: Vec<Vec<f32>>,
105}
106
107#[derive(Debug, Default, Clone, Copy)]
112pub struct DeterministicFakeBackend;
113
114impl LocalEmbeddingBackend for DeterministicFakeBackend {
115 fn embed(&self, request: &LocalEmbeddingRequest) -> RedDBResult<Vec<Vec<f32>>> {
116 let dim = request.dimensions as usize;
117 let mut out = Vec::with_capacity(request.inputs.len());
118 for text in &request.inputs {
119 out.push(deterministic_embedding(&request.name, text, dim));
120 }
121 Ok(out)
122 }
123}
124
125fn deterministic_embedding(model: &str, text: &str, dim: usize) -> Vec<f32> {
126 let mut out = Vec::with_capacity(dim);
127 let mut counter: u32 = 0;
128 while out.len() < dim {
129 let mut hasher = Sha256::new();
130 hasher.update(model.as_bytes());
131 hasher.update(&[0u8]);
132 hasher.update(text.as_bytes());
133 hasher.update(&[0u8]);
134 hasher.update(&counter.to_le_bytes());
135 let digest = hasher.finalize();
136 for chunk in digest.chunks(4) {
137 if out.len() >= dim {
138 break;
139 }
140 let mut bytes = [0u8; 4];
141 bytes.copy_from_slice(chunk);
142 let raw = u32::from_le_bytes(bytes) as f32 / u32::MAX as f32;
143 out.push(raw * 2.0 - 1.0);
146 }
147 counter = counter.wrapping_add(1);
148 }
149 out
150}
151
152type BackendSlot = Arc<dyn LocalEmbeddingBackend>;
153
154fn backend_slot() -> &'static RwLock<Option<BackendSlot>> {
155 static SLOT: OnceLock<RwLock<Option<BackendSlot>>> = OnceLock::new();
156 SLOT.get_or_init(|| RwLock::new(None))
157}
158
159pub fn install_local_embedding_backend(backend: Arc<dyn LocalEmbeddingBackend>) {
166 let mut guard = backend_slot().write().expect("backend slot poisoned");
167 *guard = Some(backend);
168}
169
170#[doc(hidden)]
173pub fn clear_local_embedding_backend_for_tests() {
174 let mut guard = backend_slot().write().expect("backend slot poisoned");
175 *guard = None;
176}
177
178fn current_backend() -> Option<BackendSlot> {
179 backend_slot()
180 .read()
181 .expect("backend slot poisoned")
182 .as_ref()
183 .map(Arc::clone)
184}
185
186pub fn embed_local(
195 runtime: &RedDBRuntime,
196 name: &str,
197 inputs: Vec<String>,
198) -> RedDBResult<LocalEmbeddingResponse> {
199 embed_local_with_db(&runtime.db(), name, inputs)
200}
201
202pub fn preflight_local_embedding(db: &RedDB, name: &str) -> RedDBResult<u32> {
217 let name = name.trim();
218 if name.is_empty() {
219 return Err(RedDBError::Query(
220 "local embedding 'model' field cannot be empty; pass the registered local model name"
221 .to_string(),
222 ));
223 }
224
225 if current_backend().is_none() && !cfg!(feature = "local-models") {
229 return Err(RedDBError::FeatureNotEnabled(
230 LOCAL_MODELS_DISABLED_MESSAGE.to_string(),
231 ));
232 }
233
234 let descriptor = read_model_descriptor(db, name)?;
235 if descriptor.provider != PROVIDER_LOCAL {
236 return Err(RedDBError::Query(format!(
237 "model '{name}' has provider '{}'; only '{PROVIDER_LOCAL}' is supported by local embedding routing",
238 descriptor.provider
239 )));
240 }
241 if descriptor.task != TASK_EMBEDDING {
242 return Err(RedDBError::Query(format!(
243 "model '{name}' has task '{}'; only '{TASK_EMBEDDING}' is supported by the local provider \
244 (prompt/generation are out of scope)",
245 descriptor.task
246 )));
247 }
248 if descriptor.status != STATUS_INSTALLED {
249 let message = match descriptor.pull_policy {
250 PULL_POLICY_NEVER => format!(
251 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
252 pull_policy='never' forbids runtime acquisition. An operator must explicitly install \
253 the model via `POST /ai/models/{name}/pull`.",
254 descriptor.status
255 ),
256 PULL_POLICY_ALWAYS => format!(
257 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
258 pull_policy='always' is configured but query-time auto-pull is not implemented in this slice. \
259 Trigger a refresh via `POST /ai/models/{name}/pull` before requesting embeddings.",
260 descriptor.status
261 ),
262 _ => format!(
263 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
264 pull_policy='if_missing' permits acquisition only via the explicit pull endpoint \
265 (query-time auto-pull is not implemented). Run `POST /ai/models/{name}/pull` to install.",
266 descriptor.status
267 ),
268 };
269 return Err(RedDBError::NotFound(message));
270 }
271 if descriptor.dimensions == 0 {
272 return Err(RedDBError::Query(format!(
273 "model '{name}' registry entry has dimensions=0; re-register with the model's true output width"
274 )));
275 }
276 Ok(descriptor.dimensions)
277}
278
279pub fn embed_local_with_db(
284 db: &RedDB,
285 name: &str,
286 inputs: Vec<String>,
287) -> RedDBResult<LocalEmbeddingResponse> {
288 if inputs.is_empty() {
289 return Err(RedDBError::Query(
290 "at least one input is required for local embeddings".to_string(),
291 ));
292 }
293 let name = name.trim();
294 if name.is_empty() {
295 return Err(RedDBError::Query(
296 "local embedding 'model' field cannot be empty; pass the registered local model name"
297 .to_string(),
298 ));
299 }
300
301 let backend = match current_backend() {
302 Some(b) => b,
303 None => {
304 if cfg!(feature = "local-models") {
305 let fake: Arc<dyn LocalEmbeddingBackend> = Arc::new(DeterministicFakeBackend);
309 install_local_embedding_backend(Arc::clone(&fake));
310 fake
311 } else {
312 return Err(RedDBError::FeatureNotEnabled(
313 LOCAL_MODELS_DISABLED_MESSAGE.to_string(),
314 ));
315 }
316 }
317 };
318
319 let descriptor = read_model_descriptor(db, name)?;
320
321 if descriptor.provider != PROVIDER_LOCAL {
322 return Err(RedDBError::Query(format!(
323 "model '{name}' has provider '{}'; only '{PROVIDER_LOCAL}' is supported by local embedding routing",
324 descriptor.provider
325 )));
326 }
327 if descriptor.task != TASK_EMBEDDING {
328 return Err(RedDBError::Query(format!(
329 "model '{name}' has task '{}'; only '{TASK_EMBEDDING}' is supported by the local provider \
330 (prompt/generation are out of scope)",
331 descriptor.task
332 )));
333 }
334 if descriptor.status != STATUS_INSTALLED {
335 let message = match descriptor.pull_policy {
340 PULL_POLICY_NEVER => format!(
341 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
342 pull_policy='never' forbids runtime acquisition. An operator must explicitly install \
343 the model via `POST /ai/models/{name}/pull`.",
344 descriptor.status
345 ),
346 PULL_POLICY_ALWAYS => format!(
347 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
348 pull_policy='always' is configured but query-time auto-pull is not implemented in this slice. \
349 Trigger a refresh via `POST /ai/models/{name}/pull` before requesting embeddings.",
350 descriptor.status
351 ),
352 _ => format!(
354 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
355 pull_policy='if_missing' permits acquisition only via the explicit pull endpoint \
356 (query-time auto-pull is not implemented). Run `POST /ai/models/{name}/pull` to install.",
357 descriptor.status
358 ),
359 };
360 return Err(RedDBError::NotFound(message));
361 }
362
363 let request = LocalEmbeddingRequest {
364 name: descriptor.name.clone(),
365 source: descriptor.source.clone(),
366 revision: descriptor.revision.clone(),
367 engine: descriptor.engine.clone(),
368 dimensions: descriptor.dimensions,
369 inputs,
370 };
371 let embeddings = backend.embed(&request)?;
372
373 if embeddings.len() != request.inputs.len() {
374 return Err(RedDBError::Query(format!(
375 "local backend returned {} embeddings for {} inputs",
376 embeddings.len(),
377 request.inputs.len()
378 )));
379 }
380 for (idx, row) in embeddings.iter().enumerate() {
381 if row.len() != descriptor.dimensions as usize {
382 return Err(RedDBError::Query(format!(
383 "local backend returned embedding[{idx}] of length {} but model '{name}' \
384 was registered with dimensions={}",
385 row.len(),
386 descriptor.dimensions
387 )));
388 }
389 }
390
391 Ok(LocalEmbeddingResponse {
392 provider: PROVIDER_LOCAL,
393 name: descriptor.name,
394 source: descriptor.source,
395 revision: descriptor.revision,
396 engine: descriptor.engine,
397 dimensions: descriptor.dimensions,
398 embeddings,
399 })
400}
401
402#[derive(Debug, Clone)]
403struct ModelDescriptor {
404 name: String,
405 provider: String,
406 source: String,
407 revision: String,
408 engine: String,
409 task: String,
410 status: String,
411 dimensions: u32,
412 pull_policy: &'static str,
416}
417
418fn read_model_descriptor(db: &RedDB, name: &str) -> RedDBResult<ModelDescriptor> {
419 let key = format!("{AI_MODEL_KEY_PREFIX}{name}");
420 let raw = match db.get_kv(RED_CONFIG_COLLECTION, &key) {
421 Some((Value::Text(text), _)) => text.to_string(),
422 Some(_) => {
423 return Err(RedDBError::Query(format!(
424 "local model registry entry for '{name}' is not a JSON text payload"
425 )));
426 }
427 None => {
428 return Err(RedDBError::NotFound(format!(
429 "local model '{name}' is not registered; POST /ai/models to register it first"
430 )));
431 }
432 };
433 let parsed = parse_json(&raw).map_err(|err| {
434 RedDBError::Query(format!(
435 "local model registry entry for '{name}' is corrupted: {err}"
436 ))
437 })?;
438 let value = JsonValue::from(parsed);
439 let object = value
440 .as_object()
441 .ok_or_else(|| RedDBError::Query(format!("model entry for '{name}' is not an object")))?;
442
443 let pick = |key: &str| -> Option<String> {
444 object
445 .get(key)
446 .and_then(JsonValue::as_str)
447 .map(str::to_string)
448 };
449
450 let provider = pick("provider").unwrap_or_else(|| PROVIDER_LOCAL.to_string());
451 let source = pick("source").unwrap_or_default();
452 let revision = pick("revision").unwrap_or_default();
453 let engine = pick("engine").unwrap_or_default();
454 let task = pick("task").unwrap_or_default();
455 let status = pick("status").unwrap_or_default();
456 let dimensions = object
457 .get("dimensions")
458 .and_then(JsonValue::as_u64)
459 .ok_or_else(|| {
460 RedDBError::Query(format!("model entry for '{name}' is missing 'dimensions'"))
461 })? as u32;
462 let pull_policy = normalize_stored_pull_policy(
463 pick("pull_policy")
464 .as_deref()
465 .unwrap_or(PULL_POLICY_IF_MISSING),
466 );
467
468 Ok(ModelDescriptor {
469 name: pick("name").unwrap_or_else(|| name.to_string()),
470 provider,
471 source,
472 revision,
473 engine,
474 task,
475 status,
476 dimensions,
477 pull_policy,
478 })
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn deterministic_fake_is_pure_and_correct_length() {
487 let backend = DeterministicFakeBackend;
488 let req = LocalEmbeddingRequest {
489 name: "mini".to_string(),
490 source: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
491 revision: "main".to_string(),
492 engine: "candle".to_string(),
493 dimensions: 16,
494 inputs: vec!["hello".to_string(), "world".to_string()],
495 };
496 let a = backend.embed(&req).expect("embed");
497 let b = backend.embed(&req).expect("embed twice");
498 assert_eq!(a, b, "deterministic backend must be pure");
499 assert_eq!(a.len(), 2);
500 assert!(a.iter().all(|v| v.len() == 16));
501 assert_ne!(
502 a[0], a[1],
503 "different inputs must produce different vectors"
504 );
505 }
506
507 #[test]
508 fn deterministic_fake_changes_with_model_name() {
509 let backend = DeterministicFakeBackend;
510 let mk = |name: &str| LocalEmbeddingRequest {
511 name: name.to_string(),
512 source: String::new(),
513 revision: String::new(),
514 engine: String::new(),
515 dimensions: 8,
516 inputs: vec!["x".to_string()],
517 };
518 let a = backend.embed(&mk("alpha")).unwrap();
519 let b = backend.embed(&mk("beta")).unwrap();
520 assert_ne!(a[0], b[0]);
521 }
522}