1use ferrolearn_core::error::FerroError;
28use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
29use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
30use ferrolearn_core::traits::{Fit, Predict};
31use ndarray::{Array1, Array2};
32use num_traits::{Float, FromPrimitive, ToPrimitive};
33use rand::SeedableRng;
34use rand::rngs::StdRng;
35use rand::seq::index::sample as rand_sample_indices;
36use rayon::prelude::*;
37use serde::{Deserialize, Serialize};
38
39use crate::decision_tree::{
40 self, ClassificationCriterion, Node, build_classification_tree_with_feature_subset,
41 build_regression_tree_with_feature_subset, compute_feature_importances,
42};
43
44#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
50pub enum MaxFeatures {
51 Sqrt,
53 Log2,
55 All,
57 Fixed(usize),
59 Fraction(f64),
61}
62
63fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
65 let result = match strategy {
66 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
67 MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
68 MaxFeatures::All => n_features,
69 MaxFeatures::Fixed(n) => n.min(n_features),
70 MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
71 };
72 result.max(1).min(n_features)
73}
74
75fn make_tree_params(
80 max_depth: Option<usize>,
81 min_samples_split: usize,
82 min_samples_leaf: usize,
83) -> decision_tree::TreeParams {
84 decision_tree::TreeParams {
85 max_depth,
86 min_samples_split,
87 min_samples_leaf,
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct RandomForestClassifier<F> {
106 pub n_estimators: usize,
108 pub max_depth: Option<usize>,
110 pub max_features: MaxFeatures,
112 pub min_samples_split: usize,
114 pub min_samples_leaf: usize,
116 pub random_state: Option<u64>,
118 pub criterion: ClassificationCriterion,
120 _marker: std::marker::PhantomData<F>,
121}
122
123impl<F: Float> RandomForestClassifier<F> {
124 #[must_use]
131 pub fn new() -> Self {
132 Self {
133 n_estimators: 100,
134 max_depth: None,
135 max_features: MaxFeatures::Sqrt,
136 min_samples_split: 2,
137 min_samples_leaf: 1,
138 random_state: None,
139 criterion: ClassificationCriterion::Gini,
140 _marker: std::marker::PhantomData,
141 }
142 }
143
144 #[must_use]
146 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
147 self.n_estimators = n_estimators;
148 self
149 }
150
151 #[must_use]
153 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
154 self.max_depth = max_depth;
155 self
156 }
157
158 #[must_use]
160 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
161 self.max_features = max_features;
162 self
163 }
164
165 #[must_use]
167 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
168 self.min_samples_split = min_samples_split;
169 self
170 }
171
172 #[must_use]
174 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
175 self.min_samples_leaf = min_samples_leaf;
176 self
177 }
178
179 #[must_use]
181 pub fn with_random_state(mut self, seed: u64) -> Self {
182 self.random_state = Some(seed);
183 self
184 }
185
186 #[must_use]
188 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
189 self.criterion = criterion;
190 self
191 }
192}
193
194impl<F: Float> Default for RandomForestClassifier<F> {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200#[derive(Debug, Clone)]
209pub struct FittedRandomForestClassifier<F> {
210 trees: Vec<Vec<Node<F>>>,
212 classes: Vec<usize>,
214 n_features: usize,
216 feature_importances: Array1<F>,
218}
219
220impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for RandomForestClassifier<F> {
221 type Fitted = FittedRandomForestClassifier<F>;
222 type Error = FerroError;
223
224 fn fit(
236 &self,
237 x: &Array2<F>,
238 y: &Array1<usize>,
239 ) -> Result<FittedRandomForestClassifier<F>, FerroError> {
240 let (n_samples, n_features) = x.dim();
241
242 if n_samples != y.len() {
243 return Err(FerroError::ShapeMismatch {
244 expected: vec![n_samples],
245 actual: vec![y.len()],
246 context: "y length must match number of samples in X".into(),
247 });
248 }
249 if n_samples == 0 {
250 return Err(FerroError::InsufficientSamples {
251 required: 1,
252 actual: 0,
253 context: "RandomForestClassifier requires at least one sample".into(),
254 });
255 }
256 if self.n_estimators == 0 {
257 return Err(FerroError::InvalidParameter {
258 name: "n_estimators".into(),
259 reason: "must be at least 1".into(),
260 });
261 }
262
263 let mut classes: Vec<usize> = y.iter().copied().collect();
265 classes.sort_unstable();
266 classes.dedup();
267 let n_classes = classes.len();
268
269 let y_mapped: Vec<usize> = y
270 .iter()
271 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
272 .collect();
273
274 let max_features_n = resolve_max_features(self.max_features, n_features);
275 let params = make_tree_params(
276 self.max_depth,
277 self.min_samples_split,
278 self.min_samples_leaf,
279 );
280 let criterion = self.criterion;
281
282 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
284 let mut master_rng = StdRng::seed_from_u64(seed);
285 (0..self.n_estimators)
286 .map(|_| {
287 use rand::RngCore;
288 master_rng.next_u64()
289 })
290 .collect()
291 } else {
292 (0..self.n_estimators)
293 .map(|_| {
294 use rand::RngCore;
295 rand::rng().next_u64()
296 })
297 .collect()
298 };
299
300 let trees: Vec<Vec<Node<F>>> = tree_seeds
302 .par_iter()
303 .map(|&seed| {
304 let mut rng = StdRng::seed_from_u64(seed);
305
306 let bootstrap_indices: Vec<usize> = (0..n_samples)
308 .map(|_| {
309 use rand::RngCore;
310 (rng.next_u64() as usize) % n_samples
311 })
312 .collect();
313
314 let feature_indices: Vec<usize> =
316 rand_sample_indices(&mut rng, n_features, max_features_n).into_vec();
317
318 build_classification_tree_with_feature_subset(
319 x,
320 &y_mapped,
321 n_classes,
322 &bootstrap_indices,
323 &feature_indices,
324 ¶ms,
325 criterion,
326 )
327 })
328 .collect();
329
330 let mut total_importances = Array1::<F>::zeros(n_features);
332 for tree_nodes in &trees {
333 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
334 total_importances = total_importances + tree_imp;
335 }
336 let imp_sum: F = total_importances
337 .iter()
338 .copied()
339 .fold(F::zero(), |a, b| a + b);
340 if imp_sum > F::zero() {
341 total_importances.mapv_inplace(|v| v / imp_sum);
342 }
343
344 Ok(FittedRandomForestClassifier {
345 trees,
346 classes,
347 n_features,
348 feature_importances: total_importances,
349 })
350 }
351}
352
353impl<F: Float + Send + Sync + 'static> FittedRandomForestClassifier<F> {
354 #[must_use]
356 pub fn trees(&self) -> &[Vec<Node<F>>] {
357 &self.trees
358 }
359
360 #[must_use]
362 pub fn n_features(&self) -> usize {
363 self.n_features
364 }
365}
366
367impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestClassifier<F> {
368 type Output = Array1<usize>;
369 type Error = FerroError;
370
371 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
378 if x.ncols() != self.n_features {
379 return Err(FerroError::ShapeMismatch {
380 expected: vec![self.n_features],
381 actual: vec![x.ncols()],
382 context: "number of features must match fitted model".into(),
383 });
384 }
385
386 let n_samples = x.nrows();
387 let n_classes = self.classes.len();
388 let mut predictions = Array1::zeros(n_samples);
389
390 for i in 0..n_samples {
391 let row = x.row(i);
392 let mut votes = vec![0usize; n_classes];
393
394 for tree_nodes in &self.trees {
395 let leaf_idx = decision_tree::traverse(tree_nodes, &row);
396 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
397 let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
398 if class_idx < n_classes {
399 votes[class_idx] += 1;
400 }
401 }
402 }
403
404 let winner = votes
405 .iter()
406 .enumerate()
407 .max_by_key(|&(_, &count)| count)
408 .map_or(0, |(idx, _)| idx);
409 predictions[i] = self.classes[winner];
410 }
411
412 Ok(predictions)
413 }
414}
415
416impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
417 for FittedRandomForestClassifier<F>
418{
419 fn feature_importances(&self) -> &Array1<F> {
420 &self.feature_importances
421 }
422}
423
424impl<F: Float + Send + Sync + 'static> HasClasses for FittedRandomForestClassifier<F> {
425 fn classes(&self) -> &[usize] {
426 &self.classes
427 }
428
429 fn n_classes(&self) -> usize {
430 self.classes.len()
431 }
432}
433
434impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
436 for RandomForestClassifier<F>
437{
438 fn fit_pipeline(
439 &self,
440 x: &Array2<F>,
441 y: &Array1<F>,
442 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
443 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
444 let fitted = self.fit(x, &y_usize)?;
445 Ok(Box::new(FittedForestClassifierPipelineAdapter(fitted)))
446 }
447}
448
449struct FittedForestClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
451 FittedRandomForestClassifier<F>,
452);
453
454impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
455 for FittedForestClassifierPipelineAdapter<F>
456{
457 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
458 let preds = self.0.predict(x)?;
459 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
460 }
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize)]
477pub struct RandomForestRegressor<F> {
478 pub n_estimators: usize,
480 pub max_depth: Option<usize>,
482 pub max_features: MaxFeatures,
484 pub min_samples_split: usize,
486 pub min_samples_leaf: usize,
488 pub random_state: Option<u64>,
490 _marker: std::marker::PhantomData<F>,
491}
492
493impl<F: Float> RandomForestRegressor<F> {
494 #[must_use]
500 pub fn new() -> Self {
501 Self {
502 n_estimators: 100,
503 max_depth: None,
504 max_features: MaxFeatures::All,
505 min_samples_split: 2,
506 min_samples_leaf: 1,
507 random_state: None,
508 _marker: std::marker::PhantomData,
509 }
510 }
511
512 #[must_use]
514 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
515 self.n_estimators = n_estimators;
516 self
517 }
518
519 #[must_use]
521 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
522 self.max_depth = max_depth;
523 self
524 }
525
526 #[must_use]
528 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
529 self.max_features = max_features;
530 self
531 }
532
533 #[must_use]
535 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
536 self.min_samples_split = min_samples_split;
537 self
538 }
539
540 #[must_use]
542 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
543 self.min_samples_leaf = min_samples_leaf;
544 self
545 }
546
547 #[must_use]
549 pub fn with_random_state(mut self, seed: u64) -> Self {
550 self.random_state = Some(seed);
551 self
552 }
553}
554
555impl<F: Float> Default for RandomForestRegressor<F> {
556 fn default() -> Self {
557 Self::new()
558 }
559}
560
561#[derive(Debug, Clone)]
570pub struct FittedRandomForestRegressor<F> {
571 trees: Vec<Vec<Node<F>>>,
573 n_features: usize,
575 feature_importances: Array1<F>,
577}
578
579impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for RandomForestRegressor<F> {
580 type Fitted = FittedRandomForestRegressor<F>;
581 type Error = FerroError;
582
583 fn fit(
592 &self,
593 x: &Array2<F>,
594 y: &Array1<F>,
595 ) -> Result<FittedRandomForestRegressor<F>, FerroError> {
596 let (n_samples, n_features) = x.dim();
597
598 if n_samples != y.len() {
599 return Err(FerroError::ShapeMismatch {
600 expected: vec![n_samples],
601 actual: vec![y.len()],
602 context: "y length must match number of samples in X".into(),
603 });
604 }
605 if n_samples == 0 {
606 return Err(FerroError::InsufficientSamples {
607 required: 1,
608 actual: 0,
609 context: "RandomForestRegressor requires at least one sample".into(),
610 });
611 }
612 if self.n_estimators == 0 {
613 return Err(FerroError::InvalidParameter {
614 name: "n_estimators".into(),
615 reason: "must be at least 1".into(),
616 });
617 }
618
619 let max_features_n = resolve_max_features(self.max_features, n_features);
620 let params = make_tree_params(
621 self.max_depth,
622 self.min_samples_split,
623 self.min_samples_leaf,
624 );
625
626 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
628 let mut master_rng = StdRng::seed_from_u64(seed);
629 (0..self.n_estimators)
630 .map(|_| {
631 use rand::RngCore;
632 master_rng.next_u64()
633 })
634 .collect()
635 } else {
636 (0..self.n_estimators)
637 .map(|_| {
638 use rand::RngCore;
639 rand::rng().next_u64()
640 })
641 .collect()
642 };
643
644 let trees: Vec<Vec<Node<F>>> = tree_seeds
646 .par_iter()
647 .map(|&seed| {
648 let mut rng = StdRng::seed_from_u64(seed);
649
650 let bootstrap_indices: Vec<usize> = (0..n_samples)
651 .map(|_| {
652 use rand::RngCore;
653 (rng.next_u64() as usize) % n_samples
654 })
655 .collect();
656
657 let feature_indices: Vec<usize> =
658 rand_sample_indices(&mut rng, n_features, max_features_n).into_vec();
659
660 build_regression_tree_with_feature_subset(
661 x,
662 y,
663 &bootstrap_indices,
664 &feature_indices,
665 ¶ms,
666 )
667 })
668 .collect();
669
670 let mut total_importances = Array1::<F>::zeros(n_features);
672 for tree_nodes in &trees {
673 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
674 total_importances = total_importances + tree_imp;
675 }
676 let imp_sum: F = total_importances
677 .iter()
678 .copied()
679 .fold(F::zero(), |a, b| a + b);
680 if imp_sum > F::zero() {
681 total_importances.mapv_inplace(|v| v / imp_sum);
682 }
683
684 Ok(FittedRandomForestRegressor {
685 trees,
686 n_features,
687 feature_importances: total_importances,
688 })
689 }
690}
691
692impl<F: Float + Send + Sync + 'static> FittedRandomForestRegressor<F> {
693 #[must_use]
695 pub fn trees(&self) -> &[Vec<Node<F>>] {
696 &self.trees
697 }
698
699 #[must_use]
701 pub fn n_features(&self) -> usize {
702 self.n_features
703 }
704}
705
706impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestRegressor<F> {
707 type Output = Array1<F>;
708 type Error = FerroError;
709
710 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
717 if x.ncols() != self.n_features {
718 return Err(FerroError::ShapeMismatch {
719 expected: vec![self.n_features],
720 actual: vec![x.ncols()],
721 context: "number of features must match fitted model".into(),
722 });
723 }
724
725 let n_samples = x.nrows();
726 let n_trees_f = F::from(self.trees.len()).unwrap();
727 let mut predictions = Array1::zeros(n_samples);
728
729 for i in 0..n_samples {
730 let row = x.row(i);
731 let mut sum = F::zero();
732
733 for tree_nodes in &self.trees {
734 let leaf_idx = decision_tree::traverse(tree_nodes, &row);
735 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
736 sum = sum + value;
737 }
738 }
739
740 predictions[i] = sum / n_trees_f;
741 }
742
743 Ok(predictions)
744 }
745}
746
747impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedRandomForestRegressor<F> {
748 fn feature_importances(&self) -> &Array1<F> {
749 &self.feature_importances
750 }
751}
752
753impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for RandomForestRegressor<F> {
755 fn fit_pipeline(
756 &self,
757 x: &Array2<F>,
758 y: &Array1<F>,
759 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
760 let fitted = self.fit(x, y)?;
761 Ok(Box::new(fitted))
762 }
763}
764
765impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
766 for FittedRandomForestRegressor<F>
767{
768 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
769 self.predict(x)
770 }
771}
772
773#[cfg(test)]
778mod tests {
779 use super::*;
780 use approx::assert_relative_eq;
781 use ndarray::array;
782
783 #[test]
786 fn test_forest_classifier_simple() {
787 let x = Array2::from_shape_vec(
788 (8, 2),
789 vec![
790 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
791 ],
792 )
793 .unwrap();
794 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
795
796 let model = RandomForestClassifier::<f64>::new()
797 .with_n_estimators(20)
798 .with_random_state(42);
799 let fitted = model.fit(&x, &y).unwrap();
800 let preds = fitted.predict(&x).unwrap();
801
802 assert_eq!(preds.len(), 8);
803 for i in 0..4 {
804 assert_eq!(preds[i], 0);
805 }
806 for i in 4..8 {
807 assert_eq!(preds[i], 1);
808 }
809 }
810
811 #[test]
812 fn test_forest_classifier_reproducibility() {
813 let x = Array2::from_shape_vec(
814 (8, 2),
815 vec![
816 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
817 ],
818 )
819 .unwrap();
820 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
821
822 let model = RandomForestClassifier::<f64>::new()
823 .with_n_estimators(10)
824 .with_random_state(123);
825
826 let fitted1 = model.fit(&x, &y).unwrap();
827 let fitted2 = model.fit(&x, &y).unwrap();
828
829 let preds1 = fitted1.predict(&x).unwrap();
830 let preds2 = fitted2.predict(&x).unwrap();
831
832 assert_eq!(preds1, preds2);
833 }
834
835 #[test]
836 fn test_forest_classifier_feature_importances() {
837 let x = Array2::from_shape_vec(
838 (10, 3),
839 vec![
840 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
841 0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
842 ],
843 )
844 .unwrap();
845 let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
846
847 let model = RandomForestClassifier::<f64>::new()
848 .with_n_estimators(20)
849 .with_max_features(MaxFeatures::All)
850 .with_random_state(42);
851 let fitted = model.fit(&x, &y).unwrap();
852 let importances = fitted.feature_importances();
853
854 assert_eq!(importances.len(), 3);
855 assert!(importances[0] > importances[1]);
856 assert!(importances[0] > importances[2]);
857 }
858
859 #[test]
860 fn test_forest_classifier_has_classes() {
861 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
862 let y = array![0, 1, 2, 0, 1, 2];
863
864 let model = RandomForestClassifier::<f64>::new()
865 .with_n_estimators(5)
866 .with_random_state(0);
867 let fitted = model.fit(&x, &y).unwrap();
868
869 assert_eq!(fitted.classes(), &[0, 1, 2]);
870 assert_eq!(fitted.n_classes(), 3);
871 }
872
873 #[test]
874 fn test_forest_classifier_shape_mismatch_fit() {
875 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
876 let y = array![0, 1];
877
878 let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
879 assert!(model.fit(&x, &y).is_err());
880 }
881
882 #[test]
883 fn test_forest_classifier_shape_mismatch_predict() {
884 let x =
885 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
886 let y = array![0, 0, 1, 1];
887
888 let model = RandomForestClassifier::<f64>::new()
889 .with_n_estimators(5)
890 .with_random_state(0);
891 let fitted = model.fit(&x, &y).unwrap();
892
893 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
894 assert!(fitted.predict(&x_bad).is_err());
895 }
896
897 #[test]
898 fn test_forest_classifier_empty_data() {
899 let x = Array2::<f64>::zeros((0, 2));
900 let y = Array1::<usize>::zeros(0);
901
902 let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
903 assert!(model.fit(&x, &y).is_err());
904 }
905
906 #[test]
907 fn test_forest_classifier_zero_estimators() {
908 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
909 let y = array![0, 0, 1, 1];
910
911 let model = RandomForestClassifier::<f64>::new().with_n_estimators(0);
912 assert!(model.fit(&x, &y).is_err());
913 }
914
915 #[test]
916 fn test_forest_classifier_single_tree() {
917 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
918 let y = array![0, 0, 0, 1, 1, 1];
919
920 let model = RandomForestClassifier::<f64>::new()
921 .with_n_estimators(1)
922 .with_max_features(MaxFeatures::All)
923 .with_random_state(42);
924 let fitted = model.fit(&x, &y).unwrap();
925 let preds = fitted.predict(&x).unwrap();
926
927 assert_eq!(preds.len(), 6);
928 }
929
930 #[test]
931 fn test_forest_classifier_pipeline_integration() {
932 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
933 let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
934
935 let model = RandomForestClassifier::<f64>::new()
936 .with_n_estimators(5)
937 .with_random_state(42);
938 let fitted = model.fit_pipeline(&x, &y).unwrap();
939 let preds = fitted.predict_pipeline(&x).unwrap();
940 assert_eq!(preds.len(), 6);
941 }
942
943 #[test]
944 fn test_forest_classifier_max_depth() {
945 let x =
946 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
947 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
948
949 let model = RandomForestClassifier::<f64>::new()
950 .with_n_estimators(10)
951 .with_max_depth(Some(1))
952 .with_max_features(MaxFeatures::All)
953 .with_random_state(42);
954 let fitted = model.fit(&x, &y).unwrap();
955 let preds = fitted.predict(&x).unwrap();
956
957 assert_eq!(preds.len(), 8);
958 }
959
960 #[test]
963 fn test_forest_regressor_simple() {
964 let x =
965 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
966 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
967
968 let model = RandomForestRegressor::<f64>::new()
969 .with_n_estimators(50)
970 .with_random_state(42);
971 let fitted = model.fit(&x, &y).unwrap();
972 let preds = fitted.predict(&x).unwrap();
973
974 assert_eq!(preds.len(), 8);
975 for i in 0..4 {
976 assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
977 }
978 for i in 4..8 {
979 assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
980 }
981 }
982
983 #[test]
984 fn test_forest_regressor_reproducibility() {
985 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
986 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
987
988 let model = RandomForestRegressor::<f64>::new()
989 .with_n_estimators(10)
990 .with_random_state(99);
991
992 let fitted1 = model.fit(&x, &y).unwrap();
993 let fitted2 = model.fit(&x, &y).unwrap();
994
995 let preds1 = fitted1.predict(&x).unwrap();
996 let preds2 = fitted2.predict(&x).unwrap();
997
998 for (p1, p2) in preds1.iter().zip(preds2.iter()) {
999 assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
1000 }
1001 }
1002
1003 #[test]
1004 fn test_forest_regressor_feature_importances() {
1005 let x = Array2::from_shape_vec(
1006 (8, 2),
1007 vec![
1008 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1009 ],
1010 )
1011 .unwrap();
1012 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1013
1014 let model = RandomForestRegressor::<f64>::new()
1015 .with_n_estimators(20)
1016 .with_max_features(MaxFeatures::All)
1017 .with_random_state(42);
1018 let fitted = model.fit(&x, &y).unwrap();
1019 let importances = fitted.feature_importances();
1020
1021 assert_eq!(importances.len(), 2);
1022 assert!(importances[0] > importances[1]);
1023 }
1024
1025 #[test]
1026 fn test_forest_regressor_shape_mismatch_fit() {
1027 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1028 let y = array![1.0, 2.0];
1029
1030 let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1031 assert!(model.fit(&x, &y).is_err());
1032 }
1033
1034 #[test]
1035 fn test_forest_regressor_shape_mismatch_predict() {
1036 let x =
1037 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1038 let y = array![1.0, 2.0, 3.0, 4.0];
1039
1040 let model = RandomForestRegressor::<f64>::new()
1041 .with_n_estimators(5)
1042 .with_random_state(0);
1043 let fitted = model.fit(&x, &y).unwrap();
1044
1045 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1046 assert!(fitted.predict(&x_bad).is_err());
1047 }
1048
1049 #[test]
1050 fn test_forest_regressor_empty_data() {
1051 let x = Array2::<f64>::zeros((0, 2));
1052 let y = Array1::<f64>::zeros(0);
1053
1054 let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1055 assert!(model.fit(&x, &y).is_err());
1056 }
1057
1058 #[test]
1059 fn test_forest_regressor_zero_estimators() {
1060 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1061 let y = array![1.0, 2.0, 3.0, 4.0];
1062
1063 let model = RandomForestRegressor::<f64>::new().with_n_estimators(0);
1064 assert!(model.fit(&x, &y).is_err());
1065 }
1066
1067 #[test]
1068 fn test_forest_regressor_pipeline_integration() {
1069 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1070 let y = array![1.0, 2.0, 3.0, 4.0];
1071
1072 let model = RandomForestRegressor::<f64>::new()
1073 .with_n_estimators(5)
1074 .with_random_state(42);
1075 let fitted = model.fit_pipeline(&x, &y).unwrap();
1076 let preds = fitted.predict_pipeline(&x).unwrap();
1077 assert_eq!(preds.len(), 4);
1078 }
1079
1080 #[test]
1081 fn test_forest_regressor_max_features_strategies() {
1082 let x = Array2::from_shape_vec(
1083 (8, 4),
1084 vec![
1085 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 6.0, 4.0, 5.0, 6.0, 7.0,
1086 5.0, 6.0, 7.0, 8.0, 6.0, 7.0, 8.0, 9.0, 7.0, 8.0, 9.0, 10.0, 8.0, 9.0, 10.0, 11.0,
1087 ],
1088 )
1089 .unwrap();
1090 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1091
1092 for strategy in &[
1093 MaxFeatures::Sqrt,
1094 MaxFeatures::Log2,
1095 MaxFeatures::All,
1096 MaxFeatures::Fixed(2),
1097 MaxFeatures::Fraction(0.5),
1098 ] {
1099 let model = RandomForestRegressor::<f64>::new()
1100 .with_n_estimators(5)
1101 .with_max_features(*strategy)
1102 .with_random_state(42);
1103 let fitted = model.fit(&x, &y).unwrap();
1104 let preds = fitted.predict(&x).unwrap();
1105 assert_eq!(preds.len(), 8);
1106 }
1107 }
1108
1109 #[test]
1112 fn test_resolve_max_features_sqrt() {
1113 assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 9), 3);
1114 assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 10), 4);
1115 assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 1), 1);
1116 }
1117
1118 #[test]
1119 fn test_resolve_max_features_log2() {
1120 assert_eq!(resolve_max_features(MaxFeatures::Log2, 8), 3);
1121 assert_eq!(resolve_max_features(MaxFeatures::Log2, 1), 1);
1122 }
1123
1124 #[test]
1125 fn test_resolve_max_features_all() {
1126 assert_eq!(resolve_max_features(MaxFeatures::All, 10), 10);
1127 assert_eq!(resolve_max_features(MaxFeatures::All, 1), 1);
1128 }
1129
1130 #[test]
1131 fn test_resolve_max_features_fixed() {
1132 assert_eq!(resolve_max_features(MaxFeatures::Fixed(3), 10), 3);
1133 assert_eq!(resolve_max_features(MaxFeatures::Fixed(20), 10), 10);
1134 }
1135
1136 #[test]
1137 fn test_resolve_max_features_fraction() {
1138 assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.5), 10), 5);
1139 assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.1), 10), 1);
1140 }
1141
1142 #[test]
1143 fn test_forest_classifier_f32_support() {
1144 let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1145 let y = array![0, 0, 0, 1, 1, 1];
1146
1147 let model = RandomForestClassifier::<f32>::new()
1148 .with_n_estimators(5)
1149 .with_random_state(42);
1150 let fitted = model.fit(&x, &y).unwrap();
1151 let preds = fitted.predict(&x).unwrap();
1152 assert_eq!(preds.len(), 6);
1153 }
1154
1155 #[test]
1156 fn test_forest_regressor_f32_support() {
1157 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1158 let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1159
1160 let model = RandomForestRegressor::<f32>::new()
1161 .with_n_estimators(5)
1162 .with_random_state(42);
1163 let fitted = model.fit(&x, &y).unwrap();
1164 let preds = fitted.predict(&x).unwrap();
1165 assert_eq!(preds.len(), 4);
1166 }
1167}