use crate::bloom_filter::BloomFilter;
use crate::distance::DistanceMetric;
use crate::error::ShardexError;
use crate::identifiers::{DocumentId, ShardId};
use crate::posting_storage::PostingStorage;
use crate::structures::{Posting, SearchResult};
use crate::vector_storage::VectorStorage;
use std::path::{Path, PathBuf};
use std::time::SystemTime;
#[derive(Debug, Clone, PartialEq)]
pub struct ShardMetadata {
pub created_at: SystemTime,
pub current_count: usize,
pub active_count: usize,
pub disk_usage: usize,
pub read_only: bool,
pub bloom_filter: BloomFilter,
}
pub struct Shard {
id: ShardId,
vector_storage: VectorStorage,
posting_storage: PostingStorage,
capacity: usize,
vector_size: usize,
directory: PathBuf,
metadata: ShardMetadata,
centroid: Vec<f32>,
active_vector_count: usize,
}
impl ShardMetadata {
pub fn new(read_only: bool, capacity: usize) -> Result<Self, ShardexError> {
let bloom_filter = BloomFilter::new(capacity, 0.01)?;
Ok(Self {
created_at: SystemTime::now(),
current_count: 0,
active_count: 0,
disk_usage: 0,
read_only,
bloom_filter,
})
}
pub fn update_from_storages(&mut self, vector_storage: &VectorStorage, posting_storage: &PostingStorage) {
self.current_count = vector_storage
.current_count()
.max(posting_storage.current_count());
self.active_count = vector_storage
.active_count()
.min(posting_storage.active_count());
let vector_file_size =
std::mem::size_of::<f32>() * vector_storage.capacity() * vector_storage.vector_dimension();
let posting_file_size = posting_storage.capacity() * (16 + 4 + 4 + 1); self.disk_usage = vector_file_size + posting_file_size;
}
pub fn utilization(&self, capacity: usize) -> f32 {
if capacity == 0 {
0.0
} else {
self.active_count as f32 / capacity as f32
}
}
pub fn age(&self) -> std::time::Duration {
self.created_at.elapsed().unwrap_or_default()
}
}
impl Shard {
pub fn create(id: ShardId, capacity: usize, vector_size: usize, directory: PathBuf) -> Result<Self, ShardexError> {
if capacity == 0 {
return Err(ShardexError::Config("Shard capacity cannot be zero".to_string()));
}
if vector_size == 0 {
return Err(ShardexError::Config("Vector size cannot be zero".to_string()));
}
std::fs::create_dir_all(&directory).map_err(ShardexError::Io)?;
let vector_path = directory.join(format!("{}.vectors", id));
let posting_path = directory.join(format!("{}.postings", id));
let temp_vector_path = directory.join(format!("{}.vectors.tmp", id));
let temp_posting_path = directory.join(format!("{}.postings.tmp", id));
let _vector_storage = VectorStorage::create(&temp_vector_path, vector_size, capacity)
.map_err(|e| ShardexError::Shard(format!("Failed to create vector storage: {}", e)))?;
let _posting_storage = PostingStorage::create(&temp_posting_path, capacity).map_err(|e| {
let _ = std::fs::remove_file(&temp_vector_path);
ShardexError::Shard(format!("Failed to create posting storage: {}", e))
})?;
std::fs::rename(&temp_vector_path, &vector_path).map_err(|e| {
let _ = std::fs::remove_file(&temp_vector_path);
let _ = std::fs::remove_file(&temp_posting_path);
ShardexError::Io(e)
})?;
std::fs::rename(&temp_posting_path, &posting_path).map_err(|e| {
let _ = std::fs::remove_file(&vector_path);
let _ = std::fs::remove_file(&temp_posting_path);
ShardexError::Io(e)
})?;
let vector_storage = VectorStorage::open(&vector_path)
.map_err(|e| ShardexError::Shard(format!("Failed to reopen vector storage: {}", e)))?;
let posting_storage = PostingStorage::open(&posting_path)
.map_err(|e| ShardexError::Shard(format!("Failed to reopen posting storage: {}", e)))?;
let mut metadata = ShardMetadata::new(false, capacity)?;
metadata.update_from_storages(&vector_storage, &posting_storage);
let centroid = vec![0.0; vector_size];
Ok(Self {
id,
vector_storage,
posting_storage,
capacity,
vector_size,
directory,
metadata,
centroid,
active_vector_count: 0,
})
}
pub fn open(id: ShardId, directory: &Path) -> Result<Self, ShardexError> {
let vector_path = directory.join(format!("{}.vectors", id));
let posting_path = directory.join(format!("{}.postings", id));
if !vector_path.exists() {
return Err(ShardexError::Shard(format!(
"Vector storage file not found: {}",
vector_path.display()
)));
}
if !posting_path.exists() {
return Err(ShardexError::Shard(format!(
"Posting storage file not found: {}",
posting_path.display()
)));
}
let vector_storage = VectorStorage::open(&vector_path)
.map_err(|e| ShardexError::Shard(format!("Failed to open vector storage: {}", e)))?;
let posting_storage = PostingStorage::open(&posting_path)
.map_err(|e| ShardexError::Shard(format!("Failed to open posting storage: {}", e)))?;
let capacity = vector_storage.capacity();
let vector_size = vector_storage.vector_dimension();
if posting_storage.capacity() != capacity {
return Err(ShardexError::Corruption(format!(
"Capacity mismatch: vector storage has {}, posting storage has {}",
capacity,
posting_storage.capacity()
)));
}
let mut metadata = ShardMetadata::new(
vector_storage.is_read_only() || posting_storage.is_read_only(),
capacity,
)?;
metadata.update_from_storages(&vector_storage, &posting_storage);
let centroid = vec![0.0; vector_size];
let active_vector_count = vector_storage.active_count();
let mut shard = Self {
id,
vector_storage,
posting_storage,
capacity,
vector_size,
directory: directory.to_path_buf(),
metadata,
centroid,
active_vector_count,
};
if active_vector_count > 0 {
shard.recalculate_centroid();
}
shard.populate_bloom_filter()?;
Ok(shard)
}
pub fn open_read_only(id: ShardId, directory: &Path) -> Result<Self, ShardexError> {
let vector_path = directory.join(format!("{}.vectors", id));
let posting_path = directory.join(format!("{}.postings", id));
if !vector_path.exists() {
return Err(ShardexError::Shard(format!(
"Vector storage file not found: {}",
vector_path.display()
)));
}
if !posting_path.exists() {
return Err(ShardexError::Shard(format!(
"Posting storage file not found: {}",
posting_path.display()
)));
}
let vector_storage = VectorStorage::open_read_only(&vector_path)
.map_err(|e| ShardexError::Shard(format!("Failed to open vector storage: {}", e)))?;
let posting_storage = PostingStorage::open_read_only(&posting_path)
.map_err(|e| ShardexError::Shard(format!("Failed to open posting storage: {}", e)))?;
let capacity = vector_storage.capacity();
let vector_size = vector_storage.vector_dimension();
if posting_storage.capacity() != capacity {
return Err(ShardexError::Corruption(format!(
"Capacity mismatch: vector storage has {}, posting storage has {}",
capacity,
posting_storage.capacity()
)));
}
let mut metadata = ShardMetadata::new(true, capacity)?;
metadata.update_from_storages(&vector_storage, &posting_storage);
let centroid = vec![0.0; vector_size];
let active_vector_count = vector_storage.active_count();
let mut shard = Self {
id,
vector_storage,
posting_storage,
capacity,
vector_size,
directory: directory.to_path_buf(),
metadata,
centroid,
active_vector_count,
};
if active_vector_count > 0 {
shard.recalculate_centroid();
}
shard.populate_bloom_filter()?;
Ok(shard)
}
pub fn id(&self) -> ShardId {
self.id
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn vector_size(&self) -> usize {
self.vector_size
}
pub fn directory(&self) -> &Path {
&self.directory
}
pub fn current_count(&self) -> usize {
self.metadata.current_count
}
pub fn active_count(&self) -> usize {
self.iter_unique_postings_backward()
.filter(|result| {
match result {
Ok((index, _)) => !self.is_deleted(*index).unwrap_or(true),
Err(_) => false, }
})
.count()
}
pub fn is_read_only(&self) -> bool {
self.metadata.read_only
}
pub fn is_full(&self) -> bool {
self.current_count() >= self.capacity()
}
pub fn remaining_capacity(&self) -> usize {
self.capacity().saturating_sub(self.current_count())
}
pub fn available_capacity(&self) -> usize {
self.remaining_capacity()
}
pub fn metadata(&self) -> &ShardMetadata {
&self.metadata
}
pub async fn contains_document(&self, doc_id: DocumentId) -> bool {
self.metadata.bloom_filter.contains(doc_id)
}
pub fn add_posting(&mut self, posting: Posting) -> Result<usize, ShardexError> {
if self.is_read_only() {
return Err(ShardexError::Config(
"Cannot add posting to read-only shard".to_string(),
));
}
if posting.vector.len() != self.vector_size {
return Err(ShardexError::InvalidDimension {
expected: self.vector_size,
actual: posting.vector.len(),
});
}
if self.is_full() {
return Err(ShardexError::Shard("Shard is at capacity".to_string()));
}
let vector_index = self
.vector_storage
.add_vector(&posting.vector)
.map_err(|e| ShardexError::Shard(format!("Failed to add vector: {}", e)))?;
let posting_index = self
.posting_storage
.add_posting(posting.document_id, posting.start, posting.length)
.map_err(|e| {
let _ = self.vector_storage.remove_vector(vector_index);
ShardexError::Shard(format!("Failed to add posting: {}", e))
})?;
if vector_index != posting_index {
let _ = self.vector_storage.remove_vector(vector_index);
let _ = self.posting_storage.remove_posting(posting_index);
return Err(ShardexError::Corruption(format!(
"Index mismatch: vector index {} != posting index {}",
vector_index, posting_index
)));
}
self.metadata.bloom_filter.insert(posting.document_id);
self.update_centroid_add(&posting.vector);
self.metadata
.update_from_storages(&self.vector_storage, &self.posting_storage);
Ok(vector_index)
}
pub fn get_posting(&self, index: usize) -> Result<Posting, ShardexError> {
if index >= self.current_count() {
return Err(ShardexError::Config(format!(
"Index {} out of bounds (current count: {})",
index,
self.current_count()
)));
}
let vector = self
.vector_storage
.get_vector(index)
.map_err(|e| ShardexError::Shard(format!("Failed to get vector: {}", e)))?;
let (document_id, start, length) = self
.posting_storage
.get_posting(index)
.map_err(|e| ShardexError::Shard(format!("Failed to get posting: {}", e)))?;
let posting = Posting::new(document_id, start, length, vector.to_vec(), self.vector_size)?;
Ok(posting)
}
pub fn iter_postings_backward(&self) -> impl Iterator<Item = Result<(usize, Posting), ShardexError>> + '_ {
(0..self.current_count())
.rev()
.map(move |index| match self.get_posting(index) {
Ok(posting) => Ok((index, posting)),
Err(e) => Err(e),
})
}
pub fn iter_unique_postings_backward(&self) -> impl Iterator<Item = Result<(usize, Posting), ShardexError>> + '_ {
use std::collections::HashSet;
let mut seen = HashSet::new();
self.iter_postings_backward().filter_map(move |result| {
match result {
Ok((index, posting)) => {
let key = (posting.document_id, posting.start, posting.length);
if seen.insert(key) {
Some(Ok((index, posting)))
} else {
None }
}
Err(e) => Some(Err(e)),
}
})
}
pub fn is_deleted(&self, index: usize) -> Result<bool, ShardexError> {
if index >= self.current_count() {
return Ok(false);
}
let vector_deleted = self
.vector_storage
.is_deleted(index)
.map_err(|e| ShardexError::Shard(format!("Failed to check vector deletion: {}", e)))?;
let posting_deleted = self
.posting_storage
.is_deleted(index)
.map_err(|e| ShardexError::Shard(format!("Failed to check posting deletion: {}", e)))?;
if vector_deleted != posting_deleted {
return Err(ShardexError::Corruption(format!(
"Deletion state mismatch at index {}: vector={}, posting={}",
index, vector_deleted, posting_deleted
)));
}
Ok(vector_deleted)
}
pub fn remove_posting(&mut self, index: usize) -> Result<(), ShardexError> {
if self.is_read_only() {
return Err(ShardexError::Config(
"Cannot remove posting from read-only shard".to_string(),
));
}
if index >= self.current_count() {
return Err(ShardexError::Config(format!(
"Index {} out of bounds (current count: {})",
index,
self.current_count()
)));
}
let vector = self
.vector_storage
.get_vector(index)
.map_err(|e| ShardexError::Shard(format!("Failed to get vector for centroid update: {}", e)))?;
let vector_copy = vector.to_vec();
self.vector_storage
.remove_vector(index)
.map_err(|e| ShardexError::Shard(format!("Failed to remove vector: {}", e)))?;
self.posting_storage
.remove_posting(index)
.map_err(|e| ShardexError::Shard(format!("Failed to remove posting: {}", e)))?;
self.update_centroid_remove(&vector_copy);
self.metadata
.update_from_storages(&self.vector_storage, &self.posting_storage);
Ok(())
}
pub fn remove_document(&mut self, doc_id: DocumentId) -> Result<usize, ShardexError> {
if self.is_read_only() {
return Err(ShardexError::Config(
"Cannot remove documents from read-only shard".to_string(),
));
}
if !self.metadata.bloom_filter.contains(doc_id) {
return Ok(0); }
let mut removed_vectors = Vec::new();
for index in 0..self.current_count() {
if self.is_deleted(index)? {
continue;
}
let (posting_doc_id, _, _) = self
.posting_storage
.get_posting(index)
.map_err(|e| ShardexError::Shard(format!("Failed to get posting: {}", e)))?;
if posting_doc_id == doc_id {
let vector = self
.vector_storage
.get_vector(index)
.map_err(|e| ShardexError::Shard(format!("Failed to get vector: {}", e)))?;
removed_vectors.push(vector.to_vec());
}
}
let mut removed_count = 0;
for index in (0..self.current_count()).rev() {
if self.is_deleted(index)? {
continue;
}
let (posting_doc_id, _, _) = self
.posting_storage
.get_posting(index)
.map_err(|e| ShardexError::Shard(format!("Failed to get posting: {}", e)))?;
if posting_doc_id == doc_id {
self.vector_storage
.remove_vector(index)
.map_err(|e| ShardexError::Shard(format!("Failed to remove vector: {}", e)))?;
self.posting_storage
.remove_posting(index)
.map_err(|e| ShardexError::Shard(format!("Failed to remove posting: {}", e)))?;
removed_count += 1;
}
}
for vector in &removed_vectors {
self.update_centroid_remove(vector);
}
self.metadata
.update_from_storages(&self.vector_storage, &self.posting_storage);
Ok(removed_count)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>, ShardexError> {
if query.len() != self.vector_size {
return Err(ShardexError::InvalidDimension {
expected: self.vector_size,
actual: query.len(),
});
}
if k == 0 {
return Ok(Vec::new());
}
let mut results = Vec::new();
for result in self.iter_unique_postings_backward() {
let (index, posting) = result?;
let vector = self
.vector_storage
.get_vector(index)
.map_err(|e| ShardexError::Shard(format!("Failed to get vector: {}", e)))?;
let similarity_score = Self::calculate_cosine_similarity(query, vector);
let search_result = SearchResult::from_posting(posting, similarity_score)?;
results.push(search_result);
}
results.sort_by(|a, b| {
b.similarity_score
.partial_cmp(&a.similarity_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
Ok(results)
}
fn calculate_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
if a.iter().any(|x| !x.is_finite()) || b.iter().any(|x| !x.is_finite()) {
return 0.5; }
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 || !norm_a.is_finite() || !norm_b.is_finite() {
return 0.5; }
let cosine = dot_product / (norm_a * norm_b);
if !cosine.is_finite() {
return 0.5; }
let cosine = cosine.clamp(-1.0, 1.0);
(cosine + 1.0) / 2.0
}
pub fn search_with_metric(
&self,
query: &[f32],
k: usize,
metric: DistanceMetric,
) -> Result<Vec<SearchResult>, ShardexError> {
if query.len() != self.vector_size {
return Err(ShardexError::InvalidDimension {
expected: self.vector_size,
actual: query.len(),
});
}
if k == 0 {
return Ok(Vec::new());
}
let mut results = Vec::new();
for result in self.iter_unique_postings_backward() {
let (index, posting) = result?;
let vector = self
.vector_storage
.get_vector(index)
.map_err(|e| ShardexError::Shard(format!("Failed to get vector: {}", e)))?;
let similarity_score = metric.similarity(query, vector)?;
let search_result = SearchResult::from_posting(posting, similarity_score)?;
results.push(search_result);
}
results.sort_by(|a, b| {
b.similarity_score
.partial_cmp(&a.similarity_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
Ok(results)
}
pub fn sync(&mut self) -> Result<(), ShardexError> {
if self.is_read_only() {
return Ok(());
}
self.vector_storage
.sync()
.map_err(|e| ShardexError::Shard(format!("Failed to sync vector storage: {}", e)))?;
self.posting_storage
.sync()
.map_err(|e| ShardexError::Shard(format!("Failed to sync posting storage: {}", e)))?;
self.metadata
.update_from_storages(&self.vector_storage, &self.posting_storage);
Ok(())
}
pub fn validate_integrity(&self) -> Result<(), ShardexError> {
self.vector_storage
.validate_integrity()
.map_err(|e| ShardexError::Shard(format!("Vector storage integrity failed: {}", e)))?;
self.posting_storage
.validate_integrity()
.map_err(|e| ShardexError::Shard(format!("Posting storage integrity failed: {}", e)))?;
let vector_count = self.vector_storage.current_count();
let posting_count = self.posting_storage.current_count();
if vector_count != posting_count {
return Err(ShardexError::Corruption(format!(
"Count mismatch: vector storage has {}, posting storage has {}",
vector_count, posting_count
)));
}
let vector_active = self.vector_storage.active_count();
let posting_active = self.posting_storage.active_count();
if vector_active != posting_active {
return Err(ShardexError::Corruption(format!(
"Active count mismatch: vector storage has {}, posting storage has {}",
vector_active, posting_active
)));
}
for i in 0..vector_count {
let vector_deleted = self
.vector_storage
.is_deleted(i)
.map_err(|e| ShardexError::Shard(format!("Failed to check vector deletion at {}: {}", i, e)))?;
let posting_deleted = self
.posting_storage
.is_deleted(i)
.map_err(|e| ShardexError::Shard(format!("Failed to check posting deletion at {}: {}", i, e)))?;
if vector_deleted != posting_deleted {
return Err(ShardexError::Corruption(format!(
"Deletion state mismatch at index {}: vector={}, posting={}",
i, vector_deleted, posting_deleted
)));
}
if !vector_deleted {
let _vector = self
.vector_storage
.get_vector(i)
.map_err(|e| ShardexError::Corruption(format!("Cannot read vector at index {}: {}", i, e)))?;
let _posting = self
.posting_storage
.get_posting(i)
.map_err(|e| ShardexError::Corruption(format!("Cannot read posting at index {}: {}", i, e)))?;
}
}
let expected_centroid = self.calculate_centroid();
if self.centroid.len() != self.vector_size {
return Err(ShardexError::Corruption(format!(
"Centroid dimension mismatch: expected {}, got {}",
self.vector_size,
self.centroid.len()
)));
}
let actual_active_count = vector_active;
if self.active_vector_count != actual_active_count {
return Err(ShardexError::Corruption(format!(
"Active vector count mismatch: stored {}, actual {}",
self.active_vector_count, actual_active_count
)));
}
const CENTROID_TOLERANCE: f32 = 1e-5;
for (i, (&stored, &expected)) in self
.centroid
.iter()
.zip(expected_centroid.iter())
.enumerate()
{
let diff = (stored - expected).abs();
if diff > CENTROID_TOLERANCE {
return Err(ShardexError::Corruption(format!(
"Centroid component {} mismatch: stored {}, calculated {} (diff: {})",
i, stored, expected, diff
)));
}
}
Ok(())
}
pub fn calculate_centroid(&self) -> Vec<f32> {
let mut centroid = vec![0.0; self.vector_size];
let mut active_count = 0;
for (index, _posting) in self.iter_unique_postings_backward().flatten() {
if let Ok(vector) = self.vector_storage.get_vector(index) {
for (j, &value) in vector.iter().enumerate() {
centroid[j] += value;
}
active_count += 1;
}
}
if active_count > 0 {
let count_f32 = active_count as f32;
for value in &mut centroid {
*value /= count_f32;
}
}
centroid
}
pub fn get_centroid(&self) -> &[f32] {
&self.centroid
}
pub fn update_centroid_add(&mut self, vector: &[f32]) {
debug_assert_eq!(vector.len(), self.vector_size, "Vector dimension mismatch");
self.active_vector_count += 1;
if self.active_vector_count == 1 {
self.centroid.copy_from_slice(vector);
} else {
let count_f32 = self.active_vector_count as f32;
for (i, &vector_val) in vector.iter().enumerate() {
self.centroid[i] += (vector_val - self.centroid[i]) / count_f32;
}
}
}
pub fn update_centroid_remove(&mut self, vector: &[f32]) {
debug_assert_eq!(vector.len(), self.vector_size, "Vector dimension mismatch");
if self.active_vector_count == 0 {
return; }
if self.active_vector_count == 1 {
for value in &mut self.centroid {
*value = 0.0;
}
self.active_vector_count = 0;
} else {
let old_count_f32 = self.active_vector_count as f32;
self.active_vector_count -= 1;
let new_count_f32 = self.active_vector_count as f32;
for (i, &vector_val) in vector.iter().enumerate() {
self.centroid[i] = (self.centroid[i] * old_count_f32 - vector_val) / new_count_f32;
}
}
}
pub fn recalculate_centroid(&mut self) {
self.centroid = self.calculate_centroid();
let mut count = 0;
for result in self.iter_unique_postings_backward() {
if result.is_ok() {
count += 1;
}
}
self.active_vector_count = count;
}
fn populate_bloom_filter(&mut self) -> Result<(), ShardexError> {
self.metadata.bloom_filter.clear();
let unique_postings: Result<Vec<_>, _> = self.iter_unique_postings_backward().collect();
let unique_postings = unique_postings?;
for (_index, posting) in unique_postings {
self.metadata.bloom_filter.insert(posting.document_id);
}
Ok(())
}
pub fn should_split(&self) -> bool {
if self.capacity == 0 {
return false;
}
let split_threshold = (self.capacity as f64 * 0.9) as usize;
self.current_count() >= split_threshold
}
pub async fn split(&self) -> Result<(Shard, Shard), ShardexError> {
if self.is_read_only() {
return Err(ShardexError::Config("Cannot split read-only shard".to_string()));
}
if self.active_count() < 2 {
return Err(ShardexError::Config(
"Cannot split shard with less than 2 active postings".to_string(),
));
}
let unique_postings: Result<Vec<(usize, Posting)>, ShardexError> =
self.iter_unique_postings_backward().collect();
let unique_postings = unique_postings?;
let indices: Vec<usize> = unique_postings.iter().map(|(index, _)| *index).collect();
let (cluster_a_indices, cluster_b_indices) = self.cluster_unique_vectors(&indices)?;
let unique_count = unique_postings.len();
let has_duplicates = self.current_count() > unique_count;
let new_capacity = if has_duplicates {
let min_capacity_needed = (unique_count + 1) / 2 + 5;
std::cmp::max(self.capacity / 2, min_capacity_needed)
} else {
std::cmp::max(self.capacity / 2, 5)
};
let shard_a_id = ShardId::new();
let shard_b_id = ShardId::new();
let mut shard_a = Shard::create(shard_a_id, new_capacity, self.vector_size, self.directory.clone())?;
let mut shard_b = Shard::create(shard_b_id, new_capacity, self.vector_size, self.directory.clone())?;
for &index in &cluster_a_indices {
if let Some((_, posting)) = unique_postings.iter().find(|(i, _)| *i == index) {
shard_a.add_posting(posting.clone())?;
}
}
for &index in &cluster_b_indices {
if let Some((_, posting)) = unique_postings.iter().find(|(i, _)| *i == index) {
shard_b.add_posting(posting.clone())?;
}
}
Ok((shard_a, shard_b))
}
fn cluster_unique_vectors(&self, indices: &[usize]) -> Result<(Vec<usize>, Vec<usize>), ShardexError> {
if indices.len() < 2 {
return Err(ShardexError::Config(
"Need at least 2 vectors for clustering".to_string(),
));
}
if indices.len() <= 4 {
let mid = indices.len() / 2;
let cluster_a = indices[..mid].to_vec();
let cluster_b = indices[mid..].to_vec();
return Ok((cluster_a, cluster_b));
}
let (centroid_a_idx, centroid_b_idx) = self.find_furthest_pair(indices)?;
let first_vector = self
.vector_storage
.get_vector(indices[0])
.map_err(|e| ShardexError::Shard(format!("Failed to get first vector: {}", e)))?;
let mut all_identical = true;
for &idx in &indices[1..] {
let vector = self
.vector_storage
.get_vector(idx)
.map_err(|e| ShardexError::Shard(format!("Failed to get vector for identity check: {}", e)))?;
if Self::euclidean_distance(first_vector, vector) > 1e-6 {
all_identical = false;
break;
}
}
if all_identical {
let mid = indices.len() / 2;
let cluster_a = indices[..mid].to_vec();
let cluster_b = indices[mid..].to_vec();
return Ok((cluster_a, cluster_b));
}
let centroid_a = self
.vector_storage
.get_vector(centroid_a_idx)
.map_err(|e| ShardexError::Shard(format!("Failed to get centroid A vector: {}", e)))?
.to_vec();
let centroid_b = self
.vector_storage
.get_vector(centroid_b_idx)
.map_err(|e| ShardexError::Shard(format!("Failed to get centroid B vector: {}", e)))?
.to_vec();
let mut cluster_a_indices = Vec::new();
let mut cluster_b_indices = Vec::new();
let max_iterations = 10;
for iteration in 0..max_iterations {
cluster_a_indices.clear();
cluster_b_indices.clear();
for &idx in indices {
let vector = self
.vector_storage
.get_vector(idx)
.map_err(|e| ShardexError::Shard(format!("Failed to get vector during clustering: {}", e)))?;
let dist_a = Self::euclidean_distance(¢roid_a, vector);
let dist_b = Self::euclidean_distance(¢roid_b, vector);
if dist_a <= dist_b {
cluster_a_indices.push(idx);
} else {
cluster_b_indices.push(idx);
}
}
if cluster_a_indices.is_empty() || cluster_b_indices.is_empty() {
let mid = indices.len() / 2;
return Ok((indices[..mid].to_vec(), indices[mid..].to_vec()));
}
if iteration > 0 {
break;
}
}
Ok((cluster_a_indices, cluster_b_indices))
}
fn find_furthest_pair(&self, indices: &[usize]) -> Result<(usize, usize), ShardexError> {
if indices.len() < 2 {
return Err(ShardexError::Config(
"Need at least 2 vectors to find furthest pair".to_string(),
));
}
let mut max_distance = 0.0;
let mut furthest_pair = (indices[0], indices[1]);
for i in 0..indices.len() {
for j in i + 1..indices.len() {
let vector_i = self
.vector_storage
.get_vector(indices[i])
.map_err(|e| ShardexError::Shard(format!("Failed to get vector for furthest pair: {}", e)))?;
let vector_j = self
.vector_storage
.get_vector(indices[j])
.map_err(|e| ShardexError::Shard(format!("Failed to get vector for furthest pair: {}", e)))?;
let distance = Self::euclidean_distance(vector_i, vector_j);
if distance > max_distance {
max_distance = distance;
furthest_pair = (indices[i], indices[j]);
}
}
}
Ok(furthest_pair)
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::identifiers::DocumentId;
use tempfile::TempDir;
#[test]
fn test_shard_creation() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let shard = Shard::create(shard_id, 100, 128, temp_dir.path().to_path_buf()).unwrap();
assert_eq!(shard.id(), shard_id);
assert_eq!(shard.capacity(), 100);
assert_eq!(shard.vector_size(), 128);
assert_eq!(shard.current_count(), 0);
assert_eq!(shard.active_count(), 0);
assert!(!shard.is_read_only());
assert!(!shard.is_full());
assert_eq!(shard.remaining_capacity(), 100);
let vector_path = temp_dir.path().join(format!("{}.vectors", shard_id));
let posting_path = temp_dir.path().join(format!("{}.postings", shard_id));
assert!(vector_path.exists());
assert!(posting_path.exists());
}
#[test]
fn test_shard_creation_validation() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let result = Shard::create(shard_id, 0, 128, temp_dir.path().to_path_buf());
assert!(result.is_err());
let result = Shard::create(shard_id, 100, 0, temp_dir.path().to_path_buf());
assert!(result.is_err());
}
#[test]
fn test_shard_add_and_get_posting() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
let vector = vec![1.0, 2.0, 3.0];
let posting = Posting::new(doc_id, 100, 50, vector.clone(), 3).unwrap();
let index = shard.add_posting(posting.clone()).unwrap();
assert_eq!(index, 0);
assert_eq!(shard.current_count(), 1);
assert_eq!(shard.active_count(), 1);
let retrieved = shard.get_posting(index).unwrap();
assert_eq!(retrieved.document_id, posting.document_id);
assert_eq!(retrieved.start, posting.start);
assert_eq!(retrieved.length, posting.length);
assert_eq!(retrieved.vector, posting.vector);
}
#[test]
fn test_shard_vector_dimension_validation() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
let wrong_vector = vec![1.0, 2.0]; let posting = Posting::new(doc_id, 100, 50, wrong_vector, 2).unwrap();
let result = shard.add_posting(posting);
match result {
Err(ShardexError::InvalidDimension { expected, actual }) => {
assert_eq!(expected, 3);
assert_eq!(actual, 2);
}
_ => panic!("Expected InvalidDimension error"),
}
}
#[test]
fn test_shard_capacity_limits() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(
shard_id,
2, 2,
temp_dir.path().to_path_buf(),
)
.unwrap();
let doc_id = DocumentId::new();
let vector = vec![1.0, 2.0];
let posting1 = Posting::new(doc_id, 100, 50, vector.clone(), 2).unwrap();
shard.add_posting(posting1).unwrap();
let posting2 = Posting::new(doc_id, 200, 75, vector.clone(), 2).unwrap();
shard.add_posting(posting2).unwrap();
assert!(shard.is_full());
assert_eq!(shard.remaining_capacity(), 0);
let posting3 = Posting::new(doc_id, 300, 25, vector, 2).unwrap();
let result = shard.add_posting(posting3);
assert!(matches!(result, Err(ShardexError::Shard(_))));
}
#[test]
fn test_shard_remove_posting() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let vector = vec![1.0, 2.0, 3.0];
let posting1 = Posting::new(doc_id1, 100, 50, vector.clone(), 3).unwrap();
let posting2 = Posting::new(doc_id2, 200, 75, vector, 3).unwrap();
let idx1 = shard.add_posting(posting1).unwrap();
let idx2 = shard.add_posting(posting2).unwrap();
assert_eq!(shard.active_count(), 2);
shard.remove_posting(idx1).unwrap();
assert_eq!(shard.current_count(), 2); assert_eq!(shard.active_count(), 1); assert!(shard.is_deleted(idx1).unwrap());
assert!(!shard.is_deleted(idx2).unwrap());
let retrieved = shard.get_posting(idx2).unwrap();
assert_eq!(retrieved.document_id, doc_id2);
}
#[test]
fn test_shard_persistence() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let directory = temp_dir.path().to_path_buf();
let postings_to_add = vec![
(DocumentId::new(), 100, 50, vec![1.0, 2.0, 3.0]),
(DocumentId::new(), 200, 75, vec![4.0, 5.0, 6.0]),
(DocumentId::new(), 300, 25, vec![7.0, 8.0, 9.0]),
];
{
let mut shard = Shard::create(shard_id, 10, 3, directory.clone()).unwrap();
for (doc_id, start, length, vector) in &postings_to_add {
let posting = Posting::new(*doc_id, *start, *length, vector.clone(), 3).unwrap();
shard.add_posting(posting).unwrap();
}
shard.sync().unwrap();
}
{
let shard = Shard::open(shard_id, &directory).unwrap();
assert_eq!(shard.id(), shard_id);
assert_eq!(shard.capacity(), 10);
assert_eq!(shard.vector_size(), 3);
assert_eq!(shard.current_count(), 3);
assert_eq!(shard.active_count(), 3);
for (i, (expected_doc_id, expected_start, expected_length, expected_vector)) in
postings_to_add.iter().enumerate()
{
let retrieved = shard.get_posting(i).unwrap();
assert_eq!(retrieved.document_id, *expected_doc_id);
assert_eq!(retrieved.start, *expected_start);
assert_eq!(retrieved.length, *expected_length);
assert_eq!(retrieved.vector, *expected_vector);
}
}
}
#[test]
fn test_shard_read_only_mode() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let directory = temp_dir.path().to_path_buf();
{
let mut shard = Shard::create(shard_id, 5, 2, directory.clone()).unwrap();
let doc_id = DocumentId::new();
let vector = vec![1.0, 2.0];
let posting = Posting::new(doc_id, 100, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
shard.sync().unwrap();
}
{
let mut shard = Shard::open_read_only(shard_id, &directory).unwrap();
assert!(shard.is_read_only());
assert_eq!(shard.current_count(), 1);
let retrieved = shard.get_posting(0).unwrap();
assert_eq!(retrieved.start, 100);
assert_eq!(retrieved.length, 50);
let new_doc_id = DocumentId::new();
let new_vector = vec![3.0, 4.0];
let new_posting = Posting::new(new_doc_id, 200, 75, new_vector, 2).unwrap();
assert!(shard.add_posting(new_posting).is_err());
assert!(shard.remove_posting(0).is_err());
}
}
#[test]
fn test_shard_integrity_validation() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
for i in 0..5 {
let doc_id = DocumentId::new();
let vector = vec![i as f32, (i + 1) as f32, (i + 2) as f32];
let posting = Posting::new(doc_id, i * 100, 50, vector, 3).unwrap();
shard.add_posting(posting).unwrap();
}
shard.remove_posting(1).unwrap();
shard.remove_posting(3).unwrap();
shard.validate_integrity().unwrap();
assert_eq!(shard.current_count(), 5);
assert_eq!(shard.active_count(), 3);
}
#[test]
fn test_shard_open_nonexistent() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let result = Shard::open(shard_id, temp_dir.path());
assert!(matches!(result, Err(ShardexError::Shard(_))));
}
#[test]
fn test_shard_metadata() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let metadata = shard.metadata();
assert_eq!(metadata.current_count, 0);
assert_eq!(metadata.active_count, 0);
assert!(!metadata.read_only);
assert!(metadata.age().as_secs() < 2);
let doc_id = DocumentId::new();
let vector = vec![1.0, 2.0, 3.0];
let posting = Posting::new(doc_id, 100, 50, vector, 3).unwrap();
shard.add_posting(posting).unwrap();
let metadata = shard.metadata();
assert_eq!(metadata.current_count, 1);
assert_eq!(metadata.active_count, 1);
assert!(metadata.utilization(shard.capacity()) > 0.0);
}
#[test]
fn test_shard_out_of_bounds() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let result = shard.get_posting(0);
assert!(matches!(result, Err(ShardexError::Config(_))));
let result = shard.get_posting(5);
assert!(matches!(result, Err(ShardexError::Config(_))));
}
#[test]
fn test_shard_available_capacity() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 5, 2, temp_dir.path().to_path_buf()).unwrap();
assert_eq!(shard.available_capacity(), 5);
assert_eq!(shard.remaining_capacity(), 5);
let doc_id = DocumentId::new();
let vector = vec![1.0, 2.0];
let posting = Posting::new(doc_id, 100, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
assert_eq!(shard.available_capacity(), 4);
assert_eq!(shard.remaining_capacity(), 4);
for i in 1..5 {
let doc_id = DocumentId::new();
let vector = vec![i as f32, (i + 1) as f32];
let posting = Posting::new(doc_id, i * 100, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
}
assert_eq!(shard.available_capacity(), 0);
assert!(shard.is_full());
}
#[test]
fn test_shard_remove_document_single() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let posting1 = Posting::new(doc_id1, 100, 50, vec![1.0, 2.0, 3.0], 3).unwrap();
let posting2 = Posting::new(doc_id2, 200, 75, vec![4.0, 5.0, 6.0], 3).unwrap();
let posting3 = Posting::new(doc_id1, 300, 25, vec![7.0, 8.0, 9.0], 3).unwrap();
shard.add_posting(posting1).unwrap();
shard.add_posting(posting2).unwrap();
shard.add_posting(posting3).unwrap();
assert_eq!(shard.active_count(), 3);
let removed_count = shard.remove_document(doc_id1).unwrap();
assert_eq!(removed_count, 2);
assert_eq!(shard.current_count(), 3); assert_eq!(shard.active_count(), 1);
assert!(!shard.is_deleted(1).unwrap()); assert!(shard.is_deleted(0).unwrap()); assert!(shard.is_deleted(2).unwrap()); }
#[test]
fn test_shard_remove_document_multiple() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
for i in 0..5 {
let vector = vec![i as f32, (i + 1) as f32];
let posting = Posting::new(doc_id, i * 100, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
}
assert_eq!(shard.active_count(), 5);
let removed_count = shard.remove_document(doc_id).unwrap();
assert_eq!(removed_count, 5);
assert_eq!(shard.current_count(), 5); assert_eq!(shard.active_count(), 0); }
#[test]
fn test_shard_remove_document_nonexistent() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let posting = Posting::new(doc_id1, 100, 50, vec![1.0, 2.0, 3.0], 3).unwrap();
shard.add_posting(posting).unwrap();
let removed_count = shard.remove_document(doc_id2).unwrap();
assert_eq!(removed_count, 0);
assert_eq!(shard.active_count(), 1);
assert!(!shard.is_deleted(0).unwrap());
}
#[test]
fn test_shard_remove_document_read_only() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let directory = temp_dir.path().to_path_buf();
{
let mut shard = Shard::create(shard_id, 5, 2, directory.clone()).unwrap();
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 100, 50, vec![1.0, 2.0], 2).unwrap();
shard.add_posting(posting).unwrap();
shard.sync().unwrap();
}
{
let mut shard = Shard::open_read_only(shard_id, &directory).unwrap();
let doc_id = DocumentId::new();
let result = shard.remove_document(doc_id);
assert!(matches!(result, Err(ShardexError::Config(_))));
}
}
#[test]
fn test_shard_search_basic() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let doc_id3 = DocumentId::new();
let posting1 = Posting::new(doc_id1, 100, 50, vec![1.0, 0.0, 0.0], 3).unwrap(); let posting2 = Posting::new(doc_id2, 200, 75, vec![0.0, 1.0, 0.0], 3).unwrap(); let posting3 = Posting::new(doc_id3, 300, 25, vec![0.0, 0.0, 1.0], 3).unwrap();
shard.add_posting(posting1).unwrap();
shard.add_posting(posting2).unwrap();
shard.add_posting(posting3).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = shard.search(&query, 3).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].document_id, doc_id1); assert_eq!(results[0].similarity_score, 1.0);
assert_eq!(results[1].similarity_score, 0.5);
assert_eq!(results[2].similarity_score, 0.5);
}
#[test]
fn test_shard_search_k_limit() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let vectors = [
[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.5, 0.5], [1.0, 1.0], ];
for (i, vector) in vectors.iter().enumerate() {
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, i as u32 * 100, 50, vector.to_vec(), 2).unwrap();
shard.add_posting(posting).unwrap();
}
let query = vec![1.0, 0.0]; let results = shard.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].similarity_score >= results[1].similarity_score);
for result in &results {
assert!(
result.similarity_score >= 0.0 && result.similarity_score <= 1.0,
"Similarity score {} is out of valid range [0.0, 1.0]",
result.similarity_score
);
}
}
#[test]
fn test_shard_search_with_append_only_updates() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let posting1 = Posting::new(doc_id1, 100, 50, vec![1.0, 0.0], 2).unwrap();
let posting2 = Posting::new(doc_id2, 200, 75, vec![0.0, 1.0], 2).unwrap();
shard.add_posting(posting1).unwrap();
shard.add_posting(posting2).unwrap();
let updated_posting1 = Posting::new(doc_id1, 100, 50, vec![0.5, 0.5], 2).unwrap();
shard.add_posting(updated_posting1).unwrap();
let query = vec![1.0, 1.0];
let results = shard.search(&query, 10).unwrap();
assert_eq!(results.len(), 2);
let doc1_result = results.iter().find(|r| r.document_id == doc_id1).unwrap();
let doc2_result = results.iter().find(|r| r.document_id == doc_id2).unwrap();
assert_eq!(doc1_result.start, 100);
assert_eq!(doc1_result.length, 50);
assert_eq!(doc2_result.document_id, doc_id2);
}
#[test]
fn test_shard_search_dimension_validation() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let wrong_query = vec![1.0, 2.0]; let result = shard.search(&wrong_query, 5);
match result {
Err(ShardexError::InvalidDimension { expected, actual }) => {
assert_eq!(expected, 3);
assert_eq!(actual, 2);
}
_ => panic!("Expected InvalidDimension error"),
}
}
#[test]
fn test_shard_search_empty_k() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 100, 50, vec![1.0, 2.0], 2).unwrap();
shard.add_posting(posting).unwrap();
let query = vec![1.0, 2.0];
let results = shard.search(&query, 0).unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn test_shard_search_empty_shard() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let query = vec![1.0, 2.0];
let results = shard.search(&query, 5).unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn test_cosine_similarity_calculation() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 100, 50, vec![1.0, 0.0], 2).unwrap();
shard.add_posting(posting).unwrap();
let query = vec![2.0, 0.0]; let results = shard.search(&query, 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].similarity_score, 1.0);
let doc_id2 = DocumentId::new();
let posting2 = Posting::new(doc_id2, 200, 50, vec![0.0, 1.0], 2).unwrap();
shard.add_posting(posting2).unwrap();
let results2 = shard.search(&query, 2).unwrap();
assert_eq!(results2.len(), 2);
assert_eq!(results2[0].similarity_score, 1.0); assert_eq!(results2[1].similarity_score, 0.5); }
#[test]
fn test_centroid_calculation_empty_shard() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let centroid = shard.calculate_centroid();
assert_eq!(centroid, vec![0.0, 0.0, 0.0]);
assert_eq!(shard.get_centroid(), &[0.0, 0.0, 0.0]);
assert_eq!(shard.active_vector_count, 0);
}
#[test]
fn test_centroid_single_vector() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
let vector = vec![2.0, 4.0, 6.0];
let posting = Posting::new(doc_id, 100, 50, vector.clone(), 3).unwrap();
shard.add_posting(posting).unwrap();
assert_eq!(shard.get_centroid(), &vector);
assert_eq!(shard.active_vector_count, 1);
let calculated = shard.calculate_centroid();
assert_eq!(calculated, vector);
}
#[test]
fn test_centroid_multiple_vectors() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let doc_id3 = DocumentId::new();
let vectors = [vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let posting1 = Posting::new(doc_id1, 100, 50, vectors[0].clone(), 2).unwrap();
let posting2 = Posting::new(doc_id2, 200, 50, vectors[1].clone(), 2).unwrap();
let posting3 = Posting::new(doc_id3, 300, 50, vectors[2].clone(), 2).unwrap();
shard.add_posting(posting1).unwrap();
shard.add_posting(posting2).unwrap();
shard.add_posting(posting3).unwrap();
let expected_centroid = [3.0, 4.0];
let centroid = shard.get_centroid();
assert_eq!(centroid.len(), 2);
assert_eq!(shard.active_vector_count, 3);
assert!((centroid[0] - expected_centroid[0]).abs() < 1e-6);
assert!((centroid[1] - expected_centroid[1]).abs() < 1e-6);
}
#[test]
fn test_centroid_incremental_updates() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
shard.update_centroid_add(&[2.0, 4.0]);
assert_eq!(shard.get_centroid(), &[2.0, 4.0]);
assert_eq!(shard.active_vector_count, 1);
shard.update_centroid_add(&[4.0, 2.0]);
assert_eq!(shard.get_centroid(), &[3.0, 3.0]); assert_eq!(shard.active_vector_count, 2);
shard.update_centroid_add(&[0.0, 6.0]);
assert_eq!(shard.get_centroid(), &[2.0, 4.0]); assert_eq!(shard.active_vector_count, 3);
shard.update_centroid_remove(&[0.0, 6.0]);
assert_eq!(shard.get_centroid(), &[3.0, 3.0]); assert_eq!(shard.active_vector_count, 2);
shard.update_centroid_remove(&[4.0, 2.0]);
assert_eq!(shard.get_centroid(), &[2.0, 4.0]); assert_eq!(shard.active_vector_count, 1);
shard.update_centroid_remove(&[2.0, 4.0]);
assert_eq!(shard.get_centroid(), &[0.0, 0.0]); assert_eq!(shard.active_vector_count, 0);
}
#[test]
fn test_centroid_with_deleted_vectors() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let doc_id3 = DocumentId::new();
let posting1 = Posting::new(doc_id1, 100, 50, vec![1.0, 2.0], 2).unwrap();
let posting2 = Posting::new(doc_id2, 200, 50, vec![3.0, 4.0], 2).unwrap();
let posting3 = Posting::new(doc_id3, 300, 50, vec![5.0, 6.0], 2).unwrap();
let idx1 = shard.add_posting(posting1).unwrap();
let idx2 = shard.add_posting(posting2).unwrap();
shard.add_posting(posting3).unwrap();
let initial_centroid = shard.get_centroid().to_vec();
assert!((initial_centroid[0] - 3.0).abs() < 1e-6);
assert!((initial_centroid[1] - 4.0).abs() < 1e-6);
assert_eq!(shard.active_vector_count, 3);
shard.remove_posting(idx2).unwrap();
let updated_centroid = shard.get_centroid().to_vec();
assert!((updated_centroid[0] - 3.0).abs() < 1e-6);
assert!((updated_centroid[1] - 4.0).abs() < 1e-6);
assert_eq!(shard.active_vector_count, 2);
shard.remove_posting(idx1).unwrap();
let final_centroid = shard.get_centroid().to_vec();
assert!((final_centroid[0] - 5.0).abs() < 1e-6);
assert!((final_centroid[1] - 6.0).abs() < 1e-6);
assert_eq!(shard.active_vector_count, 1);
}
#[test]
fn test_centroid_recalculation() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let posting1 = Posting::new(doc_id1, 100, 50, vec![2.0, 4.0], 2).unwrap();
let posting2 = Posting::new(doc_id2, 200, 50, vec![6.0, 8.0], 2).unwrap();
shard.add_posting(posting1).unwrap();
shard.add_posting(posting2).unwrap();
let initial_centroid = shard.get_centroid().to_vec();
shard.centroid = vec![999.0, 999.0];
shard.active_vector_count = 999;
shard.recalculate_centroid();
let corrected_centroid = shard.get_centroid().to_vec();
assert!((corrected_centroid[0] - initial_centroid[0]).abs() < 1e-6);
assert!((corrected_centroid[1] - initial_centroid[1]).abs() < 1e-6);
assert_eq!(shard.active_vector_count, 2);
}
#[test]
fn test_centroid_remove_document() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let posting1 = Posting::new(doc_id1, 100, 50, vec![1.0, 2.0], 2).unwrap();
let posting2 = Posting::new(doc_id1, 200, 50, vec![3.0, 4.0], 2).unwrap();
let posting3 = Posting::new(doc_id2, 300, 50, vec![5.0, 6.0], 2).unwrap();
shard.add_posting(posting1).unwrap();
shard.add_posting(posting2).unwrap();
shard.add_posting(posting3).unwrap();
let initial_centroid = shard.get_centroid().to_vec();
assert!((initial_centroid[0] - 3.0).abs() < 1e-6);
assert!((initial_centroid[1] - 4.0).abs() < 1e-6);
assert_eq!(shard.active_vector_count, 3);
let removed_count = shard.remove_document(doc_id1).unwrap();
assert_eq!(removed_count, 2);
let final_centroid = shard.get_centroid().to_vec();
assert!((final_centroid[0] - 5.0).abs() < 1e-6);
assert!((final_centroid[1] - 6.0).abs() < 1e-6);
assert_eq!(shard.active_vector_count, 1);
}
#[test]
fn test_centroid_persistence_on_reopen() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let directory = temp_dir.path().to_path_buf();
let expected_centroid = [3.0, 4.0];
{
let mut shard = Shard::create(shard_id, 10, 2, directory.clone()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let doc_id3 = DocumentId::new();
let posting1 = Posting::new(doc_id1, 100, 50, vec![1.0, 2.0], 2).unwrap();
let posting2 = Posting::new(doc_id2, 200, 50, vec![3.0, 4.0], 2).unwrap();
let posting3 = Posting::new(doc_id3, 300, 50, vec![5.0, 6.0], 2).unwrap();
shard.add_posting(posting1).unwrap();
shard.add_posting(posting2).unwrap();
shard.add_posting(posting3).unwrap();
shard.sync().unwrap();
}
{
let shard = Shard::open(shard_id, &directory).unwrap();
let centroid = shard.get_centroid();
assert!((centroid[0] - expected_centroid[0]).abs() < 1e-6);
assert!((centroid[1] - expected_centroid[1]).abs() < 1e-6);
assert_eq!(shard.active_vector_count, 3);
}
}
#[test]
fn test_centroid_integrity_validation() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 100, 50, vec![1.0, 2.0, 3.0], 3).unwrap();
shard.add_posting(posting).unwrap();
assert!(shard.validate_integrity().is_ok());
shard.centroid = vec![1.0, 2.0];
let result = shard.validate_integrity();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Centroid dimension mismatch"));
shard.centroid = vec![999.0, 999.0, 999.0];
let result = shard.validate_integrity();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Centroid component"));
shard.centroid = vec![1.0, 2.0, 3.0]; shard.active_vector_count = 999;
let result = shard.validate_integrity();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Active vector count mismatch"));
}
#[test]
fn test_shard_search_sorting() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 1, temp_dir.path().to_path_buf()).unwrap();
let vectors = [
[1.0], [-1.0], [0.5], [2.0], ];
for (i, vector) in vectors.iter().enumerate() {
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, i as u32 * 100, 50, vector.to_vec(), 1).unwrap();
shard.add_posting(posting).unwrap();
}
let query = vec![1.0];
let results = shard.search(&query, 4).unwrap();
assert_eq!(results.len(), 4);
assert_eq!(results[0].similarity_score, 1.0); assert_eq!(results[1].similarity_score, 1.0);
assert_eq!(results[2].similarity_score, 1.0);
assert_eq!(results[3].similarity_score, 0.0);
for i in 0..results.len() - 1 {
assert!(results[i].similarity_score >= results[i + 1].similarity_score);
}
}
#[test]
fn test_shard_search_with_metric_cosine() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let vectors = [
[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], ];
for (i, vector) in vectors.iter().enumerate() {
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, i as u32 * 100, 50, vector.to_vec(), 3).unwrap();
shard.add_posting(posting).unwrap();
}
let query = vec![1.0, 0.0, 0.0];
let results = shard
.search_with_metric(&query, 3, DistanceMetric::Cosine)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].similarity_score, 1.0); assert_eq!(results[1].similarity_score, 0.5); assert_eq!(results[2].similarity_score, 0.0); }
#[test]
fn test_shard_search_with_metric_euclidean() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let vectors = [
[0.0, 0.0], [1.0, 0.0], [3.0, 4.0], ];
for (i, vector) in vectors.iter().enumerate() {
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, i as u32 * 100, 50, vector.to_vec(), 2).unwrap();
shard.add_posting(posting).unwrap();
}
let query = vec![0.0, 0.0];
let results = shard
.search_with_metric(&query, 3, DistanceMetric::Euclidean)
.unwrap();
assert_eq!(results.len(), 3);
assert!(results[0].similarity_score > results[1].similarity_score);
assert!(results[1].similarity_score > results[2].similarity_score);
assert_eq!(results[0].similarity_score, 1.0); assert!((results[1].similarity_score - 0.5).abs() < 1e-6); assert!((results[2].similarity_score - 1.0 / 6.0).abs() < 1e-6); }
#[test]
fn test_shard_search_with_metric_dot_product() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let vectors = [
[2.0, 0.0], [0.0, 1.0], [-1.0, 0.0], ];
for (i, vector) in vectors.iter().enumerate() {
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, i as u32 * 100, 50, vector.to_vec(), 2).unwrap();
shard.add_posting(posting).unwrap();
}
let query = vec![1.0, 0.0];
let results = shard
.search_with_metric(&query, 3, DistanceMetric::DotProduct)
.unwrap();
assert_eq!(results.len(), 3);
assert!(results[0].similarity_score > results[1].similarity_score);
assert!(results[1].similarity_score > results[2].similarity_score);
assert!(results[0].similarity_score > 0.8); assert!((results[1].similarity_score - 0.5).abs() < 0.1); assert!(results[2].similarity_score < 0.4); }
#[test]
fn test_search_with_metric_dimension_validation() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 0, 50, vec![1.0, 2.0, 3.0], 3).unwrap();
shard.add_posting(posting).unwrap();
let wrong_query = vec![1.0, 2.0]; for metric in [
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
] {
let result = shard.search_with_metric(&wrong_query, 5, metric);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ShardexError::InvalidDimension { .. }));
}
}
#[test]
fn test_search_with_metric_empty_k() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 0, 50, vec![1.0, 2.0, 3.0], 3).unwrap();
shard.add_posting(posting).unwrap();
let query = vec![1.0, 0.0, 0.0];
for metric in [
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
] {
let results = shard.search_with_metric(&query, 0, metric).unwrap();
assert!(results.is_empty());
}
}
#[test]
fn test_search_with_metric_vs_regular_search() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let vectors = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]];
for (i, vector) in vectors.iter().enumerate() {
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, i as u32 * 100, 50, vector.to_vec(), 3).unwrap();
shard.add_posting(posting).unwrap();
}
let query = vec![1.0, 0.0, 0.0];
let regular_results = shard.search(&query, 3).unwrap();
let cosine_results = shard
.search_with_metric(&query, 3, DistanceMetric::Cosine)
.unwrap();
assert_eq!(regular_results.len(), cosine_results.len());
for (regular, cosine) in regular_results.iter().zip(cosine_results.iter()) {
assert_eq!(regular.document_id, cosine.document_id);
assert!((regular.similarity_score - cosine.similarity_score).abs() < 1e-6);
}
}
#[test]
fn test_should_split_empty_shard() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let shard = Shard::create(shard_id, 100, 3, temp_dir.path().to_path_buf()).unwrap();
assert!(!shard.should_split());
}
#[test]
fn test_should_split_below_threshold() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
for i in 0..8 {
let doc_id = DocumentId::new();
let vector = vec![i as f32, (i + 1) as f32, (i + 2) as f32];
let posting = Posting::new(doc_id, i * 100, 50, vector, 3).unwrap();
shard.add_posting(posting).unwrap();
}
assert!(!shard.should_split());
}
#[test]
fn test_should_split_at_threshold() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
for i in 0..9 {
let doc_id = DocumentId::new();
let vector = vec![i as f32, (i + 1) as f32, (i + 2) as f32];
let posting = Posting::new(doc_id, i * 100, 50, vector, 3).unwrap();
shard.add_posting(posting).unwrap();
}
assert!(shard.should_split());
}
#[test]
fn test_should_split_zero_capacity() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
shard.capacity = 0;
assert!(!shard.should_split());
}
#[tokio::test]
async fn test_split_read_only_shard() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let directory = temp_dir.path().to_path_buf();
{
let mut shard = Shard::create(shard_id, 10, 2, directory.clone()).unwrap();
for i in 0..9 {
let doc_id = DocumentId::new();
let vector = vec![i as f32, (i + 1) as f32];
let posting = Posting::new(doc_id, i * 100, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
}
shard.sync().unwrap();
}
{
let shard = Shard::open_read_only(shard_id, &directory).unwrap();
assert!(shard.should_split());
let result = shard.split().await;
assert!(matches!(result, Err(ShardexError::Config(_))));
}
}
#[tokio::test]
async fn test_split_insufficient_data() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
let vector = vec![1.0, 2.0];
let posting = Posting::new(doc_id, 100, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
let result = shard.split().await;
assert!(matches!(result, Err(ShardexError::Config(_))));
}
#[tokio::test]
async fn test_debug_split_capacity() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
for i in 0..9 {
let doc_id = DocumentId::new();
let vector = vec![i as f32, i as f32];
let posting = Posting::new(doc_id, 100, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
}
eprintln!("Before split:");
eprintln!(" Original capacity: {}", shard.capacity());
eprintln!(" Should split: {}", shard.should_split());
let (shard_a, shard_b) = shard.split().await.unwrap();
eprintln!("After split:");
eprintln!(" Shard A capacity: {}", shard_a.capacity());
eprintln!(" Shard B capacity: {}", shard_b.capacity());
assert!(shard_a.capacity() > 0);
assert!(shard_b.capacity() > 0);
}
#[tokio::test]
async fn test_split_basic_functionality() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let group_a_vectors = [[1.0, 1.0], [1.1, 1.1], [1.2, 1.2], [0.9, 0.9]];
let group_b_vectors = [[5.0, 5.0], [5.1, 5.1], [4.9, 4.9], [5.2, 5.0]];
for vector in group_a_vectors {
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 100, 50, vector.to_vec(), 2).unwrap();
shard.add_posting(posting).unwrap();
}
for vector in group_b_vectors {
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 200, 50, vector.to_vec(), 2).unwrap();
shard.add_posting(posting).unwrap();
}
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 300, 50, vec![2.5, 2.5], 2).unwrap();
shard.add_posting(posting).unwrap();
assert!(shard.should_split());
assert_eq!(shard.active_count(), 9);
let (shard_a, shard_b) = shard.split().await.unwrap();
assert!(shard_a.active_count() > 0);
assert!(shard_b.active_count() > 0);
assert_eq!(shard_a.active_count() + shard_b.active_count(), 9);
assert_eq!(shard_a.capacity(), 5); assert_eq!(shard_b.capacity(), 5); assert_eq!(shard_a.vector_size(), 2);
assert_eq!(shard_b.vector_size(), 2);
assert_ne!(shard_a.id(), shard_b.id());
assert_ne!(shard_a.id(), shard.id());
assert_ne!(shard_b.id(), shard.id());
}
#[tokio::test]
async fn test_split_balanced_distribution() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 20, 2, temp_dir.path().to_path_buf()).unwrap();
for i in 0..9 {
let doc_id_a = DocumentId::new();
let vector_a = vec![i as f32 * 0.1, i as f32 * 0.1];
let posting_a = Posting::new(doc_id_a, i * 100, 50, vector_a, 2).unwrap();
shard.add_posting(posting_a).unwrap();
let doc_id_b = DocumentId::new();
let vector_b = vec![10.0 + i as f32 * 0.1, 10.0 + i as f32 * 0.1];
let posting_b = Posting::new(doc_id_b, (i + 100) * 100, 50, vector_b, 2).unwrap();
shard.add_posting(posting_b).unwrap();
}
let (shard_a, shard_b) = shard.split().await.unwrap();
let count_diff = (shard_a.active_count() as i32 - shard_b.active_count() as i32).abs();
assert!(
count_diff <= 1,
"Split should be balanced: {} vs {}",
shard_a.active_count(),
shard_b.active_count()
);
assert_eq!(shard_a.active_count() + shard_b.active_count(), 18);
}
#[tokio::test]
async fn test_split_with_append_only_updates() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 20, 2, temp_dir.path().to_path_buf()).unwrap();
let doc_ids: Vec<DocumentId> = (0..15).map(|_| DocumentId::new()).collect();
for (i, &doc_id) in doc_ids.iter().enumerate() {
let vector = vec![i as f32, i as f32];
let posting = Posting::new(doc_id, (i * 100) as u32, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
}
for i in [1, 3, 5] {
let updated_vector = vec![(i + 100) as f32, (i + 100) as f32];
let updated_posting = Posting::new(doc_ids[i], (i * 100) as u32, 50, updated_vector, 2).unwrap();
shard.add_posting(updated_posting).unwrap();
}
assert_eq!(shard.current_count(), 18); assert_eq!(shard.active_count(), 15);
let (shard_a, shard_b) = shard.split().await.unwrap();
assert_eq!(shard_a.active_count() + shard_b.active_count(), 15);
assert!(shard_a.active_count() > 0);
assert!(shard_b.active_count() > 0);
}
#[tokio::test]
async fn test_split_small_dataset() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 3, temp_dir.path().to_path_buf()).unwrap();
let vectors = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]];
for vector in vectors {
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 100, 50, vector.to_vec(), 3).unwrap();
shard.add_posting(posting).unwrap();
}
let (shard_a, shard_b) = shard.split().await.unwrap();
assert_eq!(shard_a.active_count(), 2);
assert_eq!(shard_b.active_count(), 2);
}
#[test]
fn test_euclidean_distance_calculation() {
let a = [0.0, 0.0];
let b = [3.0, 4.0];
let distance = Shard::euclidean_distance(&a, &b);
assert!((distance - 5.0).abs() < 1e-6);
let c = [1.0, 2.0, 3.0];
let d = [1.0, 2.0, 3.0];
let distance2 = Shard::euclidean_distance(&c, &d);
assert!(distance2 < 1e-6);
let e = [1.0, 0.0];
let f = [0.0, 1.0];
let distance3 = Shard::euclidean_distance(&e, &f);
assert!((distance3 - 2.0_f32.sqrt()).abs() < 1e-6);
}
#[tokio::test]
async fn test_split_data_integrity() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 12, 2, temp_dir.path().to_path_buf()).unwrap();
let mut original_postings = Vec::new();
for i in 0..10 {
let doc_id = DocumentId::new();
let vector = vec![i as f32, (i * 2) as f32];
let posting = Posting::new(doc_id, i * 100 + 1000, 50 + i, vector, 2).unwrap();
original_postings.push(posting.clone());
shard.add_posting(posting).unwrap();
}
let (shard_a, shard_b) = shard.split().await.unwrap();
let mut recovered_postings = Vec::new();
for i in 0..shard_a.active_count() {
recovered_postings.push(shard_a.get_posting(i).unwrap());
}
for i in 0..shard_b.active_count() {
recovered_postings.push(shard_b.get_posting(i).unwrap());
}
assert_eq!(recovered_postings.len(), original_postings.len());
for original in &original_postings {
let found = recovered_postings.iter().any(|recovered| {
recovered.document_id == original.document_id
&& recovered.start == original.start
&& recovered.length == original.length
&& recovered.vector == original.vector
});
assert!(found, "Original posting not found in recovered postings");
}
}
#[tokio::test]
async fn test_split_centroid_calculation() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let group_a = [[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]; let group_b = [[8.0, 8.0], [9.0, 9.0], [10.0, 10.0]];
for vector in group_a.iter().chain(group_b.iter()) {
let doc_id = DocumentId::new();
let posting = Posting::new(doc_id, 100, 50, vector.to_vec(), 2).unwrap();
shard.add_posting(posting).unwrap();
}
let (shard_a, shard_b) = shard.split().await.unwrap();
assert_eq!(shard_a.get_centroid().len(), 2);
assert_eq!(shard_b.get_centroid().len(), 2);
let centroid_a = shard_a.get_centroid();
let centroid_b = shard_b.get_centroid();
let distance_between_centroids = Shard::euclidean_distance(centroid_a, centroid_b);
assert!(
distance_between_centroids > 1.0,
"Centroids should be well separated: A={:?}, B={:?}",
centroid_a,
centroid_b
);
}
#[tokio::test]
async fn test_split_minimum_capacity() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 6, 2, temp_dir.path().to_path_buf()).unwrap();
for i in 0..6 {
let doc_id = DocumentId::new();
let vector = vec![i as f32, i as f32];
let posting = Posting::new(doc_id, i * 100, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
}
let (shard_a, shard_b) = shard.split().await.unwrap();
assert_eq!(shard_a.capacity(), 5);
assert_eq!(shard_b.capacity(), 5);
}
#[tokio::test]
async fn test_split_identical_vectors() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 12, 2, temp_dir.path().to_path_buf()).unwrap();
for _i in 0..4 {
let doc_id = DocumentId::new();
let vector = vec![5.0, 5.0]; let posting = Posting::new(doc_id, 100, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
}
let (shard_a, shard_b) = shard.split().await.unwrap();
assert!(shard_a.active_count() > 0);
assert!(shard_b.active_count() > 0);
assert_eq!(shard_a.active_count() + shard_b.active_count(), 4);
let centroid_a = shard_a.get_centroid();
let centroid_b = shard_b.get_centroid();
assert!((centroid_a[0] - 5.0).abs() < 0.1);
assert!((centroid_a[1] - 5.0).abs() < 0.1);
assert!((centroid_b[0] - 5.0).abs() < 0.1);
assert!((centroid_b[1] - 5.0).abs() < 0.1);
}
#[test]
fn test_bloom_filter_integration_basic() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 100, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let doc_id3 = DocumentId::new(); let vector = vec![1.0, 2.0, 3.0];
let posting1 = Posting::new(doc_id1, 100, 50, vector.clone(), 3).unwrap();
let posting2 = Posting::new(doc_id2, 200, 75, vector, 3).unwrap();
shard.add_posting(posting1).unwrap();
shard.add_posting(posting2).unwrap();
assert!(tokio_test::block_on(shard.contains_document(doc_id1)));
assert!(tokio_test::block_on(shard.contains_document(doc_id2)));
let doc_id3_maybe_present = tokio_test::block_on(shard.contains_document(doc_id3));
let removed = shard.remove_document(doc_id3).unwrap();
if !doc_id3_maybe_present {
assert_eq!(removed, 0);
}
let removed = shard.remove_document(doc_id1).unwrap();
assert_eq!(removed, 1);
assert_eq!(shard.active_count(), 1);
}
#[tokio::test]
async fn test_bloom_filter_persistence() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let vector = vec![1.0, 2.0, 3.0];
{
let mut shard = Shard::create(shard_id, 100, 3, temp_dir.path().to_path_buf()).unwrap();
let posting1 = Posting::new(doc_id1, 100, 50, vector.clone(), 3).unwrap();
let posting2 = Posting::new(doc_id2, 200, 75, vector, 3).unwrap();
shard.add_posting(posting1).unwrap();
shard.add_posting(posting2).unwrap();
shard.sync().unwrap();
}
{
let shard = Shard::open(shard_id, temp_dir.path()).unwrap();
assert!(shard.contains_document(doc_id1).await);
assert!(shard.contains_document(doc_id2).await);
}
}
#[tokio::test]
async fn test_bloom_filter_split_maintenance() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 10, 2, temp_dir.path().to_path_buf()).unwrap();
let mut doc_ids = Vec::new();
for i in 0..9 {
let doc_id = DocumentId::new();
doc_ids.push(doc_id);
let vector = vec![i as f32, i as f32];
let posting = Posting::new(doc_id, i * 100, 50, vector, 2).unwrap();
shard.add_posting(posting).unwrap();
}
let (mut shard_a, mut shard_b) = shard.split().await.unwrap();
let mut docs_in_a = 0;
let mut docs_in_b = 0;
for doc_id in &doc_ids {
if shard_a.contains_document(*doc_id).await {
docs_in_a += 1;
}
if shard_b.contains_document(*doc_id).await {
docs_in_b += 1;
}
}
assert!(docs_in_a >= shard_a.active_count());
assert!(docs_in_b >= shard_b.active_count());
let unknown_doc = DocumentId::new();
if !shard_a.contains_document(unknown_doc).await {
let removed_a = shard_a.remove_document(unknown_doc).unwrap();
assert_eq!(removed_a, 0);
}
if !shard_b.contains_document(unknown_doc).await {
let removed_b = shard_b.remove_document(unknown_doc).unwrap();
assert_eq!(removed_b, 0);
}
}
#[tokio::test]
async fn test_bloom_filter_false_positive_handling() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 100, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_ids: Vec<_> = (0..10)
.map(|i| {
let doc_id = DocumentId::new();
let vector = vec![i as f32, i as f32, i as f32];
let posting = Posting::new(doc_id, i * 100, 50, vector, 3).unwrap();
shard.add_posting(posting).unwrap();
doc_id
})
.collect();
let mut false_positives = 0;
let mut true_negatives = 0;
for _ in 0..1000 {
let unknown_doc = DocumentId::new();
if shard.contains_document(unknown_doc).await {
false_positives += 1;
let removed = shard.remove_document(unknown_doc).unwrap();
assert_eq!(removed, 0);
} else {
true_negatives += 1;
}
}
assert!(true_negatives > false_positives);
let fp_rate = false_positives as f64 / 1000.0;
assert!(fp_rate < 0.05, "False positive rate too high: {}", fp_rate);
for doc_id in doc_ids {
assert!(shard.contains_document(doc_id).await);
}
}
#[test]
fn test_bloom_filter_metadata_consistency() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 100, 3, temp_dir.path().to_path_buf()).unwrap();
assert_eq!(shard.metadata.bloom_filter.inserted_count(), 0);
assert_eq!(shard.metadata.bloom_filter.capacity(), 100);
assert!((shard.metadata.bloom_filter.false_positive_rate() - 0.01).abs() < 1e-6);
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let vector = vec![1.0, 2.0, 3.0];
let posting1 = Posting::new(doc_id1, 100, 50, vector.clone(), 3).unwrap();
let posting2 = Posting::new(doc_id2, 200, 75, vector, 3).unwrap();
shard.add_posting(posting1).unwrap();
shard.add_posting(posting2).unwrap();
assert_eq!(shard.metadata.bloom_filter.inserted_count(), 2);
assert!(!shard.metadata.bloom_filter.is_at_capacity());
let stats = shard.metadata.bloom_filter.stats();
assert_eq!(stats.inserted_count, 2);
assert_eq!(stats.capacity, 100);
assert!(stats.load_factor > 0.0 && stats.load_factor < 1.0);
assert!(stats.memory_usage > 0);
}
#[test]
fn test_append_only_semantics() {
let temp_dir = TempDir::new().unwrap();
let shard_id = ShardId::new();
let mut shard = Shard::create(shard_id, 100, 3, temp_dir.path().to_path_buf()).unwrap();
let doc_id = DocumentId::new();
let vector1 = vec![1.0, 2.0, 3.0];
let vector2 = vec![4.0, 5.0, 6.0];
let posting1 = Posting::new(doc_id, 100, 50, vector1, 3).unwrap();
let _idx1 = shard.add_posting(posting1).unwrap();
let posting2 = Posting::new(doc_id, 100, 50, vector2.clone(), 3).unwrap();
let idx2 = shard.add_posting(posting2).unwrap();
assert_eq!(shard.current_count(), 2);
let results = shard.search(&[4.0, 5.0, 6.0], 10).unwrap();
assert_eq!(results.len(), 1); assert_eq!(results[0].document_id, doc_id);
assert_eq!(results[0].start, 100);
assert_eq!(results[0].length, 50);
let unique_postings: Vec<_> = shard
.iter_unique_postings_backward()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(unique_postings.len(), 1); assert_eq!(unique_postings[0].1.document_id, doc_id);
let (unique_index, _) = &unique_postings[0];
assert_eq!(*unique_index, idx2);
let doc_id2 = DocumentId::new();
let posting3 = Posting::new(doc_id2, 200, 30, vector2.clone(), 3).unwrap();
shard.add_posting(posting3).unwrap();
let unique_postings: Vec<_> = shard
.iter_unique_postings_backward()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(unique_postings.len(), 2);
let results = shard.search(&[4.0, 5.0, 6.0], 10).unwrap();
assert_eq!(results.len(), 2);
}
}