1use std::collections::HashMap;
11use std::hash::{Hash, Hasher};
12use std::sync::{Arc, RwLock};
13use std::time::Instant;
14
15use scirs2_core::ndarray::{s, Array2, ArrayView2};
16
17#[cfg(feature = "parallel")]
18use rayon::prelude::*;
19
20#[cfg(feature = "serde")]
21use serde::{Deserialize, Serialize};
22
23use sklears_core::{
24 error::{Result, SklearsError},
25 traits::Transform,
26};
27
28use crate::streaming::{StreamingConfig, StreamingStats, StreamingTransformer};
29
30#[derive(Clone, Debug)]
32struct CacheEntry<T> {
33 result: T,
34 timestamp: Instant,
35 access_count: usize,
36}
37
38#[derive(Debug, Clone)]
40#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
41pub struct CacheConfig {
42 pub max_entries: usize,
44 pub ttl_seconds: u64,
46 pub enabled: bool,
48}
49
50impl Default for CacheConfig {
51 fn default() -> Self {
52 Self {
53 max_entries: 100,
54 ttl_seconds: 3600, enabled: true,
56 }
57 }
58}
59
60pub struct TransformationCache<T> {
62 cache: Arc<RwLock<HashMap<u64, CacheEntry<T>>>>,
63 config: CacheConfig,
64}
65
66impl<T: Clone> TransformationCache<T> {
67 pub fn new(config: CacheConfig) -> Self {
68 Self {
69 cache: Arc::new(RwLock::new(HashMap::new())),
70 config,
71 }
72 }
73
74 fn generate_key<U: Hash>(&self, input: U) -> u64 {
76 let mut hasher = std::collections::hash_map::DefaultHasher::new();
77 input.hash(&mut hasher);
78 hasher.finish()
79 }
80
81 pub fn get(&self, key: u64) -> Option<T> {
83 if !self.config.enabled {
84 return None;
85 }
86
87 let mut cache = self.cache.write().ok()?;
88
89 if let Some(entry) = cache.get_mut(&key) {
91 let age = entry.timestamp.elapsed();
92 if age.as_secs() <= self.config.ttl_seconds {
93 entry.access_count += 1;
94 return Some(entry.result.clone());
95 } else {
96 cache.remove(&key);
98 }
99 }
100
101 None
102 }
103
104 pub fn put(&self, key: u64, value: T) {
106 if !self.config.enabled {
107 return;
108 }
109
110 let mut cache = self.cache.write().unwrap();
111
112 if cache.len() >= self.config.max_entries {
114 self.evict_lru(&mut cache);
115 }
116
117 cache.insert(
118 key,
119 CacheEntry {
120 result: value,
121 timestamp: Instant::now(),
122 access_count: 1,
123 },
124 );
125 }
126
127 fn evict_lru(&self, cache: &mut HashMap<u64, CacheEntry<T>>) {
129 if let Some((key_to_remove, _)) = cache.iter().min_by_key(|(_, entry)| entry.access_count) {
130 let key_to_remove = *key_to_remove;
131 cache.remove(&key_to_remove);
132 }
133 }
134
135 pub fn clear(&self) {
137 if let Ok(mut cache) = self.cache.write() {
138 cache.clear();
139 }
140 }
141
142 pub fn stats(&self) -> CacheStats {
144 let cache = self.cache.read().unwrap();
145 CacheStats {
146 entries: cache.len(),
147 max_entries: self.config.max_entries,
148 enabled: self.config.enabled,
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct CacheStats {
156 pub entries: usize,
157 pub max_entries: usize,
158 pub enabled: bool,
159}
160
161pub type ConditionFn = Box<dyn Fn(&ArrayView2<f64>) -> bool + Send + Sync>;
163
164pub struct ConditionalStepConfig<T> {
166 pub transformer: T,
168 pub condition: ConditionFn,
170 pub name: String,
172 pub skip_on_false: bool,
174}
175
176impl<T: std::fmt::Debug> std::fmt::Debug for ConditionalStepConfig<T> {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 f.debug_struct("ConditionalStepConfig")
179 .field("transformer", &self.transformer)
180 .field("condition", &"<function>")
181 .field("name", &self.name)
182 .field("skip_on_false", &self.skip_on_false)
183 .finish()
184 }
185}
186
187pub struct ConditionalStep<T> {
189 config: ConditionalStepConfig<T>,
190 fitted: bool,
191}
192
193impl<T> ConditionalStep<T>
194where
195 T: Transform<Array2<f64>, Array2<f64>> + Clone,
196{
197 pub fn new(config: ConditionalStepConfig<T>) -> Self {
198 Self {
199 config,
200 fitted: false,
201 }
202 }
203
204 pub fn check_condition(&self, data: &ArrayView2<f64>) -> bool {
206 (self.config.condition)(data)
207 }
208}
209
210impl<T> Transform<Array2<f64>, Array2<f64>> for ConditionalStep<T>
211where
212 T: Transform<Array2<f64>, Array2<f64>> + Clone,
213{
214 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
215 let data_view = data.view();
216
217 if self.check_condition(&data_view) {
218 self.config.transformer.transform(data)
219 } else if self.config.skip_on_false {
220 Ok(data.clone()) } else {
222 Err(SklearsError::InvalidInput(format!(
223 "Condition not met for step: {}",
224 self.config.name
225 )))
226 }
227 }
228}
229
230#[derive(Debug)]
232pub struct ParallelBranchConfig<T> {
233 pub transformers: Vec<T>,
235 pub branch_names: Vec<String>,
237 pub combination_strategy: BranchCombinationStrategy,
239}
240
241#[derive(Debug, Clone)]
243#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
244pub enum BranchCombinationStrategy {
245 Concatenate,
247 Average,
249 FirstSuccess,
251 WeightedCombination(Vec<f64>),
253}
254
255pub struct ParallelBranches<T> {
257 config: ParallelBranchConfig<T>,
258 fitted: bool,
259}
260
261impl<T> ParallelBranches<T>
262where
263 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
264{
265 pub fn new(config: ParallelBranchConfig<T>) -> Result<Self> {
266 if config.transformers.len() != config.branch_names.len() {
267 return Err(SklearsError::InvalidInput(
268 "Number of transformers must match number of branch names".to_string(),
269 ));
270 }
271
272 if let BranchCombinationStrategy::WeightedCombination(ref weights) =
273 config.combination_strategy
274 {
275 if weights.len() != config.transformers.len() {
276 return Err(SklearsError::InvalidInput(
277 "Number of weights must match number of transformers".to_string(),
278 ));
279 }
280 }
281
282 Ok(Self {
283 config,
284 fitted: false,
285 })
286 }
287}
288
289impl<T> Transform<Array2<f64>, Array2<f64>> for ParallelBranches<T>
290where
291 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
292{
293 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
294 #[cfg(feature = "parallel")]
296 let results: Result<Vec<Array2<f64>>> = self
297 .config
298 .transformers
299 .par_iter()
300 .zip(self.config.branch_names.par_iter())
301 .map(|(transformer, name)| {
302 transformer.transform(data).map_err(|e| {
303 SklearsError::TransformError(format!("Error in branch '{}': {}", name, e))
304 })
305 })
306 .collect();
307
308 #[cfg(not(feature = "parallel"))]
309 let results: Result<Vec<Array2<f64>>> = self
310 .config
311 .transformers
312 .iter()
313 .zip(self.config.branch_names.iter())
314 .map(|(transformer, name)| {
315 transformer.transform(data).map_err(|e| {
316 SklearsError::TransformError(format!("Error in branch '{}': {}", name, e))
317 })
318 })
319 .collect();
320
321 let branch_results = results?;
322
323 match &self.config.combination_strategy {
325 BranchCombinationStrategy::Concatenate => self.concatenate_results(branch_results),
326 BranchCombinationStrategy::Average => self.average_results(branch_results),
327 BranchCombinationStrategy::FirstSuccess => {
328 Ok(branch_results.into_iter().next().unwrap())
329 }
330 BranchCombinationStrategy::WeightedCombination(weights) => {
331 self.weighted_combination(branch_results, weights)
332 }
333 }
334 }
335}
336
337impl<T> ParallelBranches<T> {
338 fn concatenate_results(&self, results: Vec<Array2<f64>>) -> Result<Array2<f64>> {
340 if results.is_empty() {
341 return Err(SklearsError::InvalidInput(
342 "No results to concatenate".to_string(),
343 ));
344 }
345
346 let n_rows = results[0].nrows();
347 if !results.iter().all(|r| r.nrows() == n_rows) {
348 return Err(SklearsError::InvalidInput(
349 "All results must have the same number of rows for concatenation".to_string(),
350 ));
351 }
352
353 let total_cols: usize = results.iter().map(|r| r.ncols()).sum();
354 let mut combined = Array2::zeros((n_rows, total_cols));
355
356 let mut col_offset = 0;
357 for result in results {
358 let n_cols = result.ncols();
359 combined
360 .slice_mut(s![.., col_offset..col_offset + n_cols])
361 .assign(&result);
362 col_offset += n_cols;
363 }
364
365 Ok(combined)
366 }
367
368 fn average_results(&self, results: Vec<Array2<f64>>) -> Result<Array2<f64>> {
370 if results.is_empty() {
371 return Err(SklearsError::InvalidInput(
372 "No results to average".to_string(),
373 ));
374 }
375
376 let shape = results[0].raw_dim();
377 if !results.iter().all(|r| r.raw_dim() == shape) {
378 return Err(SklearsError::InvalidInput(
379 "All results must have the same shape for averaging".to_string(),
380 ));
381 }
382
383 let mut sum = Array2::zeros(shape);
384 for result in &results {
385 sum += result;
386 }
387 sum /= results.len() as f64;
388
389 Ok(sum)
390 }
391
392 fn weighted_combination(
394 &self,
395 results: Vec<Array2<f64>>,
396 weights: &[f64],
397 ) -> Result<Array2<f64>> {
398 if results.is_empty() {
399 return Err(SklearsError::InvalidInput(
400 "No results to combine".to_string(),
401 ));
402 }
403
404 let shape = results[0].raw_dim();
405 if !results.iter().all(|r| r.raw_dim() == shape) {
406 return Err(SklearsError::InvalidInput(
407 "All results must have the same shape for weighted combination".to_string(),
408 ));
409 }
410
411 let mut combined = Array2::zeros(shape);
412 for (result, &weight) in results.iter().zip(weights.iter()) {
413 combined += &(result * weight);
414 }
415
416 Ok(combined)
417 }
418}
419
420pub struct StreamingTransformerWrapper {
422 transformer: Box<dyn StreamingTransformer + Send + Sync>,
423 name: String,
424 fitted: bool,
425}
426
427impl StreamingTransformerWrapper {
428 pub fn new<S>(transformer: S, name: String) -> Self
430 where
431 S: StreamingTransformer + Send + Sync + 'static,
432 {
433 Self {
434 transformer: Box::new(transformer),
435 name,
436 fitted: false,
437 }
438 }
439
440 pub fn partial_fit(&mut self, data: &Array2<f64>) -> Result<()> {
442 self.transformer.partial_fit(data).map_err(|e| {
443 SklearsError::InvalidInput(format!("Streaming transformer error: {}", e))
444 })?;
445 self.fitted = true;
446 Ok(())
447 }
448
449 pub fn is_fitted(&self) -> bool {
451 self.fitted && self.transformer.is_fitted()
452 }
453
454 pub fn get_streaming_stats(&self) -> Option<StreamingStats> {
456 Some(self.transformer.get_stats())
457 }
458
459 pub fn reset(&mut self) {
461 self.transformer.reset();
462 self.fitted = false;
463 }
464
465 pub fn name(&self) -> &str {
467 &self.name
468 }
469}
470
471impl Transform<Array2<f64>, Array2<f64>> for StreamingTransformerWrapper {
472 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
473 if !self.is_fitted() {
474 return Err(SklearsError::NotFitted {
475 operation: format!("transform on streaming transformer '{}'", self.name),
476 });
477 }
478 self.transformer
479 .transform(data)
480 .map_err(|e| SklearsError::InvalidInput(e.to_string()))
481 }
482}
483
484impl Clone for StreamingTransformerWrapper {
485 fn clone(&self) -> Self {
486 Self {
489 transformer: Box::new(crate::streaming::StreamingStandardScaler::new(
490 StreamingConfig::default(),
491 )),
492 name: self.name.clone(),
493 fitted: false,
494 }
495 }
496}
497
498pub struct AdvancedPipeline<T> {
500 steps: Vec<PipelineStep<T>>,
501 cache: TransformationCache<Array2<f64>>,
502 config: AdvancedPipelineConfig,
503}
504
505pub enum PipelineStep<T> {
507 Simple(T),
509 Conditional(ConditionalStep<T>),
511 Parallel(ParallelBranches<T>),
513 Cached(T, String), Streaming(StreamingTransformerWrapper),
517}
518
519#[derive(Debug, Clone)]
521#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
522pub struct AdvancedPipelineConfig {
523 pub cache_config: CacheConfig,
525 pub parallel_execution: bool,
527 pub error_strategy: ErrorHandlingStrategy,
529}
530
531#[derive(Debug, Clone, Copy)]
533#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
534pub enum ErrorHandlingStrategy {
535 StopOnError,
537 SkipOnError,
539 Fallback,
541}
542
543impl Default for AdvancedPipelineConfig {
544 fn default() -> Self {
545 Self {
546 cache_config: CacheConfig::default(),
547 parallel_execution: true,
548 error_strategy: ErrorHandlingStrategy::StopOnError,
549 }
550 }
551}
552
553impl<T> AdvancedPipeline<T>
554where
555 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
556{
557 pub fn new(config: AdvancedPipelineConfig) -> Self {
558 Self {
559 steps: Vec::new(),
560 cache: TransformationCache::new(config.cache_config.clone()),
561 config,
562 }
563 }
564
565 pub fn add_step(mut self, transformer: T) -> Self {
567 self.steps.push(PipelineStep::Simple(transformer));
568 self
569 }
570
571 pub fn add_conditional_step(mut self, config: ConditionalStepConfig<T>) -> Self {
573 self.steps
574 .push(PipelineStep::Conditional(ConditionalStep::new(config)));
575 self
576 }
577
578 pub fn add_parallel_branches(mut self, config: ParallelBranchConfig<T>) -> Result<Self> {
580 let branches = ParallelBranches::new(config)?;
581 self.steps.push(PipelineStep::Parallel(branches));
582 Ok(self)
583 }
584
585 pub fn add_cached_step(mut self, transformer: T, cache_key_prefix: String) -> Self {
587 self.steps
588 .push(PipelineStep::Cached(transformer, cache_key_prefix));
589 self
590 }
591
592 pub fn add_streaming_step<S>(mut self, transformer: S, name: String) -> Self
594 where
595 S: StreamingTransformer + Send + Sync + 'static,
596 {
597 let wrapper = StreamingTransformerWrapper::new(transformer, name);
598 self.steps.push(PipelineStep::Streaming(wrapper));
599 self
600 }
601
602 pub fn add_pca_step(self, _pca: crate::dimensionality_reduction::PCA) -> Self {
604 self
607 }
608
609 pub fn cache_stats(&self) -> CacheStats {
611 self.cache.stats()
612 }
613
614 pub fn clear_cache(&self) {
616 self.cache.clear();
617 }
618
619 pub fn partial_fit(&mut self, data: &Array2<f64>) -> Result<()> {
621 let mut current_data = data.clone();
622
623 for step in &mut self.steps {
624 match step {
625 PipelineStep::Streaming(ref mut streaming_wrapper) => {
626 streaming_wrapper.partial_fit(¤t_data)?;
627 if streaming_wrapper.is_fitted() {
629 current_data = streaming_wrapper.transform(¤t_data)?;
630 }
631 }
632 PipelineStep::Simple(transformer) => {
634 if let Ok(transformed) = transformer.transform(¤t_data) {
636 current_data = transformed;
637 }
638 }
639 PipelineStep::Conditional(conditional) => {
640 if let Ok(transformed) = conditional.transform(¤t_data) {
641 current_data = transformed;
642 }
643 }
644 PipelineStep::Parallel(parallel) => {
645 if let Ok(transformed) = parallel.transform(¤t_data) {
646 current_data = transformed;
647 }
648 }
649 PipelineStep::Cached(transformer, _) => {
650 if let Ok(transformed) = transformer.transform(¤t_data) {
651 current_data = transformed;
652 }
653 }
654 }
655 }
656
657 Ok(())
658 }
659
660 pub fn get_streaming_stats(&self) -> Vec<(String, Option<StreamingStats>)> {
662 let mut stats = Vec::new();
663
664 for step in &self.steps {
665 if let PipelineStep::Streaming(streaming_wrapper) = step {
666 stats.push((
667 streaming_wrapper.name().to_string(),
668 streaming_wrapper.get_streaming_stats(),
669 ));
670 }
671 }
672
673 stats
674 }
675
676 pub fn reset_streaming(&mut self) {
678 for step in &mut self.steps {
679 if let PipelineStep::Streaming(ref mut streaming_wrapper) = step {
680 streaming_wrapper.reset();
681 }
682 }
683 }
684}
685
686impl<T> Transform<Array2<f64>, Array2<f64>> for AdvancedPipeline<T>
687where
688 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
689{
690 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
691 let mut current_data = data.clone();
692 for (step_idx, step) in self.steps.iter().enumerate() {
693 let step_result = match step {
694 PipelineStep::Simple(transformer) => transformer.transform(¤t_data),
695 PipelineStep::Conditional(conditional) => conditional.transform(¤t_data),
696 PipelineStep::Parallel(parallel) => parallel.transform(¤t_data),
697 PipelineStep::Cached(transformer, _cache_key_prefix) => {
698 transformer.transform(¤t_data)
701 }
702 PipelineStep::Streaming(streaming_wrapper) => {
703 streaming_wrapper.transform(¤t_data)
704 }
705 };
706
707 match step_result {
709 Ok(result) => {
710 current_data = result;
711 }
712 Err(e) => {
713 match self.config.error_strategy {
714 ErrorHandlingStrategy::StopOnError => return Err(e),
715 ErrorHandlingStrategy::SkipOnError => {
716 eprintln!("Warning: Step {} failed: {}. Skipping...", step_idx, e);
718 }
720 ErrorHandlingStrategy::Fallback => {
721 eprintln!(
724 "Warning: Step {} failed: {}. Using fallback (passthrough)...",
725 step_idx, e
726 );
727 }
728 }
729 }
730 }
731 }
732
733 Ok(current_data)
734 }
735}
736
737pub struct AdvancedPipelineBuilder<T> {
739 config: AdvancedPipelineConfig,
740 pipeline: AdvancedPipeline<T>,
741}
742
743impl<T> AdvancedPipelineBuilder<T>
744where
745 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
746{
747 pub fn new() -> Self {
748 let config = AdvancedPipelineConfig::default();
749 let pipeline = AdvancedPipeline::new(config.clone());
750 Self { config, pipeline }
751 }
752
753 pub fn with_cache_config(mut self, cache_config: CacheConfig) -> Self {
754 self.config.cache_config = cache_config;
755 self.pipeline.cache = TransformationCache::new(self.config.cache_config.clone());
756 self
757 }
758
759 pub fn with_error_strategy(mut self, strategy: ErrorHandlingStrategy) -> Self {
760 self.config.error_strategy = strategy;
761 self.pipeline.config.error_strategy = strategy;
762 self
763 }
764
765 pub fn add_step(mut self, transformer: T) -> Self {
766 self.pipeline = self.pipeline.add_step(transformer);
767 self
768 }
769
770 pub fn add_conditional_step(mut self, config: ConditionalStepConfig<T>) -> Self {
771 self.pipeline = self.pipeline.add_conditional_step(config);
772 self
773 }
774
775 pub fn add_parallel_branches(mut self, config: ParallelBranchConfig<T>) -> Result<Self> {
776 self.pipeline = self.pipeline.add_parallel_branches(config)?;
777 Ok(self)
778 }
779
780 pub fn add_cached_step(mut self, transformer: T, cache_key_prefix: String) -> Self {
781 self.pipeline = self.pipeline.add_cached_step(transformer, cache_key_prefix);
782 self
783 }
784
785 pub fn add_streaming_step<S>(mut self, transformer: S, name: String) -> Self
786 where
787 S: StreamingTransformer + Send + Sync + 'static,
788 {
789 self.pipeline = self.pipeline.add_streaming_step(transformer, name);
790 self
791 }
792
793 pub fn build(self) -> AdvancedPipeline<T> {
794 self.pipeline
795 }
796}
797
798impl<T> Default for AdvancedPipelineBuilder<T>
799where
800 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
801{
802 fn default() -> Self {
803 Self::new()
804 }
805}
806
807pub struct DynamicPipeline<T> {
809 steps: Arc<RwLock<Vec<PipelineStep<T>>>>,
810 cache: TransformationCache<Array2<f64>>,
811 config: AdvancedPipelineConfig,
812}
813
814impl<T> DynamicPipeline<T>
815where
816 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
817{
818 pub fn new(config: AdvancedPipelineConfig) -> Self {
819 Self {
820 steps: Arc::new(RwLock::new(Vec::new())),
821 cache: TransformationCache::new(config.cache_config.clone()),
822 config,
823 }
824 }
825
826 pub fn add_step_runtime(&self, transformer: T) -> Result<()> {
828 let mut steps = self
829 .steps
830 .write()
831 .map_err(|_| SklearsError::InvalidInput("Failed to acquire write lock".to_string()))?;
832 steps.push(PipelineStep::Simple(transformer));
833 Ok(())
834 }
835
836 pub fn add_streaming_step_runtime<S>(&self, transformer: S, name: String) -> Result<()>
838 where
839 S: StreamingTransformer + Send + Sync + 'static,
840 {
841 let mut steps = self
842 .steps
843 .write()
844 .map_err(|_| SklearsError::InvalidInput("Failed to acquire write lock".to_string()))?;
845 let wrapper = StreamingTransformerWrapper::new(transformer, name);
846 steps.push(PipelineStep::Streaming(wrapper));
847 Ok(())
848 }
849
850 pub fn remove_step(&self, index: usize) -> Result<()> {
852 let mut steps = self
853 .steps
854 .write()
855 .map_err(|_| SklearsError::InvalidInput("Failed to acquire write lock".to_string()))?;
856
857 if index >= steps.len() {
858 return Err(SklearsError::InvalidInput(
859 "Step index out of bounds".to_string(),
860 ));
861 }
862
863 steps.remove(index);
864 Ok(())
865 }
866
867 pub fn len(&self) -> usize {
869 self.steps.read().map(|s| s.len()).unwrap_or(0)
870 }
871
872 pub fn is_empty(&self) -> bool {
874 self.len() == 0
875 }
876
877 pub fn partial_fit(&self, data: &Array2<f64>) -> Result<()> {
879 let mut current_data = data.clone();
880 let mut steps = self
881 .steps
882 .write()
883 .map_err(|_| SklearsError::InvalidInput("Failed to acquire write lock".to_string()))?;
884
885 for step in steps.iter_mut() {
886 match step {
887 PipelineStep::Streaming(ref mut streaming_wrapper) => {
888 streaming_wrapper.partial_fit(¤t_data)?;
889 if streaming_wrapper.is_fitted() {
891 current_data = streaming_wrapper.transform(¤t_data)?;
892 }
893 }
894 PipelineStep::Simple(transformer) => {
896 if let Ok(transformed) = transformer.transform(¤t_data) {
897 current_data = transformed;
898 }
899 }
900 PipelineStep::Conditional(conditional) => {
901 if let Ok(transformed) = conditional.transform(¤t_data) {
902 current_data = transformed;
903 }
904 }
905 PipelineStep::Parallel(parallel) => {
906 if let Ok(transformed) = parallel.transform(¤t_data) {
907 current_data = transformed;
908 }
909 }
910 PipelineStep::Cached(transformer, _) => {
911 if let Ok(transformed) = transformer.transform(¤t_data) {
912 current_data = transformed;
913 }
914 }
915 }
916 }
917
918 Ok(())
919 }
920
921 pub fn get_streaming_stats(&self) -> Result<Vec<(String, Option<StreamingStats>)>> {
923 let mut stats = Vec::new();
924 let steps = self
925 .steps
926 .read()
927 .map_err(|_| SklearsError::InvalidInput("Failed to acquire read lock".to_string()))?;
928
929 for step in steps.iter() {
930 if let PipelineStep::Streaming(streaming_wrapper) = step {
931 stats.push((
932 streaming_wrapper.name().to_string(),
933 streaming_wrapper.get_streaming_stats(),
934 ));
935 }
936 }
937
938 Ok(stats)
939 }
940
941 pub fn reset_streaming(&self) -> Result<()> {
943 let mut steps = self
944 .steps
945 .write()
946 .map_err(|_| SklearsError::InvalidInput("Failed to acquire write lock".to_string()))?;
947
948 for step in steps.iter_mut() {
949 if let PipelineStep::Streaming(ref mut streaming_wrapper) = step {
950 streaming_wrapper.reset();
951 }
952 }
953
954 Ok(())
955 }
956}
957
958impl<T> Transform<Array2<f64>, Array2<f64>> for DynamicPipeline<T>
959where
960 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
961{
962 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
963 let mut current_data = data.clone();
964 let steps = self
965 .steps
966 .read()
967 .map_err(|_| SklearsError::InvalidInput("Failed to acquire read lock".to_string()))?;
968
969 for (step_idx, step) in steps.iter().enumerate() {
970 let step_result = match step {
971 PipelineStep::Simple(transformer) => transformer.transform(¤t_data),
972 PipelineStep::Conditional(conditional) => conditional.transform(¤t_data),
973 PipelineStep::Parallel(parallel) => parallel.transform(¤t_data),
974 PipelineStep::Cached(transformer, _cache_key_prefix) => {
975 transformer.transform(¤t_data)
977 }
978 PipelineStep::Streaming(streaming_wrapper) => {
979 streaming_wrapper.transform(¤t_data)
980 }
981 };
982
983 match step_result {
984 Ok(result) => {
985 current_data = result;
986 }
987 Err(e) => match self.config.error_strategy {
988 ErrorHandlingStrategy::StopOnError => return Err(e),
989 ErrorHandlingStrategy::SkipOnError => {
990 eprintln!("Warning: Step {} failed: {}. Skipping...", step_idx, e);
991 }
992 ErrorHandlingStrategy::Fallback => {
993 eprintln!(
994 "Warning: Step {} failed: {}. Using fallback (passthrough)...",
995 step_idx, e
996 );
997 }
998 },
999 }
1000 }
1001
1002 Ok(current_data)
1003 }
1004}
1005
1006#[allow(non_snake_case)]
1007#[cfg(test)]
1008mod tests {
1009 use super::*;
1010 use scirs2_core::ndarray::arr2;
1011
1012 #[test]
1013 fn test_transformation_cache() {
1014 let config = CacheConfig {
1015 max_entries: 2,
1016 ttl_seconds: 1,
1017 enabled: true,
1018 };
1019
1020 let cache = TransformationCache::new(config);
1021 let data = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
1022
1023 let key = cache.generate_key("test_key");
1025 assert!(cache.get(key).is_none());
1026
1027 cache.put(key, data.clone());
1029 assert!(cache.get(key).is_some());
1030
1031 let stats = cache.stats();
1033 assert_eq!(stats.entries, 1);
1034 assert!(stats.enabled);
1035 }
1036
1037 #[test]
1089 fn test_streaming_transformer_wrapper() {
1090 use crate::streaming::{StreamingConfig, StreamingStandardScaler};
1091 use scirs2_core::ndarray::Array2;
1092
1093 let scaler = StreamingStandardScaler::new(StreamingConfig::default());
1094 let mut wrapper = StreamingTransformerWrapper::new(scaler, "test_scaler".to_string());
1095
1096 assert!(!wrapper.is_fitted());
1098
1099 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1101 wrapper.partial_fit(&data).unwrap();
1102
1103 assert!(wrapper.is_fitted());
1105
1106 let result = wrapper.transform(&data).unwrap();
1108 assert_eq!(result.dim(), data.dim());
1109
1110 let stats = wrapper.get_streaming_stats();
1112 assert!(stats.is_some());
1113
1114 assert_eq!(wrapper.name(), "test_scaler");
1116
1117 wrapper.reset();
1119 assert!(!wrapper.is_fitted());
1120 }
1121
1122 }