1use std::{
15 collections::HashMap,
16 path::{Path, PathBuf},
17 sync::{Arc, RwLock},
18};
19
20use thiserror::Error;
21use tracing::{info, warn};
22
23use ruvector_core::{
24 types::{DbOptions, HnswConfig as RuvHnswConfig},
25 DistanceMetric, SearchQuery, VectorDB, VectorEntry,
26};
27
28pub const VECTOR_DIM: usize = 768;
31const VECTOR_NORM_EPS: f32 = 1e-12;
32const INSERT_JITTER_EPS: f32 = 1e-2;
33
34#[derive(Debug, Error)]
37pub enum RuVectorError {
38 #[error("Vector DB error: {0}")]
39 Db(String),
40
41 #[error("Table not found: {0}")]
42 TableNotFound(String),
43
44 #[error("IO error: {0}")]
45 Io(#[from] std::io::Error),
46
47 #[error("Lock poisoned")]
48 LockPoisoned,
49}
50
51impl From<ruvector_core::error::RuvectorError> for RuVectorError {
52 fn from(e: ruvector_core::error::RuvectorError) -> Self {
53 RuVectorError::Db(e.to_string())
54 }
55}
56
57#[derive(Debug, Clone)]
61pub struct VectorResult {
62 pub id: String,
64 pub distance: f32,
66}
67
68#[derive(Clone)]
72pub struct RuVectorStore {
73 root: PathBuf,
74 dimensions: usize,
76 tables: Arc<RwLock<HashMap<String, VectorDB>>>,
77}
78
79impl RuVectorStore {
80 pub async fn open(path: &Path, dimensions: usize) -> Result<Self, RuVectorError> {
87 std::fs::create_dir_all(path)?;
88 info!(
89 "RuVector store opened at {} (dim={})",
90 path.display(),
91 dimensions
92 );
93 Ok(Self {
94 root: path.to_path_buf(),
95 dimensions,
96 tables: Arc::new(RwLock::new(HashMap::new())),
97 })
98 }
99
100 fn make_db(&self, table_name: &str) -> Result<VectorDB, RuVectorError> {
101 let db_path = self.root.join(format!("{table_name}.db"));
102 let options = DbOptions {
103 dimensions: self.dimensions,
104 distance_metric: DistanceMetric::Cosine,
105 storage_path: db_path.to_string_lossy().into_owned(),
106 hnsw_config: Some(RuvHnswConfig {
107 m: 16,
108 ef_construction: 200,
109 ef_search: 50,
110 max_elements: 10_000_000,
111 }),
112 quantization: None,
113 };
114 VectorDB::new(options).map_err(Into::into)
115 }
116
117 fn get_or_create_db(&self, table_name: &str) -> Result<(), RuVectorError> {
118 let has = self
119 .tables
120 .read()
121 .map_err(|_| RuVectorError::LockPoisoned)?
122 .contains_key(table_name);
123
124 if !has {
125 let db = self.make_db(table_name)?;
126 self.tables
127 .write()
128 .map_err(|_| RuVectorError::LockPoisoned)?
129 .insert(table_name.to_string(), db);
130 }
131 Ok(())
132 }
133
134 pub async fn ensure_tables(&self) -> Result<(), RuVectorError> {
140 const MAX_RETRIES: u32 = 5;
141 const BASE_DELAY_MS: u64 = 200;
142
143 for name in &["facts_vec", "episodes_vec"] {
144 let mut last_err = None;
145 for attempt in 0..=MAX_RETRIES {
146 match self.get_or_create_db(name) {
147 Ok(()) => {
148 if attempt > 0 {
149 info!("RuVector table '{name}' opened after {attempt} retries");
150 } else {
151 info!("Ensured RuVector table: {name}");
152 }
153 last_err = None;
154 break;
155 }
156 Err(e) if attempt < MAX_RETRIES => {
157 let delay_ms = BASE_DELAY_MS * 2u64.pow(attempt);
158 warn!(
159 table = name,
160 attempt = attempt + 1,
161 delay_ms,
162 error = %e,
163 "RuVector table lock contention, retrying"
164 );
165 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
166 last_err = Some(e);
167 }
168 Err(e) => {
169 last_err = Some(e);
170 }
171 }
172 }
173 if let Some(e) = last_err {
174 return Err(e);
175 }
176 }
177 Ok(())
178 }
179
180 pub async fn add_vectors(
182 &self,
183 table_name: &str,
184 ids: Vec<String>,
185 _contents: Vec<String>,
186 vectors: Vec<Vec<f32>>,
187 _timestamps: Vec<String>,
188 _source_type: &str,
189 ) -> Result<(), RuVectorError> {
190 self.get_or_create_db(table_name)?;
191 let tables = self
192 .tables
193 .read()
194 .map_err(|_| RuVectorError::LockPoisoned)?;
195 let db = tables
196 .get(table_name)
197 .ok_or_else(|| RuVectorError::TableNotFound(table_name.to_string()))?;
198
199 let count = ids.len();
200 for (id, vector) in ids.into_iter().zip(vectors) {
201 let safe_vector = sanitize_vector_for_insert(vector, self.dimensions, &id);
202 let entry = VectorEntry {
203 id: Some(id),
204 vector: safe_vector,
205 metadata: None,
206 };
207 db.insert(entry)?;
208 }
209 info!("Added {count} vectors to '{table_name}'");
210 Ok(())
211 }
212
213 pub async fn search(
217 &self,
218 table_name: &str,
219 query_vector: Vec<f32>,
220 top_k: usize,
221 ) -> Result<Vec<VectorResult>, RuVectorError> {
222 self.get_or_create_db(table_name)?;
225
226 let tables = self
227 .tables
228 .read()
229 .map_err(|_| RuVectorError::LockPoisoned)?;
230 let db = tables
231 .get(table_name)
232 .ok_or_else(|| RuVectorError::TableNotFound(table_name.to_string()))?;
233
234 let safe_query = sanitize_vector_for_query(query_vector, self.dimensions, table_name);
235 let results = db.search(SearchQuery {
236 vector: safe_query,
237 k: top_k,
238 filter: None,
239 ef_search: None,
240 })?;
241
242 Ok(results
243 .into_iter()
244 .map(|r| VectorResult {
245 id: r.id,
246 distance: sanitize_distance(r.score),
247 })
248 .collect())
249 }
250
251 pub async fn delete(&self, table_name: &str, id: &str) -> Result<(), RuVectorError> {
253 let tables = self
254 .tables
255 .read()
256 .map_err(|_| RuVectorError::LockPoisoned)?;
257 if let Some(db) = tables.get(table_name) {
258 db.delete(id)?;
259 }
260 Ok(())
261 }
262
263 pub async fn table_count(&self, table_name: &str) -> Result<usize, RuVectorError> {
265 let tables = self
266 .tables
267 .read()
268 .map_err(|_| RuVectorError::LockPoisoned)?;
269 Ok(tables
270 .get(table_name)
271 .map(|db| db.len().unwrap_or(0))
272 .unwrap_or(0))
273 }
274
275 pub async fn table_names(&self) -> Result<Vec<String>, RuVectorError> {
277 Ok(self
278 .tables
279 .read()
280 .map_err(|_| RuVectorError::LockPoisoned)?
281 .keys()
282 .cloned()
283 .collect())
284 }
285}
286
287fn sanitize_distance(score: f32) -> f32 {
288 if !score.is_finite() {
289 return f32::MAX;
290 }
291 if score < 0.0 {
292 return 0.0;
293 }
294 score
295}
296
297fn sanitize_vector_for_insert(vector: Vec<f32>, dimensions: usize, id: &str) -> Vec<f32> {
298 let mut out = sanitize_vector_for_query(vector, dimensions, id);
299 apply_insert_jitter(&mut out, id);
300 normalize_in_place_or_fallback(&mut out, id);
301 out
302}
303
304fn sanitize_vector_for_query(vector: Vec<f32>, dimensions: usize, seed: &str) -> Vec<f32> {
305 if dimensions == 0 {
306 return Vec::new();
307 }
308 if vector.len() != dimensions || vector.iter().any(|x| !x.is_finite()) {
309 warn!(
310 expected_dim = dimensions,
311 got_dim = vector.len(),
312 "Invalid embedding shape/value; using deterministic fallback"
313 );
314 return deterministic_fallback_vector(seed, dimensions);
315 }
316
317 let mut out = vector;
318 if !normalize_in_place_or_fallback(&mut out, seed) {
319 return deterministic_fallback_vector(seed, dimensions);
320 }
321 out
322}
323
324fn normalize_in_place_or_fallback(vector: &mut [f32], seed: &str) -> bool {
325 if vector.is_empty() {
326 return true;
327 }
328
329 let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
330 if !norm_sq.is_finite() || norm_sq <= VECTOR_NORM_EPS {
331 let fallback = deterministic_fallback_vector(seed, vector.len());
332 vector.copy_from_slice(&fallback);
333 return false;
334 }
335
336 let norm = norm_sq.sqrt();
337 for v in vector.iter_mut() {
338 *v /= norm;
339 }
340 true
341}
342
343fn apply_insert_jitter(vector: &mut [f32], id: &str) {
344 if vector.is_empty() {
345 return;
346 }
347
348 let mut hash: u64 = 0xcbf29ce484222325;
350 for b in id.as_bytes() {
351 hash ^= u64::from(*b);
352 hash = hash.wrapping_mul(0x100000001b3);
353 }
354
355 let idx_a = (hash as usize) % vector.len();
356 let idx_b = (hash.rotate_left(17) as usize) % vector.len();
357 let sign_a = if (hash & 1) == 0 { 1.0 } else { -1.0 };
358 let sign_b = if ((hash >> 1) & 1) == 0 { -1.0 } else { 1.0 };
359 vector[idx_a] += sign_a * INSERT_JITTER_EPS;
360 vector[idx_b] += sign_b * INSERT_JITTER_EPS * 0.5;
361}
362
363fn deterministic_fallback_vector(seed: &str, dimensions: usize) -> Vec<f32> {
364 if dimensions == 0 {
365 return Vec::new();
366 }
367
368 let mut state: u64 = 0xcbf29ce484222325;
369 for b in seed.as_bytes() {
370 state ^= u64::from(*b);
371 state = state.wrapping_mul(0x100000001b3);
372 }
373 if state == 0 {
374 state = 1;
375 }
376
377 let mut out = Vec::with_capacity(dimensions);
378 for _ in 0..dimensions {
379 state ^= state >> 12;
380 state ^= state << 25;
381 state ^= state >> 27;
382 let r = state.wrapping_mul(0x2545f4914f6cdd1d);
383 let unit = (r as f64 / u64::MAX as f64) as f32;
384 out.push(unit * 2.0 - 1.0);
385 }
386
387 let norm = out.iter().map(|x| x * x).sum::<f32>().sqrt();
388 if !norm.is_finite() || norm <= VECTOR_NORM_EPS {
389 let mut unit = vec![0.0_f32; dimensions];
390 unit[0] = 1.0;
391 return unit;
392 }
393 for v in &mut out {
394 *v /= norm;
395 }
396 out
397}
398
399#[cfg(test)]
402mod tests {
403 use super::*;
404
405 async fn temp_store() -> (RuVectorStore, tempfile::TempDir) {
406 let dir = tempfile::tempdir().unwrap();
407 let store = RuVectorStore::open(dir.path(), VECTOR_DIM).await.unwrap();
408 (store, dir)
409 }
410
411 fn unit_vec(axis: usize) -> Vec<f32> {
412 let mut v = vec![0.0f32; VECTOR_DIM];
413 v[axis] = 1.0;
414 v
415 }
416
417 #[tokio::test]
418 async fn test_open_and_ensure_tables() {
419 let (store, _dir) = temp_store().await;
420 store.ensure_tables().await.unwrap();
421
422 let mut tables = store.table_names().await.unwrap();
423 tables.sort();
424 assert!(tables.contains(&"episodes_vec".to_string()));
425 assert!(tables.contains(&"facts_vec".to_string()));
426 }
427
428 #[tokio::test]
429 async fn test_ensure_tables_idempotent() {
430 let (store, _dir) = temp_store().await;
431 store.ensure_tables().await.unwrap();
432 store.ensure_tables().await.unwrap();
433 }
434
435 #[tokio::test]
436 async fn test_add_and_count() {
437 let (store, _dir) = temp_store().await;
438 store.ensure_tables().await.unwrap();
439
440 store
441 .add_vectors(
442 "episodes_vec",
443 vec!["ep001".into()],
444 vec![],
445 vec![unit_vec(0)],
446 vec![],
447 "episodic",
448 )
449 .await
450 .unwrap();
451
452 assert_eq!(store.table_count("episodes_vec").await.unwrap(), 1);
453 }
454
455 #[tokio::test]
456 async fn test_vector_search() {
457 let (store, _dir) = temp_store().await;
458 store.ensure_tables().await.unwrap();
459
460 let v1 = unit_vec(0);
461 let v2 = unit_vec(1);
462 let mut v3 = vec![0.0f32; VECTOR_DIM];
463 v3[0] = 0.9;
464 v3[1] = 0.1;
465
466 store
467 .add_vectors(
468 "facts_vec",
469 vec!["f1".into(), "f2".into(), "f3".into()],
470 vec![],
471 vec![v1.clone(), v2, v3],
472 vec![],
473 "semantic",
474 )
475 .await
476 .unwrap();
477
478 let results = store.search("facts_vec", v1, 2).await.unwrap();
479 assert!(!results.is_empty());
480 assert_eq!(results[0].id, "f1");
481 }
482
483 #[tokio::test]
484 async fn test_delete() {
485 let (store, _dir) = temp_store().await;
486 store.ensure_tables().await.unwrap();
487
488 store
489 .add_vectors(
490 "facts_vec",
491 vec!["f1".into()],
492 vec![],
493 vec![unit_vec(0)],
494 vec![],
495 "semantic",
496 )
497 .await
498 .unwrap();
499
500 assert_eq!(store.table_count("facts_vec").await.unwrap(), 1);
501 store.delete("facts_vec", "f1").await.unwrap();
502 assert_eq!(store.table_count("facts_vec").await.unwrap(), 0);
503 }
504
505 #[tokio::test]
506 async fn test_identical_vectors_with_different_ids_do_not_panic() {
507 let (store, _dir) = temp_store().await;
508 store.ensure_tables().await.unwrap();
509
510 let repeated = unit_vec(0);
511 for i in 0..64 {
512 store
513 .add_vectors(
514 "facts_vec",
515 vec![format!("dup-{i}")],
516 vec![],
517 vec![repeated.clone()],
518 vec![],
519 "semantic",
520 )
521 .await
522 .unwrap();
523 }
524
525 let results = store.search("facts_vec", unit_vec(0), 5).await.unwrap();
526 assert!(!results.is_empty());
527 assert!(results.iter().all(|r| r.distance.is_finite()));
528 }
529
530 #[tokio::test]
531 async fn test_invalid_or_zero_vectors_are_sanitized() {
532 let (store, _dir) = temp_store().await;
533 store.ensure_tables().await.unwrap();
534
535 store
536 .add_vectors(
537 "facts_vec",
538 vec!["zero".into(), "nan".into()],
539 vec![],
540 vec![vec![0.0_f32; VECTOR_DIM], vec![f32::NAN; VECTOR_DIM]],
541 vec![],
542 "semantic",
543 )
544 .await
545 .unwrap();
546
547 let results = store
548 .search("facts_vec", vec![0.0_f32; VECTOR_DIM], 2)
549 .await
550 .unwrap();
551 assert_eq!(results.len(), 2);
552 assert!(results.iter().all(|r| r.distance.is_finite()));
553 assert!(results.iter().all(|r| r.distance >= 0.0));
554 }
555}