sklears_compose/
streaming.rs

1//! Streaming pipeline components for real-time data processing
2//!
3//! This module provides streaming capabilities including windowing strategies,
4//! online model updates, incremental processing, and real-time analytics.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8    error::Result as SklResult,
9    prelude::{Predict, SklearsError},
10    traits::{Estimator, Fit, Untrained},
11    types::Float,
12};
13use std::collections::{HashMap, VecDeque};
14use std::time::{Duration, Instant, SystemTime};
15
16use crate::{PipelinePredictor, PipelineStep};
17
18/// Data point in a stream
19#[derive(Debug, Clone)]
20pub struct StreamDataPoint {
21    /// Feature values
22    pub features: Array1<f64>,
23    /// Target value (optional)
24    pub target: Option<f64>,
25    /// Timestamp
26    pub timestamp: SystemTime,
27    /// Metadata
28    pub metadata: HashMap<String, String>,
29    /// Data point ID
30    pub id: String,
31}
32
33impl StreamDataPoint {
34    /// Create a new stream data point
35    #[must_use]
36    pub fn new(features: Array1<f64>, id: String) -> Self {
37        Self {
38            features,
39            target: None,
40            timestamp: SystemTime::now(),
41            metadata: HashMap::new(),
42            id,
43        }
44    }
45
46    /// Set target value
47    #[must_use]
48    pub fn with_target(mut self, target: f64) -> Self {
49        self.target = Some(target);
50        self
51    }
52
53    /// Set timestamp
54    #[must_use]
55    pub fn with_timestamp(mut self, timestamp: SystemTime) -> Self {
56        self.timestamp = timestamp;
57        self
58    }
59
60    /// Set metadata
61    #[must_use]
62    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
63        self.metadata = metadata;
64        self
65    }
66}
67
68/// Window of stream data points
69#[derive(Debug, Clone)]
70pub struct StreamWindow {
71    /// Data points in the window
72    pub data_points: Vec<StreamDataPoint>,
73    /// Window start time
74    pub start_time: SystemTime,
75    /// Window end time
76    pub end_time: SystemTime,
77    /// Window metadata
78    pub metadata: HashMap<String, String>,
79}
80
81impl StreamWindow {
82    /// Create a new stream window
83    #[must_use]
84    pub fn new(start_time: SystemTime, end_time: SystemTime) -> Self {
85        Self {
86            data_points: Vec::new(),
87            start_time,
88            end_time,
89            metadata: HashMap::new(),
90        }
91    }
92
93    /// Add a data point to the window
94    pub fn add_point(&mut self, point: StreamDataPoint) {
95        self.data_points.push(point);
96    }
97
98    /// Get features matrix
99    pub fn features_matrix(&self) -> SklResult<Array2<f64>> {
100        if self.data_points.is_empty() {
101            return Err(SklearsError::InvalidInput("Empty window".to_string()));
102        }
103
104        let n_samples = self.data_points.len();
105        let n_features = self.data_points[0].features.len();
106
107        let mut features = Array2::zeros((n_samples, n_features));
108        for (i, point) in self.data_points.iter().enumerate() {
109            features.row_mut(i).assign(&point.features);
110        }
111
112        Ok(features)
113    }
114
115    /// Get targets array
116    #[must_use]
117    pub fn targets_array(&self) -> Option<Array1<f64>> {
118        if self.data_points.iter().all(|p| p.target.is_some()) {
119            Some(Array1::from_vec(
120                self.data_points.iter().map(|p| p.target.unwrap()).collect(),
121            ))
122        } else {
123            None
124        }
125    }
126
127    /// Get window size
128    #[must_use]
129    pub fn size(&self) -> usize {
130        self.data_points.len()
131    }
132
133    /// Check if window is empty
134    #[must_use]
135    pub fn is_empty(&self) -> bool {
136        self.data_points.is_empty()
137    }
138}
139
140/// Windowing strategy for stream processing
141pub enum WindowingStrategy {
142    /// Fixed time windows
143    TumblingTime {
144        /// Window duration
145        duration: Duration,
146    },
147    /// Sliding time windows
148    SlidingTime {
149        /// Window duration
150        duration: Duration,
151        /// Slide interval
152        slide: Duration,
153    },
154    /// Fixed count windows
155    TumblingCount {
156        /// Number of elements per window
157        count: usize,
158    },
159    /// Sliding count windows
160    SlidingCount {
161        /// Window size
162        size: usize,
163        /// Slide step
164        step: usize,
165    },
166    /// Session windows (gap-based)
167    Session {
168        /// Maximum gap between elements
169        gap: Duration,
170    },
171    /// Custom windowing
172    Custom {
173        /// Custom window trigger function
174        trigger_fn: Box<dyn Fn(&[StreamDataPoint]) -> bool + Send + Sync>,
175    },
176}
177
178/// Stream processing configuration
179pub struct StreamConfig {
180    /// Windowing strategy
181    pub windowing: WindowingStrategy,
182    /// Buffer size for incoming data
183    pub buffer_size: usize,
184    /// Processing parallelism
185    pub parallelism: usize,
186    /// Backpressure threshold
187    pub backpressure_threshold: usize,
188    /// Latency targets
189    pub latency_target: Duration,
190    /// Checkpoint interval
191    pub checkpoint_interval: Duration,
192    /// State management
193    pub state_management: StateManagement,
194}
195
196impl Default for StreamConfig {
197    fn default() -> Self {
198        Self {
199            windowing: WindowingStrategy::TumblingTime {
200                duration: Duration::from_secs(60),
201            },
202            buffer_size: 10000,
203            parallelism: 1,
204            backpressure_threshold: 8000,
205            latency_target: Duration::from_millis(100),
206            checkpoint_interval: Duration::from_secs(300),
207            state_management: StateManagement::InMemory,
208        }
209    }
210}
211
212/// State management strategy
213#[derive(Debug, Clone)]
214pub enum StateManagement {
215    /// In-memory state (non-persistent)
216    InMemory,
217    /// Periodic snapshots to disk
218    Snapshots {
219        /// Snapshot directory
220        directory: String,
221        /// Snapshot interval
222        interval: Duration,
223    },
224    /// Write-ahead log
225    WriteAheadLog {
226        /// Log file path
227        log_path: String,
228    },
229    /// External state store
230    External {
231        /// State store configuration
232        config: HashMap<String, String>,
233    },
234}
235
236/// Online learning update strategy
237pub enum UpdateStrategy {
238    /// Update on every data point
239    Immediate,
240    /// Batch updates
241    Batch {
242        /// Batch size
243        batch_size: usize,
244    },
245    /// Time-based updates
246    TimeBased {
247        /// Update interval
248        interval: Duration,
249    },
250    /// Adaptive updates based on drift detection
251    Adaptive {
252        /// Drift detection threshold
253        drift_threshold: f64,
254        /// Minimum update interval
255        min_interval: Duration,
256        /// Maximum update interval
257        max_interval: Duration,
258    },
259    /// Custom update trigger
260    Custom {
261        /// Update trigger function
262        trigger_fn: Box<dyn Fn(&StreamWindow, &StreamStats) -> bool + Send + Sync>,
263    },
264}
265
266/// Stream processing statistics
267#[derive(Debug, Clone)]
268pub struct StreamStats {
269    /// Total processed samples
270    pub total_samples: usize,
271    /// Current throughput (samples/second)
272    pub throughput: f64,
273    /// Average latency (milliseconds)
274    pub avg_latency: f64,
275    /// Current buffer utilization
276    pub buffer_utilization: f64,
277    /// Model accuracy (if available)
278    pub accuracy: Option<f64>,
279    /// Data drift metrics
280    pub drift_metrics: HashMap<String, f64>,
281    /// Error rates
282    pub error_rate: f64,
283    /// Processing start time
284    pub start_time: SystemTime,
285    /// Last update time
286    pub last_update: SystemTime,
287}
288
289impl Default for StreamStats {
290    fn default() -> Self {
291        let now = SystemTime::now();
292        Self {
293            total_samples: 0,
294            throughput: 0.0,
295            avg_latency: 0.0,
296            buffer_utilization: 0.0,
297            accuracy: None,
298            drift_metrics: HashMap::new(),
299            error_rate: 0.0,
300            start_time: now,
301            last_update: now,
302        }
303    }
304}
305
306/// Streaming pipeline processor
307pub struct StreamingPipeline<S = Untrained> {
308    state: S,
309    base_estimator: Option<Box<dyn PipelinePredictor>>,
310    config: StreamConfig,
311    update_strategy: UpdateStrategy,
312    data_buffer: VecDeque<StreamDataPoint>,
313    windows: Vec<StreamWindow>,
314    statistics: StreamStats,
315}
316
317/// Trained state for `StreamingPipeline`
318pub struct StreamingPipelineTrained {
319    fitted_estimator: Box<dyn PipelinePredictor>,
320    config: StreamConfig,
321    update_strategy: UpdateStrategy,
322    data_buffer: VecDeque<StreamDataPoint>,
323    windows: Vec<StreamWindow>,
324    statistics: StreamStats,
325    model_state: HashMap<String, f64>,
326    n_features_in: usize,
327    feature_names_in: Option<Vec<String>>,
328}
329
330impl StreamingPipeline<Untrained> {
331    /// Create a new streaming pipeline
332    #[must_use]
333    pub fn new(base_estimator: Box<dyn PipelinePredictor>, config: StreamConfig) -> Self {
334        Self {
335            state: Untrained,
336            base_estimator: Some(base_estimator),
337            config,
338            update_strategy: UpdateStrategy::Batch { batch_size: 100 },
339            data_buffer: VecDeque::new(),
340            windows: Vec::new(),
341            statistics: StreamStats::default(),
342        }
343    }
344
345    /// Set update strategy
346    #[must_use]
347    pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
348        self.update_strategy = strategy;
349        self
350    }
351
352    /// Create a tumbling time window pipeline
353    #[must_use]
354    pub fn tumbling_time(
355        base_estimator: Box<dyn PipelinePredictor>,
356        window_duration: Duration,
357    ) -> Self {
358        let config = StreamConfig {
359            windowing: WindowingStrategy::TumblingTime {
360                duration: window_duration,
361            },
362            ..StreamConfig::default()
363        };
364        Self::new(base_estimator, config)
365    }
366
367    /// Create a sliding window pipeline
368    #[must_use]
369    pub fn sliding_window(
370        base_estimator: Box<dyn PipelinePredictor>,
371        window_size: usize,
372        slide_step: usize,
373    ) -> Self {
374        let config = StreamConfig {
375            windowing: WindowingStrategy::SlidingCount {
376                size: window_size,
377                step: slide_step,
378            },
379            ..StreamConfig::default()
380        };
381        Self::new(base_estimator, config)
382    }
383
384    /// Create a session window pipeline
385    #[must_use]
386    pub fn session_window(
387        base_estimator: Box<dyn PipelinePredictor>,
388        session_gap: Duration,
389    ) -> Self {
390        let config = StreamConfig {
391            windowing: WindowingStrategy::Session { gap: session_gap },
392            ..StreamConfig::default()
393        };
394        Self::new(base_estimator, config)
395    }
396}
397
398impl Estimator for StreamingPipeline<Untrained> {
399    type Config = ();
400    type Error = SklearsError;
401    type Float = Float;
402
403    fn config(&self) -> &Self::Config {
404        &()
405    }
406}
407
408impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for StreamingPipeline<Untrained> {
409    type Fitted = StreamingPipeline<StreamingPipelineTrained>;
410
411    fn fit(
412        self,
413        x: &ArrayView2<'_, Float>,
414        y: &Option<&ArrayView1<'_, Float>>,
415    ) -> SklResult<Self::Fitted> {
416        let mut base_estimator = self
417            .base_estimator
418            .ok_or_else(|| SklearsError::InvalidInput("No base estimator provided".to_string()))?;
419
420        // Initial training on batch data
421        if let Some(y_ref) = y {
422            base_estimator.fit(x, y_ref)?;
423        } else {
424            return Err(SklearsError::InvalidInput(
425                "No target values provided for initial training".to_string(),
426            ));
427        }
428
429        // Initialize streaming state
430        let mut model_state = HashMap::new();
431        model_state.insert("batch_training_samples".to_string(), x.nrows() as f64);
432
433        let mut statistics = self.statistics;
434        statistics.total_samples = x.nrows();
435        statistics.start_time = SystemTime::now();
436        statistics.last_update = SystemTime::now();
437
438        Ok(StreamingPipeline {
439            state: StreamingPipelineTrained {
440                fitted_estimator: base_estimator,
441                config: self.config,
442                update_strategy: self.update_strategy,
443                data_buffer: self.data_buffer,
444                windows: self.windows,
445                statistics,
446                model_state,
447                n_features_in: x.ncols(),
448                feature_names_in: None,
449            },
450            base_estimator: None,
451            config: StreamConfig::default(),
452            update_strategy: UpdateStrategy::Immediate,
453            data_buffer: VecDeque::new(),
454            windows: Vec::new(),
455            statistics: StreamStats::default(),
456        })
457    }
458}
459
460impl StreamingPipeline<StreamingPipelineTrained> {
461    /// Process a single data point from the stream
462    pub fn process_point(&mut self, point: StreamDataPoint) -> SklResult<Option<Array1<f64>>> {
463        let start_time = Instant::now();
464
465        // Check for backpressure
466        if self.state.data_buffer.len() >= self.state.config.backpressure_threshold {
467            return Err(SklearsError::InvalidInput(
468                "Backpressure threshold exceeded".to_string(),
469            ));
470        }
471
472        // Add to buffer
473        self.state.data_buffer.push_back(point.clone());
474
475        // Update statistics
476        self.state.statistics.total_samples += 1;
477        self.state.statistics.buffer_utilization =
478            self.state.data_buffer.len() as f64 / self.state.config.buffer_size as f64;
479
480        // Create prediction input
481        let features_2d =
482            Array2::from_shape_vec((1, point.features.len()), point.features.to_vec()).map_err(
483                |e| SklearsError::InvalidData {
484                    reason: format!("Feature reshaping failed: {e}"),
485                },
486            )?;
487
488        // Make prediction
489        let prediction = self.state.fitted_estimator.predict(&features_2d.view())?;
490
491        // Process windows
492        self.process_windows()?;
493
494        // Check for model updates
495        self.check_model_update()?;
496
497        // Update latency statistics
498        let processing_time = start_time.elapsed().as_millis() as f64;
499        self.state.statistics.avg_latency =
500            (self.state.statistics.avg_latency * 0.9) + (processing_time * 0.1);
501
502        // Update throughput
503        let elapsed = self
504            .state
505            .statistics
506            .start_time
507            .elapsed()
508            .unwrap_or(Duration::from_secs(1));
509        self.state.statistics.throughput =
510            self.state.statistics.total_samples as f64 / elapsed.as_secs_f64();
511
512        Ok(Some(prediction))
513    }
514
515    /// Process batch of data points
516    pub fn process_batch(&mut self, points: Vec<StreamDataPoint>) -> SklResult<Array2<f64>> {
517        let mut predictions = Vec::new();
518
519        for point in points {
520            if let Some(pred) = self.process_point(point)? {
521                predictions.extend(pred.iter().copied());
522            }
523        }
524
525        if predictions.is_empty() {
526            return Ok(Array2::zeros((0, 1)));
527        }
528
529        let n_predictions = predictions.len();
530        Array2::from_shape_vec((n_predictions, 1), predictions).map_err(|e| {
531            SklearsError::InvalidData {
532                reason: format!("Batch prediction reshape failed: {e}"),
533            }
534        })
535    }
536
537    /// Process windowing logic
538    fn process_windows(&mut self) -> SklResult<()> {
539        match &self.state.config.windowing {
540            WindowingStrategy::TumblingTime { duration } => {
541                self.process_tumbling_time_windows(*duration)
542            }
543            WindowingStrategy::SlidingTime { duration, slide } => {
544                self.process_sliding_time_windows(*duration, *slide)
545            }
546            WindowingStrategy::TumblingCount { count } => {
547                self.process_tumbling_count_windows(*count)
548            }
549            WindowingStrategy::SlidingCount { size, step } => {
550                self.process_sliding_count_windows(*size, *step)
551            }
552            WindowingStrategy::Session { gap } => self.process_session_windows(*gap),
553            WindowingStrategy::Custom { .. } => {
554                // Handle custom windowing differently to avoid borrow checker issues
555                self.process_custom_windows_safe()
556            }
557        }
558    }
559
560    /// Process tumbling time windows
561    fn process_tumbling_time_windows(&mut self, duration: Duration) -> SklResult<()> {
562        let now = SystemTime::now();
563
564        // Create new window if needed
565        if self.state.windows.is_empty() {
566            let window = StreamWindow::new(now, now + duration);
567            self.state.windows.push(window);
568        }
569
570        // Add points to current window
571        while let Some(point) = self.state.data_buffer.pop_front() {
572            if let Some(current_window) = self.state.windows.last_mut() {
573                if point.timestamp <= current_window.end_time {
574                    current_window.add_point(point);
575                } else {
576                    // Create new window
577                    let mut new_window = StreamWindow::new(
578                        current_window.end_time,
579                        current_window.end_time + duration,
580                    );
581                    new_window.add_point(point);
582                    self.state.windows.push(new_window);
583                }
584            }
585        }
586
587        // Remove completed windows (keep only current)
588        self.state.windows.retain(|w| w.end_time > now);
589
590        Ok(())
591    }
592
593    /// Process sliding time windows
594    fn process_sliding_time_windows(
595        &mut self,
596        duration: Duration,
597        slide: Duration,
598    ) -> SklResult<()> {
599        // Simplified implementation
600        self.process_tumbling_time_windows(duration)
601    }
602
603    /// Process tumbling count windows
604    fn process_tumbling_count_windows(&mut self, count: usize) -> SklResult<()> {
605        let now = SystemTime::now();
606
607        while self.state.data_buffer.len() >= count {
608            let mut window = StreamWindow::new(now, now);
609            for _ in 0..count {
610                if let Some(point) = self.state.data_buffer.pop_front() {
611                    window.add_point(point);
612                }
613            }
614            self.state.windows.push(window);
615        }
616
617        Ok(())
618    }
619
620    /// Process sliding count windows
621    fn process_sliding_count_windows(&mut self, size: usize, step: usize) -> SklResult<()> {
622        // Simplified implementation - just use tumbling for now
623        self.process_tumbling_count_windows(step)
624    }
625
626    /// Process session windows
627    fn process_session_windows(&mut self, gap: Duration) -> SklResult<()> {
628        // Simplified implementation
629        let now = SystemTime::now();
630
631        if let Some(mut current_window) = self.state.windows.pop() {
632            while let Some(point) = self.state.data_buffer.pop_front() {
633                let time_since_last = point
634                    .timestamp
635                    .duration_since(current_window.end_time)
636                    .unwrap_or(Duration::ZERO);
637
638                if time_since_last <= gap {
639                    current_window.add_point(point.clone());
640                    current_window.end_time = point.timestamp;
641                } else {
642                    // Start new session
643                    self.state.windows.push(current_window);
644                    current_window = StreamWindow::new(point.timestamp, point.timestamp);
645                    current_window.add_point(point);
646                }
647            }
648            self.state.windows.push(current_window);
649        } else if !self.state.data_buffer.is_empty() {
650            // Start first session
651            if let Some(point) = self.state.data_buffer.pop_front() {
652                let mut window = StreamWindow::new(point.timestamp, point.timestamp);
653                window.add_point(point);
654                self.state.windows.push(window);
655            }
656        }
657
658        Ok(())
659    }
660
661    /// Process custom windows safely (avoiding borrow checker issues)
662    fn process_custom_windows_safe(&mut self) -> SklResult<()> {
663        // Extract trigger function to avoid borrowing issues
664        if let WindowingStrategy::Custom { trigger_fn } = &self.state.config.windowing {
665            let buffer_vec: Vec<StreamDataPoint> = self.state.data_buffer.iter().cloned().collect();
666
667            if trigger_fn(&buffer_vec) {
668                let now = SystemTime::now();
669                let mut window = StreamWindow::new(now, now);
670
671                while let Some(point) = self.state.data_buffer.pop_front() {
672                    window.add_point(point);
673                }
674
675                if !window.is_empty() {
676                    self.state.windows.push(window);
677                }
678            }
679        }
680
681        Ok(())
682    }
683
684    /// Process custom windows
685    fn process_custom_windows(
686        &mut self,
687        trigger_fn: &Box<dyn Fn(&[StreamDataPoint]) -> bool + Send + Sync>,
688    ) -> SklResult<()> {
689        let buffer_vec: Vec<StreamDataPoint> = self.state.data_buffer.iter().cloned().collect();
690
691        if trigger_fn(&buffer_vec) {
692            let now = SystemTime::now();
693            let mut window = StreamWindow::new(now, now);
694
695            while let Some(point) = self.state.data_buffer.pop_front() {
696                window.add_point(point);
697            }
698
699            if !window.is_empty() {
700                self.state.windows.push(window);
701            }
702        }
703
704        Ok(())
705    }
706
707    /// Check if model should be updated
708    fn check_model_update(&mut self) -> SklResult<()> {
709        let should_update = match &self.state.update_strategy {
710            UpdateStrategy::Immediate => !self.state.data_buffer.is_empty(),
711            UpdateStrategy::Batch { batch_size } => self.state.data_buffer.len() >= *batch_size,
712            UpdateStrategy::TimeBased { interval } => {
713                self.state
714                    .statistics
715                    .last_update
716                    .elapsed()
717                    .unwrap_or(Duration::ZERO)
718                    >= *interval
719            }
720            UpdateStrategy::Adaptive {
721                drift_threshold,
722                min_interval,
723                max_interval,
724            } => self.check_adaptive_update(*drift_threshold, *min_interval, *max_interval),
725            UpdateStrategy::Custom { trigger_fn } => {
726                if let Some(window) = self.state.windows.last() {
727                    trigger_fn(window, &self.state.statistics)
728                } else {
729                    false
730                }
731            }
732        };
733
734        if should_update {
735            self.update_model()?;
736        }
737
738        Ok(())
739    }
740
741    /// Check if adaptive update should be triggered
742    fn check_adaptive_update(
743        &self,
744        drift_threshold: f64,
745        min_interval: Duration,
746        max_interval: Duration,
747    ) -> bool {
748        let elapsed = self
749            .state
750            .statistics
751            .last_update
752            .elapsed()
753            .unwrap_or(Duration::ZERO);
754
755        if elapsed < min_interval {
756            return false;
757        }
758
759        if elapsed >= max_interval {
760            return true;
761        }
762
763        // Check for drift (simplified)
764        let drift_score = self
765            .state
766            .statistics
767            .drift_metrics
768            .get("feature_drift")
769            .unwrap_or(&0.0);
770        *drift_score > drift_threshold
771    }
772
773    /// Update the model with recent data
774    fn update_model(&mut self) -> SklResult<()> {
775        if let Some(window) = self.state.windows.last() {
776            if !window.is_empty() {
777                let features = window.features_matrix()?;
778                let targets = window.targets_array();
779
780                if let Some(targets_array) = targets {
781                    // Incremental learning (simplified)
782                    self.state
783                        .fitted_estimator
784                        .fit(&features.view(), &targets_array.view())?;
785
786                    self.state.statistics.last_update = SystemTime::now();
787                    self.state
788                        .model_state
789                        .insert("last_update_samples".to_string(), window.size() as f64);
790                }
791            }
792        }
793
794        Ok(())
795    }
796
797    /// Get current statistics
798    #[must_use]
799    pub fn statistics(&self) -> &StreamStats {
800        &self.state.statistics
801    }
802
803    /// Get current buffer size
804    #[must_use]
805    pub fn buffer_size(&self) -> usize {
806        self.state.data_buffer.len()
807    }
808
809    /// Get number of active windows
810    #[must_use]
811    pub fn active_windows(&self) -> usize {
812        self.state.windows.len()
813    }
814
815    /// Checkpoint the current state
816    pub fn checkpoint(&self) -> SklResult<HashMap<String, String>> {
817        let mut checkpoint = HashMap::new();
818        checkpoint.insert(
819            "total_samples".to_string(),
820            self.state.statistics.total_samples.to_string(),
821        );
822        checkpoint.insert(
823            "buffer_size".to_string(),
824            self.state.data_buffer.len().to_string(),
825        );
826        checkpoint.insert(
827            "active_windows".to_string(),
828            self.state.windows.len().to_string(),
829        );
830        checkpoint.insert(
831            "throughput".to_string(),
832            self.state.statistics.throughput.to_string(),
833        );
834
835        Ok(checkpoint)
836    }
837
838    /// Clear internal buffers and windows
839    pub fn clear_buffers(&mut self) {
840        self.state.data_buffer.clear();
841        self.state.windows.clear();
842    }
843
844    /// Get drift detection metrics
845    #[must_use]
846    pub fn drift_metrics(&self) -> &HashMap<String, f64> {
847        &self.state.statistics.drift_metrics
848    }
849
850    /// Detect concept drift (simplified implementation)
851    pub fn detect_drift(
852        &mut self,
853        reference_window: &StreamWindow,
854        current_window: &StreamWindow,
855    ) -> SklResult<f64> {
856        if reference_window.is_empty() || current_window.is_empty() {
857            return Ok(0.0);
858        }
859
860        let ref_features = reference_window.features_matrix()?;
861        let cur_features = current_window.features_matrix()?;
862
863        // Simple drift detection using mean difference
864        let ref_mean = ref_features.mean_axis(Axis(0)).unwrap();
865        let cur_mean = cur_features.mean_axis(Axis(0)).unwrap();
866
867        let drift_score = (&ref_mean - &cur_mean).mapv(|x| x * x).sum().sqrt();
868
869        // Update drift metrics
870        self.state
871            .statistics
872            .drift_metrics
873            .insert("feature_drift".to_string(), drift_score);
874
875        Ok(drift_score)
876    }
877}
878
879#[allow(non_snake_case)]
880#[cfg(test)]
881mod tests {
882    use super::*;
883    use crate::MockPredictor;
884    use scirs2_core::ndarray::array;
885
886    #[test]
887    fn test_stream_data_point() {
888        let features = array![1.0, 2.0, 3.0];
889        let point =
890            StreamDataPoint::new(features.clone(), "test_point".to_string()).with_target(1.0);
891
892        assert_eq!(point.id, "test_point");
893        assert_eq!(point.features, features);
894        assert_eq!(point.target, Some(1.0));
895    }
896
897    #[test]
898    fn test_stream_window() {
899        let start_time = SystemTime::now();
900        let end_time = start_time + Duration::from_secs(60);
901        let mut window = StreamWindow::new(start_time, end_time);
902
903        let point1 = StreamDataPoint::new(array![1.0, 2.0], "point1".to_string());
904        let point2 = StreamDataPoint::new(array![3.0, 4.0], "point2".to_string());
905
906        window.add_point(point1);
907        window.add_point(point2);
908
909        assert_eq!(window.size(), 2);
910
911        let features = window.features_matrix().unwrap();
912        assert_eq!(features.nrows(), 2);
913        assert_eq!(features.ncols(), 2);
914    }
915
916    #[test]
917    fn test_streaming_pipeline_creation() {
918        let base_estimator = Box::new(MockPredictor::new());
919        let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
920
921        assert!(matches!(
922            pipeline.config.windowing,
923            WindowingStrategy::TumblingTime { .. }
924        ));
925    }
926
927    #[test]
928    fn test_streaming_pipeline_fit() {
929        let x = array![[1.0, 2.0], [3.0, 4.0]];
930        let y = array![1.0, 0.0];
931
932        let base_estimator = Box::new(MockPredictor::new());
933        let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
934
935        let fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
936        assert_eq!(fitted_pipeline.state.n_features_in, 2);
937        assert_eq!(fitted_pipeline.state.statistics.total_samples, 2);
938    }
939
940    #[test]
941    fn test_point_processing() {
942        let x = array![[1.0, 2.0], [3.0, 4.0]];
943        let y = array![1.0, 0.0];
944
945        let base_estimator = Box::new(MockPredictor::new());
946        let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
947
948        let mut fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
949
950        let point = StreamDataPoint::new(array![5.0, 6.0], "test_point".to_string());
951        let prediction = fitted_pipeline.process_point(point).unwrap();
952
953        assert!(prediction.is_some());
954        assert_eq!(fitted_pipeline.active_windows(), 1);
955    }
956
957    #[test]
958    fn test_window_strategies() {
959        let base_estimator = Box::new(MockPredictor::new());
960
961        // Test tumbling count windows
962        let pipeline = StreamingPipeline::new(
963            base_estimator,
964            StreamConfig {
965                windowing: WindowingStrategy::TumblingCount { count: 2 },
966                ..StreamConfig::default()
967            },
968        );
969
970        assert!(matches!(
971            pipeline.config.windowing,
972            WindowingStrategy::TumblingCount { count: 2 }
973        ));
974    }
975
976    #[test]
977    fn test_update_strategies() {
978        let base_estimator = Box::new(MockPredictor::new());
979        let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60))
980            .update_strategy(UpdateStrategy::Batch { batch_size: 10 });
981
982        assert!(matches!(
983            pipeline.update_strategy,
984            UpdateStrategy::Batch { batch_size: 10 }
985        ));
986    }
987}