1use ferrolearn_core::error::FerroError;
28use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
29use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
30use ferrolearn_core::traits::{Fit, Predict};
31use ndarray::{Array1, Array2};
32use num_traits::{Float, FromPrimitive, ToPrimitive};
33use rand::SeedableRng;
34use rand::rngs::StdRng;
35use rayon::prelude::*;
36use serde::{Deserialize, Serialize};
37
38use crate::decision_tree::{
39 self, ClassificationCriterion, Node, build_classification_tree_per_split_features,
40 build_regression_tree_per_split_features, compute_feature_importances,
41};
42
43#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
49pub enum MaxFeatures {
50 Sqrt,
52 Log2,
54 All,
56 Fixed(usize),
58 Fraction(f64),
60}
61
62fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
64 let result = match strategy {
65 MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
66 MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
67 MaxFeatures::All => n_features,
68 MaxFeatures::Fixed(n) => n.min(n_features),
69 MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
70 };
71 result.max(1).min(n_features)
72}
73
74fn make_tree_params(
79 max_depth: Option<usize>,
80 min_samples_split: usize,
81 min_samples_leaf: usize,
82) -> decision_tree::TreeParams {
83 decision_tree::TreeParams {
84 max_depth,
85 min_samples_split,
86 min_samples_leaf,
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct RandomForestClassifier<F> {
105 pub n_estimators: usize,
107 pub max_depth: Option<usize>,
109 pub max_features: MaxFeatures,
111 pub min_samples_split: usize,
113 pub min_samples_leaf: usize,
115 pub random_state: Option<u64>,
117 pub criterion: ClassificationCriterion,
119 _marker: std::marker::PhantomData<F>,
120}
121
122impl<F: Float> RandomForestClassifier<F> {
123 #[must_use]
130 pub fn new() -> Self {
131 Self {
132 n_estimators: 100,
133 max_depth: None,
134 max_features: MaxFeatures::Sqrt,
135 min_samples_split: 2,
136 min_samples_leaf: 1,
137 random_state: None,
138 criterion: ClassificationCriterion::Gini,
139 _marker: std::marker::PhantomData,
140 }
141 }
142
143 #[must_use]
145 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
146 self.n_estimators = n_estimators;
147 self
148 }
149
150 #[must_use]
152 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
153 self.max_depth = max_depth;
154 self
155 }
156
157 #[must_use]
159 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
160 self.max_features = max_features;
161 self
162 }
163
164 #[must_use]
166 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
167 self.min_samples_split = min_samples_split;
168 self
169 }
170
171 #[must_use]
173 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
174 self.min_samples_leaf = min_samples_leaf;
175 self
176 }
177
178 #[must_use]
180 pub fn with_random_state(mut self, seed: u64) -> Self {
181 self.random_state = Some(seed);
182 self
183 }
184
185 #[must_use]
187 pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
188 self.criterion = criterion;
189 self
190 }
191}
192
193impl<F: Float> Default for RandomForestClassifier<F> {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199#[derive(Debug, Clone)]
208pub struct FittedRandomForestClassifier<F> {
209 trees: Vec<Vec<Node<F>>>,
211 classes: Vec<usize>,
213 n_features: usize,
215 feature_importances: Array1<F>,
217}
218
219impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for RandomForestClassifier<F> {
220 type Fitted = FittedRandomForestClassifier<F>;
221 type Error = FerroError;
222
223 fn fit(
235 &self,
236 x: &Array2<F>,
237 y: &Array1<usize>,
238 ) -> Result<FittedRandomForestClassifier<F>, FerroError> {
239 let (n_samples, n_features) = x.dim();
240
241 if n_samples != y.len() {
242 return Err(FerroError::ShapeMismatch {
243 expected: vec![n_samples],
244 actual: vec![y.len()],
245 context: "y length must match number of samples in X".into(),
246 });
247 }
248 if n_samples == 0 {
249 return Err(FerroError::InsufficientSamples {
250 required: 1,
251 actual: 0,
252 context: "RandomForestClassifier requires at least one sample".into(),
253 });
254 }
255 if self.n_estimators == 0 {
256 return Err(FerroError::InvalidParameter {
257 name: "n_estimators".into(),
258 reason: "must be at least 1".into(),
259 });
260 }
261
262 let mut classes: Vec<usize> = y.iter().copied().collect();
264 classes.sort_unstable();
265 classes.dedup();
266 let n_classes = classes.len();
267
268 let y_mapped: Vec<usize> = y
269 .iter()
270 .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
271 .collect();
272
273 let max_features_n = resolve_max_features(self.max_features, n_features);
274 let params = make_tree_params(
275 self.max_depth,
276 self.min_samples_split,
277 self.min_samples_leaf,
278 );
279 let criterion = self.criterion;
280
281 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
283 let mut master_rng = StdRng::seed_from_u64(seed);
284 (0..self.n_estimators)
285 .map(|_| {
286 use rand::RngCore;
287 master_rng.next_u64()
288 })
289 .collect()
290 } else {
291 (0..self.n_estimators)
292 .map(|_| {
293 use rand::RngCore;
294 rand::rng().next_u64()
295 })
296 .collect()
297 };
298
299 let trees: Vec<Vec<Node<F>>> = tree_seeds
309 .par_iter()
310 .map(|&seed| {
311 let mut bootstrap_rng = StdRng::seed_from_u64(seed);
312
313 let bootstrap_indices: Vec<usize> = (0..n_samples)
314 .map(|_| {
315 use rand::RngCore;
316 (bootstrap_rng.next_u64() as usize) % n_samples
317 })
318 .collect();
319
320 use rand::RngCore;
324 let split_seed = bootstrap_rng.next_u64();
325
326 build_classification_tree_per_split_features(
327 x,
328 &y_mapped,
329 n_classes,
330 &bootstrap_indices,
331 max_features_n,
332 ¶ms,
333 criterion,
334 split_seed,
335 )
336 })
337 .collect();
338
339 let mut total_importances = Array1::<F>::zeros(n_features);
341 for tree_nodes in &trees {
342 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
343 total_importances = total_importances + tree_imp;
344 }
345 let imp_sum: F = total_importances
346 .iter()
347 .copied()
348 .fold(F::zero(), |a, b| a + b);
349 if imp_sum > F::zero() {
350 total_importances.mapv_inplace(|v| v / imp_sum);
351 }
352
353 Ok(FittedRandomForestClassifier {
354 trees,
355 classes,
356 n_features,
357 feature_importances: total_importances,
358 })
359 }
360}
361
362impl<F: Float + Send + Sync + 'static> FittedRandomForestClassifier<F> {
363 #[must_use]
365 pub fn trees(&self) -> &[Vec<Node<F>>] {
366 &self.trees
367 }
368
369 #[must_use]
371 pub fn n_features(&self) -> usize {
372 self.n_features
373 }
374
375 pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
383 if x.nrows() != y.len() {
384 return Err(FerroError::ShapeMismatch {
385 expected: vec![x.nrows()],
386 actual: vec![y.len()],
387 context: "y length must match number of samples in X".into(),
388 });
389 }
390 let preds = self.predict(x)?;
391 Ok(crate::mean_accuracy(&preds, y))
392 }
393
394 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
407 if x.ncols() != self.n_features {
408 return Err(FerroError::ShapeMismatch {
409 expected: vec![self.n_features],
410 actual: vec![x.ncols()],
411 context: "number of features must match fitted model".into(),
412 });
413 }
414 let n_samples = x.nrows();
415 let n_classes = self.classes.len();
416 let n_trees_f = F::from(self.trees.len()).unwrap();
417 let mut proba = Array2::<F>::zeros((n_samples, n_classes));
418
419 for i in 0..n_samples {
420 let row = x.row(i);
421 for tree_nodes in &self.trees {
422 let leaf_idx = decision_tree::traverse(tree_nodes, &row);
423 match &tree_nodes[leaf_idx] {
424 Node::Leaf {
425 class_distribution: Some(dist),
426 ..
427 } => {
428 for (j, &p) in dist.iter().enumerate().take(n_classes) {
429 proba[[i, j]] = proba[[i, j]] + p;
430 }
431 }
432 Node::Leaf { value, .. } => {
433 let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
434 if class_idx < n_classes {
435 proba[[i, class_idx]] = proba[[i, class_idx]] + F::one();
436 }
437 }
438 _ => {}
439 }
440 }
441 for j in 0..n_classes {
442 proba[[i, j]] = proba[[i, j]] / n_trees_f;
443 }
444 }
445 Ok(proba)
446 }
447
448 pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
455 let proba = self.predict_proba(x)?;
456 Ok(crate::log_proba(&proba))
457 }
458}
459
460impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestClassifier<F> {
461 type Output = Array1<usize>;
462 type Error = FerroError;
463
464 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
471 if x.ncols() != self.n_features {
472 return Err(FerroError::ShapeMismatch {
473 expected: vec![self.n_features],
474 actual: vec![x.ncols()],
475 context: "number of features must match fitted model".into(),
476 });
477 }
478
479 let n_samples = x.nrows();
480 let n_classes = self.classes.len();
481 let mut predictions = Array1::zeros(n_samples);
482
483 for i in 0..n_samples {
484 let row = x.row(i);
485 let mut votes = vec![0usize; n_classes];
486
487 for tree_nodes in &self.trees {
488 let leaf_idx = decision_tree::traverse(tree_nodes, &row);
489 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
490 let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
491 if class_idx < n_classes {
492 votes[class_idx] += 1;
493 }
494 }
495 }
496
497 let winner = votes
498 .iter()
499 .enumerate()
500 .max_by_key(|&(_, &count)| count)
501 .map_or(0, |(idx, _)| idx);
502 predictions[i] = self.classes[winner];
503 }
504
505 Ok(predictions)
506 }
507}
508
509impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
510 for FittedRandomForestClassifier<F>
511{
512 fn feature_importances(&self) -> &Array1<F> {
513 &self.feature_importances
514 }
515}
516
517impl<F: Float + Send + Sync + 'static> HasClasses for FittedRandomForestClassifier<F> {
518 fn classes(&self) -> &[usize] {
519 &self.classes
520 }
521
522 fn n_classes(&self) -> usize {
523 self.classes.len()
524 }
525}
526
527impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
529 for RandomForestClassifier<F>
530{
531 fn fit_pipeline(
532 &self,
533 x: &Array2<F>,
534 y: &Array1<F>,
535 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
536 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
537 let fitted = self.fit(x, &y_usize)?;
538 Ok(Box::new(FittedForestClassifierPipelineAdapter(fitted)))
539 }
540}
541
542struct FittedForestClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
544 FittedRandomForestClassifier<F>,
545);
546
547impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
548 for FittedForestClassifierPipelineAdapter<F>
549{
550 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
551 let preds = self.0.predict(x)?;
552 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
553 }
554}
555
556#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct RandomForestRegressor<F> {
571 pub n_estimators: usize,
573 pub max_depth: Option<usize>,
575 pub max_features: MaxFeatures,
577 pub min_samples_split: usize,
579 pub min_samples_leaf: usize,
581 pub random_state: Option<u64>,
583 _marker: std::marker::PhantomData<F>,
584}
585
586impl<F: Float> RandomForestRegressor<F> {
587 #[must_use]
593 pub fn new() -> Self {
594 Self {
595 n_estimators: 100,
596 max_depth: None,
597 max_features: MaxFeatures::All,
598 min_samples_split: 2,
599 min_samples_leaf: 1,
600 random_state: None,
601 _marker: std::marker::PhantomData,
602 }
603 }
604
605 #[must_use]
607 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
608 self.n_estimators = n_estimators;
609 self
610 }
611
612 #[must_use]
614 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
615 self.max_depth = max_depth;
616 self
617 }
618
619 #[must_use]
621 pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
622 self.max_features = max_features;
623 self
624 }
625
626 #[must_use]
628 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
629 self.min_samples_split = min_samples_split;
630 self
631 }
632
633 #[must_use]
635 pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
636 self.min_samples_leaf = min_samples_leaf;
637 self
638 }
639
640 #[must_use]
642 pub fn with_random_state(mut self, seed: u64) -> Self {
643 self.random_state = Some(seed);
644 self
645 }
646}
647
648impl<F: Float> Default for RandomForestRegressor<F> {
649 fn default() -> Self {
650 Self::new()
651 }
652}
653
654#[derive(Debug, Clone)]
663pub struct FittedRandomForestRegressor<F> {
664 trees: Vec<Vec<Node<F>>>,
666 n_features: usize,
668 feature_importances: Array1<F>,
670}
671
672impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for RandomForestRegressor<F> {
673 type Fitted = FittedRandomForestRegressor<F>;
674 type Error = FerroError;
675
676 fn fit(
685 &self,
686 x: &Array2<F>,
687 y: &Array1<F>,
688 ) -> Result<FittedRandomForestRegressor<F>, FerroError> {
689 let (n_samples, n_features) = x.dim();
690
691 if n_samples != y.len() {
692 return Err(FerroError::ShapeMismatch {
693 expected: vec![n_samples],
694 actual: vec![y.len()],
695 context: "y length must match number of samples in X".into(),
696 });
697 }
698 if n_samples == 0 {
699 return Err(FerroError::InsufficientSamples {
700 required: 1,
701 actual: 0,
702 context: "RandomForestRegressor requires at least one sample".into(),
703 });
704 }
705 if self.n_estimators == 0 {
706 return Err(FerroError::InvalidParameter {
707 name: "n_estimators".into(),
708 reason: "must be at least 1".into(),
709 });
710 }
711
712 let max_features_n = resolve_max_features(self.max_features, n_features);
713 let params = make_tree_params(
714 self.max_depth,
715 self.min_samples_split,
716 self.min_samples_leaf,
717 );
718
719 let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
721 let mut master_rng = StdRng::seed_from_u64(seed);
722 (0..self.n_estimators)
723 .map(|_| {
724 use rand::RngCore;
725 master_rng.next_u64()
726 })
727 .collect()
728 } else {
729 (0..self.n_estimators)
730 .map(|_| {
731 use rand::RngCore;
732 rand::rng().next_u64()
733 })
734 .collect()
735 };
736
737 let trees: Vec<Vec<Node<F>>> = tree_seeds
740 .par_iter()
741 .map(|&seed| {
742 let mut bootstrap_rng = StdRng::seed_from_u64(seed);
743
744 let bootstrap_indices: Vec<usize> = (0..n_samples)
745 .map(|_| {
746 use rand::RngCore;
747 (bootstrap_rng.next_u64() as usize) % n_samples
748 })
749 .collect();
750
751 use rand::RngCore;
752 let split_seed = bootstrap_rng.next_u64();
753
754 build_regression_tree_per_split_features(
755 x,
756 y,
757 &bootstrap_indices,
758 max_features_n,
759 ¶ms,
760 split_seed,
761 )
762 })
763 .collect();
764
765 let mut total_importances = Array1::<F>::zeros(n_features);
767 for tree_nodes in &trees {
768 let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
769 total_importances = total_importances + tree_imp;
770 }
771 let imp_sum: F = total_importances
772 .iter()
773 .copied()
774 .fold(F::zero(), |a, b| a + b);
775 if imp_sum > F::zero() {
776 total_importances.mapv_inplace(|v| v / imp_sum);
777 }
778
779 Ok(FittedRandomForestRegressor {
780 trees,
781 n_features,
782 feature_importances: total_importances,
783 })
784 }
785}
786
787impl<F: Float + Send + Sync + 'static> FittedRandomForestRegressor<F> {
788 #[must_use]
790 pub fn trees(&self) -> &[Vec<Node<F>>] {
791 &self.trees
792 }
793
794 #[must_use]
796 pub fn n_features(&self) -> usize {
797 self.n_features
798 }
799
800 pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
808 if x.nrows() != y.len() {
809 return Err(FerroError::ShapeMismatch {
810 expected: vec![x.nrows()],
811 actual: vec![y.len()],
812 context: "y length must match number of samples in X".into(),
813 });
814 }
815 let preds = self.predict(x)?;
816 Ok(crate::r2_score(&preds, y))
817 }
818}
819
820impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedRandomForestRegressor<F> {
821 type Output = Array1<F>;
822 type Error = FerroError;
823
824 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
831 if x.ncols() != self.n_features {
832 return Err(FerroError::ShapeMismatch {
833 expected: vec![self.n_features],
834 actual: vec![x.ncols()],
835 context: "number of features must match fitted model".into(),
836 });
837 }
838
839 let n_samples = x.nrows();
840 let n_trees_f = F::from(self.trees.len()).unwrap();
841 let mut predictions = Array1::zeros(n_samples);
842
843 for i in 0..n_samples {
844 let row = x.row(i);
845 let mut sum = F::zero();
846
847 for tree_nodes in &self.trees {
848 let leaf_idx = decision_tree::traverse(tree_nodes, &row);
849 if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
850 sum = sum + value;
851 }
852 }
853
854 predictions[i] = sum / n_trees_f;
855 }
856
857 Ok(predictions)
858 }
859}
860
861impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedRandomForestRegressor<F> {
862 fn feature_importances(&self) -> &Array1<F> {
863 &self.feature_importances
864 }
865}
866
867impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for RandomForestRegressor<F> {
869 fn fit_pipeline(
870 &self,
871 x: &Array2<F>,
872 y: &Array1<F>,
873 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
874 let fitted = self.fit(x, y)?;
875 Ok(Box::new(fitted))
876 }
877}
878
879impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
880 for FittedRandomForestRegressor<F>
881{
882 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
883 self.predict(x)
884 }
885}
886
887#[cfg(test)]
892mod tests {
893 use super::*;
894 use approx::assert_relative_eq;
895 use ndarray::array;
896
897 #[test]
900 fn test_forest_classifier_simple() {
901 let x = Array2::from_shape_vec(
902 (8, 2),
903 vec![
904 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
905 ],
906 )
907 .unwrap();
908 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
909
910 let model = RandomForestClassifier::<f64>::new()
911 .with_n_estimators(20)
912 .with_random_state(42);
913 let fitted = model.fit(&x, &y).unwrap();
914 let preds = fitted.predict(&x).unwrap();
915
916 assert_eq!(preds.len(), 8);
917 for i in 0..4 {
918 assert_eq!(preds[i], 0);
919 }
920 for i in 4..8 {
921 assert_eq!(preds[i], 1);
922 }
923 }
924
925 #[test]
926 fn test_forest_classifier_reproducibility() {
927 let x = Array2::from_shape_vec(
928 (8, 2),
929 vec![
930 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
931 ],
932 )
933 .unwrap();
934 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
935
936 let model = RandomForestClassifier::<f64>::new()
937 .with_n_estimators(10)
938 .with_random_state(123);
939
940 let fitted1 = model.fit(&x, &y).unwrap();
941 let fitted2 = model.fit(&x, &y).unwrap();
942
943 let preds1 = fitted1.predict(&x).unwrap();
944 let preds2 = fitted2.predict(&x).unwrap();
945
946 assert_eq!(preds1, preds2);
947 }
948
949 #[test]
950 fn test_forest_classifier_feature_importances() {
951 let x = Array2::from_shape_vec(
952 (10, 3),
953 vec![
954 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
955 0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
956 ],
957 )
958 .unwrap();
959 let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
960
961 let model = RandomForestClassifier::<f64>::new()
962 .with_n_estimators(20)
963 .with_max_features(MaxFeatures::All)
964 .with_random_state(42);
965 let fitted = model.fit(&x, &y).unwrap();
966 let importances = fitted.feature_importances();
967
968 assert_eq!(importances.len(), 3);
969 assert!(importances[0] > importances[1]);
970 assert!(importances[0] > importances[2]);
971 }
972
973 #[test]
974 fn test_forest_classifier_has_classes() {
975 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
976 let y = array![0, 1, 2, 0, 1, 2];
977
978 let model = RandomForestClassifier::<f64>::new()
979 .with_n_estimators(5)
980 .with_random_state(0);
981 let fitted = model.fit(&x, &y).unwrap();
982
983 assert_eq!(fitted.classes(), &[0, 1, 2]);
984 assert_eq!(fitted.n_classes(), 3);
985 }
986
987 #[test]
988 fn test_forest_classifier_shape_mismatch_fit() {
989 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
990 let y = array![0, 1];
991
992 let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
993 assert!(model.fit(&x, &y).is_err());
994 }
995
996 #[test]
997 fn test_forest_classifier_shape_mismatch_predict() {
998 let x =
999 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1000 let y = array![0, 0, 1, 1];
1001
1002 let model = RandomForestClassifier::<f64>::new()
1003 .with_n_estimators(5)
1004 .with_random_state(0);
1005 let fitted = model.fit(&x, &y).unwrap();
1006
1007 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1008 assert!(fitted.predict(&x_bad).is_err());
1009 }
1010
1011 #[test]
1012 fn test_forest_classifier_empty_data() {
1013 let x = Array2::<f64>::zeros((0, 2));
1014 let y = Array1::<usize>::zeros(0);
1015
1016 let model = RandomForestClassifier::<f64>::new().with_n_estimators(5);
1017 assert!(model.fit(&x, &y).is_err());
1018 }
1019
1020 #[test]
1021 fn test_forest_classifier_zero_estimators() {
1022 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1023 let y = array![0, 0, 1, 1];
1024
1025 let model = RandomForestClassifier::<f64>::new().with_n_estimators(0);
1026 assert!(model.fit(&x, &y).is_err());
1027 }
1028
1029 #[test]
1030 fn test_forest_classifier_single_tree() {
1031 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1032 let y = array![0, 0, 0, 1, 1, 1];
1033
1034 let model = RandomForestClassifier::<f64>::new()
1035 .with_n_estimators(1)
1036 .with_max_features(MaxFeatures::All)
1037 .with_random_state(42);
1038 let fitted = model.fit(&x, &y).unwrap();
1039 let preds = fitted.predict(&x).unwrap();
1040
1041 assert_eq!(preds.len(), 6);
1042 }
1043
1044 #[test]
1045 fn test_forest_classifier_pipeline_integration() {
1046 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1047 let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1048
1049 let model = RandomForestClassifier::<f64>::new()
1050 .with_n_estimators(5)
1051 .with_random_state(42);
1052 let fitted = model.fit_pipeline(&x, &y).unwrap();
1053 let preds = fitted.predict_pipeline(&x).unwrap();
1054 assert_eq!(preds.len(), 6);
1055 }
1056
1057 #[test]
1058 fn test_forest_classifier_max_depth() {
1059 let x =
1060 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1061 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1062
1063 let model = RandomForestClassifier::<f64>::new()
1064 .with_n_estimators(10)
1065 .with_max_depth(Some(1))
1066 .with_max_features(MaxFeatures::All)
1067 .with_random_state(42);
1068 let fitted = model.fit(&x, &y).unwrap();
1069 let preds = fitted.predict(&x).unwrap();
1070
1071 assert_eq!(preds.len(), 8);
1072 }
1073
1074 #[test]
1077 fn test_forest_regressor_simple() {
1078 let x =
1079 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1080 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1081
1082 let model = RandomForestRegressor::<f64>::new()
1083 .with_n_estimators(50)
1084 .with_random_state(42);
1085 let fitted = model.fit(&x, &y).unwrap();
1086 let preds = fitted.predict(&x).unwrap();
1087
1088 assert_eq!(preds.len(), 8);
1089 for i in 0..4 {
1090 assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
1091 }
1092 for i in 4..8 {
1093 assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
1094 }
1095 }
1096
1097 #[test]
1098 fn test_forest_regressor_reproducibility() {
1099 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1100 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1101
1102 let model = RandomForestRegressor::<f64>::new()
1103 .with_n_estimators(10)
1104 .with_random_state(99);
1105
1106 let fitted1 = model.fit(&x, &y).unwrap();
1107 let fitted2 = model.fit(&x, &y).unwrap();
1108
1109 let preds1 = fitted1.predict(&x).unwrap();
1110 let preds2 = fitted2.predict(&x).unwrap();
1111
1112 for (p1, p2) in preds1.iter().zip(preds2.iter()) {
1113 assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
1114 }
1115 }
1116
1117 #[test]
1118 fn test_forest_regressor_feature_importances() {
1119 let x = Array2::from_shape_vec(
1120 (8, 2),
1121 vec![
1122 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
1123 ],
1124 )
1125 .unwrap();
1126 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1127
1128 let model = RandomForestRegressor::<f64>::new()
1129 .with_n_estimators(20)
1130 .with_max_features(MaxFeatures::All)
1131 .with_random_state(42);
1132 let fitted = model.fit(&x, &y).unwrap();
1133 let importances = fitted.feature_importances();
1134
1135 assert_eq!(importances.len(), 2);
1136 assert!(importances[0] > importances[1]);
1137 }
1138
1139 #[test]
1140 fn test_forest_regressor_shape_mismatch_fit() {
1141 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1142 let y = array![1.0, 2.0];
1143
1144 let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1145 assert!(model.fit(&x, &y).is_err());
1146 }
1147
1148 #[test]
1149 fn test_forest_regressor_shape_mismatch_predict() {
1150 let x =
1151 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1152 let y = array![1.0, 2.0, 3.0, 4.0];
1153
1154 let model = RandomForestRegressor::<f64>::new()
1155 .with_n_estimators(5)
1156 .with_random_state(0);
1157 let fitted = model.fit(&x, &y).unwrap();
1158
1159 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1160 assert!(fitted.predict(&x_bad).is_err());
1161 }
1162
1163 #[test]
1164 fn test_forest_regressor_empty_data() {
1165 let x = Array2::<f64>::zeros((0, 2));
1166 let y = Array1::<f64>::zeros(0);
1167
1168 let model = RandomForestRegressor::<f64>::new().with_n_estimators(5);
1169 assert!(model.fit(&x, &y).is_err());
1170 }
1171
1172 #[test]
1173 fn test_forest_regressor_zero_estimators() {
1174 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1175 let y = array![1.0, 2.0, 3.0, 4.0];
1176
1177 let model = RandomForestRegressor::<f64>::new().with_n_estimators(0);
1178 assert!(model.fit(&x, &y).is_err());
1179 }
1180
1181 #[test]
1182 fn test_forest_regressor_pipeline_integration() {
1183 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1184 let y = array![1.0, 2.0, 3.0, 4.0];
1185
1186 let model = RandomForestRegressor::<f64>::new()
1187 .with_n_estimators(5)
1188 .with_random_state(42);
1189 let fitted = model.fit_pipeline(&x, &y).unwrap();
1190 let preds = fitted.predict_pipeline(&x).unwrap();
1191 assert_eq!(preds.len(), 4);
1192 }
1193
1194 #[test]
1195 fn test_forest_regressor_max_features_strategies() {
1196 let x = Array2::from_shape_vec(
1197 (8, 4),
1198 vec![
1199 1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 6.0, 4.0, 5.0, 6.0, 7.0,
1200 5.0, 6.0, 7.0, 8.0, 6.0, 7.0, 8.0, 9.0, 7.0, 8.0, 9.0, 10.0, 8.0, 9.0, 10.0, 11.0,
1201 ],
1202 )
1203 .unwrap();
1204 let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
1205
1206 for strategy in &[
1207 MaxFeatures::Sqrt,
1208 MaxFeatures::Log2,
1209 MaxFeatures::All,
1210 MaxFeatures::Fixed(2),
1211 MaxFeatures::Fraction(0.5),
1212 ] {
1213 let model = RandomForestRegressor::<f64>::new()
1214 .with_n_estimators(5)
1215 .with_max_features(*strategy)
1216 .with_random_state(42);
1217 let fitted = model.fit(&x, &y).unwrap();
1218 let preds = fitted.predict(&x).unwrap();
1219 assert_eq!(preds.len(), 8);
1220 }
1221 }
1222
1223 #[test]
1226 fn test_resolve_max_features_sqrt() {
1227 assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 9), 3);
1228 assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 10), 4);
1229 assert_eq!(resolve_max_features(MaxFeatures::Sqrt, 1), 1);
1230 }
1231
1232 #[test]
1233 fn test_resolve_max_features_log2() {
1234 assert_eq!(resolve_max_features(MaxFeatures::Log2, 8), 3);
1235 assert_eq!(resolve_max_features(MaxFeatures::Log2, 1), 1);
1236 }
1237
1238 #[test]
1239 fn test_resolve_max_features_all() {
1240 assert_eq!(resolve_max_features(MaxFeatures::All, 10), 10);
1241 assert_eq!(resolve_max_features(MaxFeatures::All, 1), 1);
1242 }
1243
1244 #[test]
1245 fn test_resolve_max_features_fixed() {
1246 assert_eq!(resolve_max_features(MaxFeatures::Fixed(3), 10), 3);
1247 assert_eq!(resolve_max_features(MaxFeatures::Fixed(20), 10), 10);
1248 }
1249
1250 #[test]
1251 fn test_resolve_max_features_fraction() {
1252 assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.5), 10), 5);
1253 assert_eq!(resolve_max_features(MaxFeatures::Fraction(0.1), 10), 1);
1254 }
1255
1256 #[test]
1257 fn test_forest_classifier_f32_support() {
1258 let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1259 let y = array![0, 0, 0, 1, 1, 1];
1260
1261 let model = RandomForestClassifier::<f32>::new()
1262 .with_n_estimators(5)
1263 .with_random_state(42);
1264 let fitted = model.fit(&x, &y).unwrap();
1265 let preds = fitted.predict(&x).unwrap();
1266 assert_eq!(preds.len(), 6);
1267 }
1268
1269 #[test]
1270 fn test_forest_regressor_f32_support() {
1271 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
1272 let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
1273
1274 let model = RandomForestRegressor::<f32>::new()
1275 .with_n_estimators(5)
1276 .with_random_state(42);
1277 let fitted = model.fit(&x, &y).unwrap();
1278 let preds = fitted.predict(&x).unwrap();
1279 assert_eq!(preds.len(), 4);
1280 }
1281}