Skip to main content

scirs2_cluster/
streaming_cluster.rs

1//! Streaming (online) clustering algorithms
2//!
3//! This module provides algorithms designed for clustering data streams where
4//! points arrive sequentially and must be processed incrementally.
5//!
6//! # Algorithms
7//!
8//! - **CluStream**: Micro-cluster based stream clustering (Aggarwal et al. 2003)
9//! - **DenStream**: Density-based stream clustering (Cao et al. 2006)
10//! - **StreamKM++**: Coreset-based streaming k-means (Ackermann et al. 2012)
11//! - **Sliding window clustering**: Fixed-window online clustering
12//! - **Online K-means with forgetting factor**: Exponentially weighted online k-means
13
14use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
15use scirs2_core::numeric::{Float, FromPrimitive};
16use std::collections::VecDeque;
17use std::fmt::Debug;
18
19use crate::error::{ClusteringError, Result};
20
21// ---------------------------------------------------------------------------
22// Micro-cluster (shared primitive for CluStream / DenStream)
23// ---------------------------------------------------------------------------
24
25/// A micro-cluster summarising a set of nearby points.
26#[derive(Debug, Clone)]
27pub struct MicroCluster<F: Float> {
28    /// Linear sum of points (LS).
29    pub linear_sum: Vec<F>,
30    /// Squared sum of points (SS).
31    pub squared_sum: Vec<F>,
32    /// Number of points absorbed.
33    pub n_points: usize,
34    /// Creation timestamp.
35    pub creation_time: u64,
36    /// Last update timestamp.
37    pub last_update_time: u64,
38    /// Weight (may differ from n_points for fading models).
39    pub weight: F,
40}
41
42impl<F: Float + FromPrimitive + Debug> MicroCluster<F> {
43    /// Create a new micro-cluster from a single point.
44    pub fn from_point(point: &[F], timestamp: u64) -> Self {
45        let d = point.len();
46        let mut ls = vec![F::zero(); d];
47        let mut ss = vec![F::zero(); d];
48        for i in 0..d {
49            ls[i] = point[i];
50            ss[i] = point[i] * point[i];
51        }
52        Self {
53            linear_sum: ls,
54            squared_sum: ss,
55            n_points: 1,
56            creation_time: timestamp,
57            last_update_time: timestamp,
58            weight: F::one(),
59        }
60    }
61
62    /// Centroid of the micro-cluster.
63    pub fn centroid(&self) -> Vec<F> {
64        if self.weight <= F::epsilon() {
65            return self.linear_sum.clone();
66        }
67        self.linear_sum.iter().map(|&v| v / self.weight).collect()
68    }
69
70    /// Radius (RMS deviation from centroid).
71    pub fn radius(&self) -> F {
72        if self.weight <= F::one() {
73            return F::zero();
74        }
75        let d = self.linear_sum.len();
76        let w = self.weight;
77        let mut variance = F::zero();
78        for i in 0..d {
79            let mean = self.linear_sum[i] / w;
80            let mean_sq = self.squared_sum[i] / w;
81            let v = mean_sq - mean * mean;
82            variance = variance + if v > F::zero() { v } else { F::zero() };
83        }
84        (variance / F::from(d).unwrap_or_else(|| F::one())).sqrt()
85    }
86
87    /// Absorb a single point.
88    pub fn absorb(&mut self, point: &[F], timestamp: u64) {
89        let d = self.linear_sum.len().min(point.len());
90        for i in 0..d {
91            self.linear_sum[i] = self.linear_sum[i] + point[i];
92            self.squared_sum[i] = self.squared_sum[i] + point[i] * point[i];
93        }
94        self.n_points += 1;
95        self.weight = self.weight + F::one();
96        self.last_update_time = timestamp;
97    }
98
99    /// Merge another micro-cluster into this one.
100    pub fn merge(&mut self, other: &MicroCluster<F>) {
101        let d = self.linear_sum.len().min(other.linear_sum.len());
102        for i in 0..d {
103            self.linear_sum[i] = self.linear_sum[i] + other.linear_sum[i];
104            self.squared_sum[i] = self.squared_sum[i] + other.squared_sum[i];
105        }
106        self.n_points += other.n_points;
107        self.weight = self.weight + other.weight;
108        if other.last_update_time > self.last_update_time {
109            self.last_update_time = other.last_update_time;
110        }
111    }
112
113    /// Apply exponential fading with factor lambda over elapsed time.
114    pub fn apply_fading(&mut self, lambda: F, elapsed: F) {
115        let factor = (F::zero() - lambda * elapsed).exp();
116        let d = self.linear_sum.len();
117        for i in 0..d {
118            self.linear_sum[i] = self.linear_sum[i] * factor;
119            self.squared_sum[i] = self.squared_sum[i] * factor;
120        }
121        self.weight = self.weight * factor;
122    }
123
124    /// Squared distance from centroid to a point.
125    fn distance_sq_to(&self, point: &[F]) -> F {
126        let centroid = self.centroid();
127        let d = centroid.len().min(point.len());
128        let mut s = F::zero();
129        for i in 0..d {
130            let diff = centroid[i] - point[i];
131            s = s + diff * diff;
132        }
133        s
134    }
135}
136
137// ---------------------------------------------------------------------------
138// CluStream
139// ---------------------------------------------------------------------------
140
141/// Configuration for the CluStream algorithm.
142#[derive(Debug, Clone)]
143pub struct CluStreamConfig {
144    /// Maximum number of micro-clusters to maintain.
145    pub max_micro_clusters: usize,
146    /// Number of macro-clusters for final output.
147    pub n_macro_clusters: usize,
148    /// Time horizon for snapshot pyramids (T in the paper).
149    pub time_horizon: u64,
150    /// Maximum radius factor for absorbing into a micro-cluster.
151    pub radius_factor: f64,
152}
153
154impl Default for CluStreamConfig {
155    fn default() -> Self {
156        Self {
157            max_micro_clusters: 100,
158            n_macro_clusters: 5,
159            time_horizon: 1000,
160            radius_factor: 2.0,
161        }
162    }
163}
164
165/// CluStream online clustering algorithm.
166///
167/// Maintains a set of micro-clusters that summarise the data stream.
168/// Periodically, macro-clustering (e.g. weighted k-means on micro-cluster
169/// centroids) produces the final cluster assignments.
170pub struct CluStream<F: Float> {
171    config: CluStreamConfig,
172    micro_clusters: Vec<MicroCluster<F>>,
173    current_time: u64,
174    n_features: usize,
175    initialized: bool,
176}
177
178impl<F: Float + FromPrimitive + Debug> CluStream<F> {
179    /// Create a new CluStream instance.
180    pub fn new(config: CluStreamConfig) -> Self {
181        Self {
182            config,
183            micro_clusters: Vec::new(),
184            current_time: 0,
185            n_features: 0,
186            initialized: false,
187        }
188    }
189
190    /// Initialize with a batch of points.
191    pub fn initialize(&mut self, data: ArrayView2<F>) -> Result<()> {
192        let (n, d) = (data.shape()[0], data.shape()[1]);
193        if n == 0 {
194            return Err(ClusteringError::InvalidInput("Empty data".into()));
195        }
196        self.n_features = d;
197
198        // Run simple k-means to get initial micro-clusters
199        let k = self.config.max_micro_clusters.min(n);
200        let labels = simple_kmeans_init(data, k);
201
202        self.micro_clusters.clear();
203        for ci in 0..k {
204            let mut mc: Option<MicroCluster<F>> = None;
205            for i in 0..n {
206                if labels[i] == ci as i32 {
207                    match mc.as_mut() {
208                        Some(m) => m.absorb(data.row(i).as_slice().unwrap_or(&[]), 0),
209                        None => {
210                            mc = Some(MicroCluster::from_point(
211                                data.row(i).as_slice().unwrap_or(&[]),
212                                0,
213                            ));
214                        }
215                    }
216                }
217            }
218            if let Some(m) = mc {
219                self.micro_clusters.push(m);
220            }
221        }
222        self.initialized = true;
223        Ok(())
224    }
225
226    /// Process a single new data point from the stream.
227    pub fn process_point(&mut self, point: &[F]) -> Result<()> {
228        if !self.initialized {
229            return Err(ClusteringError::InvalidState(
230                "CluStream not initialized".into(),
231            ));
232        }
233        self.current_time += 1;
234
235        // Find nearest micro-cluster
236        let (nearest_idx, nearest_dist) = self.find_nearest_mc(point);
237
238        let rf = F::from(self.config.radius_factor)
239            .unwrap_or_else(|| F::from(2.0).unwrap_or_else(|| F::one()));
240
241        // Check if point fits within the micro-cluster radius
242        let fits = if let Some(mc) = self.micro_clusters.get(nearest_idx) {
243            let r = mc.radius();
244            nearest_dist.sqrt() <= r * rf + F::epsilon()
245        } else {
246            false
247        };
248
249        if fits {
250            if let Some(mc) = self.micro_clusters.get_mut(nearest_idx) {
251                mc.absorb(point, self.current_time);
252            }
253        } else {
254            // Create a new micro-cluster
255            if self.micro_clusters.len() >= self.config.max_micro_clusters {
256                // Merge two closest micro-clusters to make room
257                self.merge_closest_pair();
258            }
259            self.micro_clusters
260                .push(MicroCluster::from_point(point, self.current_time));
261        }
262
263        Ok(())
264    }
265
266    /// Process a batch of points.
267    pub fn process_batch(&mut self, data: ArrayView2<F>) -> Result<()> {
268        for i in 0..data.shape()[0] {
269            let row = data.row(i);
270            self.process_point(row.as_slice().unwrap_or(&[]))?;
271        }
272        Ok(())
273    }
274
275    /// Get current macro-cluster labels for the micro-clusters.
276    ///
277    /// Returns (micro-cluster centroids, macro-cluster labels for each micro-cluster).
278    pub fn get_macro_clusters(&self) -> Result<(Array2<F>, Array1<i32>)> {
279        if self.micro_clusters.is_empty() {
280            return Err(ClusteringError::InvalidState(
281                "No micro-clusters available".into(),
282            ));
283        }
284
285        let n_mc = self.micro_clusters.len();
286        let d = self.n_features;
287        let k = self.config.n_macro_clusters.min(n_mc);
288
289        // Build matrix of micro-cluster centroids with weights
290        let mut centroids = Array2::<F>::zeros((n_mc, d));
291        for (i, mc) in self.micro_clusters.iter().enumerate() {
292            let c = mc.centroid();
293            for j in 0..d.min(c.len()) {
294                centroids[[i, j]] = c[j];
295            }
296        }
297
298        let labels = simple_kmeans_init(centroids.view(), k);
299        Ok((centroids, labels))
300    }
301
302    /// Number of current micro-clusters.
303    pub fn n_micro_clusters(&self) -> usize {
304        self.micro_clusters.len()
305    }
306
307    /// Get reference to micro-clusters.
308    pub fn micro_clusters(&self) -> &[MicroCluster<F>] {
309        &self.micro_clusters
310    }
311
312    fn find_nearest_mc(&self, point: &[F]) -> (usize, F) {
313        let mut best_idx = 0;
314        let mut best_dist = F::infinity();
315        for (i, mc) in self.micro_clusters.iter().enumerate() {
316            let d = mc.distance_sq_to(point);
317            if d < best_dist {
318                best_dist = d;
319                best_idx = i;
320            }
321        }
322        (best_idx, best_dist)
323    }
324
325    fn merge_closest_pair(&mut self) {
326        if self.micro_clusters.len() < 2 {
327            return;
328        }
329        let n = self.micro_clusters.len();
330        let mut best_i = 0;
331        let mut best_j = 1;
332        let mut best_dist = F::infinity();
333        for i in 0..n {
334            let ci = self.micro_clusters[i].centroid();
335            for j in (i + 1)..n {
336                let cj = self.micro_clusters[j].centroid();
337                let d: F = ci
338                    .iter()
339                    .zip(cj.iter())
340                    .map(|(&a, &b)| (a - b) * (a - b))
341                    .fold(F::zero(), |acc, v| acc + v);
342                if d < best_dist {
343                    best_dist = d;
344                    best_i = i;
345                    best_j = j;
346                }
347            }
348        }
349        // Merge j into i, remove j
350        let mc_j = self.micro_clusters[best_j].clone();
351        self.micro_clusters[best_i].merge(&mc_j);
352        self.micro_clusters.remove(best_j);
353    }
354}
355
356// ---------------------------------------------------------------------------
357// DenStream
358// ---------------------------------------------------------------------------
359
360/// Configuration for the DenStream algorithm.
361#[derive(Debug, Clone)]
362pub struct DenStreamConfig {
363    /// Epsilon radius for DBSCAN-like macro-clustering.
364    pub epsilon: f64,
365    /// Minimum weight for a micro-cluster to be considered potential.
366    pub min_points: usize,
367    /// Fading factor (lambda): higher = faster forgetting.
368    pub lambda: f64,
369    /// Beta: threshold factor for potential vs outlier micro-clusters.
370    pub beta: f64,
371    /// Mu: minimum weight factor for potential micro-clusters.
372    pub mu: f64,
373    /// Time period for outlier cleanup.
374    pub cleanup_interval: u64,
375}
376
377impl Default for DenStreamConfig {
378    fn default() -> Self {
379        Self {
380            epsilon: 1.0,
381            min_points: 3,
382            lambda: 0.25,
383            beta: 0.2,
384            mu: 1.0,
385            cleanup_interval: 100,
386        }
387    }
388}
389
390/// DenStream: density-based stream clustering.
391///
392/// Maintains potential micro-clusters (p-micro-clusters) and outlier
393/// micro-clusters (o-micro-clusters). Points are absorbed into nearby
394/// p-micro-clusters or create outlier micro-clusters that may be promoted.
395pub struct DenStream<F: Float> {
396    config: DenStreamConfig,
397    /// Potential micro-clusters.
398    p_micro_clusters: Vec<MicroCluster<F>>,
399    /// Outlier micro-clusters.
400    o_micro_clusters: Vec<MicroCluster<F>>,
401    current_time: u64,
402    n_features: usize,
403    initialized: bool,
404}
405
406impl<F: Float + FromPrimitive + Debug> DenStream<F> {
407    /// Create a new DenStream instance.
408    pub fn new(config: DenStreamConfig) -> Self {
409        Self {
410            config,
411            p_micro_clusters: Vec::new(),
412            o_micro_clusters: Vec::new(),
413            current_time: 0,
414            n_features: 0,
415            initialized: false,
416        }
417    }
418
419    /// Initialize with a batch of data.
420    pub fn initialize(&mut self, data: ArrayView2<F>) -> Result<()> {
421        let (n, d) = (data.shape()[0], data.shape()[1]);
422        if n == 0 {
423            return Err(ClusteringError::InvalidInput("Empty data".into()));
424        }
425        self.n_features = d;
426
427        // DBSCAN-like initialization to create initial p-micro-clusters
428        let eps = F::from(self.config.epsilon).unwrap_or_else(|| F::one());
429        let min_pts = self.config.min_points;
430
431        // Simple: group nearby points into micro-clusters
432        let mut assigned = vec![false; n];
433        for i in 0..n {
434            if assigned[i] {
435                continue;
436            }
437            let mut mc = MicroCluster::from_point(data.row(i).as_slice().unwrap_or(&[]), 0);
438            assigned[i] = true;
439            for j in (i + 1)..n {
440                if assigned[j] {
441                    continue;
442                }
443                let dist_sq = row_dist_sq(data.row(i), data.row(j));
444                if dist_sq <= eps * eps {
445                    mc.absorb(data.row(j).as_slice().unwrap_or(&[]), 0);
446                    assigned[j] = true;
447                }
448            }
449            if mc.n_points >= min_pts {
450                self.p_micro_clusters.push(mc);
451            } else {
452                self.o_micro_clusters.push(mc);
453            }
454        }
455
456        self.initialized = true;
457        Ok(())
458    }
459
460    /// Process a single new data point.
461    pub fn process_point(&mut self, point: &[F]) -> Result<()> {
462        if !self.initialized {
463            return Err(ClusteringError::InvalidState(
464                "DenStream not initialized".into(),
465            ));
466        }
467        self.current_time += 1;
468        let lambda_f = F::from(self.config.lambda).unwrap_or_else(|| F::zero());
469        let eps = F::from(self.config.epsilon).unwrap_or_else(|| F::one());
470        let mu_f = F::from(self.config.mu).unwrap_or_else(|| F::one());
471
472        // Try to absorb into nearest p-micro-cluster
473        let (p_idx, p_dist) = nearest_mc_idx(&self.p_micro_clusters, point);
474        if !self.p_micro_clusters.is_empty() && p_dist.sqrt() <= eps {
475            self.p_micro_clusters[p_idx].absorb(point, self.current_time);
476        } else {
477            // Try outlier micro-clusters
478            let (o_idx, o_dist) = nearest_mc_idx(&self.o_micro_clusters, point);
479            if !self.o_micro_clusters.is_empty() && o_dist.sqrt() <= eps {
480                self.o_micro_clusters[o_idx].absorb(point, self.current_time);
481                // Check if outlier should be promoted
482                if self.o_micro_clusters[o_idx].weight >= mu_f {
483                    let promoted = self.o_micro_clusters.remove(o_idx);
484                    self.p_micro_clusters.push(promoted);
485                }
486            } else {
487                // Create new outlier micro-cluster
488                self.o_micro_clusters
489                    .push(MicroCluster::from_point(point, self.current_time));
490            }
491        }
492
493        // Periodic cleanup
494        if self.current_time % self.config.cleanup_interval == 0 {
495            self.cleanup(lambda_f);
496        }
497
498        Ok(())
499    }
500
501    /// Process a batch of points.
502    pub fn process_batch(&mut self, data: ArrayView2<F>) -> Result<()> {
503        for i in 0..data.shape()[0] {
504            self.process_point(data.row(i).as_slice().unwrap_or(&[]))?;
505        }
506        Ok(())
507    }
508
509    /// Get current cluster labels by running DBSCAN on p-micro-cluster centroids.
510    pub fn get_clusters(&self) -> Result<(Array2<F>, Array1<i32>)> {
511        if self.p_micro_clusters.is_empty() {
512            return Err(ClusteringError::InvalidState(
513                "No potential micro-clusters".into(),
514            ));
515        }
516
517        let n = self.p_micro_clusters.len();
518        let d = self.n_features;
519        let mut centroids = Array2::<F>::zeros((n, d));
520        for (i, mc) in self.p_micro_clusters.iter().enumerate() {
521            let c = mc.centroid();
522            for j in 0..d.min(c.len()) {
523                centroids[[i, j]] = c[j];
524            }
525        }
526
527        // Simple DBSCAN on centroids
528        let eps = F::from(self.config.epsilon).unwrap_or_else(|| F::one());
529        let labels = dbscan_on_centroids(&centroids, eps, self.config.min_points);
530
531        Ok((centroids, labels))
532    }
533
534    /// Number of potential micro-clusters.
535    pub fn n_potential(&self) -> usize {
536        self.p_micro_clusters.len()
537    }
538
539    /// Number of outlier micro-clusters.
540    pub fn n_outliers(&self) -> usize {
541        self.o_micro_clusters.len()
542    }
543
544    fn cleanup(&mut self, lambda: F) {
545        let one = F::one();
546        let beta_f = F::from(self.config.beta).unwrap_or_else(|| F::zero());
547        let mu_f = F::from(self.config.mu).unwrap_or_else(|| F::one());
548
549        // Apply fading to all micro-clusters
550        for mc in self.p_micro_clusters.iter_mut() {
551            mc.apply_fading(lambda, one);
552        }
553        for mc in self.o_micro_clusters.iter_mut() {
554            mc.apply_fading(lambda, one);
555        }
556
557        // Remove p-micro-clusters that fell below threshold
558        let threshold = beta_f * mu_f;
559        self.p_micro_clusters.retain(|mc| mc.weight >= threshold);
560
561        // Remove very weak outlier micro-clusters
562        let outlier_threshold = F::from(0.01).unwrap_or_else(|| F::epsilon());
563        self.o_micro_clusters
564            .retain(|mc| mc.weight >= outlier_threshold);
565    }
566}
567
568// ---------------------------------------------------------------------------
569// StreamKM++ (coreset-based)
570// ---------------------------------------------------------------------------
571
572/// Configuration for StreamKM++.
573#[derive(Debug, Clone)]
574pub struct StreamKMConfig {
575    /// Number of final clusters.
576    pub n_clusters: usize,
577    /// Coreset size (number of weighted representatives to maintain).
578    pub coreset_size: usize,
579    /// Number of k-means iterations on the final coreset.
580    pub kmeans_iterations: usize,
581}
582
583impl Default for StreamKMConfig {
584    fn default() -> Self {
585        Self {
586            n_clusters: 5,
587            coreset_size: 200,
588            kmeans_iterations: 50,
589        }
590    }
591}
592
593/// Coreset point with weight.
594#[derive(Debug, Clone)]
595pub struct CoresetPoint<F: Float> {
596    /// Coordinates.
597    pub coords: Vec<F>,
598    /// Weight (number of original points represented).
599    pub weight: F,
600}
601
602/// StreamKM++: coreset-based streaming k-means.
603///
604/// Maintains a weighted coreset that summarises the stream. When the
605/// buffer overflows, a merge-and-reduce step compresses it back down
606/// using k-means++ seeding to select coreset representatives.
607pub struct StreamKMPlusPlus<F: Float> {
608    config: StreamKMConfig,
609    coreset: Vec<CoresetPoint<F>>,
610    buffer: Vec<Vec<F>>,
611    n_features: usize,
612    initialized: bool,
613}
614
615impl<F: Float + FromPrimitive + Debug> StreamKMPlusPlus<F> {
616    /// Create a new StreamKM++ instance.
617    pub fn new(config: StreamKMConfig) -> Self {
618        Self {
619            config,
620            coreset: Vec::new(),
621            buffer: Vec::new(),
622            n_features: 0,
623            initialized: false,
624        }
625    }
626
627    /// Process a single point from the stream.
628    pub fn process_point(&mut self, point: &[F]) -> Result<()> {
629        if !self.initialized {
630            self.n_features = point.len();
631            self.initialized = true;
632        }
633        self.buffer.push(point.to_vec());
634
635        // When buffer is full, merge-and-reduce
636        if self.buffer.len() >= self.config.coreset_size {
637            self.merge_reduce()?;
638        }
639        Ok(())
640    }
641
642    /// Process a batch of points.
643    pub fn process_batch(&mut self, data: ArrayView2<F>) -> Result<()> {
644        for i in 0..data.shape()[0] {
645            self.process_point(data.row(i).as_slice().unwrap_or(&[]))?;
646        }
647        Ok(())
648    }
649
650    /// Get final cluster centroids and coreset labels.
651    pub fn get_clusters(&self) -> Result<(Array2<F>, Array1<i32>)> {
652        // Combine coreset and buffer into a weighted point set
653        let mut all_points: Vec<(Vec<F>, F)> = Vec::new();
654        for cp in &self.coreset {
655            all_points.push((cp.coords.clone(), cp.weight));
656        }
657        for bp in &self.buffer {
658            all_points.push((bp.clone(), F::one()));
659        }
660
661        if all_points.is_empty() {
662            return Err(ClusteringError::InvalidState(
663                "No data processed yet".into(),
664            ));
665        }
666
667        let n = all_points.len();
668        let d = self.n_features;
669        let k = self.config.n_clusters.min(n);
670
671        // Build matrix
672        let mut mat = Array2::<F>::zeros((n, d));
673        let mut weights = Array1::<F>::zeros(n);
674        for (i, (coords, w)) in all_points.iter().enumerate() {
675            for j in 0..d.min(coords.len()) {
676                mat[[i, j]] = coords[j];
677            }
678            weights[i] = *w;
679        }
680
681        // Weighted k-means
682        let labels = weighted_kmeans(mat.view(), &weights, k, self.config.kmeans_iterations);
683
684        // Compute centroids
685        let mut centroids = Array2::<F>::zeros((k, d));
686        let mut total_weights = vec![F::zero(); k];
687        for i in 0..n {
688            let ci = labels[i] as usize;
689            if ci < k {
690                total_weights[ci] = total_weights[ci] + weights[i];
691                for j in 0..d {
692                    centroids[[ci, j]] = centroids[[ci, j]] + mat[[i, j]] * weights[i];
693                }
694            }
695        }
696        for ci in 0..k {
697            if total_weights[ci] > F::epsilon() {
698                for j in 0..d {
699                    centroids[[ci, j]] = centroids[[ci, j]] / total_weights[ci];
700                }
701            }
702        }
703
704        Ok((centroids, labels))
705    }
706
707    /// Current coreset size (number of weighted representatives).
708    pub fn coreset_size(&self) -> usize {
709        self.coreset.len()
710    }
711
712    fn merge_reduce(&mut self) -> Result<()> {
713        // Combine coreset + buffer into one set, then k-means++ reduce
714        let mut all: Vec<(Vec<F>, F)> = Vec::new();
715        for cp in self.coreset.drain(..) {
716            all.push((cp.coords, cp.weight));
717        }
718        for bp in self.buffer.drain(..) {
719            all.push((bp, F::one()));
720        }
721
722        let target = self.config.coreset_size / 2;
723        if all.len() <= target {
724            for (coords, w) in all {
725                self.coreset.push(CoresetPoint { coords, weight: w });
726            }
727            return Ok(());
728        }
729
730        let n = all.len();
731        let d = self.n_features;
732        let k = target.min(n);
733
734        // Build matrix for k-means
735        let mut mat = Array2::<F>::zeros((n, d));
736        let mut weights = Array1::<F>::zeros(n);
737        for (i, (coords, w)) in all.iter().enumerate() {
738            for j in 0..d.min(coords.len()) {
739                mat[[i, j]] = coords[j];
740            }
741            weights[i] = *w;
742        }
743
744        let labels = weighted_kmeans(mat.view(), &weights, k, 10);
745
746        // Build new coreset from cluster centroids with summed weights
747        for ci in 0..k {
748            let mut sum = vec![F::zero(); d];
749            let mut total_w = F::zero();
750            for i in 0..n {
751                if labels[i] == ci as i32 {
752                    total_w = total_w + weights[i];
753                    for j in 0..d {
754                        sum[j] = sum[j] + mat[[i, j]] * weights[i];
755                    }
756                }
757            }
758            if total_w > F::epsilon() {
759                for j in 0..d {
760                    sum[j] = sum[j] / total_w;
761                }
762                self.coreset.push(CoresetPoint {
763                    coords: sum,
764                    weight: total_w,
765                });
766            }
767        }
768
769        Ok(())
770    }
771}
772
773// ---------------------------------------------------------------------------
774// Sliding Window Clustering
775// ---------------------------------------------------------------------------
776
777/// Configuration for sliding window clustering.
778#[derive(Debug, Clone)]
779pub struct SlidingWindowConfig {
780    /// Window size (number of recent points to keep).
781    pub window_size: usize,
782    /// Number of clusters.
783    pub n_clusters: usize,
784    /// K-means iterations per query.
785    pub kmeans_iterations: usize,
786}
787
788impl Default for SlidingWindowConfig {
789    fn default() -> Self {
790        Self {
791            window_size: 1000,
792            n_clusters: 5,
793            kmeans_iterations: 20,
794        }
795    }
796}
797
798/// Sliding window clustering: maintains a fixed-size window of the most
799/// recent points and clusters them on demand.
800pub struct SlidingWindowClustering<F: Float> {
801    config: SlidingWindowConfig,
802    window: VecDeque<Vec<F>>,
803    n_features: usize,
804}
805
806impl<F: Float + FromPrimitive + Debug> SlidingWindowClustering<F> {
807    /// Create a new sliding window clustering instance.
808    pub fn new(config: SlidingWindowConfig) -> Self {
809        Self {
810            config,
811            window: VecDeque::new(),
812            n_features: 0,
813        }
814    }
815
816    /// Add a single point to the window.
817    pub fn add_point(&mut self, point: &[F]) {
818        if self.n_features == 0 {
819            self.n_features = point.len();
820        }
821        self.window.push_back(point.to_vec());
822        if self.window.len() > self.config.window_size {
823            self.window.pop_front();
824        }
825    }
826
827    /// Add a batch of points.
828    pub fn add_batch(&mut self, data: ArrayView2<F>) {
829        for i in 0..data.shape()[0] {
830            self.add_point(data.row(i).as_slice().unwrap_or(&[]));
831        }
832    }
833
834    /// Get current clustering of the window contents.
835    pub fn get_clusters(&self) -> Result<(Array2<F>, Array1<i32>)> {
836        if self.window.is_empty() {
837            return Err(ClusteringError::InvalidState("Window is empty".into()));
838        }
839
840        let n = self.window.len();
841        let d = self.n_features;
842        let k = self.config.n_clusters.min(n);
843
844        let mut mat = Array2::<F>::zeros((n, d));
845        for (i, pt) in self.window.iter().enumerate() {
846            for j in 0..d.min(pt.len()) {
847                mat[[i, j]] = pt[j];
848            }
849        }
850
851        let labels = simple_kmeans_init(mat.view(), k);
852
853        // Compute centroids
854        let mut centroids = Array2::<F>::zeros((k, d));
855        let mut counts = vec![0usize; k];
856        for i in 0..n {
857            let ci = labels[i] as usize;
858            if ci < k {
859                counts[ci] += 1;
860                for j in 0..d {
861                    centroids[[ci, j]] = centroids[[ci, j]] + mat[[i, j]];
862                }
863            }
864        }
865        for ci in 0..k {
866            if counts[ci] > 0 {
867                let cnt = F::from(counts[ci]).unwrap_or_else(|| F::one());
868                for j in 0..d {
869                    centroids[[ci, j]] = centroids[[ci, j]] / cnt;
870                }
871            }
872        }
873
874        Ok((centroids, labels))
875    }
876
877    /// Current number of points in the window.
878    pub fn window_len(&self) -> usize {
879        self.window.len()
880    }
881}
882
883// ---------------------------------------------------------------------------
884// Online K-Means with Forgetting Factor
885// ---------------------------------------------------------------------------
886
887/// Configuration for online k-means with forgetting.
888#[derive(Debug, Clone)]
889pub struct OnlineKMeansConfig {
890    /// Number of clusters.
891    pub n_clusters: usize,
892    /// Forgetting factor in (0, 1]. 1.0 = no forgetting (standard online).
893    pub forgetting_factor: f64,
894    /// Learning rate schedule: if true, use 1/n_i decay; otherwise constant.
895    pub adaptive_learning: bool,
896}
897
898impl Default for OnlineKMeansConfig {
899    fn default() -> Self {
900        Self {
901            n_clusters: 5,
902            forgetting_factor: 0.99,
903            adaptive_learning: true,
904        }
905    }
906}
907
908/// Online K-means with exponential forgetting factor.
909///
910/// Maintains cluster centroids that are updated incrementally.
911/// The forgetting factor downweights older contributions, allowing
912/// the algorithm to adapt to concept drift.
913pub struct OnlineKMeans<F: Float> {
914    config: OnlineKMeansConfig,
915    centroids: Option<Array2<F>>,
916    cluster_counts: Vec<F>,
917    n_features: usize,
918    initialized: bool,
919    total_points: usize,
920}
921
922impl<F: Float + FromPrimitive + Debug> OnlineKMeans<F> {
923    /// Create a new online k-means instance.
924    pub fn new(config: OnlineKMeansConfig) -> Self {
925        Self {
926            config,
927            centroids: None,
928            cluster_counts: Vec::new(),
929            n_features: 0,
930            initialized: false,
931            total_points: 0,
932        }
933    }
934
935    /// Initialize with a batch of data (used for seeding centroids).
936    pub fn initialize(&mut self, data: ArrayView2<F>) -> Result<()> {
937        let (n, d) = (data.shape()[0], data.shape()[1]);
938        if n == 0 {
939            return Err(ClusteringError::InvalidInput("Empty data".into()));
940        }
941        self.n_features = d;
942        let k = self.config.n_clusters.min(n);
943
944        let labels = simple_kmeans_init(data, k);
945
946        let mut centroids = Array2::<F>::zeros((k, d));
947        let mut counts = vec![F::zero(); k];
948        for i in 0..n {
949            let ci = labels[i] as usize;
950            if ci < k {
951                counts[ci] = counts[ci] + F::one();
952                for j in 0..d {
953                    centroids[[ci, j]] = centroids[[ci, j]] + data[[i, j]];
954                }
955            }
956        }
957        for ci in 0..k {
958            if counts[ci] > F::epsilon() {
959                for j in 0..d {
960                    centroids[[ci, j]] = centroids[[ci, j]] / counts[ci];
961                }
962            }
963        }
964
965        self.centroids = Some(centroids);
966        self.cluster_counts = counts;
967        self.initialized = true;
968        self.total_points = n;
969        Ok(())
970    }
971
972    /// Process a single new point.
973    pub fn process_point(&mut self, point: &[F]) -> Result<i32> {
974        if !self.initialized {
975            return Err(ClusteringError::InvalidState(
976                "OnlineKMeans not initialized".into(),
977            ));
978        }
979
980        let centroids = self
981            .centroids
982            .as_ref()
983            .ok_or_else(|| ClusteringError::InvalidState("No centroids".into()))?;
984
985        let k = centroids.shape()[0];
986        let d = centroids.shape()[1];
987        let ff = F::from(self.config.forgetting_factor).unwrap_or_else(|| F::one());
988
989        // Find nearest centroid
990        let mut best_ci = 0;
991        let mut best_dist = F::infinity();
992        for ci in 0..k {
993            let mut dist = F::zero();
994            for j in 0..d.min(point.len()) {
995                let diff = point[j] - centroids[[ci, j]];
996                dist = dist + diff * diff;
997            }
998            if dist < best_dist {
999                best_dist = dist;
1000                best_ci = ci;
1001            }
1002        }
1003
1004        // Apply forgetting factor to all cluster counts
1005        for ci in 0..k {
1006            self.cluster_counts[ci] = self.cluster_counts[ci] * ff;
1007        }
1008
1009        // Update centroid
1010        self.cluster_counts[best_ci] = self.cluster_counts[best_ci] + F::one();
1011        let eta = if self.config.adaptive_learning {
1012            F::one() / self.cluster_counts[best_ci]
1013        } else {
1014            F::from(0.01).unwrap_or_else(|| F::epsilon())
1015        };
1016
1017        let centroids_mut = self
1018            .centroids
1019            .as_mut()
1020            .ok_or_else(|| ClusteringError::InvalidState("No centroids".into()))?;
1021        for j in 0..d.min(point.len()) {
1022            centroids_mut[[best_ci, j]] =
1023                centroids_mut[[best_ci, j]] + eta * (point[j] - centroids_mut[[best_ci, j]]);
1024        }
1025
1026        self.total_points += 1;
1027        Ok(best_ci as i32)
1028    }
1029
1030    /// Process a batch and return labels.
1031    pub fn process_batch(&mut self, data: ArrayView2<F>) -> Result<Array1<i32>> {
1032        let n = data.shape()[0];
1033        let mut labels = Array1::from_elem(n, -1i32);
1034        for i in 0..n {
1035            labels[i] = self.process_point(data.row(i).as_slice().unwrap_or(&[]))?;
1036        }
1037        Ok(labels)
1038    }
1039
1040    /// Get current centroids.
1041    pub fn centroids(&self) -> Option<&Array2<F>> {
1042        self.centroids.as_ref()
1043    }
1044
1045    /// Predict cluster for a point without updating.
1046    pub fn predict(&self, point: &[F]) -> Result<i32> {
1047        let centroids = self
1048            .centroids
1049            .as_ref()
1050            .ok_or_else(|| ClusteringError::InvalidState("Not initialized".into()))?;
1051        let k = centroids.shape()[0];
1052        let d = centroids.shape()[1];
1053        let mut best_ci = 0i32;
1054        let mut best_dist = F::infinity();
1055        for ci in 0..k {
1056            let mut dist = F::zero();
1057            for j in 0..d.min(point.len()) {
1058                let diff = point[j] - centroids[[ci, j]];
1059                dist = dist + diff * diff;
1060            }
1061            if dist < best_dist {
1062                best_dist = dist;
1063                best_ci = ci as i32;
1064            }
1065        }
1066        Ok(best_ci)
1067    }
1068}
1069
1070// ---------------------------------------------------------------------------
1071// Helper functions
1072// ---------------------------------------------------------------------------
1073
1074/// Squared distance between two array rows.
1075fn row_dist_sq<F: Float>(a: ArrayView1<F>, b: ArrayView1<F>) -> F {
1076    let mut s = F::zero();
1077    for i in 0..a.len().min(b.len()) {
1078        let diff = a[i] - b[i];
1079        s = s + diff * diff;
1080    }
1081    s
1082}
1083
1084/// Find the nearest micro-cluster to a point; returns (index, sq distance).
1085fn nearest_mc_idx<F: Float + FromPrimitive + Debug>(
1086    clusters: &[MicroCluster<F>],
1087    point: &[F],
1088) -> (usize, F) {
1089    let mut best = 0;
1090    let mut best_d = F::infinity();
1091    for (i, mc) in clusters.iter().enumerate() {
1092        let d = mc.distance_sq_to(point);
1093        if d < best_d {
1094            best_d = d;
1095            best = i;
1096        }
1097    }
1098    (best, best_d)
1099}
1100
1101/// Simple DBSCAN on a centroid matrix (for DenStream macro-clustering).
1102fn dbscan_on_centroids<F: Float + FromPrimitive + Debug>(
1103    centroids: &Array2<F>,
1104    eps: F,
1105    min_pts: usize,
1106) -> Array1<i32> {
1107    let n = centroids.shape()[0];
1108    let eps_sq = eps * eps;
1109    let mut labels = vec![-2i32; n]; // -2 = undefined, -1 = noise
1110    let mut cluster_id = 0i32;
1111
1112    for i in 0..n {
1113        if labels[i] != -2 {
1114            continue;
1115        }
1116        let neighbors: Vec<usize> = (0..n)
1117            .filter(|&j| {
1118                let d = row_dist_sq(centroids.row(i), centroids.row(j));
1119                d <= eps_sq
1120            })
1121            .collect();
1122
1123        if neighbors.len() < min_pts {
1124            labels[i] = -1;
1125            continue;
1126        }
1127
1128        labels[i] = cluster_id;
1129        let mut queue = neighbors.clone();
1130        let mut head = 0usize;
1131        while head < queue.len() {
1132            let cur = queue[head];
1133            head += 1;
1134            if labels[cur] == -1 {
1135                labels[cur] = cluster_id;
1136                continue;
1137            }
1138            if labels[cur] != -2 {
1139                continue;
1140            }
1141            labels[cur] = cluster_id;
1142
1143            let cur_neighbors: Vec<usize> = (0..n)
1144                .filter(|&j| {
1145                    let d = row_dist_sq(centroids.row(cur), centroids.row(j));
1146                    d <= eps_sq
1147                })
1148                .collect();
1149
1150            if cur_neighbors.len() >= min_pts {
1151                for nb in cur_neighbors {
1152                    if labels[nb] == -2 || labels[nb] == -1 {
1153                        queue.push(nb);
1154                    }
1155                }
1156            }
1157        }
1158        cluster_id += 1;
1159    }
1160
1161    Array1::from_vec(labels)
1162}
1163
1164/// Simple k-means for initialization purposes.
1165fn simple_kmeans_init<F: Float + FromPrimitive + Debug>(
1166    data: ArrayView2<F>,
1167    k: usize,
1168) -> Array1<i32> {
1169    let (n, d) = (data.shape()[0], data.shape()[1]);
1170    if n == 0 || k == 0 {
1171        return Array1::from_elem(n, 0i32);
1172    }
1173    let k = k.min(n);
1174
1175    // K-means++ style init
1176    let mut centroids = Array2::<F>::zeros((k, d));
1177    centroids.row_mut(0).assign(&data.row(0));
1178
1179    for ci in 1..k {
1180        let mut best_idx = 0;
1181        let mut best_dist = F::zero();
1182        for i in 0..n {
1183            let mut min_d = F::infinity();
1184            for prev in 0..ci {
1185                let d = row_dist_sq(data.row(i), centroids.row(prev));
1186                if d < min_d {
1187                    min_d = d;
1188                }
1189            }
1190            if min_d > best_dist {
1191                best_dist = min_d;
1192                best_idx = i;
1193            }
1194        }
1195        centroids.row_mut(ci).assign(&data.row(best_idx));
1196    }
1197
1198    let mut labels = Array1::from_elem(n, 0i32);
1199    for _ in 0..20 {
1200        // Assign
1201        let mut changed = false;
1202        for i in 0..n {
1203            let mut best_ci = 0i32;
1204            let mut best_d = F::infinity();
1205            for ci in 0..k {
1206                let d = row_dist_sq(data.row(i), centroids.row(ci));
1207                if d < best_d {
1208                    best_d = d;
1209                    best_ci = ci as i32;
1210                }
1211            }
1212            if labels[i] != best_ci {
1213                labels[i] = best_ci;
1214                changed = true;
1215            }
1216        }
1217        if !changed {
1218            break;
1219        }
1220        // Update
1221        let mut counts = vec![0usize; k];
1222        let mut sums = Array2::<F>::zeros((k, d));
1223        for i in 0..n {
1224            let ci = labels[i] as usize;
1225            counts[ci] += 1;
1226            for j in 0..d {
1227                sums[[ci, j]] = sums[[ci, j]] + data[[i, j]];
1228            }
1229        }
1230        for ci in 0..k {
1231            if counts[ci] > 0 {
1232                let cnt = F::from(counts[ci]).unwrap_or_else(|| F::one());
1233                for j in 0..d {
1234                    centroids[[ci, j]] = sums[[ci, j]] / cnt;
1235                }
1236            }
1237        }
1238    }
1239    labels
1240}
1241
1242/// Weighted k-means.
1243fn weighted_kmeans<F: Float + FromPrimitive + Debug>(
1244    data: ArrayView2<F>,
1245    weights: &Array1<F>,
1246    k: usize,
1247    max_iter: usize,
1248) -> Array1<i32> {
1249    let (n, d) = (data.shape()[0], data.shape()[1]);
1250    if n == 0 || k == 0 {
1251        return Array1::from_elem(n, 0i32);
1252    }
1253    let k = k.min(n);
1254
1255    // Init: pick k spread points
1256    let mut centroids = Array2::<F>::zeros((k, d));
1257    let step = (n as f64 / k as f64).max(1.0);
1258    for ci in 0..k {
1259        let idx = ((ci as f64 * step) as usize).min(n - 1);
1260        centroids.row_mut(ci).assign(&data.row(idx));
1261    }
1262
1263    let mut labels = Array1::from_elem(n, 0i32);
1264    for _ in 0..max_iter {
1265        let mut changed = false;
1266        for i in 0..n {
1267            let mut best = 0i32;
1268            let mut best_d = F::infinity();
1269            for ci in 0..k {
1270                let dd = row_dist_sq(data.row(i), centroids.row(ci));
1271                if dd < best_d {
1272                    best_d = dd;
1273                    best = ci as i32;
1274                }
1275            }
1276            if labels[i] != best {
1277                labels[i] = best;
1278                changed = true;
1279            }
1280        }
1281        if !changed {
1282            break;
1283        }
1284
1285        let mut sums = Array2::<F>::zeros((k, d));
1286        let mut total_w = vec![F::zero(); k];
1287        for i in 0..n {
1288            let ci = labels[i] as usize;
1289            total_w[ci] = total_w[ci] + weights[i];
1290            for j in 0..d {
1291                sums[[ci, j]] = sums[[ci, j]] + data[[i, j]] * weights[i];
1292            }
1293        }
1294        for ci in 0..k {
1295            if total_w[ci] > F::epsilon() {
1296                for j in 0..d {
1297                    centroids[[ci, j]] = sums[[ci, j]] / total_w[ci];
1298                }
1299            }
1300        }
1301    }
1302    labels
1303}
1304
1305// ---------------------------------------------------------------------------
1306// Tests
1307// ---------------------------------------------------------------------------
1308
1309#[cfg(test)]
1310mod tests {
1311    use super::*;
1312    use scirs2_core::ndarray::Array2;
1313
1314    fn make_stream_data() -> Array2<f64> {
1315        let mut data = Vec::new();
1316        // Cluster A around (1, 1)
1317        for i in 0..30 {
1318            let noise = (i as f64 * 0.073).sin() * 0.3;
1319            data.push(1.0 + noise);
1320            data.push(1.0 + noise * 0.7);
1321        }
1322        // Cluster B around (5, 5)
1323        for i in 0..30 {
1324            let noise = (i as f64 * 0.131).sin() * 0.3;
1325            data.push(5.0 + noise);
1326            data.push(5.0 + noise * 0.7);
1327        }
1328        Array2::from_shape_vec((60, 2), data).expect("shape failed")
1329    }
1330
1331    // -- MicroCluster tests --
1332
1333    #[test]
1334    fn test_micro_cluster_from_point() {
1335        let mc = MicroCluster::<f64>::from_point(&[1.0, 2.0, 3.0], 0);
1336        assert_eq!(mc.n_points, 1);
1337        let c = mc.centroid();
1338        assert!((c[0] - 1.0).abs() < 1e-10);
1339        assert!((c[1] - 2.0).abs() < 1e-10);
1340    }
1341
1342    #[test]
1343    fn test_micro_cluster_absorb() {
1344        let mut mc = MicroCluster::<f64>::from_point(&[1.0, 2.0], 0);
1345        mc.absorb(&[3.0, 4.0], 1);
1346        assert_eq!(mc.n_points, 2);
1347        let c = mc.centroid();
1348        assert!((c[0] - 2.0).abs() < 1e-10);
1349        assert!((c[1] - 3.0).abs() < 1e-10);
1350    }
1351
1352    #[test]
1353    fn test_micro_cluster_merge() {
1354        let mut mc1 = MicroCluster::<f64>::from_point(&[1.0, 1.0], 0);
1355        let mc2 = MicroCluster::<f64>::from_point(&[3.0, 3.0], 1);
1356        mc1.merge(&mc2);
1357        assert_eq!(mc1.n_points, 2);
1358        let c = mc1.centroid();
1359        assert!((c[0] - 2.0).abs() < 1e-10);
1360    }
1361
1362    #[test]
1363    fn test_micro_cluster_radius() {
1364        let mut mc = MicroCluster::<f64>::from_point(&[0.0, 0.0], 0);
1365        mc.absorb(&[2.0, 0.0], 1);
1366        mc.absorb(&[0.0, 2.0], 2);
1367        let r = mc.radius();
1368        // Should be > 0 for spread-out points
1369        assert!(r > 0.0);
1370    }
1371
1372    #[test]
1373    fn test_micro_cluster_fading() {
1374        let mut mc = MicroCluster::<f64>::from_point(&[1.0, 2.0], 0);
1375        mc.absorb(&[1.0, 2.0], 1);
1376        let w_before = mc.weight;
1377        mc.apply_fading(0.5, 1.0);
1378        assert!(mc.weight < w_before);
1379    }
1380
1381    // -- CluStream tests --
1382
1383    #[test]
1384    fn test_clustream_basic() {
1385        let data = make_stream_data();
1386        let config = CluStreamConfig {
1387            max_micro_clusters: 20,
1388            n_macro_clusters: 2,
1389            ..Default::default()
1390        };
1391        let mut cs = CluStream::new(config);
1392        let init_data = data.slice(scirs2_core::ndarray::s![0..20, ..]);
1393        cs.initialize(init_data).expect("init failed");
1394
1395        // Process remaining points
1396        for i in 20..60 {
1397            cs.process_point(data.row(i).as_slice().unwrap_or(&[]))
1398                .expect("process failed");
1399        }
1400
1401        assert!(cs.n_micro_clusters() > 0);
1402        let (centroids, labels) = cs.get_macro_clusters().expect("macro failed");
1403        assert_eq!(labels.len(), cs.n_micro_clusters());
1404    }
1405
1406    #[test]
1407    fn test_clustream_empty_init() {
1408        let data = Array2::<f64>::zeros((0, 3));
1409        let config = CluStreamConfig::default();
1410        let mut cs = CluStream::new(config);
1411        assert!(cs.initialize(data.view()).is_err());
1412    }
1413
1414    #[test]
1415    fn test_clustream_not_initialized() {
1416        let cs = CluStream::<f64>::new(CluStreamConfig::default());
1417        assert!(cs.get_macro_clusters().is_err());
1418    }
1419
1420    // -- DenStream tests --
1421
1422    #[test]
1423    fn test_denstream_basic() {
1424        let data = make_stream_data();
1425        let config = DenStreamConfig {
1426            epsilon: 2.0,
1427            min_points: 2,
1428            lambda: 0.1,
1429            ..Default::default()
1430        };
1431        let mut ds = DenStream::new(config);
1432        let init_data = data.slice(scirs2_core::ndarray::s![0..30, ..]);
1433        ds.initialize(init_data).expect("init failed");
1434
1435        for i in 30..60 {
1436            ds.process_point(data.row(i).as_slice().unwrap_or(&[]))
1437                .expect("process failed");
1438        }
1439
1440        assert!(ds.n_potential() > 0);
1441        let result = ds.get_clusters();
1442        assert!(result.is_ok());
1443    }
1444
1445    #[test]
1446    fn test_denstream_empty_init() {
1447        let data = Array2::<f64>::zeros((0, 2));
1448        let config = DenStreamConfig::default();
1449        let mut ds = DenStream::new(config);
1450        assert!(ds.initialize(data.view()).is_err());
1451    }
1452
1453    // -- StreamKM++ tests --
1454
1455    #[test]
1456    fn test_streamkm_basic() {
1457        let data = make_stream_data();
1458        let config = StreamKMConfig {
1459            n_clusters: 2,
1460            coreset_size: 20,
1461            kmeans_iterations: 20,
1462        };
1463        let mut skm = StreamKMPlusPlus::new(config);
1464        skm.process_batch(data.view()).expect("batch failed");
1465        let (centroids, labels) = skm.get_clusters().expect("clusters failed");
1466        assert_eq!(labels.len(), skm.coreset_size() + skm.buffer.len());
1467    }
1468
1469    #[test]
1470    fn test_streamkm_single_point() {
1471        let config = StreamKMConfig {
1472            n_clusters: 1,
1473            coreset_size: 100,
1474            ..Default::default()
1475        };
1476        let mut skm = StreamKMPlusPlus::<f64>::new(config);
1477        skm.process_point(&[1.0, 2.0]).expect("failed");
1478        let (_, labels) = skm.get_clusters().expect("clusters failed");
1479        assert_eq!(labels.len(), 1);
1480    }
1481
1482    // -- Sliding Window tests --
1483
1484    #[test]
1485    fn test_sliding_window_basic() {
1486        let data = make_stream_data();
1487        let config = SlidingWindowConfig {
1488            window_size: 50,
1489            n_clusters: 2,
1490            kmeans_iterations: 20,
1491        };
1492        let mut sw = SlidingWindowClustering::new(config);
1493        sw.add_batch(data.view());
1494        assert_eq!(sw.window_len(), 50); // capped at window_size
1495        let (_, labels) = sw.get_clusters().expect("clusters failed");
1496        assert_eq!(labels.len(), 50);
1497    }
1498
1499    #[test]
1500    fn test_sliding_window_empty() {
1501        let sw = SlidingWindowClustering::<f64>::new(SlidingWindowConfig::default());
1502        assert!(sw.get_clusters().is_err());
1503    }
1504
1505    #[test]
1506    fn test_sliding_window_overflow() {
1507        let config = SlidingWindowConfig {
1508            window_size: 5,
1509            n_clusters: 2,
1510            ..Default::default()
1511        };
1512        let mut sw = SlidingWindowClustering::<f64>::new(config);
1513        for i in 0..10 {
1514            sw.add_point(&[i as f64, i as f64 * 2.0]);
1515        }
1516        assert_eq!(sw.window_len(), 5);
1517    }
1518
1519    // -- Online K-Means tests --
1520
1521    #[test]
1522    fn test_online_kmeans_basic() {
1523        let data = make_stream_data();
1524        let config = OnlineKMeansConfig {
1525            n_clusters: 2,
1526            forgetting_factor: 0.99,
1527            adaptive_learning: true,
1528        };
1529        let mut okm = OnlineKMeans::new(config);
1530        let init_data = data.slice(scirs2_core::ndarray::s![0..20, ..]);
1531        okm.initialize(init_data).expect("init failed");
1532
1533        let labels = okm
1534            .process_batch(data.slice(scirs2_core::ndarray::s![20..60, ..]))
1535            .expect("batch failed");
1536        assert_eq!(labels.len(), 40);
1537
1538        // Predict should work
1539        let pred = okm.predict(&[1.0, 1.0]).expect("predict failed");
1540        assert!(pred >= 0);
1541    }
1542
1543    #[test]
1544    fn test_online_kmeans_not_init() {
1545        let okm = OnlineKMeans::<f64>::new(OnlineKMeansConfig::default());
1546        assert!(okm.predict(&[1.0]).is_err());
1547    }
1548
1549    #[test]
1550    fn test_online_kmeans_forgetting() {
1551        let config = OnlineKMeansConfig {
1552            n_clusters: 2,
1553            forgetting_factor: 0.5, // aggressive forgetting
1554            adaptive_learning: true,
1555        };
1556        let mut okm = OnlineKMeans::<f64>::new(config);
1557        let init = Array2::from_shape_vec((10, 2), (0..20).map(|i| (i as f64) * 0.1).collect())
1558            .expect("shape failed");
1559        okm.initialize(init.view()).expect("init failed");
1560
1561        // Feed many points from a different region
1562        for _ in 0..50 {
1563            let _ = okm.process_point(&[10.0, 10.0]);
1564        }
1565
1566        // Centroids should have drifted toward (10, 10)
1567        let centroids = okm.centroids().expect("no centroids");
1568        let mut any_close = false;
1569        for ci in 0..centroids.shape()[0] {
1570            if (centroids[[ci, 0]] - 10.0).abs() < 3.0 {
1571                any_close = true;
1572            }
1573        }
1574        assert!(any_close, "Expected centroids to drift toward new data");
1575    }
1576
1577    // -- Helper function tests --
1578
1579    #[test]
1580    fn test_row_dist_sq() {
1581        let a = Array1::from_vec(vec![1.0, 2.0]);
1582        let b = Array1::from_vec(vec![4.0, 6.0]);
1583        let d = row_dist_sq(a.view(), b.view());
1584        assert!((d - 25.0).abs() < 1e-10);
1585    }
1586
1587    #[test]
1588    fn test_dbscan_on_centroids() {
1589        let data = Array2::from_shape_vec(
1590            (6, 2),
1591            vec![1.0, 1.0, 1.1, 1.1, 1.2, 1.2, 5.0, 5.0, 5.1, 5.1, 5.2, 5.2],
1592        )
1593        .expect("shape failed");
1594        let labels = dbscan_on_centroids(&data, 0.5, 2);
1595        // Points 0-2 should be one cluster, 3-5 another
1596        assert_eq!(labels[0], labels[1]);
1597        assert_eq!(labels[1], labels[2]);
1598        assert_eq!(labels[3], labels[4]);
1599        assert_eq!(labels[4], labels[5]);
1600        assert_ne!(labels[0], labels[3]);
1601    }
1602
1603    #[test]
1604    fn test_simple_kmeans_init() {
1605        let data = make_stream_data();
1606        let labels = simple_kmeans_init(data.view(), 2);
1607        assert_eq!(labels.len(), 60);
1608        // Should have at least 2 distinct labels
1609        let unique: std::collections::HashSet<i32> = labels.iter().copied().collect();
1610        assert!(unique.len() >= 2);
1611    }
1612
1613    #[test]
1614    fn test_clustream_batch() {
1615        let data = make_stream_data();
1616        let config = CluStreamConfig {
1617            max_micro_clusters: 10,
1618            n_macro_clusters: 2,
1619            ..Default::default()
1620        };
1621        let mut cs = CluStream::new(config);
1622        cs.initialize(data.slice(scirs2_core::ndarray::s![0..20, ..]))
1623            .expect("init");
1624        cs.process_batch(data.slice(scirs2_core::ndarray::s![20..60, ..]))
1625            .expect("batch");
1626        assert!(cs.n_micro_clusters() >= 2);
1627    }
1628
1629    #[test]
1630    fn test_streamkm_coreset_reduces() {
1631        let config = StreamKMConfig {
1632            n_clusters: 2,
1633            coreset_size: 10,
1634            ..Default::default()
1635        };
1636        let mut skm = StreamKMPlusPlus::<f64>::new(config);
1637        // Feed more than coreset_size points
1638        for i in 0..30 {
1639            skm.process_point(&[i as f64, (i * 2) as f64])
1640                .expect("fail");
1641        }
1642        // Coreset should have been compressed
1643        assert!(skm.coreset_size() > 0);
1644    }
1645}