use crate::core::error::{Error, Result, VectorError};
use crate::core::hasher::IdentityHasher;
use crate::core::id::NodeId;
use crate::core::property::MAX_VECTOR_DIMENSIONS;
use crate::core::vector::SparseVec;
use bitcode::{Decode, Encode};
use crc32fast::Hasher;
use dashmap::DashMap;
use parking_lot::Mutex;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::fs;
use std::hash::BuildHasherDefault;
use std::io::Write;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
const MAX_K: usize = 100_000;
const SPARSE_INDEX_MAGIC: [u8; 4] = [0x41, 0x53, 0x50, 0x53];
const SPARSE_INDEX_VERSION: u16 = 1;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
#[cfg_attr(feature = "config-toml", derive(serde::Serialize, serde::Deserialize))]
pub enum ScoringMethod {
#[default]
DotProduct,
Cosine,
BM25 {
k1: f32,
b: f32,
},
}
impl ScoringMethod {
pub fn bm25_default() -> Self {
ScoringMethod::BM25 { k1: 1.5, b: 0.75 }
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "config-toml", derive(serde::Serialize, serde::Deserialize))]
pub struct SparseIndexConfig {
pub dimensions: usize,
pub scoring: ScoringMethod,
pub initial_capacity: usize,
}
impl Default for SparseIndexConfig {
fn default() -> Self {
SparseIndexConfig {
dimensions: 0,
scoring: ScoringMethod::default(),
initial_capacity: 1000,
}
}
}
impl SparseIndexConfig {
pub fn new(dimensions: usize) -> Self {
SparseIndexConfig {
dimensions,
..Default::default()
}
}
pub fn with_scoring(mut self, scoring: ScoringMethod) -> Self {
self.scoring = scoring;
self
}
pub fn with_capacity(mut self, capacity: usize) -> Self {
self.initial_capacity = capacity;
self
}
}
#[derive(Debug, Clone)]
struct Posting {
node_id: NodeId,
value: f32,
}
#[derive(Debug, Clone)]
struct ScoreEntry {
node_id: NodeId,
score: f32,
}
impl PartialEq for ScoreEntry {
fn eq(&self, other: &Self) -> bool {
self.score.total_cmp(&other.score) == Ordering::Equal && self.node_id == other.node_id
}
}
impl Eq for ScoreEntry {}
impl PartialOrd for ScoreEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScoreEntry {
fn cmp(&self, other: &Self) -> Ordering {
other.score.total_cmp(&self.score)
}
}
#[derive(Debug, Clone)]
struct StoredVector {
vector: Arc<SparseVec>,
magnitude: f32,
}
pub struct SparseVectorIndex {
config: SparseIndexConfig,
inverted_index: DashMap<u32, Vec<Posting>>,
vectors: DashMap<NodeId, StoredVector>,
count: AtomicUsize,
total_length: AtomicUsize,
doc_freq: DashMap<u32, usize>,
write_lock: Mutex<()>,
}
impl SparseVectorIndex {
pub fn new(config: SparseIndexConfig) -> Result<Self> {
if config.dimensions == 0 {
return Err(Error::Vector(VectorError::InvalidVector {
reason: "Dimensions must be greater than 0".to_string(),
}));
}
if config.dimensions > MAX_VECTOR_DIMENSIONS {
return Err(Error::Vector(VectorError::DimensionTooLarge {
dimension: config.dimensions,
max_allowed: MAX_VECTOR_DIMENSIONS,
}));
}
let capacity = config.initial_capacity;
Ok(SparseVectorIndex {
config,
inverted_index: DashMap::with_capacity(capacity),
vectors: DashMap::with_capacity(capacity),
count: AtomicUsize::new(0),
total_length: AtomicUsize::new(0),
doc_freq: DashMap::with_capacity(capacity),
write_lock: Mutex::new(()),
})
}
pub fn add(&self, id: NodeId, vector: &SparseVec) -> Result<()> {
if vector.dimension() != self.config.dimensions {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: self.config.dimensions,
actual: vector.dimension(),
}));
}
let _guard = self.write_lock.lock();
self.remove_internal_unlocked(id);
let magnitude = vector.magnitude();
let stored = StoredVector {
vector: Arc::new(vector.clone()),
magnitude,
};
self.vectors.insert(id, stored);
for (&dim, &val) in vector.indices().iter().zip(vector.values().iter()) {
self.inverted_index.entry(dim).or_default().push(Posting {
node_id: id,
value: val,
});
*self.doc_freq.entry(dim).or_insert(0) += 1;
}
self.total_length
.fetch_add(vector.nnz(), AtomicOrdering::Relaxed);
self.count.fetch_add(1, AtomicOrdering::Release);
Ok(())
}
pub fn remove(&self, id: NodeId) -> Result<()> {
let _guard = self.write_lock.lock();
self.remove_internal_unlocked(id);
Ok(())
}
fn remove_internal_unlocked(&self, id: NodeId) -> bool {
if let Some((_, stored)) = self.vectors.remove(&id) {
let vec = &stored.vector;
for &dim in vec.indices() {
if let Some(mut postings) = self.inverted_index.get_mut(&dim) {
postings.retain(|p| p.node_id != id);
}
if let Some(mut freq) = self.doc_freq.get_mut(&dim) {
*freq = freq.saturating_sub(1);
}
}
self.total_length
.fetch_sub(vec.nnz(), AtomicOrdering::Relaxed);
self.count.fetch_sub(1, AtomicOrdering::Release);
true
} else {
false
}
}
#[must_use = "search results should be used"]
pub fn search(&self, query: &SparseVec, k: usize) -> Result<Vec<(NodeId, f32)>> {
self.search_with_filter(query, k, |_| true)
}
#[must_use = "search results should be used"]
pub fn search_with_filter<F>(
&self,
query: &SparseVec,
k: usize,
predicate: F,
) -> Result<Vec<(NodeId, f32)>>
where
F: Fn(&NodeId) -> bool + Send + Sync,
{
if query.dimension() != self.config.dimensions {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: self.config.dimensions,
actual: query.dimension(),
}));
}
let k = k.min(MAX_K);
if k == 0 || self.is_empty() {
return Ok(Vec::new());
}
let is_cosine = matches!(self.config.scoring, ScoringMethod::Cosine);
let mut scores: HashMap<NodeId, f32, BuildHasherDefault<IdentityHasher>> =
HashMap::default();
let mut magnitudes: HashMap<NodeId, f32, BuildHasherDefault<IdentityHasher>> =
HashMap::default();
let mut doc_lengths: HashMap<NodeId, f32, BuildHasherDefault<IdentityHasher>> =
HashMap::default();
let query_magnitude = query.magnitude();
let n = self.count.load(AtomicOrdering::Acquire) as f32;
let avgdl = if n > 0.0 {
self.total_length.load(AtomicOrdering::Acquire) as f32 / n
} else {
1.0
};
for (&dim, &query_val) in query.indices().iter().zip(query.values().iter()) {
if let Some(postings) = self.inverted_index.get(&dim) {
let df = self.doc_freq.get(&dim).map(|v| *v).unwrap_or(0) as f32;
let idf = if df > 0.0 && n > 0.0 {
((n - df + 0.5) / (df + 0.5) + 1.0).ln().max(0.0)
} else {
0.0
};
for posting in postings.iter() {
if !predicate(&posting.node_id) {
continue;
}
let score_delta = match self.config.scoring {
ScoringMethod::DotProduct => query_val * posting.value,
ScoringMethod::Cosine => {
if !magnitudes.contains_key(&posting.node_id)
&& let Some(stored) = self.vectors.get(&posting.node_id)
{
magnitudes.insert(posting.node_id, stored.magnitude);
}
query_val * posting.value
}
ScoringMethod::BM25 { k1, b } => {
let dl = *doc_lengths.entry(posting.node_id).or_insert_with(|| {
self.vectors
.get(&posting.node_id)
.map(|v| v.vector.nnz() as f32)
.unwrap_or(1.0)
});
let tf = posting.value;
let numerator = tf * (k1 + 1.0);
let denominator = tf + k1 * (1.0 - b + b * dl / avgdl);
idf * numerator / denominator * query_val
}
};
*scores.entry(posting.node_id).or_insert(0.0) += score_delta;
}
}
}
if is_cosine && query_magnitude > 0.0 {
for (&node_id, score) in scores.iter_mut() {
if let Some(&mag) = magnitudes.get(&node_id)
&& mag > 0.0
{
*score /= query_magnitude * mag;
}
}
}
let mut heap: BinaryHeap<ScoreEntry> = BinaryHeap::with_capacity(k + 1);
for (node_id, score) in scores {
heap.push(ScoreEntry { node_id, score });
if heap.len() > k {
heap.pop();
}
}
let mut results: Vec<(NodeId, f32)> =
heap.into_iter().map(|e| (e.node_id, e.score)).collect();
results.sort_by(|a, b| b.1.total_cmp(&a.1));
Ok(results)
}
#[must_use]
pub fn len(&self) -> usize {
self.count.load(AtomicOrdering::Acquire)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn dimensions(&self) -> usize {
self.config.dimensions
}
#[must_use]
pub fn scoring(&self) -> ScoringMethod {
self.config.scoring
}
#[must_use]
pub fn config(&self) -> &SparseIndexConfig {
&self.config
}
#[must_use]
pub fn contains(&self, id: NodeId) -> bool {
self.vectors.contains_key(&id)
}
#[must_use]
pub fn get(&self, id: NodeId) -> Option<Arc<SparseVec>> {
self.vectors.get(&id).map(|v| Arc::clone(&v.vector))
}
#[must_use]
pub fn memory_usage(&self) -> usize {
let posting_size = 16;
let vector_overhead = 48;
let element_size = 8;
let mut total = 0;
for entry in self.inverted_index.iter() {
total += entry.value().len() * posting_size;
}
for entry in self.vectors.iter() {
total += vector_overhead + entry.value().vector.nnz() * element_size;
}
total
}
pub fn save(&self, path: &Path) -> Result<()> {
let _guard = self.write_lock.lock();
let mut vectors = Vec::with_capacity(self.len());
for entry in self.vectors.iter() {
let node_id = entry.key();
let stored = entry.value();
vectors.push(PersistedSparseVector {
node_id: node_id.as_u64(),
indices: stored.vector.indices().to_vec(),
values: stored.vector.values().to_vec(),
});
}
let doc_freq: Vec<(u32, u64)> = self
.doc_freq
.iter()
.map(|entry| (*entry.key(), *entry.value() as u64))
.collect();
let data = SparseIndexData {
dimensions: self.config.dimensions as u32,
scoring: self.config.scoring.into(),
count: self.count.load(AtomicOrdering::Acquire) as u64,
total_length: self.total_length.load(AtomicOrdering::Acquire) as u64,
vectors,
doc_freq,
};
let encoded = bitcode::encode(&data);
let mut file_data = Vec::with_capacity(4 + 2 + encoded.len() + 4);
file_data.extend_from_slice(&SPARSE_INDEX_MAGIC);
file_data.extend_from_slice(&SPARSE_INDEX_VERSION.to_le_bytes());
file_data.extend_from_slice(&encoded);
let mut hasher = Hasher::new();
hasher.update(&file_data);
let crc = hasher.finalize();
file_data.extend_from_slice(&crc.to_le_bytes());
let temp_path = path.with_extension("tmp");
let mut file = fs::File::create(&temp_path).map_err(|e| {
Error::Vector(VectorError::IndexError(format!(
"Failed to create temp file: {}",
e
)))
})?;
file.write_all(&file_data).map_err(|e| {
Error::Vector(VectorError::IndexError(format!(
"Failed to write sparse index: {}",
e
)))
})?;
file.sync_all().map_err(|e| {
Error::Vector(VectorError::IndexError(format!(
"Failed to sync sparse index: {}",
e
)))
})?;
drop(file);
fs::rename(&temp_path, path).map_err(|e| {
Error::Vector(VectorError::IndexError(format!(
"Failed to rename temp file: {}",
e
)))
})?;
Ok(())
}
pub fn load(path: &Path, config: SparseIndexConfig) -> Result<Self> {
let file_data = fs::read(path).map_err(|e| {
Error::Vector(VectorError::IndexError(format!(
"Failed to read sparse index file: {}",
e
)))
})?;
if file_data.len() < 10 {
return Err(Error::Vector(VectorError::IndexError(
"Sparse index file too small to be valid".to_string(),
)));
}
let magic: [u8; 4] = file_data[0..4].try_into().map_err(|_| {
Error::Vector(VectorError::IndexError(
"Failed to read magic bytes".to_string(),
))
})?;
if magic != SPARSE_INDEX_MAGIC {
return Err(Error::Vector(VectorError::IndexError(format!(
"Invalid magic bytes: expected {:?}, got {:?}",
SPARSE_INDEX_MAGIC, magic
))));
}
let version = u16::from_le_bytes(file_data[4..6].try_into().map_err(|_| {
Error::Vector(VectorError::IndexError(
"Failed to read version".to_string(),
))
})?);
if version > SPARSE_INDEX_VERSION {
return Err(Error::Vector(VectorError::IndexError(format!(
"Unsupported sparse index version: {} (max supported: {})",
version, SPARSE_INDEX_VERSION
))));
}
let crc_offset = file_data.len() - 4;
let stored_crc = u32::from_le_bytes(file_data[crc_offset..].try_into().map_err(|_| {
Error::Vector(VectorError::IndexError("Failed to read CRC32".to_string()))
})?);
let mut hasher = Hasher::new();
hasher.update(&file_data[..crc_offset]);
let computed_crc = hasher.finalize();
if stored_crc != computed_crc {
return Err(Error::Vector(VectorError::IndexError(format!(
"CRC32 mismatch: stored={:#x}, computed={:#x}",
stored_crc, computed_crc
))));
}
let encoded_data = &file_data[6..crc_offset];
let data: SparseIndexData = bitcode::decode(encoded_data).map_err(|e| {
Error::Vector(VectorError::IndexError(format!(
"Failed to decode sparse index: {}",
e
)))
})?;
if data.dimensions as usize != config.dimensions {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: config.dimensions,
actual: data.dimensions as usize,
}));
}
let loaded_config = SparseIndexConfig {
dimensions: data.dimensions as usize,
scoring: data.scoring.into(),
initial_capacity: data.count as usize,
};
let index = SparseVectorIndex {
config: loaded_config,
inverted_index: DashMap::with_capacity(data.count as usize),
vectors: DashMap::with_capacity(data.count as usize),
count: AtomicUsize::new(data.count as usize),
total_length: AtomicUsize::new(data.total_length as usize),
doc_freq: DashMap::with_capacity(data.doc_freq.len()),
write_lock: Mutex::new(()),
};
for (dim, freq) in data.doc_freq {
index.doc_freq.insert(dim, freq as usize);
}
for persisted in data.vectors {
let node_id = NodeId::new(persisted.node_id).map_err(|_| {
Error::Vector(VectorError::IndexError(format!(
"Invalid node ID: {}",
persisted.node_id
)))
})?;
let vector = SparseVec::new(persisted.indices, persisted.values, data.dimensions)?;
let magnitude = vector.magnitude();
let stored = StoredVector {
vector: Arc::new(vector),
magnitude,
};
index.vectors.insert(node_id, stored.clone());
for (&dim, &val) in stored
.vector
.indices()
.iter()
.zip(stored.vector.values().iter())
{
index.inverted_index.entry(dim).or_default().push(Posting {
node_id,
value: val,
});
}
}
Ok(index)
}
pub fn compact(&self) {
let _guard = self.write_lock.lock();
self.inverted_index
.retain(|_, postings| !postings.is_empty());
for mut entry in self.inverted_index.iter_mut() {
entry.value_mut().shrink_to_fit();
}
self.doc_freq.retain(|_, &mut freq| freq > 0);
}
#[must_use]
pub fn stats(&self) -> SparseIndexStats {
let mut total_postings = 0;
let mut non_empty_dimensions = 0;
let mut max_posting_length = 0;
for entry in self.inverted_index.iter() {
let len = entry.value().len();
if len > 0 {
non_empty_dimensions += 1;
total_postings += len;
max_posting_length = max_posting_length.max(len);
}
}
SparseIndexStats {
num_vectors: self.len(),
dimensions: self.config.dimensions,
non_empty_dimensions,
total_postings,
avg_posting_length: if non_empty_dimensions > 0 {
total_postings as f32 / non_empty_dimensions as f32
} else {
0.0
},
max_posting_length,
avg_vector_nnz: if !self.is_empty() {
self.total_length.load(AtomicOrdering::Acquire) as f32 / self.len() as f32
} else {
0.0
},
memory_usage: self.memory_usage(),
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "config-toml", derive(serde::Serialize, serde::Deserialize))]
pub struct SparseIndexStats {
pub num_vectors: usize,
pub dimensions: usize,
pub non_empty_dimensions: usize,
pub total_postings: usize,
pub avg_posting_length: f32,
pub max_posting_length: usize,
pub avg_vector_nnz: f32,
pub memory_usage: usize,
}
#[derive(Debug, Clone, Encode, Decode)]
enum PersistedScoringMethod {
DotProduct,
Cosine,
BM25 { k1: f32, b: f32 },
}
impl From<ScoringMethod> for PersistedScoringMethod {
fn from(method: ScoringMethod) -> Self {
match method {
ScoringMethod::DotProduct => PersistedScoringMethod::DotProduct,
ScoringMethod::Cosine => PersistedScoringMethod::Cosine,
ScoringMethod::BM25 { k1, b } => PersistedScoringMethod::BM25 { k1, b },
}
}
}
impl From<PersistedScoringMethod> for ScoringMethod {
fn from(method: PersistedScoringMethod) -> Self {
match method {
PersistedScoringMethod::DotProduct => ScoringMethod::DotProduct,
PersistedScoringMethod::Cosine => ScoringMethod::Cosine,
PersistedScoringMethod::BM25 { k1, b } => ScoringMethod::BM25 { k1, b },
}
}
}
#[derive(Debug, Clone, Encode, Decode)]
struct PersistedSparseVector {
node_id: u64,
indices: Vec<u32>,
values: Vec<f32>,
}
#[derive(Debug, Clone, Encode, Decode)]
struct SparseIndexData {
dimensions: u32,
scoring: PersistedScoringMethod,
count: u64,
total_length: u64,
vectors: Vec<PersistedSparseVector>,
doc_freq: Vec<(u32, u64)>,
}
pub fn hybrid_fusion(
dense_results: &[(NodeId, f32)],
sparse_results: &[(NodeId, f32)],
alpha: f32,
k: usize,
) -> Vec<(NodeId, f32)> {
let alpha = alpha.clamp(0.0, 1.0);
let k = k.min(MAX_K);
if dense_results.is_empty() && sparse_results.is_empty() {
return Vec::new();
}
let dense_normalized = normalize_scores(dense_results);
let sparse_normalized = normalize_scores(sparse_results);
let mut combined: HashMap<NodeId, f32, BuildHasherDefault<IdentityHasher>> = HashMap::default();
for (id, score) in dense_normalized {
*combined.entry(id).or_insert(0.0) += alpha * score;
}
for (id, score) in sparse_normalized {
*combined.entry(id).or_insert(0.0) += (1.0 - alpha) * score;
}
let mut results: Vec<(NodeId, f32)> = combined.into_iter().collect();
results.sort_by(|a, b| b.1.total_cmp(&a.1));
results.truncate(k);
results
}
pub fn reciprocal_rank_fusion(
dense_results: &[(NodeId, f32)],
sparse_results: &[(NodeId, f32)],
k_constant: f32,
k: usize,
) -> Vec<(NodeId, f32)> {
let k = k.min(MAX_K);
let k_constant = k_constant.max(1.0);
let mut rrf_scores: HashMap<NodeId, f32, BuildHasherDefault<IdentityHasher>> =
HashMap::default();
for (rank, (id, _)) in dense_results.iter().enumerate() {
*rrf_scores.entry(*id).or_insert(0.0) += 1.0 / (k_constant + rank as f32 + 1.0);
}
for (rank, (id, _)) in sparse_results.iter().enumerate() {
*rrf_scores.entry(*id).or_insert(0.0) += 1.0 / (k_constant + rank as f32 + 1.0);
}
let mut results: Vec<(NodeId, f32)> = rrf_scores.into_iter().collect();
results.sort_by(|a, b| b.1.total_cmp(&a.1));
results.truncate(k);
results
}
fn normalize_scores(results: &[(NodeId, f32)]) -> Vec<(NodeId, f32)> {
if results.is_empty() {
return Vec::new();
}
let min_score = results
.iter()
.map(|(_, s)| *s)
.fold(f32::INFINITY, f32::min);
let max_score = results
.iter()
.map(|(_, s)| *s)
.fold(f32::NEG_INFINITY, f32::max);
let range = max_score - min_score;
if range == 0.0 {
return results.iter().map(|(id, _)| (*id, 1.0)).collect();
}
results
.iter()
.map(|(id, score)| (*id, (score - min_score) / range))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn test_create_sparse_index() {
let config = SparseIndexConfig::new(10_000);
let index = SparseVectorIndex::new(config).unwrap();
assert_eq!(index.dimensions(), 10_000);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_create_sparse_index_zero_dimensions_fails() {
let config = SparseIndexConfig::new(0);
let result = SparseVectorIndex::new(config);
assert!(result.is_err());
}
#[test]
fn test_add_and_retrieve_vector() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
let node_id = NodeId::new(1).unwrap();
let vector = SparseVec::new(vec![10, 50, 90], vec![1.0, 2.0, 3.0], 100).unwrap();
index.add(node_id, &vector).unwrap();
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
assert!(index.contains(node_id));
let retrieved = index.get(node_id).unwrap();
assert_eq!(retrieved.dimension(), 100);
assert_eq!(retrieved.nnz(), 3);
}
#[test]
fn test_add_dimension_mismatch() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
let node_id = NodeId::new(1).unwrap();
let vector = SparseVec::new(vec![10], vec![1.0], 200).unwrap();
let result = index.add(node_id, &vector);
assert!(matches!(
result,
Err(Error::Vector(VectorError::DimensionMismatch { .. }))
));
}
#[test]
fn test_remove_vector() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
let node_id = NodeId::new(1).unwrap();
let vector = SparseVec::new(vec![10, 50], vec![1.0, 2.0], 100).unwrap();
index.add(node_id, &vector).unwrap();
assert_eq!(index.len(), 1);
index.remove(node_id).unwrap();
assert_eq!(index.len(), 0);
assert!(!index.contains(node_id));
}
#[test]
fn test_remove_nonexistent_vector() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
let node_id = NodeId::new(999).unwrap();
index.remove(node_id).unwrap();
assert_eq!(index.len(), 0);
}
#[test]
fn test_update_existing_vector() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
let node_id = NodeId::new(1).unwrap();
let vector1 = SparseVec::new(vec![10], vec![1.0], 100).unwrap();
index.add(node_id, &vector1).unwrap();
assert_eq!(index.len(), 1);
let vector2 = SparseVec::new(vec![20, 30], vec![2.0, 3.0], 100).unwrap();
index.add(node_id, &vector2).unwrap();
assert_eq!(index.len(), 1);
let retrieved = index.get(node_id).unwrap();
assert_eq!(retrieved.nnz(), 2);
assert_eq!(retrieved.indices(), &[20, 30]);
}
#[test]
fn test_search_dot_product_basic() {
let config = SparseIndexConfig::new(100).with_scoring(ScoringMethod::DotProduct);
let index = SparseVectorIndex::new(config).unwrap();
let doc1 = SparseVec::new(vec![0, 5, 10], vec![1.0, 2.0, 3.0], 100).unwrap();
let doc2 = SparseVec::new(vec![5, 10, 15], vec![1.0, 1.0, 1.0], 100).unwrap();
index.add(NodeId::new(1).unwrap(), &doc1).unwrap();
index.add(NodeId::new(2).unwrap(), &doc2).unwrap();
let query = SparseVec::new(vec![5, 10], vec![1.0, 1.0], 100).unwrap();
let results = index.search(&query, 10).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, NodeId::new(1).unwrap());
assert!((results[0].1 - 5.0).abs() < 1e-6);
assert_eq!(results[1].0, NodeId::new(2).unwrap());
assert!((results[1].1 - 2.0).abs() < 1e-6);
}
#[test]
fn test_search_no_overlap() {
let config = SparseIndexConfig::new(100);
let index = SparseVectorIndex::new(config).unwrap();
let doc = SparseVec::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0], 100).unwrap();
index.add(NodeId::new(1).unwrap(), &doc).unwrap();
let query = SparseVec::new(vec![50, 60], vec![1.0, 1.0], 100).unwrap();
let results = index.search(&query, 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_empty_index() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
let query = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
let results = index.search(&query, 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_k_zero() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
let doc = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
index.add(NodeId::new(1).unwrap(), &doc).unwrap();
let query = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
let results = index.search(&query, 0).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_top_k() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
for i in 1..=10 {
let doc = SparseVec::new(vec![0], vec![i as f32], 100).unwrap();
index.add(NodeId::new(i).unwrap(), &doc).unwrap();
}
let query = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
let results = index.search(&query, 3).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, NodeId::new(10).unwrap());
assert_eq!(results[1].0, NodeId::new(9).unwrap());
assert_eq!(results[2].0, NodeId::new(8).unwrap());
}
#[test]
fn test_search_cosine_similarity() {
let config = SparseIndexConfig::new(100).with_scoring(ScoringMethod::Cosine);
let index = SparseVectorIndex::new(config).unwrap();
let doc1 = SparseVec::new(vec![0, 1], vec![1.0, 1.0], 100).unwrap();
let doc2 = SparseVec::new(vec![0, 1], vec![10.0, 10.0], 100).unwrap();
index.add(NodeId::new(1).unwrap(), &doc1).unwrap();
index.add(NodeId::new(2).unwrap(), &doc2).unwrap();
let query = SparseVec::new(vec![0, 1], vec![1.0, 1.0], 100).unwrap();
let results = index.search(&query, 10).unwrap();
assert_eq!(results.len(), 2);
assert!((results[0].1 - 1.0).abs() < 1e-5);
assert!((results[1].1 - 1.0).abs() < 1e-5);
}
#[test]
fn test_search_cosine_orthogonal() {
let config = SparseIndexConfig::new(100).with_scoring(ScoringMethod::Cosine);
let index = SparseVectorIndex::new(config).unwrap();
let doc = SparseVec::new(vec![0, 1], vec![1.0, 1.0], 100).unwrap();
index.add(NodeId::new(1).unwrap(), &doc).unwrap();
let query = SparseVec::new(vec![50, 51], vec![1.0, 1.0], 100).unwrap();
let results = index.search(&query, 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_bm25() {
let config = SparseIndexConfig::new(100).with_scoring(ScoringMethod::bm25_default());
let index = SparseVectorIndex::new(config).unwrap();
let doc1 = SparseVec::new(vec![0, 1, 2], vec![3.0, 1.0, 1.0], 100).unwrap(); let doc2 = SparseVec::new(vec![0, 3, 4], vec![1.0, 1.0, 1.0], 100).unwrap();
index.add(NodeId::new(1).unwrap(), &doc1).unwrap();
index.add(NodeId::new(2).unwrap(), &doc2).unwrap();
let query = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
let results = index.search(&query, 10).unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].1 > results[1].1);
}
#[test]
fn test_search_with_filter() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
for i in 1..=10 {
let doc = SparseVec::new(vec![0], vec![i as f32], 100).unwrap();
index.add(NodeId::new(i).unwrap(), &doc).unwrap();
}
let query = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
let allowed: HashSet<NodeId> = (1..=10)
.filter(|i| i % 2 == 0)
.map(|i| NodeId::new(i).unwrap())
.collect();
let results = index
.search_with_filter(&query, 10, |id| allowed.contains(id))
.unwrap();
assert_eq!(results.len(), 5);
for (id, _) in &results {
assert!(allowed.contains(id));
}
}
#[test]
fn test_index_stats() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(1000)).unwrap();
let doc1 = SparseVec::new(vec![0, 100, 500], vec![1.0, 2.0, 3.0], 1000).unwrap();
let doc2 = SparseVec::new(vec![0, 200], vec![1.0, 1.0], 1000).unwrap();
let doc3 = SparseVec::new(vec![0, 100, 200, 300], vec![1.0, 1.0, 1.0, 1.0], 1000).unwrap();
index.add(NodeId::new(1).unwrap(), &doc1).unwrap();
index.add(NodeId::new(2).unwrap(), &doc2).unwrap();
index.add(NodeId::new(3).unwrap(), &doc3).unwrap();
let stats = index.stats();
assert_eq!(stats.num_vectors, 3);
assert_eq!(stats.dimensions, 1000);
assert!(stats.non_empty_dimensions > 0);
assert_eq!(stats.total_postings, 9); assert!(stats.avg_vector_nnz > 0.0);
}
#[test]
fn test_hybrid_fusion_basic() {
let dense = vec![
(NodeId::new(1).unwrap(), 0.9),
(NodeId::new(2).unwrap(), 0.85),
(NodeId::new(4).unwrap(), 0.7),
];
let sparse = vec![
(NodeId::new(2).unwrap(), 10.0),
(NodeId::new(3).unwrap(), 8.0),
(NodeId::new(4).unwrap(), 6.0),
];
let fused = hybrid_fusion(&dense, &sparse, 0.5, 10);
assert_eq!(fused.len(), 4);
assert_eq!(fused[0].0, NodeId::new(2).unwrap());
}
#[test]
fn test_hybrid_fusion_dense_only() {
let dense = vec![
(NodeId::new(1).unwrap(), 0.9),
(NodeId::new(2).unwrap(), 0.8),
];
let sparse: Vec<(NodeId, f32)> = vec![];
let fused = hybrid_fusion(&dense, &sparse, 0.5, 10);
assert_eq!(fused.len(), 2);
assert_eq!(fused[0].0, NodeId::new(1).unwrap());
}
#[test]
fn test_hybrid_fusion_sparse_only() {
let dense: Vec<(NodeId, f32)> = vec![];
let sparse = vec![
(NodeId::new(1).unwrap(), 10.0),
(NodeId::new(2).unwrap(), 8.0),
];
let fused = hybrid_fusion(&dense, &sparse, 0.5, 10);
assert_eq!(fused.len(), 2);
assert_eq!(fused[0].0, NodeId::new(1).unwrap());
}
#[test]
fn test_hybrid_fusion_alpha_extremes() {
let dense = vec![(NodeId::new(1).unwrap(), 0.9)];
let sparse = vec![(NodeId::new(2).unwrap(), 10.0)];
let fused = hybrid_fusion(&dense, &sparse, 1.0, 10);
assert_eq!(fused[0].0, NodeId::new(1).unwrap());
assert!(fused[0].1 > fused[1].1);
let fused = hybrid_fusion(&dense, &sparse, 0.0, 10);
assert_eq!(fused[0].0, NodeId::new(2).unwrap());
}
#[test]
fn test_reciprocal_rank_fusion() {
let dense = vec![
(NodeId::new(1).unwrap(), 0.9),
(NodeId::new(2).unwrap(), 0.8),
(NodeId::new(3).unwrap(), 0.7),
];
let sparse = vec![
(NodeId::new(2).unwrap(), 10.0),
(NodeId::new(4).unwrap(), 8.0),
(NodeId::new(1).unwrap(), 6.0),
];
let fused = reciprocal_rank_fusion(&dense, &sparse, 60.0, 10);
assert!(fused.len() <= 4);
}
#[test]
fn test_concurrent_adds() {
use std::thread;
let index = Arc::new(SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap());
let mut handles = vec![];
for i in 0..10 {
let index_clone = Arc::clone(&index);
let handle = thread::spawn(move || {
let doc = SparseVec::new(vec![i as u32], vec![1.0], 100).unwrap();
index_clone.add(NodeId::new(i + 1).unwrap(), &doc).unwrap();
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(index.len(), 10);
}
#[test]
fn test_concurrent_search() {
use std::thread;
let index = Arc::new(SparseVectorIndex::new(SparseIndexConfig::new(200)).unwrap());
for i in 1..=100 {
let doc = SparseVec::new(vec![0, i as u32], vec![1.0, i as f32], 200).unwrap();
index.add(NodeId::new(i).unwrap(), &doc).unwrap();
}
let mut handles = vec![];
for _ in 0..10 {
let index_clone = Arc::clone(&index);
let handle = thread::spawn(move || {
let query = SparseVec::new(vec![0], vec![1.0], 200).unwrap();
let results = index_clone.search(&query, 10).unwrap();
assert_eq!(results.len(), 10);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_empty_sparse_vector() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
let empty = SparseVec::new(vec![], vec![], 100).unwrap();
index.add(NodeId::new(1).unwrap(), &empty).unwrap();
assert_eq!(index.len(), 1);
let query = SparseVec::new(vec![], vec![], 100).unwrap();
let results = index.search(&query, 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_single_dimension_vectors() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(10_000)).unwrap();
for i in 1..=100 {
let doc = SparseVec::new(vec![0], vec![i as f32], 10_000).unwrap();
index.add(NodeId::new(i).unwrap(), &doc).unwrap();
}
let query = SparseVec::new(vec![0], vec![1.0], 10_000).unwrap();
let results = index.search(&query, 5).unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, NodeId::new(100).unwrap());
}
#[test]
fn test_very_sparse_high_dimensional() {
let dim = 100_000;
let index = SparseVectorIndex::new(SparseIndexConfig::new(dim)).unwrap();
let doc = SparseVec::new(vec![0, 50_000, 99_999], vec![1.0, 2.0, 3.0], dim as u32).unwrap();
index.add(NodeId::new(1).unwrap(), &doc).unwrap();
let query = SparseVec::new(vec![50_000], vec![1.0], dim as u32).unwrap();
let results = index.search(&query, 10).unwrap();
assert_eq!(results.len(), 1);
assert!((results[0].1 - 2.0).abs() < 1e-6);
}
#[test]
fn test_compact() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
for i in 1..=10 {
let doc = SparseVec::new(vec![i as u32], vec![1.0], 100).unwrap();
index.add(NodeId::new(i).unwrap(), &doc).unwrap();
}
for i in 1..=10 {
index.remove(NodeId::new(i).unwrap()).unwrap();
}
assert_eq!(index.len(), 0);
index.compact();
let stats = index.stats();
assert_eq!(stats.total_postings, 0);
}
#[test]
fn test_memory_usage() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(1000)).unwrap();
let initial_mem = index.memory_usage();
for i in 1..=100 {
let doc = SparseVec::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0], 1000).unwrap();
index.add(NodeId::new(i).unwrap(), &doc).unwrap();
}
let final_mem = index.memory_usage();
assert!(final_mem > initial_mem);
}
#[test]
fn test_config_builder() {
let config = SparseIndexConfig::new(1000)
.with_scoring(ScoringMethod::Cosine)
.with_capacity(5000);
assert_eq!(config.dimensions, 1000);
assert_eq!(config.scoring, ScoringMethod::Cosine);
assert_eq!(config.initial_capacity, 5000);
}
#[test]
fn test_bm25_custom_params() {
let scoring = ScoringMethod::BM25 { k1: 2.0, b: 0.5 };
let config = SparseIndexConfig::new(100).with_scoring(scoring);
if let ScoringMethod::BM25 { k1, b } = config.scoring {
assert_eq!(k1, 2.0);
assert_eq!(b, 0.5);
} else {
panic!("Expected BM25 scoring");
}
}
#[test]
fn test_max_dimensions_boundary() {
let result = SparseVectorIndex::new(SparseIndexConfig::new(MAX_VECTOR_DIMENSIONS));
assert!(result.is_ok());
let result = SparseVectorIndex::new(SparseIndexConfig::new(MAX_VECTOR_DIMENSIONS + 1));
assert!(result.is_err());
match result {
Err(Error::Vector(VectorError::DimensionTooLarge {
dimension,
max_allowed,
})) => {
assert_eq!(dimension, MAX_VECTOR_DIMENSIONS + 1);
assert_eq!(max_allowed, MAX_VECTOR_DIMENSIONS);
}
_ => panic!("Expected DimensionTooLarge error"),
}
}
#[test]
fn test_max_k_capping() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
for i in 1..=100 {
let doc = SparseVec::new(vec![0], vec![i as f32], 100).unwrap();
index.add(NodeId::new(i).unwrap(), &doc).unwrap();
}
let query = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
let results = index.search(&query, 100_000).unwrap();
assert!(results.len() <= 100);
}
#[test]
fn test_nan_values_in_search_results() {
let index = SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap();
for i in 1..=5 {
let doc = SparseVec::new(vec![0], vec![i as f32], 100).unwrap();
index.add(NodeId::new(i).unwrap(), &doc).unwrap();
}
let query = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
let results = index.search(&query, 10).unwrap();
assert!(!results.is_empty());
for i in 1..results.len() {
assert!(
results[i - 1].1 >= results[i].1 || results[i].1.is_nan(),
"Results should be sorted by score descending"
);
}
}
#[test]
fn test_concurrent_add_remove_same_node() {
use std::sync::Arc;
use std::thread;
let index = Arc::new(SparseVectorIndex::new(SparseIndexConfig::new(100)).unwrap());
let num_threads = 4;
let iterations = 100;
let handles: Vec<_> = (0..num_threads)
.map(|thread_id| {
let index = Arc::clone(&index);
thread::spawn(move || {
for i in 0..iterations {
let node_id = NodeId::new(1).unwrap(); let doc = SparseVec::new(
vec![(thread_id * iterations + i) as u32 % 50],
vec![1.0],
100,
)
.unwrap();
let _ = index.add(node_id, &doc);
let _ = index.remove(node_id);
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert!(index.len() <= 1);
let query = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
let results = index.search(&query, 10);
assert!(results.is_ok());
}
#[test]
fn test_save_and_load_basic() {
use std::fs;
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("sparse_index.gsp");
let config = SparseIndexConfig::new(100).with_scoring(ScoringMethod::DotProduct);
let index = SparseVectorIndex::new(config.clone()).unwrap();
let v1 = SparseVec::new(vec![0, 10, 50], vec![1.0, 2.0, 3.0], 100).unwrap();
let v2 = SparseVec::new(vec![10, 20, 30], vec![0.5, 1.5, 2.5], 100).unwrap();
index.add(NodeId::new(1).unwrap(), &v1).unwrap();
index.add(NodeId::new(2).unwrap(), &v2).unwrap();
index.save(&path).unwrap();
assert!(path.exists());
let loaded = SparseVectorIndex::load(&path, config).unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(loaded.dimensions(), 100);
let loaded_v1 = loaded.get(NodeId::new(1).unwrap()).unwrap();
assert_eq!(loaded_v1.indices(), v1.indices());
assert_eq!(loaded_v1.values(), v1.values());
let loaded_v2 = loaded.get(NodeId::new(2).unwrap()).unwrap();
assert_eq!(loaded_v2.indices(), v2.indices());
assert_eq!(loaded_v2.values(), v2.values());
let query = SparseVec::new(vec![10], vec![1.0], 100).unwrap();
let results = loaded.search(&query, 10).unwrap();
assert_eq!(results.len(), 2);
fs::remove_file(&path).ok();
}
#[test]
fn test_save_and_load_bm25() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("sparse_bm25.gsp");
let config =
SparseIndexConfig::new(1000).with_scoring(ScoringMethod::BM25 { k1: 1.8, b: 0.6 });
let index = SparseVectorIndex::new(config.clone()).unwrap();
for i in 1..=10 {
let v = SparseVec::new(vec![i as u32, (i * 10) as u32], vec![1.0, 2.0], 1000).unwrap();
index.add(NodeId::new(i).unwrap(), &v).unwrap();
}
index.save(&path).unwrap();
let loaded = SparseVectorIndex::load(&path, config).unwrap();
if let ScoringMethod::BM25 { k1, b } = loaded.scoring() {
assert!((k1 - 1.8).abs() < 1e-6);
assert!((b - 0.6).abs() < 1e-6);
} else {
panic!("Expected BM25 scoring method");
}
assert_eq!(loaded.len(), 10);
}
#[test]
fn test_save_and_load_cosine() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("sparse_cosine.gsp");
let config = SparseIndexConfig::new(500).with_scoring(ScoringMethod::Cosine);
let index = SparseVectorIndex::new(config.clone()).unwrap();
let v = SparseVec::new(vec![0, 100, 200], vec![1.0, 1.0, 1.0], 500).unwrap();
index.add(NodeId::new(42).unwrap(), &v).unwrap();
index.save(&path).unwrap();
let loaded = SparseVectorIndex::load(&path, config).unwrap();
assert_eq!(loaded.scoring(), ScoringMethod::Cosine);
assert_eq!(loaded.len(), 1);
assert!(loaded.contains(NodeId::new(42).unwrap()));
}
#[test]
fn test_save_and_load_empty_index() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("sparse_empty.gsp");
let config = SparseIndexConfig::new(100);
let index = SparseVectorIndex::new(config.clone()).unwrap();
index.save(&path).unwrap();
let loaded = SparseVectorIndex::load(&path, config).unwrap();
assert_eq!(loaded.len(), 0);
assert!(loaded.is_empty());
}
#[test]
fn test_load_invalid_magic() {
use std::fs;
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("invalid_magic.gsp");
fs::write(&path, b"BADM\x01\x00\x00\x00\x00\x00").unwrap();
let config = SparseIndexConfig::new(100);
let result = SparseVectorIndex::load(&path, config);
assert!(result.is_err());
let err = format!("{}", result.err().unwrap());
assert!(err.contains("Invalid magic bytes"));
}
#[test]
fn test_load_corrupted_crc() {
use std::fs;
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("corrupted.gsp");
let config = SparseIndexConfig::new(100);
let index = SparseVectorIndex::new(config.clone()).unwrap();
let v = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
index.add(NodeId::new(1).unwrap(), &v).unwrap();
index.save(&path).unwrap();
let mut data = fs::read(&path).unwrap();
let last_idx = data.len() - 1;
data[last_idx] ^= 0xFF;
fs::write(&path, &data).unwrap();
let result = SparseVectorIndex::load(&path, config);
assert!(result.is_err());
let err = format!("{}", result.err().unwrap());
assert!(err.contains("CRC32 mismatch"));
}
#[test]
fn test_load_dimension_mismatch() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("dim_mismatch.gsp");
let config100 = SparseIndexConfig::new(100);
let index = SparseVectorIndex::new(config100).unwrap();
index.save(&path).unwrap();
let config200 = SparseIndexConfig::new(200);
let result = SparseVectorIndex::load(&path, config200);
assert!(result.is_err());
}
#[test]
fn test_load_file_too_small() {
use std::fs;
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("too_small.gsp");
fs::write(&path, b"ASPS").unwrap();
let config = SparseIndexConfig::new(100);
let result = SparseVectorIndex::load(&path, config);
assert!(result.is_err());
let err = format!("{}", result.err().unwrap());
assert!(err.contains("too small"));
}
#[test]
fn test_save_and_load_preserves_search_results() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("search_preserve.gsp");
let config = SparseIndexConfig::new(100);
let index = SparseVectorIndex::new(config.clone()).unwrap();
for i in 1..=5 {
let v = SparseVec::new(vec![0, 1], vec![i as f32, (6 - i) as f32], 100).unwrap();
index.add(NodeId::new(i).unwrap(), &v).unwrap();
}
let query = SparseVec::new(vec![0], vec![1.0], 100).unwrap();
let results_before = index.search(&query, 5).unwrap();
index.save(&path).unwrap();
let loaded = SparseVectorIndex::load(&path, config).unwrap();
let results_after = loaded.search(&query, 5).unwrap();
assert_eq!(results_before.len(), results_after.len());
for (before, after) in results_before.iter().zip(results_after.iter()) {
assert_eq!(before.0, after.0);
assert!((before.1 - after.1).abs() < 1e-6);
}
}
}