1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::types::Float;
8use smartcore::linalg::basic::matrix::DenseMatrix;
9
10#[cfg(feature = "oblique")]
11use sklears_core::error::Result;
12
13use crate::criteria::SplitCriterion;
15
16#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum MonotonicConstraint {
19 None,
21 Increasing,
23 Decreasing,
25}
26
27#[derive(Debug, Clone)]
29pub enum InteractionConstraint {
30 None,
32 Groups(Vec<Vec<usize>>),
34 Forbidden(Vec<(usize, usize)>),
36 Allowed(Vec<(usize, usize)>),
38}
39
40#[derive(Debug, Clone)]
42pub enum FeatureGrouping {
43 None,
45 AutoCorrelation {
47 threshold: Float,
49 selection_method: GroupSelectionMethod,
51 },
52 Manual {
54 groups: Vec<Vec<usize>>,
56 selection_method: GroupSelectionMethod,
58 },
59 Hierarchical {
61 n_clusters: usize,
63 linkage: LinkageMethod,
65 selection_method: GroupSelectionMethod,
67 },
68}
69
70#[derive(Debug, Clone, Copy, PartialEq)]
72pub enum GroupSelectionMethod {
73 MaxVariance,
75 MaxTargetCorrelation,
77 First,
79 Random,
81 WeightedAll,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq)]
87pub enum LinkageMethod {
88 Single,
90 Complete,
92 Average,
94 Ward,
96}
97
98#[derive(Debug, Clone)]
100pub struct FeatureGroupInfo {
101 pub groups: Vec<Vec<usize>>,
103 pub representatives: Vec<usize>,
105 pub correlation_matrix: Option<Array2<Float>>,
107 pub group_correlations: Vec<Float>,
109}
110
111#[derive(Debug, Clone)]
113pub enum MaxFeatures {
114 All,
116 Sqrt,
118 Log2,
120 Number(usize),
122 Fraction(f64),
124}
125
126#[derive(Debug, Clone, Copy)]
128pub enum PruningStrategy {
129 None,
131 CostComplexity { alpha: f64 },
133 ReducedError,
135}
136
137#[derive(Debug, Clone, Copy)]
139pub enum MissingValueStrategy {
140 Skip,
142 Majority,
144 Surrogate,
146}
147
148#[derive(Debug, Clone)]
150pub enum FeatureType {
151 Continuous,
153 Categorical(Vec<String>),
155}
156
157#[derive(Debug, Clone)]
159pub struct MultiwaySplit {
160 pub feature_idx: usize,
162 pub category_branches: Vec<Vec<String>>,
164 pub impurity_decrease: f64,
166}
167
168#[derive(Debug, Clone, Copy)]
170pub enum TreeGrowingStrategy {
171 DepthFirst,
173 BestFirst { max_leaves: Option<usize> },
175}
176
177#[derive(Debug, Clone, Copy)]
179pub enum SplitType {
180 AxisAligned,
182 Oblique {
184 n_hyperplanes: usize,
186 use_ridge: bool,
188 },
189}
190
191#[derive(Debug, Clone)]
193pub struct HyperplaneSplit {
194 pub coefficients: Array1<f64>,
196 pub threshold: f64,
198 pub bias: f64,
200 pub impurity_decrease: f64,
202}
203
204impl HyperplaneSplit {
205 pub fn evaluate(&self, sample: &Array1<f64>) -> bool {
207 let dot_product = self.coefficients.dot(sample) + self.bias;
208 dot_product >= self.threshold
209 }
210
211 pub fn random(n_features: usize, rng: &mut scirs2_core::CoreRandom) -> Self {
213 let mut coefficients = Array1::zeros(n_features);
214 for i in 0..n_features {
215 coefficients[i] = rng.gen_range(-1.0..1.0);
216 }
217
218 let dot_product: f64 = coefficients.dot(&coefficients);
220 let norm = dot_product.sqrt();
221 if norm > 1e-10_f64 {
222 coefficients /= norm;
223 }
224
225 Self {
226 coefficients,
227 threshold: rng.gen_range(-1.0..1.0),
228 bias: rng.gen_range(-0.1..0.1),
229 impurity_decrease: 0.0,
230 }
231 }
232
233 #[cfg(feature = "oblique")]
235 pub fn from_ridge_regression(x: &Array2<f64>, y: &Array1<f64>, alpha: f64) -> Result<Self> {
236 use scirs2_core::ndarray::s;
237 use sklears_core::error::SklearsError;
238
239 let n_features = x.ncols();
240 if x.nrows() < 2 {
241 return Err(SklearsError::InvalidInput(
242 "Need at least 2 samples for ridge regression".to_string(),
243 ));
244 }
245
246 let mut x_bias = Array2::ones((x.nrows(), n_features + 1));
248 x_bias.slice_mut(s![.., ..n_features]).assign(x);
249
250 let xtx = x_bias.t().dot(&x_bias);
252 let ridge_matrix = xtx + Array2::<f64>::eye(n_features + 1) * alpha;
253 let xty = x_bias.t().dot(y);
254
255 match gauss_jordan_inverse(&ridge_matrix) {
257 Ok(inv_matrix) => {
258 let coefficients_full = inv_matrix.dot(&xty);
259
260 let coefficients = coefficients_full.slice(s![..n_features]).to_owned();
261 let bias = coefficients_full[n_features];
262
263 Ok(Self {
264 coefficients,
265 threshold: 0.0, bias,
267 impurity_decrease: 0.0,
268 })
269 }
270 Err(_) => {
271 let mut rng = scirs2_core::random::thread_rng();
273 Ok(Self::random(n_features, &mut rng))
274 }
275 }
276 }
277}
278
279#[derive(Debug, Clone)]
281pub struct DecisionTreeConfig {
282 pub criterion: SplitCriterion,
284 pub max_depth: Option<usize>,
286 pub min_samples_split: usize,
288 pub min_samples_leaf: usize,
290 pub max_features: MaxFeatures,
292 pub random_state: Option<u64>,
294 pub min_weight_fraction_leaf: f64,
296 pub min_impurity_decrease: f64,
298 pub pruning: PruningStrategy,
300 pub missing_values: MissingValueStrategy,
302 pub feature_types: Option<Vec<FeatureType>>,
304 pub growing_strategy: TreeGrowingStrategy,
306 pub split_type: SplitType,
308 pub monotonic_constraints: Option<Vec<MonotonicConstraint>>,
310 pub interaction_constraints: InteractionConstraint,
312 pub feature_grouping: FeatureGrouping,
314}
315
316impl Default for DecisionTreeConfig {
317 fn default() -> Self {
318 Self {
319 criterion: SplitCriterion::Gini,
320 max_depth: None,
321 min_samples_split: 2,
322 min_samples_leaf: 1,
323 max_features: MaxFeatures::All,
324 random_state: None,
325 min_weight_fraction_leaf: 0.0,
326 min_impurity_decrease: 0.0,
327 pruning: PruningStrategy::None,
328 missing_values: MissingValueStrategy::Skip,
329 feature_types: None,
330 growing_strategy: TreeGrowingStrategy::DepthFirst,
331 split_type: SplitType::AxisAligned,
332 monotonic_constraints: None,
333 interaction_constraints: InteractionConstraint::None,
334 feature_grouping: FeatureGrouping::None,
335 }
336 }
337}
338
339pub fn ndarray_to_dense_matrix(arr: &Array2<f64>) -> DenseMatrix<f64> {
341 let _rows = arr.nrows();
342 let _cols = arr.ncols();
343 let mut data = Vec::new();
344 for row in arr.outer_iter() {
345 data.push(row.to_vec());
346 }
347 DenseMatrix::from_2d_vec(&data).expect("Failed to convert ndarray to DenseMatrix")
348}
349
350#[cfg(feature = "oblique")]
352fn gauss_jordan_inverse(matrix: &Array2<f64>) -> std::result::Result<Array2<f64>, &'static str> {
353 let n = matrix.nrows();
354 if n != matrix.ncols() {
355 return Err("Matrix must be square");
356 }
357
358 let mut augmented = Array2::zeros((n, 2 * n));
360 for i in 0..n {
361 for j in 0..n {
362 augmented[[i, j]] = matrix[[i, j]];
363 if i == j {
364 augmented[[i, j + n]] = 1.0;
365 }
366 }
367 }
368
369 for i in 0..n {
371 let mut max_row = i;
373 for k in i + 1..n {
374 if augmented[[k, i]].abs() > augmented[[max_row, i]].abs() {
375 max_row = k;
376 }
377 }
378
379 if max_row != i {
381 for j in 0..2 * n {
382 let temp = augmented[[i, j]];
383 augmented[[i, j]] = augmented[[max_row, j]];
384 augmented[[max_row, j]] = temp;
385 }
386 }
387
388 if augmented[[i, i]].abs() < 1e-10 {
390 return Err("Matrix is singular");
391 }
392
393 let pivot = augmented[[i, i]];
395 for j in 0..2 * n {
396 augmented[[i, j]] /= pivot;
397 }
398
399 for k in 0..n {
401 if k != i {
402 let factor = augmented[[k, i]];
403 for j in 0..2 * n {
404 augmented[[k, j]] -= factor * augmented[[i, j]];
405 }
406 }
407 }
408 }
409
410 let mut inverse = Array2::zeros((n, n));
412 for i in 0..n {
413 for j in 0..n {
414 inverse[[i, j]] = augmented[[i, j + n]];
415 }
416 }
417
418 Ok(inverse)
419}