Skip to main content

brainos_storage/
ruvector.rs

1//! RuVector — backed by ruvector-core HNSW vector database.
2//!
3//! Wraps [`ruvector_core::VectorDB`] with the multi-table interface that the
4//! rest of Brain uses.  Each logical table maps to one `VectorDB` persisted at
5//! `<root>/<table_name>.db`.
6//!
7//! # Storage layout
8//! ```text
9//! ~/.brain/ruvector/
10//!   facts_vec.db     -- semantic fact vectors (HNSW)
11//!   episodes_vec.db  -- episode vectors (HNSW)
12//! ```
13
14use std::{
15    collections::HashMap,
16    path::{Path, PathBuf},
17    sync::{Arc, RwLock},
18};
19
20use thiserror::Error;
21use tracing::{info, warn};
22
23use ruvector_core::{
24    types::{DbOptions, HnswConfig as RuvHnswConfig},
25    DistanceMetric, SearchQuery, VectorDB, VectorEntry,
26};
27
28/// Default vector dimension.
29/// Override by passing the actual embedding model dimension to [`RuVectorStore::open`].
30pub const VECTOR_DIM: usize = 768;
31const VECTOR_NORM_EPS: f32 = 1e-12;
32const INSERT_JITTER_EPS: f32 = 1e-2;
33
34// ─── Errors ──────────────────────────────────────────────────────────────────
35
36#[derive(Debug, Error)]
37pub enum RuVectorError {
38    #[error("Vector DB error: {0}")]
39    Db(String),
40
41    #[error("Table not found: {0}")]
42    TableNotFound(String),
43
44    #[error("IO error: {0}")]
45    Io(#[from] std::io::Error),
46
47    #[error("Lock poisoned")]
48    LockPoisoned,
49}
50
51impl From<ruvector_core::error::RuvectorError> for RuVectorError {
52    fn from(e: ruvector_core::error::RuvectorError) -> Self {
53        RuVectorError::Db(e.to_string())
54    }
55}
56
57// ─── Public result type ───────────────────────────────────────────────────────
58
59/// A single vector search result.
60#[derive(Debug, Clone)]
61pub struct VectorResult {
62    /// ID of the stored vector (matches the fact/episode ID in SQLite).
63    pub id: String,
64    /// Cosine distance (lower = more similar).
65    pub distance: f32,
66}
67
68// ─── Store ───────────────────────────────────────────────────────────────────
69
70/// RuVector store — manages multiple per-table `VectorDB` instances.
71#[derive(Clone)]
72pub struct RuVectorStore {
73    root: PathBuf,
74    /// Dimension of the embedding vectors (must match the active embedding model).
75    dimensions: usize,
76    tables: Arc<RwLock<HashMap<String, VectorDB>>>,
77}
78
79impl RuVectorStore {
80    /// Open (or create) a RuVector store at the given directory.
81    ///
82    /// `dimensions` must equal the output dimension of the embedding model in use.
83    /// Passing the wrong dimension will cause `Dimension mismatch` errors on insert.
84    /// Use [`VECTOR_DIM`] as the default (384) when the embedding provider is not
85    /// yet known, and prefer probing the actual embedder output at startup.
86    pub async fn open(path: &Path, dimensions: usize) -> Result<Self, RuVectorError> {
87        std::fs::create_dir_all(path)?;
88        info!(
89            "RuVector store opened at {} (dim={})",
90            path.display(),
91            dimensions
92        );
93        Ok(Self {
94            root: path.to_path_buf(),
95            dimensions,
96            tables: Arc::new(RwLock::new(HashMap::new())),
97        })
98    }
99
100    fn make_db(&self, table_name: &str) -> Result<VectorDB, RuVectorError> {
101        let db_path = self.root.join(format!("{table_name}.db"));
102        let options = DbOptions {
103            dimensions: self.dimensions,
104            distance_metric: DistanceMetric::Cosine,
105            storage_path: db_path.to_string_lossy().into_owned(),
106            hnsw_config: Some(RuvHnswConfig {
107                m: 16,
108                ef_construction: 200,
109                ef_search: 50,
110                max_elements: 10_000_000,
111            }),
112            quantization: None,
113        };
114        VectorDB::new(options).map_err(Into::into)
115    }
116
117    fn get_or_create_db(&self, table_name: &str) -> Result<(), RuVectorError> {
118        let has = self
119            .tables
120            .read()
121            .map_err(|_| RuVectorError::LockPoisoned)?
122            .contains_key(table_name);
123
124        if !has {
125            let db = self.make_db(table_name)?;
126            self.tables
127                .write()
128                .map_err(|_| RuVectorError::LockPoisoned)?
129                .insert(table_name.to_string(), db);
130        }
131        Ok(())
132    }
133
134    /// Ensure the standard vector tables exist (idempotent).
135    ///
136    /// Retries up to 5 times with exponential backoff (200ms → 3.2s) to handle
137    /// transient file-lock contention from other `brain` processes (e.g. a
138    /// standalone `brain mcp` holding the redb lock when the daemon starts).
139    pub async fn ensure_tables(&self) -> Result<(), RuVectorError> {
140        const MAX_RETRIES: u32 = 5;
141        const BASE_DELAY_MS: u64 = 200;
142
143        for name in &["facts_vec", "episodes_vec"] {
144            let mut last_err = None;
145            for attempt in 0..=MAX_RETRIES {
146                match self.get_or_create_db(name) {
147                    Ok(()) => {
148                        if attempt > 0 {
149                            info!("RuVector table '{name}' opened after {attempt} retries");
150                        } else {
151                            info!("Ensured RuVector table: {name}");
152                        }
153                        last_err = None;
154                        break;
155                    }
156                    Err(e) if attempt < MAX_RETRIES => {
157                        let delay_ms = BASE_DELAY_MS * 2u64.pow(attempt);
158                        warn!(
159                            table = name,
160                            attempt = attempt + 1,
161                            delay_ms,
162                            error = %e,
163                            "RuVector table lock contention, retrying"
164                        );
165                        tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
166                        last_err = Some(e);
167                    }
168                    Err(e) => {
169                        last_err = Some(e);
170                    }
171                }
172            }
173            if let Some(e) = last_err {
174                return Err(e);
175            }
176        }
177        Ok(())
178    }
179
180    /// Add vectors to a table. `ids` and `vectors` must have the same length.
181    pub async fn add_vectors(
182        &self,
183        table_name: &str,
184        ids: Vec<String>,
185        _contents: Vec<String>,
186        vectors: Vec<Vec<f32>>,
187        _timestamps: Vec<String>,
188        _source_type: &str,
189    ) -> Result<(), RuVectorError> {
190        self.get_or_create_db(table_name)?;
191        let tables = self
192            .tables
193            .read()
194            .map_err(|_| RuVectorError::LockPoisoned)?;
195        let db = tables
196            .get(table_name)
197            .ok_or_else(|| RuVectorError::TableNotFound(table_name.to_string()))?;
198
199        let count = ids.len();
200        for (id, vector) in ids.into_iter().zip(vectors) {
201            let safe_vector = sanitize_vector_for_insert(vector, self.dimensions, &id);
202            let entry = VectorEntry {
203                id: Some(id),
204                vector: safe_vector,
205                metadata: None,
206            };
207            db.insert(entry)?;
208        }
209        info!("Added {count} vectors to '{table_name}'");
210        Ok(())
211    }
212
213    /// Search for the most similar vectors using cosine distance.
214    ///
215    /// Returns results sorted by distance ascending (closest first).
216    pub async fn search(
217        &self,
218        table_name: &str,
219        query_vector: Vec<f32>,
220        top_k: usize,
221    ) -> Result<Vec<VectorResult>, RuVectorError> {
222        // Lazy-open on first use so search does not fail with TableNotFound
223        // when callers skipped an explicit ensure_tables() step.
224        self.get_or_create_db(table_name)?;
225
226        let tables = self
227            .tables
228            .read()
229            .map_err(|_| RuVectorError::LockPoisoned)?;
230        let db = tables
231            .get(table_name)
232            .ok_or_else(|| RuVectorError::TableNotFound(table_name.to_string()))?;
233
234        let safe_query = sanitize_vector_for_query(query_vector, self.dimensions, table_name);
235        let results = db.search(SearchQuery {
236            vector: safe_query,
237            k: top_k,
238            filter: None,
239            ef_search: None,
240        })?;
241
242        Ok(results
243            .into_iter()
244            .map(|r| VectorResult {
245                id: r.id,
246                distance: sanitize_distance(r.score),
247            })
248            .collect())
249    }
250
251    /// Delete a vector by ID from a table.
252    pub async fn delete(&self, table_name: &str, id: &str) -> Result<(), RuVectorError> {
253        let tables = self
254            .tables
255            .read()
256            .map_err(|_| RuVectorError::LockPoisoned)?;
257        if let Some(db) = tables.get(table_name) {
258            db.delete(id)?;
259        }
260        Ok(())
261    }
262
263    /// Get the row count for a table.
264    pub async fn table_count(&self, table_name: &str) -> Result<usize, RuVectorError> {
265        let tables = self
266            .tables
267            .read()
268            .map_err(|_| RuVectorError::LockPoisoned)?;
269        Ok(tables
270            .get(table_name)
271            .map(|db| db.len().unwrap_or(0))
272            .unwrap_or(0))
273    }
274
275    /// List all open table names.
276    pub async fn table_names(&self) -> Result<Vec<String>, RuVectorError> {
277        Ok(self
278            .tables
279            .read()
280            .map_err(|_| RuVectorError::LockPoisoned)?
281            .keys()
282            .cloned()
283            .collect())
284    }
285}
286
287fn sanitize_distance(score: f32) -> f32 {
288    if !score.is_finite() {
289        return f32::MAX;
290    }
291    if score < 0.0 {
292        return 0.0;
293    }
294    score
295}
296
297fn sanitize_vector_for_insert(vector: Vec<f32>, dimensions: usize, id: &str) -> Vec<f32> {
298    let mut out = sanitize_vector_for_query(vector, dimensions, id);
299    apply_insert_jitter(&mut out, id);
300    normalize_in_place_or_fallback(&mut out, id);
301    out
302}
303
304fn sanitize_vector_for_query(vector: Vec<f32>, dimensions: usize, seed: &str) -> Vec<f32> {
305    if dimensions == 0 {
306        return Vec::new();
307    }
308    if vector.len() != dimensions || vector.iter().any(|x| !x.is_finite()) {
309        warn!(
310            expected_dim = dimensions,
311            got_dim = vector.len(),
312            "Invalid embedding shape/value; using deterministic fallback"
313        );
314        return deterministic_fallback_vector(seed, dimensions);
315    }
316
317    let mut out = vector;
318    if !normalize_in_place_or_fallback(&mut out, seed) {
319        return deterministic_fallback_vector(seed, dimensions);
320    }
321    out
322}
323
324fn normalize_in_place_or_fallback(vector: &mut [f32], seed: &str) -> bool {
325    if vector.is_empty() {
326        return true;
327    }
328
329    let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
330    if !norm_sq.is_finite() || norm_sq <= VECTOR_NORM_EPS {
331        let fallback = deterministic_fallback_vector(seed, vector.len());
332        vector.copy_from_slice(&fallback);
333        return false;
334    }
335
336    let norm = norm_sq.sqrt();
337    for v in vector.iter_mut() {
338        *v /= norm;
339    }
340    true
341}
342
343fn apply_insert_jitter(vector: &mut [f32], id: &str) {
344    if vector.is_empty() {
345        return;
346    }
347
348    // Deterministic id-based perturbation to avoid pathological duplicate vectors.
349    let mut hash: u64 = 0xcbf29ce484222325;
350    for b in id.as_bytes() {
351        hash ^= u64::from(*b);
352        hash = hash.wrapping_mul(0x100000001b3);
353    }
354
355    let idx_a = (hash as usize) % vector.len();
356    let idx_b = (hash.rotate_left(17) as usize) % vector.len();
357    let sign_a = if (hash & 1) == 0 { 1.0 } else { -1.0 };
358    let sign_b = if ((hash >> 1) & 1) == 0 { -1.0 } else { 1.0 };
359    vector[idx_a] += sign_a * INSERT_JITTER_EPS;
360    vector[idx_b] += sign_b * INSERT_JITTER_EPS * 0.5;
361}
362
363fn deterministic_fallback_vector(seed: &str, dimensions: usize) -> Vec<f32> {
364    if dimensions == 0 {
365        return Vec::new();
366    }
367
368    let mut state: u64 = 0xcbf29ce484222325;
369    for b in seed.as_bytes() {
370        state ^= u64::from(*b);
371        state = state.wrapping_mul(0x100000001b3);
372    }
373    if state == 0 {
374        state = 1;
375    }
376
377    let mut out = Vec::with_capacity(dimensions);
378    for _ in 0..dimensions {
379        state ^= state >> 12;
380        state ^= state << 25;
381        state ^= state >> 27;
382        let r = state.wrapping_mul(0x2545f4914f6cdd1d);
383        let unit = (r as f64 / u64::MAX as f64) as f32;
384        out.push(unit * 2.0 - 1.0);
385    }
386
387    let norm = out.iter().map(|x| x * x).sum::<f32>().sqrt();
388    if !norm.is_finite() || norm <= VECTOR_NORM_EPS {
389        let mut unit = vec![0.0_f32; dimensions];
390        unit[0] = 1.0;
391        return unit;
392    }
393    for v in &mut out {
394        *v /= norm;
395    }
396    out
397}
398
399// ─── Tests ───────────────────────────────────────────────────────────────────
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    async fn temp_store() -> (RuVectorStore, tempfile::TempDir) {
406        let dir = tempfile::tempdir().unwrap();
407        let store = RuVectorStore::open(dir.path(), VECTOR_DIM).await.unwrap();
408        (store, dir)
409    }
410
411    fn unit_vec(axis: usize) -> Vec<f32> {
412        let mut v = vec![0.0f32; VECTOR_DIM];
413        v[axis] = 1.0;
414        v
415    }
416
417    #[tokio::test]
418    async fn test_open_and_ensure_tables() {
419        let (store, _dir) = temp_store().await;
420        store.ensure_tables().await.unwrap();
421
422        let mut tables = store.table_names().await.unwrap();
423        tables.sort();
424        assert!(tables.contains(&"episodes_vec".to_string()));
425        assert!(tables.contains(&"facts_vec".to_string()));
426    }
427
428    #[tokio::test]
429    async fn test_ensure_tables_idempotent() {
430        let (store, _dir) = temp_store().await;
431        store.ensure_tables().await.unwrap();
432        store.ensure_tables().await.unwrap();
433    }
434
435    #[tokio::test]
436    async fn test_add_and_count() {
437        let (store, _dir) = temp_store().await;
438        store.ensure_tables().await.unwrap();
439
440        store
441            .add_vectors(
442                "episodes_vec",
443                vec!["ep001".into()],
444                vec![],
445                vec![unit_vec(0)],
446                vec![],
447                "episodic",
448            )
449            .await
450            .unwrap();
451
452        assert_eq!(store.table_count("episodes_vec").await.unwrap(), 1);
453    }
454
455    #[tokio::test]
456    async fn test_vector_search() {
457        let (store, _dir) = temp_store().await;
458        store.ensure_tables().await.unwrap();
459
460        let v1 = unit_vec(0);
461        let v2 = unit_vec(1);
462        let mut v3 = vec![0.0f32; VECTOR_DIM];
463        v3[0] = 0.9;
464        v3[1] = 0.1;
465
466        store
467            .add_vectors(
468                "facts_vec",
469                vec!["f1".into(), "f2".into(), "f3".into()],
470                vec![],
471                vec![v1.clone(), v2, v3],
472                vec![],
473                "semantic",
474            )
475            .await
476            .unwrap();
477
478        let results = store.search("facts_vec", v1, 2).await.unwrap();
479        assert!(!results.is_empty());
480        assert_eq!(results[0].id, "f1");
481    }
482
483    #[tokio::test]
484    async fn test_delete() {
485        let (store, _dir) = temp_store().await;
486        store.ensure_tables().await.unwrap();
487
488        store
489            .add_vectors(
490                "facts_vec",
491                vec!["f1".into()],
492                vec![],
493                vec![unit_vec(0)],
494                vec![],
495                "semantic",
496            )
497            .await
498            .unwrap();
499
500        assert_eq!(store.table_count("facts_vec").await.unwrap(), 1);
501        store.delete("facts_vec", "f1").await.unwrap();
502        assert_eq!(store.table_count("facts_vec").await.unwrap(), 0);
503    }
504
505    #[tokio::test]
506    async fn test_identical_vectors_with_different_ids_do_not_panic() {
507        let (store, _dir) = temp_store().await;
508        store.ensure_tables().await.unwrap();
509
510        let repeated = unit_vec(0);
511        for i in 0..64 {
512            store
513                .add_vectors(
514                    "facts_vec",
515                    vec![format!("dup-{i}")],
516                    vec![],
517                    vec![repeated.clone()],
518                    vec![],
519                    "semantic",
520                )
521                .await
522                .unwrap();
523        }
524
525        let results = store.search("facts_vec", unit_vec(0), 5).await.unwrap();
526        assert!(!results.is_empty());
527        assert!(results.iter().all(|r| r.distance.is_finite()));
528    }
529
530    #[tokio::test]
531    async fn test_invalid_or_zero_vectors_are_sanitized() {
532        let (store, _dir) = temp_store().await;
533        store.ensure_tables().await.unwrap();
534
535        store
536            .add_vectors(
537                "facts_vec",
538                vec!["zero".into(), "nan".into()],
539                vec![],
540                vec![vec![0.0_f32; VECTOR_DIM], vec![f32::NAN; VECTOR_DIM]],
541                vec![],
542                "semantic",
543            )
544            .await
545            .unwrap();
546
547        let results = store
548            .search("facts_vec", vec![0.0_f32; VECTOR_DIM], 2)
549            .await
550            .unwrap();
551        assert_eq!(results.len(), 2);
552        assert!(results.iter().all(|r| r.distance.is_finite()));
553        assert!(results.iter().all(|r| r.distance >= 0.0));
554    }
555}