1use std::sync::{Arc, OnceLock, RwLock};
26
27use crate::crypto::sha256::Sha256;
28use crate::json::{parse_json, Value as JsonValue};
29use crate::runtime::RedDBRuntime;
30use crate::storage::schema::Value;
31use crate::storage::unified::RedDB;
32use crate::{RedDBError, RedDBResult};
33
34const RED_CONFIG_COLLECTION: &str = "red_config";
35const AI_MODEL_KEY_PREFIX: &str = "red.config.ai.models.";
36const STATUS_INSTALLED: &str = "installed";
37const TASK_EMBEDDING: &str = "embedding";
38const PROVIDER_LOCAL: &str = "local";
39
40const PULL_POLICY_NEVER: &str = "never";
45const PULL_POLICY_IF_MISSING: &str = "if_missing";
46const PULL_POLICY_ALWAYS: &str = "always";
47
48fn normalize_stored_pull_policy(raw: &str) -> &'static str {
53 match raw.trim().to_ascii_lowercase().as_str() {
54 "never" | "manual" => PULL_POLICY_NEVER,
55 "always" | "eager" => PULL_POLICY_ALWAYS,
56 _ => PULL_POLICY_IF_MISSING,
61 }
62}
63
64const LOCAL_MODELS_DISABLED_MESSAGE: &str =
65 "local embeddings require the `local-models` feature flag at engine build time. \
66 Build with: cargo build --features local-models. Alternatively, install a backend \
67 via runtime::ai::local_embedding::install_local_embedding_backend, or use the \
68 'ollama' provider with a local Ollama server.";
69
70#[derive(Debug, Clone)]
72pub struct LocalEmbeddingRequest {
73 pub name: String,
75 pub source: String,
77 pub revision: String,
79 pub engine: String,
81 pub dimensions: u32,
83 pub inputs: Vec<String>,
85}
86
87pub trait LocalEmbeddingBackend: Send + Sync {
91 fn embed(&self, request: &LocalEmbeddingRequest) -> RedDBResult<Vec<Vec<f32>>>;
92}
93
94#[derive(Debug, Clone)]
97pub struct LocalEmbeddingResponse {
98 pub provider: &'static str,
99 pub name: String,
100 pub source: String,
101 pub revision: String,
102 pub engine: String,
103 pub dimensions: u32,
104 pub embeddings: Vec<Vec<f32>>,
105}
106
107#[derive(Debug, Default, Clone, Copy)]
112pub struct DeterministicFakeBackend;
113
114impl LocalEmbeddingBackend for DeterministicFakeBackend {
115 fn embed(&self, request: &LocalEmbeddingRequest) -> RedDBResult<Vec<Vec<f32>>> {
116 let dim = request.dimensions as usize;
117 let mut out = Vec::with_capacity(request.inputs.len());
118 for text in &request.inputs {
119 out.push(deterministic_embedding(&request.name, text, dim));
120 }
121 Ok(out)
122 }
123}
124
125fn deterministic_embedding(model: &str, text: &str, dim: usize) -> Vec<f32> {
126 let mut out = Vec::with_capacity(dim);
127 let mut counter: u32 = 0;
128 while out.len() < dim {
129 let mut hasher = Sha256::new();
130 hasher.update(model.as_bytes());
131 hasher.update(&[0u8]);
132 hasher.update(text.as_bytes());
133 hasher.update(&[0u8]);
134 hasher.update(&counter.to_le_bytes());
135 let digest = hasher.finalize();
136 for chunk in digest.chunks(4) {
137 if out.len() >= dim {
138 break;
139 }
140 let mut bytes = [0u8; 4];
141 bytes.copy_from_slice(chunk);
142 let raw = u32::from_le_bytes(bytes) as f32 / u32::MAX as f32;
143 out.push(raw * 2.0 - 1.0);
146 }
147 counter = counter.wrapping_add(1);
148 }
149 out
150}
151
152type BackendSlot = Arc<dyn LocalEmbeddingBackend>;
153
154fn backend_slot() -> &'static RwLock<Option<BackendSlot>> {
155 static SLOT: OnceLock<RwLock<Option<BackendSlot>>> = OnceLock::new();
156 SLOT.get_or_init(|| RwLock::new(None))
157}
158
159pub fn install_local_embedding_backend(backend: Arc<dyn LocalEmbeddingBackend>) {
166 let mut guard = backend_slot().write().expect("backend slot poisoned");
167 *guard = Some(backend);
168}
169
170#[doc(hidden)]
173pub fn clear_local_embedding_backend_for_tests() {
174 let mut guard = backend_slot().write().expect("backend slot poisoned");
175 *guard = None;
176}
177
178fn current_backend() -> Option<BackendSlot> {
179 backend_slot()
180 .read()
181 .expect("backend slot poisoned")
182 .as_ref()
183 .map(Arc::clone)
184}
185
186pub fn ensure_local_embedding_available() -> RedDBResult<()> {
191 if current_backend().is_none() && !cfg!(feature = "local-models") {
192 return Err(RedDBError::FeatureNotEnabled(
193 LOCAL_MODELS_DISABLED_MESSAGE.to_string(),
194 ));
195 }
196 Ok(())
197}
198
199pub fn embed_local(
208 runtime: &RedDBRuntime,
209 name: &str,
210 inputs: Vec<String>,
211) -> RedDBResult<LocalEmbeddingResponse> {
212 embed_local_with_db(&runtime.db(), name, inputs)
213}
214
215pub fn preflight_local_embedding(db: &RedDB, name: &str) -> RedDBResult<u32> {
230 let name = name.trim();
231 if name.is_empty() {
232 return Err(RedDBError::Query(
233 "local embedding 'model' field cannot be empty; pass the registered local model name"
234 .to_string(),
235 ));
236 }
237
238 ensure_local_embedding_available()?;
242
243 let descriptor = read_model_descriptor(db, name)?;
244 if descriptor.provider != PROVIDER_LOCAL {
245 return Err(RedDBError::Query(format!(
246 "model '{name}' has provider '{}'; only '{PROVIDER_LOCAL}' is supported by local embedding routing",
247 descriptor.provider
248 )));
249 }
250 if descriptor.task != TASK_EMBEDDING {
251 return Err(RedDBError::Query(format!(
252 "model '{name}' has task '{}'; only '{TASK_EMBEDDING}' is supported by the local provider \
253 (prompt/generation are out of scope)",
254 descriptor.task
255 )));
256 }
257 if descriptor.status != STATUS_INSTALLED {
258 let message = match descriptor.pull_policy {
259 PULL_POLICY_NEVER => format!(
260 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
261 pull_policy='never' forbids runtime acquisition. An operator must explicitly install \
262 the model via `POST /ai/models/{name}/pull`.",
263 descriptor.status
264 ),
265 PULL_POLICY_ALWAYS => format!(
266 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
267 pull_policy='always' is configured but query-time auto-pull is not implemented in this slice. \
268 Trigger a refresh via `POST /ai/models/{name}/pull` before requesting embeddings.",
269 descriptor.status
270 ),
271 _ => format!(
272 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
273 pull_policy='if_missing' permits acquisition only via the explicit pull endpoint \
274 (query-time auto-pull is not implemented). Run `POST /ai/models/{name}/pull` to install.",
275 descriptor.status
276 ),
277 };
278 return Err(RedDBError::NotFound(message));
279 }
280 if descriptor.dimensions == 0 {
281 return Err(RedDBError::Query(format!(
282 "model '{name}' registry entry has dimensions=0; re-register with the model's true output width"
283 )));
284 }
285 Ok(descriptor.dimensions)
286}
287
288pub fn embed_local_with_db(
293 db: &RedDB,
294 name: &str,
295 inputs: Vec<String>,
296) -> RedDBResult<LocalEmbeddingResponse> {
297 if inputs.is_empty() {
298 return Err(RedDBError::Query(
299 "at least one input is required for local embeddings".to_string(),
300 ));
301 }
302 let name = name.trim();
303 if name.is_empty() {
304 return Err(RedDBError::Query(
305 "local embedding 'model' field cannot be empty; pass the registered local model name"
306 .to_string(),
307 ));
308 }
309
310 let backend = match current_backend() {
311 Some(b) => b,
312 None => {
313 if cfg!(feature = "local-models") {
314 let fake: Arc<dyn LocalEmbeddingBackend> = Arc::new(DeterministicFakeBackend);
318 install_local_embedding_backend(Arc::clone(&fake));
319 fake
320 } else {
321 return Err(RedDBError::FeatureNotEnabled(
322 LOCAL_MODELS_DISABLED_MESSAGE.to_string(),
323 ));
324 }
325 }
326 };
327
328 let descriptor = read_model_descriptor(db, name)?;
329
330 if descriptor.provider != PROVIDER_LOCAL {
331 return Err(RedDBError::Query(format!(
332 "model '{name}' has provider '{}'; only '{PROVIDER_LOCAL}' is supported by local embedding routing",
333 descriptor.provider
334 )));
335 }
336 if descriptor.task != TASK_EMBEDDING {
337 return Err(RedDBError::Query(format!(
338 "model '{name}' has task '{}'; only '{TASK_EMBEDDING}' is supported by the local provider \
339 (prompt/generation are out of scope)",
340 descriptor.task
341 )));
342 }
343 if descriptor.status != STATUS_INSTALLED {
344 let message = match descriptor.pull_policy {
349 PULL_POLICY_NEVER => format!(
350 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
351 pull_policy='never' forbids runtime acquisition. An operator must explicitly install \
352 the model via `POST /ai/models/{name}/pull`.",
353 descriptor.status
354 ),
355 PULL_POLICY_ALWAYS => format!(
356 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
357 pull_policy='always' is configured but query-time auto-pull is not implemented in this slice. \
358 Trigger a refresh via `POST /ai/models/{name}/pull` before requesting embeddings.",
359 descriptor.status
360 ),
361 _ => format!(
363 "local model '{name}' is registered (status='{}') but its artifacts are not installed; \
364 pull_policy='if_missing' permits acquisition only via the explicit pull endpoint \
365 (query-time auto-pull is not implemented). Run `POST /ai/models/{name}/pull` to install.",
366 descriptor.status
367 ),
368 };
369 return Err(RedDBError::NotFound(message));
370 }
371
372 let request = LocalEmbeddingRequest {
373 name: descriptor.name.clone(),
374 source: descriptor.source.clone(),
375 revision: descriptor.revision.clone(),
376 engine: descriptor.engine.clone(),
377 dimensions: descriptor.dimensions,
378 inputs,
379 };
380 let embeddings = backend.embed(&request)?;
381
382 if embeddings.len() != request.inputs.len() {
383 return Err(RedDBError::Query(format!(
384 "local backend returned {} embeddings for {} inputs",
385 embeddings.len(),
386 request.inputs.len()
387 )));
388 }
389 for (idx, row) in embeddings.iter().enumerate() {
390 if row.len() != descriptor.dimensions as usize {
391 return Err(RedDBError::Query(format!(
392 "local backend returned embedding[{idx}] of length {} but model '{name}' \
393 was registered with dimensions={}",
394 row.len(),
395 descriptor.dimensions
396 )));
397 }
398 }
399
400 Ok(LocalEmbeddingResponse {
401 provider: PROVIDER_LOCAL,
402 name: descriptor.name,
403 source: descriptor.source,
404 revision: descriptor.revision,
405 engine: descriptor.engine,
406 dimensions: descriptor.dimensions,
407 embeddings,
408 })
409}
410
411#[derive(Debug, Clone)]
412struct ModelDescriptor {
413 name: String,
414 provider: String,
415 source: String,
416 revision: String,
417 engine: String,
418 task: String,
419 status: String,
420 dimensions: u32,
421 pull_policy: &'static str,
425}
426
427fn read_model_descriptor(db: &RedDB, name: &str) -> RedDBResult<ModelDescriptor> {
428 let key = format!("{AI_MODEL_KEY_PREFIX}{name}");
429 let raw = match db.get_kv(RED_CONFIG_COLLECTION, &key) {
430 Some((Value::Text(text), _)) => text.to_string(),
431 Some(_) => {
432 return Err(RedDBError::Query(format!(
433 "local model registry entry for '{name}' is not a JSON text payload"
434 )));
435 }
436 None => {
437 return Err(RedDBError::NotFound(format!(
438 "local model '{name}' is not registered; POST /ai/models to register it first"
439 )));
440 }
441 };
442 let parsed = parse_json(&raw).map_err(|err| {
443 RedDBError::Query(format!(
444 "local model registry entry for '{name}' is corrupted: {err}"
445 ))
446 })?;
447 let value = JsonValue::from(parsed);
448 let object = value
449 .as_object()
450 .ok_or_else(|| RedDBError::Query(format!("model entry for '{name}' is not an object")))?;
451
452 let pick = |key: &str| -> Option<String> {
453 object
454 .get(key)
455 .and_then(JsonValue::as_str)
456 .map(str::to_string)
457 };
458
459 let provider = pick("provider").unwrap_or_else(|| PROVIDER_LOCAL.to_string());
460 let source = pick("source").unwrap_or_default();
461 let revision = pick("revision").unwrap_or_default();
462 let engine = pick("engine").unwrap_or_default();
463 let task = pick("task").unwrap_or_default();
464 let status = pick("status").unwrap_or_default();
465 let dimensions = object
466 .get("dimensions")
467 .and_then(JsonValue::as_u64)
468 .ok_or_else(|| {
469 RedDBError::Query(format!("model entry for '{name}' is missing 'dimensions'"))
470 })? as u32;
471 let pull_policy = normalize_stored_pull_policy(
472 pick("pull_policy")
473 .as_deref()
474 .unwrap_or(PULL_POLICY_IF_MISSING),
475 );
476
477 Ok(ModelDescriptor {
478 name: pick("name").unwrap_or_else(|| name.to_string()),
479 provider,
480 source,
481 revision,
482 engine,
483 task,
484 status,
485 dimensions,
486 pull_policy,
487 })
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
495 fn deterministic_fake_is_pure_and_correct_length() {
496 let backend = DeterministicFakeBackend;
497 let req = LocalEmbeddingRequest {
498 name: "mini".to_string(),
499 source: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
500 revision: "main".to_string(),
501 engine: "candle".to_string(),
502 dimensions: 16,
503 inputs: vec!["hello".to_string(), "world".to_string()],
504 };
505 let a = backend.embed(&req).expect("embed");
506 let b = backend.embed(&req).expect("embed twice");
507 assert_eq!(a, b, "deterministic backend must be pure");
508 assert_eq!(a.len(), 2);
509 assert!(a.iter().all(|v| v.len() == 16));
510 assert_ne!(
511 a[0], a[1],
512 "different inputs must produce different vectors"
513 );
514 }
515
516 #[test]
517 fn deterministic_fake_changes_with_model_name() {
518 let backend = DeterministicFakeBackend;
519 let mk = |name: &str| LocalEmbeddingRequest {
520 name: name.to_string(),
521 source: String::new(),
522 revision: String::new(),
523 engine: String::new(),
524 dimensions: 8,
525 inputs: vec!["x".to_string()],
526 };
527 let a = backend.embed(&mk("alpha")).unwrap();
528 let b = backend.embed(&mk("beta")).unwrap();
529 assert_ne!(a[0], b[0]);
530 }
531}