1use diskann_utils::{
7 strided::StridedView,
8 views::{self, Matrix},
9};
10#[cfg(feature = "rayon")]
11use rayon::iter::{IntoParallelIterator, ParallelIterator};
12use thiserror::Error;
13
14use crate::{
15 Parallelism,
16 algorithms::kmeans::{self, common::square_norm},
17 cancel::Cancelation,
18 multi_vector::BlockTransposed,
19 random::{BoxedRngBuilder, RngBuilder},
20};
21
22pub struct LightPQTrainingParameters {
23 ncenters: usize,
25 lloyds_reps: usize,
27}
28
29impl LightPQTrainingParameters {
30 pub fn new(ncenters: usize, lloyds_reps: usize) -> Self {
32 Self {
33 ncenters,
34 lloyds_reps,
35 }
36 }
37}
38
39#[derive(Debug)]
40pub struct SimplePivots {
41 dim: usize,
42 ncenters: usize,
43 pivots: Vec<Matrix<f32>>,
44}
45
46fn flatten<T: Copy + Default>(pivots: &[Matrix<T>], ncenters: usize, dim: usize) -> Matrix<T> {
47 let mut flattened = Matrix::new(T::default(), ncenters, dim);
48 let mut col_start = 0;
49 for matrix in pivots {
50 assert_eq!(matrix.nrows(), flattened.nrows());
51 for (row_index, row) in matrix.row_iter().enumerate() {
52 let dst = &mut flattened.row_mut(row_index)[col_start..col_start + row.len()];
53 dst.copy_from_slice(row);
54 }
55 col_start += matrix.ncols();
56 }
57 flattened
58}
59
60impl SimplePivots {
61 pub fn pivots(&self) -> &[Matrix<f32>] {
63 &self.pivots
64 }
65
66 pub fn flatten(&self) -> Vec<f32> {
68 flatten(self.pivots(), self.ncenters, self.dim)
69 .into_inner()
70 .into()
71 }
72}
73
74pub trait TrainQuantizer {
75 type Quantizer;
76 type Error: std::error::Error;
77
78 fn train<R, C>(
79 &self,
80 data: views::MatrixView<f32>,
81 schema: crate::views::ChunkOffsetsView<'_>,
82 parallelism: Parallelism,
83 rng_builder: &R,
84 cancelation: &C,
85 ) -> Result<Self::Quantizer, Self::Error>
86 where
87 R: RngBuilder<usize> + Sync,
88 C: Cancelation + Sync;
89}
90
91impl TrainQuantizer for LightPQTrainingParameters {
92 type Quantizer = SimplePivots;
93 type Error = PQTrainingError;
94
95 fn train<R, C>(
110 &self,
111 data: views::MatrixView<f32>,
112 schema: crate::views::ChunkOffsetsView<'_>,
113 parallelism: Parallelism,
114 rng_builder: &R,
115 cancelation: &C,
116 ) -> Result<Self::Quantizer, Self::Error>
117 where
118 R: RngBuilder<usize> + Sync,
119 C: Cancelation + Sync,
120 {
121 #[inline(never)]
124 fn train(
125 trainer: &LightPQTrainingParameters,
126 data: views::MatrixView<f32>,
127 schema: crate::views::ChunkOffsetsView<'_>,
128 parallelism: Parallelism,
129 rng_builder: &(dyn BoxedRngBuilder<usize> + Sync),
130 cancelation: &(dyn Cancelation + Sync),
131 ) -> Result<SimplePivots, PQTrainingError> {
132 assert_eq!(data.ncols(), schema.dim());
134
135 let thunk = |i| -> Result<Matrix<f32>, PQTrainingError> {
136 let range = schema.at(i);
137
138 let exit_if_canceled = || -> Result<(), PQTrainingError> {
140 if cancelation.should_cancel() {
141 Err(PQTrainingError {
142 chunk: i,
143 of: schema.len(),
144 dim: range.len(),
145 kind: PQTrainingErrorKind::Canceled,
146 })
147 } else {
148 Ok(())
149 }
150 };
151
152 exit_if_canceled()?;
155
156 let view = StridedView::try_shrink_from(
157 &(data.as_slice()[range.start..]),
158 data.nrows(),
159 range.len(),
160 schema.dim(),
161 )
162 .map_err(|err| PQTrainingError {
163 chunk: i,
164 of: schema.len(),
165 dim: range.len(),
166 kind: PQTrainingErrorKind::InternalError(Box::new(err.as_static())),
167 })?;
168
169 let norms: Vec<f32> = view.row_iter().map(square_norm).collect();
171 let transpose = BlockTransposed::<f32, 16>::from_strided(view);
172 let mut centers = Matrix::new(0.0, trainer.ncenters, range.len());
173
174 let mut rng = rng_builder.build_boxed_rng(i);
176
177 kmeans::plusplus::kmeans_plusplus_into_inner(
179 centers.as_mut_view(),
180 view,
181 transpose.as_view(),
182 &norms,
183 &mut rng,
184 )
185 .or_else(|err| {
186 if !err.is_numerically_recoverable() {
188 Err(PQTrainingError {
189 chunk: i,
190 of: schema.len(),
191 dim: range.len(),
192 kind: PQTrainingErrorKind::Initialization(Box::new(err)),
193 })
194 } else {
195 Ok(())
196 }
197 })?;
198
199 exit_if_canceled()?;
201
202 kmeans::lloyds::lloyds_inner(
204 view,
205 &norms,
206 transpose.as_view(),
207 centers.as_mut_view(),
208 trainer.lloyds_reps,
209 );
210 Ok(centers)
211 };
212
213 let pivots: Result<Vec<_>, _> = match parallelism {
214 Parallelism::Sequential => (0..schema.len()).map(thunk).collect(),
215
216 #[cfg(feature = "rayon")]
217 Parallelism::Rayon => (0..schema.len()).into_par_iter().map(thunk).collect(),
218 };
219
220 let dim = data.ncols();
221 let ncenters = trainer.ncenters;
222 Ok(SimplePivots {
223 dim,
224 ncenters,
225 pivots: pivots?,
226 })
227 }
228
229 train(self, data, schema, parallelism, rng_builder, cancelation)
230 }
231}
232
233#[derive(Debug, Error)]
234#[error("pq training failed on chunk {chunk} of {of} (dim {dim})")]
235pub struct PQTrainingError {
236 chunk: usize,
237 of: usize,
238 dim: usize,
239 #[source]
240 kind: PQTrainingErrorKind,
241}
242
243impl PQTrainingError {
244 pub fn was_canceled(&self) -> bool {
246 matches!(self.kind, PQTrainingErrorKind::Canceled)
247 }
248}
249
250#[derive(Debug, Error)]
251#[non_exhaustive]
252enum PQTrainingErrorKind {
253 #[error("canceled by request")]
254 Canceled,
255 #[error("initial pivot selection error")]
256 Initialization(#[source] Box<dyn std::error::Error + Send + Sync>),
257 #[error("internal logic error")]
258 InternalError(#[source] Box<dyn std::error::Error + Send + Sync>),
259}
260
261#[cfg(not(miri))]
266#[cfg(test)]
267mod tests {
268 use std::sync::atomic::{AtomicUsize, Ordering};
269
270 use rand::{
271 Rng, SeedableRng,
272 distr::{Distribution, StandardUniform, Uniform},
273 rngs::StdRng,
274 seq::SliceRandom,
275 };
276
277 use diskann_utils::lazy_format;
278
279 use super::*;
280 use crate::{cancel::DontCancel, error::format, random::StdRngBuilder};
281
282 #[test]
285 fn test_flatten() {
286 let nrows = 5;
288 let sub_dims = [1, 2, 3, 4, 5];
290 let prefix_sum: Vec<usize> = sub_dims
292 .iter()
293 .scan(0, |state, i| {
294 let this = *state;
295 *state += *i;
296 Some(this)
297 })
298 .collect();
299
300 let dim: usize = sub_dims.iter().sum();
301
302 let matrices: Vec<Matrix<usize>> = std::iter::zip(sub_dims.iter(), prefix_sum.iter())
304 .map(|(&this_dim, &offset)| {
305 let mut m = Matrix::new(0, nrows, this_dim);
306 for r in 0..nrows {
307 for c in 0..this_dim {
308 m[(r, c)] = dim * r + offset + c;
309 }
310 }
311 m
312 })
313 .collect();
314
315 let flattened = flatten(&matrices, nrows, dim);
316 for (i, v) in flattened.as_slice().iter().enumerate() {
318 assert_eq!(*v, i, "failed at index {i}");
319 }
320 }
321
322 struct DatasetBuilder {
323 nclusters: usize,
324 cluster_size: usize,
325 step_between_clusters: f32,
326 }
327
328 struct ClusteredDataset {
329 data: Matrix<f32>,
330 centers: Matrix<f32>,
332 }
333
334 impl DatasetBuilder {
335 fn build<R>(
336 &self,
337 schema: crate::views::ChunkOffsetsView<'_>,
338 rng: &mut R,
339 ) -> ClusteredDataset
340 where
341 R: Rng,
342 {
343 let ndata = self.nclusters * self.cluster_size;
344
345 let offsets_distribution = Uniform::<f32>::new(-100.0, 100.0).unwrap();
348
349 let perturbation_distribution = rand_distr::StandardNormal;
352
353 let mut indices: Vec<usize> = (0..ndata).collect();
355
356 let (pieces, centers): (Vec<_>, Vec<_>) = (0..schema.len())
358 .map(|chunk| {
359 let dim = schema.at(chunk).len();
360
361 let mut initial = Matrix::new(0.0, ndata, dim);
362 let mut centers = Matrix::new(0.0, self.nclusters, 1);
363
364 let offset = offsets_distribution.sample(rng);
366
367 for cluster in 0..self.nclusters {
369 let this_offset = offset + (cluster as f32 * self.step_between_clusters);
370 centers[(cluster, 0)] = this_offset;
371
372 for element in 0..self.cluster_size {
373 let row = initial.row_mut(cluster * self.cluster_size + element);
374 for r in row.iter_mut() {
375 let perturbation: f32 = perturbation_distribution.sample(rng);
376 *r = this_offset + perturbation;
377 }
378 }
379 }
380
381 indices.shuffle(rng);
383 let mut piece = Matrix::new(0.0, ndata, dim);
384 for (dst, src) in indices.iter().enumerate() {
385 piece.row_mut(dst).copy_from_slice(initial.row(*src));
386 }
387 (piece, centers)
388 })
389 .unzip();
390
391 ClusteredDataset {
392 data: flatten(&pieces, ndata, schema.dim()),
393 centers: flatten(¢ers, self.nclusters, schema.len()),
394 }
395 }
396 }
397
398 fn broadcast_distance(x: &[f32], y: f32) -> f32 {
399 x.iter()
400 .map(|i| {
401 let d = *i - y;
402 d * d
403 })
404 .sum()
405 }
406
407 fn test_pq_training_happy_path(parallelism: Parallelism) {
409 let mut rng = StdRng::seed_from_u64(0x749cb951cf960384);
410 let builder = DatasetBuilder {
411 nclusters: 16,
412 cluster_size: 20,
413 step_between_clusters: 20.0,
416 };
417
418 let ncenters = builder.nclusters;
419
420 let offsets = [0, 2, 3, 8, 12, 16];
421 let schema = crate::views::ChunkOffsetsView::new(&offsets).unwrap();
422 let dataset = builder.build(schema, &mut rng);
423
424 let trainer = LightPQTrainingParameters::new(ncenters, 6);
425
426 let quantizer = trainer
427 .train(
428 dataset.data.as_view(),
429 schema,
430 parallelism,
431 &StdRngBuilder::new(StandardUniform {}.sample(&mut rng)),
432 &DontCancel,
433 )
434 .unwrap();
435
436 assert_eq!(quantizer.dim, schema.dim());
448 assert_eq!(quantizer.ncenters, ncenters);
449 assert_eq!(quantizer.pivots.len(), schema.len());
450 for (i, pivot) in quantizer.pivots.iter().enumerate() {
451 assert_eq!(
453 pivot.ncols(),
454 schema.at(i).len(),
455 "center {i} has the incorrect number of columns"
456 );
457 assert_eq!(pivot.nrows(), ncenters);
458
459 let mut seen: Vec<bool> = (0..dataset.centers.nrows()).map(|_| false).collect();
461 for row in pivot.row_iter() {
462 let mut min_distance = f32::MAX;
463 let mut min_index = 0;
464 for c in 0..dataset.centers.nrows() {
465 let distance = broadcast_distance(row, dataset.centers[(c, i)]);
466 if distance < min_distance {
467 min_distance = distance;
468 min_index = c;
469 }
470 }
471
472 assert!(
474 min_distance < 1.0,
475 "got a minimum distance of {}, pivot = {}. Row = {:?}",
476 min_distance,
477 i,
478 row,
479 );
480
481 let seen_before = &mut seen[min_index];
483 assert!(
484 !*seen_before,
485 "cluster {} has more than one assignment",
486 min_index
487 );
488 *seen_before = true;
489 }
490
491 assert!(seen.iter().all(|i| *i), "not all clusters were seen");
493 }
494
495 let flattened = quantizer.flatten();
497 assert_eq!(
498 &flattened,
499 flatten(&quantizer.pivots, quantizer.ncenters, quantizer.dim).as_slice()
500 );
501 }
502
503 #[test]
504 fn test_pq_training_happy_path_sequential() {
505 test_pq_training_happy_path(Parallelism::Sequential);
506 }
507
508 #[test]
509 #[cfg(feature = "rayon")]
510 fn test_pq_training_happy_path_parallel() {
511 test_pq_training_happy_path(Parallelism::Rayon);
512 }
513
514 struct CancelAfter {
516 counter: AtomicUsize,
517 after: usize,
518 }
519
520 impl CancelAfter {
521 fn new(after: usize) -> Self {
522 Self {
523 counter: AtomicUsize::new(0),
524 after,
525 }
526 }
527 }
528
529 impl Cancelation for CancelAfter {
530 fn should_cancel(&self) -> bool {
531 let v = self.counter.fetch_add(1, Ordering::Relaxed);
532 v >= self.after
533 }
534 }
535
536 #[test]
537 fn test_cancel() {
538 let mut rng = StdRng::seed_from_u64(0xb85352d38cc5353b);
539 let builder = DatasetBuilder {
540 nclusters: 16,
541 cluster_size: 20,
542 step_between_clusters: 20.0,
545 };
546
547 let offsets = [0, 2, 3, 8, 12, 16];
548 let schema = crate::views::ChunkOffsetsView::new(&offsets).unwrap();
549 let dataset = builder.build(schema, &mut rng);
550
551 let trainer = LightPQTrainingParameters::new(builder.nclusters, 6);
552
553 for after in 0..10 {
554 let parallelism = [
555 Parallelism::Sequential,
556 #[cfg(feature = "rayon")]
557 Parallelism::Rayon,
558 ];
559
560 for par in parallelism {
561 let result = trainer.train(
562 dataset.data.as_view(),
563 schema,
564 par,
565 &StdRngBuilder::new(StandardUniform {}.sample(&mut rng)),
566 &CancelAfter::new(after),
567 );
568 assert!(result.is_err(), "expected the operation to be canceled");
569 let err = result.unwrap_err();
570 assert!(
571 err.was_canceled(),
572 "expected the failure reason to be cancellation"
573 );
574 }
575 }
576 }
577
578 #[test]
581 fn tests_succeeded_with_too_many_pivots() {
582 let data = Matrix::<f32>::new(1.0, 10, 5);
583 let offsets: Vec<usize> = vec![0, 1, 4, 5];
584
585 let trainer = LightPQTrainingParameters::new(2 * data.nrows(), 6);
586
587 let quantizer = trainer
588 .train(
589 data.as_view(),
590 crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
591 Parallelism::Sequential,
592 &StdRngBuilder::new(0),
593 &DontCancel,
594 )
595 .unwrap();
596
597 let flat = flatten(&quantizer.pivots, quantizer.ncenters, quantizer.dim);
603
604 assert!(
605 flat.row(0).iter().all(|i| *i == 1.0),
606 "expected pivot 0 to be the non-zero pivot"
607 );
608
609 for (i, row) in flat.row_iter().enumerate() {
610 if i == 0 {
612 continue;
613 }
614
615 assert!(
616 row.iter().all(|j| *j == 0.0),
617 "expected pivot {i} to be all zeros"
618 );
619 }
620 }
621
622 #[test]
623 fn test_infinity_and_nan_is_not_recoverable() {
624 let num_trials = 10;
625 let nrows = 10;
626 let ncols = 5;
627
628 let offsets: Vec<usize> = vec![0, 1, 4, 5];
629
630 let trainer = LightPQTrainingParameters::new(nrows, 6);
631
632 let row_distribution = Uniform::new(0, nrows).unwrap();
633 let col_distribution = Uniform::new(0, ncols).unwrap();
634 let mut rng = StdRng::seed_from_u64(0xe746cfebba2d7e35);
635
636 for trial in 0..num_trials {
637 let context = lazy_format!("trial {} of {}", trial + 1, num_trials);
638
639 let r = row_distribution.sample(&mut rng);
640 let c = col_distribution.sample(&mut rng);
641
642 let check_result = |r: Result<_, PQTrainingError>| {
643 assert!(
644 r.is_err(),
645 "expected error due to infinities/NaN -- {}",
646 context
647 );
648 let err = r.unwrap_err();
649 assert!(!err.was_canceled());
650 assert!(format(&err).contains("infinity"));
651 };
652
653 let mut data = Matrix::<f32>::new(1.0, nrows, ncols);
654
655 data[(r, c)] = f32::INFINITY;
657 let result = trainer.train(
658 data.as_view(),
659 crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
660 Parallelism::Sequential,
661 &StdRngBuilder::new(0),
662 &DontCancel,
663 );
664 check_result(result);
665
666 data[(r, c)] = f32::NEG_INFINITY;
668 let result = trainer.train(
669 data.as_view(),
670 crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
671 Parallelism::Sequential,
672 &StdRngBuilder::new(0),
673 &DontCancel,
674 );
675 check_result(result);
676
677 data[(r, c)] = f32::NAN;
679 let result = trainer.train(
680 data.as_view(),
681 crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
682 Parallelism::Sequential,
683 &StdRngBuilder::new(0),
684 &DontCancel,
685 );
686 check_result(result);
687 }
688 }
689}