1use std::{collections::HashSet, hash::Hash};
7
8use diskann_utils::{
9 strided::StridedView,
10 views::{Matrix, MatrixView},
11};
12use thiserror::Error;
13
14#[derive(Debug, Clone)]
15#[non_exhaustive]
16pub struct RecallMetrics {
17 pub recall_k: usize,
19 pub recall_n: usize,
21 pub num_queries: usize,
23 pub average: f64,
25 pub minimum: usize,
27 pub maximum: usize,
29}
30
31#[derive(Debug, Error)]
32pub enum ComputeRecallError {
33 #[error("results matrix has {0} rows but ground truth has {1}")]
34 RowsMismatch(usize, usize),
35 #[error("distances matrix has {0} rows but ground truth has {1}")]
36 DistanceRowsMismatch(usize, usize),
37 #[error("recall k value {0} must be less than or equal to recall n {1}")]
38 RecallKAndNError(usize, usize),
39 #[error("number of results per query {0} must be at least the specified recall k {1}")]
40 NotEnoughResults(usize, usize),
41 #[error(
42 "number of groundtruth values per query {0} must be at least the specified recall n {1}"
43 )]
44 NotEnoughGroundTruth(usize, usize),
45 #[error("number of groundtruth distances {0} does not match groundtruth entries {1}")]
46 NotEnoughGroundTruthDistances(usize, usize),
47}
48
49pub trait Rows<T> {
69 fn nrows(&self) -> usize;
71
72 fn row(&self, i: usize) -> &[T];
74
75 fn ncols(&self) -> Option<usize> {
82 None
83 }
84}
85
86impl<T> Rows<T> for Matrix<T> {
87 fn nrows(&self) -> usize {
88 Matrix::<T>::nrows(self)
89 }
90 fn row(&self, i: usize) -> &[T] {
91 Matrix::<T>::row(self, i)
92 }
93 fn ncols(&self) -> Option<usize> {
94 Some(Matrix::<T>::ncols(self))
95 }
96}
97
98impl<T> Rows<T> for MatrixView<'_, T> {
99 fn nrows(&self) -> usize {
100 MatrixView::<'_, T>::nrows(self)
101 }
102 fn row(&self, i: usize) -> &[T] {
103 MatrixView::<'_, T>::row(self, i)
104 }
105 fn ncols(&self) -> Option<usize> {
106 Some(MatrixView::<'_, T>::ncols(self))
107 }
108}
109
110impl<T> Rows<T> for Vec<Vec<T>> {
111 fn nrows(&self) -> usize {
112 self.len()
113 }
114 fn row(&self, i: usize) -> &[T] {
115 &self[i]
116 }
117}
118
119pub trait RecallCompatible: Eq + Hash + Clone + std::fmt::Debug {}
121
122impl<T> RecallCompatible for T where T: Eq + Hash + Clone + std::fmt::Debug {}
123
124pub fn knn<T>(
136 groundtruth: &dyn Rows<T>,
137 groundtruth_distances: Option<StridedView<'_, f32>>,
138 results: &dyn Rows<T>,
139 recall_k: usize,
140 recall_n: usize,
141 allow_insufficient_results: bool,
142) -> Result<RecallMetrics, ComputeRecallError>
143where
144 T: RecallCompatible,
145{
146 if recall_k > recall_n {
147 return Err(ComputeRecallError::RecallKAndNError(recall_k, recall_n));
148 }
149
150 let nrows = results.nrows();
151 if nrows != groundtruth.nrows() {
152 return Err(ComputeRecallError::RowsMismatch(nrows, groundtruth.nrows()));
153 }
154
155 if let Some(cols) = results.ncols()
156 && cols < recall_n
157 && !allow_insufficient_results
158 {
159 return Err(ComputeRecallError::NotEnoughResults(cols, recall_n));
160 }
161
162 match groundtruth.ncols() {
164 Some(ncols) if ncols < recall_k => {
165 return Err(ComputeRecallError::NotEnoughGroundTruth(ncols, recall_k));
166 }
167 _ => {}
168 }
169
170 if let Some(distances) = groundtruth_distances {
171 if nrows != distances.nrows() {
172 return Err(ComputeRecallError::DistanceRowsMismatch(
173 distances.nrows(),
174 nrows,
175 ));
176 }
177
178 match groundtruth.ncols() {
179 Some(ncols) if distances.ncols() != ncols => {
180 return Err(ComputeRecallError::NotEnoughGroundTruthDistances(
181 distances.ncols(),
182 ncols,
183 ));
184 }
185 _ => {}
186 }
187 }
188
189 let mut recall_values: Vec<usize> = Vec::new();
191 let mut this_groundtruth = HashSet::new();
192 let mut this_results = HashSet::new();
193
194 for i in 0..results.nrows() {
195 let result = results.row(i);
196 if !allow_insufficient_results && result.len() < recall_n {
197 return Err(ComputeRecallError::NotEnoughResults(result.len(), recall_n));
198 }
199
200 let gt_row = groundtruth.row(i);
201 if gt_row.len() < recall_k {
202 return Err(ComputeRecallError::NotEnoughGroundTruth(
203 gt_row.len(),
204 recall_k,
205 ));
206 }
207
208 this_groundtruth.clear();
210 this_groundtruth.extend(gt_row.iter().take(recall_k).cloned());
211
212 if let Some(distances) = groundtruth_distances
215 && recall_k > 0
216 {
217 let distances_row = distances.row(i);
218 if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 {
219 let last_distance = distances_row[recall_k - 1];
220 for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) {
221 if *d == last_distance {
222 this_groundtruth.insert(g.clone());
223 } else {
224 break;
225 }
226 }
227 }
228 }
229
230 this_results.clear();
231 this_results.extend(result.iter().take(recall_n).cloned());
232
233 let r = this_groundtruth
235 .iter()
236 .filter(|i| this_results.contains(i))
237 .count()
238 .min(recall_k);
239
240 recall_values.push(r);
241 }
242
243 let total: usize = recall_values.iter().sum();
245 let minimum = recall_values.iter().min().unwrap_or(&0);
246 let maximum = recall_values.iter().max().unwrap_or(&0);
247
248 let div = recall_k * nrows;
250 let average = (total as f64) / (div as f64);
251
252 Ok(RecallMetrics {
253 recall_k,
254 recall_n,
255 num_queries: nrows,
256 average,
257 minimum: *minimum,
258 maximum: *maximum,
259 })
260}
261
262#[derive(Debug, Clone)]
263#[non_exhaustive]
264pub struct AveragePrecisionMetrics {
265 pub num_queries: usize,
267 pub average_precision: f64,
269}
270
271#[derive(Debug, Error)]
272pub enum AveragePrecisionError {
273 #[error("results has {0} elements but ground truth has {1}")]
274 EntriesMismatch(usize, usize),
275}
276
277pub fn average_precision<T>(
279 results: &dyn Rows<T>,
280 groundtruth: &dyn Rows<T>,
281) -> Result<AveragePrecisionMetrics, AveragePrecisionError>
282where
283 T: RecallCompatible,
284{
285 let nrows = results.nrows();
286 let groundtruth_nrows = groundtruth.nrows();
287 if nrows != groundtruth_nrows {
288 return Err(AveragePrecisionError::EntriesMismatch(
289 nrows,
290 groundtruth_nrows,
291 ));
292 }
293
294 let mut num_gt_results = 0;
296 let mut num_reported_results = 0;
297
298 let mut scratch = HashSet::new();
299 let nrows = results.nrows();
300
301 for i in 0..nrows {
302 let result = results.row(i);
303 let gt = groundtruth.row(i);
304
305 scratch.clear();
306 scratch.extend(result.iter().cloned());
307 num_reported_results += gt.iter().filter(|i| scratch.contains(i)).count();
308 num_gt_results += gt.len();
309 }
310
311 let average_precision = (num_reported_results as f64) / (num_gt_results as f64);
313
314 Ok(AveragePrecisionMetrics {
315 average_precision,
316 num_queries: nrows,
317 })
318}
319
320#[cfg(test)]
325mod tests {
326 use diskann_utils::views::{self, Matrix};
327
328 use super::*;
329
330 fn test_rows_inner(rows: &dyn Rows<usize>, ncols: Option<usize>) {
331 assert_eq!(rows.ncols(), ncols);
332 assert_eq!(rows.nrows(), 3);
333 assert_eq!(rows.row(0), &[0, 1, 2, 3]);
334 assert_eq!(rows.row(1), &[4, 5, 6, 7]);
335 assert_eq!(rows.row(2), &[8, 9, 10, 11]);
336 }
337
338 #[test]
339 fn test_rows() {
340 let mut i = 0usize;
341 let mat = Matrix::new(
342 views::Init(|| {
343 let v = i;
344 i += 1;
345 v
346 }),
347 3,
348 4,
349 );
350
351 test_rows_inner(&mat, Some(4));
352 test_rows_inner(&(mat.as_view()), Some(4));
353
354 let vecs = vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11]];
355 test_rows_inner(&vecs, None);
356 }
357
358 struct ExpectedRecall {
359 recall_k: usize,
360 recall_n: usize,
361 components: Vec<usize>,
363 }
364
365 impl ExpectedRecall {
366 fn new(recall_k: usize, recall_n: usize, components: Vec<usize>) -> Self {
367 assert!(recall_k <= recall_n);
368 components.iter().for_each(|x| {
369 assert!(*x <= recall_k);
370 });
371 Self {
372 recall_k,
373 recall_n,
374 components,
375 }
376 }
377
378 fn compute_recall(&self) -> f64 {
379 (self.components.iter().sum::<usize>() as f64)
380 / ((self.components.len() * self.recall_k) as f64)
381 }
382 }
383
384 #[test]
385 fn test_happy_path() {
386 let groundtruth = Matrix::try_from(
387 vec![
388 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ]
393 .into(),
394 4,
395 10,
396 )
397 .unwrap();
398
399 let distances = Matrix::try_from(
400 vec![
401 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, ]
406 .into(),
407 4,
408 10,
409 )
410 .unwrap();
411
412 let our_results = Matrix::try_from(
414 vec![
415 100, 0, 1, 2, 5, 6, 100, 101, 7, 8, 9, 10, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, ]
420 .into(),
421 4,
422 6,
423 )
424 .unwrap();
425
426 let expected_no_ties = vec![
430 ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
432 ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
433 ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
434 ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
435 ExpectedRecall::new(5, 5, vec![3, 3, 5, 5]),
436 ExpectedRecall::new(6, 6, vec![4, 4, 6, 6]),
437 ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
439 ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
440 ExpectedRecall::new(2, 3, vec![2, 0, 2, 2]),
441 ExpectedRecall::new(3, 5, vec![3, 1, 3, 3]),
442 ];
443 let epsilon = 1e-6; for (i, expected) in expected_no_ties.iter().enumerate() {
446 assert_eq!(expected.components.len(), our_results.nrows());
447 let recall = knn(
448 &groundtruth,
449 None,
450 &our_results,
451 expected.recall_k,
452 expected.recall_n,
453 false,
454 )
455 .unwrap();
456
457 let left = recall.average;
458 let right = expected.compute_recall();
459 assert!(
460 (left - right).abs() < epsilon,
461 "left = {}, right = {} on input {}",
462 left,
463 right,
464 i
465 );
466
467 assert_eq!(recall.num_queries, our_results.nrows());
468 assert_eq!(recall.recall_k, expected.recall_k);
469 assert_eq!(recall.recall_n, expected.recall_n);
470 assert_eq!(recall.minimum, *expected.components.iter().min().unwrap());
471 assert_eq!(recall.maximum, *expected.components.iter().max().unwrap());
472 }
473
474 let expected_with_ties = vec![
478 ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]),
480 ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]),
481 ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]),
482 ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]),
483 ExpectedRecall::new(5, 5, vec![4, 3, 5, 5]), ExpectedRecall::new(6, 6, vec![5, 4, 6, 6]), ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]),
487 ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]),
488 ExpectedRecall::new(2, 3, vec![2, 1, 2, 2]),
489 ExpectedRecall::new(4, 5, vec![4, 3, 4, 4]),
490 ];
491
492 for (i, expected) in expected_with_ties.iter().enumerate() {
493 assert_eq!(expected.components.len(), our_results.nrows());
494 let recall = knn(
495 &groundtruth,
496 Some(distances.as_view().into()),
497 &our_results,
498 expected.recall_k,
499 expected.recall_n,
500 false,
501 )
502 .unwrap();
503
504 let left = recall.average;
505 let right = expected.compute_recall();
506 assert!(
507 (left - right).abs() < epsilon,
508 "left = {}, right = {} on input {}",
509 left,
510 right,
511 i
512 );
513
514 assert_eq!(recall.num_queries, our_results.nrows());
515 assert_eq!(recall.recall_k, expected.recall_k);
516 assert_eq!(recall.recall_n, expected.recall_n);
517 assert_eq!(recall.minimum, *expected.components.iter().min().unwrap());
518 assert_eq!(recall.maximum, *expected.components.iter().max().unwrap());
519 }
520 }
521
522 #[test]
523 fn test_errors() {
524 {
526 let groundtruth = Matrix::<u32>::new(0, 10, 10);
527 let results = Matrix::<u32>::new(0, 10, 10);
528 let err = knn(&groundtruth, None, &results, 11, 10, false).unwrap_err();
529 assert!(matches!(err, ComputeRecallError::RecallKAndNError(..)));
530 }
531
532 {
534 let groundtruth = Matrix::<u32>::new(0, 11, 10);
535 let results = Matrix::<u32>::new(0, 10, 10);
536 let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err();
537 assert!(matches!(err, ComputeRecallError::RowsMismatch(..)));
538 let err_allow_insufficient_results =
539 knn(&groundtruth, None, &results, 10, 10, true).unwrap_err();
540 assert!(matches!(
541 err_allow_insufficient_results,
542 ComputeRecallError::RowsMismatch(..)
543 ));
544 }
545
546 {
548 let groundtruth = Matrix::<u32>::new(0, 10, 10);
549 let results = Matrix::<u32>::new(0, 10, 5);
550 let err = knn(&groundtruth, None, &results, 5, 10, false).unwrap_err();
551 assert!(matches!(err, ComputeRecallError::NotEnoughResults(..)));
552 let _ = knn(&groundtruth, None, &results, 5, 10, true);
553 }
554
555 {
557 let groundtruth = Matrix::<u32>::new(0, 10, 10);
558 let results: Vec<_> = (0..10).map(|_| vec![0; 5]).collect();
559 let err = knn(&groundtruth, None, &results, 5, 10, false).unwrap_err();
560 assert!(matches!(err, ComputeRecallError::NotEnoughResults(..)));
561 let _ = knn(&groundtruth, None, &results, 5, 10, true);
562 }
563
564 {
566 let groundtruth = Matrix::<u32>::new(0, 10, 5);
567 let results = Matrix::<u32>::new(0, 10, 10);
568 let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err();
569 assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..)));
570 let err_allow_insufficient_results =
571 knn(&groundtruth, None, &results, 10, 10, true).unwrap_err();
572 assert!(matches!(
573 err_allow_insufficient_results,
574 ComputeRecallError::NotEnoughGroundTruth(..)
575 ));
576 }
577
578 {
580 let groundtruth: Vec<_> = (0..10).map(|_| vec![0; 5]).collect();
581 let results = Matrix::<u32>::new(0, 10, 10);
582 let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err();
583 assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..)));
584 let err_allow_insufficient_results =
585 knn(&groundtruth, None, &results, 10, 10, true).unwrap_err();
586 assert!(matches!(
587 err_allow_insufficient_results,
588 ComputeRecallError::NotEnoughGroundTruth(..)
589 ));
590 }
591
592 {
594 let groundtruth = Matrix::<u32>::new(0, 10, 10);
595 let distances = Matrix::<f32>::new(0.0, 9, 10);
596 let results = Matrix::<u32>::new(0, 10, 10);
597 let err = knn(
598 &groundtruth,
599 Some(distances.as_view().into()),
600 &results,
601 10,
602 10,
603 false,
604 )
605 .unwrap_err();
606 assert!(matches!(err, ComputeRecallError::DistanceRowsMismatch(..)));
607 }
608
609 {
611 let groundtruth = Matrix::<u32>::new(0, 10, 10);
612 let distances = Matrix::<f32>::new(0.0, 10, 9);
613 let results = Matrix::<u32>::new(0, 10, 10);
614 let err = knn(
615 &groundtruth,
616 Some(distances.as_view().into()),
617 &results,
618 10,
619 10,
620 false,
621 )
622 .unwrap_err();
623 assert!(matches!(
624 err,
625 ComputeRecallError::NotEnoughGroundTruthDistances(..)
626 ));
627 }
628 }
629}