pub mod codebook;
pub mod encode;
pub mod io;
pub mod pack;
pub mod rotation;
pub mod search;
use std::path::Path;
use std::sync::OnceLock;
const ROTATION_SEED: u64 = 42;
const BLOCK: usize = 32;
const FLUSH_EVERY: usize = 256;
struct BlockedCache {
data: Vec<u8>,
n_blocks: usize,
}
pub struct TurboQuantIndex {
dim: usize,
bit_width: usize,
n_vectors: usize,
packed_codes: Vec<u8>,
norms: Vec<f32>,
rotation: OnceLock<Vec<f32>>,
centroids: OnceLock<Vec<f32>>,
blocked: OnceLock<BlockedCache>,
}
pub struct SearchResults {
pub scores: Vec<f32>,
pub indices: Vec<i64>,
pub nq: usize,
pub k: usize,
}
impl SearchResults {
pub fn scores_for_query(&self, qi: usize) -> &[f32] {
&self.scores[qi * self.k..(qi + 1) * self.k]
}
pub fn indices_for_query(&self, qi: usize) -> &[i64] {
&self.indices[qi * self.k..(qi + 1) * self.k]
}
}
impl TurboQuantIndex {
pub fn new(dim: usize, bit_width: usize) -> Self {
assert!((2..=4).contains(&bit_width), "bit_width must be 2, 3, or 4");
assert!(dim % 8 == 0, "dim must be a multiple of 8");
Self {
dim,
bit_width,
n_vectors: 0,
packed_codes: Vec::new(),
norms: Vec::new(),
rotation: OnceLock::new(),
centroids: OnceLock::new(),
blocked: OnceLock::new(),
}
}
pub fn add(&mut self, vectors: &[f32]) {
let n = vectors.len() / self.dim;
assert_eq!(
vectors.len(),
n * self.dim,
"vectors length must be a multiple of dim"
);
let rotation = self
.rotation
.get_or_init(|| rotation::make_rotation_matrix(self.dim));
let (boundaries, _) = codebook::codebook(self.bit_width, self.dim);
let (packed, norms) =
encode::encode(vectors, n, self.dim, rotation, &boundaries, self.bit_width);
if self.n_vectors == 0 {
self.packed_codes = packed;
self.norms = norms;
} else {
self.packed_codes.extend_from_slice(&packed);
self.norms.extend_from_slice(&norms);
}
self.n_vectors += n;
self.blocked = OnceLock::new();
}
pub fn search(&self, queries: &[f32], k: usize) -> SearchResults {
let nq = queries.len() / self.dim;
assert_eq!(queries.len(), nq * self.dim);
let rotation = self
.rotation
.get_or_init(|| rotation::make_rotation_matrix(self.dim));
let centroids = self.centroids.get_or_init(|| {
let (_, c) = codebook::codebook(self.bit_width, self.dim);
c
});
let blocked = self.blocked.get_or_init(|| {
let (data, n_blocks) =
pack::repack(&self.packed_codes, self.n_vectors, self.bit_width, self.dim);
BlockedCache { data, n_blocks }
});
let k = k.min(self.n_vectors);
let (scores, indices) = search::search(
queries,
nq,
rotation,
&blocked.data,
centroids,
&self.norms,
self.bit_width,
self.dim,
self.n_vectors,
blocked.n_blocks,
k,
);
SearchResults {
scores,
indices,
nq,
k,
}
}
pub fn prepare(&self) {
self.rotation
.get_or_init(|| rotation::make_rotation_matrix(self.dim));
self.centroids.get_or_init(|| {
let (_, c) = codebook::codebook(self.bit_width, self.dim);
c
});
self.blocked.get_or_init(|| {
let (data, n_blocks) =
pack::repack(&self.packed_codes, self.n_vectors, self.bit_width, self.dim);
BlockedCache { data, n_blocks }
});
}
pub fn write(&self, path: impl AsRef<Path>) -> std::io::Result<()> {
io::write(
path,
self.bit_width,
self.dim,
self.n_vectors,
&self.packed_codes,
&self.norms,
)
}
pub fn load(path: impl AsRef<Path>) -> std::io::Result<Self> {
let (bit_width, dim, n_vectors, packed_codes, norms) = io::load(path)?;
Ok(Self {
dim,
bit_width,
n_vectors,
packed_codes,
norms,
rotation: OnceLock::new(),
centroids: OnceLock::new(),
blocked: OnceLock::new(),
})
}
pub fn len(&self) -> usize {
self.n_vectors
}
pub fn is_empty(&self) -> bool {
self.n_vectors == 0
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn bit_width(&self) -> usize {
self.bit_width
}
}