pub mod codebook;
pub mod encode;
pub mod error;
pub mod id_map;
pub mod io;
pub mod pack;
pub mod rotation;
pub mod search;
pub use error::{AddError, ConstructError};
pub use id_map::IdMapIndex;
use std::path::Path;
use std::sync::OnceLock;
const ROTATION_SEED: u64 = 42;
const BLOCK: usize = 32;
const FLUSH_EVERY: usize = 256;
const MAX_INPUT_MAGNITUDE: f32 = 1e16;
fn first_invalid_coord(values: &[f32], dim: usize) -> Option<(usize, usize, f32)> {
for (i, x) in values.iter().enumerate() {
if !x.is_finite() || x.abs() >= MAX_INPUT_MAGNITUDE {
let vector_index = if dim == 0 { 0 } else { i / dim };
let coord_index = if dim == 0 { i } else { i % dim };
return Some((vector_index, coord_index, *x));
}
}
None
}
struct BlockedCache {
data: Vec<u8>,
n_blocks: usize,
}
pub struct TurboQuantIndex {
dim: Option<usize>,
bit_width: usize,
n_vectors: usize,
packed_codes: Vec<u8>,
scales: Vec<f32>,
tqplus_shift: Vec<f32>,
tqplus_scale: Vec<f32>,
rotation: OnceLock<Vec<f32>>,
boundaries: OnceLock<Vec<f32>>,
centroids: OnceLock<Vec<f32>>,
blocked: OnceLock<BlockedCache>,
}
pub struct SearchResults {
pub scores: Vec<f32>,
pub indices: Vec<i64>,
pub nq: usize,
pub k: usize,
}
impl SearchResults {
pub fn scores_for_query(&self, qi: usize) -> &[f32] {
&self.scores[qi * self.k..(qi + 1) * self.k]
}
pub fn indices_for_query(&self, qi: usize) -> &[i64] {
&self.indices[qi * self.k..(qi + 1) * self.k]
}
}
impl TurboQuantIndex {
pub fn new(dim: usize, bit_width: usize) -> Result<Self, ConstructError> {
if !(2..=4).contains(&bit_width) {
return Err(ConstructError::BitWidthOutOfRange(bit_width));
}
if dim == 0 || dim % 8 != 0 {
return Err(ConstructError::DimNotPositiveMultipleOf8(dim));
}
Ok(Self {
dim: Some(dim),
bit_width,
n_vectors: 0,
packed_codes: Vec::new(),
scales: Vec::new(),
tqplus_shift: Vec::new(),
tqplus_scale: Vec::new(),
rotation: OnceLock::new(),
boundaries: OnceLock::new(),
centroids: OnceLock::new(),
blocked: OnceLock::new(),
})
}
pub fn new_lazy(bit_width: usize) -> Result<Self, ConstructError> {
if !(2..=4).contains(&bit_width) {
return Err(ConstructError::BitWidthOutOfRange(bit_width));
}
Ok(Self {
dim: None,
bit_width,
n_vectors: 0,
packed_codes: Vec::new(),
scales: Vec::new(),
tqplus_shift: Vec::new(),
tqplus_scale: Vec::new(),
rotation: OnceLock::new(),
boundaries: OnceLock::new(),
centroids: OnceLock::new(),
blocked: OnceLock::new(),
})
}
pub fn add(&mut self, vectors: &[f32]) {
let dim = self.dim.expect(
"TurboQuantIndex dim is not set; use add_2d(vectors, dim) on the \
first add or construct via TurboQuantIndex::new(dim, bit_width)",
);
let n = vectors.len() / dim;
assert_eq!(
vectors.len(),
n * dim,
"vectors length must be a multiple of dim"
);
if n == 0 {
return;
}
if let Some((vi, ci, v)) = first_invalid_coord(vectors, dim) {
panic!(
"invalid input value at vector {vi}, coord {ci}: {v} \
(must be finite and |value| < 1e16 to avoid f32 norm overflow)",
);
}
let rotation = self
.rotation
.get_or_init(|| rotation::make_rotation_matrix(dim));
if self.boundaries.get().is_none() || self.centroids.get().is_none() {
let (boundaries, centroids) = codebook::codebook(self.bit_width, dim);
let _ = self.boundaries.set(boundaries);
let _ = self.centroids.set(centroids);
}
let boundaries = self
.boundaries
.get()
.expect("boundaries cache is initialized");
let centroids = self
.centroids
.get()
.expect("centroids cache is initialized");
let existing = if self.tqplus_shift.is_empty() {
None
} else {
Some((self.tqplus_shift.as_slice(), self.tqplus_scale.as_slice()))
};
let (packed, scales, shift, scale_tq) = encode::encode(
vectors,
n,
dim,
rotation,
boundaries,
centroids,
self.bit_width,
existing,
);
if self.n_vectors == 0 {
self.packed_codes = packed;
self.scales = scales;
self.tqplus_shift = shift;
self.tqplus_scale = scale_tq;
} else {
self.packed_codes.extend_from_slice(&packed);
self.scales.extend_from_slice(&scales);
}
self.n_vectors += n;
self.blocked = OnceLock::new();
}
pub fn add_2d(&mut self, vectors: &[f32], dim: usize) -> Result<(), AddError> {
match self.dim {
Some(existing) if existing != dim => {
return Err(AddError::DimMismatch { existing, got: dim });
}
Some(_) => {}
None => {
if dim % 8 != 0 {
return Err(AddError::DimNotMultipleOf8(dim));
}
}
}
if let Some((vi, ci, v)) = first_invalid_coord(vectors, dim) {
return Err(AddError::InvalidInputValue {
vector_index: vi,
coord_index: ci,
value: v,
});
}
if self.dim.is_none() {
self.dim = Some(dim);
}
self.add(vectors);
Ok(())
}
pub fn search(&self, queries: &[f32], k: usize) -> SearchResults {
self.search_with_mask(queries, k, None)
}
pub fn search_with_mask(
&self,
queries: &[f32],
k: usize,
mask: Option<&[bool]>,
) -> SearchResults {
let Some(dim) = self.dim else {
return SearchResults {
scores: Vec::new(),
indices: Vec::new(),
nq: 0,
k: 0,
};
};
let nq = queries.len() / dim;
assert_eq!(queries.len(), nq * dim);
if let Some((vi, ci, v)) = first_invalid_coord(queries, dim) {
panic!(
"invalid query value at query {vi}, coord {ci}: {v} \
(must be finite and |value| < 1e16 to avoid f32 overflow)",
);
}
let rotation = self
.rotation
.get_or_init(|| rotation::make_rotation_matrix(dim));
let centroids = self.centroids.get_or_init(|| {
let (_, c) = codebook::codebook(self.bit_width, dim);
c
});
let blocked = self.blocked.get_or_init(|| {
let (data, n_blocks) =
pack::repack(&self.packed_codes, self.n_vectors, self.bit_width, dim);
BlockedCache { data, n_blocks }
});
let packed_mask = mask.map(|m| {
assert_eq!(
m.len(),
self.n_vectors,
"mask length {} does not match index size {}",
m.len(),
self.n_vectors,
);
let n_words = (self.n_vectors + 63) / 64;
let mut buf = vec![0u64; n_words];
for (i, &b) in m.iter().enumerate() {
if b {
buf[i >> 6] |= 1u64 << (i & 63);
}
}
buf
});
let n_allowed = packed_mask.as_ref().map_or(self.n_vectors, |p| {
p.iter().map(|w| w.count_ones() as usize).sum::<usize>()
});
let effective_k = k.min(self.n_vectors).min(n_allowed);
let (scores, indices) = search::search(
queries,
nq,
rotation,
&blocked.data,
centroids,
&self.scales,
&self.tqplus_shift,
&self.tqplus_scale,
self.bit_width,
dim,
self.n_vectors,
blocked.n_blocks,
k,
packed_mask.as_deref(),
);
SearchResults {
scores,
indices,
nq,
k: effective_k,
}
}
pub fn prepare(&self) {
let Some(dim) = self.dim else { return };
self.rotation
.get_or_init(|| rotation::make_rotation_matrix(dim));
self.centroids.get_or_init(|| {
let (_, c) = codebook::codebook(self.bit_width, dim);
c
});
self.blocked.get_or_init(|| {
let (data, n_blocks) =
pack::repack(&self.packed_codes, self.n_vectors, self.bit_width, dim);
BlockedCache { data, n_blocks }
});
}
pub fn write(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
io::write(
path,
self.bit_width,
self.dim.unwrap_or(0),
self.n_vectors,
&self.packed_codes,
&self.scales,
&self.tqplus_shift,
&self.tqplus_scale,
)
}
pub fn load(path: impl AsRef<Path>) -> std::io::Result<Self> {
let (bit_width, dim, n_vectors, packed_codes, scales, tqplus_shift, tqplus_scale) =
io::load(path)?;
let dim_opt = if dim == 0 { None } else { Some(dim) };
Ok(Self::from_parts(
dim_opt,
bit_width,
n_vectors,
packed_codes,
scales,
tqplus_shift,
tqplus_scale,
))
}
pub(crate) fn from_parts(
dim: Option<usize>,
bit_width: usize,
n_vectors: usize,
packed_codes: Vec<u8>,
scales: Vec<f32>,
tqplus_shift: Vec<f32>,
tqplus_scale: Vec<f32>,
) -> Self {
assert_eq!(
tqplus_shift.len(),
tqplus_scale.len(),
"from_parts: tqplus_shift.len()={} != tqplus_scale.len()={}",
tqplus_shift.len(),
tqplus_scale.len(),
);
match dim {
Some(d) => {
let expected_packed = n_vectors * d * bit_width / 8;
assert_eq!(
packed_codes.len(),
expected_packed,
"from_parts: packed_codes.len()={} != n_vectors({}) * dim({}) * bit_width({}) / 8 = {}",
packed_codes.len(),
n_vectors,
d,
bit_width,
expected_packed,
);
assert_eq!(
scales.len(),
n_vectors,
"from_parts: scales.len()={} != n_vectors={}",
scales.len(),
n_vectors,
);
if !tqplus_shift.is_empty() {
assert_eq!(
tqplus_shift.len(),
d,
"from_parts: non-empty TQ+ length {} must equal dim {}",
tqplus_shift.len(),
d,
);
}
}
None => {
assert_eq!(n_vectors, 0, "from_parts: lazy index must have n_vectors=0");
assert!(
packed_codes.is_empty(),
"from_parts: lazy index must have empty packed_codes",
);
assert!(scales.is_empty(), "from_parts: lazy index must have empty scales");
assert!(
tqplus_shift.is_empty(),
"from_parts: lazy index must have empty tqplus_shift",
);
}
}
let (tqplus_shift, tqplus_scale) = if tqplus_shift.is_empty() && n_vectors > 0 {
let d = dim.expect(
"from_parts: n_vectors > 0 implies a committed dim — \
mismatch indicates a corrupted side-car or a misuse",
);
(vec![0.0; d], vec![1.0; d])
} else {
(tqplus_shift, tqplus_scale)
};
Self {
dim,
bit_width,
n_vectors,
packed_codes,
scales,
tqplus_shift,
tqplus_scale,
rotation: OnceLock::new(),
boundaries: OnceLock::new(),
centroids: OnceLock::new(),
blocked: OnceLock::new(),
}
}
pub(crate) fn packed_codes(&self) -> &[u8] {
&self.packed_codes
}
pub(crate) fn scales(&self) -> &[f32] {
&self.scales
}
pub(crate) fn tqplus_shift(&self) -> &[f32] {
&self.tqplus_shift
}
pub(crate) fn tqplus_scale(&self) -> &[f32] {
&self.tqplus_scale
}
pub fn swap_remove(&mut self, idx: usize) -> usize {
assert!(
idx < self.n_vectors,
"index {idx} out of bounds (n_vectors = {})",
self.n_vectors
);
let dim = self.dim.expect("n_vectors > 0 but dim is None");
let bytes_per_vec = dim * self.bit_width / 8;
let last = self.n_vectors - 1;
if idx != last {
let src = last * bytes_per_vec;
let dst = idx * bytes_per_vec;
self.packed_codes.copy_within(src..src + bytes_per_vec, dst);
self.scales[idx] = self.scales[last];
}
self.packed_codes.truncate(last * bytes_per_vec);
self.scales.truncate(last);
self.n_vectors -= 1;
self.blocked = OnceLock::new();
last
}
pub fn len(&self) -> usize {
self.n_vectors
}
pub fn is_empty(&self) -> bool {
self.n_vectors == 0
}
pub fn dim(&self) -> usize {
self.dim.unwrap_or(0)
}
pub fn dim_opt(&self) -> Option<usize> {
self.dim
}
pub fn bit_width(&self) -> usize {
self.bit_width
}
}
#[cfg(test)]
mod from_parts_tests {
use super::TurboQuantIndex;
#[test]
#[should_panic(expected = "packed_codes.len()")]
fn from_parts_panics_on_packed_codes_length_mismatch() {
let _ = TurboQuantIndex::from_parts(
Some(64),
4,
2,
vec![0u8; 32],
vec![1.0f32; 2],
Vec::new(),
Vec::new(),
);
}
#[test]
#[should_panic(expected = "scales.len()")]
fn from_parts_panics_on_scales_length_mismatch() {
let _ = TurboQuantIndex::from_parts(
Some(64),
4,
2,
vec![0u8; 64],
vec![1.0f32; 5], Vec::new(),
Vec::new(),
);
}
#[test]
#[should_panic(expected = "tqplus_shift.len()")]
fn from_parts_panics_on_mismatched_tqplus_lengths() {
let _ = TurboQuantIndex::from_parts(
Some(64),
4,
2,
vec![0u8; 64],
vec![1.0f32; 2],
vec![0.0f32; 64], vec![1.0f32; 32], );
}
#[test]
#[should_panic(expected = "non-empty TQ+ length")]
fn from_parts_panics_when_tqplus_length_does_not_equal_dim() {
let _ = TurboQuantIndex::from_parts(
Some(64),
4,
2,
vec![0u8; 64],
vec![1.0f32; 2],
vec![0.0f32; 48], vec![1.0f32; 48],
);
}
#[test]
#[should_panic(expected = "lazy index must have n_vectors=0")]
fn from_parts_panics_on_lazy_with_nonzero_n_vectors() {
let _ = TurboQuantIndex::from_parts(
None,
4,
5,
Vec::new(),
Vec::new(),
Vec::new(),
Vec::new(),
);
}
#[test]
fn from_parts_accepts_lazy_uncommitted() {
let idx = TurboQuantIndex::from_parts(
None,
4,
0,
Vec::new(),
Vec::new(),
Vec::new(),
Vec::new(),
);
assert_eq!(idx.dim_opt(), None);
assert_eq!(idx.len(), 0);
}
#[test]
fn from_parts_accepts_eager_with_consistent_lengths() {
let idx = TurboQuantIndex::from_parts(
Some(64),
4,
2,
vec![0u8; 64],
vec![1.0f32; 2],
Vec::new(),
Vec::new(),
);
assert_eq!(idx.dim(), 64);
assert_eq!(idx.len(), 2);
}
}