1use ferrolearn_core::error::FerroError;
28use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
29use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
30use ferrolearn_core::traits::{Fit, Predict};
31use ndarray::{Array1, Array2};
32use num_traits::{Float, FromPrimitive, ToPrimitive};
33use rand::SeedableRng;
34use rand::rngs::StdRng;
35use rand::seq::index::sample as rand_sample_indices;
36use rayon::prelude::*;
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40 self, ClassificationCriterion, Node, build_classification_tree_with_feature_subset,
41 build_regression_tree_with_feature_subset, compute_feature_importances,
42};
43
44#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
50pub enum MaxFeatures {
51 Sqrt,
53 Log2,
55 All,
57 Fixed(usize),
59 Fraction(f64),
61}
62
63fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
65 let result = match strategy {
66 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
67 MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
68 MaxFeatures::All => n_features,
69 MaxFeatures::Fixed(n) => n.min(n_features),
70 MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
71 };
72 result.max(1).min(n_features)
73}
74
75fn make_tree_params(
80 max_depth: Option<usize>,
81 min_samples_split: usize,
82 min_samples_leaf: usize,
83) -> decision_tree::TreeParams {
84 decision_tree::TreeParams {
85 max_depth,
86 min_samples_split,
87 min_samples_leaf,
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct RandomForestClassifier<F> {
106 pub n_estimators: usize,
108 pub max_depth: Option<usize>,
110 pub max_features: MaxFeatures,
112 pub min_samples_split: usize,
114 pub min_samples_leaf: usize,
116 pub random_state: Option<u64>,
118 pub criterion: ClassificationCriterion,
120 _marker: std::marker::PhantomData<F>,
121}
122
123impl<F: Float> RandomForestClassifier<F> {
124 #[must_use]
131 pub fn new() -> Self {
132 Self {
133 n_estimators: 100,
134 max_depth: None,
135 max_features: MaxFeatures::Sqrt,
136 min_samples_split: 2,
137 min_samples_leaf: 1,
138 random_state: None,
139 criterion: ClassificationCriterion::Gini,
140 _marker: std::marker::PhantomData,
141 }
142 }
143
144 #[must_use]
146 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
147 self.n_estimators = n_estimators;
148 self
149 }
150
151 #[must_use]
153 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
154 self.max_depth = max_depth;
155 self
156 }
157
158 #[must_use]
160 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
161 self.max_features = max_features;
162 self
163 }
164
165 #[must_use]
167 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
168 self.min_samples_split = min_samples_split;
169 self
170 }
171
172 #[must_use]
174 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
175 self.min_samples_leaf = min_samples_leaf;
176 self
177 }
178
179 #[must_use]
181 pub fn with_random_state(mut self, seed: u64) -> Self {
182 self.random_state = Some(seed);
183 self
184 }
185
186 #[must_use]
188 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
189 self.criterion = criterion;
190 self
191 }
192}
193
194impl<F: Float> Default for RandomForestClassifier<F> {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200#[derive(Debug, Clone)]
209pub struct FittedRandomForestClassifier<F> {
210 trees: Vec<Vec<Node<F>>>,
212 classes: Vec<usize>,
214 n_features: usize,
216 feature_importances: Array1<F>,
218}
219
220impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for RandomForestClassifier<F> {
221 type Fitted = FittedRandomForestClassifier<F>;
222 type Error = FerroError;
223
224 fn fit(
236 &self,
237 x: &Array2<F>,
238 y: &Array1<usize>,
239 ) -> Result<FittedRandomForestClassifier<F>, FerroError> {
240 let (n_samples, n_features) = x.dim();
241
242 if n_samples != y.len() {
243 return Err(FerroError::ShapeMismatch {
244 expected: vec![n_samples],
245 actual: vec![y.len()],
246 context: "y length must match number of samples in X".into(),
247 });
248 }
249 if n_samples == 0 {
250 return Err(FerroError::InsufficientSamples {
251 required: 1,
252 actual: 0,
253 context: "RandomForestClassifier requires at least one sample".into(),
254 });
255 }
256 if self.n_estimators == 0 {
257 return Err(FerroError::InvalidParameter {
258 name: "n_estimators".into(),
259 reason: "must be at least 1".into(),
260 });
261 }
262
263 let mut classes: Vec<usize> = y.iter().copied().collect();
265 classes.sort_unstable();
266 classes.dedup();
267 let n_classes = classes.len();
268
269 let y_mapped: Vec<usize> = y
270 .iter()
271 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
272 .collect();
273
274 let max_features_n = resolve_max_features(self.max_features, n_features);
275 let params = make_tree_params(
276 self.max_depth,
277 self.min_samples_split,
278 self.min_samples_leaf,
279 );
280 let criterion = self.criterion;
281
282 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
284 let mut master_rng = StdRng::seed_from_u64(seed);
285 (0..self.n_estimators)
286 .map(|_| {
287 use rand::RngCore;
288 master_rng.next_u64()
289 })
290 .collect()
291 } else {
292 (0..self.n_estimators)
293 .map(|_| {
294 use rand::RngCore;
295 rand::rng().next_u64()
296 })
297 .collect()
298 };
299
300 let trees: Vec<Vec<Node<F>>> = tree_seeds
302 .par_iter()
303 .map(|&seed| {
304 let mut rng = StdRng::seed_from_u64(seed);
305
306 let bootstrap_indices: Vec<usize> = (0..n_samples)
308 .map(|_| {
309 use rand::RngCore;
310 (rng.next_u64() as usize) % n_samples
311 })
312 .collect();
313
314 let feature_indices: Vec<usize> =
316 rand_sample_indices(&mut rng, n_features, max_features_n).into_vec();
317
318 build_classification_tree_with_feature_subset(
319 x,
320 &y_mapped,
321 n_classes,
322 &bootstrap_indices,
323 &feature_indices,
324 ¶ms,
325 criterion,
326 )
327 })
328 .collect();
329
330 let mut total_importances = Array1::<F>::zeros(n_features);
332 for tree_nodes in &trees {
333 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
334 total_importances = total_importances + tree_imp;
335 }
336 let imp_sum: F = total_importances
337 .iter()
338 .copied()
339 .fold(F::zero(), |a, b| a + b);
340 if imp_sum > F::zero() {
341 total_importances.mapv_inplace(|v| v / imp_sum);
342 }
343
344 Ok(FittedRandomForestClassifier {
345 trees,
346 classes,
347 n_features,
348 feature_importances: total_importances,
349 })
350 }
351}
352
353impl<F: Float + Send + Sync + 'static> FittedRandomForestClassifier<F> {
354 #[must_use]
356 pub fn trees(&self) -> &[Vec<Node<F>>] {
357 &self.trees
358 }
359
360 #[must_use]
362 pub fn n_features(&self) -> usize {
363 self.n_features
364 }
365}
366
367impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestClassifier<F> {
368 type Output = Array1<usize>;
369 type Error = FerroError;
370
371 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
378 if x.ncols() != self.n_features {
379 return Err(FerroError::ShapeMismatch {
380 expected: vec![self.n_features],
381 actual: vec![x.ncols()],
382 context: "number of features must match fitted model".into(),
383 });
384 }
385
386 let n_samples = x.nrows();
387 let n_classes = self.classes.len();
388 let mut predictions = Array1::zeros(n_samples);
389
390 for i in 0..n_samples {
391 let row = x.row(i);
392 let mut votes = vec![0usize; n_classes];
393
394 for tree_nodes in &self.trees {
395 let leaf_idx = decision_tree::traverse(tree_nodes, &row);
396 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
397 let class_idx = value.to_f64().map(|f| f.round() as usize).unwrap_or(0);
398 if class_idx < n_classes {
399 votes[class_idx] += 1;
400 }
401 }
402 }
403
404 let winner = votes
405 .iter()
406 .enumerate()
407 .max_by_key(|&(_, &count)| count)
408 .map(|(idx, _)| idx)
409 .unwrap_or(0);
410 predictions[i] = self.classes[winner];
411 }
412
413 Ok(predictions)
414 }
415}
416
417impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
418 for FittedRandomForestClassifier<F>
419{
420 fn feature_importances(&self) -> &Array1<F> {
421 &self.feature_importances
422 }
423}
424
425impl<F: Float + Send + Sync + 'static> HasClasses for FittedRandomForestClassifier<F> {
426 fn classes(&self) -> &[usize] {
427 &self.classes
428 }
429
430 fn n_classes(&self) -> usize {
431 self.classes.len()
432 }
433}
434
435impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
437 for RandomForestClassifier<F>
438{
439 fn fit_pipeline(
440 &self,
441 x: &Array2<F>,
442 y: &Array1<F>,
443 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
444 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
445 let fitted = self.fit(x, &y_usize)?;
446 Ok(Box::new(FittedForestClassifierPipelineAdapter(fitted)))
447 }
448}
449
450struct FittedForestClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
452 FittedRandomForestClassifier<F>,
453);
454
455impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
456 for FittedForestClassifierPipelineAdapter<F>
457{
458 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
459 let preds = self.0.predict(x)?;
460 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
461 }
462}
463
464#[derive(Debug, Clone, Serialize, Deserialize)]
478pub struct RandomForestRegressor<F> {
479 pub n_estimators: usize,
481 pub max_depth: Option<usize>,
483 pub max_features: MaxFeatures,
485 pub min_samples_split: usize,
487 pub min_samples_leaf: usize,
489 pub random_state: Option<u64>,
491 _marker: std::marker::PhantomData<F>,
492}
493
494impl<F: Float> RandomForestRegressor<F> {
495 #[must_use]
501 pub fn new() -> Self {
502 Self {
503 n_estimators: 100,
504 max_depth: None,
505 max_features: MaxFeatures::All,
506 min_samples_split: 2,
507 min_samples_leaf: 1,
508 random_state: None,
509 _marker: std::marker::PhantomData,
510 }
511 }
512
513 #[must_use]
515 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
516 self.n_estimators = n_estimators;
517 self
518 }
519
520 #[must_use]
522 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
523 self.max_depth = max_depth;
524 self
525 }
526
527 #[must_use]
529 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
530 self.max_features = max_features;
531 self
532 }
533
534 #[must_use]
536 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
537 self.min_samples_split = min_samples_split;
538 self
539 }
540
541 #[must_use]
543 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
544 self.min_samples_leaf = min_samples_leaf;
545 self
546 }
547
548 #[must_use]
550 pub fn with_random_state(mut self, seed: u64) -> Self {
551 self.random_state = Some(seed);
552 self
553 }
554}
555
556impl<F: Float> Default for RandomForestRegressor<F> {
557 fn default() -> Self {
558 Self::new()
559 }
560}
561
562#[derive(Debug, Clone)]
571pub struct FittedRandomForestRegressor<F> {
572 trees: Vec<Vec<Node<F>>>,
574 n_features: usize,
576 feature_importances: Array1<F>,
578}
579
580impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for RandomForestRegressor<F> {
581 type Fitted = FittedRandomForestRegressor<F>;
582 type Error = FerroError;
583
584 fn fit(
593 &self,
594 x: &Array2<F>,
595 y: &Array1<F>,
596 ) -> Result<FittedRandomForestRegressor<F>, FerroError> {
597 let (n_samples, n_features) = x.dim();
598
599 if n_samples != y.len() {
600 return Err(FerroError::ShapeMismatch {
601 expected: vec![n_samples],
602 actual: vec![y.len()],
603 context: "y length must match number of samples in X".into(),
604 });
605 }
606 if n_samples == 0 {
607 return Err(FerroError::InsufficientSamples {
608 required: 1,
609 actual: 0,
610 context: "RandomForestRegressor requires at least one sample".into(),
611 });
612 }
613 if self.n_estimators == 0 {
614 return Err(FerroError::InvalidParameter {
615 name: "n_estimators".into(),
616 reason: "must be at least 1".into(),
617 });
618 }
619
620 let max_features_n = resolve_max_features(self.max_features, n_features);
621 let params = make_tree_params(
622 self.max_depth,
623 self.min_samples_split,
624 self.min_samples_leaf,
625 );
626
627 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
629 let mut master_rng = StdRng::seed_from_u64(seed);
630 (0..self.n_estimators)
631 .map(|_| {
632 use rand::RngCore;
633 master_rng.next_u64()
634 })
635 .collect()
636 } else {
637 (0..self.n_estimators)
638 .map(|_| {
639 use rand::RngCore;
640 rand::rng().next_u64()
641 })
642 .collect()
643 };
644
645 let trees: Vec<Vec<Node<F>>> = tree_seeds
647 .par_iter()
648 .map(|&seed| {
649 let mut rng = StdRng::seed_from_u64(seed);
650
651 let bootstrap_indices: Vec<usize> = (0..n_samples)
652 .map(|_| {
653 use rand::RngCore;
654 (rng.next_u64() as usize) % n_samples
655 })
656 .collect();
657
658 let feature_indices: Vec<usize> =
659 rand_sample_indices(&mut rng, n_features, max_features_n).into_vec();
660
661 build_regression_tree_with_feature_subset(
662 x,
663 y,
664 &bootstrap_indices,
665 &feature_indices,
666 ¶ms,
667 )
668 })
669 .collect();
670
671 let mut total_importances = Array1::<F>::zeros(n_features);
673 for tree_nodes in &trees {
674 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
675 total_importances = total_importances + tree_imp;
676 }
677 let imp_sum: F = total_importances
678 .iter()
679 .copied()
680 .fold(F::zero(), |a, b| a + b);
681 if imp_sum > F::zero() {
682 total_importances.mapv_inplace(|v| v / imp_sum);
683 }
684
685 Ok(FittedRandomForestRegressor {
686 trees,
687 n_features,
688 feature_importances: total_importances,
689 })
690 }
691}
692
693impl<F: Float + Send + Sync + 'static> FittedRandomForestRegressor<F> {
694 #[must_use]
696 pub fn trees(&self) -> &[Vec<Node<F>>] {
697 &self.trees
698 }
699
700 #[must_use]
702 pub fn n_features(&self) -> usize {
703 self.n_features
704 }
705}
706
707impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestRegressor<F> {
708 type Output = Array1<F>;
709 type Error = FerroError;
710
711 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
718 if x.ncols() != self.n_features {
719 return Err(FerroError::ShapeMismatch {
720 expected: vec![self.n_features],
721 actual: vec![x.ncols()],
722 context: "number of features must match fitted model".into(),
723 });
724 }
725
726 let n_samples = x.nrows();
727 let n_trees_f = F::from(self.trees.len()).unwrap();
728 let mut predictions = Array1::zeros(n_samples);
729
730 for i in 0..n_samples {
731 let row = x.row(i);
732 let mut sum = F::zero();
733
734 for tree_nodes in &self.trees {
735 let leaf_idx = decision_tree::traverse(tree_nodes, &row);
736 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
737 sum = sum + value;
738 }
739 }
740
741 predictions[i] = sum / n_trees_f;
742 }
743
744 Ok(predictions)
745 }
746}
747
748impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedRandomForestRegressor<F> {
749 fn feature_importances(&self) -> &Array1<F> {
750 &self.feature_importances
751 }
752}
753
754impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for RandomForestRegressor<F> {
756 fn fit_pipeline(
757 &self,
758 x: &Array2<F>,
759 y: &Array1<F>,
760 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
761 let fitted = self.fit(x, y)?;
762 Ok(Box::new(fitted))
763 }
764}
765
766impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
767 for FittedRandomForestRegressor<F>
768{
769 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
770 self.predict(x)
771 }
772}
773
774#[cfg(test)]
779mod tests {
780 use super::*;
781 use approx::assert_relative_eq;
782 use ndarray::array;
783
784 #[test]
787 fn test_forest_classifier_simple() {
788 let x = Array2::from_shape_vec(
789 (8, 2),
790 vec![
791 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
792 ],
793 )
794 .unwrap();
795 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
796
797 let model = RandomForestClassifier::<f64>::new()
798 .with_n_estimators(20)
799 .with_random_state(42);
800 let fitted = model.fit(&x, &y).unwrap();
801 let preds = fitted.predict(&x).unwrap();
802
803 assert_eq!(preds.len(), 8);
804 for i in 0..4 {
805 assert_eq!(preds[i], 0);
806 }
807 for i in 4..8 {
808 assert_eq!(preds[i], 1);
809 }
810 }
811
812 #[test]
813 fn test_forest_classifier_reproducibility() {
814 let x = Array2::from_shape_vec(
815 (8, 2),
816 vec![
817 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
818 ],
819 )
820 .unwrap();
821 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
822
823 let model = RandomForestClassifier::<f64>::new()
824 .with_n_estimators(10)
825 .with_random_state(123);
826
827 let fitted1 = model.fit(&x, &y).unwrap();
828 let fitted2 = model.fit(&x, &y).unwrap();
829
830 let preds1 = fitted1.predict(&x).unwrap();
831 let preds2 = fitted2.predict(&x).unwrap();
832
833 assert_eq!(preds1, preds2);
834 }
835
836 #[test]
837 fn test_forest_classifier_feature_importances() {
838 let x = Array2::from_shape_vec(
839 (10, 3),
840 vec![
841 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
842 0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
843 ],
844 )
845 .unwrap();
846 let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
847
848 let model = RandomForestClassifier::<f64>::new()
849 .with_n_estimators(20)
850 .with_max_features(MaxFeatures::All)
851 .with_random_state(42);
852 let fitted = model.fit(&x, &y).unwrap();
853 let importances = fitted.feature_importances();
854
855 assert_eq!(importances.len(), 3);
856 assert!(importances[0] > importances[1]);
857 assert!(importances[0] > importances[2]);
858 }
859
860 #[test]
861 fn test_forest_classifier_has_classes() {
862 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
863 let y = array![0, 1, 2, 0, 1, 2];
864
865 let model = RandomForestClassifier::<f64>::new()
866 .with_n_estimators(5)
867 .with_random_state(0);
868 let fitted = model.fit(&x, &y).unwrap();
869
870 assert_eq!(fitted.classes(), &[0, 1, 2]);
871 assert_eq!(fitted.n_classes(), 3);
872 }
873
874 #[test]
875 fn test_forest_classifier_shape_mismatch_fit() {
876 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
877 let y = array![0, 1];
878
879 let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
880 assert!(model.fit(&x, &y).is_err());
881 }
882
883 #[test]
884 fn test_forest_classifier_shape_mismatch_predict() {
885 let x =
886 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
887 let y = array![0, 0, 1, 1];
888
889 let model = RandomForestClassifier::<f64>::new()
890 .with_n_estimators(5)
891 .with_random_state(0);
892 let fitted = model.fit(&x, &y).unwrap();
893
894 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
895 assert!(fitted.predict(&x_bad).is_err());
896 }
897
898 #[test]
899 fn test_forest_classifier_empty_data() {
900 let x = Array2::<f64>::zeros((0, 2));
901 let y = Array1::<usize>::zeros(0);
902
903 let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
904 assert!(model.fit(&x, &y).is_err());
905 }
906
907 #[test]
908 fn test_forest_classifier_zero_estimators() {
909 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
910 let y = array![0, 0, 1, 1];
911
912 let model = RandomForestClassifier::<f64>::new().with_n_estimators(0);
913 assert!(model.fit(&x, &y).is_err());
914 }
915
916 #[test]
917 fn test_forest_classifier_single_tree() {
918 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
919 let y = array![0, 0, 0, 1, 1, 1];
920
921 let model = RandomForestClassifier::<f64>::new()
922 .with_n_estimators(1)
923 .with_max_features(MaxFeatures::All)
924 .with_random_state(42);
925 let fitted = model.fit(&x, &y).unwrap();
926 let preds = fitted.predict(&x).unwrap();
927
928 assert_eq!(preds.len(), 6);
929 }
930
931 #[test]
932 fn test_forest_classifier_pipeline_integration() {
933 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
934 let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
935
936 let model = RandomForestClassifier::<f64>::new()
937 .with_n_estimators(5)
938 .with_random_state(42);
939 let fitted = model.fit_pipeline(&x, &y).unwrap();
940 let preds = fitted.predict_pipeline(&x).unwrap();
941 assert_eq!(preds.len(), 6);
942 }
943
944 #[test]
945 fn test_forest_classifier_max_depth() {
946 let x =
947 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
948 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
949
950 let model = RandomForestClassifier::<f64>::new()
951 .with_n_estimators(10)
952 .with_max_depth(Some(1))
953 .with_max_features(MaxFeatures::All)
954 .with_random_state(42);
955 let fitted = model.fit(&x, &y).unwrap();
956 let preds = fitted.predict(&x).unwrap();
957
958 assert_eq!(preds.len(), 8);
959 }
960
961 #[test]
964 fn test_forest_regressor_simple() {
965 let x =
966 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
967 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
968
969 let model = RandomForestRegressor::<f64>::new()
970 .with_n_estimators(50)
971 .with_random_state(42);
972 let fitted = model.fit(&x, &y).unwrap();
973 let preds = fitted.predict(&x).unwrap();
974
975 assert_eq!(preds.len(), 8);
976 for i in 0..4 {
977 assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
978 }
979 for i in 4..8 {
980 assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
981 }
982 }
983
984 #[test]
985 fn test_forest_regressor_reproducibility() {
986 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
987 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
988
989 let model = RandomForestRegressor::<f64>::new()
990 .with_n_estimators(10)
991 .with_random_state(99);
992
993 let fitted1 = model.fit(&x, &y).unwrap();
994 let fitted2 = model.fit(&x, &y).unwrap();
995
996 let preds1 = fitted1.predict(&x).unwrap();
997 let preds2 = fitted2.predict(&x).unwrap();
998
999 for (p1, p2) in preds1.iter().zip(preds2.iter()) {
1000 assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
1001 }
1002 }
1003
1004 #[test]
1005 fn test_forest_regressor_feature_importances() {
1006 let x = Array2::from_shape_vec(
1007 (8, 2),
1008 vec![
1009 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1010 ],
1011 )
1012 .unwrap();
1013 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1014
1015 let model = RandomForestRegressor::<f64>::new()
1016 .with_n_estimators(20)
1017 .with_max_features(MaxFeatures::All)
1018 .with_random_state(42);
1019 let fitted = model.fit(&x, &y).unwrap();
1020 let importances = fitted.feature_importances();
1021
1022 assert_eq!(importances.len(), 2);
1023 assert!(importances[0] > importances[1]);
1024 }
1025
1026 #[test]
1027 fn test_forest_regressor_shape_mismatch_fit() {
1028 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1029 let y = array![1.0, 2.0];
1030
1031 let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1032 assert!(model.fit(&x, &y).is_err());
1033 }
1034
1035 #[test]
1036 fn test_forest_regressor_shape_mismatch_predict() {
1037 let x =
1038 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1039 let y = array![1.0, 2.0, 3.0, 4.0];
1040
1041 let model = RandomForestRegressor::<f64>::new()
1042 .with_n_estimators(5)
1043 .with_random_state(0);
1044 let fitted = model.fit(&x, &y).unwrap();
1045
1046 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1047 assert!(fitted.predict(&x_bad).is_err());
1048 }
1049
1050 #[test]
1051 fn test_forest_regressor_empty_data() {
1052 let x = Array2::<f64>::zeros((0, 2));
1053 let y = Array1::<f64>::zeros(0);
1054
1055 let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1056 assert!(model.fit(&x, &y).is_err());
1057 }
1058
1059 #[test]
1060 fn test_forest_regressor_zero_estimators() {
1061 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1062 let y = array![1.0, 2.0, 3.0, 4.0];
1063
1064 let model = RandomForestRegressor::<f64>::new().with_n_estimators(0);
1065 assert!(model.fit(&x, &y).is_err());
1066 }
1067
1068 #[test]
1069 fn test_forest_regressor_pipeline_integration() {
1070 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1071 let y = array![1.0, 2.0, 3.0, 4.0];
1072
1073 let model = RandomForestRegressor::<f64>::new()
1074 .with_n_estimators(5)
1075 .with_random_state(42);
1076 let fitted = model.fit_pipeline(&x, &y).unwrap();
1077 let preds = fitted.predict_pipeline(&x).unwrap();
1078 assert_eq!(preds.len(), 4);
1079 }
1080
1081 #[test]
1082 fn test_forest_regressor_max_features_strategies() {
1083 let x = Array2::from_shape_vec(
1084 (8, 4),
1085 vec![
1086 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 6.0, 4.0, 5.0, 6.0, 7.0,
1087 5.0, 6.0, 7.0, 8.0, 6.0, 7.0, 8.0, 9.0, 7.0, 8.0, 9.0, 10.0, 8.0, 9.0, 10.0, 11.0,
1088 ],
1089 )
1090 .unwrap();
1091 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1092
1093 for strategy in &[
1094 MaxFeatures::Sqrt,
1095 MaxFeatures::Log2,
1096 MaxFeatures::All,
1097 MaxFeatures::Fixed(2),
1098 MaxFeatures::Fraction(0.5),
1099 ] {
1100 let model = RandomForestRegressor::<f64>::new()
1101 .with_n_estimators(5)
1102 .with_max_features(*strategy)
1103 .with_random_state(42);
1104 let fitted = model.fit(&x, &y).unwrap();
1105 let preds = fitted.predict(&x).unwrap();
1106 assert_eq!(preds.len(), 8);
1107 }
1108 }
1109
1110 #[test]
1113 fn test_resolve_max_features_sqrt() {
1114 assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 9), 3);
1115 assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 10), 4);
1116 assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 1), 1);
1117 }
1118
1119 #[test]
1120 fn test_resolve_max_features_log2() {
1121 assert_eq!(resolve_max_features(MaxFeatures::Log2, 8), 3);
1122 assert_eq!(resolve_max_features(MaxFeatures::Log2, 1), 1);
1123 }
1124
1125 #[test]
1126 fn test_resolve_max_features_all() {
1127 assert_eq!(resolve_max_features(MaxFeatures::All, 10), 10);
1128 assert_eq!(resolve_max_features(MaxFeatures::All, 1), 1);
1129 }
1130
1131 #[test]
1132 fn test_resolve_max_features_fixed() {
1133 assert_eq!(resolve_max_features(MaxFeatures::Fixed(3), 10), 3);
1134 assert_eq!(resolve_max_features(MaxFeatures::Fixed(20), 10), 10);
1135 }
1136
1137 #[test]
1138 fn test_resolve_max_features_fraction() {
1139 assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.5), 10), 5);
1140 assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.1), 10), 1);
1141 }
1142
1143 #[test]
1144 fn test_forest_classifier_f32_support() {
1145 let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1146 let y = array![0, 0, 0, 1, 1, 1];
1147
1148 let model = RandomForestClassifier::<f32>::new()
1149 .with_n_estimators(5)
1150 .with_random_state(42);
1151 let fitted = model.fit(&x, &y).unwrap();
1152 let preds = fitted.predict(&x).unwrap();
1153 assert_eq!(preds.len(), 6);
1154 }
1155
1156 #[test]
1157 fn test_forest_regressor_f32_support() {
1158 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1159 let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1160
1161 let model = RandomForestRegressor::<f32>::new()
1162 .with_n_estimators(5)
1163 .with_random_state(42);
1164 let fitted = model.fit(&x, &y).unwrap();
1165 let preds = fitted.predict(&x).unwrap();
1166 assert_eq!(preds.len(), 4);
1167 }
1168}