sklears_tree/
config.rs

1//! Configuration types and enums for decision trees
2//!
3//! This module contains all the configuration enums, structs, and parameters
4//! used by decision tree classifiers and regressors.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::types::Float;
8use smartcore::linalg::basic::matrix::DenseMatrix;
9
10#[cfg(feature = "oblique")]
11use sklears_core::error::Result;
12
13// Import types from criteria module
14use crate::criteria::SplitCriterion;
15
16/// Monotonic constraint for a feature
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum MonotonicConstraint {
19    /// No constraint on the relationship
20    None,
21    /// Feature must have increasing relationship with target (positive monotonicity)
22    Increasing,
23    /// Feature must have decreasing relationship with target (negative monotonicity)
24    Decreasing,
25}
26
27/// Interaction constraint between features
28#[derive(Debug, Clone)]
29pub enum InteractionConstraint {
30    /// No constraints on feature interactions
31    None,
32    /// Allow interactions only within specified groups
33    Groups(Vec<Vec<usize>>),
34    /// Forbid specific feature pairs from interacting
35    Forbidden(Vec<(usize, usize)>),
36    /// Allow only specific feature pairs to interact
37    Allowed(Vec<(usize, usize)>),
38}
39
40/// Feature grouping strategy for handling correlated features
41#[derive(Debug, Clone)]
42pub enum FeatureGrouping {
43    /// No feature grouping (default)
44    None,
45    /// Automatic grouping based on correlation threshold
46    AutoCorrelation {
47        /// Correlation threshold above which features are grouped together
48        threshold: Float,
49        /// Method to select representative feature from each group
50        selection_method: GroupSelectionMethod,
51    },
52    /// Manual feature groups specified by user
53    Manual {
54        /// List of feature groups, each group contains feature indices
55        groups: Vec<Vec<usize>>,
56        /// Method to select representative feature from each group
57        selection_method: GroupSelectionMethod,
58    },
59    /// Hierarchical clustering-based grouping
60    Hierarchical {
61        /// Number of clusters to create
62        n_clusters: usize,
63        /// Linkage method for hierarchical clustering
64        linkage: LinkageMethod,
65        /// Method to select representative feature from each group
66        selection_method: GroupSelectionMethod,
67    },
68}
69
70/// Method for selecting representative feature from a group
71#[derive(Debug, Clone, Copy, PartialEq)]
72pub enum GroupSelectionMethod {
73    /// Select feature with highest variance within the group
74    MaxVariance,
75    /// Select feature with highest correlation to target
76    MaxTargetCorrelation,
77    /// Select first feature in the group (index order)
78    First,
79    /// Select random feature from the group
80    Random,
81    /// Use all features from the group but with reduced weight
82    WeightedAll,
83}
84
85/// Linkage method for hierarchical clustering
86#[derive(Debug, Clone, Copy, PartialEq)]
87pub enum LinkageMethod {
88    /// Single linkage (minimum distance)
89    Single,
90    /// Complete linkage (maximum distance)
91    Complete,
92    /// Average linkage
93    Average,
94    /// Ward linkage (minimize within-cluster variance)
95    Ward,
96}
97
98/// Information about feature groups discovered or specified
99#[derive(Debug, Clone)]
100pub struct FeatureGroupInfo {
101    /// Groups of correlated features
102    pub groups: Vec<Vec<usize>>,
103    /// Representative feature index for each group
104    pub representatives: Vec<usize>,
105    /// Correlation matrix used for grouping (if applicable)
106    pub correlation_matrix: Option<Array2<Float>>,
107    /// Within-group correlations for each group
108    pub group_correlations: Vec<Float>,
109}
110
111/// Strategy for selecting max features
112#[derive(Debug, Clone)]
113pub enum MaxFeatures {
114    /// Use all features
115    All,
116    /// Use sqrt(n_features)
117    Sqrt,
118    /// Use log2(n_features)
119    Log2,
120    /// Use a specific number of features
121    Number(usize),
122    /// Use a fraction of features
123    Fraction(f64),
124}
125
126/// Pruning strategy for decision trees
127#[derive(Debug, Clone, Copy)]
128pub enum PruningStrategy {
129    /// No pruning
130    None,
131    /// Cost-complexity pruning (post-pruning)
132    CostComplexity { alpha: f64 },
133    /// Reduced error pruning
134    ReducedError,
135}
136
137/// Missing value handling strategy
138#[derive(Debug, Clone, Copy)]
139pub enum MissingValueStrategy {
140    /// Skip samples with missing values
141    Skip,
142    /// Use majority class/mean for splits
143    Majority,
144    /// Use surrogate splits
145    Surrogate,
146}
147
148/// Feature type specification for multiway splits
149#[derive(Debug, Clone)]
150pub enum FeatureType {
151    /// Continuous numerical feature (binary splits)
152    Continuous,
153    /// Categorical feature with specified categories (multiway splits)
154    Categorical(Vec<String>),
155}
156
157/// Information about a multiway split
158#[derive(Debug, Clone)]
159pub struct MultiwaySplit {
160    /// Feature index
161    pub feature_idx: usize,
162    /// Category assignments for each branch
163    pub category_branches: Vec<Vec<String>>,
164    /// Impurity decrease achieved by this split
165    pub impurity_decrease: f64,
166}
167
168/// Tree growing strategy
169#[derive(Debug, Clone, Copy)]
170pub enum TreeGrowingStrategy {
171    /// Depth-first growing (traditional CART)
172    DepthFirst,
173    /// Best-first growing (expand node with highest impurity decrease)
174    BestFirst { max_leaves: Option<usize> },
175}
176
177/// Split type for decision trees
178#[derive(Debug, Clone, Copy)]
179pub enum SplitType {
180    /// Traditional axis-aligned splits (threshold on single feature)
181    AxisAligned,
182    /// Linear hyperplane splits (linear combination of features)
183    Oblique {
184        /// Number of random hyperplanes to evaluate per split
185        n_hyperplanes: usize,
186        /// Use ridge regression to find optimal hyperplane
187        use_ridge: bool,
188    },
189}
190
191/// Hyperplane split information for oblique trees
192#[derive(Debug, Clone)]
193pub struct HyperplaneSplit {
194    /// Feature coefficients for the hyperplane (w^T x >= threshold)
195    pub coefficients: Array1<f64>,
196    /// Threshold for the hyperplane split
197    pub threshold: f64,
198    /// Bias term for the hyperplane
199    pub bias: f64,
200    /// Impurity decrease achieved by this split
201    pub impurity_decrease: f64,
202}
203
204impl HyperplaneSplit {
205    /// Evaluate the hyperplane split for a sample
206    pub fn evaluate(&self, sample: &Array1<f64>) -> bool {
207        let dot_product = self.coefficients.dot(sample) + self.bias;
208        dot_product >= self.threshold
209    }
210
211    /// Create a random hyperplane with normalized coefficients
212    pub fn random(n_features: usize, rng: &mut scirs2_core::CoreRandom) -> Self {
213        let mut coefficients = Array1::zeros(n_features);
214        for i in 0..n_features {
215            coefficients[i] = rng.gen_range(-1.0..1.0);
216        }
217
218        // Normalize coefficients
219        let dot_product: f64 = coefficients.dot(&coefficients);
220        let norm = dot_product.sqrt();
221        if norm > 1e-10_f64 {
222            coefficients /= norm;
223        }
224
225        Self {
226            coefficients,
227            threshold: rng.gen_range(-1.0..1.0),
228            bias: rng.gen_range(-0.1..0.1),
229            impurity_decrease: 0.0,
230        }
231    }
232
233    /// Find optimal hyperplane using ridge regression
234    #[cfg(feature = "oblique")]
235    pub fn from_ridge_regression(x: &Array2<f64>, y: &Array1<f64>, alpha: f64) -> Result<Self> {
236        use scirs2_core::ndarray::s;
237        use sklears_core::error::SklearsError;
238
239        let n_features = x.ncols();
240        if x.nrows() < 2 {
241            return Err(SklearsError::InvalidInput(
242                "Need at least 2 samples for ridge regression".to_string(),
243            ));
244        }
245
246        // Add bias column to X
247        let mut x_bias = Array2::ones((x.nrows(), n_features + 1));
248        x_bias.slice_mut(s![.., ..n_features]).assign(x);
249
250        // Ridge regression: w = (X^T X + α I)^(-1) X^T y
251        let xtx = x_bias.t().dot(&x_bias);
252        let ridge_matrix = xtx + Array2::<f64>::eye(n_features + 1) * alpha;
253        let xty = x_bias.t().dot(y);
254
255        // Simple matrix inverse using Gauss-Jordan elimination
256        match gauss_jordan_inverse(&ridge_matrix) {
257            Ok(inv_matrix) => {
258                let coefficients_full = inv_matrix.dot(&xty);
259
260                let coefficients = coefficients_full.slice(s![..n_features]).to_owned();
261                let bias = coefficients_full[n_features];
262
263                Ok(Self {
264                    coefficients,
265                    threshold: 0.0, // Will be set during split evaluation
266                    bias,
267                    impurity_decrease: 0.0,
268                })
269            }
270            Err(_) => {
271                // Fallback to random hyperplane if matrix is singular
272                let mut rng = scirs2_core::random::thread_rng();
273                Ok(Self::random(n_features, &mut rng))
274            }
275        }
276    }
277}
278
279/// Configuration for Decision Trees
280#[derive(Debug, Clone)]
281pub struct DecisionTreeConfig {
282    /// Split criterion
283    pub criterion: SplitCriterion,
284    /// Maximum depth of the tree
285    pub max_depth: Option<usize>,
286    /// Minimum samples required to split an internal node
287    pub min_samples_split: usize,
288    /// Minimum samples required to be at a leaf node
289    pub min_samples_leaf: usize,
290    /// Maximum number of features to consider for splits
291    pub max_features: MaxFeatures,
292    /// Random seed for reproducibility
293    pub random_state: Option<u64>,
294    /// Minimum weighted fraction of samples required to be at a leaf
295    pub min_weight_fraction_leaf: f64,
296    /// Minimum impurity decrease required for a split
297    pub min_impurity_decrease: f64,
298    /// Pruning strategy to apply
299    pub pruning: PruningStrategy,
300    /// Strategy for handling missing values
301    pub missing_values: MissingValueStrategy,
302    /// Feature types for each feature (enables multiway splits for categorical features)
303    pub feature_types: Option<Vec<FeatureType>>,
304    /// Tree growing strategy
305    pub growing_strategy: TreeGrowingStrategy,
306    /// Split type (axis-aligned or oblique)
307    pub split_type: SplitType,
308    /// Monotonic constraints for each feature
309    pub monotonic_constraints: Option<Vec<MonotonicConstraint>>,
310    /// Interaction constraints between features
311    pub interaction_constraints: InteractionConstraint,
312    /// Feature grouping strategy for handling correlated features
313    pub feature_grouping: FeatureGrouping,
314}
315
316impl Default for DecisionTreeConfig {
317    fn default() -> Self {
318        Self {
319            criterion: SplitCriterion::Gini,
320            max_depth: None,
321            min_samples_split: 2,
322            min_samples_leaf: 1,
323            max_features: MaxFeatures::All,
324            random_state: None,
325            min_weight_fraction_leaf: 0.0,
326            min_impurity_decrease: 0.0,
327            pruning: PruningStrategy::None,
328            missing_values: MissingValueStrategy::Skip,
329            feature_types: None,
330            growing_strategy: TreeGrowingStrategy::DepthFirst,
331            split_type: SplitType::AxisAligned,
332            monotonic_constraints: None,
333            interaction_constraints: InteractionConstraint::None,
334            feature_grouping: FeatureGrouping::None,
335        }
336    }
337}
338
339/// Helper function to convert ndarray to DenseMatrix
340pub fn ndarray_to_dense_matrix(arr: &Array2<f64>) -> DenseMatrix<f64> {
341    let _rows = arr.nrows();
342    let _cols = arr.ncols();
343    let mut data = Vec::new();
344    for row in arr.outer_iter() {
345        data.push(row.to_vec());
346    }
347    DenseMatrix::from_2d_vec(&data).expect("Failed to convert ndarray to DenseMatrix")
348}
349
350/// Simple Gauss-Jordan elimination for matrix inversion
351#[cfg(feature = "oblique")]
352fn gauss_jordan_inverse(matrix: &Array2<f64>) -> std::result::Result<Array2<f64>, &'static str> {
353    let n = matrix.nrows();
354    if n != matrix.ncols() {
355        return Err("Matrix must be square");
356    }
357
358    // Create augmented matrix [A | I]
359    let mut augmented = Array2::zeros((n, 2 * n));
360    for i in 0..n {
361        for j in 0..n {
362            augmented[[i, j]] = matrix[[i, j]];
363            if i == j {
364                augmented[[i, j + n]] = 1.0;
365            }
366        }
367    }
368
369    // Forward elimination
370    for i in 0..n {
371        // Find pivot
372        let mut max_row = i;
373        for k in i + 1..n {
374            if augmented[[k, i]].abs() > augmented[[max_row, i]].abs() {
375                max_row = k;
376            }
377        }
378
379        // Swap rows if needed
380        if max_row != i {
381            for j in 0..2 * n {
382                let temp = augmented[[i, j]];
383                augmented[[i, j]] = augmented[[max_row, j]];
384                augmented[[max_row, j]] = temp;
385            }
386        }
387
388        // Check for singular matrix
389        if augmented[[i, i]].abs() < 1e-10 {
390            return Err("Matrix is singular");
391        }
392
393        // Make diagonal element 1
394        let pivot = augmented[[i, i]];
395        for j in 0..2 * n {
396            augmented[[i, j]] /= pivot;
397        }
398
399        // Eliminate column
400        for k in 0..n {
401            if k != i {
402                let factor = augmented[[k, i]];
403                for j in 0..2 * n {
404                    augmented[[k, j]] -= factor * augmented[[i, j]];
405                }
406            }
407        }
408    }
409
410    // Extract inverse matrix
411    let mut inverse = Array2::zeros((n, n));
412    for i in 0..n {
413        for j in 0..n {
414            inverse[[i, j]] = augmented[[i, j + n]];
415        }
416    }
417
418    Ok(inverse)
419}