1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8 error::Result as SklResult,
9 prelude::{Predict, SklearsError},
10 traits::{Estimator, Fit, Untrained},
11 types::Float,
12};
13use std::collections::{HashMap, VecDeque};
14use std::time::{Duration, Instant, SystemTime};
15
16use crate::{PipelinePredictor, PipelineStep};
17
18#[derive(Debug, Clone)]
20pub struct StreamDataPoint {
21 pub features: Array1<f64>,
23 pub target: Option<f64>,
25 pub timestamp: SystemTime,
27 pub metadata: HashMap<String, String>,
29 pub id: String,
31}
32
33impl StreamDataPoint {
34 #[must_use]
36 pub fn new(features: Array1<f64>, id: String) -> Self {
37 Self {
38 features,
39 target: None,
40 timestamp: SystemTime::now(),
41 metadata: HashMap::new(),
42 id,
43 }
44 }
45
46 #[must_use]
48 pub fn with_target(mut self, target: f64) -> Self {
49 self.target = Some(target);
50 self
51 }
52
53 #[must_use]
55 pub fn with_timestamp(mut self, timestamp: SystemTime) -> Self {
56 self.timestamp = timestamp;
57 self
58 }
59
60 #[must_use]
62 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
63 self.metadata = metadata;
64 self
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct StreamWindow {
71 pub data_points: Vec<StreamDataPoint>,
73 pub start_time: SystemTime,
75 pub end_time: SystemTime,
77 pub metadata: HashMap<String, String>,
79}
80
81impl StreamWindow {
82 #[must_use]
84 pub fn new(start_time: SystemTime, end_time: SystemTime) -> Self {
85 Self {
86 data_points: Vec::new(),
87 start_time,
88 end_time,
89 metadata: HashMap::new(),
90 }
91 }
92
93 pub fn add_point(&mut self, point: StreamDataPoint) {
95 self.data_points.push(point);
96 }
97
98 pub fn features_matrix(&self) -> SklResult<Array2<f64>> {
100 if self.data_points.is_empty() {
101 return Err(SklearsError::InvalidInput("Empty window".to_string()));
102 }
103
104 let n_samples = self.data_points.len();
105 let n_features = self.data_points[0].features.len();
106
107 let mut features = Array2::zeros((n_samples, n_features));
108 for (i, point) in self.data_points.iter().enumerate() {
109 features.row_mut(i).assign(&point.features);
110 }
111
112 Ok(features)
113 }
114
115 #[must_use]
117 pub fn targets_array(&self) -> Option<Array1<f64>> {
118 if self.data_points.iter().all(|p| p.target.is_some()) {
119 Some(Array1::from_vec(
120 self.data_points.iter().map(|p| p.target.unwrap()).collect(),
121 ))
122 } else {
123 None
124 }
125 }
126
127 #[must_use]
129 pub fn size(&self) -> usize {
130 self.data_points.len()
131 }
132
133 #[must_use]
135 pub fn is_empty(&self) -> bool {
136 self.data_points.is_empty()
137 }
138}
139
140pub enum WindowingStrategy {
142 TumblingTime {
144 duration: Duration,
146 },
147 SlidingTime {
149 duration: Duration,
151 slide: Duration,
153 },
154 TumblingCount {
156 count: usize,
158 },
159 SlidingCount {
161 size: usize,
163 step: usize,
165 },
166 Session {
168 gap: Duration,
170 },
171 Custom {
173 trigger_fn: Box<dyn Fn(&[StreamDataPoint]) -> bool + Send + Sync>,
175 },
176}
177
178pub struct StreamConfig {
180 pub windowing: WindowingStrategy,
182 pub buffer_size: usize,
184 pub parallelism: usize,
186 pub backpressure_threshold: usize,
188 pub latency_target: Duration,
190 pub checkpoint_interval: Duration,
192 pub state_management: StateManagement,
194}
195
196impl Default for StreamConfig {
197 fn default() -> Self {
198 Self {
199 windowing: WindowingStrategy::TumblingTime {
200 duration: Duration::from_secs(60),
201 },
202 buffer_size: 10000,
203 parallelism: 1,
204 backpressure_threshold: 8000,
205 latency_target: Duration::from_millis(100),
206 checkpoint_interval: Duration::from_secs(300),
207 state_management: StateManagement::InMemory,
208 }
209 }
210}
211
212#[derive(Debug, Clone)]
214pub enum StateManagement {
215 InMemory,
217 Snapshots {
219 directory: String,
221 interval: Duration,
223 },
224 WriteAheadLog {
226 log_path: String,
228 },
229 External {
231 config: HashMap<String, String>,
233 },
234}
235
236pub enum UpdateStrategy {
238 Immediate,
240 Batch {
242 batch_size: usize,
244 },
245 TimeBased {
247 interval: Duration,
249 },
250 Adaptive {
252 drift_threshold: f64,
254 min_interval: Duration,
256 max_interval: Duration,
258 },
259 Custom {
261 trigger_fn: Box<dyn Fn(&StreamWindow, &StreamStats) -> bool + Send + Sync>,
263 },
264}
265
266#[derive(Debug, Clone)]
268pub struct StreamStats {
269 pub total_samples: usize,
271 pub throughput: f64,
273 pub avg_latency: f64,
275 pub buffer_utilization: f64,
277 pub accuracy: Option<f64>,
279 pub drift_metrics: HashMap<String, f64>,
281 pub error_rate: f64,
283 pub start_time: SystemTime,
285 pub last_update: SystemTime,
287}
288
289impl Default for StreamStats {
290 fn default() -> Self {
291 let now = SystemTime::now();
292 Self {
293 total_samples: 0,
294 throughput: 0.0,
295 avg_latency: 0.0,
296 buffer_utilization: 0.0,
297 accuracy: None,
298 drift_metrics: HashMap::new(),
299 error_rate: 0.0,
300 start_time: now,
301 last_update: now,
302 }
303 }
304}
305
306pub struct StreamingPipeline<S = Untrained> {
308 state: S,
309 base_estimator: Option<Box<dyn PipelinePredictor>>,
310 config: StreamConfig,
311 update_strategy: UpdateStrategy,
312 data_buffer: VecDeque<StreamDataPoint>,
313 windows: Vec<StreamWindow>,
314 statistics: StreamStats,
315}
316
317pub struct StreamingPipelineTrained {
319 fitted_estimator: Box<dyn PipelinePredictor>,
320 config: StreamConfig,
321 update_strategy: UpdateStrategy,
322 data_buffer: VecDeque<StreamDataPoint>,
323 windows: Vec<StreamWindow>,
324 statistics: StreamStats,
325 model_state: HashMap<String, f64>,
326 n_features_in: usize,
327 feature_names_in: Option<Vec<String>>,
328}
329
330impl StreamingPipeline<Untrained> {
331 #[must_use]
333 pub fn new(base_estimator: Box<dyn PipelinePredictor>, config: StreamConfig) -> Self {
334 Self {
335 state: Untrained,
336 base_estimator: Some(base_estimator),
337 config,
338 update_strategy: UpdateStrategy::Batch { batch_size: 100 },
339 data_buffer: VecDeque::new(),
340 windows: Vec::new(),
341 statistics: StreamStats::default(),
342 }
343 }
344
345 #[must_use]
347 pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
348 self.update_strategy = strategy;
349 self
350 }
351
352 #[must_use]
354 pub fn tumbling_time(
355 base_estimator: Box<dyn PipelinePredictor>,
356 window_duration: Duration,
357 ) -> Self {
358 let config = StreamConfig {
359 windowing: WindowingStrategy::TumblingTime {
360 duration: window_duration,
361 },
362 ..StreamConfig::default()
363 };
364 Self::new(base_estimator, config)
365 }
366
367 #[must_use]
369 pub fn sliding_window(
370 base_estimator: Box<dyn PipelinePredictor>,
371 window_size: usize,
372 slide_step: usize,
373 ) -> Self {
374 let config = StreamConfig {
375 windowing: WindowingStrategy::SlidingCount {
376 size: window_size,
377 step: slide_step,
378 },
379 ..StreamConfig::default()
380 };
381 Self::new(base_estimator, config)
382 }
383
384 #[must_use]
386 pub fn session_window(
387 base_estimator: Box<dyn PipelinePredictor>,
388 session_gap: Duration,
389 ) -> Self {
390 let config = StreamConfig {
391 windowing: WindowingStrategy::Session { gap: session_gap },
392 ..StreamConfig::default()
393 };
394 Self::new(base_estimator, config)
395 }
396}
397
398impl Estimator for StreamingPipeline<Untrained> {
399 type Config = ();
400 type Error = SklearsError;
401 type Float = Float;
402
403 fn config(&self) -> &Self::Config {
404 &()
405 }
406}
407
408impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for StreamingPipeline<Untrained> {
409 type Fitted = StreamingPipeline<StreamingPipelineTrained>;
410
411 fn fit(
412 self,
413 x: &ArrayView2<'_, Float>,
414 y: &Option<&ArrayView1<'_, Float>>,
415 ) -> SklResult<Self::Fitted> {
416 let mut base_estimator = self
417 .base_estimator
418 .ok_or_else(|| SklearsError::InvalidInput("No base estimator provided".to_string()))?;
419
420 if let Some(y_ref) = y {
422 base_estimator.fit(x, y_ref)?;
423 } else {
424 return Err(SklearsError::InvalidInput(
425 "No target values provided for initial training".to_string(),
426 ));
427 }
428
429 let mut model_state = HashMap::new();
431 model_state.insert("batch_training_samples".to_string(), x.nrows() as f64);
432
433 let mut statistics = self.statistics;
434 statistics.total_samples = x.nrows();
435 statistics.start_time = SystemTime::now();
436 statistics.last_update = SystemTime::now();
437
438 Ok(StreamingPipeline {
439 state: StreamingPipelineTrained {
440 fitted_estimator: base_estimator,
441 config: self.config,
442 update_strategy: self.update_strategy,
443 data_buffer: self.data_buffer,
444 windows: self.windows,
445 statistics,
446 model_state,
447 n_features_in: x.ncols(),
448 feature_names_in: None,
449 },
450 base_estimator: None,
451 config: StreamConfig::default(),
452 update_strategy: UpdateStrategy::Immediate,
453 data_buffer: VecDeque::new(),
454 windows: Vec::new(),
455 statistics: StreamStats::default(),
456 })
457 }
458}
459
460impl StreamingPipeline<StreamingPipelineTrained> {
461 pub fn process_point(&mut self, point: StreamDataPoint) -> SklResult<Option<Array1<f64>>> {
463 let start_time = Instant::now();
464
465 if self.state.data_buffer.len() >= self.state.config.backpressure_threshold {
467 return Err(SklearsError::InvalidInput(
468 "Backpressure threshold exceeded".to_string(),
469 ));
470 }
471
472 self.state.data_buffer.push_back(point.clone());
474
475 self.state.statistics.total_samples += 1;
477 self.state.statistics.buffer_utilization =
478 self.state.data_buffer.len() as f64 / self.state.config.buffer_size as f64;
479
480 let features_2d =
482 Array2::from_shape_vec((1, point.features.len()), point.features.to_vec()).map_err(
483 |e| SklearsError::InvalidData {
484 reason: format!("Feature reshaping failed: {e}"),
485 },
486 )?;
487
488 let prediction = self.state.fitted_estimator.predict(&features_2d.view())?;
490
491 self.process_windows()?;
493
494 self.check_model_update()?;
496
497 let processing_time = start_time.elapsed().as_millis() as f64;
499 self.state.statistics.avg_latency =
500 (self.state.statistics.avg_latency * 0.9) + (processing_time * 0.1);
501
502 let elapsed = self
504 .state
505 .statistics
506 .start_time
507 .elapsed()
508 .unwrap_or(Duration::from_secs(1));
509 self.state.statistics.throughput =
510 self.state.statistics.total_samples as f64 / elapsed.as_secs_f64();
511
512 Ok(Some(prediction))
513 }
514
515 pub fn process_batch(&mut self, points: Vec<StreamDataPoint>) -> SklResult<Array2<f64>> {
517 let mut predictions = Vec::new();
518
519 for point in points {
520 if let Some(pred) = self.process_point(point)? {
521 predictions.extend(pred.iter().copied());
522 }
523 }
524
525 if predictions.is_empty() {
526 return Ok(Array2::zeros((0, 1)));
527 }
528
529 let n_predictions = predictions.len();
530 Array2::from_shape_vec((n_predictions, 1), predictions).map_err(|e| {
531 SklearsError::InvalidData {
532 reason: format!("Batch prediction reshape failed: {e}"),
533 }
534 })
535 }
536
537 fn process_windows(&mut self) -> SklResult<()> {
539 match &self.state.config.windowing {
540 WindowingStrategy::TumblingTime { duration } => {
541 self.process_tumbling_time_windows(*duration)
542 }
543 WindowingStrategy::SlidingTime { duration, slide } => {
544 self.process_sliding_time_windows(*duration, *slide)
545 }
546 WindowingStrategy::TumblingCount { count } => {
547 self.process_tumbling_count_windows(*count)
548 }
549 WindowingStrategy::SlidingCount { size, step } => {
550 self.process_sliding_count_windows(*size, *step)
551 }
552 WindowingStrategy::Session { gap } => self.process_session_windows(*gap),
553 WindowingStrategy::Custom { .. } => {
554 self.process_custom_windows_safe()
556 }
557 }
558 }
559
560 fn process_tumbling_time_windows(&mut self, duration: Duration) -> SklResult<()> {
562 let now = SystemTime::now();
563
564 if self.state.windows.is_empty() {
566 let window = StreamWindow::new(now, now + duration);
567 self.state.windows.push(window);
568 }
569
570 while let Some(point) = self.state.data_buffer.pop_front() {
572 if let Some(current_window) = self.state.windows.last_mut() {
573 if point.timestamp <= current_window.end_time {
574 current_window.add_point(point);
575 } else {
576 let mut new_window = StreamWindow::new(
578 current_window.end_time,
579 current_window.end_time + duration,
580 );
581 new_window.add_point(point);
582 self.state.windows.push(new_window);
583 }
584 }
585 }
586
587 self.state.windows.retain(|w| w.end_time > now);
589
590 Ok(())
591 }
592
593 fn process_sliding_time_windows(
595 &mut self,
596 duration: Duration,
597 slide: Duration,
598 ) -> SklResult<()> {
599 self.process_tumbling_time_windows(duration)
601 }
602
603 fn process_tumbling_count_windows(&mut self, count: usize) -> SklResult<()> {
605 let now = SystemTime::now();
606
607 while self.state.data_buffer.len() >= count {
608 let mut window = StreamWindow::new(now, now);
609 for _ in 0..count {
610 if let Some(point) = self.state.data_buffer.pop_front() {
611 window.add_point(point);
612 }
613 }
614 self.state.windows.push(window);
615 }
616
617 Ok(())
618 }
619
620 fn process_sliding_count_windows(&mut self, size: usize, step: usize) -> SklResult<()> {
622 self.process_tumbling_count_windows(step)
624 }
625
626 fn process_session_windows(&mut self, gap: Duration) -> SklResult<()> {
628 let now = SystemTime::now();
630
631 if let Some(mut current_window) = self.state.windows.pop() {
632 while let Some(point) = self.state.data_buffer.pop_front() {
633 let time_since_last = point
634 .timestamp
635 .duration_since(current_window.end_time)
636 .unwrap_or(Duration::ZERO);
637
638 if time_since_last <= gap {
639 current_window.add_point(point.clone());
640 current_window.end_time = point.timestamp;
641 } else {
642 self.state.windows.push(current_window);
644 current_window = StreamWindow::new(point.timestamp, point.timestamp);
645 current_window.add_point(point);
646 }
647 }
648 self.state.windows.push(current_window);
649 } else if !self.state.data_buffer.is_empty() {
650 if let Some(point) = self.state.data_buffer.pop_front() {
652 let mut window = StreamWindow::new(point.timestamp, point.timestamp);
653 window.add_point(point);
654 self.state.windows.push(window);
655 }
656 }
657
658 Ok(())
659 }
660
661 fn process_custom_windows_safe(&mut self) -> SklResult<()> {
663 if let WindowingStrategy::Custom { trigger_fn } = &self.state.config.windowing {
665 let buffer_vec: Vec<StreamDataPoint> = self.state.data_buffer.iter().cloned().collect();
666
667 if trigger_fn(&buffer_vec) {
668 let now = SystemTime::now();
669 let mut window = StreamWindow::new(now, now);
670
671 while let Some(point) = self.state.data_buffer.pop_front() {
672 window.add_point(point);
673 }
674
675 if !window.is_empty() {
676 self.state.windows.push(window);
677 }
678 }
679 }
680
681 Ok(())
682 }
683
684 fn process_custom_windows(
686 &mut self,
687 trigger_fn: &Box<dyn Fn(&[StreamDataPoint]) -> bool + Send + Sync>,
688 ) -> SklResult<()> {
689 let buffer_vec: Vec<StreamDataPoint> = self.state.data_buffer.iter().cloned().collect();
690
691 if trigger_fn(&buffer_vec) {
692 let now = SystemTime::now();
693 let mut window = StreamWindow::new(now, now);
694
695 while let Some(point) = self.state.data_buffer.pop_front() {
696 window.add_point(point);
697 }
698
699 if !window.is_empty() {
700 self.state.windows.push(window);
701 }
702 }
703
704 Ok(())
705 }
706
707 fn check_model_update(&mut self) -> SklResult<()> {
709 let should_update = match &self.state.update_strategy {
710 UpdateStrategy::Immediate => !self.state.data_buffer.is_empty(),
711 UpdateStrategy::Batch { batch_size } => self.state.data_buffer.len() >= *batch_size,
712 UpdateStrategy::TimeBased { interval } => {
713 self.state
714 .statistics
715 .last_update
716 .elapsed()
717 .unwrap_or(Duration::ZERO)
718 >= *interval
719 }
720 UpdateStrategy::Adaptive {
721 drift_threshold,
722 min_interval,
723 max_interval,
724 } => self.check_adaptive_update(*drift_threshold, *min_interval, *max_interval),
725 UpdateStrategy::Custom { trigger_fn } => {
726 if let Some(window) = self.state.windows.last() {
727 trigger_fn(window, &self.state.statistics)
728 } else {
729 false
730 }
731 }
732 };
733
734 if should_update {
735 self.update_model()?;
736 }
737
738 Ok(())
739 }
740
741 fn check_adaptive_update(
743 &self,
744 drift_threshold: f64,
745 min_interval: Duration,
746 max_interval: Duration,
747 ) -> bool {
748 let elapsed = self
749 .state
750 .statistics
751 .last_update
752 .elapsed()
753 .unwrap_or(Duration::ZERO);
754
755 if elapsed < min_interval {
756 return false;
757 }
758
759 if elapsed >= max_interval {
760 return true;
761 }
762
763 let drift_score = self
765 .state
766 .statistics
767 .drift_metrics
768 .get("feature_drift")
769 .unwrap_or(&0.0);
770 *drift_score > drift_threshold
771 }
772
773 fn update_model(&mut self) -> SklResult<()> {
775 if let Some(window) = self.state.windows.last() {
776 if !window.is_empty() {
777 let features = window.features_matrix()?;
778 let targets = window.targets_array();
779
780 if let Some(targets_array) = targets {
781 self.state
783 .fitted_estimator
784 .fit(&features.view(), &targets_array.view())?;
785
786 self.state.statistics.last_update = SystemTime::now();
787 self.state
788 .model_state
789 .insert("last_update_samples".to_string(), window.size() as f64);
790 }
791 }
792 }
793
794 Ok(())
795 }
796
797 #[must_use]
799 pub fn statistics(&self) -> &StreamStats {
800 &self.state.statistics
801 }
802
803 #[must_use]
805 pub fn buffer_size(&self) -> usize {
806 self.state.data_buffer.len()
807 }
808
809 #[must_use]
811 pub fn active_windows(&self) -> usize {
812 self.state.windows.len()
813 }
814
815 pub fn checkpoint(&self) -> SklResult<HashMap<String, String>> {
817 let mut checkpoint = HashMap::new();
818 checkpoint.insert(
819 "total_samples".to_string(),
820 self.state.statistics.total_samples.to_string(),
821 );
822 checkpoint.insert(
823 "buffer_size".to_string(),
824 self.state.data_buffer.len().to_string(),
825 );
826 checkpoint.insert(
827 "active_windows".to_string(),
828 self.state.windows.len().to_string(),
829 );
830 checkpoint.insert(
831 "throughput".to_string(),
832 self.state.statistics.throughput.to_string(),
833 );
834
835 Ok(checkpoint)
836 }
837
838 pub fn clear_buffers(&mut self) {
840 self.state.data_buffer.clear();
841 self.state.windows.clear();
842 }
843
844 #[must_use]
846 pub fn drift_metrics(&self) -> &HashMap<String, f64> {
847 &self.state.statistics.drift_metrics
848 }
849
850 pub fn detect_drift(
852 &mut self,
853 reference_window: &StreamWindow,
854 current_window: &StreamWindow,
855 ) -> SklResult<f64> {
856 if reference_window.is_empty() || current_window.is_empty() {
857 return Ok(0.0);
858 }
859
860 let ref_features = reference_window.features_matrix()?;
861 let cur_features = current_window.features_matrix()?;
862
863 let ref_mean = ref_features.mean_axis(Axis(0)).unwrap();
865 let cur_mean = cur_features.mean_axis(Axis(0)).unwrap();
866
867 let drift_score = (&ref_mean - &cur_mean).mapv(|x| x * x).sum().sqrt();
868
869 self.state
871 .statistics
872 .drift_metrics
873 .insert("feature_drift".to_string(), drift_score);
874
875 Ok(drift_score)
876 }
877}
878
879#[allow(non_snake_case)]
880#[cfg(test)]
881mod tests {
882 use super::*;
883 use crate::MockPredictor;
884 use scirs2_core::ndarray::array;
885
886 #[test]
887 fn test_stream_data_point() {
888 let features = array![1.0, 2.0, 3.0];
889 let point =
890 StreamDataPoint::new(features.clone(), "test_point".to_string()).with_target(1.0);
891
892 assert_eq!(point.id, "test_point");
893 assert_eq!(point.features, features);
894 assert_eq!(point.target, Some(1.0));
895 }
896
897 #[test]
898 fn test_stream_window() {
899 let start_time = SystemTime::now();
900 let end_time = start_time + Duration::from_secs(60);
901 let mut window = StreamWindow::new(start_time, end_time);
902
903 let point1 = StreamDataPoint::new(array![1.0, 2.0], "point1".to_string());
904 let point2 = StreamDataPoint::new(array![3.0, 4.0], "point2".to_string());
905
906 window.add_point(point1);
907 window.add_point(point2);
908
909 assert_eq!(window.size(), 2);
910
911 let features = window.features_matrix().unwrap();
912 assert_eq!(features.nrows(), 2);
913 assert_eq!(features.ncols(), 2);
914 }
915
916 #[test]
917 fn test_streaming_pipeline_creation() {
918 let base_estimator = Box::new(MockPredictor::new());
919 let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
920
921 assert!(matches!(
922 pipeline.config.windowing,
923 WindowingStrategy::TumblingTime { .. }
924 ));
925 }
926
927 #[test]
928 fn test_streaming_pipeline_fit() {
929 let x = array![[1.0, 2.0], [3.0, 4.0]];
930 let y = array![1.0, 0.0];
931
932 let base_estimator = Box::new(MockPredictor::new());
933 let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
934
935 let fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
936 assert_eq!(fitted_pipeline.state.n_features_in, 2);
937 assert_eq!(fitted_pipeline.state.statistics.total_samples, 2);
938 }
939
940 #[test]
941 fn test_point_processing() {
942 let x = array![[1.0, 2.0], [3.0, 4.0]];
943 let y = array![1.0, 0.0];
944
945 let base_estimator = Box::new(MockPredictor::new());
946 let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
947
948 let mut fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
949
950 let point = StreamDataPoint::new(array![5.0, 6.0], "test_point".to_string());
951 let prediction = fitted_pipeline.process_point(point).unwrap();
952
953 assert!(prediction.is_some());
954 assert_eq!(fitted_pipeline.active_windows(), 1);
955 }
956
957 #[test]
958 fn test_window_strategies() {
959 let base_estimator = Box::new(MockPredictor::new());
960
961 let pipeline = StreamingPipeline::new(
963 base_estimator,
964 StreamConfig {
965 windowing: WindowingStrategy::TumblingCount { count: 2 },
966 ..StreamConfig::default()
967 },
968 );
969
970 assert!(matches!(
971 pipeline.config.windowing,
972 WindowingStrategy::TumblingCount { count: 2 }
973 ));
974 }
975
976 #[test]
977 fn test_update_strategies() {
978 let base_estimator = Box::new(MockPredictor::new());
979 let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60))
980 .update_strategy(UpdateStrategy::Batch { batch_size: 10 });
981
982 assert!(matches!(
983 pipeline.update_strategy,
984 UpdateStrategy::Batch { batch_size: 10 }
985 ));
986 }
987}