1use std::collections::HashMap;
27
28use num::Bounded;
29
30use crate::error::{Failed, FailedError};
31use crate::linalg::basic::arrays::{Array1, Array2};
32use crate::metrics::distance::euclidian::Euclidian;
33use crate::metrics::distance::PairwiseDistance;
34use crate::numbers::floatnum::FloatNumber;
35use crate::numbers::realnum::RealNumber;
36
37#[derive(Debug, Clone)]
45pub struct FastPair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
46 samples: &'a M,
48 pub distances: HashMap<usize, PairwiseDistance<T>>,
50 pub neighbours: Vec<usize>,
52}
53
54impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
55 pub fn new(m: &'a M) -> Result<Self, Failed> {
58 if m.shape().0 < 3 {
59 return Err(Failed::because(
60 FailedError::FindFailed,
61 "min number of rows should be 3",
62 ));
63 }
64
65 let mut init = Self {
66 samples: m,
67 distances: HashMap::with_capacity(m.shape().0),
69 neighbours: Vec::with_capacity(m.shape().0 + 1),
70 };
71 init.init();
72 Ok(init)
73 }
74
75 fn init(&mut self) {
78 let len = self.samples.shape().0;
80 let max_index = self.samples.shape().0 - 1;
81
82 let _distances = Box::new(HashMap::with_capacity(len));
84 let _neighbours = Box::new(Vec::with_capacity(len));
85
86 let mut distances = *_distances;
87 let mut neighbours = *_neighbours;
88
89 neighbours.extend(0..len);
91
92 for index_row_i in 0..(max_index) {
94 distances.insert(
95 index_row_i,
96 PairwiseDistance {
97 node: index_row_i,
98 neighbour: Option::None,
99 distance: Some(<T as Bounded>::max_value()),
100 },
101 );
102 }
103
104 for index_row_i in 0..(len) {
106 let mut index_closest = index_row_i + 1; let mut nbd: Option<T> = distances[&index_row_i].distance; for index_row_j in (index_row_i + 1)..len {
110 distances.insert(
111 index_row_j,
112 PairwiseDistance {
113 node: index_row_j,
114 neighbour: Some(index_row_i),
115 distance: nbd,
116 },
117 );
118
119 let d = Euclidian::squared_distance(
120 &Vec::from_iterator(
121 self.samples.get_row(index_row_i).iterator(0).copied(),
122 self.samples.shape().1,
123 ),
124 &Vec::from_iterator(
125 self.samples.get_row(index_row_j).iterator(0).copied(),
126 self.samples.shape().1,
127 ),
128 );
129 if d < nbd.unwrap().to_f64().unwrap() {
130 index_closest = index_row_j;
132 nbd = Some(T::from(d).unwrap());
133 }
134 }
135
136 distances.entry(index_row_i).and_modify(|e| {
138 e.distance = nbd;
139 e.neighbour = Some(index_closest);
140 });
141 }
142 distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
145 distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
146
147 let mut sparse_matrix = M::zeros(len, len);
149 for (_, p) in distances.iter() {
150 sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap());
151 }
152
153 self.distances = distances;
154 self.neighbours = neighbours;
155 }
156
157 #[allow(dead_code)]
159 pub fn closest_pair(&self) -> PairwiseDistance<T> {
160 let mut a = self.neighbours[0]; let mut d = self.distances[&a].distance;
162 for p in self.neighbours.iter() {
163 if self.distances[p].distance < d {
164 a = *p; d = self.distances[p].distance;
166 }
167 }
168 let b = self.distances[&a].neighbour;
169 PairwiseDistance {
170 node: a,
171 neighbour: b,
172 distance: d,
173 }
174 }
175
176 #[allow(dead_code)]
180 pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
181 let mut distances = self
184 .distances
185 .values()
186 .collect::<Vec<&PairwiseDistance<T>>>();
187 distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
188 distances.into_iter()
189 }
190
191 #[allow(dead_code)]
196 fn distances_from(&self, index_row: usize) -> Vec<PairwiseDistance<T>> {
197 let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
198 for other in self.neighbours.iter() {
199 if index_row != *other {
200 distances.push(PairwiseDistance {
201 node: index_row,
202 neighbour: Some(*other),
203 distance: Some(
204 T::from(Euclidian::squared_distance(
205 &Vec::from_iterator(
206 self.samples.get_row(index_row).iterator(0).copied(),
207 self.samples.shape().1,
208 ),
209 &Vec::from_iterator(
210 self.samples.get_row(*other).iterator(0).copied(),
211 self.samples.shape().1,
212 ),
213 ))
214 .unwrap(),
215 ),
216 })
217 }
218 }
219 distances
220 }
221}
222
223#[cfg(test)]
224mod tests_fastpair {
225
226 use super::*;
227 use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
228
229 pub fn closest_pair_brute(
231 fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
232 ) -> PairwiseDistance<f64> {
233 use itertools::Itertools;
234 let m = fastpair.samples.shape().0;
235
236 let mut closest_pair = PairwiseDistance {
237 node: 0,
238 neighbour: Option::None,
239 distance: Some(f64::max_value()),
240 };
241 for pair in (0..m).combinations(2) {
242 let d = Euclidian::squared_distance(
243 &Vec::from_iterator(
244 fastpair.samples.get_row(pair[0]).iterator(0).copied(),
245 fastpair.samples.shape().1,
246 ),
247 &Vec::from_iterator(
248 fastpair.samples.get_row(pair[1]).iterator(0).copied(),
249 fastpair.samples.shape().1,
250 ),
251 );
252 if d < closest_pair.distance.unwrap() {
253 closest_pair.node = pair[0];
254 closest_pair.neighbour = Some(pair[1]);
255 closest_pair.distance = Some(d);
256 }
257 }
258 closest_pair
259 }
260
261 #[test]
262 fn fastpair_init() {
263 let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
264 let _fastpair = FastPair::new(&x);
265 assert!(_fastpair.is_ok());
266
267 let fastpair = _fastpair.unwrap();
268
269 let distances = fastpair.distances;
270 let neighbours = fastpair.neighbours;
271
272 assert!(!distances.is_empty());
273 assert!(!neighbours.is_empty());
274
275 assert_eq!(10, neighbours.len());
276 assert_eq!(10, distances.len());
277 }
278
279 #[test]
280 fn dataset_has_at_least_three_points() {
281 let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]).unwrap();
284
285 let fastpair = FastPair::new(&dataset);
289 assert!(fastpair.is_err());
290
291 if let Err(e) = fastpair {
292 let expected_error =
293 Failed::because(FailedError::FindFailed, "min number of rows should be 3");
294 assert_eq!(e, expected_error)
295 }
296 }
297
298 #[test]
299 fn one_dimensional_dataset_minimal() {
300 let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]).unwrap();
301
302 let result = FastPair::new(&dataset);
303 assert!(result.is_ok());
304
305 let fastpair = result.unwrap();
306 let closest_pair = fastpair.closest_pair();
307 let expected_closest_pair = PairwiseDistance {
308 node: 0,
309 neighbour: Some(1),
310 distance: Some(4.0),
311 };
312 assert_eq!(closest_pair, expected_closest_pair);
313
314 let closest_pair_brute = closest_pair_brute(&fastpair);
315 assert_eq!(closest_pair_brute, expected_closest_pair);
316 }
317
318 #[test]
319 fn one_dimensional_dataset_2() {
320 let dataset =
321 DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]).unwrap();
322
323 let result = FastPair::new(&dataset);
324 assert!(result.is_ok());
325
326 let fastpair = result.unwrap();
327 let closest_pair = fastpair.closest_pair();
328 let expected_closest_pair = PairwiseDistance {
329 node: 1,
330 neighbour: Some(3),
331 distance: Some(4.0),
332 };
333 assert_eq!(closest_pair, closest_pair_brute(&fastpair));
334 assert_eq!(closest_pair, expected_closest_pair);
335 }
336
337 #[test]
338 fn fastpair_new() {
339 let x = DenseMatrix::<f64>::from_2d_array(&[
341 &[5.1, 3.5, 1.4, 0.2],
342 &[4.9, 3.0, 1.4, 0.2],
343 &[4.7, 3.2, 1.3, 0.2],
344 &[4.6, 3.1, 1.5, 0.2],
345 &[5.0, 3.6, 1.4, 0.2],
346 &[5.4, 3.9, 1.7, 0.4],
347 &[4.6, 3.4, 1.4, 0.3],
348 &[5.0, 3.4, 1.5, 0.2],
349 &[4.4, 2.9, 1.4, 0.2],
350 &[4.9, 3.1, 1.5, 0.1],
351 &[7.0, 3.2, 4.7, 1.4],
352 &[6.4, 3.2, 4.5, 1.5],
353 &[6.9, 3.1, 4.9, 1.5],
354 &[5.5, 2.3, 4.0, 1.3],
355 &[6.5, 2.8, 4.6, 1.5],
356 ])
357 .unwrap();
358 let fastpair = FastPair::new(&x);
359 assert!(fastpair.is_ok());
360
361 let result = fastpair.unwrap();
363
364 let dissimilarities = vec![
366 (
367 1,
368 PairwiseDistance {
369 node: 1,
370 neighbour: Some(9),
371 distance: Some(0.030000000000000037),
372 },
373 ),
374 (
375 10,
376 PairwiseDistance {
377 node: 10,
378 neighbour: Some(12),
379 distance: Some(0.07000000000000003),
380 },
381 ),
382 (
383 11,
384 PairwiseDistance {
385 node: 11,
386 neighbour: Some(14),
387 distance: Some(0.18000000000000013),
388 },
389 ),
390 (
391 12,
392 PairwiseDistance {
393 node: 12,
394 neighbour: Some(14),
395 distance: Some(0.34000000000000086),
396 },
397 ),
398 (
399 13,
400 PairwiseDistance {
401 node: 13,
402 neighbour: Some(14),
403 distance: Some(1.6499999999999997),
404 },
405 ),
406 (
407 14,
408 PairwiseDistance {
409 node: 14,
410 neighbour: Some(14),
411 distance: Some(f64::MAX),
412 },
413 ),
414 (
415 6,
416 PairwiseDistance {
417 node: 6,
418 neighbour: Some(7),
419 distance: Some(0.18000000000000027),
420 },
421 ),
422 (
423 0,
424 PairwiseDistance {
425 node: 0,
426 neighbour: Some(4),
427 distance: Some(0.01999999999999995),
428 },
429 ),
430 (
431 8,
432 PairwiseDistance {
433 node: 8,
434 neighbour: Some(9),
435 distance: Some(0.3100000000000001),
436 },
437 ),
438 (
439 2,
440 PairwiseDistance {
441 node: 2,
442 neighbour: Some(3),
443 distance: Some(0.0600000000000001),
444 },
445 ),
446 (
447 3,
448 PairwiseDistance {
449 node: 3,
450 neighbour: Some(8),
451 distance: Some(0.08999999999999982),
452 },
453 ),
454 (
455 7,
456 PairwiseDistance {
457 node: 7,
458 neighbour: Some(9),
459 distance: Some(0.10999999999999982),
460 },
461 ),
462 (
463 9,
464 PairwiseDistance {
465 node: 9,
466 neighbour: Some(13),
467 distance: Some(8.69),
468 },
469 ),
470 (
471 4,
472 PairwiseDistance {
473 node: 4,
474 neighbour: Some(7),
475 distance: Some(0.050000000000000086),
476 },
477 ),
478 (
479 5,
480 PairwiseDistance {
481 node: 5,
482 neighbour: Some(7),
483 distance: Some(0.4900000000000002),
484 },
485 ),
486 ];
487
488 let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
489
490 for i in 0..(x.shape().0 - 1) {
491 let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
492 let distance = Euclidian::squared_distance(
493 &Vec::from_iterator(
494 result.samples.get_row(i).iterator(0).copied(),
495 result.samples.shape().1,
496 ),
497 &Vec::from_iterator(
498 result.samples.get_row(input_neighbour).iterator(0).copied(),
499 result.samples.shape().1,
500 ),
501 );
502
503 assert_eq!(i, expected.get(&i).unwrap().node);
504 assert_eq!(
505 input_neighbour,
506 expected.get(&i).unwrap().neighbour.unwrap()
507 );
508 assert_eq!(distance, expected.get(&i).unwrap().distance.unwrap());
509 }
510 }
511
512 #[test]
513 fn fastpair_closest_pair() {
514 let x = DenseMatrix::<f64>::from_2d_array(&[
515 &[5.1, 3.5, 1.4, 0.2],
516 &[4.9, 3.0, 1.4, 0.2],
517 &[4.7, 3.2, 1.3, 0.2],
518 &[4.6, 3.1, 1.5, 0.2],
519 &[5.0, 3.6, 1.4, 0.2],
520 &[5.4, 3.9, 1.7, 0.4],
521 &[4.6, 3.4, 1.4, 0.3],
522 &[5.0, 3.4, 1.5, 0.2],
523 &[4.4, 2.9, 1.4, 0.2],
524 &[4.9, 3.1, 1.5, 0.1],
525 &[7.0, 3.2, 4.7, 1.4],
526 &[6.4, 3.2, 4.5, 1.5],
527 &[6.9, 3.1, 4.9, 1.5],
528 &[5.5, 2.3, 4.0, 1.3],
529 &[6.5, 2.8, 4.6, 1.5],
530 ])
531 .unwrap();
532 let fastpair = FastPair::new(&x);
534 assert!(fastpair.is_ok());
535
536 let dissimilarity = fastpair.unwrap().closest_pair();
537 let closest = PairwiseDistance {
538 node: 0,
539 neighbour: Some(4),
540 distance: Some(0.01999999999999995),
541 };
542
543 assert_eq!(closest, dissimilarity);
544 }
545
546 #[test]
547 fn fastpair_closest_pair_random_matrix() {
548 let x = DenseMatrix::<f64>::rand(200, 25);
549 let fastpair = FastPair::new(&x);
551 assert!(fastpair.is_ok());
552
553 let result = fastpair.unwrap();
554
555 let dissimilarity1 = result.closest_pair();
556 let dissimilarity2 = closest_pair_brute(&result);
557
558 assert_eq!(dissimilarity1, dissimilarity2);
559 }
560
561 #[test]
562 fn fastpair_distances() {
563 let x = DenseMatrix::<f64>::from_2d_array(&[
564 &[5.1, 3.5, 1.4, 0.2],
565 &[4.9, 3.0, 1.4, 0.2],
566 &[4.7, 3.2, 1.3, 0.2],
567 &[4.6, 3.1, 1.5, 0.2],
568 &[5.0, 3.6, 1.4, 0.2],
569 &[5.4, 3.9, 1.7, 0.4],
570 &[4.6, 3.4, 1.4, 0.3],
571 &[5.0, 3.4, 1.5, 0.2],
572 &[4.4, 2.9, 1.4, 0.2],
573 &[4.9, 3.1, 1.5, 0.1],
574 &[7.0, 3.2, 4.7, 1.4],
575 &[6.4, 3.2, 4.5, 1.5],
576 &[6.9, 3.1, 4.9, 1.5],
577 &[5.5, 2.3, 4.0, 1.3],
578 &[6.5, 2.8, 4.6, 1.5],
579 ])
580 .unwrap();
581 let fastpair = FastPair::new(&x);
583 assert!(fastpair.is_ok());
584
585 let dissimilarities = fastpair.unwrap().distances_from(0);
586
587 let mut min_dissimilarity = PairwiseDistance {
588 node: 0,
589 neighbour: Option::None,
590 distance: Some(f64::MAX),
591 };
592 for p in dissimilarities.iter() {
593 if p.distance.unwrap() < min_dissimilarity.distance.unwrap() {
594 min_dissimilarity = *p
595 }
596 }
597
598 let closest = PairwiseDistance {
599 node: 0,
600 neighbour: Some(4),
601 distance: Some(0.01999999999999995),
602 };
603
604 assert_eq!(closest, min_dissimilarity);
605 }
606
607 #[test]
608 fn fastpair_ordered_pairs() {
609 let x = DenseMatrix::<f64>::from_2d_array(&[
610 &[5.1, 3.5, 1.4, 0.2],
611 &[4.9, 3.0, 1.4, 0.2],
612 &[4.7, 3.2, 1.3, 0.2],
613 &[4.6, 3.1, 1.5, 0.2],
614 &[5.0, 3.6, 1.4, 0.2],
615 &[5.4, 3.9, 1.7, 0.4],
616 &[4.9, 3.1, 1.5, 0.1],
617 &[7.0, 3.2, 4.7, 1.4],
618 &[6.4, 3.2, 4.5, 1.5],
619 &[6.9, 3.1, 4.9, 1.5],
620 &[5.5, 2.3, 4.0, 1.3],
621 &[6.5, 2.8, 4.6, 1.5],
622 &[4.6, 3.4, 1.4, 0.3],
623 &[5.0, 3.4, 1.5, 0.2],
624 &[4.4, 2.9, 1.4, 0.2],
625 ])
626 .unwrap();
627 let fastpair = FastPair::new(&x).unwrap();
628
629 let ordered = fastpair.ordered_pairs();
630
631 let mut previous: f64 = -1.0;
632 for p in ordered {
633 if previous == -1.0 {
634 previous = p.distance.unwrap();
635 } else {
636 let current = p.distance.unwrap();
637 assert!(current >= previous);
638 previous = current;
639 }
640 }
641 }
642
643 #[test]
644 fn test_empty_set() {
645 let empty_matrix = DenseMatrix::<f64>::zeros(0, 0);
646 let result = FastPair::new(&empty_matrix);
647 assert!(result.is_err());
648 if let Err(e) = result {
649 assert_eq!(
650 e,
651 Failed::because(FailedError::FindFailed, "min number of rows should be 3")
652 );
653 }
654 }
655
656 #[test]
657 fn test_single_point() {
658 let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
659 let result = FastPair::new(&single_point);
660 assert!(result.is_err());
661 if let Err(e) = result {
662 assert_eq!(
663 e,
664 Failed::because(FailedError::FindFailed, "min number of rows should be 3")
665 );
666 }
667 }
668
669 #[test]
670 fn test_two_points() {
671 let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
672 let result = FastPair::new(&two_points);
673 assert!(result.is_err());
674 if let Err(e) = result {
675 assert_eq!(
676 e,
677 Failed::because(FailedError::FindFailed, "min number of rows should be 3")
678 );
679 }
680 }
681
682 #[test]
683 fn test_three_identical_points() {
684 let identical_points =
685 DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap();
686 let result = FastPair::new(&identical_points);
687 assert!(result.is_ok());
688 let fastpair = result.unwrap();
689 let closest_pair = fastpair.closest_pair();
690 assert_eq!(closest_pair.distance, Some(0.0));
691 }
692
693 #[test]
694 fn test_result_unwrapping() {
695 let valid_matrix =
696 DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]])
697 .unwrap();
698
699 let result = FastPair::new(&valid_matrix);
700 assert!(result.is_ok());
701
702 let _fastpair = result.unwrap();
704 }
705}