use crate::core::error::{Error, Result, VectorError};
use crate::core::id::NodeId;
use crate::core::vector::validate_vector;
use crate::index::vector::{
DistanceMetric, HnswConfig, HnswIndex, Quantization, VectorIndex, merge_top_k_results,
};
use rayon::prelude::*;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::path::Path;
use std::sync::Arc;
const MAX_K: usize = 100_000;
const DEFAULT_NUM_SHARDS: usize = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ShardingStrategy {
#[default]
HashBased,
RangeBased,
}
#[derive(Debug, Default, Clone)]
pub struct ShardStats {
pub shard_sizes: Vec<usize>,
pub total_vectors: usize,
pub imbalance_ratio: f64,
}
#[derive(Debug, Clone)]
pub struct RebalanceConfig {
pub imbalance_threshold: f64,
pub batch_size: usize,
}
impl Default for RebalanceConfig {
fn default() -> Self {
Self {
imbalance_threshold: 2.0, batch_size: 1000,
}
}
}
impl RebalanceConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_imbalance_threshold(mut self, threshold: f64) -> Self {
self.imbalance_threshold = threshold.max(1.0);
self
}
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size.max(1);
self
}
}
#[derive(Debug, Clone)]
pub struct ShardedVectorConfig {
pub num_shards: usize,
pub strategy: ShardingStrategy,
pub hnsw_config: HnswConfig,
pub rebalance_config: RebalanceConfig,
}
impl ShardedVectorConfig {
pub fn new(num_shards: usize) -> Self {
Self {
num_shards: num_shards.max(1),
strategy: ShardingStrategy::default(),
hnsw_config: HnswConfig::default(),
rebalance_config: RebalanceConfig::default(),
}
}
pub fn with_strategy(mut self, strategy: ShardingStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_hnsw_config(mut self, config: HnswConfig) -> Self {
self.hnsw_config = config;
self
}
pub fn with_rebalance_config(mut self, config: RebalanceConfig) -> Self {
self.rebalance_config = config;
self
}
}
impl Default for ShardedVectorConfig {
fn default() -> Self {
Self::new(DEFAULT_NUM_SHARDS)
}
}
pub struct ShardedVectorIndex {
config: ShardedVectorConfig,
shards: Vec<Arc<HnswIndex>>,
}
impl ShardedVectorIndex {
pub fn new(config: ShardedVectorConfig) -> Result<Self> {
if config.hnsw_config.dimensions == 0 {
return Err(Error::Vector(VectorError::InvalidVector {
reason: "dimensions must be > 0".to_string(),
}));
}
let num_shards = config.num_shards.max(1);
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
let shard = HnswIndex::new(config.hnsw_config.clone())?;
shards.push(Arc::new(shard));
}
Ok(Self { config, shards })
}
pub fn with_defaults(
dimensions: usize,
metric: DistanceMetric,
num_shards: usize,
) -> Result<Self> {
let hnsw_config = HnswConfig::new(dimensions, metric);
let config = ShardedVectorConfig::new(num_shards).with_hnsw_config(hnsw_config);
Self::new(config)
}
pub fn num_shards(&self) -> usize {
self.shards.len()
}
pub fn strategy(&self) -> ShardingStrategy {
self.config.strategy
}
pub fn config(&self) -> &ShardedVectorConfig {
&self.config
}
pub fn stats(&self) -> ShardStats {
let shard_sizes: Vec<usize> = self.shards.iter().map(|s| s.len()).collect();
let total_vectors: usize = shard_sizes.iter().sum();
let min_size = shard_sizes.iter().min().copied().unwrap_or(0);
let max_size = shard_sizes.iter().max().copied().unwrap_or(0);
let imbalance_ratio = if min_size > 0 {
max_size as f64 / min_size as f64
} else if max_size > 0 {
f64::INFINITY
} else {
1.0
};
ShardStats {
shard_sizes,
total_vectors,
imbalance_ratio,
}
}
fn shard_for_id(&self, id: NodeId) -> usize {
debug_assert!(!self.shards.is_empty(), "shards cannot be empty");
let num_shards = self.shards.len();
match self.config.strategy {
ShardingStrategy::HashBased => {
let mut hasher = DefaultHasher::new();
id.as_u64().hash(&mut hasher);
(hasher.finish() as usize) % num_shards
}
ShardingStrategy::RangeBased => {
let num_shards_128 = num_shards as u128;
let id_128 = id.as_u64() as u128;
let shard = ((id_128 * num_shards_128) / (u64::MAX as u128 + 1)) as usize;
shard.min(num_shards - 1)
}
}
}
pub fn get_shard(&self, index: usize) -> Option<&Arc<HnswIndex>> {
self.shards.get(index)
}
pub fn needs_rebalancing(&self) -> bool {
let stats = self.stats();
stats.imbalance_ratio > self.config.rebalance_config.imbalance_threshold
}
pub fn estimate_rebalance_cost(&self) -> Result<usize> {
debug_assert!(!self.shards.is_empty(), "shards cannot be empty");
let stats = self.stats();
if stats.imbalance_ratio <= self.config.rebalance_config.imbalance_threshold {
return Ok(0);
}
let target_size = stats.total_vectors / self.shards.len();
let mut vectors_to_move = 0;
for size in &stats.shard_sizes {
if *size > target_size {
vectors_to_move += size - target_size;
}
}
vectors_to_move = vectors_to_move.min(self.config.rebalance_config.batch_size);
Ok(vectors_to_move)
}
pub fn total_memory_usage(&self) -> usize {
self.shards.iter().map(|s| s.memory_usage()).sum()
}
pub fn load(path: &Path, config: ShardedVectorConfig) -> Result<Self> {
Self::validate_path(path)?;
let mut shards = Vec::with_capacity(config.num_shards);
for i in 0..config.num_shards {
let shard_path = Self::shard_path(path, i)?;
let shard = HnswIndex::load(&shard_path, config.hnsw_config.clone())?;
shards.push(Arc::new(shard));
}
Ok(Self { config, shards })
}
fn validate_path(path: &Path) -> Result<()> {
let path_str = path.to_string_lossy();
if path_str.contains("..") {
return Err(Error::Vector(VectorError::IndexError(
"Path contains invalid traversal characters (..)".to_string(),
)));
}
if path.file_name().is_none() {
return Err(Error::Vector(VectorError::IndexError(
"Path must have a valid file name".to_string(),
)));
}
Ok(())
}
fn shard_path(base_path: &Path, shard_index: usize) -> Result<std::path::PathBuf> {
let file_name = base_path
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| {
Error::Vector(VectorError::IndexError(
"Path must have a valid UTF-8 file name".to_string(),
))
})?;
Ok(base_path.with_file_name(format!("{}.shard_{}", file_name, shard_index)))
}
fn merge_results(shard_results: Vec<Vec<(NodeId, f32)>>, k: usize) -> Vec<(NodeId, f32)> {
merge_top_k_results(shard_results, k)
}
}
impl VectorIndex for ShardedVectorIndex {
fn add(&self, id: NodeId, vector: &[f32]) -> Result<()> {
validate_vector(vector)?;
if vector.len() != self.config.hnsw_config.dimensions {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: self.config.hnsw_config.dimensions,
actual: vector.len(),
}));
}
let shard_idx = self.shard_for_id(id);
self.shards[shard_idx].add(id, vector)
}
fn remove(&self, id: NodeId) -> Result<()> {
let shard_idx = self.shard_for_id(id);
self.shards[shard_idx].remove(id)
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
validate_vector(query)?;
if query.len() != self.config.hnsw_config.dimensions {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: self.config.hnsw_config.dimensions,
actual: query.len(),
}));
}
let k_capped = k.min(MAX_K);
let shard_results: Result<Vec<Vec<(NodeId, f32)>>> = self
.shards
.par_iter()
.filter(|shard| !shard.is_empty())
.map(|shard| shard.search(query, k_capped))
.collect();
let shard_results = shard_results?;
Ok(Self::merge_results(shard_results, k_capped))
}
fn search_with_filter<F>(
&self,
query: &[f32],
k: usize,
predicate: F,
) -> Result<Vec<(NodeId, f32)>>
where
F: Fn(&NodeId) -> bool + Send + Sync,
{
validate_vector(query)?;
if query.len() != self.config.hnsw_config.dimensions {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: self.config.hnsw_config.dimensions,
actual: query.len(),
}));
}
let k_capped = k.min(MAX_K);
let shard_results: Result<Vec<Vec<(NodeId, f32)>>> = self
.shards
.par_iter()
.filter(|shard| !shard.is_empty())
.map(|shard| shard.search_with_filter(query, k_capped, &predicate))
.collect();
let shard_results = shard_results?;
Ok(Self::merge_results(shard_results, k_capped))
}
fn len(&self) -> usize {
self.shards.iter().map(|s| s.len()).sum()
}
fn dimensions(&self) -> usize {
self.config.hnsw_config.dimensions
}
fn distance_metric(&self) -> DistanceMetric {
self.config.hnsw_config.metric
}
fn add_batch(&self, items: &[(NodeId, Vec<f32>)]) -> Result<()> {
let refs: Vec<(NodeId, &[f32])> = items
.iter()
.map(|(id, vec)| (*id, vec.as_slice()))
.collect();
self.add_batch_ref(&refs)
}
fn add_batch_ref(&self, items: &[(NodeId, &[f32])]) -> Result<()> {
if items.is_empty() {
return Ok(());
}
let capacity = items
.len()
.checked_div(self.shards.len())
.unwrap_or(items.len())
+ 1;
let mut shard_indices: Vec<Vec<usize>> = (0..self.shards.len())
.map(|_| Vec::with_capacity(capacity))
.collect();
for (idx, (id, _)) in items.iter().enumerate() {
let shard_idx = self.shard_for_id(*id);
shard_indices[shard_idx].push(idx);
}
for (shard_idx, indices) in shard_indices.into_iter().enumerate() {
if !indices.is_empty() {
let shard_items: Vec<(NodeId, &[f32])> = indices
.iter()
.map(|&idx| (items[idx].0, items[idx].1))
.collect();
self.shards[shard_idx].add_batch_ref(&shard_items)?;
}
}
Ok(())
}
fn remove_batch(&self, ids: &[NodeId]) -> Result<()> {
if ids.is_empty() {
return Ok(());
}
let capacity = ids
.len()
.checked_div(self.shards.len())
.unwrap_or(ids.len())
+ 1;
let mut shard_ids: Vec<Vec<NodeId>> = (0..self.shards.len())
.map(|_| Vec::with_capacity(capacity))
.collect();
for id in ids {
let shard_idx = self.shard_for_id(*id);
shard_ids[shard_idx].push(*id);
}
for (shard_idx, ids) in shard_ids.into_iter().enumerate() {
if !ids.is_empty() {
self.shards[shard_idx].remove_batch(&ids)?;
}
}
Ok(())
}
fn save(&self, path: &Path) -> Result<()> {
Self::validate_path(path)?;
for (i, shard) in self.shards.iter().enumerate() {
let shard_path = Self::shard_path(path, i)?;
shard.save(&shard_path)?;
}
Ok(())
}
fn memory_usage(&self) -> usize {
self.total_memory_usage()
}
fn quantization(&self) -> Quantization {
self.config.hnsw_config.quantization
}
fn compact(&self) -> Result<()> {
for shard in &self.shards {
shard.compact()?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::vector::OrderedFloat;
#[test]
fn test_sharded_config_defaults() {
let config = ShardedVectorConfig::default();
assert_eq!(config.num_shards, DEFAULT_NUM_SHARDS);
assert_eq!(config.strategy, ShardingStrategy::HashBased);
}
#[test]
fn test_sharded_config_builder() {
let hnsw_config = HnswConfig::new(128, DistanceMetric::Cosine);
let config = ShardedVectorConfig::new(8)
.with_strategy(ShardingStrategy::RangeBased)
.with_hnsw_config(hnsw_config)
.with_rebalance_config(RebalanceConfig::new().with_imbalance_threshold(1.5));
assert_eq!(config.num_shards, 8);
assert_eq!(config.strategy, ShardingStrategy::RangeBased);
assert_eq!(config.hnsw_config.dimensions, 128);
assert!((config.rebalance_config.imbalance_threshold - 1.5).abs() < 0.001);
}
#[test]
fn test_sharded_config_clamps_num_shards() {
let config = ShardedVectorConfig::new(0);
assert_eq!(config.num_shards, 1);
}
#[test]
fn test_rebalance_config_defaults() {
let config = RebalanceConfig::default();
assert!((config.imbalance_threshold - 2.0).abs() < 0.001);
assert_eq!(config.batch_size, 1000);
}
#[test]
fn test_rebalance_config_clamps_threshold() {
let config = RebalanceConfig::new().with_imbalance_threshold(0.5);
assert!((config.imbalance_threshold - 1.0).abs() < 0.001);
}
#[test]
fn test_create_sharded_index() -> Result<()> {
let config = ShardedVectorConfig::new(4)
.with_hnsw_config(HnswConfig::new(128, DistanceMetric::Cosine));
let index = ShardedVectorIndex::new(config)?;
assert_eq!(index.num_shards(), 4);
assert_eq!(index.dimensions(), 128);
assert_eq!(index.distance_metric(), DistanceMetric::Cosine);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
Ok(())
}
#[test]
fn test_create_with_defaults() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(256, DistanceMetric::Euclidean, 2)?;
assert_eq!(index.num_shards(), 2);
assert_eq!(index.dimensions(), 256);
assert_eq!(index.distance_metric(), DistanceMetric::Euclidean);
Ok(())
}
#[test]
fn test_create_fails_with_zero_dimensions() {
let config = ShardedVectorConfig::new(4)
.with_hnsw_config(HnswConfig::new(0, DistanceMetric::Cosine));
let result = ShardedVectorIndex::new(config);
assert!(result.is_err());
}
#[test]
fn test_add_single_vector() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let node = NodeId::new(1).unwrap();
let vector = vec![1.0, 0.0, 0.0, 0.0];
index.add(node, &vector)?;
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
Ok(())
}
#[test]
fn test_add_multiple_vectors() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
for i in 1..=100 {
let node = NodeId::new(i).unwrap();
let vector = vec![i as f32, 0.0, 0.0, 0.0];
index.add(node, &vector)?;
}
assert_eq!(index.len(), 100);
Ok(())
}
#[test]
fn test_add_dimension_mismatch() {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4).unwrap();
let node = NodeId::new(1).unwrap();
let wrong_dim_vector = vec![1.0, 0.0];
let result = index.add(node, &wrong_dim_vector);
assert!(result.is_err());
}
#[test]
fn test_add_with_nan() {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4).unwrap();
let node = NodeId::new(1).unwrap();
let nan_vector = vec![1.0, f32::NAN, 0.0, 0.0];
let result = index.add(node, &nan_vector);
assert!(result.is_err());
}
#[test]
fn test_remove_vector() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let node = NodeId::new(1).unwrap();
let vector = vec![1.0, 0.0, 0.0, 0.0];
index.add(node, &vector)?;
assert_eq!(index.len(), 1);
index.remove(node)?;
assert_eq!(index.len(), 0);
Ok(())
}
#[test]
fn test_remove_nonexistent() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let node = NodeId::new(1).unwrap();
index.remove(node)?;
assert_eq!(index.len(), 0);
Ok(())
}
#[test]
fn test_add_batch() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let items: Vec<(NodeId, Vec<f32>)> = (1..=10)
.map(|i| (NodeId::new(i).unwrap(), vec![i as f32, 0.0, 0.0, 0.0]))
.collect();
index.add_batch(&items)?;
assert_eq!(index.len(), 10);
Ok(())
}
#[test]
fn test_remove_batch() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
for i in 1..=10 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
assert_eq!(index.len(), 10);
let ids_to_remove: Vec<NodeId> = (1..=5).map(|i| NodeId::new(i).unwrap()).collect();
index.remove_batch(&ids_to_remove)?;
assert_eq!(index.len(), 5);
Ok(())
}
#[test]
fn test_search_empty_index() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 10)?;
assert!(results.is_empty());
Ok(())
}
#[test]
fn test_search_basic() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let node1 = NodeId::new(1).unwrap();
let node2 = NodeId::new(2).unwrap();
let node3 = NodeId::new(3).unwrap();
index.add(node1, &[1.0, 0.0, 0.0, 0.0])?;
index.add(node2, &[0.9, 0.1, 0.0, 0.0])?;
index.add(node3, &[0.0, 1.0, 0.0, 0.0])?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 3)?;
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, node1);
assert_eq!(results[1].0, node2);
Ok(())
}
#[test]
fn test_search_k_larger_than_index() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let node1 = NodeId::new(1).unwrap();
let node2 = NodeId::new(2).unwrap();
index.add(node1, &[1.0, 0.0, 0.0, 0.0])?;
index.add(node2, &[0.0, 1.0, 0.0, 0.0])?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 100)?;
assert_eq!(results.len(), 2);
Ok(())
}
#[test]
fn test_search_dimension_mismatch() {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4).unwrap();
let node = NodeId::new(1).unwrap();
index.add(node, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let wrong_dim_query = vec![1.0, 0.0];
let result = index.search(&wrong_dim_query, 10);
assert!(result.is_err());
}
#[test]
fn test_search_with_filter() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
for i in 1..=10 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
let query = vec![5.0, 0.0, 0.0, 0.0];
let results = index.search_with_filter(&query, 10, |id| id.as_u64() % 2 == 0)?;
for (id, _) in &results {
assert_eq!(id.as_u64() % 2, 0);
}
Ok(())
}
#[test]
fn test_search_with_filter_no_matches() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let node = NodeId::new(1).unwrap();
index.add(node, &[1.0, 0.0, 0.0, 0.0])?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search_with_filter(&query, 10, |_| false)?;
assert!(results.is_empty());
Ok(())
}
#[test]
fn test_hash_based_distribution() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
for i in 1..=1000 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
let stats = index.stats();
assert_eq!(stats.total_vectors, 1000);
for size in &stats.shard_sizes {
assert!(*size >= 100, "Shard too small: {}", size);
assert!(*size <= 400, "Shard too large: {}", size);
}
Ok(())
}
#[test]
fn test_range_based_distribution() -> Result<()> {
let config = ShardedVectorConfig::new(4)
.with_strategy(ShardingStrategy::RangeBased)
.with_hnsw_config(HnswConfig::new(4, DistanceMetric::Cosine));
let index = ShardedVectorIndex::new(config)?;
for i in 1..=100 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
assert_eq!(index.len(), 100);
let stats = index.stats();
assert_eq!(stats.total_vectors, 100);
Ok(())
}
#[test]
fn test_range_based_distribution_even() -> Result<()> {
let config = ShardedVectorConfig::new(4)
.with_strategy(ShardingStrategy::RangeBased)
.with_hnsw_config(HnswConfig::new(4, DistanceMetric::Cosine));
let index = ShardedVectorIndex::new(config)?;
let test_ids: Vec<u64> = vec![
0,
u64::MAX / 4,
u64::MAX / 2,
u64::MAX / 4 * 3,
u64::MAX - 1000, ];
for id in test_ids {
if let Ok(node) = NodeId::new(id) {
index.add(node, &[id as f32, 0.0, 0.0, 0.0])?;
}
}
let stats = index.stats();
let non_empty_shards = stats.shard_sizes.iter().filter(|&&s| s > 0).count();
assert!(
non_empty_shards >= 2,
"Vectors should be in multiple shards"
);
Ok(())
}
#[test]
fn test_consistent_routing() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let node = NodeId::new(42).unwrap();
let shard1 = index.shard_for_id(node);
let shard2 = index.shard_for_id(node);
let shard3 = index.shard_for_id(node);
assert_eq!(shard1, shard2);
assert_eq!(shard2, shard3);
Ok(())
}
#[test]
fn test_range_based_max_id_edge_case() -> Result<()> {
let config = ShardedVectorConfig::new(4)
.with_strategy(ShardingStrategy::RangeBased)
.with_hnsw_config(HnswConfig::new(4, DistanceMetric::Cosine));
let index = ShardedVectorIndex::new(config)?;
let edge_ids: Vec<u64> = vec![
0,
1,
u64::MAX - 1,
u64::MAX, ];
for id in edge_ids {
if let Ok(node) = NodeId::new(id) {
let shard = index.shard_for_id(node);
assert!(
shard < 4,
"Shard index {} out of bounds for id {}",
shard,
id
);
}
}
Ok(())
}
#[test]
fn test_stats_empty_index() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let stats = index.stats();
assert_eq!(stats.total_vectors, 0);
assert_eq!(stats.shard_sizes.len(), 4);
assert!((stats.imbalance_ratio - 1.0).abs() < 0.001);
Ok(())
}
#[test]
fn test_needs_rebalancing() -> Result<()> {
let config = ShardedVectorConfig::new(4)
.with_hnsw_config(HnswConfig::new(4, DistanceMetric::Cosine))
.with_rebalance_config(RebalanceConfig::new().with_imbalance_threshold(2.0));
let index = ShardedVectorIndex::new(config)?;
assert!(!index.needs_rebalancing());
Ok(())
}
#[test]
fn test_estimate_rebalance_cost() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
for i in 1..=100 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
let cost = index.estimate_rebalance_cost()?;
assert!(cost <= 100);
Ok(())
}
#[test]
fn test_memory_usage() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(128, DistanceMetric::Cosine, 4)?;
let empty_usage = index.memory_usage();
assert!(empty_usage > 0);
for i in 1..=100 {
let node = NodeId::new(i).unwrap();
let vector: Vec<f32> = (0..128).map(|j| (i * j) as f32).collect();
index.add(node, &vector)?;
}
let usage_with_vectors = index.memory_usage();
assert!(usage_with_vectors > empty_usage);
Ok(())
}
#[test]
fn test_concurrent_add() -> Result<()> {
use std::thread;
let index = Arc::new(ShardedVectorIndex::with_defaults(
4,
DistanceMetric::Cosine,
4,
)?);
let mut handles = vec![];
for t in 0..4 {
let index = Arc::clone(&index);
let handle = thread::spawn(move || {
for i in 0..25 {
let id = t * 25 + i + 1;
let node = NodeId::new(id as u64).unwrap();
let vector = vec![id as f32, 0.0, 0.0, 0.0];
index.add(node, &vector).unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(index.len(), 100);
Ok(())
}
#[test]
fn test_concurrent_search() -> Result<()> {
use std::thread;
let index = Arc::new(ShardedVectorIndex::with_defaults(
4,
DistanceMetric::Cosine,
4,
)?);
for i in 1..=50 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
let mut handles = vec![];
for _ in 0..4 {
let index = Arc::clone(&index);
let handle = thread::spawn(move || {
for i in 1..=10 {
let query = vec![i as f32, 0.0, 0.0, 0.0];
let results = index.search(&query, 5).unwrap();
assert!(!results.is_empty());
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
Ok(())
}
#[test]
fn test_quantization() -> Result<()> {
let config = ShardedVectorConfig::new(4).with_hnsw_config(
HnswConfig::new(4, DistanceMetric::Cosine).with_quantization(Quantization::F16),
);
let index = ShardedVectorIndex::new(config)?;
assert_eq!(index.quantization(), Quantization::F16);
Ok(())
}
#[test]
fn test_cross_shard_search_merging() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
for i in 1..=100 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
let query = vec![50.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 10)?;
assert_eq!(results.len(), 10);
for i in 0..results.len() - 1 {
assert!(
results[i].1 >= results[i + 1].1,
"Results not sorted by similarity"
);
}
Ok(())
}
#[test]
fn test_search_results_from_multiple_shards() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let vectors = vec![
(1u64, [1.0f32, 0.0, 0.0, 0.0]),
(1000, [0.9, 0.1, 0.0, 0.0]),
(2000, [0.8, 0.2, 0.0, 0.0]),
(3000, [0.7, 0.3, 0.0, 0.0]),
];
for (id, vec) in &vectors {
let node = NodeId::new(*id).unwrap();
index.add(node, vec)?;
}
let stats = index.stats();
assert_eq!(stats.total_vectors, 4);
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 10)?;
assert_eq!(results.len(), 4);
Ok(())
}
#[test]
fn test_merge_results_empty() {
let results = ShardedVectorIndex::merge_results(vec![], 10);
assert!(results.is_empty());
}
#[test]
fn test_merge_results_single_shard() {
let shard_results = vec![vec![
(NodeId::new(1).unwrap(), 0.9),
(NodeId::new(2).unwrap(), 0.8),
(NodeId::new(3).unwrap(), 0.7),
]];
let merged = ShardedVectorIndex::merge_results(shard_results, 2);
assert_eq!(merged.len(), 2);
assert_eq!(merged[0].0, NodeId::new(1).unwrap());
assert_eq!(merged[1].0, NodeId::new(2).unwrap());
}
#[test]
fn test_merge_results_multiple_shards() {
let shard_results = vec![
vec![
(NodeId::new(1).unwrap(), 0.9),
(NodeId::new(2).unwrap(), 0.7),
],
vec![
(NodeId::new(3).unwrap(), 0.85),
(NodeId::new(4).unwrap(), 0.6),
],
];
let merged = ShardedVectorIndex::merge_results(shard_results, 3);
assert_eq!(merged.len(), 3);
assert_eq!(merged[0].0, NodeId::new(1).unwrap());
assert_eq!(merged[1].0, NodeId::new(3).unwrap());
assert_eq!(merged[2].0, NodeId::new(2).unwrap());
}
#[test]
fn test_merge_results_k_zero() {
let shard_results = vec![vec![(NodeId::new(1).unwrap(), 0.9)]];
let merged = ShardedVectorIndex::merge_results(shard_results, 0);
assert!(merged.is_empty());
}
#[test]
fn test_path_validation_traversal() {
let path = Path::new("/data/../etc/passwd");
let result = ShardedVectorIndex::validate_path(path);
assert!(result.is_err());
}
#[test]
fn test_path_validation_no_filename() {
let path = Path::new("/");
let result = ShardedVectorIndex::validate_path(path);
assert!(result.is_err());
}
#[test]
fn test_path_validation_valid() {
let path = Path::new("/data/index");
let result = ShardedVectorIndex::validate_path(path);
assert!(result.is_ok());
}
#[test]
fn test_save_and_load() -> Result<()> {
use tempfile::tempdir;
let dir = tempdir().map_err(Error::Io)?;
let path = dir.path().join("test_index");
let config = ShardedVectorConfig::new(2)
.with_hnsw_config(HnswConfig::new(4, DistanceMetric::Cosine));
let index = ShardedVectorIndex::new(config.clone())?;
for i in 1..=10 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
assert_eq!(index.len(), 10);
index.save(&path)?;
let loaded = ShardedVectorIndex::load(&path, config)?;
assert_eq!(loaded.len(), 10);
assert_eq!(loaded.num_shards(), 2);
let query = vec![5.0, 0.0, 0.0, 0.0];
let results = loaded.search(&query, 5)?;
assert!(!results.is_empty());
Ok(())
}
#[test]
fn test_ordered_float_ordering() {
let a = OrderedFloat(0.5);
let b = OrderedFloat(0.7);
let c = OrderedFloat(0.5);
assert!(a < b);
assert!(b > a);
assert_eq!(a, c);
}
#[test]
fn test_ordered_float_nan_handling() {
let nan = OrderedFloat(f32::NAN);
let normal = OrderedFloat(0.5);
assert!(nan < normal);
assert!(normal > nan);
}
#[test]
fn test_add_batch_ref() -> Result<()> {
let index = ShardedVectorIndex::with_defaults(4, DistanceMetric::Cosine, 4)?;
let vectors: Vec<Vec<f32>> = (1..=10).map(|i| vec![i as f32, 0.0, 0.0, 0.0]).collect();
let items: Vec<(NodeId, &[f32])> = vectors
.iter()
.enumerate()
.map(|(i, v)| (NodeId::new((i + 1) as u64).unwrap(), v.as_slice()))
.collect();
index.add_batch_ref(&items)?;
assert_eq!(index.len(), 10);
Ok(())
}
}