pub mod codebook;
pub mod encode;
pub mod id_map;
pub mod io;
pub mod pack;
pub mod rotation;
pub mod search;
pub use id_map::IdMapIndex;
use std::path::Path;
use std::sync::OnceLock;
const ROTATION_SEED: u64 = 42;
const BLOCK: usize = 32;
const FLUSH_EVERY: usize = 256;
struct BlockedCache {
data: Vec<u8>,
n_blocks: usize,
}
pub struct TurboQuantIndex {
dim: Option<usize>,
bit_width: usize,
n_vectors: usize,
packed_codes: Vec<u8>,
scales: Vec<f32>,
rotation: OnceLock<Vec<f32>>,
centroids: OnceLock<Vec<f32>>,
blocked: OnceLock<BlockedCache>,
}
pub struct SearchResults {
pub scores: Vec<f32>,
pub indices: Vec<i64>,
pub nq: usize,
pub k: usize,
}
impl SearchResults {
pub fn scores_for_query(&self, qi: usize) -> &[f32] {
&self.scores[qi * self.k..(qi + 1) * self.k]
}
pub fn indices_for_query(&self, qi: usize) -> &[i64] {
&self.indices[qi * self.k..(qi + 1) * self.k]
}
}
impl TurboQuantIndex {
pub fn new(dim: usize, bit_width: usize) -> Self {
assert!((2..=4).contains(&bit_width), "bit_width must be 2, 3, or 4");
assert!(dim % 8 == 0, "dim must be a multiple of 8");
Self {
dim: Some(dim),
bit_width,
n_vectors: 0,
packed_codes: Vec::new(),
scales: Vec::new(),
rotation: OnceLock::new(),
centroids: OnceLock::new(),
blocked: OnceLock::new(),
}
}
pub fn new_lazy(bit_width: usize) -> Self {
assert!((2..=4).contains(&bit_width), "bit_width must be 2, 3, or 4");
Self {
dim: None,
bit_width,
n_vectors: 0,
packed_codes: Vec::new(),
scales: Vec::new(),
rotation: OnceLock::new(),
centroids: OnceLock::new(),
blocked: OnceLock::new(),
}
}
pub fn add(&mut self, vectors: &[f32]) {
let dim = self.dim.expect(
"TurboQuantIndex dim is not set; use add_2d(vectors, dim) on the \
first add or construct via TurboQuantIndex::new(dim, bit_width)",
);
let n = vectors.len() / dim;
assert_eq!(
vectors.len(),
n * dim,
"vectors length must be a multiple of dim"
);
let rotation = self
.rotation
.get_or_init(|| rotation::make_rotation_matrix(dim));
let (boundaries, centroids) = codebook::codebook(self.bit_width, dim);
let (packed, scales) = encode::encode(
vectors,
n,
dim,
rotation,
&boundaries,
¢roids,
self.bit_width,
);
if self.n_vectors == 0 {
self.packed_codes = packed;
self.scales = scales;
} else {
self.packed_codes.extend_from_slice(&packed);
self.scales.extend_from_slice(&scales);
}
self.n_vectors += n;
self.blocked = OnceLock::new();
}
pub fn add_2d(&mut self, vectors: &[f32], dim: usize) {
match self.dim {
Some(existing) => assert_eq!(
existing, dim,
"dim mismatch: index dim={existing}, batch dim={dim}"
),
None => {
assert!(dim % 8 == 0, "dim must be a multiple of 8");
self.dim = Some(dim);
}
}
self.add(vectors);
}
pub fn search(&self, queries: &[f32], k: usize) -> SearchResults {
self.search_with_mask(queries, k, None)
}
pub fn search_with_mask(
&self,
queries: &[f32],
k: usize,
mask: Option<&[bool]>,
) -> SearchResults {
let Some(dim) = self.dim else {
return SearchResults {
scores: Vec::new(),
indices: Vec::new(),
nq: 0,
k: 0,
};
};
let nq = queries.len() / dim;
assert_eq!(queries.len(), nq * dim);
let rotation = self
.rotation
.get_or_init(|| rotation::make_rotation_matrix(dim));
let centroids = self.centroids.get_or_init(|| {
let (_, c) = codebook::codebook(self.bit_width, dim);
c
});
let blocked = self.blocked.get_or_init(|| {
let (data, n_blocks) =
pack::repack(&self.packed_codes, self.n_vectors, self.bit_width, dim);
BlockedCache { data, n_blocks }
});
let packed_mask = mask.map(|m| {
assert_eq!(
m.len(),
self.n_vectors,
"mask length {} does not match index size {}",
m.len(),
self.n_vectors,
);
let n_words = (self.n_vectors + 63) / 64;
let mut buf = vec![0u64; n_words];
for (i, &b) in m.iter().enumerate() {
if b {
buf[i >> 6] |= 1u64 << (i & 63);
}
}
buf
});
let n_allowed = packed_mask.as_ref().map_or(self.n_vectors, |p| {
p.iter().map(|w| w.count_ones() as usize).sum::<usize>()
});
let effective_k = k.min(self.n_vectors).min(n_allowed);
let (scores, indices) = search::search(
queries,
nq,
rotation,
&blocked.data,
centroids,
&self.scales,
self.bit_width,
dim,
self.n_vectors,
blocked.n_blocks,
k,
packed_mask.as_deref(),
);
SearchResults {
scores,
indices,
nq,
k: effective_k,
}
}
pub fn prepare(&self) {
let Some(dim) = self.dim else { return };
self.rotation
.get_or_init(|| rotation::make_rotation_matrix(dim));
self.centroids.get_or_init(|| {
let (_, c) = codebook::codebook(self.bit_width, dim);
c
});
self.blocked.get_or_init(|| {
let (data, n_blocks) =
pack::repack(&self.packed_codes, self.n_vectors, self.bit_width, dim);
BlockedCache { data, n_blocks }
});
}
pub fn write(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
io::write(
path,
self.bit_width,
self.dim.unwrap_or(0),
self.n_vectors,
&self.packed_codes,
&self.scales,
)
}
pub fn load(path: impl AsRef<Path>) -> std::io::Result<Self> {
let (bit_width, dim, n_vectors, packed_codes, scales) = io::load(path)?;
let dim_opt = if dim == 0 { None } else { Some(dim) };
Ok(Self::from_parts(dim_opt, bit_width, n_vectors, packed_codes, scales))
}
pub(crate) fn from_parts(
dim: Option<usize>,
bit_width: usize,
n_vectors: usize,
packed_codes: Vec<u8>,
scales: Vec<f32>,
) -> Self {
Self {
dim,
bit_width,
n_vectors,
packed_codes,
scales,
rotation: OnceLock::new(),
centroids: OnceLock::new(),
blocked: OnceLock::new(),
}
}
pub(crate) fn packed_codes(&self) -> &[u8] {
&self.packed_codes
}
pub(crate) fn scales(&self) -> &[f32] {
&self.scales
}
pub fn swap_remove(&mut self, idx: usize) -> usize {
assert!(
idx < self.n_vectors,
"index {idx} out of bounds (n_vectors = {})",
self.n_vectors
);
let dim = self.dim.expect("n_vectors > 0 but dim is None");
let bytes_per_vec = dim * self.bit_width / 8;
let last = self.n_vectors - 1;
if idx != last {
let src = last * bytes_per_vec;
let dst = idx * bytes_per_vec;
self.packed_codes.copy_within(src..src + bytes_per_vec, dst);
self.scales[idx] = self.scales[last];
}
self.packed_codes.truncate(last * bytes_per_vec);
self.scales.truncate(last);
self.n_vectors -= 1;
self.blocked = OnceLock::new();
last
}
pub fn len(&self) -> usize {
self.n_vectors
}
pub fn is_empty(&self) -> bool {
self.n_vectors == 0
}
pub fn dim(&self) -> usize {
self.dim.unwrap_or(0)
}
pub fn dim_opt(&self) -> Option<usize> {
self.dim
}
pub fn bit_width(&self) -> usize {
self.bit_width
}
}