1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
14use sklears_core::error::{Result, SklearsError};
15use sklears_core::traits::{Estimator, Fit, Predict, Trained, Untrained};
16use sklears_core::types::Float;
17use std::marker::PhantomData;
18
19#[derive(Debug, Clone)]
21pub struct ModelTreeNode {
22 pub feature: Option<usize>,
24 pub threshold: Float,
26 pub left: Option<Box<ModelTreeNode>>,
28 pub right: Option<Box<ModelTreeNode>>,
30 pub coefficients: Option<Array1<Float>>,
32 pub intercept: Option<Float>,
34 pub n_samples: usize,
36 pub std_dev: Float,
38}
39
40impl ModelTreeNode {
41 pub fn new_leaf(
43 coefficients: Array1<Float>,
44 intercept: Float,
45 n_samples: usize,
46 std_dev: Float,
47 ) -> Self {
48 Self {
49 feature: None,
50 threshold: 0.0,
51 left: None,
52 right: None,
53 coefficients: Some(coefficients),
54 intercept: Some(intercept),
55 n_samples,
56 std_dev,
57 }
58 }
59
60 pub fn new_internal(
62 feature: usize,
63 threshold: Float,
64 left: Self,
65 right: Self,
66 n_samples: usize,
67 std_dev: Float,
68 ) -> Self {
69 Self {
70 feature: Some(feature),
71 threshold,
72 left: Some(Box::new(left)),
73 right: Some(Box::new(right)),
74 coefficients: None,
75 intercept: None,
76 n_samples,
77 std_dev,
78 }
79 }
80
81 pub fn is_leaf(&self) -> bool {
83 self.left.is_none() && self.right.is_none()
84 }
85
86 pub fn predict_sample(&self, sample: &ArrayView1<Float>) -> Float {
88 if self.is_leaf() {
89 if let (Some(coef), Some(intercept)) = (&self.coefficients, &self.intercept) {
91 let prediction: Float = sample.dot(coef) + intercept;
92 prediction
93 } else {
94 0.0 }
96 } else if let Some(feature_idx) = self.feature {
97 let value = sample[feature_idx];
99 if value <= self.threshold {
100 self.left
101 .as_ref()
102 .map(|node| node.predict_sample(sample))
103 .unwrap_or(0.0)
104 } else {
105 self.right
106 .as_ref()
107 .map(|node| node.predict_sample(sample))
108 .unwrap_or(0.0)
109 }
110 } else {
111 0.0
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct ModelTreeConfig {
119 pub max_depth: Option<usize>,
121 pub min_samples_split: usize,
123 pub min_samples_leaf: usize,
125 pub min_std_dev_reduction: Float,
127 pub prune: bool,
129 pub smoothing: bool,
131 pub leaf_model: LeafModelType,
133}
134
135impl Default for ModelTreeConfig {
136 fn default() -> Self {
137 Self {
138 max_depth: None,
139 min_samples_split: 4,
140 min_samples_leaf: 2,
141 min_std_dev_reduction: 0.05,
142 prune: true,
143 smoothing: true,
144 leaf_model: LeafModelType::Linear,
145 }
146 }
147}
148
149#[derive(Debug, Clone, Copy)]
151pub enum LeafModelType {
152 Linear,
154 Constant,
156 Polynomial,
158}
159
160pub struct ModelTree<State = Untrained> {
162 config: ModelTreeConfig,
163 state: PhantomData<State>,
164 root: Option<ModelTreeNode>,
165 n_features: Option<usize>,
166 feature_importances: Option<Array1<Float>>,
167}
168
169impl ModelTree<Untrained> {
170 pub fn new() -> Self {
172 Self {
173 config: ModelTreeConfig::default(),
174 state: PhantomData,
175 root: None,
176 n_features: None,
177 feature_importances: None,
178 }
179 }
180
181 pub fn max_depth(mut self, max_depth: usize) -> Self {
183 self.config.max_depth = Some(max_depth);
184 self
185 }
186
187 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
189 self.config.min_samples_split = min_samples_split;
190 self
191 }
192
193 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
195 self.config.min_samples_leaf = min_samples_leaf;
196 self
197 }
198
199 pub fn min_std_dev_reduction(mut self, min_std_dev_reduction: Float) -> Self {
201 self.config.min_std_dev_reduction = min_std_dev_reduction;
202 self
203 }
204
205 pub fn prune(mut self, prune: bool) -> Self {
207 self.config.prune = prune;
208 self
209 }
210
211 pub fn leaf_model(mut self, leaf_model: LeafModelType) -> Self {
213 self.config.leaf_model = leaf_model;
214 self
215 }
216}
217
218impl Default for ModelTree<Untrained> {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224impl Estimator for ModelTree<Untrained> {
225 type Config = ModelTreeConfig;
226 type Error = SklearsError;
227 type Float = Float;
228
229 fn config(&self) -> &Self::Config {
230 &self.config
231 }
232}
233
234impl Fit<Array2<Float>, Array1<Float>> for ModelTree<Untrained> {
235 type Fitted = ModelTree<Trained>;
236
237 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
238 let n_samples = x.nrows();
239 let n_features = x.ncols();
240
241 if n_samples == 0 {
242 return Err(SklearsError::InvalidInput(
243 "No samples provided".to_string(),
244 ));
245 }
246
247 if n_samples != y.len() {
248 return Err(SklearsError::ShapeMismatch {
249 expected: format!("X.shape[0] == y.len() ({})", x.nrows()),
250 actual: format!("y.len() = {}", y.len()),
251 });
252 }
253
254 let indices: Vec<usize> = (0..n_samples).collect();
256 let root = build_model_tree(x, y, &indices, 0, &self.config)?;
257
258 let mut feature_importances = Array1::zeros(n_features);
260 compute_feature_importances(&root, &mut feature_importances);
261
262 let sum = feature_importances.sum();
264 if sum > 0.0 {
265 feature_importances /= sum;
266 }
267
268 Ok(ModelTree::<Trained> {
269 config: self.config,
270 state: PhantomData,
271 root: Some(root),
272 n_features: Some(n_features),
273 feature_importances: Some(feature_importances),
274 })
275 }
276}
277
278impl ModelTree<Trained> {
279 pub fn n_features(&self) -> usize {
281 self.n_features.expect("Model should be fitted")
282 }
283
284 pub fn feature_importances(&self) -> &Array1<Float> {
286 self.feature_importances
287 .as_ref()
288 .expect("Model should be fitted")
289 }
290
291 pub fn tree(&self) -> &ModelTreeNode {
293 self.root.as_ref().expect("Model should be fitted")
294 }
295}
296
297impl Predict<Array2<Float>, Array1<Float>> for ModelTree<Trained> {
298 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
299 let n_samples = x.nrows();
300
301 if x.ncols() != self.n_features() {
302 return Err(SklearsError::FeatureMismatch {
303 expected: self.n_features(),
304 actual: x.ncols(),
305 });
306 }
307
308 let root = self.root.as_ref().ok_or(SklearsError::NotFitted {
309 operation: "predict".to_string(),
310 })?;
311
312 let mut predictions = Array1::zeros(n_samples);
313 for (i, sample) in x.axis_iter(Axis(0)).enumerate() {
314 predictions[i] = root.predict_sample(&sample);
315 }
316
317 Ok(predictions)
318 }
319}
320
321fn build_model_tree(
323 x: &Array2<Float>,
324 y: &Array1<Float>,
325 indices: &[usize],
326 depth: usize,
327 config: &ModelTreeConfig,
328) -> Result<ModelTreeNode> {
329 let n_samples = indices.len();
330 let _n_features = x.ncols();
331
332 let mean_y: Float = indices.iter().map(|&i| y[i]).sum::<Float>() / n_samples as Float;
334 let variance: Float = indices
335 .iter()
336 .map(|&i| (y[i] - mean_y).powi(2))
337 .sum::<Float>()
338 / n_samples as Float;
339 let std_dev = variance.sqrt();
340
341 let should_create_leaf = n_samples < config.min_samples_split
343 || depth >= config.max_depth.unwrap_or(usize::MAX)
344 || std_dev < config.min_std_dev_reduction
345 || n_samples < 2 * config.min_samples_leaf;
346
347 if should_create_leaf {
348 return create_leaf_node(x, y, indices, config);
349 }
350
351 let best_split = find_best_split(x, y, indices, std_dev, config)?;
353
354 if let Some((feature, threshold, left_indices, right_indices, std_dev_reduction)) = best_split {
355 if std_dev_reduction < config.min_std_dev_reduction {
357 return create_leaf_node(x, y, indices, config);
358 }
359
360 let left_node = build_model_tree(x, y, &left_indices, depth + 1, config)?;
362 let right_node = build_model_tree(x, y, &right_indices, depth + 1, config)?;
363
364 Ok(ModelTreeNode::new_internal(
365 feature, threshold, left_node, right_node, n_samples, std_dev,
366 ))
367 } else {
368 create_leaf_node(x, y, indices, config)
370 }
371}
372
373fn find_best_split(
375 x: &Array2<Float>,
376 y: &Array1<Float>,
377 indices: &[usize],
378 current_std_dev: Float,
379 config: &ModelTreeConfig,
380) -> Result<Option<(usize, Float, Vec<usize>, Vec<usize>, Float)>> {
381 let n_features = x.ncols();
382 let n_samples = indices.len();
383
384 let mut best_split = None;
385 let mut best_reduction = 0.0;
386
387 for feature in 0..n_features {
389 let mut feature_values: Vec<Float> = indices.iter().map(|&i| x[[i, feature]]).collect();
391 feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
392 feature_values.dedup();
393
394 if feature_values.len() < 2 {
395 continue;
396 }
397
398 for i in 0..feature_values.len() - 1 {
400 let threshold = (feature_values[i] + feature_values[i + 1]) / 2.0;
401
402 let mut left_indices = Vec::new();
404 let mut right_indices = Vec::new();
405
406 for &idx in indices {
407 if x[[idx, feature]] <= threshold {
408 left_indices.push(idx);
409 } else {
410 right_indices.push(idx);
411 }
412 }
413
414 if left_indices.len() < config.min_samples_leaf
416 || right_indices.len() < config.min_samples_leaf
417 {
418 continue;
419 }
420
421 let left_std = calculate_std_dev(y, &left_indices);
423 let right_std = calculate_std_dev(y, &right_indices);
424
425 let weighted_std = (left_indices.len() as Float * left_std
426 + right_indices.len() as Float * right_std)
427 / n_samples as Float;
428
429 let std_dev_reduction = current_std_dev - weighted_std;
430
431 if std_dev_reduction > best_reduction {
432 best_reduction = std_dev_reduction;
433 best_split = Some((
434 feature,
435 threshold,
436 left_indices,
437 right_indices,
438 std_dev_reduction,
439 ));
440 }
441 }
442 }
443
444 Ok(best_split)
445}
446
447fn calculate_std_dev(y: &Array1<Float>, indices: &[usize]) -> Float {
449 if indices.is_empty() {
450 return 0.0;
451 }
452
453 let mean: Float = indices.iter().map(|&i| y[i]).sum::<Float>() / indices.len() as Float;
454 let variance: Float = indices
455 .iter()
456 .map(|&i| (y[i] - mean).powi(2))
457 .sum::<Float>()
458 / indices.len() as Float;
459
460 variance.sqrt()
461}
462
463fn create_leaf_node(
465 x: &Array2<Float>,
466 y: &Array1<Float>,
467 indices: &[usize],
468 config: &ModelTreeConfig,
469) -> Result<ModelTreeNode> {
470 let n_samples = indices.len();
471 let n_features = x.ncols();
472
473 let std_dev = calculate_std_dev(y, indices);
475
476 match config.leaf_model {
477 LeafModelType::Constant => {
478 let mean: Float = indices.iter().map(|&i| y[i]).sum::<Float>() / n_samples as Float;
480 let coefficients = Array1::zeros(n_features);
481 Ok(ModelTreeNode::new_leaf(
482 coefficients,
483 mean,
484 n_samples,
485 std_dev,
486 ))
487 }
488 LeafModelType::Linear | LeafModelType::Polynomial => {
489 let (coefficients, intercept) = fit_linear_model(x, y, indices)?;
491 Ok(ModelTreeNode::new_leaf(
492 coefficients,
493 intercept,
494 n_samples,
495 std_dev,
496 ))
497 }
498 }
499}
500
501fn fit_linear_model(
503 x: &Array2<Float>,
504 y: &Array1<Float>,
505 indices: &[usize],
506) -> Result<(Array1<Float>, Float)> {
507 let n_samples = indices.len();
508 let n_features = x.ncols();
509
510 if n_samples == 0 {
511 return Ok((Array1::zeros(n_features), 0.0));
512 }
513
514 let mut x_subset = Array2::zeros((n_samples, n_features));
516 let mut y_subset = Array1::zeros(n_samples);
517
518 for (i, &idx) in indices.iter().enumerate() {
519 x_subset.row_mut(i).assign(&x.row(idx));
520 y_subset[i] = y[idx];
521 }
522
523 let mut x_design = Array2::ones((n_samples, n_features + 1));
525 for i in 0..n_samples {
526 for j in 0..n_features {
527 x_design[[i, j]] = x_subset[[i, j]];
528 }
529 }
530
531 let xt = x_design.t();
534 let xtx = xt.dot(&x_design);
535 let xty = xt.dot(&y_subset);
536
537 let mut xtx_reg = xtx.to_owned();
539 for i in 0..n_features + 1 {
540 xtx_reg[[i, i]] += 1e-6;
541 }
542
543 let beta = solve_linear_system(&xtx_reg, &xty)?;
545
546 let coefficients = beta
548 .slice(scirs2_core::ndarray::s![0..n_features])
549 .to_owned();
550 let intercept = beta[n_features];
551
552 Ok((coefficients, intercept))
553}
554
555fn solve_linear_system(a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
557 let n = a.nrows();
558
559 if n != b.len() {
560 return Err(SklearsError::ShapeMismatch {
561 expected: format!("A.nrows() == b.len() ({})", n),
562 actual: format!("b.len() = {}", b.len()),
563 });
564 }
565
566 let mut aug = Array2::zeros((n, n + 1));
568 for i in 0..n {
569 for j in 0..n {
570 aug[[i, j]] = a[[i, j]];
571 }
572 aug[[i, n]] = b[i];
573 }
574
575 for k in 0..n {
577 let mut pivot_row = k;
579 let mut max_val = aug[[k, k]].abs();
580 for i in (k + 1)..n {
581 if aug[[i, k]].abs() > max_val {
582 max_val = aug[[i, k]].abs();
583 pivot_row = i;
584 }
585 }
586
587 if pivot_row != k {
589 for j in 0..=n {
590 let temp = aug[[k, j]];
591 aug[[k, j]] = aug[[pivot_row, j]];
592 aug[[pivot_row, j]] = temp;
593 }
594 }
595
596 let pivot = aug[[k, k]];
598 if pivot.abs() < 1e-10 {
599 continue;
601 }
602
603 for i in (k + 1)..n {
604 let factor = aug[[i, k]] / pivot;
605 for j in k..=n {
606 aug[[i, j]] -= factor * aug[[k, j]];
607 }
608 }
609 }
610
611 let mut x = Array1::zeros(n);
613 for i in (0..n).rev() {
614 let mut sum = aug[[i, n]];
615 for j in (i + 1)..n {
616 sum -= aug[[i, j]] * x[j];
617 }
618
619 let diag = aug[[i, i]];
620 x[i] = if diag.abs() > 1e-10 {
621 sum / diag
622 } else {
623 0.0 };
625 }
626
627 Ok(x)
628}
629
630fn compute_feature_importances(node: &ModelTreeNode, importances: &mut Array1<Float>) {
632 if let Some(feature) = node.feature {
633 let importance = node.std_dev * node.n_samples as Float;
635 importances[feature] += importance;
636
637 if let Some(ref left) = node.left {
639 compute_feature_importances(left, importances);
640 }
641 if let Some(ref right) = node.right {
642 compute_feature_importances(right, importances);
643 }
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use approx::assert_relative_eq;
651
652 #[test]
653 fn test_model_tree_basic() {
654 let mut x = Array2::zeros((100, 2));
656 let mut y = Array1::zeros(100);
657
658 for i in 0..100 {
659 let x1 = (i as Float) / 50.0 - 1.0;
660 let x2 = ((i * 2) as Float) / 50.0 - 2.0;
661 x[[i, 0]] = x1;
662 x[[i, 1]] = x2;
663 y[i] = 2.0 * x1 + 3.0 * x2 + 1.0;
664 }
665
666 let model = ModelTree::new().max_depth(5).min_samples_leaf(5);
667
668 let fitted = model.fit(&x, &y).unwrap();
669 let predictions = fitted.predict(&x).unwrap();
670
671 let y_mean = y.mean().unwrap();
673 let ss_tot: Float = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
674 let ss_res: Float = y
675 .iter()
676 .zip(predictions.iter())
677 .map(|(&yi, &pred)| (yi - pred).powi(2))
678 .sum();
679 let r2 = 1.0 - ss_res / ss_tot;
680
681 assert!(
682 r2 > 0.8,
683 "R² should be high for linear relationship: {}",
684 r2
685 );
686 }
687
688 #[test]
689 fn test_model_tree_nonlinear() {
690 let mut x = Array2::zeros((50, 1));
692 let mut y = Array1::zeros(50);
693
694 for i in 0..50 {
695 let xi = (i as Float) / 10.0 - 2.5;
696 x[[i, 0]] = xi;
697 y[i] = xi * xi;
698 }
699
700 let model = ModelTree::new().max_depth(4).min_samples_leaf(3);
701
702 let fitted = model.fit(&x, &y).unwrap();
703 let predictions = fitted.predict(&x).unwrap();
704
705 let mse: Float = y
707 .iter()
708 .zip(predictions.iter())
709 .map(|(&yi, &pred)| (yi - pred).powi(2))
710 .sum::<Float>()
711 / y.len() as Float;
712
713 assert!(
714 mse < 1.0,
715 "MSE should be reasonable for piecewise linear approximation: {}",
716 mse
717 );
718 }
719
720 #[test]
721 fn test_linear_model_fitting() {
722 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
724 let y = Array1::from_vec(vec![3.0, 7.0, 11.0]); let indices = vec![0, 1, 2];
726
727 let (coef, intercept) = fit_linear_model(&x, &y, &indices).unwrap();
728
729 let mut error = 0.0;
731 for i in 0..3 {
732 let pred = x.row(i).dot(&coef) + intercept;
733 error += (y[i] - pred).abs();
734 }
735
736 assert!(error < 0.1, "Linear model should fit well: error={}", error);
737 }
738
739 #[test]
740 fn test_solve_linear_system() {
741 let a = Array2::from_shape_vec((2, 2), vec![2.0, 1.0, 1.0, 3.0]).unwrap();
746 let b = Array1::from_vec(vec![5.0, 8.0]);
747
748 let x = solve_linear_system(&a, &b).unwrap();
749
750 assert_relative_eq!(x[0], 1.4, epsilon = 1e-6);
751 assert_relative_eq!(x[1], 2.2, epsilon = 1e-6);
752 }
753
754 #[test]
755 fn test_constant_leaf_model() {
756 let mut x = Array2::zeros((20, 1));
757 let mut y = Array1::zeros(20);
758
759 for i in 0..20 {
760 x[[i, 0]] = (i as Float) / 10.0;
761 y[i] = if i < 10 { 1.0 } else { 5.0 };
762 }
763
764 let model = ModelTree::new()
765 .leaf_model(LeafModelType::Constant)
766 .max_depth(2);
767
768 let fitted = model.fit(&x, &y).unwrap();
769 let predictions = fitted.predict(&x).unwrap();
770
771 assert!(predictions[0] > 0.5 && predictions[0] < 2.0);
773 assert!(predictions[19] > 4.0 && predictions[19] < 6.0);
774 }
775}