sklears_multioutput/
tree.rs

1//! Tree-based multi-output algorithms
2//!
3//! This module provides tree-based algorithms for multi-output prediction problems,
4//! including decision trees, random forests, and structured predictors.
5
6use crate::multi_label::{BinaryRelevance, BinaryRelevanceTrained};
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    traits::{Estimator, Fit, Predict, Untrained},
12    types::Float,
13};
14use std::collections::HashMap;
15
16// Tree-related enums and helper structures
17
18/// Classification criterion for decision trees
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ClassificationCriterion {
21    /// Gini impurity
22    Gini,
23    /// Information gain / entropy
24    Entropy,
25}
26
27/// DAG inference methods
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum DAGInferenceMethod {
30    /// Greedy inference following topological order
31    Greedy,
32    /// Belief propagation on the DAG
33    BeliefPropagation,
34    /// Integer linear programming for exact inference
35    ExactILP,
36}
37
38#[derive(Debug, Clone)]
39struct DecisionNode {
40    is_leaf: bool,
41    prediction: Option<Array1<Float>>, // Mean values for each target
42    feature_idx: Option<usize>,
43    threshold: Option<Float>,
44    left: Option<Box<DecisionNode>>,
45    right: Option<Box<DecisionNode>>,
46    n_samples: usize,
47    variance: Float, // Sum of variances across all targets
48}
49
50#[derive(Debug, Clone)]
51pub struct ClassificationDecisionNode {
52    is_leaf: bool,
53    prediction: Option<Array1<i32>>, // Mode/majority class for each target
54    probabilities: Option<Array2<Float>>, // Probability distributions per target
55    feature_idx: Option<usize>,
56    threshold: Option<Float>,
57    left: Option<Box<ClassificationDecisionNode>>,
58    right: Option<Box<ClassificationDecisionNode>>,
59    n_samples: usize,
60    impurity: Float, // Combined impurity across all targets
61}
62
63/// Multi-Target Regression Tree
64///
65/// A decision tree regressor that can handle multiple target variables simultaneously.
66/// Uses joint variance reduction for optimal splits across all targets.
67///
68/// # Examples
69///
70/// ```
71/// use sklears_multioutput::MultiTargetRegressionTree;
72/// use sklears_core::traits::{Predict, Fit};
73/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
74/// use scirs2_core::ndarray::array;
75///
76/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
77/// let y = array![[1.5, 2.5], [2.5, 3.5], [3.5, 1.5], [4.5, 4.5]];
78///
79/// let tree = MultiTargetRegressionTree::new()
80///     .max_depth(Some(3))
81///     .min_samples_split(2);
82/// let trained_tree = tree.fit(&X.view(), &y).unwrap();
83/// let predictions = trained_tree.predict(&X.view()).unwrap();
84/// ```
85#[derive(Debug, Clone)]
86pub struct MultiTargetRegressionTree<S = Untrained> {
87    state: S,
88    max_depth: Option<usize>,
89    min_samples_split: usize,
90    min_samples_leaf: usize,
91    random_state: Option<u64>,
92}
93
94#[derive(Debug, Clone)]
95pub struct MultiTargetRegressionTreeTrained {
96    tree: DecisionNode,
97    n_features: usize,
98    n_targets: usize,
99    feature_importances: Array1<Float>,
100}
101
102impl MultiTargetRegressionTree<Untrained> {
103    /// Create a new MultiTargetRegressionTree instance
104    pub fn new() -> Self {
105        Self {
106            state: Untrained,
107            max_depth: Some(5),
108            min_samples_split: 2,
109            min_samples_leaf: 1,
110            random_state: None,
111        }
112    }
113
114    /// Set the maximum depth of the tree
115    pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
116        self.max_depth = max_depth;
117        self
118    }
119
120    /// Set the minimum number of samples required to split an internal node
121    pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
122        self.min_samples_split = min_samples_split;
123        self
124    }
125
126    /// Set the minimum number of samples required to be at a leaf node
127    pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
128        self.min_samples_leaf = min_samples_leaf;
129        self
130    }
131
132    /// Set the random state for reproducible results
133    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
134        self.random_state = random_state;
135        self
136    }
137
138    /// Get the maximum depth of the tree
139    pub fn get_max_depth(&self) -> Option<usize> {
140        self.max_depth
141    }
142
143    /// Get the minimum number of samples required to split an internal node
144    pub fn get_min_samples_split(&self) -> usize {
145        self.min_samples_split
146    }
147
148    /// Get the minimum number of samples required to be at a leaf node
149    pub fn get_min_samples_leaf(&self) -> usize {
150        self.min_samples_leaf
151    }
152
153    /// Get the random state
154    pub fn get_random_state(&self) -> Option<u64> {
155        self.random_state
156    }
157}
158
159impl Default for MultiTargetRegressionTree<Untrained> {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165impl Estimator for MultiTargetRegressionTree<Untrained> {
166    type Config = ();
167    type Error = SklearsError;
168    type Float = Float;
169
170    fn config(&self) -> &Self::Config {
171        &()
172    }
173}
174
175impl Fit<ArrayView2<'_, Float>, Array2<Float>> for MultiTargetRegressionTree<Untrained> {
176    type Fitted = MultiTargetRegressionTree<MultiTargetRegressionTreeTrained>;
177
178    #[allow(non_snake_case)]
179    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<Float>) -> SklResult<Self::Fitted> {
180        let X = X.to_owned();
181        let (n_samples, n_features) = X.dim();
182
183        if n_samples != y.nrows() {
184            return Err(SklearsError::InvalidInput(
185                "X and y must have the same number of samples".to_string(),
186            ));
187        }
188
189        let n_targets = y.ncols();
190        if n_targets == 0 {
191            return Err(SklearsError::InvalidInput(
192                "y must have at least one target".to_string(),
193            ));
194        }
195
196        if n_samples < self.min_samples_split {
197            return Err(SklearsError::InvalidInput(
198                "Number of samples is less than min_samples_split".to_string(),
199            ));
200        }
201
202        // Build the tree
203        let indices: Vec<usize> = (0..n_samples).collect();
204        let tree = self.build_tree(&X, y, &indices, 0)?;
205
206        // Calculate feature importances (simplified)
207        let mut feature_importances = Array1::<Float>::zeros(n_features);
208        self.calculate_feature_importances(&tree, &mut feature_importances, n_samples as Float);
209
210        // Normalize feature importances
211        let sum_importances: Float = feature_importances.sum();
212        if sum_importances > 0.0 {
213            feature_importances /= sum_importances;
214        }
215
216        Ok(MultiTargetRegressionTree {
217            state: MultiTargetRegressionTreeTrained {
218                tree,
219                n_features,
220                n_targets,
221                feature_importances,
222            },
223            max_depth: self.max_depth,
224            min_samples_split: self.min_samples_split,
225            min_samples_leaf: self.min_samples_leaf,
226            random_state: self.random_state,
227        })
228    }
229}
230
231impl MultiTargetRegressionTree<Untrained> {
232    fn build_tree(
233        &self,
234        X: &Array2<Float>,
235        y: &Array2<Float>,
236        indices: &[usize],
237        depth: usize,
238    ) -> SklResult<DecisionNode> {
239        let n_samples = indices.len();
240        let n_targets = y.ncols();
241
242        // Calculate current prediction (mean of targets)
243        let mut prediction = Array1::<Float>::zeros(n_targets);
244        for &idx in indices {
245            for j in 0..n_targets {
246                prediction[j] += y[[idx, j]];
247            }
248        }
249        prediction /= n_samples as Float;
250
251        // Calculate variance across all targets
252        let mut variance = 0.0;
253        for &idx in indices {
254            for j in 0..n_targets {
255                let diff = y[[idx, j]] - prediction[j];
256                variance += diff * diff;
257            }
258        }
259        variance /= n_samples as Float;
260
261        // Check stopping criteria
262        let should_stop = n_samples < self.min_samples_split
263            || n_samples < self.min_samples_leaf
264            || self.max_depth.is_some_and(|max_d| depth >= max_d)
265            || variance < 1e-10;
266
267        if should_stop {
268            return Ok(DecisionNode {
269                is_leaf: true,
270                prediction: Some(prediction),
271                feature_idx: None,
272                threshold: None,
273                left: None,
274                right: None,
275                n_samples,
276                variance,
277            });
278        }
279
280        // Find best split
281        let (best_feature, best_threshold, best_variance_reduction) =
282            self.find_best_split(X, y, indices)?;
283
284        if best_variance_reduction <= 0.0 {
285            return Ok(DecisionNode {
286                is_leaf: true,
287                prediction: Some(prediction),
288                feature_idx: None,
289                threshold: None,
290                left: None,
291                right: None,
292                n_samples,
293                variance,
294            });
295        }
296
297        // Split the data
298        let (left_indices, right_indices) =
299            self.split_data(X, indices, best_feature, best_threshold);
300
301        if left_indices.len() < self.min_samples_leaf || right_indices.len() < self.min_samples_leaf
302        {
303            return Ok(DecisionNode {
304                is_leaf: true,
305                prediction: Some(prediction),
306                feature_idx: None,
307                threshold: None,
308                left: None,
309                right: None,
310                n_samples,
311                variance,
312            });
313        }
314
315        // Recursively build child nodes
316        let left_child = self.build_tree(X, y, &left_indices, depth + 1)?;
317        let right_child = self.build_tree(X, y, &right_indices, depth + 1)?;
318
319        Ok(DecisionNode {
320            is_leaf: false,
321            prediction: None,
322            feature_idx: Some(best_feature),
323            threshold: Some(best_threshold),
324            left: Some(Box::new(left_child)),
325            right: Some(Box::new(right_child)),
326            n_samples,
327            variance,
328        })
329    }
330
331    fn find_best_split(
332        &self,
333        X: &Array2<Float>,
334        y: &Array2<Float>,
335        indices: &[usize],
336    ) -> SklResult<(usize, Float, Float)> {
337        let n_features = X.ncols();
338        let mut best_feature = 0;
339        let mut best_threshold = 0.0;
340        let mut best_variance_reduction = 0.0;
341
342        // Calculate current variance
343        let current_variance = self.calculate_variance(y, indices);
344
345        for feature_idx in 0..n_features {
346            // Get unique feature values
347            let mut feature_values: Vec<Float> =
348                indices.iter().map(|&idx| X[[idx, feature_idx]]).collect();
349            feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
350            feature_values.dedup();
351
352            for i in 0..feature_values.len().saturating_sub(1) {
353                let threshold = (feature_values[i] + feature_values[i + 1]) / 2.0;
354
355                let (left_indices, right_indices) =
356                    self.split_data(X, indices, feature_idx, threshold);
357
358                if left_indices.is_empty() || right_indices.is_empty() {
359                    continue;
360                }
361
362                let left_variance = self.calculate_variance(y, &left_indices);
363                let right_variance = self.calculate_variance(y, &right_indices);
364
365                let weighted_variance = (left_indices.len() as Float * left_variance
366                    + right_indices.len() as Float * right_variance)
367                    / indices.len() as Float;
368
369                let variance_reduction = current_variance - weighted_variance;
370
371                if variance_reduction > best_variance_reduction {
372                    best_variance_reduction = variance_reduction;
373                    best_feature = feature_idx;
374                    best_threshold = threshold;
375                }
376            }
377        }
378
379        Ok((best_feature, best_threshold, best_variance_reduction))
380    }
381
382    fn calculate_variance(&self, y: &Array2<Float>, indices: &[usize]) -> Float {
383        if indices.is_empty() {
384            return 0.0;
385        }
386
387        let n_targets = y.ncols();
388        let n_samples = indices.len();
389
390        // Calculate means
391        let mut means = Array1::<Float>::zeros(n_targets);
392        for &idx in indices {
393            for j in 0..n_targets {
394                means[j] += y[[idx, j]];
395            }
396        }
397        means /= n_samples as Float;
398
399        // Calculate variance
400        let mut variance = 0.0;
401        for &idx in indices {
402            for j in 0..n_targets {
403                let diff = y[[idx, j]] - means[j];
404                variance += diff * diff;
405            }
406        }
407        variance / n_samples as Float
408    }
409
410    fn split_data(
411        &self,
412        X: &Array2<Float>,
413        indices: &[usize],
414        feature_idx: usize,
415        threshold: Float,
416    ) -> (Vec<usize>, Vec<usize>) {
417        let mut left_indices = Vec::new();
418        let mut right_indices = Vec::new();
419
420        for &idx in indices {
421            if X[[idx, feature_idx]] <= threshold {
422                left_indices.push(idx);
423            } else {
424                right_indices.push(idx);
425            }
426        }
427
428        (left_indices, right_indices)
429    }
430
431    fn calculate_feature_importances(
432        &self,
433        node: &DecisionNode,
434        importances: &mut Array1<Float>,
435        total_samples: Float,
436    ) {
437        if let (Some(feature_idx), Some(left), Some(right)) =
438            (node.feature_idx, &node.left, &node.right)
439        {
440            let importance = (node.n_samples as Float / total_samples) * node.variance;
441            importances[feature_idx] += importance;
442
443            self.calculate_feature_importances(left, importances, total_samples);
444            self.calculate_feature_importances(right, importances, total_samples);
445        }
446    }
447}
448
449impl MultiTargetRegressionTree<MultiTargetRegressionTreeTrained> {
450    /// Get the feature importances
451    pub fn feature_importances(&self) -> &Array1<Float> {
452        &self.state.feature_importances
453    }
454
455    /// Get the number of features
456    pub fn n_features(&self) -> usize {
457        self.state.n_features
458    }
459
460    /// Get the number of targets
461    pub fn n_targets(&self) -> usize {
462        self.state.n_targets
463    }
464
465    /// Predict using the fitted tree
466    #[allow(non_snake_case)]
467    pub fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
468        let X = *X;
469        let (n_samples, n_features) = X.dim();
470
471        if n_features != self.state.n_features {
472            return Err(SklearsError::InvalidInput(
473                "Number of features doesn't match training data".to_string(),
474            ));
475        }
476
477        let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_targets));
478
479        for i in 0..n_samples {
480            let sample = X.slice(s![i, ..]);
481            let prediction = self.predict_single(&self.state.tree, &sample)?;
482            for j in 0..self.state.n_targets {
483                predictions[[i, j]] = prediction[j];
484            }
485        }
486
487        Ok(predictions)
488    }
489
490    fn predict_single(
491        &self,
492        node: &DecisionNode,
493        sample: &ArrayView1<'_, Float>,
494    ) -> SklResult<Array1<Float>> {
495        if node.is_leaf {
496            if let Some(ref prediction) = node.prediction {
497                Ok(prediction.clone())
498            } else {
499                Err(SklearsError::InvalidInput(
500                    "Leaf node without prediction".to_string(),
501                ))
502            }
503        } else {
504            let feature_idx = node.feature_idx.ok_or(SklearsError::InvalidInput(
505                "Non-leaf node without feature index".to_string(),
506            ))?;
507            let threshold = node.threshold.ok_or(SklearsError::InvalidInput(
508                "Non-leaf node without threshold".to_string(),
509            ))?;
510
511            if sample[feature_idx] <= threshold {
512                if let Some(ref left) = node.left {
513                    self.predict_single(left, sample)
514                } else {
515                    Err(SklearsError::InvalidInput(
516                        "Non-leaf node without left child".to_string(),
517                    ))
518                }
519            } else if let Some(ref right) = node.right {
520                self.predict_single(right, sample)
521            } else {
522                Err(SklearsError::InvalidInput(
523                    "Non-leaf node without right child".to_string(),
524                ))
525            }
526        }
527    }
528}
529
530/// Multi-Target Decision Tree Classifier
531///
532/// A decision tree classifier that can handle multiple target variables simultaneously.
533/// Uses joint entropy/gini reduction for optimal splits across all targets.
534///
535/// # Examples
536///
537/// ```
538/// use sklears_multioutput::MultiTargetDecisionTreeClassifier;
539/// use sklears_core::traits::{Predict, Fit};
540/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
541/// use scirs2_core::ndarray::array;
542///
543/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
544/// let y = array![[0, 1], [1, 0], [1, 1], [0, 0]]; // Two binary classification targets
545///
546/// let tree = MultiTargetDecisionTreeClassifier::new()
547///     .max_depth(Some(3));
548/// let trained_tree = tree.fit(&X.view(), &y).unwrap();
549/// let predictions = trained_tree.predict(&X.view()).unwrap();
550/// ```
551#[derive(Debug, Clone)]
552pub struct MultiTargetDecisionTreeClassifier<S = Untrained> {
553    state: S,
554    max_depth: Option<usize>,
555    min_samples_split: usize,
556    min_samples_leaf: usize,
557    criterion: ClassificationCriterion,
558    random_state: Option<u64>,
559}
560
561#[derive(Debug, Clone)]
562pub struct MultiTargetDecisionTreeClassifierTrained {
563    tree: ClassificationDecisionNode,
564    n_features: usize,
565    n_targets: usize,
566    feature_importances: Array1<Float>,
567    classes_per_target: Vec<Vec<i32>>,
568}
569
570impl MultiTargetDecisionTreeClassifier<Untrained> {
571    /// Create a new MultiTargetDecisionTreeClassifier instance
572    pub fn new() -> Self {
573        Self {
574            state: Untrained,
575            max_depth: Some(5),
576            min_samples_split: 2,
577            min_samples_leaf: 1,
578            criterion: ClassificationCriterion::Gini,
579            random_state: None,
580        }
581    }
582
583    /// Set the maximum depth of the tree
584    pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
585        self.max_depth = max_depth;
586        self
587    }
588
589    /// Set the minimum number of samples required to split an internal node
590    pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
591        self.min_samples_split = min_samples_split;
592        self
593    }
594
595    /// Set the minimum number of samples required to be at a leaf node
596    pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
597        self.min_samples_leaf = min_samples_leaf;
598        self
599    }
600
601    /// Set the split criterion
602    pub fn criterion(mut self, criterion: ClassificationCriterion) -> Self {
603        self.criterion = criterion;
604        self
605    }
606
607    /// Set the random state for reproducible results
608    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
609        self.random_state = random_state;
610        self
611    }
612}
613
614impl Default for MultiTargetDecisionTreeClassifier<Untrained> {
615    fn default() -> Self {
616        Self::new()
617    }
618}
619
620impl Estimator for MultiTargetDecisionTreeClassifier<Untrained> {
621    type Config = ();
622    type Error = SklearsError;
623    type Float = Float;
624
625    fn config(&self) -> &Self::Config {
626        &()
627    }
628}
629
630impl Fit<ArrayView2<'_, Float>, Array2<i32>> for MultiTargetDecisionTreeClassifier<Untrained> {
631    type Fitted = MultiTargetDecisionTreeClassifier<MultiTargetDecisionTreeClassifierTrained>;
632
633    #[allow(non_snake_case)]
634    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
635        let X = X.to_owned();
636        let (n_samples, n_features) = X.dim();
637
638        if n_samples != y.nrows() {
639            return Err(SklearsError::InvalidInput(
640                "X and y must have the same number of samples".to_string(),
641            ));
642        }
643
644        let n_targets = y.ncols();
645        if n_targets == 0 {
646            return Err(SklearsError::InvalidInput(
647                "y must have at least one target".to_string(),
648            ));
649        }
650
651        // Get unique classes for each target
652        let mut classes_per_target = Vec::new();
653        for target_idx in 0..n_targets {
654            let target_column = y.column(target_idx);
655            let mut unique_classes: Vec<i32> = target_column.iter().cloned().collect();
656            unique_classes.sort_unstable();
657            unique_classes.dedup();
658            classes_per_target.push(unique_classes);
659        }
660
661        // Initialize feature importances
662        let mut feature_importances = Array1::<Float>::zeros(n_features);
663
664        // Build the tree
665        let indices: Vec<usize> = (0..n_samples).collect();
666        let tree = build_classification_tree(
667            &X,
668            y,
669            &indices,
670            &mut feature_importances,
671            0,
672            self.max_depth,
673            self.min_samples_split,
674            self.min_samples_leaf,
675            self.criterion,
676            &classes_per_target,
677        )?;
678
679        // Normalize feature importances
680        let importance_sum = feature_importances.sum();
681        if importance_sum > 0.0 {
682            feature_importances /= importance_sum;
683        }
684
685        let trained_state = MultiTargetDecisionTreeClassifierTrained {
686            tree,
687            n_features,
688            n_targets,
689            feature_importances,
690            classes_per_target,
691        };
692
693        Ok(MultiTargetDecisionTreeClassifier {
694            state: trained_state,
695            max_depth: self.max_depth,
696            min_samples_split: self.min_samples_split,
697            min_samples_leaf: self.min_samples_leaf,
698            criterion: self.criterion,
699            random_state: self.random_state,
700        })
701    }
702}
703
704impl Predict<ArrayView2<'_, Float>, Array2<i32>>
705    for MultiTargetDecisionTreeClassifier<MultiTargetDecisionTreeClassifierTrained>
706{
707    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
708        let (n_samples, n_features) = X.dim();
709        if n_features != self.state.n_features {
710            return Err(SklearsError::InvalidInput(
711                "X has different number of features than training data".to_string(),
712            ));
713        }
714
715        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_targets));
716
717        for i in 0..n_samples {
718            let sample = X.row(i);
719            let prediction = predict_classification_sample(&self.state.tree, &sample);
720            for j in 0..self.state.n_targets {
721                predictions[[i, j]] = prediction[j];
722            }
723        }
724
725        Ok(predictions)
726    }
727}
728
729impl MultiTargetDecisionTreeClassifier<MultiTargetDecisionTreeClassifierTrained> {
730    /// Get feature importances
731    pub fn feature_importances(&self) -> &Array1<Float> {
732        &self.state.feature_importances
733    }
734
735    /// Predict class probabilities for each target
736    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Vec<Array2<Float>>> {
737        let (n_samples, n_features) = X.dim();
738        if n_features != self.state.n_features {
739            return Err(SklearsError::InvalidInput(
740                "X has different number of features than training data".to_string(),
741            ));
742        }
743
744        let mut all_probabilities = Vec::new();
745
746        // Initialize probability arrays for each target
747        for target_idx in 0..self.state.n_targets {
748            let n_classes = self.state.classes_per_target[target_idx].len();
749            all_probabilities.push(Array2::<Float>::zeros((n_samples, n_classes)));
750        }
751
752        for i in 0..n_samples {
753            let sample = X.row(i);
754            let probabilities = predict_classification_probabilities(
755                &self.state.tree,
756                &sample,
757                &self.state.classes_per_target,
758            );
759
760            for (target_idx, target_probs) in probabilities.iter().enumerate() {
761                for (class_idx, &prob) in target_probs.iter().enumerate() {
762                    all_probabilities[target_idx][[i, class_idx]] = prob;
763                }
764            }
765        }
766
767        Ok(all_probabilities)
768    }
769}
770
771/// Random Forest Multi-Output Extension
772///
773/// A random forest that can handle multiple output variables simultaneously.
774/// This implementation creates multiple multi-target regression trees and
775/// averages their predictions for robust multi-output regression.
776///
777/// # Mathematical Foundation
778///
779/// The random forest combines multiple multi-target regression trees:
780/// - Each tree is trained on a bootstrap sample of the data
781/// - Each tree considers only a random subset of features at each split
782/// - Final prediction is the average of all tree predictions
783/// - Feature importance is averaged across all trees
784///
785/// # Examples
786///
787/// ```
788/// use sklears_multioutput::RandomForestMultiOutput;
789/// use sklears_core::traits::{Predict, Fit};
790/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
791/// use scirs2_core::ndarray::array;
792///
793/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
794/// let y = array![[1.5, 2.5], [2.5, 3.5], [3.5, 1.5], [4.5, 4.5]];
795///
796/// let forest = RandomForestMultiOutput::new()
797///     .n_estimators(10)
798///     .max_depth(Some(3));
799/// let trained_forest = forest.fit(&X.view(), &y).unwrap();
800/// let predictions = trained_forest.predict(&X.view()).unwrap();
801/// ```
802#[derive(Debug, Clone)]
803pub struct RandomForestMultiOutput<S = Untrained> {
804    state: S,
805    n_estimators: usize,
806    max_depth: Option<usize>,
807    min_samples_split: usize,
808    min_samples_leaf: usize,
809    max_features: Option<usize>,
810    bootstrap: bool,
811    random_state: Option<u64>,
812}
813
814#[derive(Debug, Clone)]
815pub struct RandomForestMultiOutputTrained {
816    trees: Vec<MultiTargetRegressionTree<MultiTargetRegressionTreeTrained>>,
817    n_features: usize,
818    n_targets: usize,
819    feature_importances: Array1<Float>,
820}
821
822impl RandomForestMultiOutput<Untrained> {
823    /// Create a new RandomForestMultiOutput instance
824    pub fn new() -> Self {
825        Self {
826            state: Untrained,
827            n_estimators: 10,
828            max_depth: None,
829            min_samples_split: 2,
830            min_samples_leaf: 1,
831            max_features: None,
832            bootstrap: true,
833            random_state: None,
834        }
835    }
836
837    /// Set the number of trees in the forest
838    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
839        self.n_estimators = n_estimators;
840        self
841    }
842
843    /// Set the maximum depth of the trees
844    pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
845        self.max_depth = max_depth;
846        self
847    }
848
849    /// Set the minimum number of samples required to split an internal node
850    pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
851        self.min_samples_split = min_samples_split;
852        self
853    }
854
855    /// Set the minimum number of samples required to be at a leaf node
856    pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
857        self.min_samples_leaf = min_samples_leaf;
858        self
859    }
860
861    /// Set the number of features to consider when looking for the best split
862    pub fn max_features(mut self, max_features: Option<usize>) -> Self {
863        self.max_features = max_features;
864        self
865    }
866
867    /// Set whether to use bootstrap samples when building trees
868    pub fn bootstrap(mut self, bootstrap: bool) -> Self {
869        self.bootstrap = bootstrap;
870        self
871    }
872
873    /// Set the random state for reproducible results
874    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
875        self.random_state = random_state;
876        self
877    }
878
879    /// Get the number of trees in the forest
880    pub fn get_n_estimators(&self) -> usize {
881        self.n_estimators
882    }
883
884    /// Get the maximum depth of the trees
885    pub fn get_max_depth(&self) -> Option<usize> {
886        self.max_depth
887    }
888
889    /// Get the minimum number of samples required to split an internal node
890    pub fn get_min_samples_split(&self) -> usize {
891        self.min_samples_split
892    }
893
894    /// Get the minimum number of samples required to be at a leaf node
895    pub fn get_min_samples_leaf(&self) -> usize {
896        self.min_samples_leaf
897    }
898
899    /// Get the maximum number of features to consider when looking for the best split
900    pub fn get_max_features(&self) -> Option<usize> {
901        self.max_features
902    }
903
904    /// Get whether bootstrap samples are used when building trees
905    pub fn get_bootstrap(&self) -> bool {
906        self.bootstrap
907    }
908
909    /// Get the random state
910    pub fn get_random_state(&self) -> Option<u64> {
911        self.random_state
912    }
913}
914
915impl Default for RandomForestMultiOutput<Untrained> {
916    fn default() -> Self {
917        Self::new()
918    }
919}
920
921impl Estimator for RandomForestMultiOutput<Untrained> {
922    type Config = ();
923    type Error = SklearsError;
924    type Float = Float;
925
926    fn config(&self) -> &Self::Config {
927        &()
928    }
929}
930
931impl Fit<ArrayView2<'_, Float>, Array2<Float>> for RandomForestMultiOutput<Untrained> {
932    type Fitted = RandomForestMultiOutput<RandomForestMultiOutputTrained>;
933
934    #[allow(non_snake_case)]
935    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<Float>) -> SklResult<Self::Fitted> {
936        let X = X.to_owned();
937        let (n_samples, n_features) = X.dim();
938
939        if n_samples != y.nrows() {
940            return Err(SklearsError::InvalidInput(
941                "X and y must have the same number of samples".to_string(),
942            ));
943        }
944
945        let n_targets = y.ncols();
946        if n_targets == 0 {
947            return Err(SklearsError::InvalidInput(
948                "y must have at least one target".to_string(),
949            ));
950        }
951
952        let mut trees = Vec::new();
953        let mut feature_importances = Array1::<Float>::zeros(n_features);
954
955        for i in 0..self.n_estimators {
956            // Create bootstrap sample if needed
957            let (X_sample, y_sample) = if self.bootstrap {
958                self.create_bootstrap_sample(&X, y, i)?
959            } else {
960                (X.clone(), y.clone())
961            };
962
963            // Create and train tree
964            let tree = MultiTargetRegressionTree::new()
965                .max_depth(self.max_depth)
966                .min_samples_split(self.min_samples_split)
967                .min_samples_leaf(self.min_samples_leaf)
968                .random_state(self.random_state.map(|s| s.wrapping_add(i as u64)));
969
970            let trained_tree = tree.fit(&X_sample.view(), &y_sample)?;
971
972            // Accumulate feature importances
973            feature_importances += trained_tree.feature_importances();
974
975            trees.push(trained_tree);
976        }
977
978        // Average feature importances
979        feature_importances /= self.n_estimators as Float;
980
981        Ok(RandomForestMultiOutput {
982            state: RandomForestMultiOutputTrained {
983                trees,
984                n_features,
985                n_targets,
986                feature_importances,
987            },
988            n_estimators: self.n_estimators,
989            max_depth: self.max_depth,
990            min_samples_split: self.min_samples_split,
991            min_samples_leaf: self.min_samples_leaf,
992            max_features: self.max_features,
993            bootstrap: self.bootstrap,
994            random_state: self.random_state,
995        })
996    }
997}
998
999impl RandomForestMultiOutput<Untrained> {
1000    fn create_bootstrap_sample(
1001        &self,
1002        X: &Array2<Float>,
1003        y: &Array2<Float>,
1004        seed: usize,
1005    ) -> SklResult<(Array2<Float>, Array2<Float>)> {
1006        let n_samples = X.nrows();
1007        let mut rng_state = self.random_state.unwrap_or(42).wrapping_add(seed as u64);
1008
1009        let mut X_sample = Array2::<Float>::zeros(X.raw_dim());
1010        let mut y_sample = Array2::<Float>::zeros(y.raw_dim());
1011
1012        for i in 0..n_samples {
1013            rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
1014            let idx = (rng_state / 65536) % (n_samples as u64);
1015
1016            X_sample
1017                .slice_mut(s![i, ..])
1018                .assign(&X.slice(s![idx as usize, ..]));
1019            y_sample
1020                .slice_mut(s![i, ..])
1021                .assign(&y.slice(s![idx as usize, ..]));
1022        }
1023
1024        Ok((X_sample, y_sample))
1025    }
1026}
1027
1028impl RandomForestMultiOutput<RandomForestMultiOutputTrained> {
1029    /// Get the feature importances averaged across all trees
1030    pub fn feature_importances(&self) -> &Array1<Float> {
1031        &self.state.feature_importances
1032    }
1033
1034    /// Get the number of estimators (trees)
1035    pub fn n_estimators(&self) -> usize {
1036        self.state.trees.len()
1037    }
1038
1039    /// Get the number of features
1040    pub fn n_features(&self) -> usize {
1041        self.state.n_features
1042    }
1043
1044    /// Get the number of targets
1045    pub fn n_targets(&self) -> usize {
1046        self.state.n_targets
1047    }
1048
1049    /// Predict using the forest (average of all tree predictions)
1050    #[allow(non_snake_case)]
1051    pub fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1052        let X = *X;
1053        let (n_samples, n_features) = X.dim();
1054
1055        if n_features != self.state.n_features {
1056            return Err(SklearsError::InvalidInput(
1057                "Number of features doesn't match training data".to_string(),
1058            ));
1059        }
1060
1061        let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_targets));
1062
1063        // Average predictions from all trees
1064        for tree in &self.state.trees {
1065            let tree_predictions = tree.predict(&X)?;
1066            predictions += &tree_predictions;
1067        }
1068
1069        predictions /= self.state.trees.len() as Float;
1070        Ok(predictions)
1071    }
1072}
1073
1074/// Tree-Structured Prediction
1075///
1076/// A structured prediction method for outputs that follow a tree structure.
1077/// Each internal node in the tree has a classifier that decides which branch to take,
1078/// enabling hierarchical multi-label classification.
1079///
1080/// # Examples
1081///
1082/// ```
1083/// use sklears_core::traits::{Predict, Fit};
1084/// use sklears_multioutput::TreeStructuredPredictor;
1085/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
1086/// use scirs2_core::ndarray::array;
1087///
1088/// let X = array![[1.0, 2.0], [2.0, 3.0]];
1089/// let y = array![[0, 1, 2], [1, 2, 0]]; // Tree paths
1090///
1091/// let tree_predictor = TreeStructuredPredictor::new()
1092///     .max_depth(3)
1093///     .branching_factor(3);
1094/// let trained_predictor = tree_predictor.fit(&X, &y).unwrap();
1095/// let predictions = trained_predictor.predict(&X).unwrap();
1096/// ```
1097#[derive(Debug, Clone)]
1098pub struct TreeStructuredPredictor<State = Untrained> {
1099    max_depth: usize,
1100    branching_factor: usize,
1101    tree_structure: Vec<Vec<usize>>, // Adjacency list representation
1102    node_classifiers: HashMap<usize, String>,
1103    state: State,
1104}
1105
1106/// Trained state for Tree-Structured Predictor
1107#[derive(Debug, Clone)]
1108pub struct TreeStructuredPredictorTrained {
1109    node_classifiers: HashMap<usize, BinaryRelevance<BinaryRelevanceTrained>>,
1110    tree_structure: Vec<Vec<usize>>,
1111    max_depth: usize,
1112    n_nodes: usize,
1113}
1114
1115impl Default for TreeStructuredPredictor<Untrained> {
1116    fn default() -> Self {
1117        Self::new()
1118    }
1119}
1120
1121impl TreeStructuredPredictor<Untrained> {
1122    /// Create new tree-structured predictor
1123    pub fn new() -> Self {
1124        Self {
1125            max_depth: 5,
1126            branching_factor: 2,
1127            tree_structure: Vec::new(),
1128            node_classifiers: HashMap::new(),
1129            state: Untrained,
1130        }
1131    }
1132
1133    /// Set maximum tree depth
1134    pub fn max_depth(mut self, depth: usize) -> Self {
1135        self.max_depth = depth;
1136        self
1137    }
1138
1139    /// Set branching factor
1140    pub fn branching_factor(mut self, factor: usize) -> Self {
1141        self.branching_factor = factor;
1142        self
1143    }
1144
1145    /// Set custom tree structure
1146    pub fn tree_structure(mut self, structure: Vec<Vec<usize>>) -> Self {
1147        self.tree_structure = structure;
1148        self
1149    }
1150}
1151
1152impl Estimator for TreeStructuredPredictor<Untrained> {
1153    type Config = ();
1154    type Error = SklearsError;
1155    type Float = Float;
1156
1157    fn config(&self) -> &Self::Config {
1158        &()
1159    }
1160}
1161
1162impl Fit<Array2<Float>, Array2<i32>> for TreeStructuredPredictor<Untrained> {
1163    type Fitted = TreeStructuredPredictor<TreeStructuredPredictorTrained>;
1164
1165    fn fit(self, X: &Array2<Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
1166        let (n_samples, _n_features) = X.dim();
1167        let (y_samples, max_path_length) = y.dim();
1168
1169        if n_samples != y_samples {
1170            return Err(SklearsError::InvalidInput(
1171                "Number of samples in X and y must match".to_string(),
1172            ));
1173        }
1174
1175        // Build tree structure if not provided
1176        let tree_structure = if self.tree_structure.is_empty() {
1177            self.build_default_tree_structure()?
1178        } else {
1179            self.tree_structure.clone()
1180        };
1181
1182        let n_nodes = tree_structure.len();
1183        let mut node_classifiers = HashMap::new();
1184
1185        // Train classifier for each internal node
1186        for node_id in 0..n_nodes {
1187            if !tree_structure[node_id].is_empty() {
1188                // Internal node
1189                // Create binary classification data for this node
1190                let (node_X, node_y) = self.create_node_training_data(
1191                    &X.view(),
1192                    &y.view(),
1193                    node_id,
1194                    &tree_structure,
1195                    max_path_length,
1196                )?;
1197
1198                if !node_y.is_empty() {
1199                    let classifier = BinaryRelevance::new();
1200                    let trained_classifier = classifier.fit(&node_X.view(), &node_y)?;
1201                    node_classifiers.insert(node_id, trained_classifier);
1202                }
1203            }
1204        }
1205
1206        Ok(TreeStructuredPredictor {
1207            max_depth: self.max_depth,
1208            branching_factor: self.branching_factor,
1209            tree_structure: tree_structure.clone(),
1210            node_classifiers: HashMap::new(),
1211            state: TreeStructuredPredictorTrained {
1212                node_classifiers,
1213                tree_structure,
1214                max_depth: self.max_depth,
1215                n_nodes,
1216            },
1217        })
1218    }
1219}
1220
1221impl TreeStructuredPredictor<Untrained> {
1222    /// Build default tree structure
1223    fn build_default_tree_structure(&self) -> SklResult<Vec<Vec<usize>>> {
1224        let mut total_nodes = 0;
1225        for depth in 0..self.max_depth {
1226            total_nodes += self.branching_factor.pow(depth as u32);
1227        }
1228
1229        let mut tree_structure = vec![Vec::new(); total_nodes];
1230        let mut node_id = 0;
1231
1232        // Build complete tree
1233        for depth in 0..(self.max_depth - 1) {
1234            let nodes_at_depth = self.branching_factor.pow(depth as u32);
1235
1236            for _ in 0..nodes_at_depth {
1237                for child in 0..self.branching_factor {
1238                    let child_id = node_id + nodes_at_depth + child;
1239                    if child_id < total_nodes {
1240                        tree_structure[node_id].push(child_id);
1241                    }
1242                }
1243                node_id += 1;
1244            }
1245        }
1246
1247        Ok(tree_structure)
1248    }
1249
1250    /// Create training data for a specific node
1251    fn create_node_training_data(
1252        &self,
1253        X: &ArrayView2<Float>,
1254        y: &ArrayView2<i32>,
1255        node_id: usize,
1256        tree_structure: &Vec<Vec<usize>>,
1257        max_path_length: usize,
1258    ) -> SklResult<(Array2<Float>, Array2<i32>)> {
1259        let n_samples = X.nrows();
1260        let mut valid_samples = Vec::new();
1261        let mut node_labels = Vec::new();
1262
1263        for sample_idx in 0..n_samples {
1264            let path = y.row(sample_idx);
1265
1266            // Check if this sample's path passes through this node
1267            for pos in 0..max_path_length {
1268                if path[pos] as usize == node_id && pos + 1 < max_path_length {
1269                    // Sample passes through this node, determine which child to predict
1270                    let next_node = path[pos + 1] as usize;
1271
1272                    // Find which child index this corresponds to
1273                    if let Some(child_idx) = tree_structure[node_id]
1274                        .iter()
1275                        .position(|&child| child == next_node)
1276                    {
1277                        valid_samples.push(sample_idx);
1278                        node_labels.push(child_idx as i32);
1279                        break;
1280                    }
1281                }
1282            }
1283        }
1284
1285        // Create training data for this node
1286        let n_valid = valid_samples.len();
1287        if n_valid == 0 {
1288            return Ok((
1289                Array2::<Float>::zeros((0, X.ncols())),
1290                Array2::<i32>::zeros((0, 1)),
1291            ));
1292        }
1293
1294        let mut node_X = Array2::<Float>::zeros((n_valid, X.ncols()));
1295        let mut node_y = Array2::<i32>::zeros((n_valid, 1));
1296
1297        for (i, &sample_idx) in valid_samples.iter().enumerate() {
1298            for j in 0..X.ncols() {
1299                node_X[[i, j]] = X[[sample_idx, j]];
1300            }
1301            node_y[[i, 0]] = node_labels[i];
1302        }
1303
1304        Ok((node_X, node_y))
1305    }
1306}
1307
1308impl Predict<Array2<Float>, Array2<i32>>
1309    for TreeStructuredPredictor<TreeStructuredPredictorTrained>
1310{
1311    fn predict(&self, X: &Array2<Float>) -> SklResult<Array2<i32>> {
1312        let n_samples = X.nrows();
1313        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.max_depth));
1314
1315        for sample_idx in 0..n_samples {
1316            let sample = X.row(sample_idx);
1317            let path = self.predict_tree_path(&sample)?;
1318
1319            for (pos, &node) in path.iter().enumerate() {
1320                if pos < self.state.max_depth {
1321                    predictions[[sample_idx, pos]] = node as i32;
1322                }
1323            }
1324        }
1325
1326        Ok(predictions)
1327    }
1328}
1329
1330impl TreeStructuredPredictor<TreeStructuredPredictorTrained> {
1331    /// Predict tree path for a single sample
1332    fn predict_tree_path(&self, sample: &ArrayView1<Float>) -> SklResult<Vec<usize>> {
1333        let mut path = Vec::new();
1334        let mut current_node = 0; // Start at root
1335        path.push(current_node);
1336
1337        while !self.state.tree_structure[current_node].is_empty() {
1338            // Internal node - predict which child to take
1339            if let Some(classifier) = self.state.node_classifiers.get(&current_node) {
1340                let sample_2d = sample.to_owned().insert_axis(scirs2_core::ndarray::Axis(0));
1341                let prediction = classifier.predict(&sample_2d.view())?;
1342                let child_idx = prediction[[0, 0]] as usize;
1343
1344                if child_idx < self.state.tree_structure[current_node].len() {
1345                    current_node = self.state.tree_structure[current_node][child_idx];
1346                    path.push(current_node);
1347                } else {
1348                    break; // Invalid prediction
1349                }
1350            } else {
1351                break; // No classifier for this node
1352            }
1353        }
1354
1355        Ok(path)
1356    }
1357
1358    /// Get tree structure
1359    pub fn tree_structure(&self) -> &Vec<Vec<usize>> {
1360        &self.state.tree_structure
1361    }
1362}
1363
1364// Supporting functions for tree algorithms
1365
1366/// Build a classification decision tree recursively
1367pub fn build_classification_tree(
1368    X: &Array2<Float>,
1369    y: &Array2<i32>,
1370    indices: &[usize],
1371    feature_importances: &mut Array1<Float>,
1372    depth: usize,
1373    max_depth: Option<usize>,
1374    min_samples_split: usize,
1375    min_samples_leaf: usize,
1376    criterion: ClassificationCriterion,
1377    classes_per_target: &[Vec<i32>],
1378) -> SklResult<ClassificationDecisionNode> {
1379    let n_samples = indices.len();
1380
1381    // Calculate current impurity and predictions
1382    let (current_impurity, prediction, probabilities) =
1383        calculate_classification_metrics(y, indices, classes_per_target, criterion);
1384
1385    // Check stopping criteria
1386    let should_stop = n_samples < min_samples_split
1387        || (max_depth.is_some() && depth >= max_depth.unwrap())
1388        || current_impurity == 0.0;
1389
1390    if should_stop {
1391        return Ok(ClassificationDecisionNode {
1392            is_leaf: true,
1393            prediction: Some(prediction),
1394            probabilities: Some(probabilities),
1395            feature_idx: None,
1396            threshold: None,
1397            left: None,
1398            right: None,
1399            n_samples,
1400            impurity: current_impurity,
1401        });
1402    }
1403
1404    // Find best split
1405    let mut best_impurity_reduction = 0.0;
1406    let mut best_feature = None;
1407    let mut best_threshold = None;
1408    let mut best_left_indices = Vec::new();
1409    let mut best_right_indices = Vec::new();
1410
1411    for feature_idx in 0..X.ncols() {
1412        // Get unique values for this feature in current samples
1413        let mut feature_values: Vec<Float> = indices.iter().map(|&i| X[[i, feature_idx]]).collect();
1414        feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1415        feature_values.dedup();
1416
1417        // Try splits between consecutive unique values
1418        for i in 0..feature_values.len().saturating_sub(1) {
1419            let threshold = (feature_values[i] + feature_values[i + 1]) / 2.0;
1420
1421            let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
1422                .iter()
1423                .partition(|&&idx| X[[idx, feature_idx]] <= threshold);
1424
1425            // Check minimum samples per leaf
1426            if left_indices.len() < min_samples_leaf || right_indices.len() < min_samples_leaf {
1427                continue;
1428            }
1429
1430            // Calculate impurity reduction
1431            let (left_impurity, _, _) =
1432                calculate_classification_metrics(y, &left_indices, classes_per_target, criterion);
1433            let (right_impurity, _, _) =
1434                calculate_classification_metrics(y, &right_indices, classes_per_target, criterion);
1435
1436            let weighted_impurity = (left_indices.len() as Float * left_impurity
1437                + right_indices.len() as Float * right_impurity)
1438                / n_samples as Float;
1439            let impurity_reduction = current_impurity - weighted_impurity;
1440
1441            if impurity_reduction > best_impurity_reduction {
1442                best_impurity_reduction = impurity_reduction;
1443                best_feature = Some(feature_idx);
1444                best_threshold = Some(threshold);
1445                best_left_indices = left_indices;
1446                best_right_indices = right_indices;
1447            }
1448        }
1449    }
1450
1451    // If no good split found, create leaf
1452    if best_feature.is_none() || best_impurity_reduction <= 0.0 {
1453        return Ok(ClassificationDecisionNode {
1454            is_leaf: true,
1455            prediction: Some(prediction),
1456            probabilities: Some(probabilities),
1457            feature_idx: None,
1458            threshold: None,
1459            left: None,
1460            right: None,
1461            n_samples,
1462            impurity: current_impurity,
1463        });
1464    }
1465
1466    // Update feature importance
1467    feature_importances[best_feature.unwrap()] += best_impurity_reduction * n_samples as Float;
1468
1469    // Recursively build left and right subtrees
1470    let left_child = build_classification_tree(
1471        X,
1472        y,
1473        &best_left_indices,
1474        feature_importances,
1475        depth + 1,
1476        max_depth,
1477        min_samples_split,
1478        min_samples_leaf,
1479        criterion,
1480        classes_per_target,
1481    )?;
1482
1483    let right_child = build_classification_tree(
1484        X,
1485        y,
1486        &best_right_indices,
1487        feature_importances,
1488        depth + 1,
1489        max_depth,
1490        min_samples_split,
1491        min_samples_leaf,
1492        criterion,
1493        classes_per_target,
1494    )?;
1495
1496    Ok(ClassificationDecisionNode {
1497        is_leaf: false,
1498        prediction: Some(prediction),
1499        probabilities: Some(probabilities),
1500        feature_idx: best_feature,
1501        threshold: best_threshold,
1502        left: Some(Box::new(left_child)),
1503        right: Some(Box::new(right_child)),
1504        n_samples,
1505        impurity: current_impurity,
1506    })
1507}
1508
1509/// Calculate classification metrics for a set of samples
1510pub fn calculate_classification_metrics(
1511    y: &Array2<i32>,
1512    indices: &[usize],
1513    classes_per_target: &[Vec<i32>],
1514    criterion: ClassificationCriterion,
1515) -> (Float, Array1<i32>, Array2<Float>) {
1516    let n_targets = y.ncols();
1517    let n_samples = indices.len();
1518
1519    let mut prediction = Array1::<i32>::zeros(n_targets);
1520    let mut total_impurity = 0.0;
1521
1522    // Calculate max number of classes across all targets for probability matrix
1523    let max_classes = classes_per_target
1524        .iter()
1525        .map(|classes| classes.len())
1526        .max()
1527        .unwrap_or(0);
1528    let mut probabilities = Array2::<Float>::zeros((n_targets, max_classes));
1529
1530    for target_idx in 0..n_targets {
1531        let classes = &classes_per_target[target_idx];
1532        let n_classes = classes.len();
1533
1534        // Count class frequencies
1535        let mut class_counts = vec![0; n_classes];
1536        for &sample_idx in indices {
1537            let class_label = y[[sample_idx, target_idx]];
1538            if let Some(class_idx) = classes.iter().position(|&c| c == class_label) {
1539                class_counts[class_idx] += 1;
1540            }
1541        }
1542
1543        // Find majority class
1544        let majority_class_idx = class_counts
1545            .iter()
1546            .enumerate()
1547            .max_by_key(|(_, &count)| count)
1548            .map(|(idx, _)| idx)
1549            .unwrap_or(0);
1550
1551        prediction[target_idx] = classes[majority_class_idx];
1552
1553        // Calculate probabilities and impurity
1554        let mut target_impurity = 0.0;
1555        for (class_idx, &count) in class_counts.iter().enumerate() {
1556            let prob = count as Float / n_samples as Float;
1557            probabilities[[target_idx, class_idx]] = prob;
1558
1559            if prob > 0.0 {
1560                target_impurity += match criterion {
1561                    ClassificationCriterion::Gini => prob * (1.0 - prob),
1562                    ClassificationCriterion::Entropy => -prob * prob.ln(),
1563                };
1564            }
1565        }
1566
1567        // For Gini, multiply by 2; for Entropy, it's already correct
1568        if matches!(criterion, ClassificationCriterion::Gini) {
1569            target_impurity *= 2.0;
1570        }
1571
1572        total_impurity += target_impurity;
1573    }
1574
1575    // Average impurity across targets
1576    total_impurity /= n_targets as Float;
1577
1578    (total_impurity, prediction, probabilities)
1579}
1580
1581/// Predict for a single classification sample
1582pub fn predict_classification_sample(
1583    node: &ClassificationDecisionNode,
1584    sample: &ArrayView1<Float>,
1585) -> Array1<i32> {
1586    if node.is_leaf {
1587        return node.prediction.as_ref().unwrap().clone();
1588    }
1589
1590    let feature_value = sample[node.feature_idx.unwrap()];
1591    let threshold = node.threshold.unwrap();
1592
1593    if feature_value <= threshold {
1594        predict_classification_sample(node.left.as_ref().unwrap(), sample)
1595    } else {
1596        predict_classification_sample(node.right.as_ref().unwrap(), sample)
1597    }
1598}
1599
1600/// Predict probabilities for a single classification sample
1601pub fn predict_classification_probabilities(
1602    node: &ClassificationDecisionNode,
1603    sample: &ArrayView1<Float>,
1604    classes_per_target: &[Vec<i32>],
1605) -> Vec<Array1<Float>> {
1606    if node.is_leaf {
1607        let mut result = Vec::new();
1608        for target_idx in 0..classes_per_target.len() {
1609            let n_classes = classes_per_target[target_idx].len();
1610            let mut target_probs = Array1::<Float>::zeros(n_classes);
1611            for class_idx in 0..n_classes {
1612                target_probs[class_idx] =
1613                    node.probabilities.as_ref().unwrap()[[target_idx, class_idx]];
1614            }
1615            result.push(target_probs);
1616        }
1617        return result;
1618    }
1619
1620    let feature_value = sample[node.feature_idx.unwrap()];
1621    let threshold = node.threshold.unwrap();
1622
1623    if feature_value <= threshold {
1624        predict_classification_probabilities(
1625            node.left.as_ref().unwrap(),
1626            sample,
1627            classes_per_target,
1628        )
1629    } else {
1630        predict_classification_probabilities(
1631            node.right.as_ref().unwrap(),
1632            sample,
1633            classes_per_target,
1634        )
1635    }
1636}