1use scirs2_core::ndarray::{Array2, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12use crate::error::{ClusteringError, Result};
13use crate::hierarchy::{LinkageMethod, Metric};
14
15#[derive(Debug, Clone)]
19pub struct SparseDistanceMatrix<F: Float> {
20 rows: Vec<usize>,
22 cols: Vec<usize>,
24 data: Vec<F>,
26 n_samples: usize,
28 default_value: F,
30}
31
32impl<F: Float + FromPrimitive> SparseDistanceMatrix<F> {
33 pub fn new(n_samples: usize, default_value: F) -> Self {
35 Self {
36 rows: Vec::new(),
37 cols: Vec::new(),
38 data: Vec::new(),
39 n_samples,
40 default_value,
41 }
42 }
43
44 pub fn from_dense(dense: ArrayView2<F>, threshold: F) -> Self {
46 let n_samples = dense.shape()[0];
47 let mut rows = Vec::new();
48 let mut cols = Vec::new();
49 let mut data = Vec::new();
50
51 for i in 0..n_samples {
52 for j in (i + 1)..n_samples {
53 let distance = dense[[i, j]];
54 if distance > threshold {
55 rows.push(i);
56 cols.push(j);
57 data.push(distance);
58 }
59 }
60 }
61
62 Self {
63 rows,
64 cols,
65 data,
66 n_samples,
67 default_value: F::zero(),
68 }
69 }
70
71 pub fn add_distance(&mut self, i: usize, j: usize, distance: F) -> Result<()> {
73 if i >= self.n_samples || j >= self.n_samples {
74 return Err(ClusteringError::InvalidInput("Index out of bounds".into()));
75 }
76
77 let (row, col) = if i < j { (i, j) } else { (j, i) };
79
80 for idx in 0..self.rows.len() {
82 if self.rows[idx] == row && self.cols[idx] == col {
83 if distance < self.data[idx] {
85 self.data[idx] = distance;
86 }
87 return Ok(());
88 }
89 }
90
91 self.rows.push(row);
93 self.cols.push(col);
94 self.data.push(distance);
95
96 Ok(())
97 }
98
99 pub fn get_distance(&self, i: usize, j: usize) -> F {
101 if i == j {
102 return F::zero();
103 }
104
105 let (row, col) = if i < j { (i, j) } else { (j, i) };
106
107 for idx in 0..self.rows.len() {
109 if self.rows[idx] == row && self.cols[idx] == col {
110 return self.data[idx];
111 }
112 }
113
114 self.default_value
115 }
116
117 pub fn neighbors_within_distance(&self, point: usize, maxdistance: F) -> Vec<(usize, F)> {
119 let mut neighbors = Vec::new();
120
121 for idx in 0..self.rows.len() {
123 let (neighbor, distance) = if self.rows[idx] == point {
124 (self.cols[idx], self.data[idx])
125 } else if self.cols[idx] == point {
126 (self.rows[idx], self.data[idx])
127 } else {
128 continue;
129 };
130
131 if distance <= maxdistance {
132 neighbors.push((neighbor, distance));
133 }
134 }
135
136 neighbors
137 }
138
139 pub fn k_nearest_neighbors(&self, point: usize, k: usize) -> Vec<(usize, F)> {
141 let mut all_neighbors = Vec::new();
142
143 for idx in 0..self.rows.len() {
145 let (neighbor, distance) = if self.rows[idx] == point {
146 (self.cols[idx], self.data[idx])
147 } else if self.cols[idx] == point {
148 (self.rows[idx], self.data[idx])
149 } else {
150 continue;
151 };
152
153 all_neighbors.push((neighbor, distance));
154 }
155
156 all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
158 all_neighbors.truncate(k);
159
160 all_neighbors
161 }
162
163 pub fn to_dense(&self) -> Array2<F> {
165 let mut dense = Array2::from_elem((self.n_samples, self.n_samples), self.default_value);
166
167 for i in 0..self.n_samples {
169 dense[[i, i]] = F::zero();
170 }
171
172 for idx in 0..self.rows.len() {
174 let i = self.rows[idx];
175 let j = self.cols[idx];
176 let distance = self.data[idx];
177
178 dense[[i, j]] = distance;
179 dense[[j, i]] = distance;
180 }
181
182 dense
183 }
184
185 pub fn nnz(&self) -> usize {
187 self.data.len()
188 }
189
190 pub fn sparsity(&self) -> f64 {
192 let total_entries = self.n_samples * (self.n_samples - 1) / 2;
193 1.0 - (self.nnz() as f64 / total_entries as f64)
194 }
195
196 pub fn n_samples(&self) -> usize {
198 self.n_samples
199 }
200}
201
202pub struct SparseHierarchicalClustering<F: Float> {
207 sparse_matrix: SparseDistanceMatrix<F>,
208 linkage_method: LinkageMethod,
209}
210
211impl<F: Float + FromPrimitive + Debug + PartialOrd> SparseHierarchicalClustering<F> {
212 pub fn new(sparse_matrix: SparseDistanceMatrix<F>, linkage_method: LinkageMethod) -> Self {
214 Self {
215 sparse_matrix,
216 linkage_method,
217 }
218 }
219
220 pub fn fit(&self) -> Result<Array2<F>> {
222 let n_samples = self.sparse_matrix.n_samples();
223
224 if n_samples < 2 {
225 return Err(ClusteringError::InvalidInput(
226 "Need at least 2 samples for clustering".into(),
227 ));
228 }
229
230 let mst_edges = self.minimum_spanning_tree()?;
232
233 self.mst_to_linkage(mst_edges)
235 }
236
237 fn minimum_spanning_tree(&self) -> Result<Vec<(usize, usize, F)>> {
239 let n_samples = self.sparse_matrix.n_samples();
240 let mut mst_edges = Vec::new();
241 let mut visited = vec![false; n_samples];
242 let mut min_edge: HashMap<usize, (usize, F)> = HashMap::new();
243
244 visited[0] = true;
246
247 for neighbor_idx in 0..self.sparse_matrix.rows.len() {
249 let (i, j) = (
250 self.sparse_matrix.rows[neighbor_idx],
251 self.sparse_matrix.cols[neighbor_idx],
252 );
253 let distance = self.sparse_matrix.data[neighbor_idx];
254
255 if i == 0 && !visited[j] {
256 min_edge.insert(j, (i, distance));
257 } else if j == 0 && !visited[i] {
258 min_edge.insert(i, (j, distance));
259 }
260 }
261
262 for _ in 1..n_samples {
264 let mut min_dist = F::infinity();
266 let mut min_vertex = 0;
267 let mut min_parent = 0;
268
269 for (&vertex, &(parent, distance)) in &min_edge {
270 if !visited[vertex] && distance < min_dist {
271 min_dist = distance;
272 min_vertex = vertex;
273 min_parent = parent;
274 }
275 }
276
277 if min_dist == F::infinity() {
278 min_dist = self.sparse_matrix.default_value;
280 }
281
282 mst_edges.push((min_parent, min_vertex, min_dist));
284 visited[min_vertex] = true;
285
286 for neighbor_idx in 0..self.sparse_matrix.rows.len() {
288 let (i, j) = (
289 self.sparse_matrix.rows[neighbor_idx],
290 self.sparse_matrix.cols[neighbor_idx],
291 );
292 let distance = self.sparse_matrix.data[neighbor_idx];
293
294 let (from_vertex, to_vertex) = if i == min_vertex && !visited[j] {
295 (i, j)
296 } else if j == min_vertex && !visited[i] {
297 (j, i)
298 } else {
299 continue;
300 };
301
302 match min_edge.get(&to_vertex) {
304 Some(&(_, current_dist)) if distance < current_dist => {
305 min_edge.insert(to_vertex, (from_vertex, distance));
306 }
307 None => {
308 min_edge.insert(to_vertex, (from_vertex, distance));
309 }
310 _ => {}
311 }
312 }
313 }
314
315 Ok(mst_edges)
316 }
317
318 fn mst_to_linkage(&self, mut mst_edges: Vec<(usize, usize, F)>) -> Result<Array2<F>> {
320 let n_samples = self.sparse_matrix.n_samples();
321
322 match self.linkage_method {
324 LinkageMethod::Single => {
325 mst_edges.sort_by(|a, b| a.2.partial_cmp(&b.2).expect("Operation failed"));
327 }
328 _ => {
329 }
332 }
333
334 let mut linkage_matrix = Array2::zeros((n_samples - 1, 4));
335 let mut cluster_map: HashMap<usize, usize> = HashMap::new();
336 let mut next_cluster_id = n_samples;
337
338 for i in 0..n_samples {
340 cluster_map.insert(i, i);
341 }
342
343 for (step, (i, j, distance)) in mst_edges.iter().enumerate() {
344 let cluster_i = cluster_map[i];
345 let cluster_j = cluster_map[j];
346
347 linkage_matrix[[step, 0]] = F::from(cluster_i).expect("Failed to convert to float");
349 linkage_matrix[[step, 1]] = F::from(cluster_j).expect("Failed to convert to float");
350 linkage_matrix[[step, 2]] = *distance;
351 linkage_matrix[[step, 3]] = F::from(2).expect("Failed to convert constant to float"); cluster_map.insert(*i, next_cluster_id);
355 cluster_map.insert(*j, next_cluster_id);
356 next_cluster_id += 1;
357 }
358
359 Ok(linkage_matrix)
360 }
361}
362
363#[allow(dead_code)]
365pub fn sparse_knn_graph<F>(
366 data: ArrayView2<F>,
367 k: usize,
368 metric: Metric,
369) -> Result<SparseDistanceMatrix<F>>
370where
371 F: Float + FromPrimitive + Debug,
372{
373 let n_samples = data.shape()[0];
374 let n_features = data.shape()[1];
375
376 if k >= n_samples {
377 return Err(ClusteringError::InvalidInput(
378 "k must be less than number of samples".into(),
379 ));
380 }
381
382 let mut sparse_matrix = SparseDistanceMatrix::new(n_samples, F::infinity());
383
384 for i in 0..n_samples {
386 let mut distances: Vec<(usize, F)> = Vec::new();
387
388 for j in 0..n_samples {
390 if i == j {
391 continue;
392 }
393
394 let dist = match metric {
395 Metric::Euclidean => {
396 let mut sum = F::zero();
397 for k in 0..n_features {
398 let diff = data[[i, k]] - data[[j, k]];
399 sum = sum + diff * diff;
400 }
401 sum.sqrt()
402 }
403 Metric::Manhattan => {
404 let mut sum = F::zero();
405 for k in 0..n_features {
406 let diff = (data[[i, k]] - data[[j, k]]).abs();
407 sum = sum + diff;
408 }
409 sum
410 }
411 Metric::Chebyshev => {
412 let mut max_diff = F::zero();
413 for k in 0..n_features {
414 let diff = (data[[i, k]] - data[[j, k]]).abs();
415 if diff > max_diff {
416 max_diff = diff;
417 }
418 }
419 max_diff
420 }
421 Metric::Cosine => {
422 let mut dot = F::zero();
423 let mut norm_i = F::zero();
424 let mut norm_j = F::zero();
425 for k in 0..n_features {
426 let vi = data[[i, k]];
427 let vj = data[[j, k]];
428 dot = dot + vi * vj;
429 norm_i = norm_i + vi * vi;
430 norm_j = norm_j + vj * vj;
431 }
432 let norm_prod = (norm_i * norm_j).sqrt();
433 if norm_prod
434 < F::from_f64(1e-10).ok_or_else(|| {
435 ClusteringError::InvalidInput("float conversion failed".into())
436 })?
437 {
438 F::one()
439 } else {
440 F::one() - dot / norm_prod
441 }
442 }
443 Metric::Correlation => {
444 let n_f = F::from_usize(n_features).ok_or_else(|| {
445 ClusteringError::InvalidInput("float conversion failed".into())
446 })?;
447 let mut mean_i = F::zero();
448 let mut mean_j = F::zero();
449 for k in 0..n_features {
450 mean_i = mean_i + data[[i, k]];
451 mean_j = mean_j + data[[j, k]];
452 }
453 mean_i = mean_i / n_f;
454 mean_j = mean_j / n_f;
455
456 let mut numerator = F::zero();
457 let mut denom_i = F::zero();
458 let mut denom_j = F::zero();
459 for k in 0..n_features {
460 let di = data[[i, k]] - mean_i;
461 let dj = data[[j, k]] - mean_j;
462 numerator = numerator + di * dj;
463 denom_i = denom_i + di * di;
464 denom_j = denom_j + dj * dj;
465 }
466 let denom = (denom_i * denom_j).sqrt();
467 if denom
468 < F::from_f64(1e-10).ok_or_else(|| {
469 ClusteringError::InvalidInput("float conversion failed".into())
470 })?
471 {
472 F::zero()
473 } else {
474 F::one() - numerator / denom
475 }
476 }
477 };
478
479 distances.push((j, dist));
480 }
481
482 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
484 distances.truncate(k);
485
486 for (neighbor, distance) in distances {
488 sparse_matrix.add_distance(i, neighbor, distance)?;
489 }
490 }
491
492 Ok(sparse_matrix)
493}
494
495#[allow(dead_code)]
497pub fn sparse_epsilon_graph<F>(
498 data: ArrayView2<F>,
499 epsilon: F,
500 metric: Metric,
501) -> Result<SparseDistanceMatrix<F>>
502where
503 F: Float + FromPrimitive + Debug,
504{
505 let n_samples = data.shape()[0];
506 let n_features = data.shape()[1];
507
508 let mut sparse_matrix = SparseDistanceMatrix::new(n_samples, F::infinity());
509
510 for i in 0..n_samples {
512 for j in (i + 1)..n_samples {
513 let dist = match metric {
514 Metric::Euclidean => {
515 let mut sum = F::zero();
516 for k in 0..n_features {
517 let diff = data[[i, k]] - data[[j, k]];
518 sum = sum + diff * diff;
519 }
520 sum.sqrt()
521 }
522 Metric::Manhattan => {
523 let mut sum = F::zero();
524 for k in 0..n_features {
525 let diff = (data[[i, k]] - data[[j, k]]).abs();
526 sum = sum + diff;
527 }
528 sum
529 }
530 Metric::Chebyshev => {
531 let mut max_diff = F::zero();
532 for k in 0..n_features {
533 let diff = (data[[i, k]] - data[[j, k]]).abs();
534 if diff > max_diff {
535 max_diff = diff;
536 }
537 }
538 max_diff
539 }
540 Metric::Cosine => {
541 let mut dot = F::zero();
542 let mut norm_i = F::zero();
543 let mut norm_j = F::zero();
544 for k in 0..n_features {
545 let vi = data[[i, k]];
546 let vj = data[[j, k]];
547 dot = dot + vi * vj;
548 norm_i = norm_i + vi * vi;
549 norm_j = norm_j + vj * vj;
550 }
551 let norm_prod = (norm_i * norm_j).sqrt();
552 if norm_prod
553 < F::from_f64(1e-10).ok_or_else(|| {
554 ClusteringError::InvalidInput("float conversion failed".into())
555 })?
556 {
557 F::one()
558 } else {
559 F::one() - dot / norm_prod
560 }
561 }
562 Metric::Correlation => {
563 let n_f = F::from_usize(n_features).ok_or_else(|| {
564 ClusteringError::InvalidInput("float conversion failed".into())
565 })?;
566 let mut mean_i = F::zero();
567 let mut mean_j = F::zero();
568 for k in 0..n_features {
569 mean_i = mean_i + data[[i, k]];
570 mean_j = mean_j + data[[j, k]];
571 }
572 mean_i = mean_i / n_f;
573 mean_j = mean_j / n_f;
574
575 let mut numerator = F::zero();
576 let mut denom_i = F::zero();
577 let mut denom_j = F::zero();
578 for k in 0..n_features {
579 let di = data[[i, k]] - mean_i;
580 let dj = data[[j, k]] - mean_j;
581 numerator = numerator + di * dj;
582 denom_i = denom_i + di * di;
583 denom_j = denom_j + dj * dj;
584 }
585 let denom = (denom_i * denom_j).sqrt();
586 if denom
587 < F::from_f64(1e-10).ok_or_else(|| {
588 ClusteringError::InvalidInput("float conversion failed".into())
589 })?
590 {
591 F::zero()
592 } else {
593 F::one() - numerator / denom
594 }
595 }
596 };
597
598 if dist <= epsilon {
599 sparse_matrix.add_distance(i, j, dist)?;
600 }
601 }
602 }
603
604 Ok(sparse_matrix)
605}
606
607#[cfg(test)]
608mod tests {
609 use super::*;
610 use scirs2_core::ndarray::Array2;
611
612 #[test]
613 fn test_sparse_distance_matrix_creation() {
614 let sparse_matrix = SparseDistanceMatrix::<f64>::new(5, 0.0);
615 assert_eq!(sparse_matrix.n_samples(), 5);
616 assert_eq!(sparse_matrix.nnz(), 0);
617 assert_eq!(sparse_matrix.sparsity(), 1.0);
618 }
619
620 #[test]
621 fn test_sparse_distance_matrix_add_distance() {
622 let mut sparse_matrix = SparseDistanceMatrix::new(3, 0.0);
623
624 sparse_matrix
625 .add_distance(0, 1, 2.0)
626 .expect("Operation failed");
627 sparse_matrix
628 .add_distance(1, 2, 3.0)
629 .expect("Operation failed");
630
631 assert_eq!(sparse_matrix.get_distance(0, 1), 2.0);
632 assert_eq!(sparse_matrix.get_distance(1, 0), 2.0); assert_eq!(sparse_matrix.get_distance(1, 2), 3.0);
634 assert_eq!(sparse_matrix.get_distance(0, 2), 0.0); assert_eq!(sparse_matrix.nnz(), 2);
636 }
637
638 #[test]
639 fn test_sparse_from_dense() {
640 let dense =
641 Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 5.0, 1.0, 0.0, 2.0, 5.0, 2.0, 0.0])
642 .expect("Operation failed");
643
644 let sparse = SparseDistanceMatrix::from_dense(dense.view(), 1.5);
645
646 assert_eq!(sparse.nnz(), 2);
648 assert_eq!(sparse.get_distance(0, 2), 5.0);
649 assert_eq!(sparse.get_distance(1, 2), 2.0);
650 assert_eq!(sparse.get_distance(0, 1), 0.0); }
652
653 #[test]
654 fn test_neighbors_within_distance() {
655 let mut sparse_matrix = SparseDistanceMatrix::new(4, f64::INFINITY);
656
657 sparse_matrix
658 .add_distance(0, 1, 1.0)
659 .expect("Operation failed");
660 sparse_matrix
661 .add_distance(0, 2, 2.5)
662 .expect("Operation failed");
663 sparse_matrix
664 .add_distance(0, 3, 0.5)
665 .expect("Operation failed");
666
667 let neighbors = sparse_matrix.neighbors_within_distance(0, 2.0);
668
669 assert_eq!(neighbors.len(), 2);
671
672 let mut neighbor_distances: Vec<f64> = neighbors.iter().map(|(_, d)| *d).collect();
673 neighbor_distances.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
674 assert_eq!(neighbor_distances, vec![0.5, 1.0]);
675 }
676
677 #[test]
678 fn test_k_nearest_neighbors() {
679 let mut sparse_matrix = SparseDistanceMatrix::new(5, f64::INFINITY);
680
681 sparse_matrix
682 .add_distance(0, 1, 3.0)
683 .expect("Operation failed");
684 sparse_matrix
685 .add_distance(0, 2, 1.0)
686 .expect("Operation failed");
687 sparse_matrix
688 .add_distance(0, 3, 2.0)
689 .expect("Operation failed");
690 sparse_matrix
691 .add_distance(0, 4, 4.0)
692 .expect("Operation failed");
693
694 let knn = sparse_matrix.k_nearest_neighbors(0, 2);
695
696 assert_eq!(knn.len(), 2);
698 assert_eq!(knn[0], (2, 1.0)); assert_eq!(knn[1], (3, 2.0)); }
701
702 #[test]
703 fn test_sparse_knn_graph() {
704 let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 5.0, 5.0])
705 .expect("Operation failed");
706
707 let sparse_graph =
708 sparse_knn_graph(data.view(), 2, Metric::Euclidean).expect("Operation failed");
709
710 assert!(sparse_graph.nnz() > 0);
713 assert!(sparse_graph.sparsity() > 0.0);
714 }
715
716 #[test]
717 fn test_sparse_epsilon_graph() {
718 let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 5.0])
719 .expect("Operation failed");
720
721 let sparse_graph =
722 sparse_epsilon_graph(data.view(), 1.0, Metric::Euclidean).expect("Operation failed");
723
724 assert!(sparse_graph.nnz() >= 3); assert!(sparse_graph.get_distance(0, 1) <= 1.0);
730 assert!(sparse_graph.get_distance(0, 2) <= 1.0);
731 }
732
733 #[test]
734 fn test_to_dense() {
735 let mut sparse_matrix = SparseDistanceMatrix::new(3, f64::INFINITY);
736 sparse_matrix
737 .add_distance(0, 1, 2.0)
738 .expect("Operation failed");
739 sparse_matrix
740 .add_distance(1, 2, 3.0)
741 .expect("Operation failed");
742
743 let dense = sparse_matrix.to_dense();
744
745 assert_eq!(dense.shape(), &[3, 3]);
746 assert_eq!(dense[[0, 1]], 2.0);
747 assert_eq!(dense[[1, 0]], 2.0); assert_eq!(dense[[1, 2]], 3.0);
749 assert_eq!(dense[[2, 1]], 3.0); assert_eq!(dense[[0, 0]], 0.0); assert_eq!(dense[[0, 2]], f64::INFINITY); }
753
754 #[test]
755 fn test_sparse_knn_graph_chebyshev() {
756 let data = Array2::from_shape_vec((4, 2), vec![0.0_f64, 0.0, 1.0, 0.0, 0.0, 1.0, 5.0, 5.0])
758 .expect("Shape error");
759
760 let graph = sparse_knn_graph(data.view(), 2, Metric::Chebyshev)
761 .expect("sparse_knn_graph Chebyshev failed");
762 assert!(graph.nnz() > 0, "Chebyshev KNN graph should have edges");
763 assert!(
765 graph.get_distance(0, 1) > 0.0,
766 "points 0 and 1 should be neighbours"
767 );
768 }
769
770 #[test]
771 fn test_sparse_knn_graph_cosine() {
772 let data = Array2::from_shape_vec((4, 2), vec![1.0_f64, 0.0, 2.0, 0.0, 0.0, 1.0, 0.0, 3.0])
774 .expect("Shape error");
775
776 let graph = sparse_knn_graph(data.view(), 2, Metric::Cosine)
777 .expect("sparse_knn_graph Cosine failed");
778 assert!(graph.nnz() > 0, "Cosine KNN graph should have edges");
779 assert_eq!(
781 graph.get_distance(0, 1),
782 0.0,
783 "parallel vectors have cosine distance 0"
784 );
785 }
786
787 #[test]
788 fn test_sparse_knn_graph_correlation() {
789 let data = Array2::from_shape_vec(
791 (3, 4),
792 vec![
793 1.0_f64, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 4.0, 3.0, 2.0, 1.0, ],
797 )
798 .expect("Shape error");
799
800 let graph = sparse_knn_graph(data.view(), 2, Metric::Correlation)
801 .expect("sparse_knn_graph Correlation failed");
802 assert!(graph.nnz() > 0, "Correlation KNN graph should have edges");
803 let d01 = graph.get_distance(0, 1);
805 assert!(
806 d01 < 1e-9,
807 "perfectly correlated rows have correlation distance ≈ 0, got {d01}"
808 );
809 }
810
811 #[test]
812 fn test_sparse_epsilon_graph_chebyshev() {
813 let data = Array2::from_shape_vec((3, 2), vec![0.0_f64, 0.0, 0.8, 0.0, 5.0, 5.0])
814 .expect("Shape error");
815
816 let graph = sparse_epsilon_graph(data.view(), 1.0, Metric::Chebyshev)
818 .expect("sparse_epsilon_graph Chebyshev failed");
819 assert!(
820 graph.get_distance(0, 1) > 0.0,
821 "points 0 and 1 should be connected under Chebyshev"
822 );
823 assert_eq!(
825 graph.get_distance(0, 2),
826 f64::INFINITY,
827 "distant point should be disconnected"
828 );
829 }
830
831 #[test]
832 fn test_sparse_epsilon_graph_cosine() {
833 let data = Array2::from_shape_vec((3, 2), vec![1.0_f64, 0.0, 2.0, 0.0, 0.0, 1.0])
835 .expect("Shape error");
836
837 let graph = sparse_epsilon_graph(data.view(), 0.5, Metric::Cosine)
838 .expect("sparse_epsilon_graph Cosine failed");
839 assert!(
841 graph.get_distance(0, 1) < 0.5,
842 "parallel vectors connected under cosine epsilon graph"
843 );
844 assert_eq!(
846 graph.get_distance(0, 2),
847 f64::INFINITY,
848 "orthogonal vector should be disconnected"
849 );
850 }
851
852 #[test]
853 fn test_sparse_epsilon_graph_correlation() {
854 let data = Array2::from_shape_vec(
856 (3, 4),
857 vec![
858 1.0_f64, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 4.0, 3.0, 2.0, 1.0,
859 ],
860 )
861 .expect("Shape error");
862
863 let graph = sparse_epsilon_graph(data.view(), 0.01, Metric::Correlation)
864 .expect("sparse_epsilon_graph Correlation failed");
865 let d01 = graph.get_distance(0, 1);
867 assert!(
868 d01 < 0.01,
869 "perfectly correlated rows connected under correlation epsilon graph, got {d01}"
870 );
871 }
872}