1use scirs2_core::ndarray::{ArrayView1, ArrayView2};
7use sklears_core::types::Float;
8
9#[derive(Debug, Clone, Copy)]
11pub enum DistanceMetric {
12 Euclidean,
13 Manhattan,
14 Chebyshev,
15 Cosine,
16 Minkowski(Float),
17 Jaccard,
18}
19
20pub struct OptimizedDistanceComputer {
22 simd_available: bool,
24 block_size: usize,
26}
27
28impl Default for OptimizedDistanceComputer {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl OptimizedDistanceComputer {
35 pub fn new() -> Self {
37 Self {
38 simd_available: Self::detect_simd_support(),
39 block_size: Self::optimal_block_size(),
40 }
41 }
42
43 fn detect_simd_support() -> bool {
45 #[cfg(target_arch = "x86_64")]
46 {
47 std::arch::is_x86_feature_detected!("avx2")
48 || std::arch::is_x86_feature_detected!("sse2")
49 }
50 #[cfg(target_arch = "aarch64")]
51 {
52 std::arch::is_aarch64_feature_detected!("neon")
53 }
54 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
55 {
56 false
57 }
58 }
59
60 fn optimal_block_size() -> usize {
62 48
66 }
67
68 pub fn pairwise_distances(
70 &self,
71 points1: &ArrayView2<Float>,
72 points2: &ArrayView2<Float>,
73 metric: DistanceMetric,
74 ) -> scirs2_core::ndarray::Array2<Float> {
75 let (n1, _) = points1.dim();
76 let (n2, _) = points2.dim();
77 let mut distances = scirs2_core::ndarray::Array2::zeros((n1, n2));
78
79 for i_start in (0..n1).step_by(self.block_size) {
81 let i_end = (i_start + self.block_size).min(n1);
82
83 for j_start in (0..n2).step_by(self.block_size) {
84 let j_end = (j_start + self.block_size).min(n2);
85
86 for i in i_start..i_end {
88 for j in j_start..j_end {
89 let point1 = points1.row(i);
90 let point2 = points2.row(j);
91
92 distances[[i, j]] = self.compute_distance(&point1, &point2, metric);
93 }
94 }
95 }
96 }
97
98 distances
99 }
100
101 pub fn compute_distance(
103 &self,
104 point1: &ArrayView1<Float>,
105 point2: &ArrayView1<Float>,
106 metric: DistanceMetric,
107 ) -> Float {
108 if self.simd_available && point1.len() >= 4 {
109 self.compute_distance_simd(point1, point2, metric)
110 } else {
111 self.compute_distance_scalar(point1, point2, metric)
112 }
113 }
114
115 fn compute_distance_simd(
117 &self,
118 point1: &ArrayView1<Float>,
119 point2: &ArrayView1<Float>,
120 metric: DistanceMetric,
121 ) -> Float {
122 let a = point1.as_slice().unwrap();
124 let b = point2.as_slice().unwrap();
125
126 match metric {
127 DistanceMetric::Euclidean => self.euclidean_simd(a, b),
128 DistanceMetric::Manhattan => self.manhattan_simd(a, b),
129 DistanceMetric::Chebyshev => self.chebyshev_simd(a, b),
130 DistanceMetric::Cosine => self.cosine_simd(a, b),
131 DistanceMetric::Minkowski(p) => self.minkowski_simd(a, b, p),
132 DistanceMetric::Jaccard => self.jaccard_simd(a, b),
133 }
134 }
135
136 fn compute_distance_scalar(
138 &self,
139 point1: &ArrayView1<Float>,
140 point2: &ArrayView1<Float>,
141 metric: DistanceMetric,
142 ) -> Float {
143 let a = point1.as_slice().unwrap();
144 let b = point2.as_slice().unwrap();
145
146 match metric {
147 DistanceMetric::Euclidean => fallback_distance::euclidean_distance(a, b),
148 DistanceMetric::Manhattan => fallback_distance::manhattan_distance(a, b),
149 DistanceMetric::Chebyshev => fallback_distance::chebyshev_distance(a, b),
150 DistanceMetric::Cosine => fallback_distance::cosine_distance(a, b),
151 DistanceMetric::Minkowski(p) => fallback_distance::minkowski_distance(a, b, p),
152 DistanceMetric::Jaccard => fallback_distance::jaccard_distance(a, b),
153 }
154 }
155
156 fn euclidean_simd(&self, a: &[Float], b: &[Float]) -> Float {
158 #[cfg(target_arch = "x86_64")]
159 {
160 if std::arch::is_x86_feature_detected!("avx2") {
161 return unsafe { self.euclidean_avx2(a, b) };
162 }
163 }
164
165 self.euclidean_unrolled(a, b)
167 }
168
169 fn euclidean_unrolled(&self, a: &[Float], b: &[Float]) -> Float {
171 let mut sum = 0.0;
172 let len = a.len();
173 let chunks = len / 4;
174
175 for i in 0..chunks {
177 let base = i * 4;
178 let diff1 = a[base] - b[base];
179 let diff2 = a[base + 1] - b[base + 1];
180 let diff3 = a[base + 2] - b[base + 2];
181 let diff4 = a[base + 3] - b[base + 3];
182
183 sum += diff1 * diff1 + diff2 * diff2 + diff3 * diff3 + diff4 * diff4;
184 }
185
186 for i in (chunks * 4)..len {
188 let diff = a[i] - b[i];
189 sum += diff * diff;
190 }
191
192 sum.sqrt()
193 }
194
195 #[cfg(target_arch = "x86_64")]
197 unsafe fn euclidean_avx2(&self, a: &[Float], b: &[Float]) -> Float {
198 self.euclidean_unrolled(a, b)
201 }
202
203 fn manhattan_simd(&self, a: &[Float], b: &[Float]) -> Float {
205 let mut sum = 0.0;
207 let len = a.len();
208 let chunks = len / 4;
209
210 for i in 0..chunks {
211 let base = i * 4;
212 sum += (a[base] - b[base]).abs()
213 + (a[base + 1] - b[base + 1]).abs()
214 + (a[base + 2] - b[base + 2]).abs()
215 + (a[base + 3] - b[base + 3]).abs();
216 }
217
218 for i in (chunks * 4)..len {
219 sum += (a[i] - b[i]).abs();
220 }
221
222 sum
223 }
224
225 fn chebyshev_simd(&self, a: &[Float], b: &[Float]) -> Float {
227 let mut max_diff = 0.0;
228
229 for (x, y) in a.iter().zip(b.iter()) {
230 let diff = (x - y).abs();
231 if diff > max_diff {
232 max_diff = diff;
233 }
234 }
235
236 max_diff
237 }
238
239 fn cosine_simd(&self, a: &[Float], b: &[Float]) -> Float {
241 let mut dot = 0.0;
242 let mut norm_a_sq = 0.0;
243 let mut norm_b_sq = 0.0;
244
245 let len = a.len();
246 let chunks = len / 4;
247
248 for i in 0..chunks {
250 let base = i * 4;
251 for j in 0..4 {
252 let idx = base + j;
253 dot += a[idx] * b[idx];
254 norm_a_sq += a[idx] * a[idx];
255 norm_b_sq += b[idx] * b[idx];
256 }
257 }
258
259 for i in (chunks * 4)..len {
260 dot += a[i] * b[i];
261 norm_a_sq += a[i] * a[i];
262 norm_b_sq += b[i] * b[i];
263 }
264
265 1.0 - (dot / (norm_a_sq.sqrt() * norm_b_sq.sqrt()))
266 }
267
268 fn minkowski_simd(&self, a: &[Float], b: &[Float], p: Float) -> Float {
270 let mut sum = 0.0;
271
272 for (x, y) in a.iter().zip(b.iter()) {
273 sum += (x - y).abs().powf(p);
274 }
275
276 sum.powf(1.0 / p)
277 }
278
279 fn jaccard_simd(&self, a: &[Float], b: &[Float]) -> Float {
281 let mut intersection = 0.0;
282 let mut union = 0.0;
283
284 for (x, y) in a.iter().zip(b.iter()) {
285 intersection += x.min(*y);
286 union += x.max(*y);
287 }
288
289 1.0 - (intersection / union)
290 }
291}
292
293mod fallback_distance {
295 use super::Float;
296
297 pub fn euclidean_distance(a: &[Float], b: &[Float]) -> Float {
298 a.iter()
299 .zip(b.iter())
300 .map(|(x, y)| (x - y).powi(2))
301 .sum::<Float>()
302 .sqrt()
303 }
304
305 pub fn manhattan_distance(a: &[Float], b: &[Float]) -> Float {
306 a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
307 }
308
309 pub fn chebyshev_distance(a: &[Float], b: &[Float]) -> Float {
310 a.iter()
311 .zip(b.iter())
312 .map(|(x, y)| (x - y).abs())
313 .fold(0.0, Float::max)
314 }
315
316 pub fn cosine_distance(a: &[Float], b: &[Float]) -> Float {
317 let dot: Float = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
318 let norm_a: Float = a.iter().map(|x| x * x).sum::<Float>().sqrt();
319 let norm_b: Float = b.iter().map(|x| x * x).sum::<Float>().sqrt();
320 1.0 - (dot / (norm_a * norm_b))
321 }
322
323 pub fn minkowski_distance(a: &[Float], b: &[Float], p: Float) -> Float {
324 a.iter()
325 .zip(b.iter())
326 .map(|(x, y)| (x - y).abs().powf(p))
327 .sum::<Float>()
328 .powf(1.0 / p)
329 }
330
331 pub fn jaccard_distance(a: &[Float], b: &[Float]) -> Float {
332 let intersection: Float = a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).sum();
333 let union: Float = a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).sum();
334 1.0 - (intersection / union)
335 }
336}
337
338#[derive(Debug, Clone, Copy, PartialEq)]
340pub enum SimdDistanceMetric {
341 Euclidean,
343 EuclideanSquared,
345 Manhattan,
347 Chebyshev,
349 Cosine,
351 CosineSimilarity,
353 Minkowski(Float),
355 Jaccard,
357 Hamming,
359 Canberra,
361 Braycurtis,
363 Mahalanobis,
365 Correlation,
367 Wasserstein,
369}
370
371pub fn simd_distance(
373 point1: &ArrayView1<Float>,
374 point2: &ArrayView1<Float>,
375 metric: SimdDistanceMetric,
376) -> Result<Float, Box<dyn std::error::Error>> {
377 let a = point1.as_slice().unwrap();
379 let b = point2.as_slice().unwrap();
380
381 let result = match metric {
382 SimdDistanceMetric::Euclidean => fallback_distance::euclidean_distance(a, b),
383 SimdDistanceMetric::EuclideanSquared => {
384 let euclidean = fallback_distance::euclidean_distance(a, b);
385 euclidean * euclidean
386 }
387 SimdDistanceMetric::Manhattan => fallback_distance::manhattan_distance(a, b),
388 SimdDistanceMetric::Chebyshev => fallback_distance::chebyshev_distance(a, b),
389 SimdDistanceMetric::Cosine => fallback_distance::cosine_distance(a, b),
390 SimdDistanceMetric::CosineSimilarity => 1.0 - fallback_distance::cosine_distance(a, b),
391 SimdDistanceMetric::Minkowski(p) => fallback_distance::minkowski_distance(a, b, p),
392 SimdDistanceMetric::Jaccard => fallback_distance::jaccard_distance(a, b),
393 SimdDistanceMetric::Hamming => hamming_distance_simd(a, b),
394 SimdDistanceMetric::Canberra => canberra_distance_simd(a, b),
395 SimdDistanceMetric::Braycurtis => braycurtis_distance_simd(a, b),
396 SimdDistanceMetric::Mahalanobis => {
397 return Err("Mahalanobis distance requires covariance matrix parameter".into());
398 }
399 SimdDistanceMetric::Correlation => correlation_distance_simd(a, b),
400 SimdDistanceMetric::Wasserstein => wasserstein_distance_simd(a, b),
401 };
402
403 Ok(result as Float)
404}
405
406pub fn simd_squared_euclidean_distance(
408 point1: &ArrayView1<Float>,
409 point2: &ArrayView1<Float>,
410) -> Result<Float, Box<dyn std::error::Error>> {
411 let a = point1.as_slice().unwrap();
412 let b = point2.as_slice().unwrap();
413
414 let euclidean = fallback_distance::euclidean_distance(a, b);
415 Ok(euclidean * euclidean)
416}
417
418pub fn simd_distance_batch(
420 points: &[scirs2_core::ndarray::Array1<Float>],
421 queries: &[scirs2_core::ndarray::Array1<Float>],
422 metric: SimdDistanceMetric,
423) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
424 if points.len() != queries.len() {
425 return Err("Points and queries must have the same length".into());
426 }
427
428 let mut results = Vec::with_capacity(points.len());
429
430 for (point, query) in points.iter().zip(queries.iter()) {
431 let point_slice = point.as_slice().unwrap();
432 let query_slice = query.as_slice().unwrap();
433
434 let distance = match metric {
435 SimdDistanceMetric::Euclidean => {
436 fallback_distance::euclidean_distance(point_slice, query_slice)
437 }
438 SimdDistanceMetric::EuclideanSquared => {
439 let euclidean = fallback_distance::euclidean_distance(point_slice, query_slice);
440 euclidean * euclidean
441 }
442 SimdDistanceMetric::Manhattan => {
443 fallback_distance::manhattan_distance(point_slice, query_slice)
444 }
445 SimdDistanceMetric::Chebyshev => {
446 fallback_distance::chebyshev_distance(point_slice, query_slice)
447 }
448 SimdDistanceMetric::Cosine => {
449 fallback_distance::cosine_distance(point_slice, query_slice)
450 }
451 SimdDistanceMetric::CosineSimilarity => {
452 1.0 - fallback_distance::cosine_distance(point_slice, query_slice)
453 }
454 SimdDistanceMetric::Minkowski(p) => {
455 fallback_distance::minkowski_distance(point_slice, query_slice, p)
456 }
457 SimdDistanceMetric::Jaccard => {
458 fallback_distance::jaccard_distance(point_slice, query_slice)
459 }
460 SimdDistanceMetric::Hamming => hamming_distance_simd(point_slice, query_slice),
461 SimdDistanceMetric::Canberra => canberra_distance_simd(point_slice, query_slice),
462 SimdDistanceMetric::Braycurtis => braycurtis_distance_simd(point_slice, query_slice),
463 SimdDistanceMetric::Mahalanobis => {
464 return Err("Mahalanobis distance requires covariance matrix parameter".into());
465 }
466 SimdDistanceMetric::Correlation => correlation_distance_simd(point_slice, query_slice),
467 SimdDistanceMetric::Wasserstein => wasserstein_distance_simd(point_slice, query_slice),
468 };
469
470 results.push(distance);
471 }
472
473 Ok(results)
474}
475
476pub fn simd_distance_batch_query(
478 points: &ArrayView2<Float>,
479 query: &ArrayView1<Float>,
480 metric: SimdDistanceMetric,
481) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
482 let query_slice = query.as_slice().unwrap();
483 let mut results = Vec::with_capacity(points.nrows());
484
485 for i in 0..points.nrows() {
486 let point = points.row(i);
487 let point_slice = point.as_slice().unwrap();
488
489 let distance = match metric {
490 SimdDistanceMetric::Euclidean => {
491 fallback_distance::euclidean_distance(point_slice, query_slice)
492 }
493 SimdDistanceMetric::EuclideanSquared => {
494 let euclidean = fallback_distance::euclidean_distance(point_slice, query_slice);
495 euclidean * euclidean
496 }
497 SimdDistanceMetric::Manhattan => {
498 fallback_distance::manhattan_distance(point_slice, query_slice)
499 }
500 SimdDistanceMetric::Chebyshev => {
501 fallback_distance::chebyshev_distance(point_slice, query_slice)
502 }
503 SimdDistanceMetric::Cosine => {
504 fallback_distance::cosine_distance(point_slice, query_slice)
505 }
506 SimdDistanceMetric::CosineSimilarity => {
507 1.0 - fallback_distance::cosine_distance(point_slice, query_slice)
508 }
509 SimdDistanceMetric::Minkowski(p) => {
510 fallback_distance::minkowski_distance(point_slice, query_slice, p)
511 }
512 SimdDistanceMetric::Jaccard => {
513 fallback_distance::jaccard_distance(point_slice, query_slice)
514 }
515 SimdDistanceMetric::Hamming => hamming_distance_simd(point_slice, query_slice),
516 SimdDistanceMetric::Canberra => canberra_distance_simd(point_slice, query_slice),
517 SimdDistanceMetric::Braycurtis => braycurtis_distance_simd(point_slice, query_slice),
518 SimdDistanceMetric::Mahalanobis => {
519 return Err("Mahalanobis distance requires covariance matrix parameter".into());
520 }
521 SimdDistanceMetric::Correlation => correlation_distance_simd(point_slice, query_slice),
522 SimdDistanceMetric::Wasserstein => wasserstein_distance_simd(point_slice, query_slice),
523 };
524
525 results.push(distance);
526 }
527
528 Ok(results)
529}
530
531#[cfg(feature = "parallel")]
533pub fn simd_distance_batch_parallel(
534 points: &ArrayView2<Float>,
535 query: &ArrayView1<Float>,
536 metric: SimdDistanceMetric,
537) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
538 use rayon::prelude::*;
539
540 let query_vec: Vec<Float> = query.iter().copied().collect();
541
542 let results: Vec<Float> = (0..points.nrows())
543 .into_par_iter()
544 .map(|i| {
545 let point = points.row(i);
546 let point_vec: Vec<Float> = point.iter().copied().collect();
547
548 let distance = match metric {
549 SimdDistanceMetric::Euclidean => {
550 fallback_distance::euclidean_distance(&point_vec, &query_vec)
551 }
552 SimdDistanceMetric::EuclideanSquared => {
553 let euclidean = fallback_distance::euclidean_distance(&point_vec, &query_vec);
554 euclidean * euclidean
555 }
556 SimdDistanceMetric::Manhattan => {
557 fallback_distance::manhattan_distance(&point_vec, &query_vec)
558 }
559 SimdDistanceMetric::Chebyshev => {
560 fallback_distance::chebyshev_distance(&point_vec, &query_vec)
561 }
562 SimdDistanceMetric::Cosine => {
563 fallback_distance::cosine_distance(&point_vec, &query_vec)
564 }
565 SimdDistanceMetric::CosineSimilarity => {
566 1.0 - fallback_distance::cosine_distance(&point_vec, &query_vec)
567 }
568 SimdDistanceMetric::Minkowski(p) => {
569 fallback_distance::minkowski_distance(&point_vec, &query_vec, p)
570 }
571 SimdDistanceMetric::Jaccard => {
572 fallback_distance::jaccard_distance(&point_vec, &query_vec)
573 }
574 SimdDistanceMetric::Hamming => hamming_distance_simd(&point_vec, &query_vec),
575 SimdDistanceMetric::Canberra => canberra_distance_simd(&point_vec, &query_vec),
576 SimdDistanceMetric::Braycurtis => braycurtis_distance_simd(&point_vec, &query_vec),
577 SimdDistanceMetric::Mahalanobis => {
578 return Err(Box::new(std::io::Error::new(
579 std::io::ErrorKind::InvalidInput,
580 "Mahalanobis distance requires covariance matrix parameter",
581 )));
582 }
583 SimdDistanceMetric::Correlation => {
584 correlation_distance_simd(&point_vec, &query_vec)
585 }
586 SimdDistanceMetric::Wasserstein => {
587 wasserstein_distance_simd(&point_vec, &query_vec)
588 }
589 };
590
591 Ok(distance as Float)
592 })
593 .collect::<Result<Vec<_>, _>>()?;
594
595 Ok(results)
596}
597
598pub fn simd_pairwise_distances(
600 points: &ArrayView2<Float>,
601 metric: SimdDistanceMetric,
602) -> Result<Vec<Vec<Float>>, Box<dyn std::error::Error>> {
603 let n_points = points.nrows();
604 let mut distances = vec![vec![0.0; n_points]; n_points];
605
606 for i in 0..n_points {
607 let point_i = points.row(i);
608 for j in (i + 1)..n_points {
609 let point_j = points.row(j);
610 let dist = simd_distance(&point_i, &point_j, metric)?;
611 distances[i][j] = dist;
612 distances[j][i] = dist; }
614 }
615
616 Ok(distances)
617}
618
619pub fn simd_k_nearest_neighbors(
621 points: &ArrayView2<Float>,
622 query: &ArrayView1<Float>,
623 k: usize,
624 metric: SimdDistanceMetric,
625) -> Result<Vec<(usize, Float)>, Box<dyn std::error::Error>> {
626 let distances = simd_distance_batch_query(points, query, metric)?;
627
628 let mut indexed_distances: Vec<(usize, Float)> = distances.into_iter().enumerate().collect();
629
630 indexed_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
632 indexed_distances.truncate(k);
633
634 Ok(indexed_distances)
635}
636
637pub fn simd_radius_neighbors(
639 points: &ArrayView2<Float>,
640 query: &ArrayView1<Float>,
641 radius: Float,
642 metric: SimdDistanceMetric,
643) -> Result<Vec<usize>, Box<dyn std::error::Error>> {
644 let distances = simd_distance_batch_query(points, query, metric)?;
645
646 let neighbors: Vec<usize> = distances
647 .into_iter()
648 .enumerate()
649 .filter_map(|(idx, dist)| if dist <= radius { Some(idx) } else { None })
650 .collect();
651
652 Ok(neighbors)
653}
654
655pub fn simd_distance_matrix(
657 points: &ArrayView2<Float>,
658 metric: SimdDistanceMetric,
659) -> Result<scirs2_core::ndarray::Array2<Float>, Box<dyn std::error::Error>> {
660 let n_points = points.nrows();
661 let mut matrix = scirs2_core::ndarray::Array2::zeros((n_points, n_points));
662
663 for i in 0..n_points {
664 let point_i = points.row(i);
665 for j in (i + 1)..n_points {
666 let point_j = points.row(j);
667 let dist = simd_distance(&point_i, &point_j, metric)?;
668 matrix[[i, j]] = dist;
669 matrix[[j, i]] = dist;
670 }
671 }
672
673 Ok(matrix)
674}
675
676pub fn benchmark_simd_vs_scalar(
678 points: &ArrayView2<Float>,
679 query: &ArrayView1<Float>,
680 metric: SimdDistanceMetric,
681) -> (f64, f64) {
682 use std::time::Instant;
683
684 let start = Instant::now();
686 let _simd_result = simd_distance_batch_query(points, query, metric).unwrap();
687 let simd_time = start.elapsed().as_secs_f64();
688
689 let start = Instant::now();
691 let _scalar_result = scalar_distance_batch(points, query, metric);
692 let scalar_time = start.elapsed().as_secs_f64();
693
694 (simd_time, scalar_time)
695}
696
697fn scalar_distance_batch(
699 points: &ArrayView2<Float>,
700 query: &ArrayView1<Float>,
701 metric: SimdDistanceMetric,
702) -> Vec<Float> {
703 let mut results = Vec::with_capacity(points.nrows());
704
705 for i in 0..points.nrows() {
706 let point = points.row(i);
707 let dist = match metric {
708 SimdDistanceMetric::Euclidean => {
709 let mut sum = 0.0;
710 for (&a, &b) in point.iter().zip(query.iter()) {
711 let diff = a - b;
712 sum += diff * diff;
713 }
714 sum.sqrt()
715 }
716 SimdDistanceMetric::Manhattan => {
717 let mut sum = 0.0;
718 for (&a, &b) in point.iter().zip(query.iter()) {
719 sum += (a - b).abs();
720 }
721 sum
722 }
723 SimdDistanceMetric::Chebyshev => {
724 let mut max_diff = 0.0;
725 for (&a, &b) in point.iter().zip(query.iter()) {
726 let diff = (a - b).abs();
727 if diff > max_diff {
728 max_diff = diff;
729 }
730 }
731 max_diff
732 }
733 SimdDistanceMetric::Cosine => {
734 let mut dot = 0.0;
735 let mut norm_a = 0.0;
736 let mut norm_b = 0.0;
737 for (&a, &b) in point.iter().zip(query.iter()) {
738 dot += a * b;
739 norm_a += a * a;
740 norm_b += b * b;
741 }
742 let norm_product = norm_a.sqrt() * norm_b.sqrt();
743 if norm_product == 0.0 {
744 0.0
745 } else {
746 1.0 - (dot / norm_product)
747 }
748 }
749 SimdDistanceMetric::CosineSimilarity => {
750 let mut dot = 0.0;
751 let mut norm_a = 0.0;
752 let mut norm_b = 0.0;
753 for (&a, &b) in point.iter().zip(query.iter()) {
754 dot += a * b;
755 norm_a += a * a;
756 norm_b += b * b;
757 }
758 let norm_product = norm_a.sqrt() * norm_b.sqrt();
759 if norm_product == 0.0 {
760 0.0
761 } else {
762 dot / norm_product
763 }
764 }
765 SimdDistanceMetric::Minkowski(p) => {
766 let mut sum = 0.0;
767 for (&a, &b) in point.iter().zip(query.iter()) {
768 sum += (a - b).abs().powf(p as Float);
769 }
770 sum.powf(1.0 / p as Float)
771 }
772 SimdDistanceMetric::Jaccard => {
773 let mut intersection = 0.0;
774 let mut union = 0.0;
775 for (&a, &b) in point.iter().zip(query.iter()) {
776 intersection += a.min(b);
777 union += a.max(b);
778 }
779 if union == 0.0 {
780 0.0
781 } else {
782 1.0 - (intersection / union)
783 }
784 }
785 SimdDistanceMetric::EuclideanSquared => {
786 let mut sum = 0.0;
787 for (&a, &b) in point.iter().zip(query.iter()) {
788 let diff = a - b;
789 sum += diff * diff;
790 }
791 sum
792 }
793 SimdDistanceMetric::Hamming => {
794 let mut count = 0.0;
795 for (&a, &b) in point.iter().zip(query.iter()) {
796 if (a - b).abs() > Float::EPSILON {
797 count += 1.0;
798 }
799 }
800 count
801 }
802 SimdDistanceMetric::Canberra => {
803 let mut sum = 0.0;
804 for (&a, &b) in point.iter().zip(query.iter()) {
805 let numerator = (a - b).abs();
806 let denominator = a.abs() + b.abs();
807 if denominator > 0.0 {
808 sum += numerator / denominator;
809 }
810 }
811 sum
812 }
813 SimdDistanceMetric::Braycurtis => {
814 let mut numerator = 0.0;
815 let mut denominator = 0.0;
816 for (&a, &b) in point.iter().zip(query.iter()) {
817 numerator += (a - b).abs();
818 denominator += a.abs() + b.abs();
819 }
820 if denominator == 0.0 {
821 0.0
822 } else {
823 numerator / denominator
824 }
825 }
826 SimdDistanceMetric::Mahalanobis => {
827 0.0
829 }
830 SimdDistanceMetric::Correlation => {
831 let n = point.len() as Float;
833 let sum_a: Float = point.iter().sum();
834 let sum_b: Float = query.iter().sum();
835 let mean_a = sum_a / n;
836 let mean_b = sum_b / n;
837
838 let mut numerator = 0.0;
839 let mut var_a = 0.0;
840 let mut var_b = 0.0;
841
842 for (&a, &b) in point.iter().zip(query.iter()) {
843 let diff_a = a - mean_a;
844 let diff_b = b - mean_b;
845 numerator += diff_a * diff_b;
846 var_a += diff_a * diff_a;
847 var_b += diff_b * diff_b;
848 }
849
850 let denominator = (var_a * var_b).sqrt();
851 if denominator == 0.0 {
852 0.0
853 } else {
854 1.0 - (numerator / denominator)
855 }
856 }
857 SimdDistanceMetric::Wasserstein => {
858 let mut sorted_a: Vec<Float> = point.iter().cloned().collect();
860 let mut sorted_b: Vec<Float> = query.iter().cloned().collect();
861 sorted_a.sort_by(|a, b| a.partial_cmp(b).unwrap());
862 sorted_b.sort_by(|a, b| a.partial_cmp(b).unwrap());
863
864 let mut sum = 0.0;
865 for (a, b) in sorted_a.iter().zip(sorted_b.iter()) {
866 sum += (a - b).abs();
867 }
868 sum / point.len() as Float
869 }
870 };
871 results.push(dist);
872 }
873
874 results
875}
876
877pub fn adaptive_distance_batch(
879 points: &ArrayView2<Float>,
880 query: &ArrayView1<Float>,
881 metric: SimdDistanceMetric,
882 simd_threshold: usize,
883) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
884 if points.nrows() >= simd_threshold && query.len() >= 4 {
885 simd_distance_batch_query(points, query, metric)
886 } else {
887 Ok(scalar_distance_batch(points, query, metric))
888 }
889}
890
891pub fn custom_distance<F>(
893 point1: &ArrayView1<Float>,
894 point2: &ArrayView1<Float>,
895 distance_fn: F,
896) -> Float
897where
898 F: Fn(&ArrayView1<Float>, &ArrayView1<Float>) -> Float,
899{
900 distance_fn(point1, point2)
901}
902
903pub fn mahalanobis_distance(
905 point1: &ArrayView1<Float>,
906 point2: &ArrayView1<Float>,
907 cov_inv: &scirs2_core::ndarray::Array2<Float>,
908) -> Result<Float, Box<dyn std::error::Error>> {
909 if point1.len() != point2.len() {
910 return Err("Points must have the same dimensions".into());
911 }
912
913 if cov_inv.nrows() != point1.len() || cov_inv.ncols() != point1.len() {
914 return Err("Covariance matrix dimensions must match point dimensions".into());
915 }
916
917 let diff: scirs2_core::ndarray::Array1<Float> = point1.to_owned() - point2;
918 let temp = cov_inv.dot(&diff);
919 let distance_squared = diff.dot(&temp);
920
921 Ok(distance_squared.sqrt())
922}
923
924pub fn categorical_distance(
926 point1: &ArrayView1<Float>,
927 point2: &ArrayView1<Float>,
928 metric: CategoricalDistanceMetric,
929) -> Float {
930 match metric {
931 CategoricalDistanceMetric::Hamming => {
932 let mut count = 0.0;
933 for (&a, &b) in point1.iter().zip(point2.iter()) {
934 if (a - b).abs() > Float::EPSILON {
935 count += 1.0;
936 }
937 }
938 count / point1.len() as Float
939 }
940 CategoricalDistanceMetric::MatchingDissimilarity => {
941 let mut mismatches = 0.0;
942 for (&a, &b) in point1.iter().zip(point2.iter()) {
943 if (a - b).abs() > Float::EPSILON {
944 mismatches += 1.0;
945 }
946 }
947 mismatches / point1.len() as Float
948 }
949 }
950}
951
952#[derive(Debug, Clone, Copy, PartialEq)]
954pub enum CategoricalDistanceMetric {
955 Hamming,
957 MatchingDissimilarity,
959}
960
961pub fn weighted_distance(
963 point1: &ArrayView1<Float>,
964 point2: &ArrayView1<Float>,
965 weights: &ArrayView1<Float>,
966 metric: SimdDistanceMetric,
967) -> Result<Float, Box<dyn std::error::Error>> {
968 if point1.len() != point2.len() || point1.len() != weights.len() {
969 return Err("All arrays must have the same length".into());
970 }
971
972 match metric {
973 SimdDistanceMetric::Euclidean => {
974 let mut sum = 0.0;
975 for ((&a, &b), &w) in point1.iter().zip(point2.iter()).zip(weights.iter()) {
976 let diff = a - b;
977 sum += w * diff * diff;
978 }
979 Ok(sum.sqrt())
980 }
981 SimdDistanceMetric::Manhattan => {
982 let mut sum = 0.0;
983 for ((&a, &b), &w) in point1.iter().zip(point2.iter()).zip(weights.iter()) {
984 sum += w * (a - b).abs();
985 }
986 Ok(sum)
987 }
988 _ => {
989 let weighted_p1: scirs2_core::ndarray::Array1<Float> = point1
991 .iter()
992 .zip(weights.iter())
993 .map(|(&p, &w)| p * w.sqrt())
994 .collect();
995 let weighted_p2: scirs2_core::ndarray::Array1<Float> = point2
996 .iter()
997 .zip(weights.iter())
998 .map(|(&p, &w)| p * w.sqrt())
999 .collect();
1000 simd_distance(&weighted_p1.view(), &weighted_p2.view(), metric)
1001 }
1002 }
1003}
1004
1005fn hamming_distance_simd(a: &[Float], b: &[Float]) -> Float {
1007 let mut count = 0.0;
1008 for (&x, &y) in a.iter().zip(b.iter()) {
1009 if (x - y).abs() > Float::EPSILON {
1010 count += 1.0;
1011 }
1012 }
1013 count
1014}
1015
1016fn canberra_distance_simd(a: &[Float], b: &[Float]) -> Float {
1017 let mut sum = 0.0;
1018 for (&x, &y) in a.iter().zip(b.iter()) {
1019 let numerator = (x - y).abs();
1020 let denominator = x.abs() + y.abs();
1021 if denominator > 0.0 {
1022 sum += numerator / denominator;
1023 }
1024 }
1025 sum
1026}
1027
1028fn braycurtis_distance_simd(a: &[Float], b: &[Float]) -> Float {
1029 let mut numerator = 0.0;
1030 let mut denominator = 0.0;
1031 for (&x, &y) in a.iter().zip(b.iter()) {
1032 numerator += (x - y).abs();
1033 denominator += x.abs() + y.abs();
1034 }
1035 if denominator == 0.0 {
1036 0.0
1037 } else {
1038 numerator / denominator
1039 }
1040}
1041
1042fn correlation_distance_simd(a: &[Float], b: &[Float]) -> Float {
1043 let n = a.len() as Float;
1044 let sum_a: Float = a.iter().sum();
1045 let sum_b: Float = b.iter().sum();
1046 let mean_a = sum_a / n;
1047 let mean_b = sum_b / n;
1048
1049 let mut numerator = 0.0;
1050 let mut var_a = 0.0;
1051 let mut var_b = 0.0;
1052
1053 for (&x, &y) in a.iter().zip(b.iter()) {
1054 let diff_a = x - mean_a;
1055 let diff_b = y - mean_b;
1056 numerator += diff_a * diff_b;
1057 var_a += diff_a * diff_a;
1058 var_b += diff_b * diff_b;
1059 }
1060
1061 let denominator = (var_a * var_b).sqrt();
1062 if denominator == 0.0 {
1063 0.0
1064 } else {
1065 1.0 - (numerator / denominator)
1066 }
1067}
1068
1069fn wasserstein_distance_simd(a: &[Float], b: &[Float]) -> Float {
1070 let mut sorted_a = a.to_vec();
1071 let mut sorted_b = b.to_vec();
1072 sorted_a.sort_by(|x, y| x.partial_cmp(y).unwrap());
1073 sorted_b.sort_by(|x, y| x.partial_cmp(y).unwrap());
1074
1075 let mut sum = 0.0;
1076 for (x, y) in sorted_a.iter().zip(sorted_b.iter()) {
1077 sum += (x - y).abs();
1078 }
1079 sum / a.len() as Float
1080}
1081
1082#[allow(non_snake_case)]
1083#[cfg(test)]
1084mod tests {
1085 use super::*;
1086 use approx::assert_abs_diff_eq;
1087 use scirs2_core::ndarray::{array, Array1, Array2};
1088
1089 #[test]
1090 fn test_simd_euclidean_distance() {
1091 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1092 let query = array![0.0, 0.0];
1093
1094 let distances =
1095 simd_distance_batch_query(&data.view(), &query.view(), SimdDistanceMetric::Euclidean)
1096 .unwrap();
1097
1098 assert_eq!(distances.len(), 3);
1099 assert_abs_diff_eq!(distances[0], (5.0_f64).sqrt(), epsilon = 1e-6);
1100 assert_abs_diff_eq!(distances[1], 5.0, epsilon = 1e-6);
1101 assert_abs_diff_eq!(distances[2], (61.0_f64).sqrt(), epsilon = 1e-6);
1102 }
1103
1104 #[test]
1105 fn test_simd_manhattan_distance() {
1106 let data = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1107 let query = array![0.0, 0.0, 0.0];
1108
1109 let distances =
1110 simd_distance_batch_query(&data.view(), &query.view(), SimdDistanceMetric::Manhattan)
1111 .unwrap();
1112
1113 assert_eq!(distances.len(), 2);
1114 assert_abs_diff_eq!(distances[0], 6.0, epsilon = 1e-6); assert_abs_diff_eq!(distances[1], 15.0, epsilon = 1e-6); }
1117
1118 #[test]
1119 fn test_simd_vs_scalar_consistency() {
1120 let data = Array2::from_shape_vec(
1121 (4, 3),
1122 vec![
1123 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1124 ],
1125 )
1126 .unwrap();
1127 let query = array![0.0, 0.0, 0.0];
1128
1129 let simd_distances =
1130 simd_distance_batch_query(&data.view(), &query.view(), SimdDistanceMetric::Euclidean)
1131 .unwrap();
1132 let scalar_distances =
1133 scalar_distance_batch(&data.view(), &query.view(), SimdDistanceMetric::Euclidean);
1134
1135 assert_eq!(simd_distances.len(), scalar_distances.len());
1136 for (simd, scalar) in simd_distances.iter().zip(scalar_distances.iter()) {
1137 assert_abs_diff_eq!(simd, scalar, epsilon = 1e-5);
1138 }
1139 }
1140
1141 #[test]
1142 fn test_simd_k_nearest_neighbors() {
1143 let data = Array2::from_shape_vec(
1144 (5, 2),
1145 vec![
1146 1.0, 1.0, 2.0, 2.0, 0.0, 0.0, 3.0, 3.0, 0.5, 0.5, ],
1152 )
1153 .unwrap();
1154 let query = array![0.0, 0.0];
1155
1156 let neighbors = simd_k_nearest_neighbors(
1157 &data.view(),
1158 &query.view(),
1159 3,
1160 SimdDistanceMetric::Euclidean,
1161 )
1162 .unwrap();
1163
1164 assert_eq!(neighbors.len(), 3);
1165 assert_eq!(neighbors[0].0, 2); assert_eq!(neighbors[1].0, 4); assert_eq!(neighbors[2].0, 0); }
1169
1170 #[test]
1171 fn test_simd_radius_neighbors() {
1172 let data = Array2::from_shape_vec(
1173 (4, 2),
1174 vec![
1175 1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0, ],
1180 )
1181 .unwrap();
1182 let query = array![0.0, 0.0];
1183
1184 let neighbors = simd_radius_neighbors(
1185 &data.view(),
1186 &query.view(),
1187 1.5,
1188 SimdDistanceMetric::Euclidean,
1189 )
1190 .unwrap();
1191
1192 assert_eq!(neighbors.len(), 2);
1193 assert!(neighbors.contains(&0));
1194 assert!(neighbors.contains(&1));
1195 }
1196
1197 #[test]
1198 fn test_simd_distance_matrix() {
1199 let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
1200
1201 let matrix = simd_distance_matrix(&data.view(), SimdDistanceMetric::Euclidean).unwrap();
1202
1203 assert_eq!(matrix.shape(), &[3, 3]);
1204
1205 assert_abs_diff_eq!(matrix[[0, 0]], 0.0, epsilon = 1e-6);
1207 assert_abs_diff_eq!(matrix[[1, 1]], 0.0, epsilon = 1e-6);
1208 assert_abs_diff_eq!(matrix[[2, 2]], 0.0, epsilon = 1e-6);
1209
1210 assert_abs_diff_eq!(matrix[[0, 1]], 1.0, epsilon = 1e-6);
1212 assert_abs_diff_eq!(matrix[[1, 0]], 1.0, epsilon = 1e-6);
1213
1214 assert_abs_diff_eq!(matrix[[0, 2]], 1.0, epsilon = 1e-6);
1216 assert_abs_diff_eq!(matrix[[2, 0]], 1.0, epsilon = 1e-6);
1217
1218 assert_abs_diff_eq!(matrix[[1, 2]], (2.0_f64).sqrt(), epsilon = 1e-6);
1220 assert_abs_diff_eq!(matrix[[2, 1]], (2.0_f64).sqrt(), epsilon = 1e-6);
1221 }
1222
1223 #[test]
1224 fn test_adaptive_distance_batch() {
1225 let small_data = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1226 let large_data =
1227 Array2::from_shape_vec((10, 4), (0..40).map(|x| x as f64).collect()).unwrap();
1228 let query = array![0.0, 0.0, 0.0, 0.0];
1229
1230 let small_result = adaptive_distance_batch(
1232 &small_data.view(),
1233 &query.view().slice(scirs2_core::ndarray::s![..2]),
1234 SimdDistanceMetric::Euclidean,
1235 5,
1236 )
1237 .unwrap();
1238 assert_eq!(small_result.len(), 2);
1239
1240 let large_result = adaptive_distance_batch(
1242 &large_data.view(),
1243 &query.view(),
1244 SimdDistanceMetric::Euclidean,
1245 5,
1246 )
1247 .unwrap();
1248 assert_eq!(large_result.len(), 10);
1249 }
1250
1251 #[cfg(feature = "parallel")]
1252 #[test]
1253 fn test_parallel_simd_distance_batch() {
1254 let data = Array2::from_shape_vec((6, 3), (0..18).map(|x| x as f64).collect()).unwrap();
1255 let query = array![0.0, 0.0, 0.0];
1256
1257 let parallel_result = simd_distance_batch_parallel(
1258 &data.view(),
1259 &query.view(),
1260 SimdDistanceMetric::Euclidean,
1261 )
1262 .unwrap();
1263 let sequential_result =
1264 simd_distance_batch_query(&data.view(), &query.view(), SimdDistanceMetric::Euclidean)
1265 .unwrap();
1266
1267 assert_eq!(parallel_result.len(), sequential_result.len());
1268 for (par, seq) in parallel_result.iter().zip(sequential_result.iter()) {
1269 assert_abs_diff_eq!(par, seq, epsilon = 1e-6);
1270 }
1271 }
1272
1273 #[test]
1274 fn test_optimized_distance_computer_performance() {
1275 use scirs2_core::ndarray::{Array1, Array2};
1276
1277 let computer = OptimizedDistanceComputer::new();
1278 let n_points = 100;
1279 let n_features = 10;
1280
1281 let points1 = Array2::<Float>::ones((n_points, n_features));
1283 let points2 = Array2::<Float>::zeros((n_points, n_features));
1284
1285 let distances = computer.pairwise_distances(
1287 &points1.view(),
1288 &points2.view(),
1289 DistanceMetric::Euclidean,
1290 );
1291
1292 assert_eq!(distances.dim(), (n_points, n_points));
1293
1294 let expected_distance = (n_features as Float).sqrt();
1296 for &dist in distances.iter() {
1297 assert!((dist - expected_distance).abs() < 1e-6);
1298 }
1299 }
1300
1301 #[test]
1302 fn test_simd_detection() {
1303 let computer = OptimizedDistanceComputer::new();
1304
1305 let point1 = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1307 let point2 = Array1::from(vec![2.0, 3.0, 4.0, 5.0]);
1308
1309 let distance =
1310 computer.compute_distance(&point1.view(), &point2.view(), DistanceMetric::Euclidean);
1311
1312 assert!((distance - 2.0).abs() < 1e-6);
1314 }
1315}