use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::{Arc, RwLock},
};
use thiserror::Error;
use tracing::{info, warn};
use ruvector_core::{
types::{DbOptions, HnswConfig as RuvHnswConfig},
DistanceMetric, SearchQuery, VectorDB, VectorEntry,
};
pub const VECTOR_DIM: usize = 768;
const VECTOR_NORM_EPS: f32 = 1e-12;
const INSERT_JITTER_EPS: f32 = 1e-2;
#[derive(Debug, Error)]
pub enum RuVectorError {
#[error("Vector DB error: {0}")]
Db(String),
#[error("Table not found: {0}")]
TableNotFound(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Lock poisoned")]
LockPoisoned,
}
impl From<ruvector_core::error::RuvectorError> for RuVectorError {
fn from(e: ruvector_core::error::RuvectorError) -> Self {
RuVectorError::Db(e.to_string())
}
}
#[derive(Debug, Clone)]
pub struct VectorResult {
pub id: String,
pub distance: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HnswConfig {
pub m: u32,
pub ef_construction: u32,
pub ef_search: u32,
pub max_elements: u32,
}
impl Default for HnswConfig {
fn default() -> Self {
Self {
m: 16,
ef_construction: 200,
ef_search: 50,
max_elements: 100_000,
}
}
}
#[derive(Clone)]
pub struct RuVectorStore {
root: PathBuf,
dimensions: usize,
hnsw: HnswConfig,
tables: Arc<RwLock<HashMap<String, VectorDB>>>,
}
impl RuVectorStore {
pub async fn open(path: &Path, dimensions: usize) -> Result<Self, RuVectorError> {
Self::open_with_config(path, dimensions, HnswConfig::default()).await
}
pub async fn open_with_config(
path: &Path,
dimensions: usize,
hnsw: HnswConfig,
) -> Result<Self, RuVectorError> {
std::fs::create_dir_all(path)?;
info!(
m = hnsw.m,
ef_construction = hnsw.ef_construction,
ef_search = hnsw.ef_search,
max_elements = hnsw.max_elements,
"RuVector store opened at {} (dim={})",
path.display(),
dimensions
);
Ok(Self {
root: path.to_path_buf(),
dimensions,
hnsw,
tables: Arc::new(RwLock::new(HashMap::new())),
})
}
fn make_db(&self, table_name: &str) -> Result<VectorDB, RuVectorError> {
let db_path = self.root.join(format!("{table_name}.db"));
let options = DbOptions {
dimensions: self.dimensions,
distance_metric: DistanceMetric::Cosine,
storage_path: db_path.to_string_lossy().into_owned(),
hnsw_config: Some(RuvHnswConfig {
m: self.hnsw.m as usize,
ef_construction: self.hnsw.ef_construction as usize,
ef_search: self.hnsw.ef_search as usize,
max_elements: self.hnsw.max_elements as usize,
}),
quantization: None,
};
VectorDB::new(options).map_err(Into::into)
}
fn get_or_create_db(&self, table_name: &str) -> Result<(), RuVectorError> {
let has = self
.tables
.read()
.map_err(|_| RuVectorError::LockPoisoned)?
.contains_key(table_name);
if !has {
let db = self.make_db(table_name)?;
self.tables
.write()
.map_err(|_| RuVectorError::LockPoisoned)?
.insert(table_name.to_string(), db);
}
Ok(())
}
pub async fn ensure_tables(&self) -> Result<(), RuVectorError> {
const MAX_RETRIES: u32 = 5;
const BASE_DELAY_MS: u64 = 200;
for name in &["facts_vec", "episodes_vec", "graph_vec"] {
let mut last_err = None;
for attempt in 0..=MAX_RETRIES {
match self.get_or_create_db(name) {
Ok(()) => {
if attempt > 0 {
info!("RuVector table '{name}' opened after {attempt} retries");
} else {
info!("Ensured RuVector table: {name}");
}
last_err = None;
break;
}
Err(e) if attempt < MAX_RETRIES => {
let delay_ms = BASE_DELAY_MS * 2u64.pow(attempt);
warn!(
table = name,
attempt = attempt + 1,
delay_ms,
error = %e,
"RuVector table lock contention, retrying"
);
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
last_err = Some(e);
}
Err(e) => {
last_err = Some(e);
}
}
}
if let Some(e) = last_err {
return Err(e);
}
}
Ok(())
}
pub async fn add_vectors(
&self,
table_name: &str,
ids: Vec<String>,
_contents: Vec<String>,
vectors: Vec<Vec<f32>>,
_timestamps: Vec<String>,
_source_type: &str,
) -> Result<(), RuVectorError> {
self.get_or_create_db(table_name)?;
let tables = self
.tables
.read()
.map_err(|_| RuVectorError::LockPoisoned)?;
let db = tables
.get(table_name)
.ok_or_else(|| RuVectorError::TableNotFound(table_name.to_string()))?;
let count = ids.len();
for (id, vector) in ids.into_iter().zip(vectors) {
let safe_vector = sanitize_vector_for_insert(vector, self.dimensions, &id);
let entry = VectorEntry {
id: Some(id),
vector: safe_vector,
metadata: None,
};
db.insert(entry)?;
}
info!("Added {count} vectors to '{table_name}'");
Ok(())
}
pub async fn search(
&self,
table_name: &str,
query_vector: Vec<f32>,
top_k: usize,
) -> Result<Vec<VectorResult>, RuVectorError> {
self.get_or_create_db(table_name)?;
let tables = self
.tables
.read()
.map_err(|_| RuVectorError::LockPoisoned)?;
let db = tables
.get(table_name)
.ok_or_else(|| RuVectorError::TableNotFound(table_name.to_string()))?;
let safe_query = sanitize_vector_for_query(query_vector, self.dimensions, table_name);
let results = db.search(SearchQuery {
vector: safe_query,
k: top_k,
filter: None,
ef_search: None,
})?;
Ok(results
.into_iter()
.map(|r| VectorResult {
id: r.id,
distance: sanitize_distance(r.score),
})
.collect())
}
pub async fn delete(&self, table_name: &str, id: &str) -> Result<(), RuVectorError> {
let tables = self
.tables
.read()
.map_err(|_| RuVectorError::LockPoisoned)?;
if let Some(db) = tables.get(table_name) {
db.delete(id)?;
}
Ok(())
}
pub async fn delete_batch(
&self,
table_name: &str,
ids: &[&str],
) -> Result<Vec<(String, RuVectorError)>, RuVectorError> {
let tables = self
.tables
.read()
.map_err(|_| RuVectorError::LockPoisoned)?;
let mut failures = Vec::new();
if let Some(db) = tables.get(table_name) {
for id in ids {
if let Err(e) = db.delete(id) {
failures.push(((*id).to_string(), RuVectorError::from(e)));
}
}
}
Ok(failures)
}
pub async fn table_count(&self, table_name: &str) -> Result<usize, RuVectorError> {
let tables = self
.tables
.read()
.map_err(|_| RuVectorError::LockPoisoned)?;
Ok(tables
.get(table_name)
.map(|db| db.len().unwrap_or(0))
.unwrap_or(0))
}
pub async fn table_names(&self) -> Result<Vec<String>, RuVectorError> {
Ok(self
.tables
.read()
.map_err(|_| RuVectorError::LockPoisoned)?
.keys()
.cloned()
.collect())
}
}
fn sanitize_distance(score: f32) -> f32 {
if !score.is_finite() {
return f32::MAX;
}
if score < 0.0 {
return 0.0;
}
score
}
fn sanitize_vector_for_insert(vector: Vec<f32>, dimensions: usize, id: &str) -> Vec<f32> {
let mut out = sanitize_vector_for_query(vector, dimensions, id);
apply_insert_jitter(&mut out, id);
normalize_in_place_or_fallback(&mut out, id);
out
}
fn sanitize_vector_for_query(vector: Vec<f32>, dimensions: usize, seed: &str) -> Vec<f32> {
if dimensions == 0 {
return Vec::new();
}
if vector.len() != dimensions || vector.iter().any(|x| !x.is_finite()) {
warn!(
expected_dim = dimensions,
got_dim = vector.len(),
"Invalid embedding shape/value; using deterministic fallback"
);
return deterministic_fallback_vector(seed, dimensions);
}
let mut out = vector;
if !normalize_in_place_or_fallback(&mut out, seed) {
return deterministic_fallback_vector(seed, dimensions);
}
out
}
fn normalize_in_place_or_fallback(vector: &mut [f32], seed: &str) -> bool {
if vector.is_empty() {
return true;
}
let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
if !norm_sq.is_finite() || norm_sq <= VECTOR_NORM_EPS {
let fallback = deterministic_fallback_vector(seed, vector.len());
vector.copy_from_slice(&fallback);
return false;
}
let norm = norm_sq.sqrt();
for v in vector.iter_mut() {
*v /= norm;
}
true
}
fn apply_insert_jitter(vector: &mut [f32], id: &str) {
if vector.is_empty() {
return;
}
let mut hash: u64 = 0xcbf29ce484222325;
for b in id.as_bytes() {
hash ^= u64::from(*b);
hash = hash.wrapping_mul(0x100000001b3);
}
let idx_a = (hash as usize) % vector.len();
let idx_b = (hash.rotate_left(17) as usize) % vector.len();
let sign_a = if (hash & 1) == 0 { 1.0 } else { -1.0 };
let sign_b = if ((hash >> 1) & 1) == 0 { -1.0 } else { 1.0 };
vector[idx_a] += sign_a * INSERT_JITTER_EPS;
vector[idx_b] += sign_b * INSERT_JITTER_EPS * 0.5;
}
fn deterministic_fallback_vector(seed: &str, dimensions: usize) -> Vec<f32> {
if dimensions == 0 {
return Vec::new();
}
let mut state: u64 = 0xcbf29ce484222325;
for b in seed.as_bytes() {
state ^= u64::from(*b);
state = state.wrapping_mul(0x100000001b3);
}
if state == 0 {
state = 1;
}
let mut out = Vec::with_capacity(dimensions);
for _ in 0..dimensions {
state ^= state >> 12;
state ^= state << 25;
state ^= state >> 27;
let r = state.wrapping_mul(0x2545f4914f6cdd1d);
let unit = (r as f64 / u64::MAX as f64) as f32;
out.push(unit * 2.0 - 1.0);
}
let norm = out.iter().map(|x| x * x).sum::<f32>().sqrt();
if !norm.is_finite() || norm <= VECTOR_NORM_EPS {
let mut unit = vec![0.0_f32; dimensions];
unit[0] = 1.0;
return unit;
}
for v in &mut out {
*v /= norm;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
async fn temp_store() -> (RuVectorStore, tempfile::TempDir) {
let dir = tempfile::tempdir().unwrap();
let store = RuVectorStore::open(dir.path(), VECTOR_DIM).await.unwrap();
(store, dir)
}
#[tokio::test]
async fn open_with_config_persists_tuning() {
let dir = tempfile::tempdir().unwrap();
let custom = HnswConfig {
m: 32,
ef_construction: 400,
ef_search: 100,
max_elements: 5_000_000,
};
let store = RuVectorStore::open_with_config(dir.path(), VECTOR_DIM, custom)
.await
.unwrap();
assert_eq!(store.hnsw, custom);
let default_store = RuVectorStore::open(dir.path(), VECTOR_DIM).await.unwrap();
assert_eq!(default_store.hnsw, HnswConfig::default());
}
fn unit_vec(axis: usize) -> Vec<f32> {
let mut v = vec![0.0f32; VECTOR_DIM];
v[axis] = 1.0;
v
}
#[tokio::test]
async fn test_open_and_ensure_tables() {
let (store, _dir) = temp_store().await;
store.ensure_tables().await.unwrap();
let mut tables = store.table_names().await.unwrap();
tables.sort();
assert!(tables.contains(&"episodes_vec".to_string()));
assert!(tables.contains(&"facts_vec".to_string()));
}
#[tokio::test]
async fn test_ensure_tables_idempotent() {
let (store, _dir) = temp_store().await;
store.ensure_tables().await.unwrap();
store.ensure_tables().await.unwrap();
}
#[tokio::test]
async fn test_add_and_count() {
let (store, _dir) = temp_store().await;
store.ensure_tables().await.unwrap();
store
.add_vectors(
"episodes_vec",
vec!["ep001".into()],
vec![],
vec![unit_vec(0)],
vec![],
"episodic",
)
.await
.unwrap();
assert_eq!(store.table_count("episodes_vec").await.unwrap(), 1);
}
#[tokio::test]
async fn test_vector_search() {
let (store, _dir) = temp_store().await;
store.ensure_tables().await.unwrap();
let v1 = unit_vec(0);
let v2 = unit_vec(1);
let mut v3 = vec![0.0f32; VECTOR_DIM];
v3[0] = 0.9;
v3[1] = 0.1;
store
.add_vectors(
"facts_vec",
vec!["f1".into(), "f2".into(), "f3".into()],
vec![],
vec![v1.clone(), v2, v3],
vec![],
"semantic",
)
.await
.unwrap();
let results = store.search("facts_vec", v1, 2).await.unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, "f1");
}
#[tokio::test]
async fn test_delete() {
let (store, _dir) = temp_store().await;
store.ensure_tables().await.unwrap();
store
.add_vectors(
"facts_vec",
vec!["f1".into()],
vec![],
vec![unit_vec(0)],
vec![],
"semantic",
)
.await
.unwrap();
assert_eq!(store.table_count("facts_vec").await.unwrap(), 1);
store.delete("facts_vec", "f1").await.unwrap();
assert_eq!(store.table_count("facts_vec").await.unwrap(), 0);
}
#[tokio::test]
async fn test_identical_vectors_with_different_ids_do_not_panic() {
let (store, _dir) = temp_store().await;
store.ensure_tables().await.unwrap();
let repeated = unit_vec(0);
for i in 0..64 {
store
.add_vectors(
"facts_vec",
vec![format!("dup-{i}")],
vec![],
vec![repeated.clone()],
vec![],
"semantic",
)
.await
.unwrap();
}
let results = store.search("facts_vec", unit_vec(0), 5).await.unwrap();
assert!(!results.is_empty());
assert!(results.iter().all(|r| r.distance.is_finite()));
}
#[tokio::test]
async fn test_invalid_or_zero_vectors_are_sanitized() {
let (store, _dir) = temp_store().await;
store.ensure_tables().await.unwrap();
store
.add_vectors(
"facts_vec",
vec!["zero".into(), "nan".into()],
vec![],
vec![vec![0.0_f32; VECTOR_DIM], vec![f32::NAN; VECTOR_DIM]],
vec![],
"semantic",
)
.await
.unwrap();
let results = store
.search("facts_vec", vec![0.0_f32; VECTOR_DIM], 2)
.await
.unwrap();
assert_eq!(results.len(), 2);
assert!(results.iter().all(|r| r.distance.is_finite()));
assert!(results.iter().all(|r| r.distance >= 0.0));
}
}