use std::collections::HashMap;
use turbovec::TurboQuantIndex;
use crate::distance::Distance;
use crate::index::{IndexError, NodeId, SearchResult};
pub struct TurboTable {
bits: u8,
distance: Distance,
dim: u16,
index: TurboQuantIndex,
slots: Vec<Option<NodeId>>,
id_to_slot: HashMap<NodeId, usize>,
}
impl TurboTable {
pub fn new(distance: Distance, dim: u16, bits: u8) -> Result<Self, IndexError> {
if dim == 0 {
return Err(IndexError::Empty);
}
if !dim.is_multiple_of(8) {
return Err(IndexError::DimensionMismatch {
expected: ((dim / 8) + 1) * 8,
got: dim,
});
}
if !(2..=4).contains(&bits) {
return Err(IndexError::Empty);
}
let index = TurboQuantIndex::new(usize::from(dim), usize::from(bits))
.map_err(|_| IndexError::Empty)?;
Ok(Self {
bits,
distance,
dim,
index,
slots: Vec::new(),
id_to_slot: HashMap::new(),
})
}
#[must_use]
pub fn len(&self) -> usize {
self.slots.iter().filter(|s| s.is_some()).count()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn dim(&self) -> u16 {
self.dim
}
#[must_use]
pub fn distance(&self) -> Distance {
self.distance
}
#[must_use]
pub fn bits(&self) -> u8 {
self.bits
}
pub fn insert(&mut self, id: NodeId, vector: Vec<f32>) -> Result<(), IndexError> {
if vector.is_empty() {
return Err(IndexError::Empty);
}
let got = u16::try_from(vector.len()).unwrap_or(u16::MAX);
if got != self.dim {
return Err(IndexError::DimensionMismatch {
expected: self.dim,
got,
});
}
if self.id_to_slot.contains_key(&id) {
return Err(IndexError::Duplicate(id));
}
let prepared = match self.distance {
Distance::Cosine | Distance::Euclidean => l2_normalise(&vector),
Distance::DotProduct => vector,
};
self.index
.add_2d(&prepared, usize::from(self.dim))
.map_err(|_| IndexError::Empty)?;
let slot = self.slots.len();
self.slots.push(Some(id));
self.id_to_slot.insert(id, slot);
Ok(())
}
pub fn delete(&mut self, id: NodeId) -> bool {
let Some(slot) = self.id_to_slot.remove(&id) else {
return false;
};
if slot < self.slots.len() {
self.slots[slot] = None;
}
true
}
#[must_use]
pub fn contains(&self, id: NodeId) -> bool {
self.id_to_slot.contains_key(&id)
}
pub fn search(
&self,
query: &[f32],
k: usize,
_ef: Option<usize>,
) -> Result<Vec<SearchResult>, IndexError> {
if query.is_empty() || self.slots.is_empty() {
return Ok(Vec::new());
}
let got = u16::try_from(query.len()).unwrap_or(u16::MAX);
if got != self.dim {
return Err(IndexError::DimensionMismatch {
expected: self.dim,
got,
});
}
let prepared = match self.distance {
Distance::Cosine | Distance::Euclidean => l2_normalise(query),
Distance::DotProduct => query.to_vec(),
};
let mask: Vec<bool> = self.slots.iter().map(Option::is_some).collect();
let allowed = mask.iter().filter(|b| **b).count();
if allowed == 0 {
return Ok(Vec::new());
}
let res = self.index.search_with_mask(&prepared, k, Some(&mask));
let mut out = Vec::with_capacity(res.k);
for i in 0..res.k {
let raw_idx = res.indices[i];
if raw_idx < 0 {
continue;
}
let Ok(slot) = usize::try_from(raw_idx) else {
continue;
};
let Some(Some(node_id)) = self.slots.get(slot) else {
continue;
};
let similarity = res.scores[i];
let score = match self.distance {
Distance::DotProduct => -similarity,
Distance::Cosine => 1.0 - similarity,
Distance::Euclidean => (2.0 - 2.0 * similarity).max(0.0).sqrt(),
};
out.push(SearchResult {
id: *node_id,
score,
});
}
out.sort_by(|a, b| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
out.truncate(k);
Ok(out)
}
}
fn l2_normalise(v: &[f32]) -> Vec<f32> {
let n2: f32 = v.iter().map(|x| x * x).sum();
let n = n2.sqrt();
if n <= 0.0 {
return v.to_vec();
}
v.iter().map(|x| x / n).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn rand_vec(seed: u64, dim: usize) -> Vec<f32> {
let mut x = seed;
let mut v = Vec::with_capacity(dim);
for _ in 0..dim {
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
let bits = (x >> 11) & ((1_u64 << 53) - 1);
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
reason = "test fixture: PRNG narrowed to f32"
)]
let r = (((bits as f64) / ((1_u64 << 53) as f64)) * 2.0 - 1.0) as f32;
v.push(r);
}
v
}
#[test]
fn insert_and_search_returns_self_first() {
let mut t = TurboTable::new(Distance::Cosine, 64, 4).unwrap();
let target = rand_vec(42, 64);
t.insert(0, target.clone()).unwrap();
for i in 1..50_u64 {
t.insert(i, rand_vec(i.wrapping_mul(1_000_003) + 1, 64))
.unwrap();
}
let res = t.search(&target, 3, None).unwrap();
assert!(!res.is_empty());
assert_eq!(res[0].id, 0);
}
#[test]
fn delete_excludes_from_search() {
let mut t = TurboTable::new(Distance::Cosine, 64, 4).unwrap();
for i in 0..30_u64 {
t.insert(i, rand_vec(i + 1, 64)).unwrap();
}
let q = rand_vec(1, 64);
let before = t.search(&q, 5, None).unwrap();
let target = before[0].id;
assert!(t.delete(target));
let after = t.search(&q, 5, None).unwrap();
assert!(after.iter().all(|r| r.id != target));
}
#[test]
fn empty_table_search_is_empty() {
let t = TurboTable::new(Distance::Cosine, 64, 4).unwrap();
assert!(t.search(&rand_vec(0, 64), 5, None).unwrap().is_empty());
}
#[test]
fn duplicate_id_rejected() {
let mut t = TurboTable::new(Distance::Cosine, 64, 4).unwrap();
t.insert(7, rand_vec(7, 64)).unwrap();
assert!(matches!(
t.insert(7, rand_vec(8, 64)),
Err(IndexError::Duplicate(7))
));
}
#[test]
fn dimension_mismatch_rejected() {
let mut t = TurboTable::new(Distance::Cosine, 64, 4).unwrap();
assert!(matches!(
t.insert(0, vec![0.1; 32]),
Err(IndexError::DimensionMismatch { .. })
));
}
}