1use crate::UtilsError;
7use scirs2_core::ndarray::{s, Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt;
11use std::sync::{Arc, Mutex, RwLock};
12use std::time::{Duration, Instant};
13
14pub trait PipelineStep: Send + Sync {
16 type Input;
17 type Output;
18 type Error;
19
20 fn process(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
21 fn name(&self) -> &str;
22 fn description(&self) -> Option<&str> {
23 None
24 }
25}
26
27pub type TransformFn<T, U> = Box<dyn Fn(T) -> Result<U, UtilsError> + Send + Sync>;
29
30pub struct TransformStep<T, U> {
32 name: String,
33 description: Option<String>,
34 transform_fn: TransformFn<T, U>,
35}
36
37impl<T, U> TransformStep<T, U> {
38 pub fn new(name: String, transform_fn: TransformFn<T, U>) -> Self {
39 Self {
40 name,
41 description: None,
42 transform_fn,
43 }
44 }
45
46 pub fn with_description(mut self, description: String) -> Self {
47 self.description = Some(description);
48 self
49 }
50}
51
52impl<T, U> PipelineStep for TransformStep<T, U>
53where
54 T: Send + Sync,
55 U: Send + Sync,
56{
57 type Input = T;
58 type Output = U;
59 type Error = UtilsError;
60
61 fn process(&self, input: T) -> Result<U, UtilsError> {
62 (self.transform_fn)(input)
63 }
64
65 fn name(&self) -> &str {
66 &self.name
67 }
68
69 fn description(&self) -> Option<&str> {
70 self.description.as_deref()
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct PipelineContext {
77 pub metadata: HashMap<String, String>,
78 pub start_time: Instant,
79 cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
80}
81
82impl Default for PipelineContext {
83 fn default() -> Self {
84 Self {
85 metadata: HashMap::new(),
86 start_time: Instant::now(),
87 cache: Arc::new(RwLock::new(HashMap::new())),
88 }
89 }
90}
91
92impl PipelineContext {
93 pub fn new() -> Self {
94 Self::default()
95 }
96
97 pub fn with_metadata(mut self, key: String, value: String) -> Self {
98 self.metadata.insert(key, value);
99 self
100 }
101
102 pub fn elapsed(&self) -> Duration {
103 self.start_time.elapsed()
104 }
105
106 pub fn cache_get(&self, key: &str) -> Option<Vec<u8>> {
107 self.cache.read().ok()?.get(key).cloned()
108 }
109
110 pub fn cache_set(&self, key: String, value: Vec<u8>) {
111 if let Ok(mut cache) = self.cache.write() {
112 cache.insert(key, value);
113 }
114 }
115
116 pub fn cache_clear(&self) {
117 if let Ok(mut cache) = self.cache.write() {
118 cache.clear();
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct PipelineResult<T> {
126 pub data: T,
127 pub execution_time: Duration,
128 pub steps_executed: Vec<String>,
129 pub metadata: HashMap<String, String>,
130}
131
132impl<T> PipelineResult<T> {
133 pub fn new(data: T, execution_time: Duration, steps_executed: Vec<String>) -> Self {
134 Self {
135 data,
136 execution_time,
137 steps_executed,
138 metadata: HashMap::new(),
139 }
140 }
141
142 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
143 self.metadata = metadata;
144 self
145 }
146}
147
148pub struct DataPipeline<T> {
150 steps: Vec<Box<dyn PipelineStep<Input = T, Output = T, Error = UtilsError>>>,
151 context: PipelineContext,
152 validation_enabled: bool,
153 parallel_execution: bool,
154}
155
156impl<T> Default for DataPipeline<T>
157where
158 T: Clone + Send + Sync + 'static,
159{
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165impl<T> DataPipeline<T>
166where
167 T: Clone + Send + Sync + 'static,
168{
169 pub fn new() -> Self {
170 Self {
171 steps: Vec::new(),
172 context: PipelineContext::new(),
173 validation_enabled: true,
174 parallel_execution: false,
175 }
176 }
177
178 pub fn with_context(mut self, context: PipelineContext) -> Self {
179 self.context = context;
180 self
181 }
182
183 pub fn enable_validation(mut self, enabled: bool) -> Self {
184 self.validation_enabled = enabled;
185 self
186 }
187
188 pub fn enable_parallel_execution(mut self, enabled: bool) -> Self {
189 self.parallel_execution = enabled;
190 self
191 }
192
193 pub fn add_step(
194 mut self,
195 step: Box<dyn PipelineStep<Input = T, Output = T, Error = UtilsError>>,
196 ) -> Self {
197 self.steps.push(step);
198 self
199 }
200
201 pub fn add_transform<F>(self, name: String, transform_fn: F) -> Self
202 where
203 F: Fn(T) -> Result<T, UtilsError> + Send + Sync + 'static,
204 {
205 let step = TransformStep::new(name, Box::new(transform_fn));
206 self.add_step(Box::new(step))
207 }
208
209 pub fn execute(&self, mut data: T) -> Result<PipelineResult<T>, UtilsError> {
210 let start_time = Instant::now();
211 let mut steps_executed = Vec::new();
212
213 for step in &self.steps {
214 let step_start = Instant::now();
215
216 data = step.process(data).map_err(|e| {
217 UtilsError::InvalidParameter(format!(
218 "Pipeline step '{}' failed: {}",
219 step.name(),
220 e
221 ))
222 })?;
223
224 steps_executed.push(format!(
225 "{} ({}ms)",
226 step.name(),
227 step_start.elapsed().as_millis()
228 ));
229 }
230
231 let execution_time = start_time.elapsed();
232 Ok(PipelineResult::new(data, execution_time, steps_executed)
233 .with_metadata(self.context.metadata.clone()))
234 }
235}
236
237pub struct MLPipelineBuilder;
239
240impl MLPipelineBuilder {
241 pub fn data_cleaning() -> DataPipeline<Array2<f64>> {
243 DataPipeline::new()
244 .add_transform("remove_duplicates".to_string(), |data: Array2<f64>| {
245 let mut unique_rows = Vec::new();
247 let mut prev_row: Option<Array1<f64>> = None;
248
249 for row in data.rows() {
250 let current_row = row.to_owned();
251 if prev_row.as_ref() != Some(¤t_row) {
252 unique_rows.push(current_row.clone());
253 }
254 prev_row = Some(current_row.clone());
255 }
256
257 if unique_rows.is_empty() {
258 return Err(UtilsError::EmptyInput);
259 }
260
261 let n_cols = unique_rows[0].len();
262 let mut result = Array2::zeros((unique_rows.len(), n_cols));
263 for (i, row) in unique_rows.iter().enumerate() {
264 result.row_mut(i).assign(row);
265 }
266 Ok(result)
267 })
268 .add_transform("handle_missing_values".to_string(), |mut data| {
269 for mut col in data.columns_mut() {
271 let valid_values: Vec<f64> =
272 col.iter().filter(|&&x| x.is_finite()).copied().collect();
273
274 if !valid_values.is_empty() {
275 let mean = valid_values.iter().sum::<f64>() / valid_values.len() as f64;
276 for val in col.iter_mut() {
277 if !val.is_finite() {
278 *val = mean;
279 }
280 }
281 }
282 }
283 Ok(data)
284 })
285 .add_transform("normalize_data".to_string(), |mut data| {
286 for mut col in data.columns_mut() {
288 let mean = col.mean().unwrap_or(0.0);
289 let std = {
290 let variance =
291 col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / col.len() as f64;
292 variance.sqrt()
293 };
294
295 if std > 1e-10 {
296 for val in col.iter_mut() {
297 *val = (*val - mean) / std;
298 }
299 }
300 }
301 Ok(data)
302 })
303 }
304
305 pub fn feature_engineering() -> DataPipeline<Array2<f64>> {
307 DataPipeline::new()
308 .add_transform(
309 "add_polynomial_features".to_string(),
310 |data: Array2<f64>| {
311 let (n_rows, n_cols) = data.dim();
312 let mut result = Array2::zeros((n_rows, n_cols + n_cols * (n_cols - 1) / 2));
313
314 result.slice_mut(s![.., ..n_cols]).assign(&data);
316
317 let mut col_idx = n_cols;
319 for i in 0..n_cols {
320 for j in (i + 1)..n_cols {
321 for row in 0..n_rows {
322 result[[row, col_idx]] = data[[row, i]] * data[[row, j]];
323 }
324 col_idx += 1;
325 }
326 }
327
328 Ok(result)
329 },
330 )
331 .add_transform("add_statistical_features".to_string(), |data| {
332 let (n_rows, n_cols) = data.dim();
333 let mut result = Array2::zeros((n_rows, n_cols + 3)); result.slice_mut(s![.., ..n_cols]).assign(&data);
337
338 for (i, row) in data.rows().into_iter().enumerate() {
340 let mean = row.mean().unwrap_or(0.0);
341 let std = {
342 let variance =
343 row.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / row.len() as f64;
344 variance.sqrt()
345 };
346 let min = row.iter().copied().fold(f64::INFINITY, f64::min);
347 let max = row.iter().copied().fold(f64::NEG_INFINITY, f64::max);
348 let range = max - min;
349
350 result[[i, n_cols]] = mean;
351 result[[i, n_cols + 1]] = std;
352 result[[i, n_cols + 2]] = range;
353 }
354
355 Ok(result)
356 })
357 }
358
359 pub fn data_validation() -> DataPipeline<Array2<f64>> {
361 DataPipeline::new()
362 .add_transform(
363 "check_shape_consistency".to_string(),
364 |data: Array2<f64>| {
365 if data.is_empty() {
366 return Err(UtilsError::EmptyInput);
367 }
368 if data.nrows() == 0 || data.ncols() == 0 {
369 return Err(UtilsError::InvalidParameter(
370 "Data has zero rows or columns".to_string(),
371 ));
372 }
373 Ok(data)
374 },
375 )
376 .add_transform("check_data_quality".to_string(), |data| {
377 let total_elements = data.len();
378 let nan_count = data.iter().filter(|&&x| !x.is_finite()).count();
379 let nan_ratio = nan_count as f64 / total_elements as f64;
380
381 if nan_ratio > 0.5 {
382 return Err(UtilsError::InvalidParameter(format!(
383 "Too many missing values: {:.2}%",
384 nan_ratio * 100.0
385 )));
386 }
387
388 Ok(data)
389 })
390 .add_transform("check_feature_variance".to_string(), |data| {
391 for (i, col) in data.columns().into_iter().enumerate() {
392 let mean = col.mean().unwrap_or(0.0);
393 let variance =
394 col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / col.len() as f64;
395
396 if variance < 1e-10 {
397 return Err(UtilsError::InvalidParameter(format!(
398 "Feature {i} has zero variance"
399 )));
400 }
401 }
402 Ok(data)
403 })
404 }
405}
406
407#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct PipelineMetrics {
410 pub total_executions: u64,
411 pub successful_executions: u64,
412 pub failed_executions: u64,
413 pub average_execution_time: Duration,
414 pub total_execution_time: Duration,
415 pub step_metrics: HashMap<String, StepMetrics>,
416}
417
418#[derive(Debug, Clone, Serialize, Deserialize)]
419pub struct StepMetrics {
420 pub executions: u64,
421 pub average_time: Duration,
422 pub total_time: Duration,
423 pub success_rate: f64,
424}
425
426impl Default for PipelineMetrics {
427 fn default() -> Self {
428 Self {
429 total_executions: 0,
430 successful_executions: 0,
431 failed_executions: 0,
432 average_execution_time: Duration::from_secs(0),
433 total_execution_time: Duration::from_secs(0),
434 step_metrics: HashMap::new(),
435 }
436 }
437}
438
439impl PipelineMetrics {
440 pub fn success_rate(&self) -> f64 {
441 if self.total_executions == 0 {
442 0.0
443 } else {
444 self.successful_executions as f64 / self.total_executions as f64
445 }
446 }
447
448 pub fn record_execution(&mut self, result: &PipelineResult<impl Clone>, success: bool) {
449 self.total_executions += 1;
450 if success {
451 self.successful_executions += 1;
452 } else {
453 self.failed_executions += 1;
454 }
455
456 self.total_execution_time += result.execution_time;
457 self.average_execution_time = Duration::from_nanos(
458 (self.total_execution_time.as_nanos() / self.total_executions as u128) as u64,
459 );
460 }
461}
462
463pub struct PipelineMonitor {
465 metrics: Arc<Mutex<PipelineMetrics>>,
466 enabled: bool,
467}
468
469impl Default for PipelineMonitor {
470 fn default() -> Self {
471 Self {
472 metrics: Arc::new(Mutex::new(PipelineMetrics::default())),
473 enabled: true,
474 }
475 }
476}
477
478impl PipelineMonitor {
479 pub fn new() -> Self {
480 Self::default()
481 }
482
483 pub fn enable(&mut self, enabled: bool) {
484 self.enabled = enabled;
485 }
486
487 pub fn record_execution<T: Clone>(&self, result: &PipelineResult<T>, success: bool) {
488 if !self.enabled {
489 return;
490 }
491
492 if let Ok(mut metrics) = self.metrics.lock() {
493 metrics.record_execution(result, success);
494 }
495 }
496
497 pub fn get_metrics(&self) -> Option<PipelineMetrics> {
498 self.metrics.lock().ok().map(|m| m.clone())
499 }
500
501 pub fn reset_metrics(&self) {
502 if let Ok(mut metrics) = self.metrics.lock() {
503 *metrics = PipelineMetrics::default();
504 }
505 }
506}
507
508impl fmt::Display for PipelineMetrics {
509 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510 writeln!(f, "Pipeline Metrics:")?;
511 writeln!(f, " Total Executions: {}", self.total_executions)?;
512 writeln!(f, " Success Rate: {:.2}%", self.success_rate() * 100.0)?;
513 writeln!(
514 f,
515 " Average Execution Time: {:?}",
516 self.average_execution_time
517 )?;
518 writeln!(f, " Total Execution Time: {:?}", self.total_execution_time)?;
519 Ok(())
520 }
521}
522
523#[allow(non_snake_case)]
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use scirs2_core::ndarray::array;
528
529 #[test]
530 fn test_pipeline_context() {
531 let context = PipelineContext::new().with_metadata("user".to_string(), "test".to_string());
532
533 assert_eq!(context.metadata.get("user"), Some(&"test".to_string()));
534
535 context.cache_set("key1".to_string(), vec![1, 2, 3]);
536 assert_eq!(context.cache_get("key1"), Some(vec![1, 2, 3]));
537
538 context.cache_clear();
539 assert_eq!(context.cache_get("key1"), None);
540 }
541
542 #[test]
543 fn test_transform_step() {
544 let step = TransformStep::new("double".to_string(), Box::new(|x: f64| Ok(x * 2.0)))
545 .with_description("Doubles the input value".to_string());
546
547 assert_eq!(step.name(), "double");
548 assert_eq!(step.description(), Some("Doubles the input value"));
549 assert_eq!(step.process(5.0).unwrap(), 10.0);
550 }
551
552 #[test]
553 fn test_data_pipeline_execution() {
554 let pipeline = DataPipeline::new()
555 .add_transform("add_one".to_string(), |x: f64| Ok(x + 1.0))
556 .add_transform("multiply_two".to_string(), |x: f64| Ok(x * 2.0));
557
558 let result = pipeline.execute(5.0).unwrap();
559 assert_eq!(result.data, 12.0); assert_eq!(result.steps_executed.len(), 2);
561 }
562
563 #[test]
564 fn test_ml_pipeline_data_cleaning() {
565 let data = array![[1.0, 2.0, f64::NAN], [3.0, f64::NAN, 4.0], [5.0, 6.0, 7.0]];
566
567 let pipeline = MLPipelineBuilder::data_cleaning();
568 let result = pipeline.execute(data).unwrap();
569
570 assert!(result.data.iter().all(|&x| x.is_finite()));
572 assert_eq!(result.steps_executed.len(), 3);
573 }
574
575 #[test]
576 fn test_ml_pipeline_feature_engineering() {
577 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
578
579 let pipeline = MLPipelineBuilder::feature_engineering();
580 let result = pipeline.execute(data).unwrap();
581
582 assert_eq!(result.data.ncols(), 6);
584 assert_eq!(result.steps_executed.len(), 2);
585 }
586
587 #[test]
588 fn test_ml_pipeline_validation() {
589 let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
590
591 let pipeline = MLPipelineBuilder::data_validation();
592 let result = pipeline.execute(data).unwrap();
593
594 assert_eq!(result.data.shape(), &[3, 3]);
595 assert_eq!(result.steps_executed.len(), 3);
596 }
597
598 #[test]
599 fn test_pipeline_validation_failure() {
600 let data = array![[1.0, 1.0], [2.0, 1.0], [3.0, 1.0]];
602
603 let pipeline = MLPipelineBuilder::data_validation();
604 let result = pipeline.execute(data);
605
606 assert!(result.is_err());
607 assert!(result.unwrap_err().to_string().contains("zero variance"));
608 }
609
610 #[test]
611 fn test_pipeline_monitor() {
612 let monitor = PipelineMonitor::new();
613
614 let result =
615 PipelineResult::new(42.0, Duration::from_millis(100), vec!["step1".to_string()]);
616
617 monitor.record_execution(&result, true);
618
619 let metrics = monitor.get_metrics().unwrap();
620 assert_eq!(metrics.total_executions, 1);
621 assert_eq!(metrics.successful_executions, 1);
622 assert_eq!(metrics.success_rate(), 1.0);
623
624 monitor.reset_metrics();
625 let metrics = monitor.get_metrics().unwrap();
626 assert_eq!(metrics.total_executions, 0);
627 }
628
629 #[test]
630 fn test_pipeline_metrics_display() {
631 let mut metrics = PipelineMetrics::default();
632 metrics.total_executions = 10;
633 metrics.successful_executions = 8;
634 metrics.average_execution_time = Duration::from_millis(50);
635
636 let display = format!("{metrics}");
637 assert!(display.contains("Total Executions: 10"));
638 assert!(display.contains("Success Rate: 80.00%"));
639 }
640}