1use std::{
15 collections::HashMap,
16 path::{Path, PathBuf},
17 sync::{Arc, RwLock},
18};
19
20use thiserror::Error;
21use tracing::{info, warn};
22
23use ruvector_core::{
24 types::{DbOptions, HnswConfig as RuvHnswConfig},
25 DistanceMetric, SearchQuery, VectorDB, VectorEntry,
26};
27
28pub const VECTOR_DIM: usize = 768;
31const VECTOR_NORM_EPS: f32 = 1e-12;
32const INSERT_JITTER_EPS: f32 = 1e-2;
33
34#[derive(Debug, Error)]
37pub enum RuVectorError {
38 #[error("Vector DB error: {0}")]
39 Db(String),
40
41 #[error("Table not found: {0}")]
42 TableNotFound(String),
43
44 #[error("IO error: {0}")]
45 Io(#[from] std::io::Error),
46
47 #[error("Lock poisoned")]
48 LockPoisoned,
49}
50
51impl From<ruvector_core::error::RuvectorError> for RuVectorError {
52 fn from(e: ruvector_core::error::RuvectorError) -> Self {
53 RuVectorError::Db(e.to_string())
54 }
55}
56
57#[derive(Debug, Clone)]
61pub struct VectorResult {
62 pub id: String,
64 pub distance: f32,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub struct HnswConfig {
74 pub m: u32,
75 pub ef_construction: u32,
76 pub ef_search: u32,
77 pub max_elements: u32,
78}
79
80impl Default for HnswConfig {
81 fn default() -> Self {
82 Self {
83 m: 16,
84 ef_construction: 200,
85 ef_search: 50,
86 max_elements: 100_000,
93 }
94 }
95}
96
97#[derive(Clone)]
101pub struct RuVectorStore {
102 root: PathBuf,
103 dimensions: usize,
105 hnsw: HnswConfig,
106 tables: Arc<RwLock<HashMap<String, VectorDB>>>,
107}
108
109impl RuVectorStore {
110 pub async fn open(path: &Path, dimensions: usize) -> Result<Self, RuVectorError> {
117 Self::open_with_config(path, dimensions, HnswConfig::default()).await
118 }
119
120 pub async fn open_with_config(
125 path: &Path,
126 dimensions: usize,
127 hnsw: HnswConfig,
128 ) -> Result<Self, RuVectorError> {
129 std::fs::create_dir_all(path)?;
130 info!(
131 m = hnsw.m,
132 ef_construction = hnsw.ef_construction,
133 ef_search = hnsw.ef_search,
134 max_elements = hnsw.max_elements,
135 "RuVector store opened at {} (dim={})",
136 path.display(),
137 dimensions
138 );
139 Ok(Self {
140 root: path.to_path_buf(),
141 dimensions,
142 hnsw,
143 tables: Arc::new(RwLock::new(HashMap::new())),
144 })
145 }
146
147 fn make_db(&self, table_name: &str) -> Result<VectorDB, RuVectorError> {
148 let db_path = self.root.join(format!("{table_name}.db"));
149 let options = DbOptions {
150 dimensions: self.dimensions,
151 distance_metric: DistanceMetric::Cosine,
152 storage_path: db_path.to_string_lossy().into_owned(),
153 hnsw_config: Some(RuvHnswConfig {
154 m: self.hnsw.m as usize,
155 ef_construction: self.hnsw.ef_construction as usize,
156 ef_search: self.hnsw.ef_search as usize,
157 max_elements: self.hnsw.max_elements as usize,
158 }),
159 quantization: None,
160 };
161 VectorDB::new(options).map_err(Into::into)
162 }
163
164 fn get_or_create_db(&self, table_name: &str) -> Result<(), RuVectorError> {
165 let has = self
166 .tables
167 .read()
168 .map_err(|_| RuVectorError::LockPoisoned)?
169 .contains_key(table_name);
170
171 if !has {
172 let db = self.make_db(table_name)?;
173 self.tables
174 .write()
175 .map_err(|_| RuVectorError::LockPoisoned)?
176 .insert(table_name.to_string(), db);
177 }
178 Ok(())
179 }
180
181 pub async fn ensure_tables(&self) -> Result<(), RuVectorError> {
187 const MAX_RETRIES: u32 = 5;
188 const BASE_DELAY_MS: u64 = 200;
189
190 for name in &["facts_vec", "episodes_vec", "graph_vec"] {
191 let mut last_err = None;
192 for attempt in 0..=MAX_RETRIES {
193 match self.get_or_create_db(name) {
194 Ok(()) => {
195 if attempt > 0 {
196 info!("RuVector table '{name}' opened after {attempt} retries");
197 } else {
198 info!("Ensured RuVector table: {name}");
199 }
200 last_err = None;
201 break;
202 }
203 Err(e) if attempt < MAX_RETRIES => {
204 let delay_ms = BASE_DELAY_MS * 2u64.pow(attempt);
205 warn!(
206 table = name,
207 attempt = attempt + 1,
208 delay_ms,
209 error = %e,
210 "RuVector table lock contention, retrying"
211 );
212 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
213 last_err = Some(e);
214 }
215 Err(e) => {
216 last_err = Some(e);
217 }
218 }
219 }
220 if let Some(e) = last_err {
221 return Err(e);
222 }
223 }
224 Ok(())
225 }
226
227 pub async fn add_vectors(
229 &self,
230 table_name: &str,
231 ids: Vec<String>,
232 _contents: Vec<String>,
233 vectors: Vec<Vec<f32>>,
234 _timestamps: Vec<String>,
235 _source_type: &str,
236 ) -> Result<(), RuVectorError> {
237 self.get_or_create_db(table_name)?;
238 let tables = self
239 .tables
240 .read()
241 .map_err(|_| RuVectorError::LockPoisoned)?;
242 let db = tables
243 .get(table_name)
244 .ok_or_else(|| RuVectorError::TableNotFound(table_name.to_string()))?;
245
246 let count = ids.len();
247 for (id, vector) in ids.into_iter().zip(vectors) {
248 let safe_vector = sanitize_vector_for_insert(vector, self.dimensions, &id);
249 let entry = VectorEntry {
250 id: Some(id),
251 vector: safe_vector,
252 metadata: None,
253 };
254 db.insert(entry)?;
255 }
256 info!("Added {count} vectors to '{table_name}'");
257 Ok(())
258 }
259
260 pub async fn search(
264 &self,
265 table_name: &str,
266 query_vector: Vec<f32>,
267 top_k: usize,
268 ) -> Result<Vec<VectorResult>, RuVectorError> {
269 self.get_or_create_db(table_name)?;
272
273 let tables = self
274 .tables
275 .read()
276 .map_err(|_| RuVectorError::LockPoisoned)?;
277 let db = tables
278 .get(table_name)
279 .ok_or_else(|| RuVectorError::TableNotFound(table_name.to_string()))?;
280
281 let safe_query = sanitize_vector_for_query(query_vector, self.dimensions, table_name);
282 let results = db.search(SearchQuery {
283 vector: safe_query,
284 k: top_k,
285 filter: None,
286 ef_search: None,
287 })?;
288
289 Ok(results
290 .into_iter()
291 .map(|r| VectorResult {
292 id: r.id,
293 distance: sanitize_distance(r.score),
294 })
295 .collect())
296 }
297
298 pub async fn delete(&self, table_name: &str, id: &str) -> Result<(), RuVectorError> {
300 let tables = self
301 .tables
302 .read()
303 .map_err(|_| RuVectorError::LockPoisoned)?;
304 if let Some(db) = tables.get(table_name) {
305 db.delete(id)?;
306 }
307 Ok(())
308 }
309
310 pub async fn delete_batch(
318 &self,
319 table_name: &str,
320 ids: &[&str],
321 ) -> Result<Vec<(String, RuVectorError)>, RuVectorError> {
322 let tables = self
323 .tables
324 .read()
325 .map_err(|_| RuVectorError::LockPoisoned)?;
326 let mut failures = Vec::new();
327 if let Some(db) = tables.get(table_name) {
328 for id in ids {
329 if let Err(e) = db.delete(id) {
330 failures.push(((*id).to_string(), RuVectorError::from(e)));
331 }
332 }
333 }
334 Ok(failures)
335 }
336
337 pub async fn table_count(&self, table_name: &str) -> Result<usize, RuVectorError> {
339 let tables = self
340 .tables
341 .read()
342 .map_err(|_| RuVectorError::LockPoisoned)?;
343 Ok(tables
344 .get(table_name)
345 .map(|db| db.len().unwrap_or(0))
346 .unwrap_or(0))
347 }
348
349 pub async fn table_names(&self) -> Result<Vec<String>, RuVectorError> {
351 Ok(self
352 .tables
353 .read()
354 .map_err(|_| RuVectorError::LockPoisoned)?
355 .keys()
356 .cloned()
357 .collect())
358 }
359}
360
361fn sanitize_distance(score: f32) -> f32 {
362 if !score.is_finite() {
363 return f32::MAX;
364 }
365 if score < 0.0 {
366 return 0.0;
367 }
368 score
369}
370
371fn sanitize_vector_for_insert(vector: Vec<f32>, dimensions: usize, id: &str) -> Vec<f32> {
372 let mut out = sanitize_vector_for_query(vector, dimensions, id);
373 apply_insert_jitter(&mut out, id);
374 normalize_in_place_or_fallback(&mut out, id);
375 out
376}
377
378fn sanitize_vector_for_query(vector: Vec<f32>, dimensions: usize, seed: &str) -> Vec<f32> {
379 if dimensions == 0 {
380 return Vec::new();
381 }
382 if vector.len() != dimensions || vector.iter().any(|x| !x.is_finite()) {
383 warn!(
384 expected_dim = dimensions,
385 got_dim = vector.len(),
386 "Invalid embedding shape/value; using deterministic fallback"
387 );
388 return deterministic_fallback_vector(seed, dimensions);
389 }
390
391 let mut out = vector;
392 if !normalize_in_place_or_fallback(&mut out, seed) {
393 return deterministic_fallback_vector(seed, dimensions);
394 }
395 out
396}
397
398fn normalize_in_place_or_fallback(vector: &mut [f32], seed: &str) -> bool {
399 if vector.is_empty() {
400 return true;
401 }
402
403 let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
404 if !norm_sq.is_finite() || norm_sq <= VECTOR_NORM_EPS {
405 let fallback = deterministic_fallback_vector(seed, vector.len());
406 vector.copy_from_slice(&fallback);
407 return false;
408 }
409
410 let norm = norm_sq.sqrt();
411 for v in vector.iter_mut() {
412 *v /= norm;
413 }
414 true
415}
416
417fn apply_insert_jitter(vector: &mut [f32], id: &str) {
418 if vector.is_empty() {
419 return;
420 }
421
422 let mut hash: u64 = 0xcbf29ce484222325;
424 for b in id.as_bytes() {
425 hash ^= u64::from(*b);
426 hash = hash.wrapping_mul(0x100000001b3);
427 }
428
429 let idx_a = (hash as usize) % vector.len();
430 let idx_b = (hash.rotate_left(17) as usize) % vector.len();
431 let sign_a = if (hash & 1) == 0 { 1.0 } else { -1.0 };
432 let sign_b = if ((hash >> 1) & 1) == 0 { -1.0 } else { 1.0 };
433 vector[idx_a] += sign_a * INSERT_JITTER_EPS;
434 vector[idx_b] += sign_b * INSERT_JITTER_EPS * 0.5;
435}
436
437fn deterministic_fallback_vector(seed: &str, dimensions: usize) -> Vec<f32> {
438 if dimensions == 0 {
439 return Vec::new();
440 }
441
442 let mut state: u64 = 0xcbf29ce484222325;
443 for b in seed.as_bytes() {
444 state ^= u64::from(*b);
445 state = state.wrapping_mul(0x100000001b3);
446 }
447 if state == 0 {
448 state = 1;
449 }
450
451 let mut out = Vec::with_capacity(dimensions);
452 for _ in 0..dimensions {
453 state ^= state >> 12;
454 state ^= state << 25;
455 state ^= state >> 27;
456 let r = state.wrapping_mul(0x2545f4914f6cdd1d);
457 let unit = (r as f64 / u64::MAX as f64) as f32;
458 out.push(unit * 2.0 - 1.0);
459 }
460
461 let norm = out.iter().map(|x| x * x).sum::<f32>().sqrt();
462 if !norm.is_finite() || norm <= VECTOR_NORM_EPS {
463 let mut unit = vec![0.0_f32; dimensions];
464 unit[0] = 1.0;
465 return unit;
466 }
467 for v in &mut out {
468 *v /= norm;
469 }
470 out
471}
472
473#[cfg(test)]
476mod tests {
477 use super::*;
478
479 async fn temp_store() -> (RuVectorStore, tempfile::TempDir) {
480 let dir = tempfile::tempdir().unwrap();
481 let store = RuVectorStore::open(dir.path(), VECTOR_DIM).await.unwrap();
482 (store, dir)
483 }
484
485 #[tokio::test]
490 async fn open_with_config_persists_tuning() {
491 let dir = tempfile::tempdir().unwrap();
492 let custom = HnswConfig {
493 m: 32,
494 ef_construction: 400,
495 ef_search: 100,
496 max_elements: 5_000_000,
497 };
498 let store = RuVectorStore::open_with_config(dir.path(), VECTOR_DIM, custom)
499 .await
500 .unwrap();
501 assert_eq!(store.hnsw, custom);
502
503 let default_store = RuVectorStore::open(dir.path(), VECTOR_DIM).await.unwrap();
504 assert_eq!(default_store.hnsw, HnswConfig::default());
505 }
506
507 fn unit_vec(axis: usize) -> Vec<f32> {
508 let mut v = vec![0.0f32; VECTOR_DIM];
509 v[axis] = 1.0;
510 v
511 }
512
513 #[tokio::test]
514 async fn test_open_and_ensure_tables() {
515 let (store, _dir) = temp_store().await;
516 store.ensure_tables().await.unwrap();
517
518 let mut tables = store.table_names().await.unwrap();
519 tables.sort();
520 assert!(tables.contains(&"episodes_vec".to_string()));
521 assert!(tables.contains(&"facts_vec".to_string()));
522 }
523
524 #[tokio::test]
525 async fn test_ensure_tables_idempotent() {
526 let (store, _dir) = temp_store().await;
527 store.ensure_tables().await.unwrap();
528 store.ensure_tables().await.unwrap();
529 }
530
531 #[tokio::test]
532 async fn test_add_and_count() {
533 let (store, _dir) = temp_store().await;
534 store.ensure_tables().await.unwrap();
535
536 store
537 .add_vectors(
538 "episodes_vec",
539 vec!["ep001".into()],
540 vec![],
541 vec![unit_vec(0)],
542 vec![],
543 "episodic",
544 )
545 .await
546 .unwrap();
547
548 assert_eq!(store.table_count("episodes_vec").await.unwrap(), 1);
549 }
550
551 #[tokio::test]
552 async fn test_vector_search() {
553 let (store, _dir) = temp_store().await;
554 store.ensure_tables().await.unwrap();
555
556 let v1 = unit_vec(0);
557 let v2 = unit_vec(1);
558 let mut v3 = vec![0.0f32; VECTOR_DIM];
559 v3[0] = 0.9;
560 v3[1] = 0.1;
561
562 store
563 .add_vectors(
564 "facts_vec",
565 vec!["f1".into(), "f2".into(), "f3".into()],
566 vec![],
567 vec![v1.clone(), v2, v3],
568 vec![],
569 "semantic",
570 )
571 .await
572 .unwrap();
573
574 let results = store.search("facts_vec", v1, 2).await.unwrap();
575 assert!(!results.is_empty());
576 assert_eq!(results[0].id, "f1");
577 }
578
579 #[tokio::test]
580 async fn test_delete() {
581 let (store, _dir) = temp_store().await;
582 store.ensure_tables().await.unwrap();
583
584 store
585 .add_vectors(
586 "facts_vec",
587 vec!["f1".into()],
588 vec![],
589 vec![unit_vec(0)],
590 vec![],
591 "semantic",
592 )
593 .await
594 .unwrap();
595
596 assert_eq!(store.table_count("facts_vec").await.unwrap(), 1);
597 store.delete("facts_vec", "f1").await.unwrap();
598 assert_eq!(store.table_count("facts_vec").await.unwrap(), 0);
599 }
600
601 #[tokio::test]
602 async fn test_identical_vectors_with_different_ids_do_not_panic() {
603 let (store, _dir) = temp_store().await;
604 store.ensure_tables().await.unwrap();
605
606 let repeated = unit_vec(0);
607 for i in 0..64 {
608 store
609 .add_vectors(
610 "facts_vec",
611 vec![format!("dup-{i}")],
612 vec![],
613 vec![repeated.clone()],
614 vec![],
615 "semantic",
616 )
617 .await
618 .unwrap();
619 }
620
621 let results = store.search("facts_vec", unit_vec(0), 5).await.unwrap();
622 assert!(!results.is_empty());
623 assert!(results.iter().all(|r| r.distance.is_finite()));
624 }
625
626 #[tokio::test]
627 async fn test_invalid_or_zero_vectors_are_sanitized() {
628 let (store, _dir) = temp_store().await;
629 store.ensure_tables().await.unwrap();
630
631 store
632 .add_vectors(
633 "facts_vec",
634 vec!["zero".into(), "nan".into()],
635 vec![],
636 vec![vec![0.0_f32; VECTOR_DIM], vec![f32::NAN; VECTOR_DIM]],
637 vec![],
638 "semantic",
639 )
640 .await
641 .unwrap();
642
643 let results = store
644 .search("facts_vec", vec![0.0_f32; VECTOR_DIM], 2)
645 .await
646 .unwrap();
647 assert_eq!(results.len(), 2);
648 assert!(results.iter().all(|r| r.distance.is_finite()));
649 assert!(results.iter().all(|r| r.distance >= 0.0));
650 }
651}