Skip to main content

brainos_storage/
ruvector.rs

1//! RuVector — backed by ruvector-core HNSW vector database.
2//!
3//! Wraps [`ruvector_core::VectorDB`] with the multi-table interface that the
4//! rest of Brain uses.  Each logical table maps to one `VectorDB` persisted at
5//! `<root>/<table_name>.db`.
6//!
7//! # Storage layout
8//! ```text
9//! ~/.brain/ruvector/
10//!   facts_vec.db     -- semantic fact vectors (HNSW)
11//!   episodes_vec.db  -- episode vectors (HNSW)
12//! ```
13
14use std::{
15    collections::HashMap,
16    path::{Path, PathBuf},
17    sync::{Arc, RwLock},
18};
19
20use thiserror::Error;
21use tracing::{info, warn};
22
23use ruvector_core::{
24    types::{DbOptions, HnswConfig as RuvHnswConfig},
25    DistanceMetric, SearchQuery, VectorDB, VectorEntry,
26};
27
28/// Default vector dimension.
29/// Override by passing the actual embedding model dimension to [`RuVectorStore::open`].
30pub const VECTOR_DIM: usize = 768;
31const VECTOR_NORM_EPS: f32 = 1e-12;
32const INSERT_JITTER_EPS: f32 = 1e-2;
33
34// ─── Errors ──────────────────────────────────────────────────────────────────
35
36#[derive(Debug, Error)]
37pub enum RuVectorError {
38    #[error("Vector DB error: {0}")]
39    Db(String),
40
41    #[error("Table not found: {0}")]
42    TableNotFound(String),
43
44    #[error("IO error: {0}")]
45    Io(#[from] std::io::Error),
46
47    #[error("Lock poisoned")]
48    LockPoisoned,
49}
50
51impl From<ruvector_core::error::RuvectorError> for RuVectorError {
52    fn from(e: ruvector_core::error::RuvectorError) -> Self {
53        RuVectorError::Db(e.to_string())
54    }
55}
56
57// ─── Public result type ───────────────────────────────────────────────────────
58
59/// A single vector search result.
60#[derive(Debug, Clone)]
61pub struct VectorResult {
62    /// ID of the stored vector (matches the fact/episode ID in SQLite).
63    pub id: String,
64    /// Cosine distance (lower = more similar).
65    pub distance: f32,
66}
67
68/// Tuning knobs for the underlying ruvector HNSW index. Mirrors the
69/// fields in `brain::HnswConfig` but lives here so the storage
70/// crate can stay independent of `brain`. Callers convert at the
71/// boundary.
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub struct HnswConfig {
74    pub m: u32,
75    pub ef_construction: u32,
76    pub ef_search: u32,
77    pub max_elements: u32,
78}
79
80impl Default for HnswConfig {
81    fn default() -> Self {
82        Self {
83            m: 16,
84            ef_construction: 200,
85            ef_search: 50,
86            // 100k vectors. HNSW pre-allocates the index graph for this
87            // many entries up-front so this is a real memory cost, not
88            // just a cap (Wave F, Issue 71). The canonical default lives
89            // in `core::config::HnswConfig::default_max_elements`; this
90            // mirror exists so storage-only callers (tests, isolated
91            // RuVector users) don't have to pull in the core crate.
92            max_elements: 100_000,
93        }
94    }
95}
96
97// ─── Store ───────────────────────────────────────────────────────────────────
98
99/// RuVector store — manages multiple per-table `VectorDB` instances.
100#[derive(Clone)]
101pub struct RuVectorStore {
102    root: PathBuf,
103    /// Dimension of the embedding vectors (must match the active embedding model).
104    dimensions: usize,
105    hnsw: HnswConfig,
106    tables: Arc<RwLock<HashMap<String, VectorDB>>>,
107}
108
109impl RuVectorStore {
110    /// Open (or create) a RuVector store at the given directory.
111    ///
112    /// `dimensions` must equal the output dimension of the embedding model in use.
113    /// Passing the wrong dimension will cause `Dimension mismatch` errors on insert.
114    /// Use [`VECTOR_DIM`] as the default (384) when the embedding provider is not
115    /// yet known, and prefer probing the actual embedder output at startup.
116    pub async fn open(path: &Path, dimensions: usize) -> Result<Self, RuVectorError> {
117        Self::open_with_config(path, dimensions, HnswConfig::default()).await
118    }
119
120    /// Open with explicit HNSW tuning (Issue 37). Threading the knobs
121    /// through here lets `brain serve` honour `storage.hnsw.*` from
122    /// config instead of locking every install to the hardcoded
123    /// defaults baked into this crate.
124    pub async fn open_with_config(
125        path: &Path,
126        dimensions: usize,
127        hnsw: HnswConfig,
128    ) -> Result<Self, RuVectorError> {
129        std::fs::create_dir_all(path)?;
130        info!(
131            m = hnsw.m,
132            ef_construction = hnsw.ef_construction,
133            ef_search = hnsw.ef_search,
134            max_elements = hnsw.max_elements,
135            "RuVector store opened at {} (dim={})",
136            path.display(),
137            dimensions
138        );
139        Ok(Self {
140            root: path.to_path_buf(),
141            dimensions,
142            hnsw,
143            tables: Arc::new(RwLock::new(HashMap::new())),
144        })
145    }
146
147    fn make_db(&self, table_name: &str) -> Result<VectorDB, RuVectorError> {
148        let db_path = self.root.join(format!("{table_name}.db"));
149        let options = DbOptions {
150            dimensions: self.dimensions,
151            distance_metric: DistanceMetric::Cosine,
152            storage_path: db_path.to_string_lossy().into_owned(),
153            hnsw_config: Some(RuvHnswConfig {
154                m: self.hnsw.m as usize,
155                ef_construction: self.hnsw.ef_construction as usize,
156                ef_search: self.hnsw.ef_search as usize,
157                max_elements: self.hnsw.max_elements as usize,
158            }),
159            quantization: None,
160        };
161        VectorDB::new(options).map_err(Into::into)
162    }
163
164    fn get_or_create_db(&self, table_name: &str) -> Result<(), RuVectorError> {
165        let has = self
166            .tables
167            .read()
168            .map_err(|_| RuVectorError::LockPoisoned)?
169            .contains_key(table_name);
170
171        if !has {
172            let db = self.make_db(table_name)?;
173            self.tables
174                .write()
175                .map_err(|_| RuVectorError::LockPoisoned)?
176                .insert(table_name.to_string(), db);
177        }
178        Ok(())
179    }
180
181    /// Ensure the standard vector tables exist (idempotent).
182    ///
183    /// Retries up to 5 times with exponential backoff (200ms → 3.2s) to handle
184    /// transient file-lock contention from other `brain` processes (e.g. a
185    /// standalone `brain mcp` holding the redb lock when the daemon starts).
186    pub async fn ensure_tables(&self) -> Result<(), RuVectorError> {
187        const MAX_RETRIES: u32 = 5;
188        const BASE_DELAY_MS: u64 = 200;
189
190        for name in &["facts_vec", "episodes_vec", "graph_vec"] {
191            let mut last_err = None;
192            for attempt in 0..=MAX_RETRIES {
193                match self.get_or_create_db(name) {
194                    Ok(()) => {
195                        if attempt > 0 {
196                            info!("RuVector table '{name}' opened after {attempt} retries");
197                        } else {
198                            info!("Ensured RuVector table: {name}");
199                        }
200                        last_err = None;
201                        break;
202                    }
203                    Err(e) if attempt < MAX_RETRIES => {
204                        let delay_ms = BASE_DELAY_MS * 2u64.pow(attempt);
205                        warn!(
206                            table = name,
207                            attempt = attempt + 1,
208                            delay_ms,
209                            error = %e,
210                            "RuVector table lock contention, retrying"
211                        );
212                        tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
213                        last_err = Some(e);
214                    }
215                    Err(e) => {
216                        last_err = Some(e);
217                    }
218                }
219            }
220            if let Some(e) = last_err {
221                return Err(e);
222            }
223        }
224        Ok(())
225    }
226
227    /// Add vectors to a table. `ids` and `vectors` must have the same length.
228    pub async fn add_vectors(
229        &self,
230        table_name: &str,
231        ids: Vec<String>,
232        _contents: Vec<String>,
233        vectors: Vec<Vec<f32>>,
234        _timestamps: Vec<String>,
235        _source_type: &str,
236    ) -> Result<(), RuVectorError> {
237        self.get_or_create_db(table_name)?;
238        let tables = self
239            .tables
240            .read()
241            .map_err(|_| RuVectorError::LockPoisoned)?;
242        let db = tables
243            .get(table_name)
244            .ok_or_else(|| RuVectorError::TableNotFound(table_name.to_string()))?;
245
246        let count = ids.len();
247        for (id, vector) in ids.into_iter().zip(vectors) {
248            let safe_vector = sanitize_vector_for_insert(vector, self.dimensions, &id);
249            let entry = VectorEntry {
250                id: Some(id),
251                vector: safe_vector,
252                metadata: None,
253            };
254            db.insert(entry)?;
255        }
256        info!("Added {count} vectors to '{table_name}'");
257        Ok(())
258    }
259
260    /// Search for the most similar vectors using cosine distance.
261    ///
262    /// Returns results sorted by distance ascending (closest first).
263    pub async fn search(
264        &self,
265        table_name: &str,
266        query_vector: Vec<f32>,
267        top_k: usize,
268    ) -> Result<Vec<VectorResult>, RuVectorError> {
269        // Lazy-open on first use so search does not fail with TableNotFound
270        // when callers skipped an explicit ensure_tables() step.
271        self.get_or_create_db(table_name)?;
272
273        let tables = self
274            .tables
275            .read()
276            .map_err(|_| RuVectorError::LockPoisoned)?;
277        let db = tables
278            .get(table_name)
279            .ok_or_else(|| RuVectorError::TableNotFound(table_name.to_string()))?;
280
281        let safe_query = sanitize_vector_for_query(query_vector, self.dimensions, table_name);
282        let results = db.search(SearchQuery {
283            vector: safe_query,
284            k: top_k,
285            filter: None,
286            ef_search: None,
287        })?;
288
289        Ok(results
290            .into_iter()
291            .map(|r| VectorResult {
292                id: r.id,
293                distance: sanitize_distance(r.score),
294            })
295            .collect())
296    }
297
298    /// Delete a vector by ID from a table.
299    pub async fn delete(&self, table_name: &str, id: &str) -> Result<(), RuVectorError> {
300        let tables = self
301            .tables
302            .read()
303            .map_err(|_| RuVectorError::LockPoisoned)?;
304        if let Some(db) = tables.get(table_name) {
305            db.delete(id)?;
306        }
307        Ok(())
308    }
309
310    /// Delete many vectors by ID in one read-lock cycle. Used by
311    /// `handle_forget` so a batch of N matched facts collapses from N
312    /// `delete` calls (each grabbing + releasing the tables lock) to
313    /// one. Per-id failures are collected and returned as a `Vec` of
314    /// `(id, error)` pairs so the caller can decide whether to abort or
315    /// continue; the SQLite side has its own batch flow that doesn't
316    /// short-circuit on a single failure either.
317    pub async fn delete_batch(
318        &self,
319        table_name: &str,
320        ids: &[&str],
321    ) -> Result<Vec<(String, RuVectorError)>, RuVectorError> {
322        let tables = self
323            .tables
324            .read()
325            .map_err(|_| RuVectorError::LockPoisoned)?;
326        let mut failures = Vec::new();
327        if let Some(db) = tables.get(table_name) {
328            for id in ids {
329                if let Err(e) = db.delete(id) {
330                    failures.push(((*id).to_string(), RuVectorError::from(e)));
331                }
332            }
333        }
334        Ok(failures)
335    }
336
337    /// Get the row count for a table.
338    pub async fn table_count(&self, table_name: &str) -> Result<usize, RuVectorError> {
339        let tables = self
340            .tables
341            .read()
342            .map_err(|_| RuVectorError::LockPoisoned)?;
343        Ok(tables
344            .get(table_name)
345            .map(|db| db.len().unwrap_or(0))
346            .unwrap_or(0))
347    }
348
349    /// List all open table names.
350    pub async fn table_names(&self) -> Result<Vec<String>, RuVectorError> {
351        Ok(self
352            .tables
353            .read()
354            .map_err(|_| RuVectorError::LockPoisoned)?
355            .keys()
356            .cloned()
357            .collect())
358    }
359}
360
361fn sanitize_distance(score: f32) -> f32 {
362    if !score.is_finite() {
363        return f32::MAX;
364    }
365    if score < 0.0 {
366        return 0.0;
367    }
368    score
369}
370
371fn sanitize_vector_for_insert(vector: Vec<f32>, dimensions: usize, id: &str) -> Vec<f32> {
372    let mut out = sanitize_vector_for_query(vector, dimensions, id);
373    apply_insert_jitter(&mut out, id);
374    normalize_in_place_or_fallback(&mut out, id);
375    out
376}
377
378fn sanitize_vector_for_query(vector: Vec<f32>, dimensions: usize, seed: &str) -> Vec<f32> {
379    if dimensions == 0 {
380        return Vec::new();
381    }
382    if vector.len() != dimensions || vector.iter().any(|x| !x.is_finite()) {
383        warn!(
384            expected_dim = dimensions,
385            got_dim = vector.len(),
386            "Invalid embedding shape/value; using deterministic fallback"
387        );
388        return deterministic_fallback_vector(seed, dimensions);
389    }
390
391    let mut out = vector;
392    if !normalize_in_place_or_fallback(&mut out, seed) {
393        return deterministic_fallback_vector(seed, dimensions);
394    }
395    out
396}
397
398fn normalize_in_place_or_fallback(vector: &mut [f32], seed: &str) -> bool {
399    if vector.is_empty() {
400        return true;
401    }
402
403    let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
404    if !norm_sq.is_finite() || norm_sq <= VECTOR_NORM_EPS {
405        let fallback = deterministic_fallback_vector(seed, vector.len());
406        vector.copy_from_slice(&fallback);
407        return false;
408    }
409
410    let norm = norm_sq.sqrt();
411    for v in vector.iter_mut() {
412        *v /= norm;
413    }
414    true
415}
416
417fn apply_insert_jitter(vector: &mut [f32], id: &str) {
418    if vector.is_empty() {
419        return;
420    }
421
422    // Deterministic id-based perturbation to avoid pathological duplicate vectors.
423    let mut hash: u64 = 0xcbf29ce484222325;
424    for b in id.as_bytes() {
425        hash ^= u64::from(*b);
426        hash = hash.wrapping_mul(0x100000001b3);
427    }
428
429    let idx_a = (hash as usize) % vector.len();
430    let idx_b = (hash.rotate_left(17) as usize) % vector.len();
431    let sign_a = if (hash & 1) == 0 { 1.0 } else { -1.0 };
432    let sign_b = if ((hash >> 1) & 1) == 0 { -1.0 } else { 1.0 };
433    vector[idx_a] += sign_a * INSERT_JITTER_EPS;
434    vector[idx_b] += sign_b * INSERT_JITTER_EPS * 0.5;
435}
436
437fn deterministic_fallback_vector(seed: &str, dimensions: usize) -> Vec<f32> {
438    if dimensions == 0 {
439        return Vec::new();
440    }
441
442    let mut state: u64 = 0xcbf29ce484222325;
443    for b in seed.as_bytes() {
444        state ^= u64::from(*b);
445        state = state.wrapping_mul(0x100000001b3);
446    }
447    if state == 0 {
448        state = 1;
449    }
450
451    let mut out = Vec::with_capacity(dimensions);
452    for _ in 0..dimensions {
453        state ^= state >> 12;
454        state ^= state << 25;
455        state ^= state >> 27;
456        let r = state.wrapping_mul(0x2545f4914f6cdd1d);
457        let unit = (r as f64 / u64::MAX as f64) as f32;
458        out.push(unit * 2.0 - 1.0);
459    }
460
461    let norm = out.iter().map(|x| x * x).sum::<f32>().sqrt();
462    if !norm.is_finite() || norm <= VECTOR_NORM_EPS {
463        let mut unit = vec![0.0_f32; dimensions];
464        unit[0] = 1.0;
465        return unit;
466    }
467    for v in &mut out {
468        *v /= norm;
469    }
470    out
471}
472
473// ─── Tests ───────────────────────────────────────────────────────────────────
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    async fn temp_store() -> (RuVectorStore, tempfile::TempDir) {
480        let dir = tempfile::tempdir().unwrap();
481        let store = RuVectorStore::open(dir.path(), VECTOR_DIM).await.unwrap();
482        (store, dir)
483    }
484
485    /// Issue 37 regression: `open_with_config` preserves the supplied
486    /// tuning, and `open` falls back to defaults. We don't probe the
487    /// underlying ruvector internals (no public accessor) — but the
488    /// stored `hnsw` field is the source of truth that `make_db` reads.
489    #[tokio::test]
490    async fn open_with_config_persists_tuning() {
491        let dir = tempfile::tempdir().unwrap();
492        let custom = HnswConfig {
493            m: 32,
494            ef_construction: 400,
495            ef_search: 100,
496            max_elements: 5_000_000,
497        };
498        let store = RuVectorStore::open_with_config(dir.path(), VECTOR_DIM, custom)
499            .await
500            .unwrap();
501        assert_eq!(store.hnsw, custom);
502
503        let default_store = RuVectorStore::open(dir.path(), VECTOR_DIM).await.unwrap();
504        assert_eq!(default_store.hnsw, HnswConfig::default());
505    }
506
507    fn unit_vec(axis: usize) -> Vec<f32> {
508        let mut v = vec![0.0f32; VECTOR_DIM];
509        v[axis] = 1.0;
510        v
511    }
512
513    #[tokio::test]
514    async fn test_open_and_ensure_tables() {
515        let (store, _dir) = temp_store().await;
516        store.ensure_tables().await.unwrap();
517
518        let mut tables = store.table_names().await.unwrap();
519        tables.sort();
520        assert!(tables.contains(&"episodes_vec".to_string()));
521        assert!(tables.contains(&"facts_vec".to_string()));
522    }
523
524    #[tokio::test]
525    async fn test_ensure_tables_idempotent() {
526        let (store, _dir) = temp_store().await;
527        store.ensure_tables().await.unwrap();
528        store.ensure_tables().await.unwrap();
529    }
530
531    #[tokio::test]
532    async fn test_add_and_count() {
533        let (store, _dir) = temp_store().await;
534        store.ensure_tables().await.unwrap();
535
536        store
537            .add_vectors(
538                "episodes_vec",
539                vec!["ep001".into()],
540                vec![],
541                vec![unit_vec(0)],
542                vec![],
543                "episodic",
544            )
545            .await
546            .unwrap();
547
548        assert_eq!(store.table_count("episodes_vec").await.unwrap(), 1);
549    }
550
551    #[tokio::test]
552    async fn test_vector_search() {
553        let (store, _dir) = temp_store().await;
554        store.ensure_tables().await.unwrap();
555
556        let v1 = unit_vec(0);
557        let v2 = unit_vec(1);
558        let mut v3 = vec![0.0f32; VECTOR_DIM];
559        v3[0] = 0.9;
560        v3[1] = 0.1;
561
562        store
563            .add_vectors(
564                "facts_vec",
565                vec!["f1".into(), "f2".into(), "f3".into()],
566                vec![],
567                vec![v1.clone(), v2, v3],
568                vec![],
569                "semantic",
570            )
571            .await
572            .unwrap();
573
574        let results = store.search("facts_vec", v1, 2).await.unwrap();
575        assert!(!results.is_empty());
576        assert_eq!(results[0].id, "f1");
577    }
578
579    #[tokio::test]
580    async fn test_delete() {
581        let (store, _dir) = temp_store().await;
582        store.ensure_tables().await.unwrap();
583
584        store
585            .add_vectors(
586                "facts_vec",
587                vec!["f1".into()],
588                vec![],
589                vec![unit_vec(0)],
590                vec![],
591                "semantic",
592            )
593            .await
594            .unwrap();
595
596        assert_eq!(store.table_count("facts_vec").await.unwrap(), 1);
597        store.delete("facts_vec", "f1").await.unwrap();
598        assert_eq!(store.table_count("facts_vec").await.unwrap(), 0);
599    }
600
601    #[tokio::test]
602    async fn test_identical_vectors_with_different_ids_do_not_panic() {
603        let (store, _dir) = temp_store().await;
604        store.ensure_tables().await.unwrap();
605
606        let repeated = unit_vec(0);
607        for i in 0..64 {
608            store
609                .add_vectors(
610                    "facts_vec",
611                    vec![format!("dup-{i}")],
612                    vec![],
613                    vec![repeated.clone()],
614                    vec![],
615                    "semantic",
616                )
617                .await
618                .unwrap();
619        }
620
621        let results = store.search("facts_vec", unit_vec(0), 5).await.unwrap();
622        assert!(!results.is_empty());
623        assert!(results.iter().all(|r| r.distance.is_finite()));
624    }
625
626    #[tokio::test]
627    async fn test_invalid_or_zero_vectors_are_sanitized() {
628        let (store, _dir) = temp_store().await;
629        store.ensure_tables().await.unwrap();
630
631        store
632            .add_vectors(
633                "facts_vec",
634                vec!["zero".into(), "nan".into()],
635                vec![],
636                vec![vec![0.0_f32; VECTOR_DIM], vec![f32::NAN; VECTOR_DIM]],
637                vec![],
638                "semantic",
639            )
640            .await
641            .unwrap();
642
643        let results = store
644            .search("facts_vec", vec![0.0_f32; VECTOR_DIM], 2)
645            .await
646            .unwrap();
647        assert_eq!(results.len(), 2);
648        assert!(results.iter().all(|r| r.distance.is_finite()));
649        assert!(results.iter().all(|r| r.distance >= 0.0));
650    }
651}