use crate::quantization::variable::BinaryVector;
use bitvec::order::Lsb0;
use bitvec::vec::BitVec;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BinaryStorageError {
InvalidDimension {
dimension: usize,
},
DimensionMismatch {
expected: usize,
actual: usize,
},
NotFound {
id: u64,
},
AlreadyDeleted {
id: u64,
},
}
impl fmt::Display for BinaryStorageError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidDimension { dimension } => {
write!(
f,
"dimension must be divisible by 8, got {dimension}. \
Try using a dimension like 128, 384, 768, 1024, or 1536."
)
}
Self::DimensionMismatch { expected, actual } => {
write!(f, "dimension mismatch: expected {expected}, got {actual}")
}
Self::NotFound { id } => {
write!(f, "vector with id {id} not found")
}
Self::AlreadyDeleted { id } => {
write!(f, "vector with id {id} is already deleted")
}
}
}
}
impl std::error::Error for BinaryStorageError {}
#[derive(Clone)]
pub struct BinaryVectorStorage {
data: Vec<u8>,
dimension: usize,
bytes_per_vector: usize,
deleted: BitVec<u8, Lsb0>,
count: usize,
next_id: u64,
}
impl BinaryVectorStorage {
pub fn new(dimension: usize) -> Result<Self, BinaryStorageError> {
if dimension == 0 || dimension % 8 != 0 {
return Err(BinaryStorageError::InvalidDimension { dimension });
}
Ok(Self {
data: Vec::new(),
dimension,
bytes_per_vector: dimension / 8,
deleted: BitVec::new(),
count: 0,
next_id: 0,
})
}
#[must_use]
#[inline]
pub fn dimension(&self) -> usize {
self.dimension
}
#[must_use]
#[inline]
pub fn bytes_per_vector(&self) -> usize {
self.bytes_per_vector
}
#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.count
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[must_use]
#[inline]
pub fn memory_bytes(&self) -> usize {
self.data.len()
}
pub fn insert(&mut self, vector: &BinaryVector) -> Result<u64, BinaryStorageError> {
if vector.dimension() != self.dimension {
return Err(BinaryStorageError::DimensionMismatch {
expected: self.dimension,
actual: vector.dimension(),
});
}
let id = self.next_id;
self.next_id += 1;
self.data.extend_from_slice(vector.data());
self.deleted.push(false);
self.count += 1;
Ok(id)
}
pub fn insert_raw(&mut self, data: &[u8]) -> Result<u64, BinaryStorageError> {
if data.len() != self.bytes_per_vector {
return Err(BinaryStorageError::DimensionMismatch {
expected: self.dimension,
actual: data.len() * 8,
});
}
let id = self.next_id;
self.next_id += 1;
self.data.extend_from_slice(data);
self.deleted.push(false);
self.count += 1;
Ok(id)
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn get(&self, id: u64) -> Option<BinaryVector> {
let idx = id as usize;
if idx >= self.count || self.deleted[idx] {
return None;
}
let start = idx * self.bytes_per_vector;
let end = start + self.bytes_per_vector;
let data = self.data[start..end].to_vec();
BinaryVector::from_bytes(data, self.dimension).ok()
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn get_raw(&self, id: u64) -> Option<&[u8]> {
let idx = id as usize;
if idx >= self.count || self.deleted[idx] {
return None;
}
let start = idx * self.bytes_per_vector;
let end = start + self.bytes_per_vector;
Some(&self.data[start..end])
}
#[allow(clippy::cast_possible_truncation)]
pub fn delete(&mut self, id: u64) -> Result<(), BinaryStorageError> {
let idx = id as usize;
if idx >= self.count {
return Err(BinaryStorageError::NotFound { id });
}
if self.deleted[idx] {
return Err(BinaryStorageError::AlreadyDeleted { id });
}
self.deleted.set(idx, true);
Ok(())
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn is_deleted(&self, id: u64) -> bool {
let idx = id as usize;
idx < self.count && self.deleted[idx]
}
#[must_use]
pub fn live_count(&self) -> usize {
self.count - self.deleted_count()
}
#[must_use]
pub fn deleted_count(&self) -> usize {
self.deleted.count_ones()
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn compaction_ratio(&self) -> f64 {
if self.count == 0 {
1.0
} else {
self.live_count() as f64 / self.count as f64
}
}
pub fn iter_live(&self) -> impl Iterator<Item = (u64, &[u8])> {
let bytes_per_vector = self.bytes_per_vector;
let data = &self.data;
(0..self.count)
.filter(|&idx| !self.deleted[idx])
.map(move |idx| {
let start = idx * bytes_per_vector;
let end = start + bytes_per_vector;
(idx as u64, &data[start..end])
})
}
pub fn iter_all(&self) -> impl Iterator<Item = (u64, &[u8], bool)> {
let bytes_per_vector = self.bytes_per_vector;
let data = &self.data;
let deleted = &self.deleted;
(0..self.count).map(move |idx| {
let start = idx * bytes_per_vector;
let end = start + bytes_per_vector;
(idx as u64, &data[start..end], deleted[idx])
})
}
pub fn shrink_to_fit(&mut self) {
self.data.shrink_to_fit();
self.deleted.shrink_to_fit();
}
pub fn reserve(&mut self, additional: usize) {
self.data.reserve(additional * self.bytes_per_vector);
self.deleted.reserve(additional);
}
#[must_use]
pub fn capacity(&self) -> usize {
self.data.capacity() / self.bytes_per_vector
}
}
impl fmt::Debug for BinaryVectorStorage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BinaryVectorStorage")
.field("dimension", &self.dimension)
.field("bytes_per_vector", &self.bytes_per_vector)
.field("count", &self.count)
.field("live_count", &self.live_count())
.field("memory_bytes", &self.memory_bytes())
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[allow(clippy::cast_precision_loss)]
mod tests {
use super::*;
fn make_vector(dimension: usize, value: f32) -> BinaryVector {
let v = vec![value; dimension];
BinaryVector::quantize(&v).unwrap()
}
#[test]
fn test_new_valid_dimensions() {
assert!(BinaryVectorStorage::new(128).is_ok());
assert!(BinaryVectorStorage::new(384).is_ok());
assert!(BinaryVectorStorage::new(768).is_ok());
assert!(BinaryVectorStorage::new(1024).is_ok());
assert!(BinaryVectorStorage::new(1536).is_ok());
}
#[test]
fn test_new_invalid_dimension_not_divisible() {
let result = BinaryVectorStorage::new(100);
assert!(matches!(
result,
Err(BinaryStorageError::InvalidDimension { dimension: 100 })
));
}
#[test]
fn test_new_invalid_dimension_zero() {
let result = BinaryVectorStorage::new(0);
assert!(matches!(
result,
Err(BinaryStorageError::InvalidDimension { dimension: 0 })
));
}
#[test]
fn test_insert_and_get() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
let bv = make_vector(128, 1.0);
let id = storage.insert(&bv).unwrap();
assert_eq!(id, 0);
assert_eq!(storage.len(), 1);
let retrieved = storage.get(id).unwrap();
assert_eq!(retrieved.dimension(), 128);
assert_eq!(retrieved.data(), bv.data());
}
#[test]
fn test_insert_multiple() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
for i in 0..5 {
let bv = make_vector(128, i as f32);
let id = storage.insert(&bv).unwrap();
assert_eq!(id, i);
}
assert_eq!(storage.len(), 5);
}
#[test]
fn test_insert_dimension_mismatch() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
let bv = make_vector(256, 1.0);
let result = storage.insert(&bv);
assert!(matches!(
result,
Err(BinaryStorageError::DimensionMismatch {
expected: 128,
actual: 256
})
));
}
#[test]
fn test_insert_raw() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
let data = vec![0xFF; 16];
let id = storage.insert_raw(&data).unwrap();
assert_eq!(id, 0);
let raw = storage.get_raw(id).unwrap();
assert_eq!(raw, &data[..]);
}
#[test]
fn test_insert_raw_dimension_mismatch() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
let data = vec![0xFF; 32];
let result = storage.insert_raw(&data);
assert!(matches!(
result,
Err(BinaryStorageError::DimensionMismatch { .. })
));
}
#[test]
fn test_get_nonexistent() {
let storage = BinaryVectorStorage::new(128).unwrap();
assert!(storage.get(0).is_none());
assert!(storage.get(100).is_none());
}
#[test]
fn test_get_raw() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
let bv = make_vector(128, 1.0);
let id = storage.insert(&bv).unwrap();
let raw = storage.get_raw(id).unwrap();
assert_eq!(raw.len(), 16); assert_eq!(raw, bv.data());
}
#[test]
fn test_get_raw_nonexistent() {
let storage = BinaryVectorStorage::new(128).unwrap();
assert!(storage.get_raw(0).is_none());
}
#[test]
fn test_delete() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
let bv = make_vector(128, 1.0);
let id = storage.insert(&bv).unwrap();
assert!(storage.get(id).is_some());
storage.delete(id).unwrap();
assert!(storage.get(id).is_none());
assert!(storage.is_deleted(id));
}
#[test]
fn test_delete_not_found() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
let result = storage.delete(0);
assert!(matches!(
result,
Err(BinaryStorageError::NotFound { id: 0 })
));
}
#[test]
fn test_delete_already_deleted() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
let bv = make_vector(128, 1.0);
let id = storage.insert(&bv).unwrap();
storage.delete(id).unwrap();
let result = storage.delete(id);
assert!(matches!(
result,
Err(BinaryStorageError::AlreadyDeleted { id: 0 })
));
}
#[test]
fn test_is_deleted() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
let bv = make_vector(128, 1.0);
let id = storage.insert(&bv).unwrap();
assert!(!storage.is_deleted(id));
storage.delete(id).unwrap();
assert!(storage.is_deleted(id));
}
#[test]
fn test_is_deleted_out_of_bounds() {
let storage = BinaryVectorStorage::new(128).unwrap();
assert!(!storage.is_deleted(100));
}
#[test]
fn test_live_count() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
for i in 0..5 {
let bv = make_vector(128, i as f32);
storage.insert(&bv).unwrap();
}
assert_eq!(storage.len(), 5);
assert_eq!(storage.live_count(), 5);
assert_eq!(storage.deleted_count(), 0);
storage.delete(1).unwrap();
storage.delete(3).unwrap();
assert_eq!(storage.len(), 5);
assert_eq!(storage.live_count(), 3);
assert_eq!(storage.deleted_count(), 2);
}
#[test]
fn test_compaction_ratio() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
assert!((storage.compaction_ratio() - 1.0).abs() < 0.001);
for i in 0..10 {
let bv = make_vector(128, i as f32);
storage.insert(&bv).unwrap();
}
assert!((storage.compaction_ratio() - 1.0).abs() < 0.001);
storage.delete(0).unwrap();
storage.delete(5).unwrap();
assert!((storage.compaction_ratio() - 0.8).abs() < 0.001);
}
#[test]
fn test_iter_live() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
for i in 0..5 {
let bv = make_vector(128, i as f32);
storage.insert(&bv).unwrap();
}
storage.delete(1).unwrap();
storage.delete(3).unwrap();
let live_ids: Vec<_> = storage.iter_live().map(|(id, _)| id).collect();
assert_eq!(live_ids, vec![0, 2, 4]);
}
#[test]
fn test_iter_all() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
for i in 0..3 {
let bv = make_vector(128, i as f32);
storage.insert(&bv).unwrap();
}
storage.delete(1).unwrap();
let all: Vec<_> = storage.iter_all().collect();
assert_eq!(all.len(), 3);
assert!(!all[0].2); assert!(all[1].2); assert!(!all[2].2); }
#[test]
fn test_memory_bytes() {
let mut storage = BinaryVectorStorage::new(768).unwrap();
assert_eq!(storage.memory_bytes(), 0);
let bv = make_vector(768, 1.0);
storage.insert(&bv).unwrap();
assert_eq!(storage.memory_bytes(), 96);
storage.insert(&bv).unwrap();
assert_eq!(storage.memory_bytes(), 192); }
#[test]
fn test_bytes_per_vector() {
let storage = BinaryVectorStorage::new(768).unwrap();
assert_eq!(storage.bytes_per_vector(), 96);
let storage = BinaryVectorStorage::new(128).unwrap();
assert_eq!(storage.bytes_per_vector(), 16);
let storage = BinaryVectorStorage::new(1536).unwrap();
assert_eq!(storage.bytes_per_vector(), 192);
}
#[test]
fn test_reserve_and_capacity() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
assert_eq!(storage.capacity(), 0);
storage.reserve(100);
assert!(storage.capacity() >= 100);
}
#[test]
fn test_shrink_to_fit() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
storage.reserve(1000);
let bv = make_vector(128, 1.0);
storage.insert(&bv).unwrap();
let capacity_before = storage.capacity();
storage.shrink_to_fit();
let capacity_after = storage.capacity();
assert!(capacity_after <= capacity_before);
}
#[test]
fn test_debug_format() {
let storage = BinaryVectorStorage::new(768).unwrap();
let debug = format!("{storage:?}");
assert!(debug.contains("BinaryVectorStorage"));
assert!(debug.contains("dimension: 768"));
assert!(debug.contains("bytes_per_vector: 96"));
assert!(debug.contains(".."));
}
#[test]
fn test_error_display() {
let err = BinaryStorageError::InvalidDimension { dimension: 100 };
let msg = err.to_string();
assert!(msg.contains("100"));
assert!(msg.contains("divisible by 8"));
let err = BinaryStorageError::DimensionMismatch {
expected: 128,
actual: 256,
};
let msg = err.to_string();
assert!(msg.contains("128"));
assert!(msg.contains("256"));
let err = BinaryStorageError::NotFound { id: 42 };
let msg = err.to_string();
assert!(msg.contains("42"));
let err = BinaryStorageError::AlreadyDeleted { id: 42 };
let msg = err.to_string();
assert!(msg.contains("42"));
}
#[test]
fn test_clone() {
let mut storage = BinaryVectorStorage::new(128).unwrap();
let bv = make_vector(128, 1.0);
storage.insert(&bv).unwrap();
let cloned = storage.clone();
assert_eq!(cloned.len(), storage.len());
assert_eq!(cloned.dimension(), storage.dimension());
assert_eq!(cloned.get_raw(0), storage.get_raw(0));
}
}