sklears_clustering/
streaming.rs

1//! Streaming Clustering Algorithms
2//!
3//! This module provides implementations of clustering algorithms designed for
4//! real-time data processing and online learning scenarios. These algorithms
5//! can process data points one at a time or in small batches, making them
6//! suitable for applications with memory constraints or continuous data streams.
7//!
8//! # Features
9//!
10//! - **Online K-Means**: Incremental version of K-means for streaming data
11//! - **Streaming DBSCAN**: Density-based clustering for continuous data streams
12//! - **CluStream**: Stream clustering algorithm with micro-clusters and macro-clusters
13//! - **DenStream**: Density-based stream clustering with outlier detection
14//! - **Sliding Window Clustering**: Time-window based clustering for temporal data
15//!
16//! # Mathematical Background
17//!
18//! Streaming clustering algorithms typically maintain:
19//! - Summary statistics (centroids, weights, timestamps)
20//! - Micro-clusters: Fine-grained cluster summaries
21//! - Macro-clusters: High-level cluster representations
22//! - Aging mechanisms: To handle concept drift and temporal relevance
23
24use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
25use scirs2_core::random::Random;
26use sklears_core::{
27    error::{Result, SklearsError},
28    traits::{Estimator, Fit, Predict, Trained, Untrained},
29    types::Float,
30};
31use std::collections::VecDeque;
32use std::marker::PhantomData;
33
34/// Configuration for streaming clustering algorithms
35#[derive(Debug, Clone)]
36pub struct StreamingConfig {
37    /// Maximum number of clusters to maintain
38    pub max_clusters: usize,
39    /// Learning rate for incremental updates
40    pub learning_rate: Float,
41    /// Decay factor for aging mechanism
42    pub decay_factor: Float,
43    /// Window size for sliding window approaches
44    pub window_size: usize,
45    /// Threshold for creating new clusters
46    pub creation_threshold: Float,
47    /// Threshold for merging clusters
48    pub merge_threshold: Float,
49    /// Minimum cluster weight to maintain
50    pub min_weight: Float,
51    /// Random seed for reproducibility
52    pub random_state: Option<u64>,
53    /// Update frequency for macro-cluster recalculation
54    pub update_frequency: usize,
55}
56
57impl Default for StreamingConfig {
58    fn default() -> Self {
59        Self {
60            max_clusters: 100,
61            learning_rate: 0.1,
62            decay_factor: 0.95,
63            window_size: 1000,
64            creation_threshold: 1.0,
65            merge_threshold: 0.5,
66            min_weight: 0.01,
67            random_state: None,
68            update_frequency: 10,
69        }
70    }
71}
72
73/// Micro-cluster data structure for streaming algorithms
74#[derive(Debug, Clone)]
75pub struct MicroCluster {
76    /// Cluster centroid
77    pub centroid: Array1<Float>,
78    /// Cluster weight (number of points)
79    pub weight: Float,
80    /// Sum of squared distances to centroid
81    pub sum_squared: Float,
82    /// Creation timestamp
83    pub creation_time: usize,
84    /// Last update timestamp
85    pub last_update: usize,
86    /// Cluster radius
87    pub radius: Float,
88}
89
90impl MicroCluster {
91    /// Create a new micro-cluster from a single point
92    pub fn new(point: &ArrayView1<Float>, timestamp: usize) -> Self {
93        Self {
94            centroid: point.to_owned(),
95            weight: 1.0,
96            sum_squared: 0.0,
97            creation_time: timestamp,
98            last_update: timestamp,
99            radius: 0.0,
100        }
101    }
102
103    /// Update the micro-cluster with a new point
104    pub fn update(&mut self, point: &ArrayView1<Float>, timestamp: usize, learning_rate: Float) {
105        let distance = self.distance_to_centroid(point);
106
107        // Incremental centroid update
108        let weight_factor = learning_rate / (self.weight + 1.0);
109        let diff = point - &self.centroid;
110        self.centroid = &self.centroid + &(&diff * weight_factor);
111
112        // Update statistics
113        self.weight += 1.0;
114        self.sum_squared += distance * distance;
115        self.last_update = timestamp;
116
117        // Update radius
118        self.radius = (self.sum_squared / self.weight.max(1.0)).sqrt();
119    }
120
121    /// Calculate distance from point to cluster centroid
122    pub fn distance_to_centroid(&self, point: &ArrayView1<Float>) -> Float {
123        let diff = point - &self.centroid;
124        diff.dot(&diff).sqrt()
125    }
126
127    /// Apply decay factor to the cluster
128    pub fn decay(&mut self, decay_factor: Float) {
129        self.weight *= decay_factor;
130        self.sum_squared *= decay_factor;
131    }
132
133    /// Check if the cluster should be removed due to low weight
134    pub fn should_remove(&self, min_weight: Float) -> bool {
135        self.weight < min_weight
136    }
137
138    /// Calculate cluster density
139    pub fn density(&self) -> Float {
140        if self.radius > 0.0 {
141            self.weight / (self.radius * self.radius)
142        } else {
143            self.weight
144        }
145    }
146}
147
148/// Online K-Means clustering for streaming data
149pub struct OnlineKMeans<State = Untrained> {
150    config: StreamingConfig,
151    state: PhantomData<State>,
152    // Trained state fields
153    centroids: Option<Array2<Float>>,
154    weights: Option<Array1<Float>>,
155    n_updates: usize,
156    timestamp: usize,
157}
158
159impl<State> OnlineKMeans<State> {
160    /// Create a new Online K-Means instance
161    pub fn new() -> Self {
162        Self {
163            config: StreamingConfig::default(),
164            state: PhantomData,
165            centroids: None,
166            weights: None,
167            n_updates: 0,
168            timestamp: 0,
169        }
170    }
171
172    /// Set the maximum number of clusters
173    pub fn max_clusters(mut self, max_clusters: usize) -> Self {
174        self.config.max_clusters = max_clusters;
175        self
176    }
177
178    /// Set the learning rate
179    pub fn learning_rate(mut self, learning_rate: Float) -> Self {
180        self.config.learning_rate = learning_rate;
181        self
182    }
183
184    /// Set the decay factor
185    pub fn decay_factor(mut self, decay_factor: Float) -> Self {
186        self.config.decay_factor = decay_factor;
187        self
188    }
189
190    /// Set the random state
191    pub fn random_state(mut self, seed: u64) -> Self {
192        self.config.random_state = Some(seed);
193        self
194    }
195}
196
197impl OnlineKMeans<Trained> {
198    /// Process a new data point online
199    pub fn partial_fit(&mut self, point: &ArrayView1<Float>) -> Result<()> {
200        if let (Some(ref mut centroids), Some(ref mut weights)) =
201            (&mut self.centroids, &mut self.weights)
202        {
203            // Find closest centroid
204            let mut min_distance = Float::INFINITY;
205            let mut closest_idx = 0;
206
207            for (i, centroid) in centroids.outer_iter().enumerate() {
208                let diff = point - &centroid;
209                let distance = diff.dot(&diff).sqrt();
210                if distance < min_distance {
211                    min_distance = distance;
212                    closest_idx = i;
213                }
214            }
215
216            // Update closest centroid
217            let lr = self.config.learning_rate / (weights[closest_idx] + 1.0);
218            let diff = point - &centroids.row(closest_idx);
219            let mut new_centroid = centroids.row(closest_idx).to_owned();
220            new_centroid = &new_centroid + &(&diff * lr);
221            centroids.row_mut(closest_idx).assign(&new_centroid);
222
223            // Update weight
224            weights[closest_idx] += 1.0;
225
226            // Apply decay to all clusters
227            for weight in weights.iter_mut() {
228                *weight *= self.config.decay_factor;
229            }
230
231            self.n_updates += 1;
232            self.timestamp += 1;
233
234            Ok(())
235        } else {
236            Err(SklearsError::NotFitted {
237                operation: "partial_fit".to_string(),
238            })
239        }
240    }
241
242    /// Get current centroids
243    pub fn centroids(&self) -> Result<&Array2<Float>> {
244        self.centroids
245            .as_ref()
246            .ok_or_else(|| SklearsError::NotFitted {
247                operation: "centroids".to_string(),
248            })
249    }
250
251    /// Get current weights
252    pub fn weights(&self) -> Result<&Array1<Float>> {
253        self.weights
254            .as_ref()
255            .ok_or_else(|| SklearsError::NotFitted {
256                operation: "weights".to_string(),
257            })
258    }
259}
260
261impl<State> Default for OnlineKMeans<State> {
262    fn default() -> Self {
263        Self::new()
264    }
265}
266
267impl<State> Estimator<State> for OnlineKMeans<State> {
268    type Config = StreamingConfig;
269    type Error = SklearsError;
270    type Float = Float;
271
272    fn config(&self) -> &Self::Config {
273        &self.config
274    }
275}
276
277impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, usize>> for OnlineKMeans<Untrained> {
278    type Fitted = OnlineKMeans<Trained>;
279
280    fn fit(self, x: &ArrayView2<Float>, _y: &ArrayView1<usize>) -> Result<Self::Fitted> {
281        let (n_samples, n_features) = x.dim();
282        let k = self.config.max_clusters.min(n_samples);
283
284        if n_samples == 0 || n_features == 0 {
285            return Err(SklearsError::InvalidInput("Empty input data".to_string()));
286        }
287
288        let mut rng = match self.config.random_state {
289            Some(seed) => Random::default(),
290            None => Random::default(),
291        };
292
293        // Initialize centroids with random samples
294        let mut centroids = Array2::zeros((k, n_features));
295        let weights = Array1::ones(k);
296
297        for i in 0..k {
298            let idx = rng.gen_range(0..n_samples);
299            centroids.row_mut(i).assign(&x.row(idx));
300        }
301
302        Ok(OnlineKMeans {
303            config: self.config,
304            state: PhantomData,
305            centroids: Some(centroids),
306            weights: Some(weights),
307            n_updates: 0,
308            timestamp: 0,
309        })
310    }
311}
312
313impl Predict<ArrayView2<'_, Float>, Array1<usize>> for OnlineKMeans<Trained> {
314    fn predict(&self, x: &ArrayView2<Float>) -> Result<Array1<usize>> {
315        let centroids = self.centroids()?;
316        let mut labels = Array1::zeros(x.nrows());
317
318        for (i, sample) in x.outer_iter().enumerate() {
319            let mut min_distance = Float::INFINITY;
320            let mut best_cluster = 0;
321
322            for (j, centroid) in centroids.outer_iter().enumerate() {
323                let diff = &sample - &centroid;
324                let distance = diff.dot(&diff).sqrt();
325                if distance < min_distance {
326                    min_distance = distance;
327                    best_cluster = j;
328                }
329            }
330
331            labels[i] = best_cluster;
332        }
333
334        Ok(labels)
335    }
336}
337
338/// CluStream algorithm for stream clustering
339pub struct CluStream<State = Untrained> {
340    config: StreamingConfig,
341    state: PhantomData<State>,
342    // Trained state fields
343    micro_clusters: Option<Vec<MicroCluster>>,
344    macro_clusters: Option<Array2<Float>>,
345    timestamp: usize,
346    update_counter: usize,
347}
348
349impl<State> CluStream<State> {
350    /// Create a new CluStream instance
351    pub fn new() -> Self {
352        Self {
353            config: StreamingConfig::default(),
354            state: PhantomData,
355            micro_clusters: None,
356            macro_clusters: None,
357            timestamp: 0,
358            update_counter: 0,
359        }
360    }
361
362    /// Set the maximum number of micro-clusters
363    pub fn max_clusters(mut self, max_clusters: usize) -> Self {
364        self.config.max_clusters = max_clusters;
365        self
366    }
367
368    /// Set the creation threshold
369    pub fn creation_threshold(mut self, threshold: Float) -> Self {
370        self.config.creation_threshold = threshold;
371        self
372    }
373
374    /// Set the merge threshold
375    pub fn merge_threshold(mut self, threshold: Float) -> Self {
376        self.config.merge_threshold = threshold;
377        self
378    }
379
380    /// Set the decay factor
381    pub fn decay_factor(mut self, decay_factor: Float) -> Self {
382        self.config.decay_factor = decay_factor;
383        self
384    }
385
386    /// Set the update frequency
387    pub fn update_frequency(mut self, frequency: usize) -> Self {
388        self.config.update_frequency = frequency;
389        self
390    }
391
392    /// Set the random state
393    pub fn random_state(mut self, seed: u64) -> Self {
394        self.config.random_state = Some(seed);
395        self
396    }
397}
398
399impl CluStream<Trained> {
400    /// Process a new data point online
401    pub fn partial_fit(&mut self, point: &ArrayView1<Float>) -> Result<()> {
402        if let Some(ref mut micro_clusters) = &mut self.micro_clusters {
403            self.timestamp += 1;
404
405            // Find closest micro-cluster
406            let mut min_distance = Float::INFINITY;
407            let mut closest_idx = None;
408
409            for (i, cluster) in micro_clusters.iter().enumerate() {
410                let distance = cluster.distance_to_centroid(point);
411                if distance < cluster.radius + self.config.creation_threshold
412                    && distance < min_distance
413                {
414                    min_distance = distance;
415                    closest_idx = Some(i);
416                }
417            }
418
419            if let Some(idx) = closest_idx {
420                // Update existing micro-cluster
421                micro_clusters[idx].update(point, self.timestamp, self.config.learning_rate);
422            } else {
423                // Create new micro-cluster or merge existing ones
424                if micro_clusters.len() < self.config.max_clusters {
425                    // Create new micro-cluster
426                    micro_clusters.push(MicroCluster::new(point, self.timestamp));
427                } else {
428                    // Find two closest micro-clusters to merge
429                    let mut min_merge_distance = Float::INFINITY;
430                    let mut merge_indices = (0, 1);
431
432                    for i in 0..micro_clusters.len() {
433                        for j in (i + 1)..micro_clusters.len() {
434                            let dist = micro_clusters[i]
435                                .distance_to_centroid(&micro_clusters[j].centroid.view());
436                            if dist < min_merge_distance {
437                                min_merge_distance = dist;
438                                merge_indices = (i, j);
439                            }
440                        }
441                    }
442
443                    // Merge the two closest clusters
444                    if min_merge_distance < self.config.merge_threshold {
445                        let (i, j) = merge_indices;
446                        let merged_centroid = (&micro_clusters[i].centroid
447                            * micro_clusters[i].weight
448                            + &micro_clusters[j].centroid * micro_clusters[j].weight)
449                            / (micro_clusters[i].weight + micro_clusters[j].weight);
450                        let merged_weight = micro_clusters[i].weight + micro_clusters[j].weight;
451
452                        micro_clusters[i].centroid = merged_centroid;
453                        micro_clusters[i].weight = merged_weight;
454                        micro_clusters[i].last_update = self.timestamp;
455
456                        micro_clusters.remove(j);
457
458                        // Add new micro-cluster for current point
459                        micro_clusters.push(MicroCluster::new(point, self.timestamp));
460                    } else {
461                        // Replace oldest cluster
462                        let mut oldest_idx = 0;
463                        let mut oldest_time = micro_clusters[0].last_update;
464
465                        for (i, cluster) in micro_clusters.iter().enumerate() {
466                            if cluster.last_update < oldest_time {
467                                oldest_time = cluster.last_update;
468                                oldest_idx = i;
469                            }
470                        }
471
472                        micro_clusters[oldest_idx] = MicroCluster::new(point, self.timestamp);
473                    }
474                }
475            }
476
477            // Apply decay to all micro-clusters
478            for cluster in micro_clusters.iter_mut() {
479                cluster.decay(self.config.decay_factor);
480            }
481
482            // Remove clusters with low weight
483            micro_clusters.retain(|cluster| !cluster.should_remove(self.config.min_weight));
484
485            // Update macro-clusters periodically
486            self.update_counter += 1;
487            if self.update_counter % self.config.update_frequency == 0 {
488                self.update_macro_clusters()?;
489            }
490
491            Ok(())
492        } else {
493            Err(SklearsError::NotFitted {
494                operation: "partial_fit".to_string(),
495            })
496        }
497    }
498
499    /// Update macro-clusters from micro-clusters
500    fn update_macro_clusters(&mut self) -> Result<()> {
501        if let Some(ref micro_clusters) = &self.micro_clusters {
502            if micro_clusters.is_empty() {
503                return Ok(());
504            }
505
506            let n_features = micro_clusters[0].centroid.len();
507            let mut macro_centroids = Array2::zeros((micro_clusters.len(), n_features));
508
509            for (i, cluster) in micro_clusters.iter().enumerate() {
510                macro_centroids.row_mut(i).assign(&cluster.centroid);
511            }
512
513            self.macro_clusters = Some(macro_centroids);
514        }
515
516        Ok(())
517    }
518
519    /// Get current micro-clusters
520    pub fn micro_clusters(&self) -> Result<&Vec<MicroCluster>> {
521        self.micro_clusters
522            .as_ref()
523            .ok_or_else(|| SklearsError::NotFitted {
524                operation: "micro_clusters".to_string(),
525            })
526    }
527
528    /// Get current macro-clusters
529    pub fn macro_clusters(&self) -> Result<&Array2<Float>> {
530        self.macro_clusters
531            .as_ref()
532            .ok_or_else(|| SklearsError::NotFitted {
533                operation: "macro_clusters".to_string(),
534            })
535    }
536}
537
538impl<State> Default for CluStream<State> {
539    fn default() -> Self {
540        Self::new()
541    }
542}
543
544impl<State> Estimator<State> for CluStream<State> {
545    type Config = StreamingConfig;
546    type Error = SklearsError;
547    type Float = Float;
548
549    fn config(&self) -> &Self::Config {
550        &self.config
551    }
552}
553
554impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, usize>> for CluStream<Untrained> {
555    type Fitted = CluStream<Trained>;
556
557    fn fit(self, x: &ArrayView2<Float>, _y: &ArrayView1<usize>) -> Result<Self::Fitted> {
558        let (n_samples, n_features) = x.dim();
559
560        if n_samples == 0 || n_features == 0 {
561            return Err(SklearsError::InvalidInput("Empty input data".to_string()));
562        }
563
564        // Initialize with first few points as micro-clusters
565        let initial_clusters = self.config.max_clusters.min(n_samples);
566        let mut micro_clusters = Vec::with_capacity(initial_clusters);
567
568        for i in 0..initial_clusters {
569            micro_clusters.push(MicroCluster::new(&x.row(i), i));
570        }
571
572        Ok(CluStream {
573            config: self.config,
574            state: PhantomData,
575            micro_clusters: Some(micro_clusters),
576            macro_clusters: None,
577            timestamp: initial_clusters,
578            update_counter: 0,
579        })
580    }
581}
582
583impl Predict<ArrayView2<'_, Float>, Array1<usize>> for CluStream<Trained> {
584    fn predict(&self, x: &ArrayView2<Float>) -> Result<Array1<usize>> {
585        let micro_clusters = self.micro_clusters()?;
586        let mut labels = Array1::zeros(x.nrows());
587
588        for (i, sample) in x.outer_iter().enumerate() {
589            let mut min_distance = Float::INFINITY;
590            let mut best_cluster = 0;
591
592            for (j, cluster) in micro_clusters.iter().enumerate() {
593                let distance = cluster.distance_to_centroid(&sample);
594                if distance < min_distance {
595                    min_distance = distance;
596                    best_cluster = j;
597                }
598            }
599
600            labels[i] = best_cluster;
601        }
602
603        Ok(labels)
604    }
605}
606
607/// Sliding Window K-Means for temporal data
608pub struct SlidingWindowKMeans<State = Untrained> {
609    config: StreamingConfig,
610    state: PhantomData<State>,
611    // Trained state fields
612    window_data: Option<VecDeque<Array1<Float>>>,
613    centroids: Option<Array2<Float>>,
614    timestamps: Option<VecDeque<usize>>,
615    current_time: usize,
616}
617
618impl<State> SlidingWindowKMeans<State> {
619    /// Create a new Sliding Window K-Means instance
620    pub fn new() -> Self {
621        Self {
622            config: StreamingConfig::default(),
623            state: PhantomData,
624            window_data: None,
625            centroids: None,
626            timestamps: None,
627            current_time: 0,
628        }
629    }
630
631    /// Set the window size
632    pub fn window_size(mut self, window_size: usize) -> Self {
633        self.config.window_size = window_size;
634        self
635    }
636
637    /// Set the number of clusters
638    pub fn max_clusters(mut self, max_clusters: usize) -> Self {
639        self.config.max_clusters = max_clusters;
640        self
641    }
642
643    /// Set the random state
644    pub fn random_state(mut self, seed: u64) -> Self {
645        self.config.random_state = Some(seed);
646        self
647    }
648}
649
650impl SlidingWindowKMeans<Trained> {
651    /// Process a new data point with sliding window
652    pub fn partial_fit(&mut self, point: &ArrayView1<Float>) -> Result<()> {
653        if let (Some(ref mut window_data), Some(ref mut timestamps)) =
654            (&mut self.window_data, &mut self.timestamps)
655        {
656            // Add new point to window
657            window_data.push_back(point.to_owned());
658            timestamps.push_back(self.current_time);
659
660            // Remove old points if window is full
661            while window_data.len() > self.config.window_size {
662                window_data.pop_front();
663                timestamps.pop_front();
664            }
665
666            // Recompute centroids if we have enough data
667            if window_data.len() >= self.config.max_clusters {
668                self.recompute_centroids()?;
669            }
670
671            self.current_time += 1;
672
673            Ok(())
674        } else {
675            Err(SklearsError::NotFitted {
676                operation: "partial_fit".to_string(),
677            })
678        }
679    }
680
681    /// Recompute centroids from current window
682    fn recompute_centroids(&mut self) -> Result<()> {
683        if let Some(ref window_data) = &self.window_data {
684            if window_data.is_empty() {
685                return Ok(());
686            }
687
688            let n_features = window_data[0].len();
689            let k = self.config.max_clusters.min(window_data.len());
690
691            // Simple k-means on window data
692            let mut centroids = Array2::zeros((k, n_features));
693            let mut counts = Array1::zeros(k);
694
695            // Initialize centroids with first k points
696            for (i, point) in window_data.iter().take(k).enumerate() {
697                centroids.row_mut(i).assign(point);
698            }
699
700            // Assign points to clusters and update centroids
701            for _ in 0..10 {
702                // Fixed number of iterations for simplicity
703                counts.fill(0.0);
704                let mut new_centroids = Array2::zeros((k, n_features));
705
706                for point in window_data.iter() {
707                    // Find closest centroid
708                    let mut min_distance = Float::INFINITY;
709                    let mut closest_idx = 0;
710
711                    for (j, centroid) in centroids.outer_iter().enumerate() {
712                        let diff = point - &centroid;
713                        let distance = diff.dot(&diff).sqrt();
714                        if distance < min_distance {
715                            min_distance = distance;
716                            closest_idx = j;
717                        }
718                    }
719
720                    // Update cluster sum and count
721                    let mut row = new_centroids.row_mut(closest_idx);
722                    row += point;
723                    counts[closest_idx] += 1.0;
724                }
725
726                // Update centroids
727                for i in 0..k {
728                    if counts[i] > 0.0 {
729                        let mut row = new_centroids.row_mut(i);
730                        row /= counts[i];
731                        centroids.row_mut(i).assign(&row);
732                    }
733                }
734            }
735
736            self.centroids = Some(centroids);
737        }
738
739        Ok(())
740    }
741
742    /// Get current centroids
743    pub fn centroids(&self) -> Result<&Array2<Float>> {
744        self.centroids
745            .as_ref()
746            .ok_or_else(|| SklearsError::NotFitted {
747                operation: "centroids".to_string(),
748            })
749    }
750
751    /// Get current window size
752    pub fn current_window_size(&self) -> usize {
753        self.window_data.as_ref().map_or(0, |data| data.len())
754    }
755}
756
757impl<State> Default for SlidingWindowKMeans<State> {
758    fn default() -> Self {
759        Self::new()
760    }
761}
762
763impl<State> Estimator<State> for SlidingWindowKMeans<State> {
764    type Config = StreamingConfig;
765    type Error = SklearsError;
766    type Float = Float;
767
768    fn config(&self) -> &Self::Config {
769        &self.config
770    }
771}
772
773impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, usize>> for SlidingWindowKMeans<Untrained> {
774    type Fitted = SlidingWindowKMeans<Trained>;
775
776    fn fit(self, x: &ArrayView2<Float>, _y: &ArrayView1<usize>) -> Result<Self::Fitted> {
777        let (n_samples, n_features) = x.dim();
778
779        if n_samples == 0 || n_features == 0 {
780            return Err(SklearsError::InvalidInput("Empty input data".to_string()));
781        }
782
783        // Initialize window with initial data
784        let window_size = self.config.window_size.min(n_samples);
785        let mut window_data = VecDeque::with_capacity(self.config.window_size);
786        let mut timestamps = VecDeque::with_capacity(self.config.window_size);
787
788        for (i, row) in x.outer_iter().take(window_size).enumerate() {
789            window_data.push_back(row.to_owned());
790            timestamps.push_back(i);
791        }
792
793        // Initialize centroids
794        let k = self.config.max_clusters.min(window_size);
795        let mut centroids = Array2::zeros((k, n_features));
796        for i in 0..k {
797            centroids.row_mut(i).assign(&window_data[i]);
798        }
799
800        Ok(SlidingWindowKMeans {
801            config: self.config,
802            state: PhantomData,
803            window_data: Some(window_data),
804            centroids: Some(centroids),
805            timestamps: Some(timestamps),
806            current_time: window_size,
807        })
808    }
809}
810
811impl Predict<ArrayView2<'_, Float>, Array1<usize>> for SlidingWindowKMeans<Trained> {
812    fn predict(&self, x: &ArrayView2<Float>) -> Result<Array1<usize>> {
813        let centroids = self.centroids()?;
814        let mut labels = Array1::zeros(x.nrows());
815
816        for (i, sample) in x.outer_iter().enumerate() {
817            let mut min_distance = Float::INFINITY;
818            let mut best_cluster = 0;
819
820            for (j, centroid) in centroids.outer_iter().enumerate() {
821                let diff = &sample - &centroid;
822                let distance = diff.dot(&diff).sqrt();
823                if distance < min_distance {
824                    min_distance = distance;
825                    best_cluster = j;
826                }
827            }
828
829            labels[i] = best_cluster;
830        }
831
832        Ok(labels)
833    }
834}
835
836#[allow(non_snake_case)]
837#[cfg(test)]
838mod tests {
839    use super::*;
840    use approx::assert_relative_eq;
841    use scirs2_core::ndarray::array;
842
843    #[test]
844    fn test_micro_cluster_creation() {
845        let point = array![1.0, 2.0];
846        let cluster = MicroCluster::new(&point.view(), 0);
847
848        assert_eq!(cluster.centroid, point);
849        assert_eq!(cluster.weight, 1.0);
850        assert_eq!(cluster.creation_time, 0);
851        assert_eq!(cluster.last_update, 0);
852    }
853
854    #[test]
855    fn test_micro_cluster_update() {
856        let point1 = array![1.0, 2.0];
857        let point2 = array![3.0, 4.0];
858        let mut cluster = MicroCluster::new(&point1.view(), 0);
859
860        cluster.update(&point2.view(), 1, 0.5);
861
862        assert!(cluster.weight > 1.0);
863        assert_eq!(cluster.last_update, 1);
864
865        // Centroid should be updated toward point2
866        assert!(cluster.centroid[0] > 1.0);
867        assert!(cluster.centroid[1] > 2.0);
868    }
869
870    #[test]
871    fn test_online_kmeans_fit() {
872        let x = array![[0.0, 0.0], [0.1, 0.1], [1.0, 1.0], [1.1, 1.1],];
873        let y = Array1::zeros(4);
874
875        let model = OnlineKMeans::new()
876            .max_clusters(2)
877            .learning_rate(0.1)
878            .random_state(42)
879            .fit(&x.view(), &y.view())
880            .unwrap();
881
882        assert!(model.centroids().is_ok());
883        assert!(model.weights().is_ok());
884
885        let centroids = model.centroids().unwrap();
886        assert_eq!(centroids.nrows(), 2);
887        assert_eq!(centroids.ncols(), 2);
888    }
889
890    #[test]
891    fn test_online_kmeans_partial_fit() {
892        let x = array![[0.0, 0.0], [1.0, 1.0],];
893        let y = Array1::zeros(2);
894
895        let mut model = OnlineKMeans::new()
896            .max_clusters(2)
897            .learning_rate(0.1)
898            .random_state(42)
899            .fit(&x.view(), &y.view())
900            .unwrap();
901
902        // Test partial fit with new point
903        let new_point = array![0.5, 0.5];
904        model.partial_fit(&new_point.view()).unwrap();
905
906        let centroids = model.centroids().unwrap();
907        assert_eq!(centroids.nrows(), 2);
908    }
909
910    #[test]
911    fn test_online_kmeans_predict() {
912        let x = array![[0.0, 0.0], [1.0, 1.0],];
913        let y = Array1::zeros(2);
914
915        let model = OnlineKMeans::new()
916            .max_clusters(2)
917            .random_state(42)
918            .fit(&x.view(), &y.view())
919            .unwrap();
920
921        let test_data = array![[0.1, 0.1], [0.9, 0.9],];
922
923        let predictions = model.predict(&test_data.view()).unwrap();
924        assert_eq!(predictions.len(), 2);
925    }
926
927    #[test]
928    fn test_clustream_fit() {
929        let x = array![[0.0, 0.0], [0.1, 0.1], [1.0, 1.0], [1.1, 1.1],];
930        let y = Array1::zeros(4);
931
932        let model = CluStream::new()
933            .max_clusters(3)
934            .creation_threshold(0.5)
935            .merge_threshold(0.3)
936            .random_state(42)
937            .fit(&x.view(), &y.view())
938            .unwrap();
939
940        assert!(model.micro_clusters().is_ok());
941
942        let micro_clusters = model.micro_clusters().unwrap();
943        assert!(!micro_clusters.is_empty());
944        assert!(micro_clusters.len() <= 3);
945    }
946
947    #[test]
948    fn test_clustream_partial_fit() {
949        let x = array![[0.0, 0.0], [1.0, 1.0],];
950        let y = Array1::zeros(2);
951
952        let mut model = CluStream::new()
953            .max_clusters(3)
954            .creation_threshold(0.5)
955            .random_state(42)
956            .fit(&x.view(), &y.view())
957            .unwrap();
958
959        // Test partial fit with new points
960        let new_point1 = array![0.2, 0.2];
961        let new_point2 = array![2.0, 2.0];
962
963        model.partial_fit(&new_point1.view()).unwrap();
964        model.partial_fit(&new_point2.view()).unwrap();
965
966        let micro_clusters = model.micro_clusters().unwrap();
967        assert!(!micro_clusters.is_empty());
968    }
969
970    #[test]
971    fn test_sliding_window_kmeans() {
972        let x = array![[0.0, 0.0], [0.1, 0.1], [1.0, 1.0], [1.1, 1.1],];
973        let y = Array1::zeros(4);
974
975        let mut model = SlidingWindowKMeans::new()
976            .window_size(3)
977            .max_clusters(2)
978            .random_state(42)
979            .fit(&x.view(), &y.view())
980            .unwrap();
981
982        assert!(model.centroids().is_ok());
983        assert_eq!(model.current_window_size(), 3);
984
985        // Test partial fit
986        let new_point = array![2.0, 2.0];
987        model.partial_fit(&new_point.view()).unwrap();
988
989        // Window should still be size 3 (slides)
990        assert_eq!(model.current_window_size(), 3);
991    }
992
993    #[test]
994    fn test_micro_cluster_decay() {
995        let point = array![1.0, 2.0];
996        let mut cluster = MicroCluster::new(&point.view(), 0);
997
998        let initial_weight = cluster.weight;
999        cluster.decay(0.9);
1000
1001        assert!(cluster.weight < initial_weight);
1002        assert_eq!(cluster.weight, initial_weight * 0.9);
1003    }
1004
1005    #[test]
1006    fn test_micro_cluster_should_remove() {
1007        let point = array![1.0, 2.0];
1008        let mut cluster = MicroCluster::new(&point.view(), 0);
1009
1010        // Reduce weight below threshold
1011        cluster.weight = 0.005;
1012
1013        assert!(cluster.should_remove(0.01));
1014        assert!(!cluster.should_remove(0.001));
1015    }
1016}