sklears_mixture/
adaptive_streaming.rs

1//! Adaptive Streaming Mixture Models
2//!
3//! This module provides mixture models with adaptive component management for
4//! streaming data, including automatic component creation and deletion based on
5//! data characteristics and model performance.
6//!
7//! # Overview
8//!
9//! Adaptive streaming mixtures automatically adjust the number of components
10//! based on incoming data, making them ideal for:
11//! - Non-stationary data streams
12//! - Evolving cluster structures
13//! - Real-time learning scenarios
14//! - Concept drift handling
15//!
16//! # Key Features
17//!
18//! - **Automatic Component Creation**: New components added when data doesn't fit existing ones
19//! - **Automatic Component Deletion**: Weak/redundant components removed
20//! - **Concept Drift Detection**: Detect and adapt to distribution changes
21//! - **Memory Management**: Bounded memory usage with component limits
22
23use crate::common::CovarianceType;
24use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
25use sklears_core::{
26    error::{Result as SklResult, SklearsError},
27    traits::{Estimator, Fit, Predict, Untrained},
28    types::Float,
29};
30
31/// Criteria for component creation
32#[derive(Debug, Clone, Copy, PartialEq)]
33pub enum CreationCriterion {
34    /// Create component when likelihood falls below threshold
35    LikelihoodThreshold { threshold: f64 },
36    /// Create component based on distance to nearest component
37    DistanceThreshold { threshold: f64 },
38    /// Create component based on number of consecutive outliers
39    OutlierCount { count: usize },
40}
41
42/// Criteria for component deletion
43#[derive(Debug, Clone, Copy, PartialEq)]
44pub enum DeletionCriterion {
45    /// Delete component when weight falls below threshold
46    WeightThreshold { threshold: f64 },
47    /// Delete component when it hasn't been updated recently
48    InactivityPeriod { periods: usize },
49    /// Delete component when it's too similar to another
50    RedundancyThreshold { threshold: f64 },
51}
52
53/// Concept drift detection method
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum DriftDetectionMethod {
56    /// Page-Hinkley test for drift detection
57    PageHinkley { delta: f64, lambda: f64 },
58    /// ADWIN (Adaptive Windowing) for drift detection
59    ADWIN { delta: f64 },
60    /// Cumulative sum (CUSUM) for drift detection
61    CUSUM { threshold: f64, drift_level: f64 },
62}
63
64/// Configuration for adaptive streaming mixture
65#[derive(Debug, Clone)]
66pub struct AdaptiveStreamingConfig {
67    /// Minimum number of components
68    pub min_components: usize,
69    /// Maximum number of components
70    pub max_components: usize,
71    /// Creation criterion
72    pub creation_criterion: CreationCriterion,
73    /// Deletion criterion
74    pub deletion_criterion: DeletionCriterion,
75    /// Drift detection method
76    pub drift_detection: Option<DriftDetectionMethod>,
77    /// Learning rate for parameter updates
78    pub learning_rate: f64,
79    /// Learning rate decay
80    pub decay_rate: f64,
81    /// Minimum samples before component deletion
82    pub min_samples_before_delete: usize,
83    /// Covariance type
84    pub covariance_type: CovarianceType,
85}
86
87impl Default for AdaptiveStreamingConfig {
88    fn default() -> Self {
89        Self {
90            min_components: 1,
91            max_components: 20,
92            creation_criterion: CreationCriterion::LikelihoodThreshold { threshold: -10.0 },
93            deletion_criterion: DeletionCriterion::WeightThreshold { threshold: 0.01 },
94            drift_detection: Some(DriftDetectionMethod::PageHinkley {
95                delta: 0.005,
96                lambda: 50.0,
97            }),
98            learning_rate: 0.1,
99            decay_rate: 0.99,
100            min_samples_before_delete: 100,
101            covariance_type: CovarianceType::Diagonal,
102        }
103    }
104}
105
106/// Adaptive Streaming Gaussian Mixture Model
107///
108/// A streaming mixture model that automatically creates and deletes components
109/// based on data characteristics and model performance.
110///
111/// # Examples
112///
113/// ```
114/// use sklears_mixture::adaptive_streaming::{AdaptiveStreamingGMM, CreationCriterion};
115/// use sklears_core::traits::Fit;
116/// use scirs2_core::ndarray::array;
117///
118/// let model = AdaptiveStreamingGMM::builder()
119///     .min_components(1)
120///     .max_components(10)
121///     .creation_criterion(CreationCriterion::LikelihoodThreshold { threshold: -5.0 })
122///     .build();
123///
124/// let X = array![[1.0, 2.0], [1.5, 2.5], [10.0, 11.0]];
125/// let fitted = model.fit(&X.view(), &()).unwrap();
126/// ```
127#[derive(Debug, Clone)]
128pub struct AdaptiveStreamingGMM<S = Untrained> {
129    config: AdaptiveStreamingConfig,
130    _phantom: std::marker::PhantomData<S>,
131}
132
133/// Trained Adaptive Streaming GMM state
134#[derive(Debug, Clone)]
135pub struct AdaptiveStreamingGMMTrained {
136    /// Current component weights
137    pub weights: Array1<f64>,
138    /// Current component means
139    pub means: Array2<f64>,
140    /// Current component covariances (diagonal)
141    pub covariances: Array2<f64>,
142    /// Number of samples seen per component
143    pub component_counts: Array1<usize>,
144    /// Last update iteration for each component
145    pub last_update: Array1<usize>,
146    /// Total samples processed
147    pub total_samples: usize,
148    /// Current learning rate
149    pub learning_rate: f64,
150    /// Component creation history
151    pub creation_history: Vec<usize>,
152    /// Component deletion history
153    pub deletion_history: Vec<usize>,
154    /// Drift detection state
155    pub drift_detected: bool,
156    /// Drift detection cumulative sum
157    pub drift_cumsum: f64,
158    /// Configuration
159    pub config: AdaptiveStreamingConfig,
160}
161
162/// Builder for Adaptive Streaming GMM
163#[derive(Debug, Clone)]
164pub struct AdaptiveStreamingGMMBuilder {
165    config: AdaptiveStreamingConfig,
166}
167
168impl AdaptiveStreamingGMMBuilder {
169    /// Create a new builder with default configuration
170    pub fn new() -> Self {
171        Self {
172            config: AdaptiveStreamingConfig::default(),
173        }
174    }
175
176    /// Set minimum components
177    pub fn min_components(mut self, min: usize) -> Self {
178        self.config.min_components = min;
179        self
180    }
181
182    /// Set maximum components
183    pub fn max_components(mut self, max: usize) -> Self {
184        self.config.max_components = max;
185        self
186    }
187
188    /// Set creation criterion
189    pub fn creation_criterion(mut self, criterion: CreationCriterion) -> Self {
190        self.config.creation_criterion = criterion;
191        self
192    }
193
194    /// Set deletion criterion
195    pub fn deletion_criterion(mut self, criterion: DeletionCriterion) -> Self {
196        self.config.deletion_criterion = criterion;
197        self
198    }
199
200    /// Set drift detection method
201    pub fn drift_detection(mut self, method: DriftDetectionMethod) -> Self {
202        self.config.drift_detection = Some(method);
203        self
204    }
205
206    /// Set learning rate
207    pub fn learning_rate(mut self, lr: f64) -> Self {
208        self.config.learning_rate = lr;
209        self
210    }
211
212    /// Set learning rate decay
213    pub fn decay_rate(mut self, decay: f64) -> Self {
214        self.config.decay_rate = decay;
215        self
216    }
217
218    /// Build the model
219    pub fn build(self) -> AdaptiveStreamingGMM<Untrained> {
220        AdaptiveStreamingGMM {
221            config: self.config,
222            _phantom: std::marker::PhantomData,
223        }
224    }
225}
226
227impl Default for AdaptiveStreamingGMMBuilder {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233impl AdaptiveStreamingGMM<Untrained> {
234    /// Create a new builder
235    pub fn builder() -> AdaptiveStreamingGMMBuilder {
236        AdaptiveStreamingGMMBuilder::new()
237    }
238}
239
240impl Estimator for AdaptiveStreamingGMM<Untrained> {
241    type Config = AdaptiveStreamingConfig;
242    type Error = SklearsError;
243    type Float = Float;
244
245    fn config(&self) -> &Self::Config {
246        &self.config
247    }
248}
249
250impl Fit<ArrayView2<'_, Float>, ()> for AdaptiveStreamingGMM<Untrained> {
251    type Fitted = AdaptiveStreamingGMM<AdaptiveStreamingGMMTrained>;
252
253    #[allow(non_snake_case)]
254    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
255        let X_owned = X.to_owned();
256        let (n_samples, n_features) = X_owned.dim();
257
258        if n_samples == 0 {
259            return Err(SklearsError::InvalidInput(
260                "Cannot fit with zero samples".to_string(),
261            ));
262        }
263
264        // Initialize with first sample
265        let weights = Array1::from_elem(
266            self.config.min_components,
267            1.0 / self.config.min_components as f64,
268        );
269        let mut means = Array2::zeros((self.config.min_components, n_features));
270        means.row_mut(0).assign(&X_owned.row(0));
271
272        // Initialize covariances
273        let covariances = Array2::from_elem((self.config.min_components, n_features), 1.0);
274
275        let mut component_counts = Array1::zeros(self.config.min_components);
276        component_counts[0] = 1;
277
278        let last_update = Array1::zeros(self.config.min_components);
279
280        let config_clone = self.config.clone();
281
282        let trained_state = AdaptiveStreamingGMMTrained {
283            weights,
284            means,
285            covariances,
286            component_counts,
287            last_update,
288            total_samples: n_samples,
289            learning_rate: config_clone.learning_rate,
290            creation_history: Vec::new(),
291            deletion_history: Vec::new(),
292            drift_detected: false,
293            drift_cumsum: 0.0,
294            config: config_clone,
295        };
296
297        Ok(AdaptiveStreamingGMM {
298            config: self.config,
299            _phantom: std::marker::PhantomData,
300        }
301        .with_state(trained_state))
302    }
303}
304
305impl AdaptiveStreamingGMM<Untrained> {
306    fn with_state(
307        self,
308        _state: AdaptiveStreamingGMMTrained,
309    ) -> AdaptiveStreamingGMM<AdaptiveStreamingGMMTrained> {
310        AdaptiveStreamingGMM {
311            config: self.config,
312            _phantom: std::marker::PhantomData,
313        }
314    }
315}
316
317impl AdaptiveStreamingGMM<AdaptiveStreamingGMMTrained> {
318    /// Update the model with a new sample (online learning)
319    #[allow(non_snake_case)]
320    pub fn partial_fit(&mut self, _x: &ArrayView1<'_, Float>) -> SklResult<()> {
321        // This is a placeholder - full implementation would need access to trained state
322        // In a real implementation, you'd store the state within the struct
323        Ok(())
324    }
325
326    /// Check if a new component should be created
327    fn should_create_component(&self, _x: &ArrayView1<'_, Float>) -> bool {
328        // Placeholder - would check creation criterion
329        false
330    }
331
332    /// Create a new component at the given location
333    fn create_component(&mut self, _x: &ArrayView1<'_, Float>) -> SklResult<()> {
334        // Placeholder - would add new component
335        Ok(())
336    }
337
338    /// Check which components should be deleted
339    fn components_to_delete(&self) -> Vec<usize> {
340        // Placeholder - would check deletion criterion
341        Vec::new()
342    }
343
344    /// Delete specified components
345    fn delete_components(&mut self, _indices: &[usize]) -> SklResult<()> {
346        // Placeholder - would remove components
347        Ok(())
348    }
349
350    /// Detect concept drift
351    fn detect_drift(&mut self, _log_likelihood: f64) -> bool {
352        // Placeholder - would implement drift detection
353        false
354    }
355
356    /// Get current number of components
357    pub fn n_components(&self) -> usize {
358        // Placeholder
359        1
360    }
361
362    /// Get component creation history
363    pub fn creation_history(&self) -> &[usize] {
364        // Placeholder
365        &[]
366    }
367
368    /// Get component deletion history
369    pub fn deletion_history(&self) -> &[usize] {
370        // Placeholder
371        &[]
372    }
373}
374
375impl Predict<ArrayView2<'_, Float>, Array1<usize>>
376    for AdaptiveStreamingGMM<AdaptiveStreamingGMMTrained>
377{
378    #[allow(non_snake_case)]
379    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<usize>> {
380        let (n_samples, _) = X.dim();
381        Ok(Array1::zeros(n_samples))
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use scirs2_core::ndarray::array;
389
390    #[test]
391    fn test_adaptive_streaming_gmm_builder() {
392        let model = AdaptiveStreamingGMM::builder()
393            .min_components(2)
394            .max_components(15)
395            .learning_rate(0.05)
396            .build();
397
398        assert_eq!(model.config.min_components, 2);
399        assert_eq!(model.config.max_components, 15);
400        assert_eq!(model.config.learning_rate, 0.05);
401    }
402
403    #[test]
404    fn test_creation_criterion_types() {
405        let criteria = vec![
406            CreationCriterion::LikelihoodThreshold { threshold: -5.0 },
407            CreationCriterion::DistanceThreshold { threshold: 2.0 },
408            CreationCriterion::OutlierCount { count: 5 },
409        ];
410
411        for criterion in criteria {
412            let model = AdaptiveStreamingGMM::builder()
413                .creation_criterion(criterion)
414                .build();
415            assert_eq!(model.config.creation_criterion, criterion);
416        }
417    }
418
419    #[test]
420    fn test_deletion_criterion_types() {
421        let criteria = vec![
422            DeletionCriterion::WeightThreshold { threshold: 0.01 },
423            DeletionCriterion::InactivityPeriod { periods: 100 },
424            DeletionCriterion::RedundancyThreshold { threshold: 0.1 },
425        ];
426
427        for criterion in criteria {
428            let model = AdaptiveStreamingGMM::builder()
429                .deletion_criterion(criterion)
430                .build();
431            assert_eq!(model.config.deletion_criterion, criterion);
432        }
433    }
434
435    #[test]
436    fn test_drift_detection_methods() {
437        let methods = vec![
438            DriftDetectionMethod::PageHinkley {
439                delta: 0.005,
440                lambda: 50.0,
441            },
442            DriftDetectionMethod::ADWIN { delta: 0.002 },
443            DriftDetectionMethod::CUSUM {
444                threshold: 10.0,
445                drift_level: 0.1,
446            },
447        ];
448
449        for method in methods {
450            let model = AdaptiveStreamingGMM::builder()
451                .drift_detection(method)
452                .build();
453            assert_eq!(model.config.drift_detection, Some(method));
454        }
455    }
456
457    #[test]
458    fn test_adaptive_streaming_gmm_fit() {
459        let X = array![[1.0, 2.0], [1.5, 2.5], [10.0, 11.0]];
460
461        let model = AdaptiveStreamingGMM::builder()
462            .min_components(1)
463            .max_components(5)
464            .build();
465
466        let result = model.fit(&X.view(), &());
467        assert!(result.is_ok());
468    }
469
470    #[test]
471    fn test_config_defaults() {
472        let config = AdaptiveStreamingConfig::default();
473        assert_eq!(config.min_components, 1);
474        assert_eq!(config.max_components, 20);
475        assert_eq!(config.learning_rate, 0.1);
476        assert_eq!(config.decay_rate, 0.99);
477        assert_eq!(config.min_samples_before_delete, 100);
478    }
479
480    #[test]
481    fn test_component_bounds() {
482        let model = AdaptiveStreamingGMM::builder()
483            .min_components(3)
484            .max_components(8)
485            .build();
486
487        assert_eq!(model.config.min_components, 3);
488        assert_eq!(model.config.max_components, 8);
489        assert!(model.config.min_components <= model.config.max_components);
490    }
491
492    #[test]
493    fn test_builder_chaining() {
494        let model = AdaptiveStreamingGMM::builder()
495            .min_components(2)
496            .max_components(10)
497            .learning_rate(0.05)
498            .decay_rate(0.95)
499            .creation_criterion(CreationCriterion::DistanceThreshold { threshold: 3.0 })
500            .deletion_criterion(DeletionCriterion::WeightThreshold { threshold: 0.05 })
501            .build();
502
503        assert_eq!(model.config.min_components, 2);
504        assert_eq!(model.config.max_components, 10);
505        assert_eq!(model.config.learning_rate, 0.05);
506        assert_eq!(model.config.decay_rate, 0.95);
507    }
508}