use crate::persistence::PersistenceError;
use bitvec::prelude::*;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Error)]
pub enum FlatIndexError {
#[error("dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
expected: usize,
actual: usize,
},
#[error("invalid k: must be greater than 0")]
InvalidK,
#[error("quantization not enabled: call enable_quantization() first")]
QuantizationNotEnabled,
#[error("index is empty")]
EmptyIndex,
#[error("ID counter overflow: cannot assign more IDs (u64::MAX reached)")]
IdOverflow,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum DistanceMetric {
#[default]
Cosine,
DotProduct,
L2,
Hamming,
}
impl DistanceMetric {
#[must_use]
pub const fn is_similarity(&self) -> bool {
matches!(self, Self::Cosine | Self::DotProduct)
}
}
#[derive(Debug, Clone)]
pub struct FlatSearchResult {
pub id: u64,
pub score: f32,
}
impl PartialEq for FlatSearchResult {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && (self.score - other.score).abs() < f32::EPSILON
}
}
impl Eq for FlatSearchResult {}
#[derive(Debug, Clone)]
struct HeapEntry {
id: u64,
score: f32,
}
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && (self.score - other.score).abs() < f32::EPSILON
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlatIndexConfig {
pub dimensions: u32,
pub metric: DistanceMetric,
pub initial_capacity: usize,
pub cleanup_threshold: f32,
}
impl FlatIndexConfig {
#[must_use]
pub fn new(dimensions: u32) -> Self {
Self {
dimensions,
metric: DistanceMetric::Cosine,
initial_capacity: 1000,
cleanup_threshold: 0.5,
}
}
#[must_use]
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
#[must_use]
pub fn with_capacity(mut self, capacity: usize) -> Self {
self.initial_capacity = capacity;
self
}
#[must_use]
pub fn with_cleanup_threshold(mut self, threshold: f32) -> Self {
self.cleanup_threshold = threshold.clamp(0.0, 1.0);
self
}
}
pub struct FlatIndex {
config: FlatIndexConfig,
vectors: Vec<f32>,
count: u64,
deleted: BitVec,
delete_count: usize,
next_id: u64,
quantized: Option<Vec<u8>>,
}
impl FlatIndex {
#[must_use]
pub fn new(config: FlatIndexConfig) -> Self {
let capacity = config.initial_capacity;
let dim = config.dimensions as usize;
Self {
config,
vectors: Vec::with_capacity(capacity * dim),
count: 0,
deleted: BitVec::with_capacity(capacity),
delete_count: 0,
next_id: 0,
quantized: None,
}
}
#[must_use]
pub fn dimensions(&self) -> u32 {
self.config.dimensions
}
#[must_use]
pub fn metric(&self) -> DistanceMetric {
self.config.metric
}
#[must_use]
#[allow(clippy::cast_possible_truncation)] pub fn len(&self) -> usize {
(self.count as usize).saturating_sub(self.delete_count)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
#[allow(clippy::cast_possible_truncation)] pub fn capacity(&self) -> usize {
self.count as usize
}
#[must_use]
pub fn config(&self) -> &FlatIndexConfig {
&self.config
}
pub fn insert(&mut self, vector: &[f32]) -> Result<u64, FlatIndexError> {
let expected_dim = self.config.dimensions as usize;
if vector.len() != expected_dim {
return Err(FlatIndexError::DimensionMismatch {
expected: expected_dim,
actual: vector.len(),
});
}
let id = self.next_id;
self.next_id = self
.next_id
.checked_add(1)
.ok_or(FlatIndexError::IdOverflow)?;
self.vectors.extend_from_slice(vector);
self.count += 1;
self.deleted.push(false);
if self.quantized.is_some() {
self.quantized = None;
}
Ok(id)
}
pub fn insert_batch(&mut self, vectors: &[&[f32]]) -> Result<Vec<u64>, FlatIndexError> {
let mut ids = Vec::with_capacity(vectors.len());
for vector in vectors {
let id = self.insert(vector)?;
ids.push(id);
}
Ok(ids)
}
#[must_use]
#[allow(clippy::cast_possible_truncation)] pub fn get(&self, id: u64) -> Option<&[f32]> {
let idx = id as usize;
let count = self.count as usize;
if idx >= count {
return None;
}
if self.deleted.get(idx).map_or(true, |b| *b) {
return None;
}
let dim = self.config.dimensions as usize;
let start = idx * dim;
let end = start + dim;
Some(&self.vectors[start..end])
}
#[must_use]
#[allow(clippy::cast_possible_truncation)] pub fn contains(&self, id: u64) -> bool {
let idx = id as usize;
let count = self.count as usize;
idx < count && !self.deleted.get(idx).map_or(true, |b| *b)
}
#[must_use]
pub fn deleted_count(&self) -> usize {
self.delete_count
}
#[must_use]
#[allow(clippy::cast_precision_loss)] pub fn deletion_ratio(&self) -> f32 {
if self.count == 0 {
0.0
} else {
self.delete_count as f32 / self.count as f32
}
}
#[must_use]
pub fn is_quantized(&self) -> bool {
self.quantized.is_some()
}
#[must_use]
pub fn memory_usage(&self) -> usize {
let vector_bytes = self.vectors.len() * std::mem::size_of::<f32>();
let bitmap_bytes = (self.deleted.len() + 7) / 8;
let quantized_bytes = self.quantized.as_ref().map_or(0, Vec::len);
vector_bytes + bitmap_bytes + quantized_bytes
}
#[must_use]
#[allow(clippy::cast_precision_loss)] #[allow(clippy::cast_possible_truncation)] pub fn deletion_stats(&self) -> (usize, usize, f32) {
let total = self.count as usize;
let deleted = self.delete_count;
let ratio = if total > 0 {
deleted as f32 / total as f32
} else {
0.0
};
(total, deleted, ratio)
}
#[allow(clippy::cast_possible_truncation)] pub fn delete(&mut self, id: u64) -> bool {
let idx = id as usize;
let count = self.count as usize;
if idx >= count {
return false;
}
if self.deleted.get(idx).map_or(true, |b| *b) {
return false;
}
self.deleted.set(idx, true);
self.delete_count += 1;
if self.quantized.is_some() {
self.quantized = None;
}
if self.should_compact() {
self.compact();
}
true
}
#[allow(clippy::cast_precision_loss)] fn should_compact(&self) -> bool {
if self.count == 0 {
return false;
}
(self.delete_count as f32 / self.count as f32) > self.config.cleanup_threshold
}
#[allow(clippy::cast_possible_truncation)] pub fn compact(&mut self) {
if self.delete_count == 0 {
return;
}
let dim = self.config.dimensions as usize;
let count = self.count as usize;
let new_count = count - self.delete_count;
let mut new_vectors = Vec::with_capacity(new_count * dim);
let mut new_deleted = BitVec::with_capacity(new_count);
for idx in 0..count {
if !self.deleted.get(idx).map_or(true, |b| *b) {
let start = idx * dim;
new_vectors.extend_from_slice(&self.vectors[start..start + dim]);
new_deleted.push(false);
}
}
self.vectors = new_vectors;
self.deleted = new_deleted;
self.count = new_count as u64;
self.delete_count = 0;
}
#[allow(clippy::cast_possible_truncation)] pub fn enable_quantization(&mut self) -> Result<(), FlatIndexError> {
if self.quantized.is_some() {
return Ok(()); }
if self.count == 0 {
self.quantized = Some(Vec::new());
return Ok(());
}
let dim = self.config.dimensions as usize;
let packed_dim = (dim + 7) / 8; let count = self.count as usize;
let mut quantized = Vec::with_capacity(count * packed_dim);
for idx in 0..count {
if self.deleted.get(idx).map_or(true, |b| *b) {
quantized.extend(std::iter::repeat(0u8).take(packed_dim));
continue;
}
let start = idx * dim;
let vector = &self.vectors[start..start + dim];
let packed = Self::binarize_vector(vector);
quantized.extend_from_slice(&packed);
}
self.quantized = Some(quantized);
Ok(())
}
pub fn disable_quantization(&mut self) {
self.quantized = None;
}
fn binarize_vector(vector: &[f32]) -> Vec<u8> {
let dim = vector.len();
let packed_dim = (dim + 7) / 8;
let mut packed = vec![0u8; packed_dim];
for (i, &val) in vector.iter().enumerate() {
if val > 0.0 {
packed[i / 8] |= 1 << (7 - (i % 8));
}
}
packed
}
#[inline]
fn hamming_distance_binary(a: &[u8], b: &[u8]) -> u32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x ^ y).count_ones())
.sum()
}
#[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_precision_loss)] pub fn search_quantized(
&self,
query: &[f32],
k: usize,
) -> Result<Vec<FlatSearchResult>, FlatIndexError> {
let quantized = self
.quantized
.as_ref()
.ok_or(FlatIndexError::QuantizationNotEnabled)?;
let expected_dim = self.config.dimensions as usize;
if query.len() != expected_dim {
return Err(FlatIndexError::DimensionMismatch {
expected: expected_dim,
actual: query.len(),
});
}
if k == 0 {
return Err(FlatIndexError::InvalidK);
}
if self.count == 0 {
return Ok(Vec::new());
}
let query_packed = Self::binarize_vector(query);
let packed_dim = query_packed.len();
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
let count = self.count as usize;
for idx in 0..count {
if self.deleted.get(idx).map_or(true, |b| *b) {
continue;
}
let start = idx * packed_dim;
let vector_packed = &quantized[start..start + packed_dim];
let distance = Self::hamming_distance_binary(&query_packed, vector_packed);
if heap.len() < k {
heap.push(HeapEntry {
id: idx as u64,
score: distance as f32,
});
} else if let Some(top) = heap.peek() {
if (distance as f32) < top.score {
heap.pop();
heap.push(HeapEntry {
id: idx as u64,
score: distance as f32,
});
}
}
}
let mut results: Vec<FlatSearchResult> = heap
.into_iter()
.map(|entry| FlatSearchResult {
id: entry.id,
score: entry.score,
})
.collect();
results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(Ordering::Equal));
Ok(results)
}
#[allow(clippy::cast_possible_truncation)] pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<FlatSearchResult>, FlatIndexError> {
let expected_dim = self.config.dimensions as usize;
if query.len() != expected_dim {
return Err(FlatIndexError::DimensionMismatch {
expected: expected_dim,
actual: query.len(),
});
}
if k == 0 {
return Err(FlatIndexError::InvalidK);
}
if self.count == 0 {
return Ok(Vec::new());
}
let dim = self.config.dimensions as usize;
let is_similarity = self.config.metric.is_similarity();
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
let count = self.count as usize;
for idx in 0..count {
if self.deleted.get(idx).map_or(true, |b| *b) {
continue;
}
let start = idx * dim;
let end = start + dim;
let vector = &self.vectors[start..end];
let raw_score = self.compute_distance(query, vector);
let heap_score = if is_similarity { -raw_score } else { raw_score };
if heap.len() < k {
heap.push(HeapEntry {
id: idx as u64,
score: heap_score,
});
} else if let Some(top) = heap.peek() {
if heap_score < top.score {
heap.pop();
heap.push(HeapEntry {
id: idx as u64,
score: heap_score,
});
}
}
}
let mut results: Vec<FlatSearchResult> = heap
.into_iter()
.map(|entry| FlatSearchResult {
id: entry.id,
score: if is_similarity {
-entry.score
} else {
entry.score
},
})
.collect();
if is_similarity {
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
} else {
results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(Ordering::Equal));
}
Ok(results)
}
#[allow(clippy::unused_self)] fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self.config.metric {
DistanceMetric::Cosine => Self::cosine_similarity(a, b),
DistanceMetric::DotProduct => Self::dot_product(a, b),
DistanceMetric::L2 => Self::euclidean_distance(a, b),
DistanceMetric::Hamming => Self::hamming_distance(a, b),
}
}
#[inline]
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
#[inline]
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
#[inline]
#[allow(clippy::float_cmp)] #[allow(clippy::cast_precision_loss)] fn hamming_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.filter(|(x, y)| (**x != 0.0) != (**y != 0.0))
.count() as f32
}
}
pub const FLAT_INDEX_VERSION: u32 = 1;
pub const FLAT_INDEX_MAGIC: [u8; 4] = [b'E', b'V', b'F', b'I'];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlatIndexHeader {
pub magic: [u8; 4],
pub version: u32,
pub dimensions: u32,
pub metric: DistanceMetric,
pub count: u64,
pub delete_count: u64,
pub next_id: u64,
pub is_quantized: bool,
pub cleanup_threshold: f32,
pub checksum: u32,
}
impl FlatIndexHeader {
#[must_use]
pub fn from_index(index: &FlatIndex, checksum: u32) -> Self {
Self {
magic: FLAT_INDEX_MAGIC,
version: FLAT_INDEX_VERSION,
dimensions: index.config.dimensions,
metric: index.config.metric,
count: index.count,
delete_count: index.delete_count as u64,
next_id: index.next_id,
is_quantized: index.quantized.is_some(),
cleanup_threshold: index.config.cleanup_threshold,
checksum,
}
}
#[allow(clippy::cast_possible_truncation)] pub fn validate(&self) -> Result<(), PersistenceError> {
if self.magic != FLAT_INDEX_MAGIC {
return Err(PersistenceError::InvalidMagic {
expected: FLAT_INDEX_MAGIC,
actual: self.magic,
});
}
if self.version > FLAT_INDEX_VERSION {
return Err(PersistenceError::UnsupportedVersion(
(self.version >> 8) as u8,
(self.version & 0xFF) as u8,
));
}
Ok(())
}
}
impl FlatIndex {
#[allow(clippy::cast_possible_truncation)] pub fn to_snapshot(&self) -> Result<Vec<u8>, PersistenceError> {
let mut buffer = Vec::new();
let deleted_bytes = self.serialize_deleted_bitmap();
let vectors_bytes = self.serialize_vectors();
let quantized_bytes = self.serialize_quantized();
let checksum = Self::compute_checksum(&deleted_bytes, &vectors_bytes, &quantized_bytes);
let header = FlatIndexHeader::from_index(self, checksum);
let header_bytes = postcard::to_allocvec(&header)
.map_err(|e| PersistenceError::SerializationError(e.to_string()))?;
buffer.extend_from_slice(&(header_bytes.len() as u32).to_le_bytes());
buffer.extend_from_slice(&header_bytes);
buffer.extend_from_slice(&(deleted_bytes.len() as u32).to_le_bytes());
buffer.extend_from_slice(&deleted_bytes);
buffer.extend_from_slice(&(vectors_bytes.len() as u64).to_le_bytes());
buffer.extend_from_slice(&vectors_bytes);
buffer.extend_from_slice(&(quantized_bytes.len() as u64).to_le_bytes());
if !quantized_bytes.is_empty() {
buffer.extend_from_slice(&quantized_bytes);
}
Ok(buffer)
}
#[allow(clippy::cast_possible_truncation)] #[allow(clippy::missing_panics_doc)] pub fn from_snapshot(data: &[u8]) -> Result<Self, PersistenceError> {
let mut cursor = 0;
if data.len() < 4 {
return Err(PersistenceError::TruncatedData);
}
let header_len = u32::from_le_bytes(
data[0..4]
.try_into()
.map_err(|_| PersistenceError::TruncatedData)?,
) as usize;
cursor += 4;
if data.len() < cursor + header_len {
return Err(PersistenceError::TruncatedData);
}
let header: FlatIndexHeader = postcard::from_bytes(&data[cursor..cursor + header_len])
.map_err(|e| PersistenceError::DeserializationError(e.to_string()))?;
cursor += header_len;
header.validate()?;
if data.len() < cursor + 4 {
return Err(PersistenceError::TruncatedData);
}
let deleted_len = u32::from_le_bytes(
data[cursor..cursor + 4]
.try_into()
.map_err(|_| PersistenceError::TruncatedData)?,
) as usize;
cursor += 4;
if data.len() < cursor + deleted_len {
return Err(PersistenceError::TruncatedData);
}
let deleted_bytes = &data[cursor..cursor + deleted_len];
cursor += deleted_len;
if data.len() < cursor + 8 {
return Err(PersistenceError::TruncatedData);
}
let vectors_len = u64::from_le_bytes(
data[cursor..cursor + 8]
.try_into()
.map_err(|_| PersistenceError::TruncatedData)?,
) as usize;
cursor += 8;
if data.len() < cursor + vectors_len {
return Err(PersistenceError::TruncatedData);
}
let vectors_bytes = &data[cursor..cursor + vectors_len];
cursor += vectors_len;
if data.len() < cursor + 8 {
return Err(PersistenceError::TruncatedData);
}
let quantized_len = u64::from_le_bytes(
data[cursor..cursor + 8]
.try_into()
.map_err(|_| PersistenceError::TruncatedData)?,
) as usize;
cursor += 8;
let quantized_bytes = if quantized_len > 0 {
if data.len() < cursor + quantized_len {
return Err(PersistenceError::TruncatedData);
}
Some(data[cursor..cursor + quantized_len].to_vec())
} else {
None
};
let computed_checksum = Self::compute_checksum(
deleted_bytes,
vectors_bytes,
quantized_bytes.as_deref().unwrap_or(&[]),
);
if computed_checksum != header.checksum {
return Err(PersistenceError::ChecksumMismatch {
expected: header.checksum,
actual: computed_checksum,
});
}
let config = FlatIndexConfig {
dimensions: header.dimensions,
metric: header.metric,
initial_capacity: header.count as usize,
cleanup_threshold: header.cleanup_threshold,
};
let vectors: Vec<f32> = vectors_bytes
.chunks_exact(4)
.map(|chunk| {
f32::from_le_bytes(chunk.try_into().expect("chunks_exact guarantees 4 bytes"))
})
.collect();
let deleted = Self::deserialize_deleted_bitmap(deleted_bytes, header.count as usize);
Ok(Self {
config,
vectors,
count: header.count,
deleted,
delete_count: header.delete_count as usize,
next_id: header.next_id,
quantized: quantized_bytes,
})
}
fn serialize_deleted_bitmap(&self) -> Vec<u8> {
let raw_slice = self.deleted.as_raw_slice();
let mut bytes = Vec::with_capacity(std::mem::size_of_val(raw_slice));
for &word in raw_slice {
bytes.extend_from_slice(&word.to_le_bytes());
}
bytes
}
fn deserialize_deleted_bitmap(bytes: &[u8], count: usize) -> BitVec {
let word_size = std::mem::size_of::<usize>();
let mut words: Vec<usize> = bytes
.chunks(word_size)
.map(|chunk| {
let mut arr = [0u8; std::mem::size_of::<usize>()];
let len = chunk.len().min(word_size);
arr[..len].copy_from_slice(&chunk[..len]);
usize::from_le_bytes(arr)
})
.collect();
let needed_words = (count + usize::BITS as usize - 1) / usize::BITS as usize;
words.resize(needed_words, 0);
let mut bv = BitVec::from_vec(words);
bv.truncate(count);
bv
}
fn serialize_vectors(&self) -> Vec<u8> {
self.vectors.iter().flat_map(|f| f.to_le_bytes()).collect()
}
fn serialize_quantized(&self) -> Vec<u8> {
self.quantized.clone().unwrap_or_default()
}
fn compute_checksum(deleted: &[u8], vectors: &[u8], quantized: &[u8]) -> u32 {
let mut hasher = crc32fast::Hasher::new();
hasher.update(deleted);
hasher.update(vectors);
hasher.update(quantized);
hasher.finalize()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_new() {
let config = FlatIndexConfig::new(128);
assert_eq!(config.dimensions, 128);
assert_eq!(config.metric, DistanceMetric::Cosine);
assert_eq!(config.initial_capacity, 1000);
assert!((config.cleanup_threshold - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_config_builder() {
let config = FlatIndexConfig::new(64)
.with_metric(DistanceMetric::DotProduct)
.with_capacity(5000)
.with_cleanup_threshold(0.3);
assert_eq!(config.dimensions, 64);
assert_eq!(config.metric, DistanceMetric::DotProduct);
assert_eq!(config.initial_capacity, 5000);
assert!((config.cleanup_threshold - 0.3).abs() < f32::EPSILON);
}
#[test]
fn test_config_cleanup_threshold_clamping() {
let config_low = FlatIndexConfig::new(64).with_cleanup_threshold(-0.5);
assert!((config_low.cleanup_threshold - 0.0).abs() < f32::EPSILON);
let config_high = FlatIndexConfig::new(64).with_cleanup_threshold(1.5);
assert!((config_high.cleanup_threshold - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_distance_metric_is_similarity() {
assert!(DistanceMetric::Cosine.is_similarity());
assert!(DistanceMetric::DotProduct.is_similarity());
assert!(!DistanceMetric::L2.is_similarity());
assert!(!DistanceMetric::Hamming.is_similarity());
}
#[test]
fn test_distance_metric_default() {
let metric: DistanceMetric = DistanceMetric::default();
assert_eq!(metric, DistanceMetric::Cosine);
}
#[test]
fn test_new_flat_index() {
let config = FlatIndexConfig::new(128);
let index = FlatIndex::new(config);
assert_eq!(index.dimensions(), 128);
assert_eq!(index.metric(), DistanceMetric::Cosine);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
assert_eq!(index.capacity(), 0);
assert_eq!(index.deleted_count(), 0);
assert!(!index.is_quantized());
}
#[test]
fn test_new_with_different_metrics() {
for metric in [
DistanceMetric::Cosine,
DistanceMetric::DotProduct,
DistanceMetric::L2,
DistanceMetric::Hamming,
] {
let config = FlatIndexConfig::new(64).with_metric(metric);
let index = FlatIndex::new(config);
assert_eq!(index.metric(), metric);
}
}
#[test]
fn test_insert_single() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
let id = index.insert(&[1.0, 2.0, 3.0]).unwrap();
assert_eq!(id, 0);
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
assert_eq!(index.capacity(), 1);
}
#[test]
fn test_insert_multiple_sequential_ids() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
let id1 = index.insert(&[1.0, 2.0, 3.0]).unwrap();
let id2 = index.insert(&[4.0, 5.0, 6.0]).unwrap();
let id3 = index.insert(&[7.0, 8.0, 9.0]).unwrap();
assert_eq!(id1, 0);
assert_eq!(id2, 1);
assert_eq!(id3, 2);
assert_eq!(index.len(), 3);
}
#[test]
fn test_insert_dimension_mismatch() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
let result = index.insert(&[1.0, 2.0]);
assert!(matches!(
result,
Err(FlatIndexError::DimensionMismatch {
expected: 3,
actual: 2
})
));
let result = index.insert(&[1.0, 2.0, 3.0, 4.0]);
assert!(matches!(
result,
Err(FlatIndexError::DimensionMismatch {
expected: 3,
actual: 4
})
));
assert!(index.is_empty());
}
#[test]
fn test_insert_batch() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
let vectors: Vec<&[f32]> = vec![&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]];
let ids = index.insert_batch(&vectors).unwrap();
assert_eq!(ids, vec![0, 1, 2]);
assert_eq!(index.len(), 3);
}
#[test]
fn test_insert_batch_dimension_mismatch() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
let vectors: Vec<&[f32]> = vec![
&[1.0, 2.0, 3.0], &[4.0, 5.0], &[7.0, 8.0, 9.0], ];
let result = index.insert_batch(&vectors);
assert!(result.is_err());
assert_eq!(index.len(), 1);
assert!(index.contains(0));
}
#[test]
fn test_insert_capacity_growth() {
let config = FlatIndexConfig::new(3).with_capacity(2);
let mut index = FlatIndex::new(config);
for i in 0..10 {
let id = index.insert(&[i as f32, i as f32, i as f32]).unwrap();
assert_eq!(id, i);
}
assert_eq!(index.len(), 10);
for i in 0..10 {
assert!(index.contains(i));
}
}
#[test]
fn test_get_vector() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
index.insert(&[1.0, 2.0, 3.0]).unwrap();
index.insert(&[4.0, 5.0, 6.0]).unwrap();
let v0 = index.get(0).unwrap();
let v1 = index.get(1).unwrap();
assert_eq!(v0, &[1.0, 2.0, 3.0]);
assert_eq!(v1, &[4.0, 5.0, 6.0]);
}
#[test]
fn test_get_nonexistent() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
index.insert(&[1.0, 2.0, 3.0]).unwrap();
assert!(index.get(1).is_none()); assert!(index.get(99).is_none()); assert!(index.get(u64::MAX).is_none()); }
#[test]
fn test_get_empty_index() {
let index = FlatIndex::new(FlatIndexConfig::new(3));
assert!(index.get(0).is_none());
}
#[test]
fn test_contains() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
index.insert(&[1.0, 2.0, 3.0]).unwrap();
index.insert(&[4.0, 5.0, 6.0]).unwrap();
assert!(index.contains(0));
assert!(index.contains(1));
assert!(!index.contains(2));
assert!(!index.contains(99));
}
#[test]
fn test_memory_usage() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
let empty_usage = index.memory_usage();
assert_eq!(empty_usage, 0);
index.insert(&[1.0, 2.0, 3.0]).unwrap();
let usage = index.memory_usage();
assert!(usage >= 12);
}
#[test]
fn test_deletion_ratio_empty() {
let index = FlatIndex::new(FlatIndexConfig::new(3));
assert!((index.deletion_ratio() - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_high_dimension_vectors() {
let dim = 768; let mut index = FlatIndex::new(FlatIndexConfig::new(dim));
let vector: Vec<f32> = (0..dim).map(|i| i as f32 / dim as f32).collect();
let id = index.insert(&vector).unwrap();
assert_eq!(id, 0);
let retrieved = index.get(0).unwrap();
assert_eq!(retrieved.len(), dim as usize);
assert_eq!(retrieved, vector.as_slice());
}
#[test]
fn test_zero_vector() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
let id = index.insert(&[0.0, 0.0, 0.0]).unwrap();
let v = index.get(id).unwrap();
assert_eq!(v, &[0.0, 0.0, 0.0]);
}
#[test]
fn test_negative_values() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
let id = index.insert(&[-1.0, -2.0, -3.0]).unwrap();
let v = index.get(id).unwrap();
assert_eq!(v, &[-1.0, -2.0, -3.0]);
}
#[test]
fn test_special_float_values() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
let id = index
.insert(&[f32::INFINITY, f32::NEG_INFINITY, 0.0])
.unwrap();
let v = index.get(id).unwrap();
assert!(v[0].is_infinite());
assert!(v[1].is_infinite());
}
#[test]
fn test_search_basic_cosine() {
let config = FlatIndexConfig::new(3).with_metric(DistanceMetric::Cosine);
let mut index = FlatIndex::new(config);
index.insert(&[1.0, 0.0, 0.0]).unwrap(); index.insert(&[0.0, 1.0, 0.0]).unwrap(); index.insert(&[0.0, 0.0, 1.0]).unwrap();
let results = index.search(&[0.9, 0.1, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, 0); }
#[test]
fn test_search_all_metrics() {
for metric in [
DistanceMetric::Cosine,
DistanceMetric::DotProduct,
DistanceMetric::L2,
DistanceMetric::Hamming,
] {
let config = FlatIndexConfig::new(3).with_metric(metric);
let mut index = FlatIndex::new(config);
index.insert(&[1.0, 0.0, 0.0]).unwrap();
index.insert(&[0.0, 1.0, 0.0]).unwrap();
let results = index.search(&[1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2, "Failed for metric {:?}", metric);
assert_eq!(results[0].id, 0, "Failed for metric {:?}", metric);
}
}
#[test]
fn test_search_dimension_mismatch() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
index.insert(&[1.0, 0.0, 0.0]).unwrap();
let result = index.search(&[1.0, 0.0], 1);
assert!(matches!(
result,
Err(FlatIndexError::DimensionMismatch {
expected: 3,
actual: 2
})
));
}
#[test]
fn test_search_k_zero() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
index.insert(&[1.0, 0.0, 0.0]).unwrap();
let result = index.search(&[1.0, 0.0, 0.0], 0);
assert!(matches!(result, Err(FlatIndexError::InvalidK)));
}
#[test]
fn test_search_empty_index() {
let index = FlatIndex::new(FlatIndexConfig::new(3));
let results = index.search(&[1.0, 0.0, 0.0], 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_k_larger_than_count() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
index.insert(&[1.0, 0.0, 0.0]).unwrap();
index.insert(&[0.0, 1.0, 0.0]).unwrap();
let results = index.search(&[1.0, 0.0, 0.0], 10).unwrap();
assert_eq!(results.len(), 2); }
#[test]
fn test_search_results_sorted_cosine() {
let config = FlatIndexConfig::new(3).with_metric(DistanceMetric::Cosine);
let mut index = FlatIndex::new(config);
index.insert(&[1.0, 0.0, 0.0]).unwrap(); index.insert(&[0.707, 0.707, 0.0]).unwrap(); index.insert(&[0.0, 1.0, 0.0]).unwrap();
let query = [1.0, 0.0, 0.0];
let results = index.search(&query, 3).unwrap();
assert_eq!(results[0].id, 0); assert_eq!(results[2].id, 2);
for i in 1..results.len() {
assert!(
results[i - 1].score >= results[i].score,
"Results not sorted at index {}: {} < {}",
i,
results[i - 1].score,
results[i].score
);
}
}
#[test]
fn test_search_l2_metric() {
let config = FlatIndexConfig::new(3).with_metric(DistanceMetric::L2);
let mut index = FlatIndex::new(config);
index.insert(&[0.0, 0.0, 0.0]).unwrap(); index.insert(&[1.0, 0.0, 0.0]).unwrap(); index.insert(&[2.0, 0.0, 0.0]).unwrap();
let results = index.search(&[0.0, 0.0, 0.0], 3).unwrap();
assert_eq!(results[0].id, 0);
assert!((results[0].score - 0.0).abs() < 1e-6);
assert_eq!(results[1].id, 1);
assert!((results[1].score - 1.0).abs() < 1e-6);
assert_eq!(results[2].id, 2);
assert!((results[2].score - 2.0).abs() < 1e-6);
}
#[test]
fn test_search_dot_product_metric() {
let config = FlatIndexConfig::new(3).with_metric(DistanceMetric::DotProduct);
let mut index = FlatIndex::new(config);
index.insert(&[1.0, 0.0, 0.0]).unwrap(); index.insert(&[0.5, 0.0, 0.0]).unwrap(); index.insert(&[0.0, 1.0, 0.0]).unwrap();
let results = index.search(&[1.0, 0.0, 0.0], 3).unwrap();
assert_eq!(results[0].id, 0);
assert!((results[0].score - 1.0).abs() < 1e-6);
assert_eq!(results[1].id, 1);
assert!((results[1].score - 0.5).abs() < 1e-6);
}
#[test]
fn test_search_hamming_metric() {
let config = FlatIndexConfig::new(4).with_metric(DistanceMetric::Hamming);
let mut index = FlatIndex::new(config);
index.insert(&[1.0, 1.0, 0.0, 0.0]).unwrap(); index.insert(&[1.0, 0.0, 0.0, 0.0]).unwrap(); index.insert(&[0.0, 0.0, 1.0, 1.0]).unwrap();
let results = index.search(&[1.0, 1.0, 0.0, 0.0], 3).unwrap();
assert_eq!(results[0].id, 0);
assert!((results[0].score - 0.0).abs() < 1e-6);
assert_eq!(results[1].id, 1);
assert!((results[1].score - 1.0).abs() < 1e-6);
assert_eq!(results[2].id, 2);
assert!((results[2].score - 4.0).abs() < 1e-6);
}
#[test]
fn test_search_100_recall_validation() {
let mut index = FlatIndex::new(FlatIndexConfig::new(64));
let mut seed: u64 = 42;
let lcg = |s: &mut u64| -> f32 {
*s = s.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
((*s >> 33) as f32) / (u32::MAX as f32)
};
for _ in 0..100 {
let v: Vec<f32> = (0..64).map(|_| lcg(&mut seed)).collect();
index.insert(&v).unwrap();
}
let query: Vec<f32> = (0..64).map(|_| lcg(&mut seed)).collect();
let results = index.search(&query, 10).unwrap();
assert_eq!(results.len(), 10);
for i in 1..results.len() {
assert!(results[i - 1].score >= results[i].score);
}
}
#[test]
fn test_search_high_dimension() {
let dim = 768;
let config = FlatIndexConfig::new(dim).with_metric(DistanceMetric::Cosine);
let mut index = FlatIndex::new(config);
for i in 0..100 {
let v: Vec<f32> = (0..dim as usize).map(|j| (i * j) as f32 / 1000.0).collect();
index.insert(&v).unwrap();
}
let query: Vec<f32> = (0..dim as usize).map(|j| j as f32 / 1000.0).collect();
let results = index.search(&query, 5).unwrap();
assert_eq!(results.len(), 5);
}
#[test]
fn test_delete_basic() {
let config = FlatIndexConfig::new(3).with_cleanup_threshold(1.0); let mut index = FlatIndex::new(config);
let id = index.insert(&[1.0, 2.0, 3.0]).unwrap();
assert!(index.contains(id));
assert!(index.delete(id)); assert!(!index.contains(id)); assert!(index.get(id).is_none());
}
#[test]
fn test_delete_already_deleted() {
let config = FlatIndexConfig::new(3).with_cleanup_threshold(1.0);
let mut index = FlatIndex::new(config);
let id = index.insert(&[1.0, 2.0, 3.0]).unwrap();
assert!(index.delete(id)); assert!(!index.delete(id)); }
#[test]
fn test_delete_nonexistent() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
assert!(!index.delete(0)); assert!(!index.delete(999)); }
#[test]
fn test_delete_updates_len() {
let config = FlatIndexConfig::new(3).with_cleanup_threshold(1.0);
let mut index = FlatIndex::new(config);
index.insert(&[1.0, 2.0, 3.0]).unwrap();
index.insert(&[4.0, 5.0, 6.0]).unwrap();
assert_eq!(index.len(), 2);
index.delete(0);
assert_eq!(index.len(), 1);
index.delete(1);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_delete_updates_deleted_count() {
let config = FlatIndexConfig::new(3).with_cleanup_threshold(1.0);
let mut index = FlatIndex::new(config);
index.insert(&[1.0, 2.0, 3.0]).unwrap();
index.insert(&[4.0, 5.0, 6.0]).unwrap();
assert_eq!(index.deleted_count(), 0);
index.delete(0);
assert_eq!(index.deleted_count(), 1);
index.delete(1);
assert_eq!(index.deleted_count(), 2);
}
#[test]
fn test_deletion_stats() {
let config = FlatIndexConfig::new(3).with_cleanup_threshold(1.0);
let mut index = FlatIndex::new(config);
index.insert(&[1.0, 2.0, 3.0]).unwrap();
index.insert(&[4.0, 5.0, 6.0]).unwrap();
let (total, deleted, ratio) = index.deletion_stats();
assert_eq!(total, 2);
assert_eq!(deleted, 0);
assert!((ratio - 0.0).abs() < f32::EPSILON);
index.delete(0);
let (total, deleted, ratio) = index.deletion_stats();
assert_eq!(total, 2);
assert_eq!(deleted, 1);
assert!((ratio - 0.5).abs() < 0.01);
}
#[test]
fn test_search_skips_deleted() {
let config = FlatIndexConfig::new(3)
.with_metric(DistanceMetric::Cosine)
.with_cleanup_threshold(1.0);
let mut index = FlatIndex::new(config);
index.insert(&[1.0, 0.0, 0.0]).unwrap(); index.insert(&[0.0, 1.0, 0.0]).unwrap();
let results = index.search(&[1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results[0].id, 0);
index.delete(0);
let results = index.search(&[1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 1);
}
#[test]
fn test_compact_basic() {
let config = FlatIndexConfig::new(3).with_cleanup_threshold(1.0); let mut index = FlatIndex::new(config);
index.insert(&[1.0, 2.0, 3.0]).unwrap();
index.insert(&[4.0, 5.0, 6.0]).unwrap();
index.insert(&[7.0, 8.0, 9.0]).unwrap();
index.delete(1);
assert_eq!(index.capacity(), 3);
assert_eq!(index.len(), 2);
index.compact();
assert_eq!(index.capacity(), 2);
assert_eq!(index.len(), 2);
assert_eq!(index.deleted_count(), 0);
}
#[test]
fn test_compact_preserves_data() {
let config = FlatIndexConfig::new(3).with_cleanup_threshold(1.0);
let mut index = FlatIndex::new(config);
index.insert(&[1.0, 2.0, 3.0]).unwrap(); index.insert(&[4.0, 5.0, 6.0]).unwrap(); index.insert(&[7.0, 8.0, 9.0]).unwrap();
index.delete(0);
index.delete(2);
index.compact();
assert_eq!(index.len(), 1);
let v = index.get(0).unwrap(); assert_eq!(v, &[4.0, 5.0, 6.0]);
}
#[test]
fn test_compact_empty() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
index.compact();
assert!(index.is_empty());
}
#[test]
fn test_compact_nothing_to_do() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
index.insert(&[1.0, 2.0, 3.0]).unwrap();
let capacity_before = index.capacity();
index.compact();
assert_eq!(index.capacity(), capacity_before);
}
#[test]
fn test_auto_compact_on_threshold() {
let config = FlatIndexConfig::new(3).with_cleanup_threshold(0.3);
let mut index = FlatIndex::new(config);
for i in 0..10 {
index.insert(&[i as f32, 0.0, 0.0]).unwrap();
}
assert_eq!(index.capacity(), 10);
index.delete(0);
index.delete(1);
index.delete(2);
assert_eq!(index.capacity(), 10);
index.delete(3);
assert_eq!(index.capacity(), 6);
assert_eq!(index.deleted_count(), 0);
}
#[test]
fn test_enable_quantization() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
index
.insert(&[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0])
.unwrap();
assert!(!index.is_quantized());
index.enable_quantization().unwrap();
assert!(index.is_quantized());
}
#[test]
fn test_enable_quantization_idempotent() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
index
.insert(&[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0])
.unwrap();
index.enable_quantization().unwrap();
index.enable_quantization().unwrap();
assert!(index.is_quantized());
}
#[test]
fn test_enable_quantization_empty_index() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
index.enable_quantization().unwrap();
assert!(index.is_quantized());
}
#[test]
fn test_disable_quantization() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
index
.insert(&[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0])
.unwrap();
index.enable_quantization().unwrap();
assert!(index.is_quantized());
index.disable_quantization();
assert!(!index.is_quantized());
}
#[test]
fn test_search_quantized_basic() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
index
.insert(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
.unwrap();
index
.insert(&[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0])
.unwrap();
index
.insert(&[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0])
.unwrap();
index.enable_quantization().unwrap();
let query = [1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
let results = index.search_quantized(&query, 3).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, 1); assert!((results[0].score - 0.0).abs() < f32::EPSILON); }
#[test]
fn test_search_quantized_hamming_distances() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
index
.insert(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
.unwrap();
index
.insert(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0])
.unwrap();
index
.insert(&[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0])
.unwrap();
index.enable_quantization().unwrap();
let results = index.search_quantized(&[1.0; 8], 3).unwrap();
assert_eq!(results[0].id, 0);
assert!((results[0].score - 0.0).abs() < f32::EPSILON);
assert_eq!(results[1].id, 1);
assert!((results[1].score - 1.0).abs() < f32::EPSILON);
assert_eq!(results[2].id, 2);
assert!((results[2].score - 8.0).abs() < f32::EPSILON);
}
#[test]
fn test_search_quantized_not_enabled() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
index.insert(&[1.0; 8]).unwrap();
let result = index.search_quantized(&[1.0; 8], 1);
assert!(matches!(
result,
Err(FlatIndexError::QuantizationNotEnabled)
));
}
#[test]
fn test_search_quantized_dimension_mismatch() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
index.insert(&[1.0; 8]).unwrap();
index.enable_quantization().unwrap();
let result = index.search_quantized(&[1.0; 4], 1); assert!(matches!(
result,
Err(FlatIndexError::DimensionMismatch {
expected: 8,
actual: 4
})
));
}
#[test]
fn test_search_quantized_k_zero() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
index.insert(&[1.0; 8]).unwrap();
index.enable_quantization().unwrap();
let result = index.search_quantized(&[1.0; 8], 0);
assert!(matches!(result, Err(FlatIndexError::InvalidK)));
}
#[test]
fn test_search_quantized_skips_deleted() {
let config = FlatIndexConfig::new(8).with_cleanup_threshold(1.0);
let mut index = FlatIndex::new(config);
index.insert(&[1.0; 8]).unwrap(); index.insert(&[-1.0; 8]).unwrap();
index.enable_quantization().unwrap();
let results = index.search_quantized(&[1.0; 8], 2).unwrap();
assert_eq!(results[0].id, 0);
index.delete(0);
index.enable_quantization().unwrap();
let results = index.search_quantized(&[1.0; 8], 2).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 1);
}
#[test]
fn test_insert_invalidates_quantization() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
index.insert(&[1.0; 8]).unwrap();
index.enable_quantization().unwrap();
assert!(index.is_quantized());
index.insert(&[-1.0; 8]).unwrap();
assert!(!index.is_quantized());
}
#[test]
fn test_delete_invalidates_quantization() {
let config = FlatIndexConfig::new(8).with_cleanup_threshold(1.0);
let mut index = FlatIndex::new(config);
index.insert(&[1.0; 8]).unwrap();
index.enable_quantization().unwrap();
assert!(index.is_quantized());
index.delete(0);
assert!(!index.is_quantized());
}
#[test]
fn test_quantization_memory_reduction() {
let dim = 768;
let mut index = FlatIndex::new(FlatIndexConfig::new(dim));
for i in 0..100 {
let v: Vec<f32> = (0..dim as usize)
.map(|j| if (i + j) % 2 == 0 { 1.0 } else { -1.0 })
.collect();
index.insert(&v).unwrap();
}
let memory_before = index.memory_usage();
index.enable_quantization().unwrap();
let memory_after = index.memory_usage();
assert!(memory_after > memory_before);
let f32_per_vector = dim as usize * 4;
let bq_per_vector = (dim as usize + 7) / 8;
assert_eq!(f32_per_vector / bq_per_vector, 32);
}
#[test]
fn test_binarize_vector() {
let packed = FlatIndex::binarize_vector(&[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]);
assert_eq!(packed.len(), 1);
assert_eq!(packed[0], 0b1010_1010);
}
#[test]
fn test_binarize_vector_16_dim() {
let packed = FlatIndex::binarize_vector(&[
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, ]);
assert_eq!(packed.len(), 2);
assert_eq!(packed[0], 0b1111_1111);
assert_eq!(packed[1], 0b0000_0000);
}
#[test]
fn test_hamming_distance_binary() {
let a = [0b1111_1111u8];
let b = [0b1111_1110u8];
let distance = FlatIndex::hamming_distance_binary(&a, &b);
assert_eq!(distance, 1);
let c = [0b0000_0000u8];
let distance2 = FlatIndex::hamming_distance_binary(&a, &c);
assert_eq!(distance2, 8); }
#[test]
fn test_search_quantized_high_dimension() {
let dim = 768;
let mut index = FlatIndex::new(FlatIndexConfig::new(dim));
for i in 0..50 {
let v: Vec<f32> = (0..dim as usize)
.map(|j| if (i + j) % 2 == 0 { 1.0 } else { -1.0 })
.collect();
index.insert(&v).unwrap();
}
index.enable_quantization().unwrap();
let query: Vec<f32> = (0..dim as usize)
.map(|j| if j % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let results = index.search_quantized(&query, 10).unwrap();
assert_eq!(results.len(), 10);
assert!(results[0].id % 2 == 0);
}
#[test]
fn test_snapshot_round_trip_basic() {
let mut index = FlatIndex::new(FlatIndexConfig::new(64));
for i in 0..100 {
let v: Vec<f32> = (0..64).map(|j| (i * 64 + j) as f32 / 1000.0).collect();
index.insert(&v).unwrap();
}
let snapshot = index.to_snapshot().unwrap();
let restored = FlatIndex::from_snapshot(&snapshot).unwrap();
assert_eq!(restored.dimensions(), index.dimensions());
assert_eq!(restored.len(), index.len());
assert_eq!(restored.metric(), index.metric());
for i in 0..100 {
let original = index.get(i).unwrap();
let restored_vec = restored.get(i).unwrap();
assert_eq!(original, restored_vec);
}
}
#[test]
#[allow(clippy::useless_vec)] fn test_snapshot_with_deletions() {
let mut index = FlatIndex::new(FlatIndexConfig::new(16));
for i in 0..50 {
index.insert(&vec![i as f32; 16]).unwrap();
}
assert!(index.delete(10));
assert!(index.delete(20));
assert!(index.delete(30));
let snapshot = index.to_snapshot().unwrap();
let restored = FlatIndex::from_snapshot(&snapshot).unwrap();
assert!(restored.get(10).is_none());
assert!(restored.get(20).is_none());
assert!(restored.get(30).is_none());
assert!(restored.get(0).is_some());
assert!(restored.get(49).is_some());
let (_, delete_count, _) = restored.deletion_stats();
assert_eq!(delete_count, 3);
}
#[test]
fn test_snapshot_with_quantization() {
let mut index = FlatIndex::new(FlatIndexConfig::new(128));
for i in 0..100 {
let v: Vec<f32> = (0..128)
.map(|j| if (i + j) % 2 == 0 { 1.0 } else { -1.0 })
.collect();
index.insert(&v).unwrap();
}
index.enable_quantization().unwrap();
assert!(index.is_quantized());
let snapshot = index.to_snapshot().unwrap();
let restored = FlatIndex::from_snapshot(&snapshot).unwrap();
assert!(restored.is_quantized());
let query: Vec<f32> = (0..128)
.map(|j| if j % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let results = restored.search_quantized(&query, 5).unwrap();
assert_eq!(results.len(), 5);
}
#[test]
fn test_snapshot_different_metrics() {
for metric in [
DistanceMetric::Cosine,
DistanceMetric::DotProduct,
DistanceMetric::L2,
DistanceMetric::Hamming,
] {
let config = FlatIndexConfig::new(32).with_metric(metric);
let mut index = FlatIndex::new(config);
index.insert(&[0.5; 32]).unwrap();
let snapshot = index.to_snapshot().unwrap();
let restored = FlatIndex::from_snapshot(&snapshot).unwrap();
assert_eq!(restored.metric(), metric);
}
}
#[test]
fn test_snapshot_invalid_magic() {
let mut index = FlatIndex::new(FlatIndexConfig::new(4));
index.insert(&[1.0; 4]).unwrap();
let mut snapshot = index.to_snapshot().unwrap();
if snapshot.len() > 8 {
snapshot[4] = b'X';
snapshot[5] = b'X';
snapshot[6] = b'X';
snapshot[7] = b'X';
}
let result = FlatIndex::from_snapshot(&snapshot);
assert!(result.is_err());
}
#[test]
fn test_snapshot_truncated() {
let mut index = FlatIndex::new(FlatIndexConfig::new(16));
index.insert(&[1.0; 16]).unwrap();
let snapshot = index.to_snapshot().unwrap();
let truncated = &snapshot[..snapshot.len() / 2];
let result = FlatIndex::from_snapshot(truncated);
assert!(result.is_err());
}
#[test]
fn test_snapshot_corrupted_checksum() {
let mut index = FlatIndex::new(FlatIndexConfig::new(16));
index.insert(&[1.0; 16]).unwrap();
let mut snapshot = index.to_snapshot().unwrap();
if snapshot.len() > 100 {
snapshot[100] ^= 0xFF;
}
let result = FlatIndex::from_snapshot(&snapshot);
assert!(result.is_err());
}
#[test]
fn test_search_after_restore() {
let config = FlatIndexConfig::new(64).with_metric(DistanceMetric::Cosine);
let mut index = FlatIndex::new(config);
index.insert(&[1.0; 64]).unwrap(); index.insert(&[0.5; 64]).unwrap(); index.insert(&[0.0; 64]).unwrap();
let snapshot = index.to_snapshot().unwrap();
let restored = FlatIndex::from_snapshot(&snapshot).unwrap();
let query = [1.0; 64];
let original_results = index.search(&query, 3).unwrap();
let restored_results = restored.search(&query, 3).unwrap();
assert_eq!(original_results.len(), restored_results.len());
for (orig, rest) in original_results.iter().zip(restored_results.iter()) {
assert_eq!(orig.id, rest.id);
assert!((orig.score - rest.score).abs() < 1e-6);
}
}
#[test]
fn test_snapshot_empty_index() {
let index = FlatIndex::new(FlatIndexConfig::new(128));
let snapshot = index.to_snapshot().unwrap();
let restored = FlatIndex::from_snapshot(&snapshot).unwrap();
assert_eq!(restored.dimensions(), 128);
assert!(restored.is_empty());
assert_eq!(restored.len(), 0);
}
#[test]
fn test_snapshot_preserves_next_id() {
let mut index = FlatIndex::new(FlatIndexConfig::new(8));
let id1 = index.insert(&[1.0; 8]).unwrap();
let id2 = index.insert(&[2.0; 8]).unwrap();
index.delete(id1);
let id3 = index.insert(&[3.0; 8]).unwrap();
assert_eq!(id1, 0);
assert_eq!(id2, 1);
assert_eq!(id3, 2);
let snapshot = index.to_snapshot().unwrap();
let mut restored = FlatIndex::from_snapshot(&snapshot).unwrap();
let id4 = restored.insert(&[4.0; 8]).unwrap();
assert_eq!(id4, 3);
}
#[test]
fn test_snapshot_cleanup_threshold() {
let config = FlatIndexConfig::new(8).with_cleanup_threshold(0.25);
let index = FlatIndex::new(config);
let snapshot = index.to_snapshot().unwrap();
let restored = FlatIndex::from_snapshot(&snapshot).unwrap();
assert_eq!(restored.dimensions(), 8);
}
#[test]
fn test_snapshot_header_validation() {
use crate::persistence::PersistenceError;
let header = FlatIndexHeader {
magic: FLAT_INDEX_MAGIC,
version: FLAT_INDEX_VERSION,
dimensions: 64,
metric: DistanceMetric::Cosine,
count: 10,
delete_count: 0,
next_id: 10,
is_quantized: false,
cleanup_threshold: 0.5,
checksum: 0,
};
assert!(header.validate().is_ok());
let bad_magic = FlatIndexHeader {
magic: [b'X', b'X', b'X', b'X'],
..header.clone()
};
assert!(matches!(
bad_magic.validate(),
Err(PersistenceError::InvalidMagic { .. })
));
let future_version = FlatIndexHeader {
version: FLAT_INDEX_VERSION + 1,
..header.clone()
};
assert!(matches!(
future_version.validate(),
Err(PersistenceError::UnsupportedVersion(_, _))
));
}
#[test]
fn test_bq_vs_f32_recall_comparison() {
use std::collections::HashSet;
let dim = 128;
let count = 500;
let k = 10;
let num_queries = 50;
let seed = 12345u64;
let config = FlatIndexConfig::new(dim).with_metric(DistanceMetric::Cosine);
let mut index = FlatIndex::new(config);
let mut state = seed;
let next_f32 = |s: &mut u64| -> f32 {
*s = s.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
((*s >> 33) as f32 / (u32::MAX >> 1) as f32) * 2.0 - 1.0
};
for _ in 0..count {
let v: Vec<f32> = (0..dim).map(|_| next_f32(&mut state)).collect();
index.insert(&v).unwrap();
}
index.enable_quantization().unwrap();
let mut total_recall = 0.0;
for q in 0..num_queries {
let query_id = q * (count / num_queries);
let query = index
.get(u64::try_from(query_id).unwrap())
.unwrap()
.to_vec();
let f32_results = index.search(&query, k).unwrap();
let f32_ids: HashSet<u64> = f32_results.iter().map(|r| r.id).collect();
let bq_results = index.search_quantized(&query, k).unwrap();
let bq_ids: HashSet<u64> = bq_results.iter().map(|r| r.id).collect();
let intersection = f32_ids.intersection(&bq_ids).count();
let recall = intersection as f32 / k as f32;
total_recall += recall;
}
let avg_recall = total_recall / num_queries as f32;
println!(
"BQ vs F32 Recall Comparison:\n\
- Dataset: {} vectors @ {}D (random uniform)\n\
- Queries: {}\n\
- k: {}\n\
- Average Recall@{}: {:.1}%\n\
- Note: Random data has lower recall; real embeddings achieve 70-90%",
count,
dim,
num_queries,
k,
k,
avg_recall * 100.0
);
assert!(
avg_recall >= 0.20,
"BQ recall too low: {:.1}% (minimum: 20% for random data)",
avg_recall * 100.0
);
let random_chance = k as f32 / count as f32;
assert!(
avg_recall > random_chance * 5.0,
"BQ recall ({:.1}%) not significantly better than random ({:.1}%)",
avg_recall * 100.0,
random_chance * 100.0
);
}
#[test]
fn test_id_overflow_protection() {
let mut index = FlatIndex::new(FlatIndexConfig::new(3));
index.next_id = u64::MAX;
let result = index.insert(&[1.0, 2.0, 3.0]);
assert!(matches!(result, Err(FlatIndexError::IdOverflow)));
}
}