1use serde::{Deserialize, Serialize};
18use tracing::{debug, info, warn};
19
20use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
21use ailake_vec::{kmeans_centroids, PQCodebook};
22
23fn kmeans_dispatch(vecs: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
25 if let Some(result) = crate::gpu::try_nvidia_kmeans(vecs, k, max_iter) {
26 debug!(
27 "ailake: IVF-PQ k-means used NVIDIA CUDA (n={} k={} max_iter={})",
28 vecs.len(),
29 k,
30 max_iter
31 );
32 return result;
33 }
34 if let Some(result) = crate::gpu::try_rocm_kmeans(vecs, k, max_iter) {
35 debug!(
36 "ailake: IVF-PQ k-means used AMD ROCm (n={} k={} max_iter={})",
37 vecs.len(),
38 k,
39 max_iter
40 );
41 return result;
42 }
43 debug!(
44 "ailake: IVF-PQ k-means using CPU rayon (n={} k={} max_iter={})",
45 vecs.len(),
46 k,
47 max_iter
48 );
49 kmeans_centroids(vecs, k, max_iter)
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct IvfPqConfig {
55 pub nlist: usize,
58 pub nprobe: usize,
61 pub pq_m: usize,
63 pub pq_k: usize,
65 pub max_iter: usize,
67 #[serde(default)]
70 pub residual: bool,
71}
72
73impl Default for IvfPqConfig {
74 fn default() -> Self {
75 Self {
76 nlist: 256,
77 nprobe: 8,
78 pq_m: 8,
79 pq_k: 256,
80 max_iter: 25,
81 residual: false,
82 }
83 }
84}
85
86impl IvfPqConfig {
87 pub fn for_dim(dim: usize) -> Self {
89 let pq_m = (dim / 8).clamp(4, 96);
90 Self {
91 pq_m: find_valid_pq_m(pq_m, dim),
92 ..Self::default()
93 }
94 }
95
96 pub fn for_dataset(dim: usize, n_vectors: usize) -> Self {
101 let nlist = ((n_vectors as f64).sqrt() as usize).clamp(16, 1024);
102 let nprobe = (nlist / 4).max(1); let pq_m_hint = (dim / 8).clamp(4, 96);
104 Self {
105 nlist,
106 nprobe,
107 pq_m: find_valid_pq_m(pq_m_hint, dim),
108 pq_k: 256,
109 max_iter: 25,
110 residual: false,
111 }
112 }
113
114 pub fn with_residual(mut self) -> Self {
116 self.residual = true;
117 self
118 }
119}
120
121pub struct IvfPqIndex {
122 pub config: IvfPqConfig,
123 pub metric: VectorMetric,
124 pub dim: usize,
125 coarse_centroids: Vec<Vec<f32>>,
127 pq: PQCodebook,
129 inv_row_ids: Vec<Vec<u64>>,
131 inv_codes: Vec<Vec<u8>>,
133 residual: bool,
135}
136
137#[derive(Clone)]
141pub struct IvfPqCodebook {
142 pub coarse_centroids: Vec<Vec<f32>>,
143 pub pq: PQCodebook,
144 pub nlist: usize,
145 pub nprobe: usize,
146 pub pq_m: usize,
147 pub dim: usize,
148 pub metric: VectorMetric,
149 pub residual: bool,
150}
151
152impl IvfPqIndex {
153 pub fn train(
155 row_ids: &[RowId],
156 vectors: &[Vec<f32>],
157 metric: VectorMetric,
158 config: IvfPqConfig,
159 ) -> AilakeResult<Self> {
160 let codebook = Self::train_codebook(vectors, metric, &config)?;
161 Self::build_with_codebook(row_ids, vectors, &codebook)
162 }
163
164 pub fn train_codebook(
167 vectors: &[Vec<f32>],
168 metric: VectorMetric,
169 config: &IvfPqConfig,
170 ) -> AilakeResult<IvfPqCodebook> {
171 let n = vectors.len();
172 if n == 0 {
173 return Err(AilakeError::Catalog(
174 "IVF-PQ training requires at least 1 vector".into(),
175 ));
176 }
177 let dim = vectors[0].len();
178
179 let normed_storage: Vec<Vec<f32>>;
180 let vecs: &[Vec<f32>] = if metric == VectorMetric::Cosine {
181 normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
182 &normed_storage
183 } else {
184 vectors
185 };
186
187 let nlist = config.nlist.min(n);
188 if nlist < config.nlist {
189 warn!(
190 "ailake: IVF-PQ nlist clamped from {} to {} (n={} vectors); \
191 consider using HNSW for small datasets",
192 config.nlist, nlist, n
193 );
194 }
195 let nprobe = config.nprobe.min(nlist);
196 let pq_m = find_valid_pq_m(config.pq_m, dim);
197
198 info!(
199 "ailake: training IVF-PQ codebook — n={} dim={} nlist={} nprobe={} pq_m={}",
200 n, dim, nlist, nprobe, pq_m
201 );
202
203 let coarse_centroids = kmeans_dispatch(vecs, nlist, config.max_iter);
204
205 let pq_train_vecs: Vec<Vec<f32>>;
206 let pq_input: &[Vec<f32>] = if config.residual {
207 let assignments: Vec<usize> = vecs
210 .iter()
211 .map(|v| nearest_idx(v, &coarse_centroids))
212 .collect();
213 pq_train_vecs = vecs
214 .iter()
215 .zip(assignments.iter())
216 .map(|(v, &c)| {
217 v.iter()
218 .zip(coarse_centroids[c].iter())
219 .map(|(a, b)| a - b)
220 .collect()
221 })
222 .collect();
223 &pq_train_vecs
224 } else {
225 vecs
226 };
227
228 let pq = PQCodebook::train_with_kmeans(
229 pq_input,
230 pq_m,
231 config.pq_k.min(256),
232 config.max_iter,
233 kmeans_dispatch,
234 )
235 .map_err(|e| AilakeError::Catalog(format!("PQ training failed: {e}")))?;
236
237 Ok(IvfPqCodebook {
238 coarse_centroids,
239 pq,
240 nlist,
241 nprobe,
242 pq_m,
243 dim,
244 metric,
245 residual: config.residual,
246 })
247 }
248
249 pub fn build_with_codebook(
252 row_ids: &[RowId],
253 vectors: &[Vec<f32>],
254 codebook: &IvfPqCodebook,
255 ) -> AilakeResult<Self> {
256 let n = vectors.len();
257 if n == 0 {
258 return Err(AilakeError::Catalog(
259 "IVF-PQ build requires at least 1 vector".into(),
260 ));
261 }
262
263 let normed_storage: Vec<Vec<f32>>;
264 let vecs: &[Vec<f32>] = if codebook.metric == VectorMetric::Cosine {
265 normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
266 &normed_storage
267 } else {
268 vectors
269 };
270
271 let nlist = codebook.nlist;
272 let assignments: Vec<usize> = vecs
273 .iter()
274 .map(|v| nearest_idx(v, &codebook.coarse_centroids))
275 .collect();
276
277 let mut inv_row_ids = vec![Vec::new(); nlist];
278 let mut inv_codes = vec![Vec::new(); nlist];
279
280 for (i, (v, &list_idx)) in vecs.iter().zip(assignments.iter()).enumerate() {
281 let codes = if codebook.residual {
282 let centroid = &codebook.coarse_centroids[list_idx];
283 let residual: Vec<f32> =
284 v.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect();
285 codebook.pq.encode(&residual)
286 } else {
287 codebook.pq.encode(v)
288 };
289 inv_row_ids[list_idx].push(row_ids[i].0);
290 inv_codes[list_idx].extend_from_slice(&codes);
291 }
292
293 Ok(IvfPqIndex {
294 config: IvfPqConfig {
295 nlist: codebook.nlist,
296 nprobe: codebook.nprobe,
297 pq_m: codebook.pq_m,
298 pq_k: codebook.pq.num_centroids,
299 max_iter: 0,
300 residual: codebook.residual,
301 },
302 metric: codebook.metric,
303 dim: codebook.dim,
304 coarse_centroids: codebook.coarse_centroids.clone(),
305 pq: codebook.pq.clone(),
306 inv_row_ids,
307 inv_codes,
308 residual: codebook.residual,
309 })
310 }
311
312 pub fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Vec<(RowId, f32)> {
316 let nprobe = nprobe.unwrap_or(self.config.nprobe).min(self.config.nlist);
317
318 let q_normed: Vec<f32>;
319 let q: &[f32] = if self.metric == VectorMetric::Cosine {
320 q_normed = l2_normalize(query);
321 &q_normed
322 } else {
323 query
324 };
325
326 let mut c_dists: Vec<(usize, f32)> = self
328 .coarse_centroids
329 .iter()
330 .enumerate()
331 .map(|(i, c)| (i, l2_sq(q, c)))
332 .collect();
333 c_dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
334 c_dists.truncate(nprobe);
335
336 let global_adc = if !self.residual {
339 Some(self.pq.compute_adc_table(q))
340 } else {
341 None
342 };
343
344 let pq_m = self.config.pq_m;
346 let mut candidates: Vec<(RowId, f32)> = Vec::new();
347
348 for (list_idx, _) in &c_dists {
349 let row_ids = &self.inv_row_ids[*list_idx];
350 let codes_flat = &self.inv_codes[*list_idx];
351
352 let cluster_adc;
354 let adc_table = if self.residual {
355 let centroid = &self.coarse_centroids[*list_idx];
356 let q_res: Vec<f32> = q.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect();
357 cluster_adc = self.pq.compute_adc_table(&q_res);
358 &cluster_adc
359 } else {
360 global_adc
362 .as_ref()
363 .expect("global_adc must be Some for non-residual path")
364 };
365
366 for (j, &rid) in row_ids.iter().enumerate() {
367 let codes = &codes_flat[j * pq_m..(j + 1) * pq_m];
368 let dist = self.pq.adc_distance(codes, adc_table);
369 candidates.push((RowId(rid), dist));
370 }
371 }
372
373 candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
374 candidates.truncate(top_k);
375 candidates
376 }
377
378 pub fn node_count(&self) -> u64 {
379 self.inv_row_ids.iter().map(|l| l.len() as u64).sum()
380 }
381
382 pub fn dim(&self) -> usize {
383 self.dim
384 }
385}
386
387#[derive(Serialize, Deserialize)]
403struct IvfPqSnapshotCore {
404 nlist: usize,
405 nprobe: usize,
406 pq_m: usize,
407 pq_k: usize,
408 max_iter: usize,
409 dim: usize,
410 metric: u8,
411 coarse_flat: Vec<f32>, pq: PQCodebook,
413 inv_row_ids: Vec<Vec<u64>>,
414 inv_codes: Vec<Vec<u8>>, }
416
417pub struct IvfPqSerializer;
418
419impl IvfPqSerializer {
420 pub fn to_bytes(index: &IvfPqIndex) -> AilakeResult<Vec<u8>> {
421 let coarse_flat: Vec<f32> = index
422 .coarse_centroids
423 .iter()
424 .flat_map(|c| c.iter().copied())
425 .collect();
426 let core = IvfPqSnapshotCore {
427 nlist: index.config.nlist,
428 nprobe: index.config.nprobe,
429 pq_m: index.config.pq_m,
430 pq_k: index.config.pq_k,
431 max_iter: index.config.max_iter,
432 dim: index.dim,
433 metric: metric_to_u8(index.metric),
434 coarse_flat,
435 pq: index.pq.clone(),
436 inv_row_ids: index.inv_row_ids.clone(),
437 inv_codes: index.inv_codes.clone(),
438 };
439 let mut bytes =
440 bincode::serialize(&core).map_err(|e| AilakeError::Bincode(e.to_string()))?;
441 bytes.push(u8::from(index.residual));
444 Ok(bytes)
445 }
446
447 pub fn from_bytes(bytes: &[u8]) -> AilakeResult<IvfPqIndex> {
448 let mut cursor = std::io::Cursor::new(bytes);
451 let core: IvfPqSnapshotCore = bincode::deserialize_from(&mut cursor)
452 .map_err(|e| AilakeError::Bincode(e.to_string()))?;
453
454 let residual = if (cursor.position() as usize) < bytes.len() {
455 bytes[cursor.position() as usize] != 0
456 } else {
457 false };
459
460 let metric = u8_to_metric(core.metric)?;
461 let coarse_centroids: Vec<Vec<f32>> = core
462 .coarse_flat
463 .chunks_exact(core.dim)
464 .map(|c| c.to_vec())
465 .collect();
466 Ok(IvfPqIndex {
467 config: IvfPqConfig {
468 nlist: core.nlist,
469 nprobe: core.nprobe,
470 pq_m: core.pq_m,
471 pq_k: core.pq_k,
472 max_iter: core.max_iter,
473 residual,
474 },
475 metric,
476 dim: core.dim,
477 coarse_centroids,
478 pq: core.pq,
479 inv_row_ids: core.inv_row_ids,
480 inv_codes: core.inv_codes,
481 residual,
482 })
483 }
484}
485
486fn l2_normalize(v: &[f32]) -> Vec<f32> {
489 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
490 if norm < 1e-9 {
491 v.to_vec()
492 } else {
493 v.iter().map(|x| x / norm).collect()
494 }
495}
496
497fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
498 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
499}
500
501fn nearest_idx(v: &[f32], centroids: &[Vec<f32>]) -> usize {
502 centroids
503 .iter()
504 .enumerate()
505 .map(|(i, c)| (i, l2_sq(v, c)))
506 .min_by(|a, b| a.1.total_cmp(&b.1))
507 .map(|(i, _)| i)
508 .unwrap_or(0)
509}
510
511pub fn find_valid_pq_m(requested: usize, dim: usize) -> usize {
513 for m in (1..=requested).rev() {
514 if dim.is_multiple_of(m) {
515 return m;
516 }
517 }
518 1
519}
520
521fn metric_to_u8(m: VectorMetric) -> u8 {
522 match m {
523 VectorMetric::Cosine => 0,
524 VectorMetric::Euclidean => 1,
525 VectorMetric::DotProduct => 2,
526 VectorMetric::NormalizedCosine => 3,
527 }
528}
529
530fn u8_to_metric(v: u8) -> AilakeResult<VectorMetric> {
531 match v {
532 0 => Ok(VectorMetric::Cosine),
533 1 => Ok(VectorMetric::Euclidean),
534 2 => Ok(VectorMetric::DotProduct),
535 3 => Ok(VectorMetric::NormalizedCosine),
536 _ => Err(AilakeError::Catalog(format!("unknown metric byte: {v}"))),
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 fn make_vecs(n: usize, dim: usize) -> (Vec<RowId>, Vec<Vec<f32>>) {
545 let row_ids: Vec<RowId> = (0..n).map(|i| RowId(i as u64)).collect();
546 let vecs: Vec<Vec<f32>> = (0..n)
547 .map(|i| {
548 let mut v = vec![0.0f32; dim];
549 v[i % dim] = 1.0;
550 v
551 })
552 .collect();
553 (row_ids, vecs)
554 }
555
556 #[test]
557 fn train_and_search_basic() {
558 let dim = 8;
559 let (ids, vecs) = make_vecs(64, dim);
560 let config = IvfPqConfig {
561 nlist: 4,
562 nprobe: 2,
563 pq_m: 2,
564 pq_k: 4,
565 max_iter: 10,
566 residual: false,
567 };
568 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
569 assert_eq!(idx.node_count(), 64);
570
571 let query = vecs[0].clone();
572 let results = idx.search(&query, 5, None);
573 assert!(!results.is_empty());
574 assert!(results[0].1 < 0.1, "nearest should be approximate self");
576 }
577
578 #[test]
579 fn train_cosine_normalizes() {
580 let dim = 4;
581 let (ids, vecs) = make_vecs(32, dim);
582 let config = IvfPqConfig {
583 nlist: 4,
584 nprobe: 2,
585 pq_m: 2,
586 pq_k: 4,
587 max_iter: 10,
588 residual: false,
589 };
590 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Cosine, config).unwrap();
591 let results = idx.search(&vecs[0], 1, None);
592 assert!(!results.is_empty());
593 }
594
595 #[test]
596 fn serialize_roundtrip() {
597 let dim = 8;
598 let (ids, vecs) = make_vecs(32, dim);
599 let config = IvfPqConfig {
600 nlist: 4,
601 nprobe: 2,
602 pq_m: 2,
603 pq_k: 4,
604 max_iter: 10,
605 residual: false,
606 };
607 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
608 let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
609 let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
610
611 assert_eq!(idx2.node_count(), idx.node_count());
612 assert_eq!(idx2.dim(), idx.dim());
613
614 let q = vecs[0].clone();
615 let r1 = idx.search(&q, 5, None);
616 let r2 = idx2.search(&q, 5, None);
617 assert_eq!(r1.len(), r2.len());
618 for (a, b) in r1.iter().zip(r2.iter()) {
619 assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
620 }
621 }
622
623 #[test]
624 fn nlist_clamped_to_n() {
625 let dim = 4;
626 let (ids, vecs) = make_vecs(10, dim); let config = IvfPqConfig {
628 nlist: 256, nprobe: 8,
630 pq_m: 2,
631 pq_k: 4,
632 max_iter: 5,
633 residual: false,
634 };
635 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
636 assert!(idx.config.nlist <= 10);
637 assert_eq!(idx.node_count(), 10);
638 }
639
640 #[test]
641 fn residual_pq_search_finds_nearest() {
642 let dim = 8;
643 let (ids, vecs) = make_vecs(64, dim);
644 let config = IvfPqConfig {
645 nlist: 4,
646 nprobe: 4,
647 pq_m: 2,
648 pq_k: 4,
649 max_iter: 10,
650 residual: true,
651 };
652 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
653 assert_eq!(idx.node_count(), 64);
654 assert!(idx.residual);
655
656 let query = vecs[0].clone();
657 let results = idx.search(&query, 5, None);
658 assert!(!results.is_empty());
659 assert!(
660 results[0].1 < 0.1,
661 "nearest residual-PQ result should be close to query"
662 );
663 }
664
665 #[test]
666 fn residual_pq_serialize_roundtrip() {
667 let dim = 8;
668 let (ids, vecs) = make_vecs(32, dim);
669 let config = IvfPqConfig {
670 nlist: 4,
671 nprobe: 2,
672 pq_m: 2,
673 pq_k: 4,
674 max_iter: 10,
675 residual: true,
676 };
677 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
678 let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
679 let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
680
681 assert_eq!(idx2.node_count(), idx.node_count());
682 assert!(idx2.residual, "residual flag must survive roundtrip");
683
684 let q = vecs[0].clone();
685 let r1 = idx.search(&q, 5, None);
686 let r2 = idx2.search(&q, 5, None);
687 assert_eq!(r1.len(), r2.len());
688 for (a, b) in r1.iter().zip(r2.iter()) {
689 assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
690 }
691 }
692
693 #[test]
694 fn non_residual_snapshot_deserializes_as_false() {
695 let dim = 8;
697 let (ids, vecs) = make_vecs(16, dim);
698 let config = IvfPqConfig {
699 nlist: 2,
700 nprobe: 1,
701 pq_m: 2,
702 pq_k: 4,
703 max_iter: 5,
704 residual: false,
705 };
706 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
707 let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
708 let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
709 assert!(
710 !idx2.residual,
711 "non-residual index must deserialize as residual=false"
712 );
713 }
714}