1use serde::{Deserialize, Serialize};
14use tracing::{debug, info, warn};
15
16use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
17use ailake_vec::{kmeans_centroids, PQCodebook};
18
19fn kmeans_dispatch(vecs: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
21 if let Some(result) = crate::gpu::try_nvidia_kmeans(vecs, k, max_iter) {
22 debug!(
23 "ailake: IVF-PQ k-means used NVIDIA CUDA (n={} k={} max_iter={})",
24 vecs.len(),
25 k,
26 max_iter
27 );
28 return result;
29 }
30 if let Some(result) = crate::gpu::try_rocm_kmeans(vecs, k, max_iter) {
31 debug!(
32 "ailake: IVF-PQ k-means used AMD ROCm (n={} k={} max_iter={})",
33 vecs.len(),
34 k,
35 max_iter
36 );
37 return result;
38 }
39 debug!(
40 "ailake: IVF-PQ k-means using CPU rayon (n={} k={} max_iter={})",
41 vecs.len(),
42 k,
43 max_iter
44 );
45 kmeans_centroids(vecs, k, max_iter)
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct IvfPqConfig {
51 pub nlist: usize,
54 pub nprobe: usize,
57 pub pq_m: usize,
59 pub pq_k: usize,
61 pub max_iter: usize,
63}
64
65impl Default for IvfPqConfig {
66 fn default() -> Self {
67 Self {
68 nlist: 256,
69 nprobe: 8,
70 pq_m: 8,
71 pq_k: 256,
72 max_iter: 25,
73 }
74 }
75}
76
77impl IvfPqConfig {
78 pub fn for_dim(dim: usize) -> Self {
80 let pq_m = (dim / 16).clamp(4, 64);
81 Self {
82 pq_m: find_valid_pq_m(pq_m, dim),
83 ..Self::default()
84 }
85 }
86
87 pub fn for_dataset(dim: usize, n_vectors: usize) -> Self {
92 let nlist = ((n_vectors as f64).sqrt() as usize).clamp(16, 1024);
93 let nprobe = (nlist / 4).max(1); let pq_m_hint = (dim / 16).clamp(4, 64);
95 Self {
96 nlist,
97 nprobe,
98 pq_m: find_valid_pq_m(pq_m_hint, dim),
99 pq_k: 256,
100 max_iter: 25,
101 }
102 }
103}
104
105pub struct IvfPqIndex {
106 pub config: IvfPqConfig,
107 pub metric: VectorMetric,
108 pub dim: usize,
109 coarse_centroids: Vec<Vec<f32>>,
111 pq: PQCodebook,
113 inv_row_ids: Vec<Vec<u64>>,
115 inv_codes: Vec<Vec<u8>>,
117}
118
119#[derive(Clone)]
123pub struct IvfPqCodebook {
124 pub coarse_centroids: Vec<Vec<f32>>,
125 pub pq: PQCodebook,
126 pub nlist: usize,
127 pub nprobe: usize,
128 pub pq_m: usize,
129 pub dim: usize,
130 pub metric: VectorMetric,
131}
132
133impl IvfPqIndex {
134 pub fn train(
136 row_ids: &[RowId],
137 vectors: &[Vec<f32>],
138 metric: VectorMetric,
139 config: IvfPqConfig,
140 ) -> AilakeResult<Self> {
141 let codebook = Self::train_codebook(vectors, metric, &config)?;
142 Self::build_with_codebook(row_ids, vectors, &codebook)
143 }
144
145 pub fn train_codebook(
148 vectors: &[Vec<f32>],
149 metric: VectorMetric,
150 config: &IvfPqConfig,
151 ) -> AilakeResult<IvfPqCodebook> {
152 let n = vectors.len();
153 if n == 0 {
154 return Err(AilakeError::Catalog(
155 "IVF-PQ training requires at least 1 vector".into(),
156 ));
157 }
158 let dim = vectors[0].len();
159
160 let normed_storage: Vec<Vec<f32>>;
161 let vecs: &[Vec<f32>] = if metric == VectorMetric::Cosine {
162 normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
163 &normed_storage
164 } else {
165 vectors
166 };
167
168 let nlist = config.nlist.min(n);
169 if nlist < config.nlist {
170 warn!(
171 "ailake: IVF-PQ nlist clamped from {} to {} (n={} vectors); \
172 consider using HNSW for small datasets",
173 config.nlist, nlist, n
174 );
175 }
176 let nprobe = config.nprobe.min(nlist);
177 let pq_m = find_valid_pq_m(config.pq_m, dim);
178
179 info!(
180 "ailake: training IVF-PQ codebook — n={} dim={} nlist={} nprobe={} pq_m={}",
181 n, dim, nlist, nprobe, pq_m
182 );
183
184 let coarse_centroids = kmeans_dispatch(vecs, nlist, config.max_iter);
185 let pq = PQCodebook::train_with_kmeans(
186 vecs,
187 pq_m,
188 config.pq_k.min(256),
189 config.max_iter,
190 kmeans_dispatch,
191 )
192 .map_err(|e| AilakeError::Catalog(format!("PQ training failed: {e}")))?;
193
194 Ok(IvfPqCodebook {
195 coarse_centroids,
196 pq,
197 nlist,
198 nprobe,
199 pq_m,
200 dim,
201 metric,
202 })
203 }
204
205 pub fn build_with_codebook(
208 row_ids: &[RowId],
209 vectors: &[Vec<f32>],
210 codebook: &IvfPqCodebook,
211 ) -> AilakeResult<Self> {
212 let n = vectors.len();
213 if n == 0 {
214 return Err(AilakeError::Catalog(
215 "IVF-PQ build requires at least 1 vector".into(),
216 ));
217 }
218
219 let normed_storage: Vec<Vec<f32>>;
220 let vecs: &[Vec<f32>] = if codebook.metric == VectorMetric::Cosine {
221 normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
222 &normed_storage
223 } else {
224 vectors
225 };
226
227 let nlist = codebook.nlist;
228 let assignments: Vec<usize> = vecs
229 .iter()
230 .map(|v| nearest_idx(v, &codebook.coarse_centroids))
231 .collect();
232
233 let mut inv_row_ids = vec![Vec::new(); nlist];
234 let mut inv_codes = vec![Vec::new(); nlist];
235
236 for (i, (v, &list_idx)) in vecs.iter().zip(assignments.iter()).enumerate() {
237 let codes = codebook.pq.encode(v);
238 inv_row_ids[list_idx].push(row_ids[i].0);
239 inv_codes[list_idx].extend_from_slice(&codes);
240 }
241
242 Ok(IvfPqIndex {
243 config: IvfPqConfig {
244 nlist: codebook.nlist,
245 nprobe: codebook.nprobe,
246 pq_m: codebook.pq_m,
247 pq_k: codebook.pq.num_centroids,
248 max_iter: 0,
249 },
250 metric: codebook.metric,
251 dim: codebook.dim,
252 coarse_centroids: codebook.coarse_centroids.clone(),
253 pq: codebook.pq.clone(),
254 inv_row_ids,
255 inv_codes,
256 })
257 }
258
259 pub fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Vec<(RowId, f32)> {
263 let nprobe = nprobe.unwrap_or(self.config.nprobe).min(self.config.nlist);
264
265 let q_normed: Vec<f32>;
266 let q: &[f32] = if self.metric == VectorMetric::Cosine {
267 q_normed = l2_normalize(query);
268 &q_normed
269 } else {
270 query
271 };
272
273 let mut c_dists: Vec<(usize, f32)> = self
275 .coarse_centroids
276 .iter()
277 .enumerate()
278 .map(|(i, c)| (i, l2_sq(q, c)))
279 .collect();
280 c_dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
281 c_dists.truncate(nprobe);
282
283 let adc_table = self.pq.compute_adc_table(q);
285
286 let pq_m = self.config.pq_m;
288 let mut candidates: Vec<(RowId, f32)> = Vec::new();
289
290 for (list_idx, _) in &c_dists {
291 let row_ids = &self.inv_row_ids[*list_idx];
292 let codes_flat = &self.inv_codes[*list_idx];
293
294 for (j, &rid) in row_ids.iter().enumerate() {
295 let codes = &codes_flat[j * pq_m..(j + 1) * pq_m];
296 let dist = self.pq.adc_distance(codes, &adc_table);
297 candidates.push((RowId(rid), dist));
298 }
299 }
300
301 candidates.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
302 candidates.truncate(top_k);
303 candidates
304 }
305
306 pub fn node_count(&self) -> u64 {
307 self.inv_row_ids.iter().map(|l| l.len() as u64).sum()
308 }
309
310 pub fn dim(&self) -> usize {
311 self.dim
312 }
313}
314
315#[derive(Serialize, Deserialize)]
318struct IvfPqSnapshot {
319 nlist: usize,
320 nprobe: usize,
321 pq_m: usize,
322 pq_k: usize,
323 max_iter: usize,
324 dim: usize,
325 metric: u8,
326 coarse_flat: Vec<f32>, pq: PQCodebook,
328 inv_row_ids: Vec<Vec<u64>>,
329 inv_codes: Vec<Vec<u8>>, }
331
332pub struct IvfPqSerializer;
333
334impl IvfPqSerializer {
335 pub fn to_bytes(index: &IvfPqIndex) -> AilakeResult<Vec<u8>> {
336 let coarse_flat: Vec<f32> = index
337 .coarse_centroids
338 .iter()
339 .flat_map(|c| c.iter().copied())
340 .collect();
341 let snap = IvfPqSnapshot {
342 nlist: index.config.nlist,
343 nprobe: index.config.nprobe,
344 pq_m: index.config.pq_m,
345 pq_k: index.config.pq_k,
346 max_iter: index.config.max_iter,
347 dim: index.dim,
348 metric: metric_to_u8(index.metric),
349 coarse_flat,
350 pq: index.pq.clone(),
351 inv_row_ids: index.inv_row_ids.clone(),
352 inv_codes: index.inv_codes.clone(),
353 };
354 bincode::serialize(&snap).map_err(|e| AilakeError::Bincode(e.to_string()))
355 }
356
357 pub fn from_bytes(bytes: &[u8]) -> AilakeResult<IvfPqIndex> {
358 let snap: IvfPqSnapshot =
359 bincode::deserialize(bytes).map_err(|e| AilakeError::Bincode(e.to_string()))?;
360 let metric = u8_to_metric(snap.metric)?;
361 let coarse_centroids: Vec<Vec<f32>> = snap
362 .coarse_flat
363 .chunks_exact(snap.dim)
364 .map(|c| c.to_vec())
365 .collect();
366 Ok(IvfPqIndex {
367 config: IvfPqConfig {
368 nlist: snap.nlist,
369 nprobe: snap.nprobe,
370 pq_m: snap.pq_m,
371 pq_k: snap.pq_k,
372 max_iter: snap.max_iter,
373 },
374 metric,
375 dim: snap.dim,
376 coarse_centroids,
377 pq: snap.pq,
378 inv_row_ids: snap.inv_row_ids,
379 inv_codes: snap.inv_codes,
380 })
381 }
382}
383
384fn l2_normalize(v: &[f32]) -> Vec<f32> {
387 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
388 if norm < 1e-9 {
389 v.to_vec()
390 } else {
391 v.iter().map(|x| x / norm).collect()
392 }
393}
394
395fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
396 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
397}
398
399fn nearest_idx(v: &[f32], centroids: &[Vec<f32>]) -> usize {
400 centroids
401 .iter()
402 .enumerate()
403 .map(|(i, c)| (i, l2_sq(v, c)))
404 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
405 .map(|(i, _)| i)
406 .unwrap_or(0)
407}
408
409pub fn find_valid_pq_m(requested: usize, dim: usize) -> usize {
411 for m in (1..=requested).rev() {
412 if dim.is_multiple_of(m) {
413 return m;
414 }
415 }
416 1
417}
418
419fn metric_to_u8(m: VectorMetric) -> u8 {
420 match m {
421 VectorMetric::Cosine => 0,
422 VectorMetric::Euclidean => 1,
423 VectorMetric::DotProduct => 2,
424 VectorMetric::NormalizedCosine => 3,
425 }
426}
427
428fn u8_to_metric(v: u8) -> AilakeResult<VectorMetric> {
429 match v {
430 0 => Ok(VectorMetric::Cosine),
431 1 => Ok(VectorMetric::Euclidean),
432 2 => Ok(VectorMetric::DotProduct),
433 3 => Ok(VectorMetric::NormalizedCosine),
434 _ => Err(AilakeError::Catalog(format!("unknown metric byte: {v}"))),
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 fn make_vecs(n: usize, dim: usize) -> (Vec<RowId>, Vec<Vec<f32>>) {
443 let row_ids: Vec<RowId> = (0..n).map(|i| RowId(i as u64)).collect();
444 let vecs: Vec<Vec<f32>> = (0..n)
445 .map(|i| {
446 let mut v = vec![0.0f32; dim];
447 v[i % dim] = 1.0;
448 v
449 })
450 .collect();
451 (row_ids, vecs)
452 }
453
454 #[test]
455 fn train_and_search_basic() {
456 let dim = 8;
457 let (ids, vecs) = make_vecs(64, dim);
458 let config = IvfPqConfig {
459 nlist: 4,
460 nprobe: 2,
461 pq_m: 2,
462 pq_k: 4,
463 max_iter: 10,
464 };
465 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
466 assert_eq!(idx.node_count(), 64);
467
468 let query = vecs[0].clone();
469 let results = idx.search(&query, 5, None);
470 assert!(!results.is_empty());
471 assert!(results[0].1 < 0.1, "nearest should be approximate self");
473 }
474
475 #[test]
476 fn train_cosine_normalizes() {
477 let dim = 4;
478 let (ids, vecs) = make_vecs(32, dim);
479 let config = IvfPqConfig {
480 nlist: 4,
481 nprobe: 2,
482 pq_m: 2,
483 pq_k: 4,
484 max_iter: 10,
485 };
486 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Cosine, config).unwrap();
487 let results = idx.search(&vecs[0], 1, None);
488 assert!(!results.is_empty());
489 }
490
491 #[test]
492 fn serialize_roundtrip() {
493 let dim = 8;
494 let (ids, vecs) = make_vecs(32, dim);
495 let config = IvfPqConfig {
496 nlist: 4,
497 nprobe: 2,
498 pq_m: 2,
499 pq_k: 4,
500 max_iter: 10,
501 };
502 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
503 let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
504 let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
505
506 assert_eq!(idx2.node_count(), idx.node_count());
507 assert_eq!(idx2.dim(), idx.dim());
508
509 let q = vecs[0].clone();
510 let r1 = idx.search(&q, 5, None);
511 let r2 = idx2.search(&q, 5, None);
512 assert_eq!(r1.len(), r2.len());
513 for (a, b) in r1.iter().zip(r2.iter()) {
514 assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
515 }
516 }
517
518 #[test]
519 fn nlist_clamped_to_n() {
520 let dim = 4;
521 let (ids, vecs) = make_vecs(10, dim); let config = IvfPqConfig {
523 nlist: 256, nprobe: 8,
525 pq_m: 2,
526 pq_k: 4,
527 max_iter: 5,
528 };
529 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
530 assert!(idx.config.nlist <= 10);
531 assert_eq!(idx.node_count(), 10);
532 }
533}