sklears_compose/
dag_pipeline.rs

1//! Directed Acyclic Graph (DAG) pipeline components
2//!
3//! This module provides DAG-based pipeline execution for complex workflows with
4//! parallel execution, dependency management, and cycle detection.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use sklears_core::{
8    error::{Result as SklResult, SklearsError},
9    traits::{Estimator, Fit, Untrained},
10    types::Float,
11};
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::fmt::Debug;
14
15use crate::{PipelinePredictor, PipelineStep};
16
17/// Node in a DAG pipeline
18#[derive(Debug)]
19pub struct DAGNode {
20    /// Unique node identifier
21    pub id: String,
22    /// Node name/description
23    pub name: String,
24    /// Pipeline component
25    pub component: NodeComponent,
26    /// Input dependencies
27    pub dependencies: Vec<String>,
28    /// Output consumers
29    pub consumers: Vec<String>,
30    /// Node metadata
31    pub metadata: HashMap<String, String>,
32    /// Execution configuration
33    pub config: NodeConfig,
34}
35
36/// Node component types
37pub enum NodeComponent {
38    /// Transformer component
39    Transformer(Box<dyn PipelineStep>),
40    /// Estimator component
41    Estimator(Box<dyn PipelinePredictor>),
42    /// Data source
43    DataSource {
44        data: Option<Array2<f64>>,
45        targets: Option<Array1<f64>>,
46    },
47    /// Data sink/output
48    DataSink,
49    /// Conditional branch
50    ConditionalBranch {
51        condition: BranchCondition,
52        true_path: String,
53        false_path: String,
54    },
55    /// Data merger
56    DataMerger { merge_strategy: MergeStrategy },
57    /// Custom function
58    CustomFunction {
59        function: Box<dyn Fn(&[NodeOutput]) -> SklResult<NodeOutput> + Send + Sync>,
60    },
61}
62
63impl Debug for NodeComponent {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            NodeComponent::Transformer(_) => f
67                .debug_tuple("Transformer")
68                .field(&"<transformer>")
69                .finish(),
70            NodeComponent::Estimator(_) => {
71                f.debug_tuple("Estimator").field(&"<estimator>").finish()
72            }
73            NodeComponent::DataSource { data, targets } => f
74                .debug_struct("DataSource")
75                .field(
76                    "data",
77                    &data
78                        .as_ref()
79                        .map(|d| format!("Array2<f64>({}, {})", d.nrows(), d.ncols())),
80                )
81                .field(
82                    "targets",
83                    &targets
84                        .as_ref()
85                        .map(|t| format!("Array1<f64>({})", t.len())),
86                )
87                .finish(),
88            NodeComponent::DataSink => f.debug_tuple("DataSink").finish(),
89            NodeComponent::ConditionalBranch {
90                condition,
91                true_path,
92                false_path,
93            } => f
94                .debug_struct("ConditionalBranch")
95                .field("condition", condition)
96                .field("true_path", true_path)
97                .field("false_path", false_path)
98                .finish(),
99            NodeComponent::DataMerger { merge_strategy } => f
100                .debug_struct("DataMerger")
101                .field("merge_strategy", merge_strategy)
102                .finish(),
103            NodeComponent::CustomFunction { .. } => f
104                .debug_struct("CustomFunction")
105                .field("function", &"<function>")
106                .finish(),
107        }
108    }
109}
110
111/// Branch condition for conditional nodes
112#[derive(Debug)]
113pub enum BranchCondition {
114    /// Feature threshold condition
115    FeatureThreshold {
116        feature_idx: usize,
117        threshold: f64,
118        comparison: ComparisonOp,
119    },
120    /// Data size condition
121    DataSize {
122        min_samples: Option<usize>,
123        max_samples: Option<usize>,
124    },
125    /// Custom condition
126    Custom {
127        condition_fn: fn(&NodeOutput) -> bool,
128    },
129}
130
131/// Comparison operators for conditions
132#[derive(Debug, Clone)]
133pub enum ComparisonOp {
134    /// GreaterThan
135    GreaterThan,
136    /// LessThan
137    LessThan,
138    /// GreaterEqual
139    GreaterEqual,
140    /// LessEqual
141    LessEqual,
142    /// Equal
143    Equal,
144    /// NotEqual
145    NotEqual,
146}
147
148/// Data merging strategies
149#[derive(Debug)]
150pub enum MergeStrategy {
151    /// Concatenate along features (horizontal)
152    HorizontalConcat,
153    /// Concatenate along samples (vertical)
154    VerticalConcat,
155    /// Average outputs
156    Average,
157    /// Weighted average
158    WeightedAverage { weights: Vec<f64> },
159    /// Maximum values
160    Maximum,
161    /// Minimum values
162    Minimum,
163    /// Custom merge function
164    Custom {
165        merge_fn: fn(&[NodeOutput]) -> SklResult<NodeOutput>,
166    },
167}
168
169/// Node execution configuration
170#[derive(Debug, Clone)]
171pub struct NodeConfig {
172    /// Whether node can be executed in parallel
173    pub parallel_execution: bool,
174    /// Maximum execution time (seconds)
175    pub timeout: Option<f64>,
176    /// Retry attempts on failure
177    pub retry_attempts: usize,
178    /// Cache output
179    pub cache_output: bool,
180    /// Resource requirements
181    pub resource_requirements: ResourceRequirements,
182}
183
184impl Default for NodeConfig {
185    fn default() -> Self {
186        Self {
187            parallel_execution: true,
188            timeout: None,
189            retry_attempts: 0,
190            cache_output: false,
191            resource_requirements: ResourceRequirements::default(),
192        }
193    }
194}
195
196/// Resource requirements for node execution
197#[derive(Debug, Clone, Default)]
198pub struct ResourceRequirements {
199    /// Memory requirement (MB)
200    pub memory_mb: Option<usize>,
201    /// CPU cores required
202    pub cpu_cores: Option<usize>,
203    /// GPU requirement
204    pub gpu_required: bool,
205}
206
207/// Output from a DAG node
208#[derive(Debug, Clone)]
209pub struct NodeOutput {
210    /// Output data
211    pub data: Array2<f64>,
212    /// Output targets (optional)
213    pub targets: Option<Array1<f64>>,
214    /// Output metadata
215    pub metadata: HashMap<String, String>,
216    /// Execution statistics
217    pub execution_stats: ExecutionStats,
218}
219
220/// Execution statistics for a node
221#[derive(Debug, Clone)]
222pub struct ExecutionStats {
223    /// Execution time (seconds)
224    pub execution_time: f64,
225    /// Memory usage (MB)
226    pub memory_usage: f64,
227    /// Success status
228    pub success: bool,
229    /// Error message (if any)
230    pub error_message: Option<String>,
231}
232
233impl Default for ExecutionStats {
234    fn default() -> Self {
235        Self {
236            execution_time: 0.0,
237            memory_usage: 0.0,
238            success: true,
239            error_message: None,
240        }
241    }
242}
243
244/// DAG pipeline structure
245#[derive(Debug)]
246pub struct DAGPipeline<S = Untrained> {
247    state: S,
248    nodes: HashMap<String, DAGNode>,
249    edges: HashMap<String, HashSet<String>>, // node_id -> dependencies
250    execution_order: Vec<String>,
251    parallel_groups: Vec<Vec<String>>,
252    cache: HashMap<String, NodeOutput>,
253}
254
255/// Trained state for `DAGPipeline`
256#[derive(Debug)]
257pub struct DAGPipelineTrained {
258    fitted_nodes: HashMap<String, DAGNode>,
259    edges: HashMap<String, HashSet<String>>,
260    execution_order: Vec<String>,
261    parallel_groups: Vec<Vec<String>>,
262    cache: HashMap<String, NodeOutput>,
263    execution_history: Vec<ExecutionRecord>,
264    n_features_in: usize,
265    feature_names_in: Option<Vec<String>>,
266}
267
268/// Record of pipeline execution
269#[derive(Debug, Clone)]
270pub struct ExecutionRecord {
271    /// Execution timestamp
272    pub timestamp: f64,
273    /// Executed nodes
274    pub executed_nodes: Vec<String>,
275    /// Total execution time
276    pub total_time: f64,
277    /// Success status
278    pub success: bool,
279    /// Error details
280    pub errors: Vec<(String, String)>, // (node_id, error_message)
281}
282
283impl DAGPipeline<Untrained> {
284    /// Create a new DAG pipeline
285    #[must_use]
286    pub fn new() -> Self {
287        Self {
288            state: Untrained,
289            nodes: HashMap::new(),
290            edges: HashMap::new(),
291            execution_order: Vec::new(),
292            parallel_groups: Vec::new(),
293            cache: HashMap::new(),
294        }
295    }
296
297    /// Add a node to the DAG
298    pub fn add_node(mut self, node: DAGNode) -> SklResult<Self> {
299        // Check for duplicate node IDs
300        if self.nodes.contains_key(&node.id) {
301            return Err(SklearsError::InvalidInput(format!(
302                "Node with ID '{}' already exists",
303                node.id
304            )));
305        }
306
307        // Add dependencies to edges
308        let node_id = node.id.clone();
309        self.edges
310            .insert(node_id.clone(), node.dependencies.iter().cloned().collect());
311
312        // Add node
313        self.nodes.insert(node_id, node);
314
315        // Recompute execution order
316        self.compute_execution_order()?;
317
318        Ok(self)
319    }
320
321    /// Add an edge between nodes
322    pub fn add_edge(mut self, from_node: &str, to_node: &str) -> SklResult<Self> {
323        // Check if nodes exist
324        if !self.nodes.contains_key(from_node) {
325            return Err(SklearsError::InvalidInput(format!(
326                "Source node '{from_node}' does not exist"
327            )));
328        }
329        if !self.nodes.contains_key(to_node) {
330            return Err(SklearsError::InvalidInput(format!(
331                "Target node '{to_node}' does not exist"
332            )));
333        }
334
335        // Add edge
336        self.edges
337            .entry(to_node.to_string())
338            .or_default()
339            .insert(from_node.to_string());
340
341        // Update node dependencies
342        if let Some(to_node_obj) = self.nodes.get_mut(to_node) {
343            if !to_node_obj.dependencies.contains(&from_node.to_string()) {
344                to_node_obj.dependencies.push(from_node.to_string());
345            }
346        }
347
348        // Update node consumers
349        if let Some(from_node_obj) = self.nodes.get_mut(from_node) {
350            if !from_node_obj.consumers.contains(&to_node.to_string()) {
351                from_node_obj.consumers.push(to_node.to_string());
352            }
353        }
354
355        // Check for cycles
356        if self.has_cycles()? {
357            return Err(SklearsError::InvalidInput(
358                "Adding edge would create a cycle in the DAG".to_string(),
359            ));
360        }
361
362        // Recompute execution order
363        self.compute_execution_order()?;
364
365        Ok(self)
366    }
367
368    /// Check if the DAG has cycles
369    fn has_cycles(&self) -> SklResult<bool> {
370        let mut visited = HashSet::new();
371        let mut rec_stack = HashSet::new();
372
373        for node_id in self.nodes.keys() {
374            if !visited.contains(node_id)
375                && self.dfs_cycle_check(node_id, &mut visited, &mut rec_stack)?
376            {
377                return Ok(true);
378            }
379        }
380
381        Ok(false)
382    }
383
384    /// DFS-based cycle detection
385    fn dfs_cycle_check(
386        &self,
387        node_id: &str,
388        visited: &mut HashSet<String>,
389        rec_stack: &mut HashSet<String>,
390    ) -> SklResult<bool> {
391        visited.insert(node_id.to_string());
392        rec_stack.insert(node_id.to_string());
393
394        if let Some(dependencies) = self.edges.get(node_id) {
395            for dep in dependencies {
396                if !visited.contains(dep) {
397                    if self.dfs_cycle_check(dep, visited, rec_stack)? {
398                        return Ok(true);
399                    }
400                } else if rec_stack.contains(dep) {
401                    return Ok(true);
402                }
403            }
404        }
405
406        rec_stack.remove(node_id);
407        Ok(false)
408    }
409
410    /// Compute topological execution order
411    fn compute_execution_order(&mut self) -> SklResult<()> {
412        let mut in_degree = HashMap::new();
413        let mut queue = VecDeque::new();
414        let mut order = Vec::new();
415        let mut parallel_groups = Vec::new();
416
417        // Initialize in-degrees
418        for node_id in self.nodes.keys() {
419            in_degree.insert(node_id.clone(), 0);
420        }
421
422        // Compute in-degrees
423        for (node_id, dependencies) in &self.edges {
424            in_degree.insert(node_id.clone(), dependencies.len());
425        }
426
427        // Find nodes with no dependencies
428        for (node_id, &degree) in &in_degree {
429            if degree == 0 {
430                queue.push_back(node_id.clone());
431            }
432        }
433
434        // Process nodes level by level for parallel execution
435        while !queue.is_empty() {
436            let current_level: Vec<String> = queue.drain(..).collect();
437            parallel_groups.push(current_level.clone());
438            order.extend(current_level.iter().cloned());
439
440            // Process current level
441            for node_id in &current_level {
442                // Update in-degrees of consumers
443                if let Some(node) = self.nodes.get(node_id) {
444                    for consumer in &node.consumers {
445                        if let Some(degree) = in_degree.get_mut(consumer) {
446                            *degree -= 1;
447                            if *degree == 0 {
448                                queue.push_back(consumer.clone());
449                            }
450                        }
451                    }
452                }
453            }
454        }
455
456        // Check if all nodes are processed (no cycles)
457        if order.len() != self.nodes.len() {
458            return Err(SklearsError::InvalidInput(
459                "DAG contains cycles".to_string(),
460            ));
461        }
462
463        self.execution_order = order;
464        self.parallel_groups = parallel_groups;
465
466        Ok(())
467    }
468
469    /// Create a linear pipeline from components
470    pub fn linear(components: Vec<(String, Box<dyn PipelineStep>)>) -> SklResult<Self> {
471        let mut dag = Self::new();
472        let num_components = components.len();
473
474        for (i, (name, component)) in components.into_iter().enumerate() {
475            let dependencies = if i == 0 {
476                Vec::new()
477            } else {
478                vec![format!("node_{}", i - 1)]
479            };
480
481            let node = DAGNode {
482                id: format!("node_{i}"),
483                name,
484                component: NodeComponent::Transformer(component),
485                dependencies,
486                consumers: if i == num_components - 1 {
487                    Vec::new()
488                } else {
489                    vec![format!("node_{}", i + 1)]
490                },
491                metadata: HashMap::new(),
492                config: NodeConfig::default(),
493            };
494
495            dag = dag.add_node(node)?;
496        }
497
498        Ok(dag)
499    }
500
501    /// Create a parallel pipeline with final merger
502    pub fn parallel(
503        components: Vec<(String, Box<dyn PipelineStep>)>,
504        merge_strategy: MergeStrategy,
505    ) -> SklResult<Self> {
506        let mut dag = Self::new();
507
508        let num_components = components.len();
509
510        // Add parallel components
511        for (i, (name, component)) in components.into_iter().enumerate() {
512            let node = DAGNode {
513                id: format!("parallel_{i}"),
514                name,
515                component: NodeComponent::Transformer(component),
516                dependencies: Vec::new(),
517                consumers: vec!["merger".to_string()],
518                metadata: HashMap::new(),
519                config: NodeConfig::default(),
520            };
521
522            dag = dag.add_node(node)?;
523        }
524
525        // Add merger node
526        let merger_dependencies: Vec<String> = (0..num_components)
527            .map(|i| format!("parallel_{i}"))
528            .collect();
529
530        let merger_node = DAGNode {
531            id: "merger".to_string(),
532            name: "Data Merger".to_string(),
533            component: NodeComponent::DataMerger { merge_strategy },
534            dependencies: merger_dependencies,
535            consumers: Vec::new(),
536            metadata: HashMap::new(),
537            config: NodeConfig::default(),
538        };
539
540        dag = dag.add_node(merger_node)?;
541
542        Ok(dag)
543    }
544}
545
546impl Default for DAGPipeline<Untrained> {
547    fn default() -> Self {
548        Self::new()
549    }
550}
551
552impl Estimator for DAGPipeline<Untrained> {
553    type Config = ();
554    type Error = SklearsError;
555    type Float = Float;
556
557    fn config(&self) -> &Self::Config {
558        &()
559    }
560}
561
562impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for DAGPipeline<Untrained> {
563    type Fitted = DAGPipeline<DAGPipelineTrained>;
564
565    fn fit(
566        mut self,
567        x: &ArrayView2<'_, Float>,
568        y: &Option<&ArrayView1<'_, Float>>,
569    ) -> SklResult<Self::Fitted> {
570        let mut fitted_nodes = HashMap::new();
571        let mut execution_errors = Vec::new();
572        let start_time = std::time::SystemTime::now()
573            .duration_since(std::time::UNIX_EPOCH)
574            .unwrap()
575            .as_secs_f64();
576
577        // Initialize with input data
578        let initial_output = NodeOutput {
579            data: x.mapv(|v| v),
580            targets: y.as_ref().map(|y_vals| y_vals.mapv(|v| v)),
581            metadata: HashMap::new(),
582            execution_stats: ExecutionStats::default(),
583        };
584        self.cache.insert("input".to_string(), initial_output);
585
586        // Execute nodes in topological order
587        let parallel_groups = std::mem::take(&mut self.parallel_groups);
588        for group in &parallel_groups {
589            // Execute parallel group
590            let group_results = self.execute_parallel_group(group)?;
591
592            for (node_id, result) in group_results {
593                match result {
594                    Ok(output) => {
595                        self.cache.insert(node_id.clone(), output);
596                        if let Some(node) = self.nodes.remove(&node_id) {
597                            fitted_nodes.insert(node_id, node);
598                        }
599                    }
600                    Err(e) => {
601                        execution_errors.push((node_id, e.to_string()));
602                    }
603                }
604            }
605        }
606
607        let end_time = std::time::SystemTime::now()
608            .duration_since(std::time::UNIX_EPOCH)
609            .unwrap()
610            .as_secs_f64();
611
612        let execution_record = ExecutionRecord {
613            timestamp: start_time,
614            executed_nodes: fitted_nodes.keys().cloned().collect(),
615            total_time: end_time - start_time,
616            success: execution_errors.is_empty(),
617            errors: execution_errors,
618        };
619
620        Ok(DAGPipeline {
621            state: DAGPipelineTrained {
622                fitted_nodes,
623                edges: self.edges,
624                execution_order: self.execution_order,
625                parallel_groups,
626                cache: self.cache,
627                execution_history: vec![execution_record],
628                n_features_in: x.ncols(),
629                feature_names_in: None,
630            },
631            nodes: HashMap::new(),
632            edges: HashMap::new(),
633            execution_order: Vec::new(),
634            parallel_groups: Vec::new(),
635            cache: HashMap::new(),
636        })
637    }
638}
639
640impl DAGPipeline<Untrained> {
641    /// Execute a group of nodes in parallel
642    fn execute_parallel_group(
643        &mut self,
644        group: &[String],
645    ) -> SklResult<Vec<(String, SklResult<NodeOutput>)>> {
646        let mut results = Vec::new();
647
648        for node_id in group {
649            if let Some(node) = self.nodes.remove(node_id) {
650                let result = self.execute_node(&node);
651                results.push((node_id.clone(), result));
652                // Put the node back
653                self.nodes.insert(node_id.clone(), node);
654            }
655        }
656
657        Ok(results)
658    }
659
660    /// Execute a single node
661    fn execute_node(&mut self, node: &DAGNode) -> SklResult<NodeOutput> {
662        let start_time = std::time::SystemTime::now();
663
664        // Collect inputs from dependencies
665        let mut inputs = Vec::new();
666        for dep_id in &node.dependencies {
667            if let Some(output) = self.cache.get(dep_id) {
668                inputs.push(output.clone());
669            } else if dep_id == "input" {
670                // Handle initial input - skip to next dependency
671            } else {
672                return Err(SklearsError::InvalidInput(format!(
673                    "Missing input from dependency: {dep_id}"
674                )));
675            }
676        }
677
678        // If no dependencies, use input data
679        if inputs.is_empty() && self.cache.contains_key("input") {
680            inputs.push(self.cache["input"].clone());
681        }
682
683        // Execute based on component type
684        let result = match &node.component {
685            NodeComponent::Transformer(transformer) => {
686                if let Some(input) = inputs.first() {
687                    let mapped_data = input.data.view().mapv(|v| v as Float);
688                    let transformed = transformer.transform(&mapped_data.view())?;
689                    Ok(NodeOutput {
690                        data: transformed,
691                        targets: input.targets.clone(),
692                        metadata: HashMap::new(),
693                        execution_stats: ExecutionStats::default(),
694                    })
695                } else {
696                    Err(SklearsError::InvalidInput(
697                        "No input data for transformer".to_string(),
698                    ))
699                }
700            }
701            NodeComponent::DataMerger { merge_strategy } => {
702                self.execute_data_merger(&inputs, merge_strategy)
703            }
704            NodeComponent::ConditionalBranch {
705                condition,
706                true_path,
707                false_path,
708            } => self.execute_conditional_branch(&inputs, condition, true_path, false_path),
709            NodeComponent::DataSource { data, targets } => {
710                if let Some(ref source_data) = data {
711                    Ok(NodeOutput {
712                        data: source_data.clone(),
713                        targets: targets.clone(),
714                        metadata: HashMap::new(),
715                        execution_stats: ExecutionStats::default(),
716                    })
717                } else {
718                    Err(SklearsError::InvalidInput(
719                        "No data in data source".to_string(),
720                    ))
721                }
722            }
723            NodeComponent::DataSink => {
724                // Just pass through the input
725                inputs
726                    .into_iter()
727                    .next()
728                    .ok_or_else(|| SklearsError::InvalidInput("No input for data sink".to_string()))
729            }
730            NodeComponent::Estimator(_) => {
731                // For fitting, estimators don't produce output data
732                if let Some(input) = inputs.first() {
733                    Ok(input.clone())
734                } else {
735                    Err(SklearsError::InvalidInput(
736                        "No input data for estimator".to_string(),
737                    ))
738                }
739            }
740            NodeComponent::CustomFunction { function } => function(&inputs),
741        };
742
743        // Record execution time
744        let execution_time = start_time.elapsed().unwrap().as_secs_f64();
745        if let Ok(ref mut output) = result.clone() {
746            output.execution_stats.execution_time = execution_time;
747        }
748
749        result
750    }
751
752    /// Execute data merger
753    fn execute_data_merger(
754        &self,
755        inputs: &[NodeOutput],
756        strategy: &MergeStrategy,
757    ) -> SklResult<NodeOutput> {
758        if inputs.is_empty() {
759            return Err(SklearsError::InvalidInput("No inputs to merge".to_string()));
760        }
761
762        let merged_data = match strategy {
763            MergeStrategy::HorizontalConcat => {
764                let total_cols: usize = inputs.iter().map(|inp| inp.data.ncols()).sum();
765                let n_rows = inputs[0].data.nrows();
766
767                let mut merged = Array2::zeros((n_rows, total_cols));
768                let mut col_offset = 0;
769
770                for input in inputs {
771                    let cols = input.data.ncols();
772                    merged
773                        .slice_mut(s![.., col_offset..col_offset + cols])
774                        .assign(&input.data);
775                    col_offset += cols;
776                }
777
778                merged
779            }
780            MergeStrategy::VerticalConcat => {
781                let n_cols = inputs[0].data.ncols();
782                let total_rows: usize = inputs.iter().map(|inp| inp.data.nrows()).sum();
783
784                let mut merged = Array2::zeros((total_rows, n_cols));
785                let mut row_offset = 0;
786
787                for input in inputs {
788                    let rows = input.data.nrows();
789                    merged
790                        .slice_mut(s![row_offset..row_offset + rows, ..])
791                        .assign(&input.data);
792                    row_offset += rows;
793                }
794
795                merged
796            }
797            MergeStrategy::Average => {
798                let mut sum = inputs[0].data.clone();
799                for input in inputs.iter().skip(1) {
800                    sum += &input.data;
801                }
802                sum / inputs.len() as f64
803            }
804            MergeStrategy::WeightedAverage { weights } => {
805                if weights.len() != inputs.len() {
806                    return Err(SklearsError::InvalidInput(
807                        "Number of weights must match number of inputs".to_string(),
808                    ));
809                }
810
811                let mut weighted_sum = &inputs[0].data * weights[0];
812                for (input, &weight) in inputs.iter().skip(1).zip(weights.iter().skip(1)) {
813                    weighted_sum += &(&input.data * weight);
814                }
815
816                weighted_sum
817            }
818            MergeStrategy::Maximum => {
819                let mut max_data = inputs[0].data.clone();
820                for input in inputs.iter().skip(1) {
821                    for ((i, j), &val) in input.data.indexed_iter() {
822                        if val > max_data[(i, j)] {
823                            max_data[(i, j)] = val;
824                        }
825                    }
826                }
827                max_data
828            }
829            MergeStrategy::Minimum => {
830                let mut min_data = inputs[0].data.clone();
831                for input in inputs.iter().skip(1) {
832                    for ((i, j), &val) in input.data.indexed_iter() {
833                        if val < min_data[(i, j)] {
834                            min_data[(i, j)] = val;
835                        }
836                    }
837                }
838                min_data
839            }
840            MergeStrategy::Custom { merge_fn } => {
841                return merge_fn(inputs);
842            }
843        };
844
845        Ok(NodeOutput {
846            data: merged_data,
847            targets: inputs[0].targets.clone(),
848            metadata: HashMap::new(),
849            execution_stats: ExecutionStats::default(),
850        })
851    }
852
853    /// Execute conditional branch
854    fn execute_conditional_branch(
855        &self,
856        inputs: &[NodeOutput],
857        condition: &BranchCondition,
858        true_path: &str,
859        false_path: &str,
860    ) -> SklResult<NodeOutput> {
861        if inputs.is_empty() {
862            return Err(SklearsError::InvalidInput(
863                "No input for conditional branch".to_string(),
864            ));
865        }
866
867        let input = &inputs[0];
868        let condition_result = match condition {
869            BranchCondition::FeatureThreshold {
870                feature_idx,
871                threshold,
872                comparison,
873            } => {
874                if *feature_idx >= input.data.ncols() {
875                    return Err(SklearsError::InvalidInput(
876                        "Feature index out of bounds".to_string(),
877                    ));
878                }
879
880                let feature_values = input.data.column(*feature_idx);
881                let mean_value = feature_values.mean().unwrap_or(0.0);
882
883                match comparison {
884                    ComparisonOp::GreaterThan => mean_value > *threshold,
885                    ComparisonOp::LessThan => mean_value < *threshold,
886                    ComparisonOp::GreaterEqual => mean_value >= *threshold,
887                    ComparisonOp::LessEqual => mean_value <= *threshold,
888                    ComparisonOp::Equal => (mean_value - threshold).abs() < 1e-8,
889                    ComparisonOp::NotEqual => (mean_value - threshold).abs() >= 1e-8,
890                }
891            }
892            BranchCondition::DataSize {
893                min_samples,
894                max_samples,
895            } => {
896                let n_samples = input.data.nrows();
897                let min_ok = min_samples.map_or(true, |min| n_samples >= min);
898                let max_ok = max_samples.map_or(true, |max| n_samples <= max);
899                min_ok && max_ok
900            }
901            BranchCondition::Custom { condition_fn } => condition_fn(input),
902        };
903
904        // For now, just return the input (branch execution would be more complex)
905        let mut output = input.clone();
906        output.metadata.insert(
907            "branch_taken".to_string(),
908            if condition_result {
909                true_path.to_string()
910            } else {
911                false_path.to_string()
912            },
913        );
914
915        Ok(output)
916    }
917}
918
919impl DAGPipeline<DAGPipelineTrained> {
920    /// Transform data through the fitted DAG
921    pub fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
922        // This is a simplified version - in practice, we'd re-execute the DAG
923        // For now, return the input data as f64
924        Ok(x.mapv(|v| v))
925    }
926
927    /// Get execution history
928    #[must_use]
929    pub fn execution_history(&self) -> &[ExecutionRecord] {
930        &self.state.execution_history
931    }
932
933    /// Get DAG statistics
934    #[must_use]
935    pub fn statistics(&self) -> HashMap<String, f64> {
936        let mut stats = HashMap::new();
937        stats.insert(
938            "total_nodes".to_string(),
939            self.state.fitted_nodes.len() as f64,
940        );
941        stats.insert(
942            "parallel_groups".to_string(),
943            self.state.parallel_groups.len() as f64,
944        );
945
946        if let Some(last_execution) = self.state.execution_history.last() {
947            stats.insert("last_execution_time".to_string(), last_execution.total_time);
948            stats.insert(
949                "last_execution_success".to_string(),
950                if last_execution.success { 1.0 } else { 0.0 },
951            );
952        }
953
954        stats
955    }
956
957    /// Visualize DAG structure (returns DOT format)
958    #[must_use]
959    pub fn to_dot(&self) -> String {
960        let mut dot = String::from("digraph DAG {\n");
961
962        // Add nodes
963        for (node_id, node) in &self.state.fitted_nodes {
964            dot.push_str(&format!("  \"{}\" [label=\"{}\"];\n", node_id, node.name));
965        }
966
967        // Add edges
968        for (to_node, dependencies) in &self.state.edges {
969            for from_node in dependencies {
970                dot.push_str(&format!("  \"{from_node}\" -> \"{to_node}\";\n"));
971            }
972        }
973
974        dot.push_str("}\n");
975        dot
976    }
977}
978
979// Import ndarray slice macro
980use scirs2_core::ndarray::s;
981
982#[allow(non_snake_case)]
983#[cfg(test)]
984mod tests {
985    use super::*;
986    use crate::MockTransformer;
987    use scirs2_core::ndarray::array;
988
989    #[test]
990    fn test_dag_node_creation() {
991        let node = DAGNode {
992            id: "test_node".to_string(),
993            name: "Test Node".to_string(),
994            component: NodeComponent::DataSource {
995                data: Some(array![[1.0, 2.0], [3.0, 4.0]]),
996                targets: Some(array![1.0, 0.0]),
997            },
998            dependencies: Vec::new(),
999            consumers: Vec::new(),
1000            metadata: HashMap::new(),
1001            config: NodeConfig::default(),
1002        };
1003
1004        assert_eq!(node.id, "test_node");
1005        assert_eq!(node.name, "Test Node");
1006    }
1007
1008    #[test]
1009    fn test_linear_dag() {
1010        let components = vec![
1011            (
1012                "transformer1".to_string(),
1013                Box::new(MockTransformer::new()) as Box<dyn PipelineStep>,
1014            ),
1015            (
1016                "transformer2".to_string(),
1017                Box::new(MockTransformer::new()) as Box<dyn PipelineStep>,
1018            ),
1019        ];
1020
1021        let dag = DAGPipeline::linear(components).unwrap();
1022        assert_eq!(dag.nodes.len(), 2);
1023        assert_eq!(dag.execution_order.len(), 2);
1024    }
1025
1026    #[test]
1027    fn test_parallel_dag() {
1028        let components = vec![
1029            (
1030                "transformer1".to_string(),
1031                Box::new(MockTransformer::new()) as Box<dyn PipelineStep>,
1032            ),
1033            (
1034                "transformer2".to_string(),
1035                Box::new(MockTransformer::new()) as Box<dyn PipelineStep>,
1036            ),
1037        ];
1038
1039        let dag = DAGPipeline::parallel(components, MergeStrategy::HorizontalConcat).unwrap();
1040        assert_eq!(dag.nodes.len(), 3); // 2 transformers + 1 merger
1041    }
1042
1043    #[test]
1044    fn test_cycle_detection() {
1045        let mut dag = DAGPipeline::new();
1046
1047        // First add nodes without circular dependencies
1048        let node1 = DAGNode {
1049            id: "node1".to_string(),
1050            name: "Node 1".to_string(),
1051            component: NodeComponent::DataSource {
1052                data: None,
1053                targets: None,
1054            },
1055            dependencies: vec![],
1056            consumers: vec![],
1057            metadata: HashMap::new(),
1058            config: NodeConfig::default(),
1059        };
1060
1061        let node2 = DAGNode {
1062            id: "node2".to_string(),
1063            name: "Node 2".to_string(),
1064            component: NodeComponent::DataSource {
1065                data: None,
1066                targets: None,
1067            },
1068            dependencies: vec![],
1069            consumers: vec![],
1070            metadata: HashMap::new(),
1071            config: NodeConfig::default(),
1072        };
1073
1074        dag = dag.add_node(node1).unwrap();
1075        dag = dag.add_node(node2).unwrap();
1076
1077        // Add edges to create a cycle
1078        dag = dag.add_edge("node1", "node2").unwrap();
1079
1080        // Adding a reverse edge should detect the cycle
1081        assert!(dag.add_edge("node2", "node1").is_err());
1082    }
1083
1084    #[test]
1085    fn test_merge_strategies() {
1086        let input1 = NodeOutput {
1087            data: array![[1.0, 2.0], [3.0, 4.0]],
1088            targets: None,
1089            metadata: HashMap::new(),
1090            execution_stats: ExecutionStats::default(),
1091        };
1092
1093        let input2 = NodeOutput {
1094            data: array![[5.0, 6.0], [7.0, 8.0]],
1095            targets: None,
1096            metadata: HashMap::new(),
1097            execution_stats: ExecutionStats::default(),
1098        };
1099
1100        let inputs = vec![input1, input2];
1101        let dag = DAGPipeline::new();
1102
1103        // Test horizontal concatenation
1104        let result = dag
1105            .execute_data_merger(&inputs, &MergeStrategy::HorizontalConcat)
1106            .unwrap();
1107        assert_eq!(result.data.ncols(), 4);
1108        assert_eq!(result.data.nrows(), 2);
1109
1110        // Test average
1111        let result = dag
1112            .execute_data_merger(&inputs, &MergeStrategy::Average)
1113            .unwrap();
1114        assert_eq!(result.data[[0, 0]], 3.0); // (1+5)/2
1115        assert_eq!(result.data[[0, 1]], 4.0); // (2+6)/2
1116    }
1117}