Skip to main content

scirs2_cluster/meanshift/
mod.rs

1//! Mean Shift clustering implementation
2//!
3//! Mean Shift is a non-parametric clustering technique that does not require
4//! specifying the number of clusters in advance. It works by iteratively
5//! shifting each data point towards the mode of the local density.
6//!
7//! # Features
8//!
9//! - **Flat kernel** and **Gaussian kernel** support
10//! - **Bandwidth estimation**: Silverman's rule, Scott's rule, k-NN quantile
11//! - **Bin seeding** for acceleration on large datasets
12//! - **Cluster-all** mode and noise detection mode
13//!
14//! # Examples
15//!
16//! ```
17//! use scirs2_core::ndarray::array;
18//! use scirs2_cluster::meanshift::{mean_shift, MeanShiftOptions, KernelType};
19//!
20//! let data = array![
21//!     [1.0, 1.0], [2.0, 1.0], [1.0, 0.0],
22//!     [4.0, 7.0], [3.0, 5.0], [3.0, 6.0]
23//! ];
24//!
25//! let options = MeanShiftOptions {
26//!     bandwidth: Some(2.0),
27//!     kernel: KernelType::Gaussian,
28//!     ..Default::default()
29//! };
30//!
31//! let (centers, labels) = mean_shift(&data.view(), options).expect("Operation failed");
32//! println!("Number of clusters: {}", centers.nrows());
33//! ```
34
35use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
36use scirs2_core::numeric::{Float, FromPrimitive};
37use std::collections::HashMap;
38use std::fmt::{Debug, Display};
39use std::hash::{Hash, Hasher};
40use std::marker::{Send, Sync};
41
42use crate::error::ClusteringError;
43use scirs2_core::validation::{
44    check_positive, checkarray_finite, clustering::validate_clustering_data,
45    parameters::check_unit_interval,
46};
47use scirs2_spatial::distance::EuclideanDistance;
48use scirs2_spatial::kdtree::KDTree;
49
50/// Kernel type for Mean Shift
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum KernelType {
53    /// Flat (uniform) kernel: all points within bandwidth contribute equally
54    Flat,
55    /// Gaussian kernel: points are weighted by exp(-||x - xi||^2 / (2 * bandwidth^2))
56    Gaussian,
57}
58
59impl Default for KernelType {
60    fn default() -> Self {
61        KernelType::Flat
62    }
63}
64
65/// Bandwidth estimation method
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum BandwidthEstimator {
68    /// k-NN quantile method (default): uses quantile of nearest neighbor distances
69    KNNQuantile,
70    /// Silverman's rule of thumb: h = 0.9 * min(std, IQR/1.34) * n^(-1/5)
71    Silverman,
72    /// Scott's rule: h = n^(-1/(d+4)) * std
73    Scott,
74}
75
76impl Default for BandwidthEstimator {
77    fn default() -> Self {
78        BandwidthEstimator::KNNQuantile
79    }
80}
81
82/// Configuration options for Mean Shift algorithm
83pub struct MeanShiftOptions<T: Float> {
84    /// Bandwidth parameter.
85    /// If not provided, it will be estimated from the data.
86    pub bandwidth: Option<T>,
87
88    /// Points used as initial kernel locations.
89    /// If not provided, either all points or discretized bins will be used.
90    pub seeds: Option<Array2<T>>,
91
92    /// If true, initial kernels are located on a grid with bin_size = bandwidth.
93    pub bin_seeding: bool,
94
95    /// Only bins with at least min_bin_freq points will be selected as seeds.
96    pub min_bin_freq: usize,
97
98    /// If true, all points are assigned to clusters, even outliers.
99    pub cluster_all: bool,
100
101    /// Maximum number of iterations for a single seed.
102    pub max_iter: usize,
103
104    /// Kernel type to use
105    pub kernel: KernelType,
106
107    /// Bandwidth estimation method (used when bandwidth is None)
108    pub bandwidth_estimator: BandwidthEstimator,
109}
110
111impl<T: Float> Default for MeanShiftOptions<T> {
112    fn default() -> Self {
113        Self {
114            bandwidth: None,
115            seeds: None,
116            bin_seeding: false,
117            min_bin_freq: 1,
118            cluster_all: true,
119            max_iter: 300,
120            kernel: KernelType::Flat,
121            bandwidth_estimator: BandwidthEstimator::KNNQuantile,
122        }
123    }
124}
125
126/// FloatPoint wrapper to make f32/f64 arrays comparable and hashable
127#[derive(Debug, Clone)]
128struct FloatPoint<T: Float>(Vec<T>);
129
130impl<T: Float> PartialEq for FloatPoint<T> {
131    fn eq(&self, other: &Self) -> bool {
132        if self.0.len() != other.0.len() {
133            return false;
134        }
135
136        for (a, b) in self.0.iter().zip(other.0.iter()) {
137            if !a.is_finite() || !b.is_finite() || (*a - *b).abs() > T::epsilon() {
138                return false;
139            }
140        }
141        true
142    }
143}
144
145impl<T: Float> Eq for FloatPoint<T> {}
146
147impl<T: Float> Hash for FloatPoint<T> {
148    fn hash<H: Hasher>(&self, state: &mut H) {
149        for value in &self.0 {
150            let bits = if let Some(bits) = value.to_f64() {
151                (bits * 1e10).round() as i64
152            } else {
153                0
154            };
155            bits.hash(state);
156        }
157    }
158}
159
160/// Estimate bandwidth using Silverman's rule of thumb
161///
162/// h = 0.9 * min(std, IQR/1.34) * n^(-1/5)
163///
164/// This works well for normally distributed data.
165pub fn estimate_bandwidth_silverman<T: Float + Display + FromPrimitive + Send + Sync + 'static>(
166    data: &ArrayView2<T>,
167) -> Result<T, ClusteringError> {
168    checkarray_finite(data, "data")?;
169
170    let n = data.nrows();
171    if n < 2 {
172        return Ok(T::from(1.0).ok_or_else(|| {
173            ClusteringError::ComputationError("Failed to convert constant".into())
174        })?);
175    }
176
177    let n_features = data.ncols();
178    let n_f = T::from(n)
179        .ok_or_else(|| ClusteringError::ComputationError("Failed to convert n".into()))?;
180
181    // Compute bandwidth per dimension and take the average
182    let mut bandwidth_sum = T::zero();
183
184    for col_idx in 0..n_features {
185        // Gather column values
186        let mut values: Vec<T> = (0..n).map(|i| data[[i, col_idx]]).collect();
187        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
188
189        // Standard deviation
190        let mean = values.iter().fold(T::zero(), |a, &b| a + b) / n_f;
191        let var = values
192            .iter()
193            .fold(T::zero(), |acc, &v| acc + (v - mean) * (v - mean))
194            / n_f;
195        let std_dev = var.sqrt();
196
197        // IQR
198        let q1_idx = n / 4;
199        let q3_idx = (3 * n) / 4;
200        let iqr = values[q3_idx.min(n - 1)] - values[q1_idx];
201        let one_point_three_four = T::from(1.34).ok_or_else(|| {
202            ClusteringError::ComputationError("Failed to convert constant".into())
203        })?;
204        let iqr_scaled = iqr / one_point_three_four;
205
206        // min(std, IQR/1.34), but skip IQR if it's zero
207        let spread = if iqr_scaled > T::zero() && iqr_scaled < std_dev {
208            iqr_scaled
209        } else {
210            std_dev
211        };
212
213        // Silverman factor: 0.9 * spread * n^(-1/5)
214        let zero_nine = T::from(0.9).ok_or_else(|| {
215            ClusteringError::ComputationError("Failed to convert constant".into())
216        })?;
217        let exponent = T::from(-0.2).ok_or_else(|| {
218            ClusteringError::ComputationError("Failed to convert constant".into())
219        })?;
220        let n_factor = n_f.powf(exponent);
221
222        let h = zero_nine * spread * n_factor;
223        bandwidth_sum = bandwidth_sum + h;
224    }
225
226    let n_feat_f = T::from(n_features)
227        .ok_or_else(|| ClusteringError::ComputationError("Failed to convert n_features".into()))?;
228    let bandwidth = bandwidth_sum / n_feat_f;
229
230    // Ensure positive bandwidth
231    if bandwidth <= T::zero() {
232        return Ok(T::from(1.0).ok_or_else(|| {
233            ClusteringError::ComputationError("Failed to convert constant".into())
234        })?);
235    }
236
237    Ok(bandwidth)
238}
239
240/// Estimate bandwidth using Scott's rule
241///
242/// h = n^(-1/(d+4)) * std
243///
244/// Good general-purpose estimator for multivariate data.
245pub fn estimate_bandwidth_scott<T: Float + Display + FromPrimitive + Send + Sync + 'static>(
246    data: &ArrayView2<T>,
247) -> Result<T, ClusteringError> {
248    checkarray_finite(data, "data")?;
249
250    let n = data.nrows();
251    if n < 2 {
252        return Ok(T::from(1.0).ok_or_else(|| {
253            ClusteringError::ComputationError("Failed to convert constant".into())
254        })?);
255    }
256
257    let n_features = data.ncols();
258    let n_f = T::from(n)
259        .ok_or_else(|| ClusteringError::ComputationError("Failed to convert n".into()))?;
260
261    // Scott's exponent: -1/(d+4)
262    let d_plus_4 = T::from(n_features as f64 + 4.0)
263        .ok_or_else(|| ClusteringError::ComputationError("Failed to convert dimension".into()))?;
264    let exponent = T::from(-1.0)
265        .ok_or_else(|| ClusteringError::ComputationError("Failed to convert constant".into()))?
266        / d_plus_4;
267    let n_factor = n_f.powf(exponent);
268
269    // Average standard deviation across dimensions
270    let mut std_sum = T::zero();
271    for col_idx in 0..n_features {
272        let mean = (0..n)
273            .map(|i| data[[i, col_idx]])
274            .fold(T::zero(), |a, b| a + b)
275            / n_f;
276        let var = (0..n)
277            .map(|i| {
278                let diff = data[[i, col_idx]] - mean;
279                diff * diff
280            })
281            .fold(T::zero(), |a, b| a + b)
282            / n_f;
283        std_sum = std_sum + var.sqrt();
284    }
285
286    let avg_std = std_sum
287        / T::from(n_features).ok_or_else(|| {
288            ClusteringError::ComputationError("Failed to convert n_features".into())
289        })?;
290
291    let bandwidth = n_factor * avg_std;
292
293    if bandwidth <= T::zero() {
294        return Ok(T::from(1.0).ok_or_else(|| {
295            ClusteringError::ComputationError("Failed to convert constant".into())
296        })?);
297    }
298
299    Ok(bandwidth)
300}
301
302/// Estimate the bandwidth using k-NN quantile method.
303///
304/// Computes the average distance to the k-th nearest neighbor across all points,
305/// where k = quantile * n_samples.
306pub fn estimate_bandwidth<T: Float + Display + FromPrimitive + Send + Sync + 'static>(
307    data: &ArrayView2<T>,
308    quantile: Option<T>,
309    n_samples: Option<usize>,
310    _random_state: Option<u64>,
311) -> Result<T, ClusteringError> {
312    checkarray_finite(data, "data")?;
313
314    let quantile = quantile
315        .unwrap_or_else(|| T::from(0.3).unwrap_or_else(|| T::from(0.3f64).unwrap_or(T::one())));
316    let _quantile = check_unit_interval(quantile, "quantile", "estimate_bandwidth")?;
317
318    // Select a subset of samples if specified
319    let data = if let Some(n) = n_samples {
320        if n >= data.nrows() {
321            data.to_owned()
322        } else {
323            let mut rng = scirs2_core::random::rng();
324            use scirs2_core::random::seq::SliceRandom;
325            let mut indices: Vec<usize> = (0..data.nrows()).collect();
326            indices.shuffle(&mut rng);
327
328            let indices = &indices[0..n];
329            let mut sampled_data = Array2::zeros((n, data.ncols()));
330            for (i, &idx) in indices.iter().enumerate() {
331                sampled_data.row_mut(i).assign(&data.row(idx));
332            }
333            sampled_data
334        }
335    } else {
336        data.to_owned()
337    };
338
339    let n_neighbors = (T::from(data.nrows()).unwrap_or(T::one()) * quantile)
340        .to_usize()
341        .unwrap_or(1)
342        .max(1)
343        .min(data.nrows().saturating_sub(1));
344
345    // Build KDTree for nearest neighbor search
346    let kdtree = KDTree::<_, EuclideanDistance<T>>::new(&data)
347        .map_err(|e| ClusteringError::ComputationError(format!("Failed to build KDTree: {}", e)))?;
348
349    let mut bandwidth_sum = T::zero();
350
351    let batch_size = 500;
352    for i in (0..data.nrows()).step_by(batch_size) {
353        let end = (i + batch_size).min(data.nrows());
354        let batch = data.slice(scirs2_core::ndarray::s![i..end, ..]);
355
356        for row in batch.rows() {
357            let (_, distances) = kdtree.query(&row.to_vec(), n_neighbors + 1).map_err(|e| {
358                ClusteringError::ComputationError(format!("Failed to query KDTree: {}", e))
359            })?;
360
361            if distances.len() > 1 {
362                let kth_dist = distances
363                    .last()
364                    .copied()
365                    .unwrap_or_else(|| T::from(1.0).unwrap_or(T::one()));
366                bandwidth_sum = bandwidth_sum + kth_dist;
367            } else if !distances.is_empty() {
368                bandwidth_sum = bandwidth_sum + T::from(1.0).unwrap_or(T::one());
369            }
370        }
371    }
372
373    Ok(bandwidth_sum / T::from(data.nrows()).unwrap_or(T::one()))
374}
375
376/// Find seeds for mean_shift by binning data onto a grid.
377pub fn get_bin_seeds<T: Float + Display + FromPrimitive + Send + Sync + 'static>(
378    data: &ArrayView2<T>,
379    bin_size: T,
380    min_bin_freq: usize,
381) -> Array2<T> {
382    if bin_size <= T::zero() {
383        return data.to_owned();
384    }
385
386    let mut bin_sizes: HashMap<FloatPoint<T>, usize> = HashMap::new();
387
388    for row in data.rows() {
389        let mut binned_point = Vec::with_capacity(row.len());
390        for &val in row.iter() {
391            binned_point.push((val / bin_size).round() * bin_size);
392        }
393        let point = FloatPoint::<T>(binned_point);
394        *bin_sizes.entry(point).or_insert(0) += 1;
395    }
396
397    let seeds: Vec<Vec<T>> = bin_sizes
398        .into_iter()
399        .filter(|(_, freq)| *freq >= min_bin_freq)
400        .map(|(point, _)| point.0)
401        .collect();
402
403    if seeds.len() == data.nrows() {
404        return data.to_owned();
405    }
406
407    if seeds.is_empty() {
408        Array2::zeros((0, data.ncols()))
409    } else {
410        let mut result = Array2::zeros((seeds.len(), data.ncols()));
411        for (i, seed) in seeds.into_iter().enumerate() {
412            for (j, val) in seed.into_iter().enumerate() {
413                result[[i, j]] = val;
414            }
415        }
416        result
417    }
418}
419
420/// Mean Shift single seed update with flat kernel
421fn mean_shift_single_seed_flat<
422    T: Float
423        + Display
424        + std::iter::Sum
425        + FromPrimitive
426        + Send
427        + Sync
428        + 'static
429        + scirs2_core::ndarray::ScalarOperand,
430>(
431    seed: ArrayView1<T>,
432    data: &ArrayView2<T>,
433    bandwidth: T,
434    max_iter: usize,
435) -> (Vec<T>, usize, usize) {
436    let stop_thresh = bandwidth * T::from(1e-3).unwrap_or(T::epsilon());
437    let mut my_mean = seed.to_owned();
438    let mut completed_iterations = 0;
439
440    let owned_data = data.to_owned();
441    let kdtree = match KDTree::<_, EuclideanDistance<T>>::new(&owned_data) {
442        Ok(tree) => tree,
443        Err(_) => return (seed.to_vec(), 0, 0),
444    };
445
446    loop {
447        let (indices, _distances) = match kdtree.query_radius(&my_mean.to_vec(), bandwidth) {
448            Ok((idx, distances)) => (idx, distances),
449            Err(_) => return (my_mean.to_vec(), 0, completed_iterations),
450        };
451
452        if indices.is_empty() {
453            break;
454        }
455        let my_old_mean = my_mean.clone();
456
457        // Flat kernel: equal weights for all neighbors
458        my_mean.fill(T::zero());
459        let mut sum = Array1::zeros(my_mean.dim());
460        for &point_idx in &indices {
461            let row_clone = data.row(point_idx).to_owned();
462            for (s, v) in sum.iter_mut().zip(row_clone.iter()) {
463                *s = *s + *v;
464            }
465        }
466        my_mean = sum / T::from(indices.len()).unwrap_or(T::one());
467
468        let mut dist_squared = T::zero();
469        for (a, b) in my_mean.iter().zip(my_old_mean.iter()) {
470            dist_squared = dist_squared + (*a - *b) * (*a - *b);
471        }
472        let dist = dist_squared.sqrt();
473
474        if dist <= stop_thresh || completed_iterations == max_iter {
475            break;
476        }
477
478        completed_iterations += 1;
479    }
480
481    let (final_indices, _) = match kdtree.query_radius(&my_mean.to_vec(), bandwidth) {
482        Ok((idx, distances)) => (idx, distances),
483        Err(_) => return (my_mean.to_vec(), 0, completed_iterations),
484    };
485
486    (my_mean.to_vec(), final_indices.len(), completed_iterations)
487}
488
489/// Mean Shift single seed update with Gaussian kernel
490fn mean_shift_single_seed_gaussian<
491    T: Float
492        + Display
493        + std::iter::Sum
494        + FromPrimitive
495        + Send
496        + Sync
497        + 'static
498        + scirs2_core::ndarray::ScalarOperand,
499>(
500    seed: ArrayView1<T>,
501    data: &ArrayView2<T>,
502    bandwidth: T,
503    max_iter: usize,
504) -> (Vec<T>, usize, usize) {
505    let stop_thresh = bandwidth * T::from(1e-3).unwrap_or(T::epsilon());
506    let mut my_mean = seed.to_owned();
507    let mut completed_iterations = 0;
508    let bw_sq = bandwidth * bandwidth;
509
510    // Use 3*bandwidth as the search radius for Gaussian kernel
511    let search_radius = bandwidth * T::from(3.0).unwrap_or(T::one() + T::one() + T::one());
512
513    let owned_data = data.to_owned();
514    let kdtree = match KDTree::<_, EuclideanDistance<T>>::new(&owned_data) {
515        Ok(tree) => tree,
516        Err(_) => return (seed.to_vec(), 0, 0),
517    };
518
519    loop {
520        let (indices, distances) = match kdtree.query_radius(&my_mean.to_vec(), search_radius) {
521            Ok((idx, distances)) => (idx, distances),
522            Err(_) => return (my_mean.to_vec(), 0, completed_iterations),
523        };
524
525        if indices.is_empty() {
526            break;
527        }
528        let my_old_mean = my_mean.clone();
529
530        // Gaussian kernel: weight = exp(-dist^2 / (2 * bw^2))
531        let two = T::from(2.0).unwrap_or(T::one() + T::one());
532        let n_features = my_mean.dim();
533        let mut weighted_sum = Array1::zeros(n_features);
534        let mut weight_total = T::zero();
535
536        for (local_idx, &point_idx) in indices.iter().enumerate() {
537            let dist = distances[local_idx];
538            let dist_sq = dist * dist;
539            let weight = (-dist_sq / (two * bw_sq)).exp();
540
541            let row = data.row(point_idx);
542            for (ws, &v) in weighted_sum.iter_mut().zip(row.iter()) {
543                *ws = *ws + v * weight;
544            }
545            weight_total = weight_total + weight;
546        }
547
548        if weight_total > T::zero() {
549            my_mean = weighted_sum / weight_total;
550        }
551
552        let mut dist_squared = T::zero();
553        for (a, b) in my_mean.iter().zip(my_old_mean.iter()) {
554            dist_squared = dist_squared + (*a - *b) * (*a - *b);
555        }
556        let dist = dist_squared.sqrt();
557
558        if dist <= stop_thresh || completed_iterations == max_iter {
559            break;
560        }
561
562        completed_iterations += 1;
563    }
564
565    let (final_indices, _) = match kdtree.query_radius(&my_mean.to_vec(), bandwidth) {
566        Ok((idx, distances)) => (idx, distances),
567        Err(_) => return (my_mean.to_vec(), 0, completed_iterations),
568    };
569
570    (my_mean.to_vec(), final_indices.len(), completed_iterations)
571}
572
573/// Perform Mean Shift clustering.
574///
575/// # Arguments
576///
577/// * `data` - The input data as a 2D array.
578/// * `options` - The Mean Shift algorithm options.
579///
580/// # Returns
581///
582/// * `Result<(Array2<T>, Array1<i32>), ClusteringError>` - Tuple of (cluster centers, labels).
583pub fn mean_shift<
584    T: Float
585        + Display
586        + std::iter::Sum
587        + FromPrimitive
588        + Send
589        + Sync
590        + 'static
591        + scirs2_core::ndarray::ScalarOperand
592        + Debug,
593>(
594    data: &ArrayView2<T>,
595    options: MeanShiftOptions<T>,
596) -> Result<(Array2<T>, Array1<i32>), ClusteringError> {
597    let mut model = MeanShift::new(options);
598    let model = model.fit(data)?;
599    Ok((
600        model.cluster_centers().to_owned(),
601        model.labels().to_owned(),
602    ))
603}
604
605/// Mean Shift clustering model.
606pub struct MeanShift<T: Float> {
607    options: MeanShiftOptions<T>,
608    cluster_centers_: Option<Array2<T>>,
609    labels_: Option<Array1<i32>>,
610    n_iter_: usize,
611    bandwidth_used_: Option<T>,
612}
613
614impl<
615        T: Float
616            + Display
617            + std::iter::Sum
618            + FromPrimitive
619            + Send
620            + Sync
621            + 'static
622            + scirs2_core::ndarray::ScalarOperand
623            + Debug,
624    > MeanShift<T>
625{
626    /// Create a new Mean Shift instance.
627    pub fn new(options: MeanShiftOptions<T>) -> Self {
628        Self {
629            options,
630            cluster_centers_: None,
631            labels_: None,
632            n_iter_: 0,
633            bandwidth_used_: None,
634        }
635    }
636
637    /// Fit the Mean Shift model to data.
638    pub fn fit(&mut self, data: &ArrayView2<T>) -> Result<&mut Self, ClusteringError> {
639        let config = crate::input_validation::ValidationConfig::default();
640        crate::input_validation::validate_clustering_data(data.view(), &config)?;
641
642        let (n_samples, n_features) = data.dim();
643
644        // Determine bandwidth
645        let bandwidth = match self.options.bandwidth {
646            Some(bw) => check_positive(bw, "bandwidth")?,
647            None => match self.options.bandwidth_estimator {
648                BandwidthEstimator::Silverman => estimate_bandwidth_silverman(data)?,
649                BandwidthEstimator::Scott => estimate_bandwidth_scott(data)?,
650                BandwidthEstimator::KNNQuantile => {
651                    estimate_bandwidth(data, Some(T::from(0.3).unwrap_or(T::one())), None, None)?
652                }
653            },
654        };
655        self.bandwidth_used_ = Some(bandwidth);
656
657        // Get seeds
658        let seeds = match &self.options.seeds {
659            Some(s) => s.clone(),
660            None => {
661                if self.options.bin_seeding {
662                    get_bin_seeds(data, bandwidth, self.options.min_bin_freq)
663                } else {
664                    data.to_owned()
665                }
666            }
667        };
668
669        if seeds.is_empty() {
670            return Err(ClusteringError::ComputationError(
671                "No seeds provided and bin seeding produced no seeds".to_string(),
672            ));
673        }
674
675        // Run mean shift on each seed with the appropriate kernel
676        let kernel = self.options.kernel;
677        let max_iter = self.options.max_iter;
678
679        let seed_results: Vec<_> = seeds
680            .axis_iter(Axis(0))
681            .map(|seed| match kernel {
682                KernelType::Flat => mean_shift_single_seed_flat(seed, data, bandwidth, max_iter),
683                KernelType::Gaussian => {
684                    mean_shift_single_seed_gaussian(seed, data, bandwidth, max_iter)
685                }
686            })
687            .collect();
688
689        // Process results
690        let mut center_intensity_dict: HashMap<FloatPoint<T>, usize> = HashMap::new();
691        for (center, size, iterations) in seed_results {
692            if size > 0 {
693                center_intensity_dict.insert(FloatPoint(center), size);
694            }
695            self.n_iter_ = self.n_iter_.max(iterations);
696        }
697
698        if center_intensity_dict.is_empty() {
699            return Err(ClusteringError::ComputationError(format!(
700                "No point was within bandwidth={} of any seed. \
701                 Try a different seeding strategy or increase the bandwidth.",
702                bandwidth
703            )));
704        }
705
706        // Sort centers by intensity
707        let mut sorted_by_intensity: Vec<_> = center_intensity_dict.into_iter().collect();
708        sorted_by_intensity.sort_by(|a, b| {
709            b.1.cmp(&a.1).then_with(|| {
710                a.0 .0
711                    .iter()
712                    .zip(b.0 .0.iter())
713                    .find_map(|(a_val, b_val)| a_val.partial_cmp(b_val))
714                    .unwrap_or(std::cmp::Ordering::Equal)
715            })
716        });
717
718        if !self.options.cluster_all {
719            let min_density_threshold = 2;
720            sorted_by_intensity.retain(|(_, intensity)| *intensity >= min_density_threshold);
721
722            if sorted_by_intensity.is_empty() {
723                return Err(ClusteringError::ComputationError(
724                    "No clusters found with sufficient density.".to_string(),
725                ));
726            }
727        }
728
729        // Convert to Array2
730        let mut sorted_centers = Array2::zeros((sorted_by_intensity.len(), n_features));
731        for (i, center_) in sorted_by_intensity.iter().enumerate() {
732            for (j, &val) in center_.0 .0.iter().enumerate() {
733                sorted_centers[[i, j]] = val;
734            }
735        }
736
737        // Remove near-duplicate centers
738        let mut unique = vec![true; sorted_centers.nrows()];
739
740        let kdtree = KDTree::<_, EuclideanDistance<T>>::new(&sorted_centers).map_err(|e| {
741            ClusteringError::ComputationError(format!("Failed to build KDTree: {}", e))
742        })?;
743
744        let merge_threshold = bandwidth * T::from(0.1).unwrap_or(T::epsilon());
745
746        for i in 0..sorted_centers.nrows() {
747            if unique[i] {
748                let (indices_, _) = kdtree
749                    .query_radius(&sorted_centers.row(i).to_vec(), merge_threshold)
750                    .map_err(|e| {
751                        ClusteringError::ComputationError(format!("Failed to query KDTree: {}", e))
752                    })?;
753
754                for &idx in indices_.iter() {
755                    if idx != i {
756                        unique[idx] = false;
757                    }
758                }
759            }
760        }
761
762        let unique_indices: Vec<_> = unique
763            .iter()
764            .enumerate()
765            .filter(|&(_, &is_unique)| is_unique)
766            .map(|(i_, _)| i_)
767            .collect();
768
769        let mut cluster_centers = Array2::zeros((unique_indices.len(), n_features));
770        for (i, &idx) in unique_indices.iter().enumerate() {
771            cluster_centers.row_mut(i).assign(&sorted_centers.row(idx));
772        }
773
774        // Assign labels
775        let centers_kdtree =
776            KDTree::<_, EuclideanDistance<T>>::new(&cluster_centers).map_err(|e| {
777                ClusteringError::ComputationError(format!("Failed to build KDTree: {}", e))
778            })?;
779
780        let mut labels = Array1::zeros(n_samples);
781
782        let batch_size = 1000;
783        for i in (0..n_samples).step_by(batch_size) {
784            let end = (i + batch_size).min(n_samples);
785            let batch = data.slice(scirs2_core::ndarray::s![i..end, ..]);
786
787            for (row_idx, row) in batch.rows().into_iter().enumerate() {
788                let point_idx = i + row_idx;
789
790                let (indices, distances) = centers_kdtree.query(&row.to_vec(), 1).map_err(|e| {
791                    ClusteringError::ComputationError(format!("Failed to query KDTree: {}", e))
792                })?;
793
794                if !indices.is_empty() {
795                    let idx = indices[0];
796                    let distance = T::from(distances[0]).unwrap_or(T::zero());
797
798                    if self.options.cluster_all || (distance <= bandwidth) {
799                        labels[point_idx] =
800                            T::to_i32(&T::from(idx).unwrap_or(T::zero())).unwrap_or(0);
801                    } else {
802                        labels[point_idx] = -1;
803                    }
804                } else {
805                    labels[point_idx] = -1;
806                }
807            }
808        }
809
810        self.cluster_centers_ = Some(cluster_centers);
811        self.labels_ = Some(labels);
812
813        Ok(self)
814    }
815
816    /// Get cluster centers found by the algorithm.
817    pub fn cluster_centers(&self) -> &Array2<T> {
818        self.cluster_centers_
819            .as_ref()
820            .expect("Model has not been fitted yet")
821    }
822
823    /// Get labels assigned to each data point.
824    pub fn labels(&self) -> &Array1<i32> {
825        self.labels_
826            .as_ref()
827            .expect("Model has not been fitted yet")
828    }
829
830    /// Get the number of iterations performed for the most complex seed.
831    pub fn n_iter(&self) -> usize {
832        self.n_iter_
833    }
834
835    /// Get the bandwidth that was actually used (useful when auto-estimated).
836    pub fn bandwidth_used(&self) -> Option<T> {
837        self.bandwidth_used_
838    }
839
840    /// Predict the closest cluster each sample in data belongs to.
841    pub fn predict(&self, data: &ArrayView2<T>) -> Result<Array1<i32>, ClusteringError> {
842        let centers = self.cluster_centers_.as_ref().ok_or_else(|| {
843            ClusteringError::InvalidState("Model has not been fitted yet".to_string())
844        })?;
845
846        checkarray_finite(data, "prediction data")?;
847
848        let n_samples = data.nrows();
849        let mut labels = Array1::zeros(n_samples);
850
851        let kdtree = KDTree::<_, EuclideanDistance<T>>::new(centers).map_err(|e| {
852            ClusteringError::ComputationError(format!("Failed to build KDTree: {}", e))
853        })?;
854
855        let batch_size = 1000;
856        for i in (0..n_samples).step_by(batch_size) {
857            let end = (i + batch_size).min(n_samples);
858            let batch = data.slice(scirs2_core::ndarray::s![i..end, ..]);
859
860            for (row_idx, row) in batch.rows().into_iter().enumerate() {
861                let (indices_, _distances) = kdtree.query(&row.to_vec(), 1).map_err(|e| {
862                    ClusteringError::ComputationError(format!("Failed to query KDTree: {}", e))
863                })?;
864
865                if !indices_.is_empty() {
866                    labels[i + row_idx] =
867                        T::to_i32(&T::from(indices_[0]).unwrap_or(T::zero())).unwrap_or(0);
868                } else {
869                    labels[i + row_idx] = -1;
870                }
871            }
872        }
873
874        Ok(labels)
875    }
876}
877
878#[cfg(test)]
879mod tests {
880    use super::*;
881    use scirs2_core::ndarray::{array, Array2};
882    use std::collections::HashSet;
883
884    fn make_test_data() -> Array2<f64> {
885        array![
886            [1.0, 1.0],
887            [2.0, 1.0],
888            [1.0, 0.0],
889            [4.0, 7.0],
890            [3.0, 5.0],
891            [3.0, 6.0]
892        ]
893    }
894
895    #[test]
896    fn test_estimate_bandwidth() {
897        let data = make_test_data();
898        let bandwidth = estimate_bandwidth(&data.view(), Some(0.4), None, None)
899            .expect("Bandwidth estimation should succeed");
900
901        assert!(
902            bandwidth > 0.0,
903            "Bandwidth should be positive, got: {}",
904            bandwidth
905        );
906        assert!(
907            bandwidth < 20.0,
908            "Bandwidth should be reasonable, got: {}",
909            bandwidth
910        );
911    }
912
913    #[test]
914    fn test_estimate_bandwidth_silverman() {
915        let data = make_test_data();
916        let bandwidth = estimate_bandwidth_silverman(&data.view())
917            .expect("Silverman estimation should succeed");
918
919        assert!(bandwidth > 0.0, "Silverman bandwidth should be positive");
920        assert!(bandwidth < 20.0, "Silverman bandwidth should be reasonable");
921    }
922
923    #[test]
924    fn test_estimate_bandwidth_scott() {
925        let data = make_test_data();
926        let bandwidth =
927            estimate_bandwidth_scott(&data.view()).expect("Scott estimation should succeed");
928
929        assert!(bandwidth > 0.0, "Scott bandwidth should be positive");
930        assert!(bandwidth < 20.0, "Scott bandwidth should be reasonable");
931    }
932
933    #[test]
934    fn test_estimate_bandwidth_small_sample() {
935        let data = array![[1.0, 1.0]];
936        let bandwidth = estimate_bandwidth(&data.view(), Some(0.3), None, None)
937            .expect("Should work for single sample");
938        assert!(bandwidth > 0.0);
939        assert_eq!(bandwidth, 1.0);
940    }
941
942    #[test]
943    fn test_get_bin_seeds() {
944        let data = array![
945            [1.0, 1.0],
946            [1.4, 1.4],
947            [1.8, 1.2],
948            [2.0, 1.0],
949            [2.1, 1.1],
950            [0.0, 0.0]
951        ];
952
953        let bin_seeds = get_bin_seeds(&data.view(), 1.0, 1);
954        assert_eq!(bin_seeds.nrows(), 3);
955
956        let bin_seeds = get_bin_seeds(&data.view(), 1.0, 2);
957        assert_eq!(bin_seeds.nrows(), 2);
958
959        let bin_seeds = get_bin_seeds(&data.view(), 0.01, 1);
960        assert_eq!(bin_seeds.nrows(), data.nrows());
961    }
962
963    #[test]
964    fn test_mean_shift_flat_kernel() {
965        let data = make_test_data();
966
967        let options = MeanShiftOptions {
968            bandwidth: Some(2.0),
969            kernel: KernelType::Flat,
970            ..Default::default()
971        };
972
973        let (centers, labels) =
974            mean_shift(&data.view(), options).expect("Mean shift with flat kernel should succeed");
975
976        assert!(centers.nrows() >= 1, "Should find at least 1 cluster");
977        assert!(centers.nrows() <= 3, "Should find at most 3 clusters");
978        assert!(
979            labels.iter().all(|&l| l >= 0),
980            "All labels should be non-negative"
981        );
982    }
983
984    #[test]
985    fn test_mean_shift_gaussian_kernel() {
986        let data = make_test_data();
987
988        let options = MeanShiftOptions {
989            bandwidth: Some(2.0),
990            kernel: KernelType::Gaussian,
991            ..Default::default()
992        };
993
994        let (centers, labels) = mean_shift(&data.view(), options)
995            .expect("Mean shift with Gaussian kernel should succeed");
996
997        assert!(centers.nrows() >= 1, "Should find at least 1 cluster");
998        assert!(
999            labels.iter().all(|&l| l >= 0),
1000            "All labels should be non-negative"
1001        );
1002    }
1003
1004    #[test]
1005    fn test_mean_shift_bin_seeding() {
1006        let data = make_test_data();
1007
1008        let options = MeanShiftOptions {
1009            bandwidth: Some(2.0),
1010            bin_seeding: true,
1011            ..Default::default()
1012        };
1013
1014        let (centers, labels) =
1015            mean_shift(&data.view(), options).expect("Mean shift with bin seeding should succeed");
1016
1017        assert!(centers.nrows() >= 1);
1018        assert!(centers.nrows() <= 3);
1019        assert!(labels.iter().all(|&l| l >= 0));
1020    }
1021
1022    #[test]
1023    fn test_mean_shift_no_cluster_all() {
1024        let data = array![
1025            [1.0, 1.0],
1026            [2.0, 1.0],
1027            [1.0, 0.0],
1028            [4.0, 7.0],
1029            [3.0, 5.0],
1030            [3.0, 6.0],
1031            [10.0, 10.0]
1032        ];
1033
1034        let options = MeanShiftOptions {
1035            bandwidth: Some(2.0),
1036            cluster_all: false,
1037            ..Default::default()
1038        };
1039
1040        let (_centers, labels) =
1041            mean_shift(&data.view(), options).expect("Mean shift should succeed");
1042
1043        assert!(labels.iter().any(|&l| l == -1));
1044    }
1045
1046    #[test]
1047    fn test_mean_shift_max_iter() {
1048        let data = make_test_data();
1049
1050        let options = MeanShiftOptions {
1051            bandwidth: Some(2.0),
1052            max_iter: 1,
1053            ..Default::default()
1054        };
1055
1056        let mut model = MeanShift::new(options);
1057        model.fit(&data.view()).expect("Should fit");
1058
1059        assert_eq!(model.n_iter(), 1);
1060    }
1061
1062    #[test]
1063    fn test_mean_shift_predict() {
1064        let data = make_test_data();
1065
1066        let options = MeanShiftOptions {
1067            bandwidth: Some(2.0),
1068            ..Default::default()
1069        };
1070
1071        let mut model = MeanShift::new(options);
1072        model.fit(&data.view()).expect("Should fit");
1073
1074        let predicted_labels = model.predict(&data.view()).expect("Predict should succeed");
1075        assert_eq!(predicted_labels, model.labels().clone());
1076    }
1077
1078    #[test]
1079    fn test_mean_shift_silverman_bandwidth() {
1080        let data = make_test_data();
1081
1082        let options = MeanShiftOptions {
1083            bandwidth: None,
1084            bandwidth_estimator: BandwidthEstimator::Silverman,
1085            ..Default::default()
1086        };
1087
1088        let mut model = MeanShift::new(options);
1089        model
1090            .fit(&data.view())
1091            .expect("Should fit with Silverman bandwidth");
1092
1093        assert!(model.bandwidth_used().is_some());
1094        assert!(
1095            model.bandwidth_used().unwrap_or(0.0) > 0.0,
1096            "Silverman bandwidth should be positive"
1097        );
1098    }
1099
1100    #[test]
1101    fn test_mean_shift_scott_bandwidth() {
1102        let data = make_test_data();
1103
1104        let options = MeanShiftOptions {
1105            bandwidth: None,
1106            bandwidth_estimator: BandwidthEstimator::Scott,
1107            ..Default::default()
1108        };
1109
1110        let mut model = MeanShift::new(options);
1111        model
1112            .fit(&data.view())
1113            .expect("Should fit with Scott bandwidth");
1114
1115        assert!(model.bandwidth_used().is_some());
1116        assert!(
1117            model.bandwidth_used().unwrap_or(0.0) > 0.0,
1118            "Scott bandwidth should be positive"
1119        );
1120    }
1121
1122    #[test]
1123    fn test_mean_shift_large_dataset() {
1124        let mut data = Array2::zeros((20, 2));
1125
1126        for i in 0..10 {
1127            data[[i, 0]] = 1.0 + 0.05 * (i as f64);
1128            data[[i, 1]] = 1.0 + 0.05 * (i as f64);
1129        }
1130
1131        for i in 10..20 {
1132            data[[i, 0]] = 8.0 + 0.05 * ((i - 10) as f64);
1133            data[[i, 1]] = 8.0 + 0.05 * ((i - 10) as f64);
1134        }
1135
1136        let options = MeanShiftOptions {
1137            bandwidth: Some(1.5),
1138            bin_seeding: true,
1139            ..Default::default()
1140        };
1141
1142        let (centers, labels) =
1143            mean_shift(&data.view(), options).expect("Should handle larger dataset");
1144
1145        assert!(centers.nrows() >= 1);
1146        assert!(centers.nrows() <= 3);
1147
1148        let unique_labels: HashSet<_> = labels.iter().cloned().collect();
1149        assert!(!unique_labels.is_empty());
1150        assert!(unique_labels.len() <= centers.nrows());
1151    }
1152}