#![allow(clippy::cast_precision_loss, clippy::float_cmp)]
use super::sharded_vectors::{ShardedVectors, NUM_SHARDS};
use serial_test::serial;
use std::sync::Arc;
use std::thread;
#[test]
fn test_sharded_vectors_new_is_empty() {
let storage = ShardedVectors::new(3);
assert!(storage.is_empty());
assert_eq!(storage.len(), 0);
}
#[test]
fn test_sharded_vectors_insert_and_get() {
let storage = ShardedVectors::new(3);
let vector = vec![1.0, 2.0, 3.0];
storage.insert(0, &vector);
assert_eq!(storage.get(0), Some(vector));
assert_eq!(storage.len(), 1);
}
#[test]
fn test_sharded_vectors_insert_multiple_shards() {
let storage = ShardedVectors::new(3);
for i in 0..32 {
#[allow(clippy::cast_precision_loss)]
let val = i as f32;
storage.insert(i, &[val; 3]);
}
assert_eq!(storage.len(), 32);
for i in 0..32 {
#[allow(clippy::cast_precision_loss)]
let val = i as f32;
assert_eq!(storage.get(i), Some(vec![val; 3]));
}
}
#[test]
fn test_sharded_vectors_get_nonexistent() {
let storage = ShardedVectors::new(3);
assert_eq!(storage.get(999), None);
}
#[test]
fn test_sharded_vectors_contains() {
let storage = ShardedVectors::new(1);
storage.insert(42, &[1.0]);
assert!(storage.contains(42));
assert!(!storage.contains(999));
}
#[test]
fn test_sharded_vectors_remove() {
let storage = ShardedVectors::new(2);
storage.insert(42, &[1.0, 2.0]);
let removed = storage.remove(42);
assert_eq!(removed, Some(vec![1.0, 2.0]));
assert!(!storage.contains(42));
assert!(storage.is_empty());
}
#[test]
fn test_sharded_vectors_remove_nonexistent() {
let storage = ShardedVectors::new(1);
assert_eq!(storage.remove(999), None);
}
#[test]
fn test_sharded_vectors_with_vector() {
let storage = ShardedVectors::new(3);
storage.insert(0, &[1.0, 2.0, 3.0]);
let sum = storage.with_vector(0, |v| v.iter().sum::<f32>());
assert_eq!(sum, Some(6.0));
}
#[test]
fn test_sharded_vectors_with_vector_nonexistent() {
let storage = ShardedVectors::new(1);
let result = storage.with_vector(999, <[f32]>::len);
assert_eq!(result, None);
}
#[test]
fn test_sharded_vectors_insert_batch() {
let storage = ShardedVectors::new(3);
#[allow(clippy::cast_precision_loss)]
let batch: Vec<(usize, Vec<f32>)> = (0..100).map(|i| (i, vec![i as f32; 3])).collect();
storage.insert_batch(batch);
assert_eq!(storage.len(), 100);
for i in 0..100 {
#[allow(clippy::cast_precision_loss)]
let val = i as f32;
assert_eq!(storage.get(i), Some(vec![val; 3]));
}
}
#[test]
fn test_sharded_vectors_iter_all() {
let storage = ShardedVectors::new(1);
storage.insert(0, &[1.0]);
storage.insert(16, &[2.0]); storage.insert(1, &[3.0]);
let all: Vec<(usize, Vec<f32>)> = storage.iter_all();
assert_eq!(all.len(), 3);
}
#[test]
fn test_sharded_vectors_for_each_parallel() {
let storage = ShardedVectors::new(1);
for i in 0..50 {
#[allow(clippy::cast_precision_loss)]
let val = i as f32;
storage.insert(i, &[val]);
}
let mut sum = 0.0;
storage.for_each_parallel(|_, v| {
sum += v[0];
});
assert!((sum - 1225.0).abs() < f32::EPSILON);
}
#[test]
fn test_shard_index_distribution() {
for i in 0..NUM_SHARDS {
assert_eq!(ShardedVectors::shard_index(i), i);
}
assert_eq!(ShardedVectors::shard_index(16), 0);
assert_eq!(ShardedVectors::shard_index(17), 1);
assert_eq!(ShardedVectors::shard_index(32), 0);
}
#[test]
fn test_sharded_vectors_concurrent_insert() {
let storage = Arc::new(ShardedVectors::new(32));
let num_threads = 4;
let vectors_per_thread = 50;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let s = Arc::clone(&storage);
thread::spawn(move || {
let start = t * vectors_per_thread;
for i in start..(start + vectors_per_thread) {
s.insert(i, &[i as f32; 32]);
}
})
})
.collect();
for h in handles {
h.join().expect("Thread should not panic");
}
assert_eq!(storage.len(), num_threads * vectors_per_thread);
}
#[test]
fn test_sharded_vectors_concurrent_read_write() {
let storage = Arc::new(ShardedVectors::new(16));
for i in 0..100 {
storage.insert(i, &[i as f32; 16]);
}
let num_readers = 2;
let num_writers = 2;
let mut handles = vec![];
for _ in 0..num_readers {
let s = Arc::clone(&storage);
handles.push(thread::spawn(move || {
for _ in 0..100 {
let _ = s.get(50);
let _ = s.contains(50);
let _ = s.with_vector(50, <[f32]>::len);
}
}));
}
for t in 0..num_writers {
let s = Arc::clone(&storage);
handles.push(thread::spawn(move || {
let start = 100 + t * 20;
for i in start..(start + 20) {
s.insert(i, &[i as f32; 16]);
}
}));
}
for h in handles {
h.join().expect("Thread should not panic");
}
assert_eq!(storage.len(), 100 + num_writers * 20);
}
#[test]
fn test_sharded_vectors_parallel_batch_insert() {
let storage = Arc::new(ShardedVectors::new(16));
let num_threads = 2;
let batch_size = 50;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let s = Arc::clone(&storage);
thread::spawn(move || {
let start = t * batch_size;
let batch: Vec<(usize, Vec<f32>)> = (start..(start + batch_size))
.map(|i| (i, vec![i as f32; 16]))
.collect();
s.insert_batch(batch);
})
})
.collect();
for h in handles {
h.join().expect("Thread should not panic");
}
assert_eq!(storage.len(), num_threads * batch_size);
}
#[test]
fn test_sharded_vectors_no_data_corruption() {
let storage = Arc::new(ShardedVectors::new(10));
let num_threads = 4;
let ops_per_thread = 50;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let s = Arc::clone(&storage);
thread::spawn(move || {
for i in 0..ops_per_thread {
let idx = t * ops_per_thread + i;
let expected = vec![idx as f32; 10];
s.insert(idx, &expected);
let retrieved = s.get(idx);
assert_eq!(retrieved, Some(expected), "Data corruption at idx {idx}");
}
})
})
.collect();
for h in handles {
h.join().expect("No data corruption");
}
for idx in 0..(num_threads * ops_per_thread) {
let expected = vec![idx as f32; 10];
assert_eq!(storage.get(idx), Some(expected));
}
}
#[test]
fn test_sharded_vectors_collect_for_parallel_returns_all() {
let storage = ShardedVectors::new(4);
for i in 0..100 {
storage.insert(i, &[i as f32; 4]);
}
let collected = storage.collect_for_parallel();
assert_eq!(collected.len(), 100);
for (idx, vec) in &collected {
assert_eq!(vec.len(), 4);
assert_eq!(vec[0], *idx as f32);
}
}
#[test]
fn test_sharded_vectors_collect_for_parallel_empty() {
let storage = ShardedVectors::new(4);
let collected = storage.collect_for_parallel();
assert!(collected.is_empty());
}
#[test]
#[serial]
fn test_sharded_vectors_par_map_computes_correctly() {
use rayon::prelude::*;
let storage = ShardedVectors::new(4);
for i in 0..50 {
storage.insert(i, &[i as f32; 4]);
}
let results: Vec<(usize, f32)> = storage
.collect_for_parallel()
.par_iter()
.map(|(idx, vec)| (*idx, vec.iter().sum::<f32>()))
.collect();
assert_eq!(results.len(), 50);
for (idx, sum) in &results {
assert_eq!(*sum, *idx as f32 * 4.0);
}
}
#[test]
#[serial]
fn test_sharded_vectors_par_filter_map_works() {
use rayon::prelude::*;
let storage = ShardedVectors::new(4);
for i in 0..100 {
storage.insert(i, &[i as f32; 4]);
}
let results: Vec<usize> = storage
.collect_for_parallel()
.par_iter()
.filter_map(|(idx, _)| if *idx % 2 == 0 { Some(*idx) } else { None })
.collect();
assert_eq!(results.len(), 50);
for idx in &results {
assert_eq!(*idx % 2, 0);
}
}
#[test]
fn test_collect_into_reuses_buffer() {
let storage = ShardedVectors::new(4);
for i in 0..50 {
storage.insert(i, &[i as f32; 4]);
}
let mut buffer: Vec<(usize, Vec<f32>)> = Vec::with_capacity(100);
storage.collect_into(&mut buffer);
assert_eq!(buffer.len(), 50);
assert!(buffer.capacity() >= 100);
buffer.clear();
storage.collect_into(&mut buffer);
assert_eq!(buffer.len(), 50);
assert!(buffer.capacity() >= 100);
}
#[test]
fn test_collect_into_clears_and_fills() {
let storage = ShardedVectors::new(3);
for i in 0..20 {
storage.insert(i, &[i as f32; 3]);
}
let mut buffer: Vec<(usize, Vec<f32>)> = vec![(999, vec![0.0; 3]); 5];
storage.collect_into(&mut buffer);
assert_eq!(buffer.len(), 20);
assert!(!buffer.iter().any(|(idx, _)| *idx == 999));
}
#[test]
fn test_collect_into_empty_storage() {
let storage = ShardedVectors::new(1);
let mut buffer: Vec<(usize, Vec<f32>)> = vec![(1, vec![1.0]); 10];
storage.collect_into(&mut buffer);
assert!(buffer.is_empty());
}
#[test]
fn test_collect_into_matches_collect_for_parallel() {
let storage = ShardedVectors::new(8);
for i in 0..100 {
storage.insert(i, &[i as f32; 8]);
}
let collected = storage.collect_for_parallel();
let mut buffer = Vec::new();
storage.collect_into(&mut buffer);
assert_eq!(collected.len(), buffer.len());
let mut collected_sorted: Vec<_> = collected.iter().map(|(idx, _)| *idx).collect();
let mut buffer_sorted: Vec<_> = buffer.iter().map(|(idx, _)| *idx).collect();
collected_sorted.sort_unstable();
buffer_sorted.sort_unstable();
assert_eq!(collected_sorted, buffer_sorted);
}