use crate::hnsw::VectorId;
use crate::metric::{Hamming, Metric};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Error)]
pub enum BinaryFlatIndexError {
#[error("dimensions must be divisible by 8, got {0}")]
InvalidDimensions(usize),
#[error("vector length {actual} doesn't match expected {expected}")]
DimensionMismatch {
expected: usize,
actual: usize,
},
#[error("capacity overflow: {0} * {1} exceeds usize::MAX")]
CapacityOverflow(usize, usize),
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct BinaryFlatSearchResult {
pub id: VectorId,
pub distance: f32,
}
const PARTIAL_SORT_THRESHOLD: usize = 10;
const SERIALIZATION_HEADER_SIZE: usize = 8;
#[derive(Debug, Serialize, Deserialize)]
pub struct BinaryFlatIndex {
vectors: Vec<u8>,
dimensions: usize,
bytes_per_vector: usize,
count: usize,
}
impl BinaryFlatIndex {
#[must_use = "constructors return a Result that should be used"]
pub fn new(dimensions: usize) -> Result<Self, BinaryFlatIndexError> {
if dimensions % 8 != 0 {
return Err(BinaryFlatIndexError::InvalidDimensions(dimensions));
}
Ok(Self {
vectors: Vec::new(),
dimensions,
bytes_per_vector: dimensions / 8,
count: 0,
})
}
#[must_use = "constructors return a Result that should be used"]
pub fn with_capacity(dimensions: usize, capacity: usize) -> Result<Self, BinaryFlatIndexError> {
if dimensions % 8 != 0 {
return Err(BinaryFlatIndexError::InvalidDimensions(dimensions));
}
let bytes_per_vector = dimensions / 8;
let total_bytes = capacity.checked_mul(bytes_per_vector).ok_or(
BinaryFlatIndexError::CapacityOverflow(capacity, bytes_per_vector),
)?;
Ok(Self {
vectors: Vec::with_capacity(total_bytes),
dimensions,
bytes_per_vector,
count: 0,
})
}
#[inline]
#[must_use = "insert returns the assigned VectorId"]
pub fn insert(&mut self, vector: &[u8]) -> Result<VectorId, BinaryFlatIndexError> {
if vector.len() != self.bytes_per_vector {
return Err(BinaryFlatIndexError::DimensionMismatch {
expected: self.bytes_per_vector,
actual: vector.len(),
});
}
self.vectors.extend_from_slice(vector);
self.count += 1;
Ok(VectorId(self.count as u64))
}
#[must_use = "search returns the nearest neighbors"]
pub fn search(
&self,
query: &[u8],
k: usize,
) -> Result<Vec<BinaryFlatSearchResult>, BinaryFlatIndexError> {
if query.len() != self.bytes_per_vector {
return Err(BinaryFlatIndexError::DimensionMismatch {
expected: self.bytes_per_vector,
actual: query.len(),
});
}
if self.count == 0 || k == 0 {
return Ok(Vec::new());
}
let k = k.min(self.count);
let mut results: Vec<(VectorId, f32)> = Vec::with_capacity(self.count);
for i in 0..self.count {
let start = i * self.bytes_per_vector;
let end = start + self.bytes_per_vector;
let stored = &self.vectors[start..end];
let dist = Hamming::distance(query, stored);
results.push((VectorId((i + 1) as u64), dist));
}
if k < self.count / PARTIAL_SORT_THRESHOLD {
results.select_nth_unstable_by(k - 1, |a, b| a.1.total_cmp(&b.1));
results.truncate(k);
results.sort_by(|a, b| a.1.total_cmp(&b.1));
} else {
results.sort_by(|a, b| a.1.total_cmp(&b.1));
results.truncate(k);
}
Ok(results
.into_iter()
.map(|(id, distance)| BinaryFlatSearchResult { id, distance })
.collect())
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn get(&self, id: VectorId) -> Option<&[u8]> {
let idx = (id.0 as usize).checked_sub(1)?;
if idx >= self.count {
return None;
}
let start = idx * self.bytes_per_vector;
let end = start + self.bytes_per_vector;
Some(&self.vectors[start..end])
}
#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.count
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[must_use]
#[inline]
pub fn dimensions(&self) -> usize {
self.dimensions
}
#[must_use]
#[inline]
pub fn bytes_per_vector(&self) -> usize {
self.bytes_per_vector
}
#[must_use]
pub fn memory_usage(&self) -> usize {
std::mem::size_of::<Self>() + self.vectors.capacity()
}
#[inline]
#[must_use]
pub fn vectors_len(&self) -> usize {
self.vectors.len()
}
#[must_use]
pub fn serialized_size(&self) -> usize {
SERIALIZATION_HEADER_SIZE + self.vectors.len()
}
pub fn clear(&mut self) {
self.vectors.clear();
self.count = 0;
}
pub fn shrink_to_fit(&mut self) {
self.vectors.shrink_to_fit();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let index = BinaryFlatIndex::new(1024).unwrap();
assert_eq!(index.dimensions(), 1024);
assert_eq!(index.bytes_per_vector(), 128);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_insert_and_get() {
let mut index = BinaryFlatIndex::new(64).unwrap(); let v1 = vec![0xFF; 8];
let v2 = vec![0x00; 8];
let id1 = index.insert(&v1).unwrap();
let id2 = index.insert(&v2).unwrap();
assert_eq!(id1, VectorId(1)); assert_eq!(id2, VectorId(2));
assert_eq!(index.len(), 2);
assert_eq!(index.get(id1), Some(v1.as_slice()));
assert_eq!(index.get(id2), Some(v2.as_slice()));
assert_eq!(index.get(VectorId(99)), None);
}
#[test]
fn test_search_exact_match() {
let mut index = BinaryFlatIndex::new(64).unwrap();
let v1 = vec![0xFF; 8];
let v2 = vec![0x00; 8];
let v3 = vec![0xAA; 8];
index.insert(&v1).unwrap();
index.insert(&v2).unwrap();
index.insert(&v3).unwrap();
let results = index.search(&v2, 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, VectorId(2)); assert!((results[0].distance - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_search_ordering() {
let mut index = BinaryFlatIndex::new(64).unwrap();
let query = vec![0x00; 8]; let v1 = vec![0xFF; 8]; let v2 = vec![0x0F; 8]; let v3 = vec![0x01; 8];
index.insert(&v1).unwrap(); index.insert(&v2).unwrap(); index.insert(&v3).unwrap();
let results = index.search(&query, 3).unwrap();
assert_eq!(results[0].id, VectorId(3)); assert!((results[0].distance - 8.0).abs() < f32::EPSILON);
assert_eq!(results[1].id, VectorId(2));
assert!((results[1].distance - 32.0).abs() < f32::EPSILON);
assert_eq!(results[2].id, VectorId(1)); assert!((results[2].distance - 64.0).abs() < f32::EPSILON);
}
#[test]
fn test_search_k_limit() {
let mut index = BinaryFlatIndex::new(64).unwrap();
for i in 0..100 {
let v: Vec<u8> = (0..8)
.map(|j| u8::try_from((i + j) % 256).unwrap())
.collect();
index.insert(&v).unwrap();
}
let query = vec![0x00; 8];
let results = index.search(&query, 5).unwrap();
assert_eq!(results.len(), 5);
for i in 1..results.len() {
assert!(results[i - 1].distance <= results[i].distance);
}
}
#[test]
fn test_empty_search() {
let index = BinaryFlatIndex::new(64).unwrap();
let query = vec![0x00; 8];
let results = index.search(&query, 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_k_zero() {
let mut index = BinaryFlatIndex::new(64).unwrap();
for _ in 0..20 {
index.insert(&[0xFF; 8]).unwrap();
}
let results = index.search(&[0x00; 8], 0).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_clear() {
let mut index = BinaryFlatIndex::new(64).unwrap();
index.insert(&[0xFF; 8]).unwrap();
index.insert(&[0x00; 8]).unwrap();
assert_eq!(index.len(), 2);
index.clear();
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_memory_usage() {
let mut index = BinaryFlatIndex::with_capacity(1024, 1000).unwrap();
assert!(index.memory_usage() > 0);
for _ in 0..100 {
index.insert(&[0xAA; 128]).unwrap();
}
let usage = index.memory_usage();
assert!(usage >= 12_800);
}
#[test]
fn test_invalid_dimensions() {
let result = BinaryFlatIndex::new(100); assert!(matches!(
result,
Err(BinaryFlatIndexError::InvalidDimensions(100))
));
}
#[test]
fn test_invalid_vector_length() {
let mut index = BinaryFlatIndex::new(64).unwrap();
let result = index.insert(&[0xFF; 16]); assert!(matches!(
result,
Err(BinaryFlatIndexError::DimensionMismatch {
expected: 8,
actual: 16
})
));
}
#[test]
fn test_invalid_query_length() {
let mut index = BinaryFlatIndex::new(64).unwrap();
index.insert(&[0xFF; 8]).unwrap();
let result = index.search(&[0x00; 16], 1); assert!(matches!(
result,
Err(BinaryFlatIndexError::DimensionMismatch {
expected: 8,
actual: 16
})
));
}
#[test]
fn test_vectors_len() {
let mut index = BinaryFlatIndex::new(64).unwrap(); assert_eq!(index.vectors_len(), 0);
index.insert(&[0xAA; 8]).unwrap();
assert_eq!(index.vectors_len(), 8);
assert_eq!(index.vectors_len(), index.len() * index.bytes_per_vector());
index.insert(&[0xBB; 8]).unwrap();
assert_eq!(index.vectors_len(), 16);
assert_eq!(index.vectors_len(), index.len() * index.bytes_per_vector());
index.insert(&[0xCC; 8]).unwrap();
assert_eq!(index.vectors_len(), 24);
assert_eq!(index.vectors_len(), index.len() * index.bytes_per_vector());
index.clear();
assert_eq!(index.vectors_len(), 0);
assert_eq!(index.len(), 0);
assert_eq!(index.vectors_len(), index.len() * index.bytes_per_vector());
}
#[test]
fn test_serialized_size() {
let mut index = BinaryFlatIndex::new(64).unwrap(); assert_eq!(index.serialized_size(), 8);
index.insert(&[0xAA; 8]).unwrap();
assert_eq!(index.serialized_size(), 8 + 8);
index.insert(&[0xBB; 8]).unwrap();
assert_eq!(index.serialized_size(), 8 + 16);
assert_eq!(
index.serialized_size(),
SERIALIZATION_HEADER_SIZE + index.vectors_len()
);
index.clear();
assert_eq!(index.serialized_size(), SERIALIZATION_HEADER_SIZE);
}
#[test]
fn test_shrink_to_fit() {
let mut index = BinaryFlatIndex::with_capacity(1024, 10_000).unwrap();
assert!(index.memory_usage() > 100_000);
for _ in 0..10 {
index.insert(&[0xAA; 128]).unwrap();
}
let before_shrink = index.memory_usage();
index.shrink_to_fit();
let after_shrink = index.memory_usage();
assert!(after_shrink <= before_shrink);
assert_eq!(index.len(), 10);
assert_eq!(index.vectors_len(), 10 * 128);
for i in 1..=10 {
assert_eq!(index.get(VectorId(i as u64)), Some([0xAA; 128].as_slice()));
}
let mut empty = BinaryFlatIndex::new(64).unwrap();
empty.shrink_to_fit();
assert_eq!(empty.len(), 0);
}
#[test]
fn test_serde_roundtrip() {
let mut index = BinaryFlatIndex::new(64).unwrap();
index.insert(&[0xFF; 8]).unwrap();
index.insert(&[0x00; 8]).unwrap();
index.insert(&[0xAA; 8]).unwrap();
let json = serde_json::to_string(&index).expect("serialize failed");
let restored: BinaryFlatIndex = serde_json::from_str(&json).expect("deserialize failed");
assert_eq!(restored.dimensions(), index.dimensions());
assert_eq!(restored.bytes_per_vector(), index.bytes_per_vector());
assert_eq!(restored.len(), index.len());
assert_eq!(restored.vectors_len(), index.vectors_len());
for i in 1..=3u64 {
assert_eq!(
restored.get(VectorId(i)),
index.get(VectorId(i)),
"vector {} mismatch after roundtrip",
i
);
}
}
}