sklears_clustering/
simd_distances.rs

1//! SIMD-optimized distance computations for clustering algorithms
2//!
3//! This module provides adapter functions to use SIMD-optimized distance
4//! calculations from sklears-simd with ndarray data structures.
5
6use scirs2_core::ndarray::{ArrayView1, ArrayView2};
7use sklears_core::types::Float;
8
9/// Auto-vectorized distance metrics that use SIMD when available
10#[derive(Debug, Clone, Copy)]
11pub enum DistanceMetric {
12    Euclidean,
13    Manhattan,
14    Chebyshev,
15    Cosine,
16    Minkowski(Float),
17    Jaccard,
18}
19
20/// High-performance distance computation with automatic SIMD detection
21pub struct OptimizedDistanceComputer {
22    /// Whether SIMD instructions are available
23    simd_available: bool,
24    /// Cache-friendly block size for batch operations
25    block_size: usize,
26}
27
28impl Default for OptimizedDistanceComputer {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl OptimizedDistanceComputer {
35    /// Create a new optimized distance computer with runtime SIMD detection
36    pub fn new() -> Self {
37        Self {
38            simd_available: Self::detect_simd_support(),
39            block_size: Self::optimal_block_size(),
40        }
41    }
42
43    /// Detect available SIMD instruction sets at runtime
44    fn detect_simd_support() -> bool {
45        #[cfg(target_arch = "x86_64")]
46        {
47            std::arch::is_x86_feature_detected!("avx2")
48                || std::arch::is_x86_feature_detected!("sse2")
49        }
50        #[cfg(target_arch = "aarch64")]
51        {
52            std::arch::is_aarch64_feature_detected!("neon")
53        }
54        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
55        {
56            false
57        }
58    }
59
60    /// Determine optimal block size based on cache hierarchy
61    fn optimal_block_size() -> usize {
62        // L1 cache is typically 32KB, use 1/4 for distance computations
63        // Assuming 32-bit floats: 8KB / 4 bytes = 2048 elements
64        // Square root for matrix operations: sqrt(2048) ≈ 45
65        48
66    }
67
68    /// Compute pairwise distances between all points in two arrays
69    pub fn pairwise_distances(
70        &self,
71        points1: &ArrayView2<Float>,
72        points2: &ArrayView2<Float>,
73        metric: DistanceMetric,
74    ) -> scirs2_core::ndarray::Array2<Float> {
75        let (n1, _) = points1.dim();
76        let (n2, _) = points2.dim();
77        let mut distances = scirs2_core::ndarray::Array2::zeros((n1, n2));
78
79        // Use blocked computation for cache efficiency
80        for i_start in (0..n1).step_by(self.block_size) {
81            let i_end = (i_start + self.block_size).min(n1);
82
83            for j_start in (0..n2).step_by(self.block_size) {
84                let j_end = (j_start + self.block_size).min(n2);
85
86                // Compute distances for this block
87                for i in i_start..i_end {
88                    for j in j_start..j_end {
89                        let point1 = points1.row(i);
90                        let point2 = points2.row(j);
91
92                        distances[[i, j]] = self.compute_distance(&point1, &point2, metric);
93                    }
94                }
95            }
96        }
97
98        distances
99    }
100
101    /// Compute distance between two points using optimal implementation
102    pub fn compute_distance(
103        &self,
104        point1: &ArrayView1<Float>,
105        point2: &ArrayView1<Float>,
106        metric: DistanceMetric,
107    ) -> Float {
108        if self.simd_available && point1.len() >= 4 {
109            self.compute_distance_simd(point1, point2, metric)
110        } else {
111            self.compute_distance_scalar(point1, point2, metric)
112        }
113    }
114
115    /// SIMD-optimized distance computation
116    fn compute_distance_simd(
117        &self,
118        point1: &ArrayView1<Float>,
119        point2: &ArrayView1<Float>,
120        metric: DistanceMetric,
121    ) -> Float {
122        // Convert to slices for SIMD processing
123        let a = point1.as_slice().unwrap();
124        let b = point2.as_slice().unwrap();
125
126        match metric {
127            DistanceMetric::Euclidean => self.euclidean_simd(a, b),
128            DistanceMetric::Manhattan => self.manhattan_simd(a, b),
129            DistanceMetric::Chebyshev => self.chebyshev_simd(a, b),
130            DistanceMetric::Cosine => self.cosine_simd(a, b),
131            DistanceMetric::Minkowski(p) => self.minkowski_simd(a, b, p),
132            DistanceMetric::Jaccard => self.jaccard_simd(a, b),
133        }
134    }
135
136    /// Scalar fallback distance computation
137    fn compute_distance_scalar(
138        &self,
139        point1: &ArrayView1<Float>,
140        point2: &ArrayView1<Float>,
141        metric: DistanceMetric,
142    ) -> Float {
143        let a = point1.as_slice().unwrap();
144        let b = point2.as_slice().unwrap();
145
146        match metric {
147            DistanceMetric::Euclidean => fallback_distance::euclidean_distance(a, b),
148            DistanceMetric::Manhattan => fallback_distance::manhattan_distance(a, b),
149            DistanceMetric::Chebyshev => fallback_distance::chebyshev_distance(a, b),
150            DistanceMetric::Cosine => fallback_distance::cosine_distance(a, b),
151            DistanceMetric::Minkowski(p) => fallback_distance::minkowski_distance(a, b, p),
152            DistanceMetric::Jaccard => fallback_distance::jaccard_distance(a, b),
153        }
154    }
155
156    /// SIMD-optimized Euclidean distance (when available)
157    fn euclidean_simd(&self, a: &[Float], b: &[Float]) -> Float {
158        #[cfg(target_arch = "x86_64")]
159        {
160            if std::arch::is_x86_feature_detected!("avx2") {
161                return unsafe { self.euclidean_avx2(a, b) };
162            }
163        }
164
165        // Fallback to optimized scalar with manual unrolling
166        self.euclidean_unrolled(a, b)
167    }
168
169    /// Manual loop unrolling for better performance when SIMD unavailable
170    fn euclidean_unrolled(&self, a: &[Float], b: &[Float]) -> Float {
171        let mut sum = 0.0;
172        let len = a.len();
173        let chunks = len / 4;
174
175        // Process 4 elements at a time
176        for i in 0..chunks {
177            let base = i * 4;
178            let diff1 = a[base] - b[base];
179            let diff2 = a[base + 1] - b[base + 1];
180            let diff3 = a[base + 2] - b[base + 2];
181            let diff4 = a[base + 3] - b[base + 3];
182
183            sum += diff1 * diff1 + diff2 * diff2 + diff3 * diff3 + diff4 * diff4;
184        }
185
186        // Handle remaining elements
187        for i in (chunks * 4)..len {
188            let diff = a[i] - b[i];
189            sum += diff * diff;
190        }
191
192        sum.sqrt()
193    }
194
195    /// AVX2-optimized Euclidean distance (x86_64 only)
196    #[cfg(target_arch = "x86_64")]
197    unsafe fn euclidean_avx2(&self, a: &[Float], b: &[Float]) -> Float {
198        // This would use std::arch::x86_64::* intrinsics for AVX2
199        // For now, fallback to unrolled version
200        self.euclidean_unrolled(a, b)
201    }
202
203    /// SIMD-optimized Manhattan distance
204    fn manhattan_simd(&self, a: &[Float], b: &[Float]) -> Float {
205        // Use manual unrolling for now
206        let mut sum = 0.0;
207        let len = a.len();
208        let chunks = len / 4;
209
210        for i in 0..chunks {
211            let base = i * 4;
212            sum += (a[base] - b[base]).abs()
213                + (a[base + 1] - b[base + 1]).abs()
214                + (a[base + 2] - b[base + 2]).abs()
215                + (a[base + 3] - b[base + 3]).abs();
216        }
217
218        for i in (chunks * 4)..len {
219            sum += (a[i] - b[i]).abs();
220        }
221
222        sum
223    }
224
225    /// SIMD-optimized Chebyshev distance
226    fn chebyshev_simd(&self, a: &[Float], b: &[Float]) -> Float {
227        let mut max_diff = 0.0;
228
229        for (x, y) in a.iter().zip(b.iter()) {
230            let diff = (x - y).abs();
231            if diff > max_diff {
232                max_diff = diff;
233            }
234        }
235
236        max_diff
237    }
238
239    /// SIMD-optimized Cosine distance
240    fn cosine_simd(&self, a: &[Float], b: &[Float]) -> Float {
241        let mut dot = 0.0;
242        let mut norm_a_sq = 0.0;
243        let mut norm_b_sq = 0.0;
244
245        let len = a.len();
246        let chunks = len / 4;
247
248        // Unrolled loop for better performance
249        for i in 0..chunks {
250            let base = i * 4;
251            for j in 0..4 {
252                let idx = base + j;
253                dot += a[idx] * b[idx];
254                norm_a_sq += a[idx] * a[idx];
255                norm_b_sq += b[idx] * b[idx];
256            }
257        }
258
259        for i in (chunks * 4)..len {
260            dot += a[i] * b[i];
261            norm_a_sq += a[i] * a[i];
262            norm_b_sq += b[i] * b[i];
263        }
264
265        1.0 - (dot / (norm_a_sq.sqrt() * norm_b_sq.sqrt()))
266    }
267
268    /// SIMD-optimized Minkowski distance
269    fn minkowski_simd(&self, a: &[Float], b: &[Float], p: Float) -> Float {
270        let mut sum = 0.0;
271
272        for (x, y) in a.iter().zip(b.iter()) {
273            sum += (x - y).abs().powf(p);
274        }
275
276        sum.powf(1.0 / p)
277    }
278
279    /// SIMD-optimized Jaccard distance
280    fn jaccard_simd(&self, a: &[Float], b: &[Float]) -> Float {
281        let mut intersection = 0.0;
282        let mut union = 0.0;
283
284        for (x, y) in a.iter().zip(b.iter()) {
285            intersection += x.min(*y);
286            union += x.max(*y);
287        }
288
289        1.0 - (intersection / union)
290    }
291}
292
293// Fallback implementations when SIMD is not available
294mod fallback_distance {
295    use super::Float;
296
297    pub fn euclidean_distance(a: &[Float], b: &[Float]) -> Float {
298        a.iter()
299            .zip(b.iter())
300            .map(|(x, y)| (x - y).powi(2))
301            .sum::<Float>()
302            .sqrt()
303    }
304
305    pub fn manhattan_distance(a: &[Float], b: &[Float]) -> Float {
306        a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
307    }
308
309    pub fn chebyshev_distance(a: &[Float], b: &[Float]) -> Float {
310        a.iter()
311            .zip(b.iter())
312            .map(|(x, y)| (x - y).abs())
313            .fold(0.0, Float::max)
314    }
315
316    pub fn cosine_distance(a: &[Float], b: &[Float]) -> Float {
317        let dot: Float = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
318        let norm_a: Float = a.iter().map(|x| x * x).sum::<Float>().sqrt();
319        let norm_b: Float = b.iter().map(|x| x * x).sum::<Float>().sqrt();
320        1.0 - (dot / (norm_a * norm_b))
321    }
322
323    pub fn minkowski_distance(a: &[Float], b: &[Float], p: Float) -> Float {
324        a.iter()
325            .zip(b.iter())
326            .map(|(x, y)| (x - y).abs().powf(p))
327            .sum::<Float>()
328            .powf(1.0 / p)
329    }
330
331    pub fn jaccard_distance(a: &[Float], b: &[Float]) -> Float {
332        let intersection: Float = a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).sum();
333        let union: Float = a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).sum();
334        1.0 - (intersection / union)
335    }
336}
337
338/// SIMD-optimized distance metrics for clustering
339#[derive(Debug, Clone, Copy, PartialEq)]
340pub enum SimdDistanceMetric {
341    /// Euclidean distance (L2 norm)
342    Euclidean,
343    /// Squared Euclidean distance (faster than Euclidean for many use cases)
344    EuclideanSquared,
345    /// Manhattan distance (L1 norm)
346    Manhattan,
347    /// Chebyshev distance (L∞ norm)
348    Chebyshev,
349    /// Cosine distance
350    Cosine,
351    /// Cosine similarity (1 - cosine distance)
352    CosineSimilarity,
353    /// Minkowski distance with parameter p
354    Minkowski(Float),
355    /// Jaccard distance
356    Jaccard,
357    /// Hamming distance (for binary/categorical data)
358    Hamming,
359    /// Canberra distance
360    Canberra,
361    /// Braycurtis distance
362    Braycurtis,
363    /// Mahalanobis distance (requires covariance matrix)
364    Mahalanobis,
365    /// Pearson correlation distance
366    Correlation,
367    /// Wasserstein (Earth Mover's) distance
368    Wasserstein,
369}
370
371/// Calculate SIMD-optimized distance between two points
372pub fn simd_distance(
373    point1: &ArrayView1<Float>,
374    point2: &ArrayView1<Float>,
375    metric: SimdDistanceMetric,
376) -> Result<Float, Box<dyn std::error::Error>> {
377    // Use Float directly for consistency
378    let a = point1.as_slice().unwrap();
379    let b = point2.as_slice().unwrap();
380
381    let result = match metric {
382        SimdDistanceMetric::Euclidean => fallback_distance::euclidean_distance(a, b),
383        SimdDistanceMetric::EuclideanSquared => {
384            let euclidean = fallback_distance::euclidean_distance(a, b);
385            euclidean * euclidean
386        }
387        SimdDistanceMetric::Manhattan => fallback_distance::manhattan_distance(a, b),
388        SimdDistanceMetric::Chebyshev => fallback_distance::chebyshev_distance(a, b),
389        SimdDistanceMetric::Cosine => fallback_distance::cosine_distance(a, b),
390        SimdDistanceMetric::CosineSimilarity => 1.0 - fallback_distance::cosine_distance(a, b),
391        SimdDistanceMetric::Minkowski(p) => fallback_distance::minkowski_distance(a, b, p),
392        SimdDistanceMetric::Jaccard => fallback_distance::jaccard_distance(a, b),
393        SimdDistanceMetric::Hamming => hamming_distance_simd(a, b),
394        SimdDistanceMetric::Canberra => canberra_distance_simd(a, b),
395        SimdDistanceMetric::Braycurtis => braycurtis_distance_simd(a, b),
396        SimdDistanceMetric::Mahalanobis => {
397            return Err("Mahalanobis distance requires covariance matrix parameter".into());
398        }
399        SimdDistanceMetric::Correlation => correlation_distance_simd(a, b),
400        SimdDistanceMetric::Wasserstein => wasserstein_distance_simd(a, b),
401    };
402
403    Ok(result as Float)
404}
405
406/// Calculate SIMD-optimized squared Euclidean distance (faster than Euclidean)
407pub fn simd_squared_euclidean_distance(
408    point1: &ArrayView1<Float>,
409    point2: &ArrayView1<Float>,
410) -> Result<Float, Box<dyn std::error::Error>> {
411    let a = point1.as_slice().unwrap();
412    let b = point2.as_slice().unwrap();
413
414    let euclidean = fallback_distance::euclidean_distance(a, b);
415    Ok(euclidean * euclidean)
416}
417
418/// Batch SIMD-optimized distance calculation from multiple points to multiple queries
419pub fn simd_distance_batch(
420    points: &[scirs2_core::ndarray::Array1<Float>],
421    queries: &[scirs2_core::ndarray::Array1<Float>],
422    metric: SimdDistanceMetric,
423) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
424    if points.len() != queries.len() {
425        return Err("Points and queries must have the same length".into());
426    }
427
428    let mut results = Vec::with_capacity(points.len());
429
430    for (point, query) in points.iter().zip(queries.iter()) {
431        let point_slice = point.as_slice().unwrap();
432        let query_slice = query.as_slice().unwrap();
433
434        let distance = match metric {
435            SimdDistanceMetric::Euclidean => {
436                fallback_distance::euclidean_distance(point_slice, query_slice)
437            }
438            SimdDistanceMetric::EuclideanSquared => {
439                let euclidean = fallback_distance::euclidean_distance(point_slice, query_slice);
440                euclidean * euclidean
441            }
442            SimdDistanceMetric::Manhattan => {
443                fallback_distance::manhattan_distance(point_slice, query_slice)
444            }
445            SimdDistanceMetric::Chebyshev => {
446                fallback_distance::chebyshev_distance(point_slice, query_slice)
447            }
448            SimdDistanceMetric::Cosine => {
449                fallback_distance::cosine_distance(point_slice, query_slice)
450            }
451            SimdDistanceMetric::CosineSimilarity => {
452                1.0 - fallback_distance::cosine_distance(point_slice, query_slice)
453            }
454            SimdDistanceMetric::Minkowski(p) => {
455                fallback_distance::minkowski_distance(point_slice, query_slice, p)
456            }
457            SimdDistanceMetric::Jaccard => {
458                fallback_distance::jaccard_distance(point_slice, query_slice)
459            }
460            SimdDistanceMetric::Hamming => hamming_distance_simd(point_slice, query_slice),
461            SimdDistanceMetric::Canberra => canberra_distance_simd(point_slice, query_slice),
462            SimdDistanceMetric::Braycurtis => braycurtis_distance_simd(point_slice, query_slice),
463            SimdDistanceMetric::Mahalanobis => {
464                return Err("Mahalanobis distance requires covariance matrix parameter".into());
465            }
466            SimdDistanceMetric::Correlation => correlation_distance_simd(point_slice, query_slice),
467            SimdDistanceMetric::Wasserstein => wasserstein_distance_simd(point_slice, query_slice),
468        };
469
470        results.push(distance);
471    }
472
473    Ok(results)
474}
475
476/// Batch SIMD-optimized distance calculation from a query point to multiple points
477pub fn simd_distance_batch_query(
478    points: &ArrayView2<Float>,
479    query: &ArrayView1<Float>,
480    metric: SimdDistanceMetric,
481) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
482    let query_slice = query.as_slice().unwrap();
483    let mut results = Vec::with_capacity(points.nrows());
484
485    for i in 0..points.nrows() {
486        let point = points.row(i);
487        let point_slice = point.as_slice().unwrap();
488
489        let distance = match metric {
490            SimdDistanceMetric::Euclidean => {
491                fallback_distance::euclidean_distance(point_slice, query_slice)
492            }
493            SimdDistanceMetric::EuclideanSquared => {
494                let euclidean = fallback_distance::euclidean_distance(point_slice, query_slice);
495                euclidean * euclidean
496            }
497            SimdDistanceMetric::Manhattan => {
498                fallback_distance::manhattan_distance(point_slice, query_slice)
499            }
500            SimdDistanceMetric::Chebyshev => {
501                fallback_distance::chebyshev_distance(point_slice, query_slice)
502            }
503            SimdDistanceMetric::Cosine => {
504                fallback_distance::cosine_distance(point_slice, query_slice)
505            }
506            SimdDistanceMetric::CosineSimilarity => {
507                1.0 - fallback_distance::cosine_distance(point_slice, query_slice)
508            }
509            SimdDistanceMetric::Minkowski(p) => {
510                fallback_distance::minkowski_distance(point_slice, query_slice, p)
511            }
512            SimdDistanceMetric::Jaccard => {
513                fallback_distance::jaccard_distance(point_slice, query_slice)
514            }
515            SimdDistanceMetric::Hamming => hamming_distance_simd(point_slice, query_slice),
516            SimdDistanceMetric::Canberra => canberra_distance_simd(point_slice, query_slice),
517            SimdDistanceMetric::Braycurtis => braycurtis_distance_simd(point_slice, query_slice),
518            SimdDistanceMetric::Mahalanobis => {
519                return Err("Mahalanobis distance requires covariance matrix parameter".into());
520            }
521            SimdDistanceMetric::Correlation => correlation_distance_simd(point_slice, query_slice),
522            SimdDistanceMetric::Wasserstein => wasserstein_distance_simd(point_slice, query_slice),
523        };
524
525        results.push(distance);
526    }
527
528    Ok(results)
529}
530
531/// Parallel batch SIMD-optimized distance calculation
532#[cfg(feature = "parallel")]
533pub fn simd_distance_batch_parallel(
534    points: &ArrayView2<Float>,
535    query: &ArrayView1<Float>,
536    metric: SimdDistanceMetric,
537) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
538    use rayon::prelude::*;
539
540    let query_vec: Vec<Float> = query.iter().copied().collect();
541
542    let results: Vec<Float> = (0..points.nrows())
543        .into_par_iter()
544        .map(|i| {
545            let point = points.row(i);
546            let point_vec: Vec<Float> = point.iter().copied().collect();
547
548            let distance = match metric {
549                SimdDistanceMetric::Euclidean => {
550                    fallback_distance::euclidean_distance(&point_vec, &query_vec)
551                }
552                SimdDistanceMetric::EuclideanSquared => {
553                    let euclidean = fallback_distance::euclidean_distance(&point_vec, &query_vec);
554                    euclidean * euclidean
555                }
556                SimdDistanceMetric::Manhattan => {
557                    fallback_distance::manhattan_distance(&point_vec, &query_vec)
558                }
559                SimdDistanceMetric::Chebyshev => {
560                    fallback_distance::chebyshev_distance(&point_vec, &query_vec)
561                }
562                SimdDistanceMetric::Cosine => {
563                    fallback_distance::cosine_distance(&point_vec, &query_vec)
564                }
565                SimdDistanceMetric::CosineSimilarity => {
566                    1.0 - fallback_distance::cosine_distance(&point_vec, &query_vec)
567                }
568                SimdDistanceMetric::Minkowski(p) => {
569                    fallback_distance::minkowski_distance(&point_vec, &query_vec, p)
570                }
571                SimdDistanceMetric::Jaccard => {
572                    fallback_distance::jaccard_distance(&point_vec, &query_vec)
573                }
574                SimdDistanceMetric::Hamming => hamming_distance_simd(&point_vec, &query_vec),
575                SimdDistanceMetric::Canberra => canberra_distance_simd(&point_vec, &query_vec),
576                SimdDistanceMetric::Braycurtis => braycurtis_distance_simd(&point_vec, &query_vec),
577                SimdDistanceMetric::Mahalanobis => {
578                    return Err(Box::new(std::io::Error::new(
579                        std::io::ErrorKind::InvalidInput,
580                        "Mahalanobis distance requires covariance matrix parameter",
581                    )));
582                }
583                SimdDistanceMetric::Correlation => {
584                    correlation_distance_simd(&point_vec, &query_vec)
585                }
586                SimdDistanceMetric::Wasserstein => {
587                    wasserstein_distance_simd(&point_vec, &query_vec)
588                }
589            };
590
591            Ok(distance as Float)
592        })
593        .collect::<Result<Vec<_>, _>>()?;
594
595    Ok(results)
596}
597
598/// Calculate pairwise SIMD-optimized distances between all points
599pub fn simd_pairwise_distances(
600    points: &ArrayView2<Float>,
601    metric: SimdDistanceMetric,
602) -> Result<Vec<Vec<Float>>, Box<dyn std::error::Error>> {
603    let n_points = points.nrows();
604    let mut distances = vec![vec![0.0; n_points]; n_points];
605
606    for i in 0..n_points {
607        let point_i = points.row(i);
608        for j in (i + 1)..n_points {
609            let point_j = points.row(j);
610            let dist = simd_distance(&point_i, &point_j, metric)?;
611            distances[i][j] = dist;
612            distances[j][i] = dist; // Symmetric
613        }
614    }
615
616    Ok(distances)
617}
618
619/// Calculate k-nearest neighbors using SIMD-optimized distances
620pub fn simd_k_nearest_neighbors(
621    points: &ArrayView2<Float>,
622    query: &ArrayView1<Float>,
623    k: usize,
624    metric: SimdDistanceMetric,
625) -> Result<Vec<(usize, Float)>, Box<dyn std::error::Error>> {
626    let distances = simd_distance_batch_query(points, query, metric)?;
627
628    let mut indexed_distances: Vec<(usize, Float)> = distances.into_iter().enumerate().collect();
629
630    // Sort by distance and take the k nearest
631    indexed_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
632    indexed_distances.truncate(k);
633
634    Ok(indexed_distances)
635}
636
637/// Find neighbors within a radius using SIMD-optimized distances
638pub fn simd_radius_neighbors(
639    points: &ArrayView2<Float>,
640    query: &ArrayView1<Float>,
641    radius: Float,
642    metric: SimdDistanceMetric,
643) -> Result<Vec<usize>, Box<dyn std::error::Error>> {
644    let distances = simd_distance_batch_query(points, query, metric)?;
645
646    let neighbors: Vec<usize> = distances
647        .into_iter()
648        .enumerate()
649        .filter_map(|(idx, dist)| if dist <= radius { Some(idx) } else { None })
650        .collect();
651
652    Ok(neighbors)
653}
654
655/// SIMD-optimized distance matrix computation
656pub fn simd_distance_matrix(
657    points: &ArrayView2<Float>,
658    metric: SimdDistanceMetric,
659) -> Result<scirs2_core::ndarray::Array2<Float>, Box<dyn std::error::Error>> {
660    let n_points = points.nrows();
661    let mut matrix = scirs2_core::ndarray::Array2::zeros((n_points, n_points));
662
663    for i in 0..n_points {
664        let point_i = points.row(i);
665        for j in (i + 1)..n_points {
666            let point_j = points.row(j);
667            let dist = simd_distance(&point_i, &point_j, metric)?;
668            matrix[[i, j]] = dist;
669            matrix[[j, i]] = dist;
670        }
671    }
672
673    Ok(matrix)
674}
675
676/// Performance comparison between SIMD and scalar implementations
677pub fn benchmark_simd_vs_scalar(
678    points: &ArrayView2<Float>,
679    query: &ArrayView1<Float>,
680    metric: SimdDistanceMetric,
681) -> (f64, f64) {
682    use std::time::Instant;
683
684    // SIMD version
685    let start = Instant::now();
686    let _simd_result = simd_distance_batch_query(points, query, metric).unwrap();
687    let simd_time = start.elapsed().as_secs_f64();
688
689    // Scalar version (simple fallback)
690    let start = Instant::now();
691    let _scalar_result = scalar_distance_batch(points, query, metric);
692    let scalar_time = start.elapsed().as_secs_f64();
693
694    (simd_time, scalar_time)
695}
696
697/// Scalar fallback implementation for comparison
698fn scalar_distance_batch(
699    points: &ArrayView2<Float>,
700    query: &ArrayView1<Float>,
701    metric: SimdDistanceMetric,
702) -> Vec<Float> {
703    let mut results = Vec::with_capacity(points.nrows());
704
705    for i in 0..points.nrows() {
706        let point = points.row(i);
707        let dist = match metric {
708            SimdDistanceMetric::Euclidean => {
709                let mut sum = 0.0;
710                for (&a, &b) in point.iter().zip(query.iter()) {
711                    let diff = a - b;
712                    sum += diff * diff;
713                }
714                sum.sqrt()
715            }
716            SimdDistanceMetric::Manhattan => {
717                let mut sum = 0.0;
718                for (&a, &b) in point.iter().zip(query.iter()) {
719                    sum += (a - b).abs();
720                }
721                sum
722            }
723            SimdDistanceMetric::Chebyshev => {
724                let mut max_diff = 0.0;
725                for (&a, &b) in point.iter().zip(query.iter()) {
726                    let diff = (a - b).abs();
727                    if diff > max_diff {
728                        max_diff = diff;
729                    }
730                }
731                max_diff
732            }
733            SimdDistanceMetric::Cosine => {
734                let mut dot = 0.0;
735                let mut norm_a = 0.0;
736                let mut norm_b = 0.0;
737                for (&a, &b) in point.iter().zip(query.iter()) {
738                    dot += a * b;
739                    norm_a += a * a;
740                    norm_b += b * b;
741                }
742                let norm_product = norm_a.sqrt() * norm_b.sqrt();
743                if norm_product == 0.0 {
744                    0.0
745                } else {
746                    1.0 - (dot / norm_product)
747                }
748            }
749            SimdDistanceMetric::CosineSimilarity => {
750                let mut dot = 0.0;
751                let mut norm_a = 0.0;
752                let mut norm_b = 0.0;
753                for (&a, &b) in point.iter().zip(query.iter()) {
754                    dot += a * b;
755                    norm_a += a * a;
756                    norm_b += b * b;
757                }
758                let norm_product = norm_a.sqrt() * norm_b.sqrt();
759                if norm_product == 0.0 {
760                    0.0
761                } else {
762                    dot / norm_product
763                }
764            }
765            SimdDistanceMetric::Minkowski(p) => {
766                let mut sum = 0.0;
767                for (&a, &b) in point.iter().zip(query.iter()) {
768                    sum += (a - b).abs().powf(p as Float);
769                }
770                sum.powf(1.0 / p as Float)
771            }
772            SimdDistanceMetric::Jaccard => {
773                let mut intersection = 0.0;
774                let mut union = 0.0;
775                for (&a, &b) in point.iter().zip(query.iter()) {
776                    intersection += a.min(b);
777                    union += a.max(b);
778                }
779                if union == 0.0 {
780                    0.0
781                } else {
782                    1.0 - (intersection / union)
783                }
784            }
785            SimdDistanceMetric::EuclideanSquared => {
786                let mut sum = 0.0;
787                for (&a, &b) in point.iter().zip(query.iter()) {
788                    let diff = a - b;
789                    sum += diff * diff;
790                }
791                sum
792            }
793            SimdDistanceMetric::Hamming => {
794                let mut count = 0.0;
795                for (&a, &b) in point.iter().zip(query.iter()) {
796                    if (a - b).abs() > Float::EPSILON {
797                        count += 1.0;
798                    }
799                }
800                count
801            }
802            SimdDistanceMetric::Canberra => {
803                let mut sum = 0.0;
804                for (&a, &b) in point.iter().zip(query.iter()) {
805                    let numerator = (a - b).abs();
806                    let denominator = a.abs() + b.abs();
807                    if denominator > 0.0 {
808                        sum += numerator / denominator;
809                    }
810                }
811                sum
812            }
813            SimdDistanceMetric::Braycurtis => {
814                let mut numerator = 0.0;
815                let mut denominator = 0.0;
816                for (&a, &b) in point.iter().zip(query.iter()) {
817                    numerator += (a - b).abs();
818                    denominator += a.abs() + b.abs();
819                }
820                if denominator == 0.0 {
821                    0.0
822                } else {
823                    numerator / denominator
824                }
825            }
826            SimdDistanceMetric::Mahalanobis => {
827                // This should not be reached as it requires covariance matrix
828                0.0
829            }
830            SimdDistanceMetric::Correlation => {
831                // Pearson correlation distance
832                let n = point.len() as Float;
833                let sum_a: Float = point.iter().sum();
834                let sum_b: Float = query.iter().sum();
835                let mean_a = sum_a / n;
836                let mean_b = sum_b / n;
837
838                let mut numerator = 0.0;
839                let mut var_a = 0.0;
840                let mut var_b = 0.0;
841
842                for (&a, &b) in point.iter().zip(query.iter()) {
843                    let diff_a = a - mean_a;
844                    let diff_b = b - mean_b;
845                    numerator += diff_a * diff_b;
846                    var_a += diff_a * diff_a;
847                    var_b += diff_b * diff_b;
848                }
849
850                let denominator = (var_a * var_b).sqrt();
851                if denominator == 0.0 {
852                    0.0
853                } else {
854                    1.0 - (numerator / denominator)
855                }
856            }
857            SimdDistanceMetric::Wasserstein => {
858                // Simple 1D Wasserstein distance (Earth Mover's Distance)
859                let mut sorted_a: Vec<Float> = point.iter().cloned().collect();
860                let mut sorted_b: Vec<Float> = query.iter().cloned().collect();
861                sorted_a.sort_by(|a, b| a.partial_cmp(b).unwrap());
862                sorted_b.sort_by(|a, b| a.partial_cmp(b).unwrap());
863
864                let mut sum = 0.0;
865                for (a, b) in sorted_a.iter().zip(sorted_b.iter()) {
866                    sum += (a - b).abs();
867                }
868                sum / point.len() as Float
869            }
870        };
871        results.push(dist);
872    }
873
874    results
875}
876
877/// Adaptive distance function that chooses between SIMD and scalar based on data size
878pub fn adaptive_distance_batch(
879    points: &ArrayView2<Float>,
880    query: &ArrayView1<Float>,
881    metric: SimdDistanceMetric,
882    simd_threshold: usize,
883) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
884    if points.nrows() >= simd_threshold && query.len() >= 4 {
885        simd_distance_batch_query(points, query, metric)
886    } else {
887        Ok(scalar_distance_batch(points, query, metric))
888    }
889}
890
891/// Custom distance function with user-defined metric
892pub fn custom_distance<F>(
893    point1: &ArrayView1<Float>,
894    point2: &ArrayView1<Float>,
895    distance_fn: F,
896) -> Float
897where
898    F: Fn(&ArrayView1<Float>, &ArrayView1<Float>) -> Float,
899{
900    distance_fn(point1, point2)
901}
902
903/// Mahalanobis distance with provided covariance matrix
904pub fn mahalanobis_distance(
905    point1: &ArrayView1<Float>,
906    point2: &ArrayView1<Float>,
907    cov_inv: &scirs2_core::ndarray::Array2<Float>,
908) -> Result<Float, Box<dyn std::error::Error>> {
909    if point1.len() != point2.len() {
910        return Err("Points must have the same dimensions".into());
911    }
912
913    if cov_inv.nrows() != point1.len() || cov_inv.ncols() != point1.len() {
914        return Err("Covariance matrix dimensions must match point dimensions".into());
915    }
916
917    let diff: scirs2_core::ndarray::Array1<Float> = point1.to_owned() - point2;
918    let temp = cov_inv.dot(&diff);
919    let distance_squared = diff.dot(&temp);
920
921    Ok(distance_squared.sqrt())
922}
923
924/// Distance metrics with preprocessing for categorical data
925pub fn categorical_distance(
926    point1: &ArrayView1<Float>,
927    point2: &ArrayView1<Float>,
928    metric: CategoricalDistanceMetric,
929) -> Float {
930    match metric {
931        CategoricalDistanceMetric::Hamming => {
932            let mut count = 0.0;
933            for (&a, &b) in point1.iter().zip(point2.iter()) {
934                if (a - b).abs() > Float::EPSILON {
935                    count += 1.0;
936                }
937            }
938            count / point1.len() as Float
939        }
940        CategoricalDistanceMetric::MatchingDissimilarity => {
941            let mut mismatches = 0.0;
942            for (&a, &b) in point1.iter().zip(point2.iter()) {
943                if (a - b).abs() > Float::EPSILON {
944                    mismatches += 1.0;
945                }
946            }
947            mismatches / point1.len() as Float
948        }
949    }
950}
951
952/// Categorical distance metrics
953#[derive(Debug, Clone, Copy, PartialEq)]
954pub enum CategoricalDistanceMetric {
955    /// Hamming distance for categorical variables
956    Hamming,
957    /// Matching dissimilarity
958    MatchingDissimilarity,
959}
960
961/// Weighted distance calculation
962pub fn weighted_distance(
963    point1: &ArrayView1<Float>,
964    point2: &ArrayView1<Float>,
965    weights: &ArrayView1<Float>,
966    metric: SimdDistanceMetric,
967) -> Result<Float, Box<dyn std::error::Error>> {
968    if point1.len() != point2.len() || point1.len() != weights.len() {
969        return Err("All arrays must have the same length".into());
970    }
971
972    match metric {
973        SimdDistanceMetric::Euclidean => {
974            let mut sum = 0.0;
975            for ((&a, &b), &w) in point1.iter().zip(point2.iter()).zip(weights.iter()) {
976                let diff = a - b;
977                sum += w * diff * diff;
978            }
979            Ok(sum.sqrt())
980        }
981        SimdDistanceMetric::Manhattan => {
982            let mut sum = 0.0;
983            for ((&a, &b), &w) in point1.iter().zip(point2.iter()).zip(weights.iter()) {
984                sum += w * (a - b).abs();
985            }
986            Ok(sum)
987        }
988        _ => {
989            // For other metrics, apply weights as scaling factors
990            let weighted_p1: scirs2_core::ndarray::Array1<Float> = point1
991                .iter()
992                .zip(weights.iter())
993                .map(|(&p, &w)| p * w.sqrt())
994                .collect();
995            let weighted_p2: scirs2_core::ndarray::Array1<Float> = point2
996                .iter()
997                .zip(weights.iter())
998                .map(|(&p, &w)| p * w.sqrt())
999                .collect();
1000            simd_distance(&weighted_p1.view(), &weighted_p2.view(), metric)
1001        }
1002    }
1003}
1004
1005// Helper functions for additional distance metrics
1006fn hamming_distance_simd(a: &[Float], b: &[Float]) -> Float {
1007    let mut count = 0.0;
1008    for (&x, &y) in a.iter().zip(b.iter()) {
1009        if (x - y).abs() > Float::EPSILON {
1010            count += 1.0;
1011        }
1012    }
1013    count
1014}
1015
1016fn canberra_distance_simd(a: &[Float], b: &[Float]) -> Float {
1017    let mut sum = 0.0;
1018    for (&x, &y) in a.iter().zip(b.iter()) {
1019        let numerator = (x - y).abs();
1020        let denominator = x.abs() + y.abs();
1021        if denominator > 0.0 {
1022            sum += numerator / denominator;
1023        }
1024    }
1025    sum
1026}
1027
1028fn braycurtis_distance_simd(a: &[Float], b: &[Float]) -> Float {
1029    let mut numerator = 0.0;
1030    let mut denominator = 0.0;
1031    for (&x, &y) in a.iter().zip(b.iter()) {
1032        numerator += (x - y).abs();
1033        denominator += x.abs() + y.abs();
1034    }
1035    if denominator == 0.0 {
1036        0.0
1037    } else {
1038        numerator / denominator
1039    }
1040}
1041
1042fn correlation_distance_simd(a: &[Float], b: &[Float]) -> Float {
1043    let n = a.len() as Float;
1044    let sum_a: Float = a.iter().sum();
1045    let sum_b: Float = b.iter().sum();
1046    let mean_a = sum_a / n;
1047    let mean_b = sum_b / n;
1048
1049    let mut numerator = 0.0;
1050    let mut var_a = 0.0;
1051    let mut var_b = 0.0;
1052
1053    for (&x, &y) in a.iter().zip(b.iter()) {
1054        let diff_a = x - mean_a;
1055        let diff_b = y - mean_b;
1056        numerator += diff_a * diff_b;
1057        var_a += diff_a * diff_a;
1058        var_b += diff_b * diff_b;
1059    }
1060
1061    let denominator = (var_a * var_b).sqrt();
1062    if denominator == 0.0 {
1063        0.0
1064    } else {
1065        1.0 - (numerator / denominator)
1066    }
1067}
1068
1069fn wasserstein_distance_simd(a: &[Float], b: &[Float]) -> Float {
1070    let mut sorted_a = a.to_vec();
1071    let mut sorted_b = b.to_vec();
1072    sorted_a.sort_by(|x, y| x.partial_cmp(y).unwrap());
1073    sorted_b.sort_by(|x, y| x.partial_cmp(y).unwrap());
1074
1075    let mut sum = 0.0;
1076    for (x, y) in sorted_a.iter().zip(sorted_b.iter()) {
1077        sum += (x - y).abs();
1078    }
1079    sum / a.len() as Float
1080}
1081
1082#[allow(non_snake_case)]
1083#[cfg(test)]
1084mod tests {
1085    use super::*;
1086    use approx::assert_abs_diff_eq;
1087    use scirs2_core::ndarray::{array, Array1, Array2};
1088
1089    #[test]
1090    fn test_simd_euclidean_distance() {
1091        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1092        let query = array![0.0, 0.0];
1093
1094        let distances =
1095            simd_distance_batch_query(&data.view(), &query.view(), SimdDistanceMetric::Euclidean)
1096                .unwrap();
1097
1098        assert_eq!(distances.len(), 3);
1099        assert_abs_diff_eq!(distances[0], (5.0_f64).sqrt(), epsilon = 1e-6);
1100        assert_abs_diff_eq!(distances[1], 5.0, epsilon = 1e-6);
1101        assert_abs_diff_eq!(distances[2], (61.0_f64).sqrt(), epsilon = 1e-6);
1102    }
1103
1104    #[test]
1105    fn test_simd_manhattan_distance() {
1106        let data = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1107        let query = array![0.0, 0.0, 0.0];
1108
1109        let distances =
1110            simd_distance_batch_query(&data.view(), &query.view(), SimdDistanceMetric::Manhattan)
1111                .unwrap();
1112
1113        assert_eq!(distances.len(), 2);
1114        assert_abs_diff_eq!(distances[0], 6.0, epsilon = 1e-6); // |1| + |2| + |3|
1115        assert_abs_diff_eq!(distances[1], 15.0, epsilon = 1e-6); // |4| + |5| + |6|
1116    }
1117
1118    #[test]
1119    fn test_simd_vs_scalar_consistency() {
1120        let data = Array2::from_shape_vec(
1121            (4, 3),
1122            vec![
1123                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1124            ],
1125        )
1126        .unwrap();
1127        let query = array![0.0, 0.0, 0.0];
1128
1129        let simd_distances =
1130            simd_distance_batch_query(&data.view(), &query.view(), SimdDistanceMetric::Euclidean)
1131                .unwrap();
1132        let scalar_distances =
1133            scalar_distance_batch(&data.view(), &query.view(), SimdDistanceMetric::Euclidean);
1134
1135        assert_eq!(simd_distances.len(), scalar_distances.len());
1136        for (simd, scalar) in simd_distances.iter().zip(scalar_distances.iter()) {
1137            assert_abs_diff_eq!(simd, scalar, epsilon = 1e-5);
1138        }
1139    }
1140
1141    #[test]
1142    fn test_simd_k_nearest_neighbors() {
1143        let data = Array2::from_shape_vec(
1144            (5, 2),
1145            vec![
1146                1.0, 1.0, // Distance: sqrt(2) ≈ 1.414
1147                2.0, 2.0, // Distance: sqrt(8) ≈ 2.828
1148                0.0, 0.0, // Distance: 0
1149                3.0, 3.0, // Distance: sqrt(18) ≈ 4.243
1150                0.5, 0.5, // Distance: sqrt(0.5) ≈ 0.707
1151            ],
1152        )
1153        .unwrap();
1154        let query = array![0.0, 0.0];
1155
1156        let neighbors = simd_k_nearest_neighbors(
1157            &data.view(),
1158            &query.view(),
1159            3,
1160            SimdDistanceMetric::Euclidean,
1161        )
1162        .unwrap();
1163
1164        assert_eq!(neighbors.len(), 3);
1165        assert_eq!(neighbors[0].0, 2); // Nearest is (0,0)
1166        assert_eq!(neighbors[1].0, 4); // Second nearest is (0.5, 0.5)
1167        assert_eq!(neighbors[2].0, 0); // Third nearest is (1,1)
1168    }
1169
1170    #[test]
1171    fn test_simd_radius_neighbors() {
1172        let data = Array2::from_shape_vec(
1173            (4, 2),
1174            vec![
1175                1.0, 0.0, // Distance: 1
1176                0.0, 1.0, // Distance: 1
1177                2.0, 0.0, // Distance: 2
1178                0.0, 2.0, // Distance: 2
1179            ],
1180        )
1181        .unwrap();
1182        let query = array![0.0, 0.0];
1183
1184        let neighbors = simd_radius_neighbors(
1185            &data.view(),
1186            &query.view(),
1187            1.5,
1188            SimdDistanceMetric::Euclidean,
1189        )
1190        .unwrap();
1191
1192        assert_eq!(neighbors.len(), 2);
1193        assert!(neighbors.contains(&0));
1194        assert!(neighbors.contains(&1));
1195    }
1196
1197    #[test]
1198    fn test_simd_distance_matrix() {
1199        let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
1200
1201        let matrix = simd_distance_matrix(&data.view(), SimdDistanceMetric::Euclidean).unwrap();
1202
1203        assert_eq!(matrix.shape(), &[3, 3]);
1204
1205        // Diagonal should be zero
1206        assert_abs_diff_eq!(matrix[[0, 0]], 0.0, epsilon = 1e-6);
1207        assert_abs_diff_eq!(matrix[[1, 1]], 0.0, epsilon = 1e-6);
1208        assert_abs_diff_eq!(matrix[[2, 2]], 0.0, epsilon = 1e-6);
1209
1210        // Distance from (0,0) to (1,0) should be 1
1211        assert_abs_diff_eq!(matrix[[0, 1]], 1.0, epsilon = 1e-6);
1212        assert_abs_diff_eq!(matrix[[1, 0]], 1.0, epsilon = 1e-6);
1213
1214        // Distance from (0,0) to (0,1) should be 1
1215        assert_abs_diff_eq!(matrix[[0, 2]], 1.0, epsilon = 1e-6);
1216        assert_abs_diff_eq!(matrix[[2, 0]], 1.0, epsilon = 1e-6);
1217
1218        // Distance from (1,0) to (0,1) should be sqrt(2)
1219        assert_abs_diff_eq!(matrix[[1, 2]], (2.0_f64).sqrt(), epsilon = 1e-6);
1220        assert_abs_diff_eq!(matrix[[2, 1]], (2.0_f64).sqrt(), epsilon = 1e-6);
1221    }
1222
1223    #[test]
1224    fn test_adaptive_distance_batch() {
1225        let small_data = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1226        let large_data =
1227            Array2::from_shape_vec((10, 4), (0..40).map(|x| x as f64).collect()).unwrap();
1228        let query = array![0.0, 0.0, 0.0, 0.0];
1229
1230        // Small data should use scalar
1231        let small_result = adaptive_distance_batch(
1232            &small_data.view(),
1233            &query.view().slice(scirs2_core::ndarray::s![..2]),
1234            SimdDistanceMetric::Euclidean,
1235            5,
1236        )
1237        .unwrap();
1238        assert_eq!(small_result.len(), 2);
1239
1240        // Large data should use SIMD
1241        let large_result = adaptive_distance_batch(
1242            &large_data.view(),
1243            &query.view(),
1244            SimdDistanceMetric::Euclidean,
1245            5,
1246        )
1247        .unwrap();
1248        assert_eq!(large_result.len(), 10);
1249    }
1250
1251    #[cfg(feature = "parallel")]
1252    #[test]
1253    fn test_parallel_simd_distance_batch() {
1254        let data = Array2::from_shape_vec((6, 3), (0..18).map(|x| x as f64).collect()).unwrap();
1255        let query = array![0.0, 0.0, 0.0];
1256
1257        let parallel_result = simd_distance_batch_parallel(
1258            &data.view(),
1259            &query.view(),
1260            SimdDistanceMetric::Euclidean,
1261        )
1262        .unwrap();
1263        let sequential_result =
1264            simd_distance_batch_query(&data.view(), &query.view(), SimdDistanceMetric::Euclidean)
1265                .unwrap();
1266
1267        assert_eq!(parallel_result.len(), sequential_result.len());
1268        for (par, seq) in parallel_result.iter().zip(sequential_result.iter()) {
1269            assert_abs_diff_eq!(par, seq, epsilon = 1e-6);
1270        }
1271    }
1272
1273    #[test]
1274    fn test_optimized_distance_computer_performance() {
1275        use scirs2_core::ndarray::{Array1, Array2};
1276
1277        let computer = OptimizedDistanceComputer::new();
1278        let n_points = 100;
1279        let n_features = 10;
1280
1281        // Generate test data
1282        let points1 = Array2::<Float>::ones((n_points, n_features));
1283        let points2 = Array2::<Float>::zeros((n_points, n_features));
1284
1285        // Test pairwise distances
1286        let distances = computer.pairwise_distances(
1287            &points1.view(),
1288            &points2.view(),
1289            DistanceMetric::Euclidean,
1290        );
1291
1292        assert_eq!(distances.dim(), (n_points, n_points));
1293
1294        // All distances should be sqrt(n_features) since points are (1,1,1...) vs (0,0,0...)
1295        let expected_distance = (n_features as Float).sqrt();
1296        for &dist in distances.iter() {
1297            assert!((dist - expected_distance).abs() < 1e-6);
1298        }
1299    }
1300
1301    #[test]
1302    fn test_simd_detection() {
1303        let computer = OptimizedDistanceComputer::new();
1304
1305        // Test that SIMD detection works without crashing
1306        let point1 = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
1307        let point2 = Array1::from(vec![2.0, 3.0, 4.0, 5.0]);
1308
1309        let distance =
1310            computer.compute_distance(&point1.view(), &point2.view(), DistanceMetric::Euclidean);
1311
1312        // Distance should be 2.0 (sqrt of 4 * 1^2)
1313        assert!((distance - 2.0).abs() < 1e-6);
1314    }
1315}