1use diskann_utils::{
7 strided::StridedView,
8 views::{self, Matrix},
9};
10#[cfg(feature = "rayon")]
11use rayon::iter::{IntoParallelIterator, ParallelIterator};
12use thiserror::Error;
13
14use crate::{
15 Parallelism,
16 algorithms::kmeans::{
17 self,
18 common::{BlockTranspose, square_norm},
19 },
20 cancel::Cancelation,
21 random::{BoxedRngBuilder, RngBuilder},
22};
23
24pub struct LightPQTrainingParameters {
25 ncenters: usize,
27 lloyds_reps: usize,
29}
30
31impl LightPQTrainingParameters {
32 pub fn new(ncenters: usize, lloyds_reps: usize) -> Self {
34 Self {
35 ncenters,
36 lloyds_reps,
37 }
38 }
39}
40
41#[derive(Debug)]
42pub struct SimplePivots {
43 dim: usize,
44 ncenters: usize,
45 pivots: Vec<Matrix<f32>>,
46}
47
48fn flatten<T: Copy + Default>(pivots: &[Matrix<T>], ncenters: usize, dim: usize) -> Matrix<T> {
49 let mut flattened = Matrix::new(T::default(), ncenters, dim);
50 let mut col_start = 0;
51 for matrix in pivots {
52 assert_eq!(matrix.nrows(), flattened.nrows());
53 for (row_index, row) in matrix.row_iter().enumerate() {
54 let dst = &mut flattened.row_mut(row_index)[col_start..col_start + row.len()];
55 dst.copy_from_slice(row);
56 }
57 col_start += matrix.ncols();
58 }
59 flattened
60}
61
62impl SimplePivots {
63 pub fn pivots(&self) -> &[Matrix<f32>] {
65 &self.pivots
66 }
67
68 pub fn flatten(&self) -> Vec<f32> {
70 flatten(self.pivots(), self.ncenters, self.dim)
71 .into_inner()
72 .into()
73 }
74}
75
76pub trait TrainQuantizer {
77 type Quantizer;
78 type Error: std::error::Error;
79
80 fn train<R, C>(
81 &self,
82 data: views::MatrixView<f32>,
83 schema: crate::views::ChunkOffsetsView<'_>,
84 parallelism: Parallelism,
85 rng_builder: &R,
86 cancelation: &C,
87 ) -> Result<Self::Quantizer, Self::Error>
88 where
89 R: RngBuilder<usize> + Sync,
90 C: Cancelation + Sync;
91}
92
93impl TrainQuantizer for LightPQTrainingParameters {
94 type Quantizer = SimplePivots;
95 type Error = PQTrainingError;
96
97 fn train<R, C>(
112 &self,
113 data: views::MatrixView<f32>,
114 schema: crate::views::ChunkOffsetsView<'_>,
115 parallelism: Parallelism,
116 rng_builder: &R,
117 cancelation: &C,
118 ) -> Result<Self::Quantizer, Self::Error>
119 where
120 R: RngBuilder<usize> + Sync,
121 C: Cancelation + Sync,
122 {
123 #[inline(never)]
126 fn train(
127 trainer: &LightPQTrainingParameters,
128 data: views::MatrixView<f32>,
129 schema: crate::views::ChunkOffsetsView<'_>,
130 parallelism: Parallelism,
131 rng_builder: &(dyn BoxedRngBuilder<usize> + Sync),
132 cancelation: &(dyn Cancelation + Sync),
133 ) -> Result<SimplePivots, PQTrainingError> {
134 assert_eq!(data.ncols(), schema.dim());
136
137 let thunk = |i| -> Result<Matrix<f32>, PQTrainingError> {
138 let range = schema.at(i);
139
140 let exit_if_canceled = || -> Result<(), PQTrainingError> {
142 if cancelation.should_cancel() {
143 Err(PQTrainingError {
144 chunk: i,
145 of: schema.len(),
146 dim: range.len(),
147 kind: PQTrainingErrorKind::Canceled,
148 })
149 } else {
150 Ok(())
151 }
152 };
153
154 exit_if_canceled()?;
157
158 let view = StridedView::try_shrink_from(
159 &(data.as_slice()[range.start..]),
160 data.nrows(),
161 range.len(),
162 schema.dim(),
163 )
164 .map_err(|err| PQTrainingError {
165 chunk: i,
166 of: schema.len(),
167 dim: range.len(),
168 kind: PQTrainingErrorKind::InternalError(Box::new(err.as_static())),
169 })?;
170
171 let norms: Vec<f32> = view.row_iter().map(square_norm).collect();
173 let transpose = BlockTranspose::<16>::from_strided(view);
174 let mut centers = Matrix::new(0.0, trainer.ncenters, range.len());
175
176 let mut rng = rng_builder.build_boxed_rng(i);
178
179 kmeans::plusplus::kmeans_plusplus_into_inner(
181 centers.as_mut_view(),
182 view,
183 &transpose,
184 &norms,
185 &mut rng,
186 )
187 .or_else(|err| {
188 if !err.is_numerically_recoverable() {
190 Err(PQTrainingError {
191 chunk: i,
192 of: schema.len(),
193 dim: range.len(),
194 kind: PQTrainingErrorKind::Initialization(Box::new(err)),
195 })
196 } else {
197 Ok(())
198 }
199 })?;
200
201 exit_if_canceled()?;
203
204 kmeans::lloyds::lloyds_inner(
206 view,
207 &norms,
208 &transpose,
209 centers.as_mut_view(),
210 trainer.lloyds_reps,
211 );
212 Ok(centers)
213 };
214
215 let pivots: Result<Vec<_>, _> = match parallelism {
216 Parallelism::Sequential => (0..schema.len()).map(thunk).collect(),
217
218 #[cfg(feature = "rayon")]
219 Parallelism::Rayon => (0..schema.len()).into_par_iter().map(thunk).collect(),
220 };
221
222 let dim = data.ncols();
223 let ncenters = trainer.ncenters;
224 Ok(SimplePivots {
225 dim,
226 ncenters,
227 pivots: pivots?,
228 })
229 }
230
231 train(self, data, schema, parallelism, rng_builder, cancelation)
232 }
233}
234
235#[derive(Debug, Error)]
236#[error("pq training failed on chunk {chunk} of {of} (dim {dim})")]
237pub struct PQTrainingError {
238 chunk: usize,
239 of: usize,
240 dim: usize,
241 #[source]
242 kind: PQTrainingErrorKind,
243}
244
245impl PQTrainingError {
246 pub fn was_canceled(&self) -> bool {
248 matches!(self.kind, PQTrainingErrorKind::Canceled)
249 }
250}
251
252#[derive(Debug, Error)]
253#[non_exhaustive]
254enum PQTrainingErrorKind {
255 #[error("canceled by request")]
256 Canceled,
257 #[error("initial pivot selection error")]
258 Initialization(#[source] Box<dyn std::error::Error + Send + Sync>),
259 #[error("internal logic error")]
260 InternalError(#[source] Box<dyn std::error::Error + Send + Sync>),
261}
262
263#[cfg(not(miri))]
268#[cfg(test)]
269mod tests {
270 use std::sync::atomic::{AtomicUsize, Ordering};
271
272 use rand::{
273 Rng, SeedableRng,
274 distr::{Distribution, StandardUniform, Uniform},
275 rngs::StdRng,
276 seq::SliceRandom,
277 };
278
279 use diskann_utils::lazy_format;
280
281 use super::*;
282 use crate::{cancel::DontCancel, error::format, random::StdRngBuilder};
283
284 #[test]
287 fn test_flatten() {
288 let nrows = 5;
290 let sub_dims = [1, 2, 3, 4, 5];
292 let prefix_sum: Vec<usize> = sub_dims
294 .iter()
295 .scan(0, |state, i| {
296 let this = *state;
297 *state += *i;
298 Some(this)
299 })
300 .collect();
301
302 let dim: usize = sub_dims.iter().sum();
303
304 let matrices: Vec<Matrix<usize>> = std::iter::zip(sub_dims.iter(), prefix_sum.iter())
306 .map(|(&this_dim, &offset)| {
307 let mut m = Matrix::new(0, nrows, this_dim);
308 for r in 0..nrows {
309 for c in 0..this_dim {
310 m[(r, c)] = dim * r + offset + c;
311 }
312 }
313 m
314 })
315 .collect();
316
317 let flattened = flatten(&matrices, nrows, dim);
318 for (i, v) in flattened.as_slice().iter().enumerate() {
320 assert_eq!(*v, i, "failed at index {i}");
321 }
322 }
323
324 struct DatasetBuilder {
325 nclusters: usize,
326 cluster_size: usize,
327 step_between_clusters: f32,
328 }
329
330 struct ClusteredDataset {
331 data: Matrix<f32>,
332 centers: Matrix<f32>,
334 }
335
336 impl DatasetBuilder {
337 fn build<R>(
338 &self,
339 schema: crate::views::ChunkOffsetsView<'_>,
340 rng: &mut R,
341 ) -> ClusteredDataset
342 where
343 R: Rng,
344 {
345 let ndata = self.nclusters * self.cluster_size;
346
347 let offsets_distribution = Uniform::<f32>::new(-100.0, 100.0).unwrap();
350
351 let perturbation_distribution = rand_distr::StandardNormal;
354
355 let mut indices: Vec<usize> = (0..ndata).collect();
357
358 let (pieces, centers): (Vec<_>, Vec<_>) = (0..schema.len())
360 .map(|chunk| {
361 let dim = schema.at(chunk).len();
362
363 let mut initial = Matrix::new(0.0, ndata, dim);
364 let mut centers = Matrix::new(0.0, self.nclusters, 1);
365
366 let offset = offsets_distribution.sample(rng);
368
369 for cluster in 0..self.nclusters {
371 let this_offset = offset + (cluster as f32 * self.step_between_clusters);
372 centers[(cluster, 0)] = this_offset;
373
374 for element in 0..self.cluster_size {
375 let row = initial.row_mut(cluster * self.cluster_size + element);
376 for r in row.iter_mut() {
377 let perturbation: f32 = perturbation_distribution.sample(rng);
378 *r = this_offset + perturbation;
379 }
380 }
381 }
382
383 indices.shuffle(rng);
385 let mut piece = Matrix::new(0.0, ndata, dim);
386 for (dst, src) in indices.iter().enumerate() {
387 piece.row_mut(dst).copy_from_slice(initial.row(*src));
388 }
389 (piece, centers)
390 })
391 .unzip();
392
393 ClusteredDataset {
394 data: flatten(&pieces, ndata, schema.dim()),
395 centers: flatten(¢ers, self.nclusters, schema.len()),
396 }
397 }
398 }
399
400 fn broadcast_distance(x: &[f32], y: f32) -> f32 {
401 x.iter()
402 .map(|i| {
403 let d = *i - y;
404 d * d
405 })
406 .sum()
407 }
408
409 fn test_pq_training_happy_path(parallelism: Parallelism) {
411 let mut rng = StdRng::seed_from_u64(0x749cb951cf960384);
412 let builder = DatasetBuilder {
413 nclusters: 16,
414 cluster_size: 20,
415 step_between_clusters: 20.0,
418 };
419
420 let ncenters = builder.nclusters;
421
422 let offsets = [0, 2, 3, 8, 12, 16];
423 let schema = crate::views::ChunkOffsetsView::new(&offsets).unwrap();
424 let dataset = builder.build(schema, &mut rng);
425
426 let trainer = LightPQTrainingParameters::new(ncenters, 6);
427
428 let quantizer = trainer
429 .train(
430 dataset.data.as_view(),
431 schema,
432 parallelism,
433 &StdRngBuilder::new(StandardUniform {}.sample(&mut rng)),
434 &DontCancel,
435 )
436 .unwrap();
437
438 assert_eq!(quantizer.dim, schema.dim());
450 assert_eq!(quantizer.ncenters, ncenters);
451 assert_eq!(quantizer.pivots.len(), schema.len());
452 for (i, pivot) in quantizer.pivots.iter().enumerate() {
453 assert_eq!(
455 pivot.ncols(),
456 schema.at(i).len(),
457 "center {i} has the incorrect number of columns"
458 );
459 assert_eq!(pivot.nrows(), ncenters);
460
461 let mut seen: Vec<bool> = (0..dataset.centers.nrows()).map(|_| false).collect();
463 for row in pivot.row_iter() {
464 let mut min_distance = f32::MAX;
465 let mut min_index = 0;
466 for c in 0..dataset.centers.nrows() {
467 let distance = broadcast_distance(row, dataset.centers[(c, i)]);
468 if distance < min_distance {
469 min_distance = distance;
470 min_index = c;
471 }
472 }
473
474 assert!(
476 min_distance < 1.0,
477 "got a minimum distance of {}, pivot = {}. Row = {:?}",
478 min_distance,
479 i,
480 row,
481 );
482
483 let seen_before = &mut seen[min_index];
485 assert!(
486 !*seen_before,
487 "cluster {} has more than one assignment",
488 min_index
489 );
490 *seen_before = true;
491 }
492
493 assert!(seen.iter().all(|i| *i), "not all clusters were seen");
495 }
496
497 let flattened = quantizer.flatten();
499 assert_eq!(
500 &flattened,
501 flatten(&quantizer.pivots, quantizer.ncenters, quantizer.dim).as_slice()
502 );
503 }
504
505 #[test]
506 fn test_pq_training_happy_path_sequential() {
507 test_pq_training_happy_path(Parallelism::Sequential);
508 }
509
510 #[test]
511 #[cfg(feature = "rayon")]
512 fn test_pq_training_happy_path_parallel() {
513 test_pq_training_happy_path(Parallelism::Rayon);
514 }
515
516 struct CancelAfter {
518 counter: AtomicUsize,
519 after: usize,
520 }
521
522 impl CancelAfter {
523 fn new(after: usize) -> Self {
524 Self {
525 counter: AtomicUsize::new(0),
526 after,
527 }
528 }
529 }
530
531 impl Cancelation for CancelAfter {
532 fn should_cancel(&self) -> bool {
533 let v = self.counter.fetch_add(1, Ordering::Relaxed);
534 v >= self.after
535 }
536 }
537
538 #[test]
539 fn test_cancel() {
540 let mut rng = StdRng::seed_from_u64(0xb85352d38cc5353b);
541 let builder = DatasetBuilder {
542 nclusters: 16,
543 cluster_size: 20,
544 step_between_clusters: 20.0,
547 };
548
549 let offsets = [0, 2, 3, 8, 12, 16];
550 let schema = crate::views::ChunkOffsetsView::new(&offsets).unwrap();
551 let dataset = builder.build(schema, &mut rng);
552
553 let trainer = LightPQTrainingParameters::new(builder.nclusters, 6);
554
555 for after in 0..10 {
556 let parallelism = [
557 Parallelism::Sequential,
558 #[cfg(feature = "rayon")]
559 Parallelism::Rayon,
560 ];
561
562 for par in parallelism {
563 let result = trainer.train(
564 dataset.data.as_view(),
565 schema,
566 par,
567 &StdRngBuilder::new(StandardUniform {}.sample(&mut rng)),
568 &CancelAfter::new(after),
569 );
570 assert!(result.is_err(), "expected the operation to be canceled");
571 let err = result.unwrap_err();
572 assert!(
573 err.was_canceled(),
574 "expected the failure reason to be cancellation"
575 );
576 }
577 }
578 }
579
580 #[test]
583 fn tests_succeeded_with_too_many_pivots() {
584 let data = Matrix::<f32>::new(1.0, 10, 5);
585 let offsets: Vec<usize> = vec![0, 1, 4, 5];
586
587 let trainer = LightPQTrainingParameters::new(2 * data.nrows(), 6);
588
589 let quantizer = trainer
590 .train(
591 data.as_view(),
592 crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
593 Parallelism::Sequential,
594 &StdRngBuilder::new(0),
595 &DontCancel,
596 )
597 .unwrap();
598
599 let flat = flatten(&quantizer.pivots, quantizer.ncenters, quantizer.dim);
605
606 assert!(
607 flat.row(0).iter().all(|i| *i == 1.0),
608 "expected pivot 0 to be the non-zero pivot"
609 );
610
611 for (i, row) in flat.row_iter().enumerate() {
612 if i == 0 {
614 continue;
615 }
616
617 assert!(
618 row.iter().all(|j| *j == 0.0),
619 "expected pivot {i} to be all zeros"
620 );
621 }
622 }
623
624 #[test]
625 fn test_infinity_and_nan_is_not_recoverable() {
626 let num_trials = 10;
627 let nrows = 10;
628 let ncols = 5;
629
630 let offsets: Vec<usize> = vec![0, 1, 4, 5];
631
632 let trainer = LightPQTrainingParameters::new(nrows, 6);
633
634 let row_distribution = Uniform::new(0, nrows).unwrap();
635 let col_distribution = Uniform::new(0, ncols).unwrap();
636 let mut rng = StdRng::seed_from_u64(0xe746cfebba2d7e35);
637
638 for trial in 0..num_trials {
639 let context = lazy_format!("trial {} of {}", trial + 1, num_trials);
640
641 let r = row_distribution.sample(&mut rng);
642 let c = col_distribution.sample(&mut rng);
643
644 let check_result = |r: Result<_, PQTrainingError>| {
645 assert!(
646 r.is_err(),
647 "expected error due to infinities/NaN -- {}",
648 context
649 );
650 let err = r.unwrap_err();
651 assert!(!err.was_canceled());
652 assert!(format(&err).contains("infinity"));
653 };
654
655 let mut data = Matrix::<f32>::new(1.0, nrows, ncols);
656
657 data[(r, c)] = f32::INFINITY;
659 let result = trainer.train(
660 data.as_view(),
661 crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
662 Parallelism::Sequential,
663 &StdRngBuilder::new(0),
664 &DontCancel,
665 );
666 check_result(result);
667
668 data[(r, c)] = f32::NEG_INFINITY;
670 let result = trainer.train(
671 data.as_view(),
672 crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
673 Parallelism::Sequential,
674 &StdRngBuilder::new(0),
675 &DontCancel,
676 );
677 check_result(result);
678
679 data[(r, c)] = f32::NAN;
681 let result = trainer.train(
682 data.as_view(),
683 crate::views::ChunkOffsetsView::new(&offsets).unwrap(),
684 Parallelism::Sequential,
685 &StdRngBuilder::new(0),
686 &DontCancel,
687 );
688 check_result(result);
689 }
690 }
691}