use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
use ailake_vec::{kmeans_centroids, PQCodebook};
fn kmeans_dispatch(vecs: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
if let Some(result) = crate::gpu::try_nvidia_kmeans(vecs, k, max_iter) {
debug!(
"ailake: IVF-PQ k-means used NVIDIA CUDA (n={} k={} max_iter={})",
vecs.len(),
k,
max_iter
);
return result;
}
if let Some(result) = crate::gpu::try_rocm_kmeans(vecs, k, max_iter) {
debug!(
"ailake: IVF-PQ k-means used AMD ROCm (n={} k={} max_iter={})",
vecs.len(),
k,
max_iter
);
return result;
}
debug!(
"ailake: IVF-PQ k-means using CPU rayon (n={} k={} max_iter={})",
vecs.len(),
k,
max_iter
);
kmeans_centroids(vecs, k, max_iter)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IvfPqConfig {
pub nlist: usize,
pub nprobe: usize,
pub pq_m: usize,
pub pq_k: usize,
pub max_iter: usize,
#[serde(default)]
pub residual: bool,
}
impl Default for IvfPqConfig {
fn default() -> Self {
Self {
nlist: 256,
nprobe: 8,
pq_m: 8,
pq_k: 256,
max_iter: 25,
residual: false,
}
}
}
impl IvfPqConfig {
pub fn for_dim(dim: usize) -> Self {
let pq_m = (dim / 8).clamp(4, 96);
Self {
pq_m: find_valid_pq_m(pq_m, dim),
..Self::default()
}
}
pub fn for_dataset(dim: usize, n_vectors: usize) -> Self {
let nlist = ((n_vectors as f64).sqrt() as usize).clamp(16, 1024);
let nprobe = (nlist / 4).max(1); let pq_m_hint = (dim / 8).clamp(4, 96);
Self {
nlist,
nprobe,
pq_m: find_valid_pq_m(pq_m_hint, dim),
pq_k: 256,
max_iter: 25,
residual: false,
}
}
pub fn with_residual(mut self) -> Self {
self.residual = true;
self
}
}
pub struct IvfPqIndex {
pub config: IvfPqConfig,
pub metric: VectorMetric,
pub dim: usize,
coarse_centroids: Vec<Vec<f32>>,
pq: PQCodebook,
inv_row_ids: Vec<Vec<u64>>,
inv_codes: Vec<Vec<u8>>,
residual: bool,
}
#[derive(Clone)]
pub struct IvfPqCodebook {
pub coarse_centroids: Vec<Vec<f32>>,
pub pq: PQCodebook,
pub nlist: usize,
pub nprobe: usize,
pub pq_m: usize,
pub dim: usize,
pub metric: VectorMetric,
pub residual: bool,
}
impl IvfPqIndex {
pub fn train(
row_ids: &[RowId],
vectors: &[Vec<f32>],
metric: VectorMetric,
config: IvfPqConfig,
) -> AilakeResult<Self> {
let codebook = Self::train_codebook(vectors, metric, &config)?;
Self::build_with_codebook(row_ids, vectors, &codebook)
}
pub fn train_codebook(
vectors: &[Vec<f32>],
metric: VectorMetric,
config: &IvfPqConfig,
) -> AilakeResult<IvfPqCodebook> {
let n = vectors.len();
if n == 0 {
return Err(AilakeError::Catalog(
"IVF-PQ training requires at least 1 vector".into(),
));
}
let dim = vectors[0].len();
let normed_storage: Vec<Vec<f32>>;
let vecs: &[Vec<f32>] = if metric == VectorMetric::Cosine {
normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
&normed_storage
} else {
vectors
};
let nlist = config.nlist.min(n);
if nlist < config.nlist {
warn!(
"ailake: IVF-PQ nlist clamped from {} to {} (n={} vectors); \
consider using HNSW for small datasets",
config.nlist, nlist, n
);
}
let nprobe = config.nprobe.min(nlist);
let pq_m = find_valid_pq_m(config.pq_m, dim);
info!(
"ailake: training IVF-PQ codebook — n={} dim={} nlist={} nprobe={} pq_m={}",
n, dim, nlist, nprobe, pq_m
);
let coarse_centroids = kmeans_dispatch(vecs, nlist, config.max_iter);
let pq_train_vecs: Vec<Vec<f32>>;
let pq_input: &[Vec<f32>] = if config.residual {
let assignments: Vec<usize> = vecs
.iter()
.map(|v| nearest_idx(v, &coarse_centroids))
.collect();
pq_train_vecs = vecs
.iter()
.zip(assignments.iter())
.map(|(v, &c)| {
v.iter()
.zip(coarse_centroids[c].iter())
.map(|(a, b)| a - b)
.collect()
})
.collect();
&pq_train_vecs
} else {
vecs
};
let pq = PQCodebook::train_with_kmeans(
pq_input,
pq_m,
config.pq_k.min(256),
config.max_iter,
kmeans_dispatch,
)
.map_err(|e| AilakeError::Catalog(format!("PQ training failed: {e}")))?;
Ok(IvfPqCodebook {
coarse_centroids,
pq,
nlist,
nprobe,
pq_m,
dim,
metric,
residual: config.residual,
})
}
pub fn build_with_codebook(
row_ids: &[RowId],
vectors: &[Vec<f32>],
codebook: &IvfPqCodebook,
) -> AilakeResult<Self> {
let n = vectors.len();
if n == 0 {
return Err(AilakeError::Catalog(
"IVF-PQ build requires at least 1 vector".into(),
));
}
let normed_storage: Vec<Vec<f32>>;
let vecs: &[Vec<f32>] = if codebook.metric == VectorMetric::Cosine {
normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
&normed_storage
} else {
vectors
};
let nlist = codebook.nlist;
let assignments: Vec<usize> = vecs
.iter()
.map(|v| nearest_idx(v, &codebook.coarse_centroids))
.collect();
let mut inv_row_ids = vec![Vec::new(); nlist];
let mut inv_codes = vec![Vec::new(); nlist];
for (i, (v, &list_idx)) in vecs.iter().zip(assignments.iter()).enumerate() {
let codes = if codebook.residual {
let centroid = &codebook.coarse_centroids[list_idx];
let residual: Vec<f32> =
v.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect();
codebook.pq.encode(&residual)
} else {
codebook.pq.encode(v)
};
inv_row_ids[list_idx].push(row_ids[i].0);
inv_codes[list_idx].extend_from_slice(&codes);
}
Ok(IvfPqIndex {
config: IvfPqConfig {
nlist: codebook.nlist,
nprobe: codebook.nprobe,
pq_m: codebook.pq_m,
pq_k: codebook.pq.num_centroids,
max_iter: 0,
residual: codebook.residual,
},
metric: codebook.metric,
dim: codebook.dim,
coarse_centroids: codebook.coarse_centroids.clone(),
pq: codebook.pq.clone(),
inv_row_ids,
inv_codes,
residual: codebook.residual,
})
}
pub fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Vec<(RowId, f32)> {
let nprobe = nprobe.unwrap_or(self.config.nprobe).min(self.config.nlist);
let q_normed: Vec<f32>;
let q: &[f32] = if self.metric == VectorMetric::Cosine {
q_normed = l2_normalize(query);
&q_normed
} else {
query
};
let mut c_dists: Vec<(usize, f32)> = self
.coarse_centroids
.iter()
.enumerate()
.map(|(i, c)| (i, l2_sq(q, c)))
.collect();
c_dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
c_dists.truncate(nprobe);
let global_adc = if !self.residual {
Some(self.pq.compute_adc_table(q))
} else {
None
};
let pq_m = self.config.pq_m;
let mut candidates: Vec<(RowId, f32)> = Vec::new();
for (list_idx, _) in &c_dists {
let row_ids = &self.inv_row_ids[*list_idx];
let codes_flat = &self.inv_codes[*list_idx];
let cluster_adc;
let adc_table = if self.residual {
let centroid = &self.coarse_centroids[*list_idx];
let q_res: Vec<f32> = q.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect();
cluster_adc = self.pq.compute_adc_table(&q_res);
&cluster_adc
} else {
global_adc
.as_ref()
.expect("global_adc must be Some for non-residual path")
};
for (j, &rid) in row_ids.iter().enumerate() {
let codes = &codes_flat[j * pq_m..(j + 1) * pq_m];
let dist = self.pq.adc_distance(codes, adc_table);
candidates.push((RowId(rid), dist));
}
}
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
candidates.truncate(top_k);
candidates
}
pub fn node_count(&self) -> u64 {
self.inv_row_ids.iter().map(|l| l.len() as u64).sum()
}
pub fn dim(&self) -> usize {
self.dim
}
}
#[derive(Serialize, Deserialize)]
struct IvfPqSnapshotCore {
nlist: usize,
nprobe: usize,
pq_m: usize,
pq_k: usize,
max_iter: usize,
dim: usize,
metric: u8,
coarse_flat: Vec<f32>, pq: PQCodebook,
inv_row_ids: Vec<Vec<u64>>,
inv_codes: Vec<Vec<u8>>, }
pub struct IvfPqSerializer;
impl IvfPqSerializer {
pub fn to_bytes(index: &IvfPqIndex) -> AilakeResult<Vec<u8>> {
let coarse_flat: Vec<f32> = index
.coarse_centroids
.iter()
.flat_map(|c| c.iter().copied())
.collect();
let core = IvfPqSnapshotCore {
nlist: index.config.nlist,
nprobe: index.config.nprobe,
pq_m: index.config.pq_m,
pq_k: index.config.pq_k,
max_iter: index.config.max_iter,
dim: index.dim,
metric: metric_to_u8(index.metric),
coarse_flat,
pq: index.pq.clone(),
inv_row_ids: index.inv_row_ids.clone(),
inv_codes: index.inv_codes.clone(),
};
let mut bytes =
bincode::serialize(&core).map_err(|e| AilakeError::Bincode(e.to_string()))?;
bytes.push(u8::from(index.residual));
Ok(bytes)
}
pub fn from_bytes(bytes: &[u8]) -> AilakeResult<IvfPqIndex> {
let mut cursor = std::io::Cursor::new(bytes);
let core: IvfPqSnapshotCore = bincode::deserialize_from(&mut cursor)
.map_err(|e| AilakeError::Bincode(e.to_string()))?;
let residual = if (cursor.position() as usize) < bytes.len() {
bytes[cursor.position() as usize] != 0
} else {
false };
let metric = u8_to_metric(core.metric)?;
let coarse_centroids: Vec<Vec<f32>> = core
.coarse_flat
.chunks_exact(core.dim)
.map(|c| c.to_vec())
.collect();
Ok(IvfPqIndex {
config: IvfPqConfig {
nlist: core.nlist,
nprobe: core.nprobe,
pq_m: core.pq_m,
pq_k: core.pq_k,
max_iter: core.max_iter,
residual,
},
metric,
dim: core.dim,
coarse_centroids,
pq: core.pq,
inv_row_ids: core.inv_row_ids,
inv_codes: core.inv_codes,
residual,
})
}
}
fn l2_normalize(v: &[f32]) -> Vec<f32> {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-9 {
v.to_vec()
} else {
v.iter().map(|x| x / norm).collect()
}
}
fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
}
fn nearest_idx(v: &[f32], centroids: &[Vec<f32>]) -> usize {
centroids
.iter()
.enumerate()
.map(|(i, c)| (i, l2_sq(v, c)))
.min_by(|a, b| a.1.total_cmp(&b.1))
.map(|(i, _)| i)
.unwrap_or(0)
}
pub fn find_valid_pq_m(requested: usize, dim: usize) -> usize {
for m in (1..=requested).rev() {
if dim.is_multiple_of(m) {
return m;
}
}
1
}
fn metric_to_u8(m: VectorMetric) -> u8 {
match m {
VectorMetric::Cosine => 0,
VectorMetric::Euclidean => 1,
VectorMetric::DotProduct => 2,
VectorMetric::NormalizedCosine => 3,
}
}
fn u8_to_metric(v: u8) -> AilakeResult<VectorMetric> {
match v {
0 => Ok(VectorMetric::Cosine),
1 => Ok(VectorMetric::Euclidean),
2 => Ok(VectorMetric::DotProduct),
3 => Ok(VectorMetric::NormalizedCosine),
_ => Err(AilakeError::Catalog(format!("unknown metric byte: {v}"))),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_vecs(n: usize, dim: usize) -> (Vec<RowId>, Vec<Vec<f32>>) {
let row_ids: Vec<RowId> = (0..n).map(|i| RowId(i as u64)).collect();
let vecs: Vec<Vec<f32>> = (0..n)
.map(|i| {
let mut v = vec![0.0f32; dim];
v[i % dim] = 1.0;
v
})
.collect();
(row_ids, vecs)
}
#[test]
fn train_and_search_basic() {
let dim = 8;
let (ids, vecs) = make_vecs(64, dim);
let config = IvfPqConfig {
nlist: 4,
nprobe: 2,
pq_m: 2,
pq_k: 4,
max_iter: 10,
residual: false,
};
let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
assert_eq!(idx.node_count(), 64);
let query = vecs[0].clone();
let results = idx.search(&query, 5, None);
assert!(!results.is_empty());
assert!(results[0].1 < 0.1, "nearest should be approximate self");
}
#[test]
fn train_cosine_normalizes() {
let dim = 4;
let (ids, vecs) = make_vecs(32, dim);
let config = IvfPqConfig {
nlist: 4,
nprobe: 2,
pq_m: 2,
pq_k: 4,
max_iter: 10,
residual: false,
};
let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Cosine, config).unwrap();
let results = idx.search(&vecs[0], 1, None);
assert!(!results.is_empty());
}
#[test]
fn serialize_roundtrip() {
let dim = 8;
let (ids, vecs) = make_vecs(32, dim);
let config = IvfPqConfig {
nlist: 4,
nprobe: 2,
pq_m: 2,
pq_k: 4,
max_iter: 10,
residual: false,
};
let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
assert_eq!(idx2.node_count(), idx.node_count());
assert_eq!(idx2.dim(), idx.dim());
let q = vecs[0].clone();
let r1 = idx.search(&q, 5, None);
let r2 = idx2.search(&q, 5, None);
assert_eq!(r1.len(), r2.len());
for (a, b) in r1.iter().zip(r2.iter()) {
assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
}
}
#[test]
fn nlist_clamped_to_n() {
let dim = 4;
let (ids, vecs) = make_vecs(10, dim); let config = IvfPqConfig {
nlist: 256, nprobe: 8,
pq_m: 2,
pq_k: 4,
max_iter: 5,
residual: false,
};
let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
assert!(idx.config.nlist <= 10);
assert_eq!(idx.node_count(), 10);
}
#[test]
fn residual_pq_search_finds_nearest() {
let dim = 8;
let (ids, vecs) = make_vecs(64, dim);
let config = IvfPqConfig {
nlist: 4,
nprobe: 4,
pq_m: 2,
pq_k: 4,
max_iter: 10,
residual: true,
};
let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
assert_eq!(idx.node_count(), 64);
assert!(idx.residual);
let query = vecs[0].clone();
let results = idx.search(&query, 5, None);
assert!(!results.is_empty());
assert!(
results[0].1 < 0.1,
"nearest residual-PQ result should be close to query"
);
}
#[test]
fn residual_pq_serialize_roundtrip() {
let dim = 8;
let (ids, vecs) = make_vecs(32, dim);
let config = IvfPqConfig {
nlist: 4,
nprobe: 2,
pq_m: 2,
pq_k: 4,
max_iter: 10,
residual: true,
};
let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
assert_eq!(idx2.node_count(), idx.node_count());
assert!(idx2.residual, "residual flag must survive roundtrip");
let q = vecs[0].clone();
let r1 = idx.search(&q, 5, None);
let r2 = idx2.search(&q, 5, None);
assert_eq!(r1.len(), r2.len());
for (a, b) in r1.iter().zip(r2.iter()) {
assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
}
}
#[test]
fn non_residual_snapshot_deserializes_as_false() {
let dim = 8;
let (ids, vecs) = make_vecs(16, dim);
let config = IvfPqConfig {
nlist: 2,
nprobe: 1,
pq_m: 2,
pq_k: 4,
max_iter: 5,
residual: false,
};
let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
assert!(
!idx2.residual,
"non-residual index must deserialize as residual=false"
);
}
}