use ferrolearn_core::error::FerroError;
use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
use ferrolearn_core::traits::{Fit, Predict};
use ndarray::{Array1, Array2};
use num_traits::{Float, FromPrimitive, ToPrimitive};
use rand::SeedableRng;
use rand::rngs::StdRng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::decision_tree::{
ClassificationCriterion, Node, TreeParams, compute_feature_importances, traverse,
};
use crate::extra_tree::{
build_extra_classification_tree_for_ensemble, build_extra_regression_tree_for_ensemble,
};
use crate::random_forest::MaxFeatures;
fn resolve_max_features(strategy: MaxFeatures, n_features: usize) -> usize {
let result = match strategy {
MaxFeatures::Sqrt => (n_features as f64).sqrt().ceil() as usize,
MaxFeatures::Log2 => (n_features as f64).log2().ceil().max(1.0) as usize,
MaxFeatures::All => n_features,
MaxFeatures::Fixed(n) => n.min(n_features),
MaxFeatures::Fraction(f) => ((n_features as f64) * f).ceil() as usize,
};
result.max(1).min(n_features)
}
fn make_tree_params(
max_depth: Option<usize>,
min_samples_split: usize,
min_samples_leaf: usize,
) -> TreeParams {
TreeParams {
max_depth,
min_samples_split,
min_samples_leaf,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtraTreesClassifier<F> {
pub n_estimators: usize,
pub max_depth: Option<usize>,
pub min_samples_split: usize,
pub min_samples_leaf: usize,
pub max_features: MaxFeatures,
pub bootstrap: bool,
pub criterion: ClassificationCriterion,
pub random_state: Option<u64>,
pub n_jobs: Option<usize>,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float> ExtraTreesClassifier<F> {
#[must_use]
pub fn new() -> Self {
Self {
n_estimators: 100,
max_depth: None,
min_samples_split: 2,
min_samples_leaf: 1,
max_features: MaxFeatures::Sqrt,
bootstrap: false,
criterion: ClassificationCriterion::Gini,
random_state: None,
n_jobs: None,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
self.n_estimators = n_estimators;
self
}
#[must_use]
pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
self.max_depth = max_depth;
self
}
#[must_use]
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
self.min_samples_split = min_samples_split;
self
}
#[must_use]
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
self.min_samples_leaf = min_samples_leaf;
self
}
#[must_use]
pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
self.max_features = max_features;
self
}
#[must_use]
pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
self.bootstrap = bootstrap;
self
}
#[must_use]
pub fn with_criterion(mut self, criterion: ClassificationCriterion) -> Self {
self.criterion = criterion;
self
}
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
#[must_use]
pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
self.n_jobs = Some(n_jobs);
self
}
}
impl<F: Float> Default for ExtraTreesClassifier<F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct FittedExtraTreesClassifier<F> {
trees: Vec<Vec<Node<F>>>,
classes: Vec<usize>,
n_features: usize,
feature_importances: Array1<F>,
}
impl<F: Float + Send + Sync + 'static> FittedExtraTreesClassifier<F> {
#[must_use]
pub fn trees(&self) -> &[Vec<Node<F>>] {
&self.trees
}
#[must_use]
pub fn n_features(&self) -> usize {
self.n_features
}
#[must_use]
pub fn n_estimators(&self) -> usize {
self.trees.len()
}
pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
if x.ncols() != self.n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features],
actual: vec![x.ncols()],
context: "number of features must match fitted model".into(),
});
}
let n_samples = x.nrows();
let n_classes = self.classes.len();
let n_trees_f = F::from(self.trees.len()).unwrap();
let mut proba = Array2::zeros((n_samples, n_classes));
for i in 0..n_samples {
let row = x.row(i);
for tree_nodes in &self.trees {
let leaf_idx = traverse(tree_nodes, &row);
if let Node::Leaf {
class_distribution: Some(ref dist),
..
} = tree_nodes[leaf_idx]
{
for (j, &p) in dist.iter().enumerate() {
proba[[i, j]] = proba[[i, j]] + p;
}
}
}
for j in 0..n_classes {
proba[[i, j]] = proba[[i, j]] / n_trees_f;
}
}
Ok(proba)
}
pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
if x.nrows() != y.len() {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows()],
actual: vec![y.len()],
context: "y length must match number of samples in X".into(),
});
}
let preds = self.predict(x)?;
Ok(crate::mean_accuracy(&preds, y))
}
pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let proba = self.predict_proba(x)?;
Ok(crate::log_proba(&proba))
}
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ExtraTreesClassifier<F> {
type Fitted = FittedExtraTreesClassifier<F>;
type Error = FerroError;
fn fit(
&self,
x: &Array2<F>,
y: &Array1<usize>,
) -> Result<FittedExtraTreesClassifier<F>, FerroError> {
let (n_samples, n_features) = x.dim();
if n_samples != y.len() {
return Err(FerroError::ShapeMismatch {
expected: vec![n_samples],
actual: vec![y.len()],
context: "y length must match number of samples in X".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "ExtraTreesClassifier requires at least one sample".into(),
});
}
if self.n_estimators == 0 {
return Err(FerroError::InvalidParameter {
name: "n_estimators".into(),
reason: "must be at least 1".into(),
});
}
let mut classes: Vec<usize> = y.iter().copied().collect();
classes.sort_unstable();
classes.dedup();
let n_classes = classes.len();
let y_mapped: Vec<usize> = y
.iter()
.map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
.collect();
let max_features_n = resolve_max_features(self.max_features, n_features);
let params = make_tree_params(
self.max_depth,
self.min_samples_split,
self.min_samples_leaf,
);
let criterion = self.criterion;
let bootstrap = self.bootstrap;
let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
let mut master_rng = StdRng::seed_from_u64(seed);
(0..self.n_estimators)
.map(|_| {
use rand::RngCore;
master_rng.next_u64()
})
.collect()
} else {
(0..self.n_estimators)
.map(|_| {
use rand::RngCore;
rand::rng().next_u64()
})
.collect()
};
let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(n_jobs)
.build()
.unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
pool.install(|| {
tree_seeds
.par_iter()
.map(|&seed| {
build_single_classification_tree(
x,
&y_mapped,
n_classes,
n_samples,
n_features,
max_features_n,
¶ms,
criterion,
bootstrap,
seed,
)
})
.collect()
})
} else {
tree_seeds
.par_iter()
.map(|&seed| {
build_single_classification_tree(
x,
&y_mapped,
n_classes,
n_samples,
n_features,
max_features_n,
¶ms,
criterion,
bootstrap,
seed,
)
})
.collect()
};
let mut total_importances = Array1::<F>::zeros(n_features);
for tree_nodes in &trees {
let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
total_importances = total_importances + tree_imp;
}
let imp_sum: F = total_importances
.iter()
.copied()
.fold(F::zero(), |a, b| a + b);
if imp_sum > F::zero() {
total_importances.mapv_inplace(|v| v / imp_sum);
}
Ok(FittedExtraTreesClassifier {
trees,
classes,
n_features,
feature_importances: total_importances,
})
}
}
#[allow(clippy::too_many_arguments)]
fn build_single_classification_tree<F: Float>(
x: &Array2<F>,
y_mapped: &[usize],
n_classes: usize,
n_samples: usize,
n_features: usize,
max_features_n: usize,
params: &TreeParams,
criterion: ClassificationCriterion,
bootstrap: bool,
seed: u64,
) -> Vec<Node<F>> {
let mut rng = StdRng::seed_from_u64(seed);
let indices: Vec<usize> = if bootstrap {
use rand::RngCore;
(0..n_samples)
.map(|_| (rng.next_u64() as usize) % n_samples)
.collect()
} else {
(0..n_samples).collect()
};
build_extra_classification_tree_for_ensemble(
x,
y_mapped,
n_classes,
&indices,
None, params,
criterion,
n_features,
max_features_n,
&mut rng,
)
}
impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesClassifier<F> {
type Output = Array1<usize>;
type Error = FerroError;
fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
if x.ncols() != self.n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features],
actual: vec![x.ncols()],
context: "number of features must match fitted model".into(),
});
}
let n_samples = x.nrows();
let n_classes = self.classes.len();
let mut predictions = Array1::zeros(n_samples);
for i in 0..n_samples {
let row = x.row(i);
let mut votes = vec![0usize; n_classes];
for tree_nodes in &self.trees {
let leaf_idx = traverse(tree_nodes, &row);
if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
let class_idx = value.to_f64().map_or(0, |f| f.round() as usize);
if class_idx < n_classes {
votes[class_idx] += 1;
}
}
}
let winner = votes
.iter()
.enumerate()
.max_by_key(|&(_, &count)| count)
.map_or(0, |(idx, _)| idx);
predictions[i] = self.classes[winner];
}
Ok(predictions)
}
}
impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesClassifier<F> {
fn feature_importances(&self) -> &Array1<F> {
&self.feature_importances
}
}
impl<F: Float + Send + Sync + 'static> HasClasses for FittedExtraTreesClassifier<F> {
fn classes(&self) -> &[usize] {
&self.classes
}
fn n_classes(&self) -> usize {
self.classes.len()
}
}
impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
for ExtraTreesClassifier<F>
{
fn fit_pipeline(
&self,
x: &Array2<F>,
y: &Array1<F>,
) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
let fitted = self.fit(x, &y_usize)?;
Ok(Box::new(FittedExtraTreesClassifierPipelineAdapter(fitted)))
}
}
struct FittedExtraTreesClassifierPipelineAdapter<F: Float + Send + Sync + 'static>(
FittedExtraTreesClassifier<F>,
);
impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
for FittedExtraTreesClassifierPipelineAdapter<F>
{
fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
let preds = self.0.predict(x)?;
Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtraTreesRegressor<F> {
pub n_estimators: usize,
pub max_depth: Option<usize>,
pub min_samples_split: usize,
pub min_samples_leaf: usize,
pub max_features: MaxFeatures,
pub bootstrap: bool,
pub random_state: Option<u64>,
pub n_jobs: Option<usize>,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float> ExtraTreesRegressor<F> {
#[must_use]
pub fn new() -> Self {
Self {
n_estimators: 100,
max_depth: None,
min_samples_split: 2,
min_samples_leaf: 1,
max_features: MaxFeatures::All,
bootstrap: false,
random_state: None,
n_jobs: None,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
self.n_estimators = n_estimators;
self
}
#[must_use]
pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
self.max_depth = max_depth;
self
}
#[must_use]
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
self.min_samples_split = min_samples_split;
self
}
#[must_use]
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
self.min_samples_leaf = min_samples_leaf;
self
}
#[must_use]
pub fn with_max_features(mut self, max_features: MaxFeatures) -> Self {
self.max_features = max_features;
self
}
#[must_use]
pub fn with_bootstrap(mut self, bootstrap: bool) -> Self {
self.bootstrap = bootstrap;
self
}
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
#[must_use]
pub fn with_n_jobs(mut self, n_jobs: usize) -> Self {
self.n_jobs = Some(n_jobs);
self
}
}
impl<F: Float> Default for ExtraTreesRegressor<F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct FittedExtraTreesRegressor<F> {
trees: Vec<Vec<Node<F>>>,
n_features: usize,
feature_importances: Array1<F>,
}
impl<F: Float + Send + Sync + 'static> FittedExtraTreesRegressor<F> {
#[must_use]
pub fn trees(&self) -> &[Vec<Node<F>>] {
&self.trees
}
#[must_use]
pub fn n_features(&self) -> usize {
self.n_features
}
#[must_use]
pub fn n_estimators(&self) -> usize {
self.trees.len()
}
pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
if x.nrows() != y.len() {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows()],
actual: vec![y.len()],
context: "y length must match number of samples in X".into(),
});
}
let preds = self.predict(x)?;
Ok(crate::r2_score(&preds, y))
}
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>> for ExtraTreesRegressor<F> {
type Fitted = FittedExtraTreesRegressor<F>;
type Error = FerroError;
fn fit(
&self,
x: &Array2<F>,
y: &Array1<F>,
) -> Result<FittedExtraTreesRegressor<F>, FerroError> {
let (n_samples, n_features) = x.dim();
if n_samples != y.len() {
return Err(FerroError::ShapeMismatch {
expected: vec![n_samples],
actual: vec![y.len()],
context: "y length must match number of samples in X".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "ExtraTreesRegressor requires at least one sample".into(),
});
}
if self.n_estimators == 0 {
return Err(FerroError::InvalidParameter {
name: "n_estimators".into(),
reason: "must be at least 1".into(),
});
}
let max_features_n = resolve_max_features(self.max_features, n_features);
let params = make_tree_params(
self.max_depth,
self.min_samples_split,
self.min_samples_leaf,
);
let bootstrap = self.bootstrap;
let tree_seeds: Vec<u64> = if let Some(seed) = self.random_state {
let mut master_rng = StdRng::seed_from_u64(seed);
(0..self.n_estimators)
.map(|_| {
use rand::RngCore;
master_rng.next_u64()
})
.collect()
} else {
(0..self.n_estimators)
.map(|_| {
use rand::RngCore;
rand::rng().next_u64()
})
.collect()
};
let trees: Vec<Vec<Node<F>>> = if let Some(n_jobs) = self.n_jobs {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(n_jobs)
.build()
.unwrap_or_else(|_| rayon::ThreadPoolBuilder::new().build().unwrap());
pool.install(|| {
tree_seeds
.par_iter()
.map(|&seed| {
build_single_regression_tree(
x,
y,
n_samples,
n_features,
max_features_n,
¶ms,
bootstrap,
seed,
)
})
.collect()
})
} else {
tree_seeds
.par_iter()
.map(|&seed| {
build_single_regression_tree(
x,
y,
n_samples,
n_features,
max_features_n,
¶ms,
bootstrap,
seed,
)
})
.collect()
};
let mut total_importances = Array1::<F>::zeros(n_features);
for tree_nodes in &trees {
let tree_imp = compute_feature_importances(tree_nodes, n_features, n_samples);
total_importances = total_importances + tree_imp;
}
let imp_sum: F = total_importances
.iter()
.copied()
.fold(F::zero(), |a, b| a + b);
if imp_sum > F::zero() {
total_importances.mapv_inplace(|v| v / imp_sum);
}
Ok(FittedExtraTreesRegressor {
trees,
n_features,
feature_importances: total_importances,
})
}
}
#[allow(clippy::too_many_arguments)]
fn build_single_regression_tree<F: Float>(
x: &Array2<F>,
y: &Array1<F>,
n_samples: usize,
n_features: usize,
max_features_n: usize,
params: &TreeParams,
bootstrap: bool,
seed: u64,
) -> Vec<Node<F>> {
let mut rng = StdRng::seed_from_u64(seed);
let indices: Vec<usize> = if bootstrap {
use rand::RngCore;
(0..n_samples)
.map(|_| (rng.next_u64() as usize) % n_samples)
.collect()
} else {
(0..n_samples).collect()
};
build_extra_regression_tree_for_ensemble(
x,
y,
&indices,
None, params,
n_features,
max_features_n,
&mut rng,
)
}
impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedExtraTreesRegressor<F> {
type Output = Array1<F>;
type Error = FerroError;
fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
if x.ncols() != self.n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features],
actual: vec![x.ncols()],
context: "number of features must match fitted model".into(),
});
}
let n_samples = x.nrows();
let n_trees_f = F::from(self.trees.len()).unwrap();
let mut predictions = Array1::zeros(n_samples);
for i in 0..n_samples {
let row = x.row(i);
let mut sum = F::zero();
for tree_nodes in &self.trees {
let leaf_idx = traverse(tree_nodes, &row);
if let Node::Leaf { value, .. } = tree_nodes[leaf_idx] {
sum = sum + value;
}
}
predictions[i] = sum / n_trees_f;
}
Ok(predictions)
}
}
impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F> for FittedExtraTreesRegressor<F> {
fn feature_importances(&self) -> &Array1<F> {
&self.feature_importances
}
}
impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for ExtraTreesRegressor<F> {
fn fit_pipeline(
&self,
x: &Array2<F>,
y: &Array1<F>,
) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
let fitted = self.fit(x, y)?;
Ok(Box::new(fitted))
}
}
impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F> for FittedExtraTreesRegressor<F> {
fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
self.predict(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use ndarray::array;
#[test]
fn test_ensemble_classifier_simple() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
],
)
.unwrap();
let y = array![0, 0, 0, 0, 1, 1, 1, 1];
let model = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(20)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds, y);
}
#[test]
fn test_ensemble_classifier_no_bootstrap() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
],
)
.unwrap();
let y = array![0, 0, 0, 0, 1, 1, 1, 1];
let model = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(10)
.with_random_state(42);
assert!(!model.bootstrap);
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds, y);
}
#[test]
fn test_ensemble_classifier_with_bootstrap() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
],
)
.unwrap();
let y = array![0, 0, 0, 0, 1, 1, 1, 1];
let model = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(20)
.with_bootstrap(true)
.with_random_state(42);
assert!(model.bootstrap);
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 8);
}
#[test]
fn test_ensemble_classifier_predict_proba() {
let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = array![0, 0, 0, 1, 1, 1];
let model = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(10)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
let proba = fitted.predict_proba(&x).unwrap();
assert_eq!(proba.dim(), (6, 2));
for i in 0..6 {
let row_sum = proba.row(i).sum();
assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
}
}
#[test]
fn test_ensemble_classifier_feature_importances() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 5.0, 1.0, 6.0, 1.0, 7.0, 1.0, 8.0, 1.0,
],
)
.unwrap();
let y = array![0, 0, 0, 0, 1, 1, 1, 1];
let model = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(20)
.with_max_features(MaxFeatures::All)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
let importances = fitted.feature_importances();
assert_eq!(importances.len(), 2);
let total: f64 = importances.sum();
assert_relative_eq!(total, 1.0, epsilon = 1e-10);
assert!(importances[0] > importances[1]);
}
#[test]
fn test_ensemble_classifier_n_estimators() {
let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let y = array![0, 0, 1, 1];
let model = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(15)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
assert_eq!(fitted.n_estimators(), 15);
}
#[test]
fn test_ensemble_classifier_classes() {
let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = array![0, 0, 0, 3, 3, 3];
let model = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(5)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
assert_eq!(fitted.classes(), &[0, 3]);
assert_eq!(fitted.n_classes(), 2);
}
#[test]
fn test_ensemble_classifier_shape_mismatch() {
let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = array![0, 0];
let model = ExtraTreesClassifier::<f64>::new();
assert!(model.fit(&x, &y).is_err());
}
#[test]
fn test_ensemble_classifier_empty_data() {
let x = Array2::<f64>::zeros((0, 2));
let y = Array1::<usize>::zeros(0);
let model = ExtraTreesClassifier::<f64>::new();
assert!(model.fit(&x, &y).is_err());
}
#[test]
fn test_ensemble_classifier_zero_estimators() {
let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
let y = array![0, 1];
let model = ExtraTreesClassifier::<f64>::new().with_n_estimators(0);
assert!(model.fit(&x, &y).is_err());
}
#[test]
fn test_ensemble_classifier_deterministic() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
],
)
.unwrap();
let y = array![0, 0, 0, 0, 1, 1, 1, 1];
let model1 = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(10)
.with_random_state(123);
let model2 = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(10)
.with_random_state(123);
let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
assert_eq!(preds1, preds2);
}
#[test]
fn test_ensemble_classifier_predict_shape_mismatch() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
let y = array![0, 0, 1, 1];
let model = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(5)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert!(fitted.predict(&x_wrong).is_err());
}
#[test]
fn test_ensemble_regressor_simple() {
let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let model = ExtraTreesRegressor::<f64>::new()
.with_n_estimators(20)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 6);
for i in 0..6 {
assert_relative_eq!(preds[i], y[i], epsilon = 1.0);
}
}
#[test]
fn test_ensemble_regressor_constant_target() {
let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let y = array![5.0, 5.0, 5.0, 5.0];
let model = ExtraTreesRegressor::<f64>::new()
.with_n_estimators(10)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
for &p in &preds {
assert_relative_eq!(p, 5.0, epsilon = 1e-10);
}
}
#[test]
fn test_ensemble_regressor_no_bootstrap() {
let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let y = array![1.0, 2.0, 3.0, 4.0];
let model = ExtraTreesRegressor::<f64>::new()
.with_n_estimators(10)
.with_random_state(42);
assert!(!model.bootstrap);
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 4);
}
#[test]
fn test_ensemble_regressor_with_bootstrap() {
let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let model = ExtraTreesRegressor::<f64>::new()
.with_n_estimators(10)
.with_bootstrap(true)
.with_random_state(42);
assert!(model.bootstrap);
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 6);
}
#[test]
fn test_ensemble_regressor_feature_importances() {
let x = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0, 0.0,
],
)
.unwrap();
let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let model = ExtraTreesRegressor::<f64>::new()
.with_n_estimators(20)
.with_max_features(MaxFeatures::All)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
let importances = fitted.feature_importances();
assert_eq!(importances.len(), 2);
let total: f64 = importances.sum();
assert_relative_eq!(total, 1.0, epsilon = 1e-10);
assert!(importances[0] > importances[1]);
}
#[test]
fn test_ensemble_regressor_n_estimators() {
let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let y = array![1.0, 2.0, 3.0, 4.0];
let model = ExtraTreesRegressor::<f64>::new()
.with_n_estimators(7)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
assert_eq!(fitted.n_estimators(), 7);
}
#[test]
fn test_ensemble_regressor_shape_mismatch() {
let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = array![1.0, 2.0];
let model = ExtraTreesRegressor::<f64>::new();
assert!(model.fit(&x, &y).is_err());
}
#[test]
fn test_ensemble_regressor_empty_data() {
let x = Array2::<f64>::zeros((0, 2));
let y = Array1::<f64>::zeros(0);
let model = ExtraTreesRegressor::<f64>::new();
assert!(model.fit(&x, &y).is_err());
}
#[test]
fn test_ensemble_regressor_zero_estimators() {
let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
let y = array![1.0, 2.0];
let model = ExtraTreesRegressor::<f64>::new().with_n_estimators(0);
assert!(model.fit(&x, &y).is_err());
}
#[test]
fn test_ensemble_regressor_deterministic() {
let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let model1 = ExtraTreesRegressor::<f64>::new()
.with_n_estimators(10)
.with_random_state(99);
let model2 = ExtraTreesRegressor::<f64>::new()
.with_n_estimators(10)
.with_random_state(99);
let preds1 = model1.fit(&x, &y).unwrap().predict(&x).unwrap();
let preds2 = model2.fit(&x, &y).unwrap().predict(&x).unwrap();
for i in 0..6 {
assert_relative_eq!(preds1[i], preds2[i], epsilon = 1e-12);
}
}
#[test]
fn test_ensemble_regressor_predict_shape_mismatch() {
let x =
Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
let y = array![1.0, 2.0, 3.0, 4.0];
let model = ExtraTreesRegressor::<f64>::new()
.with_n_estimators(5)
.with_random_state(42);
let fitted = model.fit(&x, &y).unwrap();
let x_wrong = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert!(fitted.predict(&x_wrong).is_err());
}
#[test]
fn test_ensemble_classifier_builder() {
let model = ExtraTreesClassifier::<f64>::new()
.with_n_estimators(50)
.with_max_depth(Some(5))
.with_min_samples_split(10)
.with_min_samples_leaf(3)
.with_max_features(MaxFeatures::Log2)
.with_bootstrap(true)
.with_criterion(ClassificationCriterion::Entropy)
.with_random_state(42)
.with_n_jobs(4);
assert_eq!(model.n_estimators, 50);
assert_eq!(model.max_depth, Some(5));
assert_eq!(model.min_samples_split, 10);
assert_eq!(model.min_samples_leaf, 3);
assert_eq!(model.max_features, MaxFeatures::Log2);
assert!(model.bootstrap);
assert_eq!(model.criterion, ClassificationCriterion::Entropy);
assert_eq!(model.random_state, Some(42));
assert_eq!(model.n_jobs, Some(4));
}
#[test]
fn test_ensemble_regressor_builder() {
let model = ExtraTreesRegressor::<f64>::new()
.with_n_estimators(25)
.with_max_depth(Some(8))
.with_min_samples_split(5)
.with_min_samples_leaf(2)
.with_max_features(MaxFeatures::Fraction(0.5))
.with_bootstrap(true)
.with_random_state(99)
.with_n_jobs(2);
assert_eq!(model.n_estimators, 25);
assert_eq!(model.max_depth, Some(8));
assert_eq!(model.min_samples_split, 5);
assert_eq!(model.min_samples_leaf, 2);
assert_eq!(model.max_features, MaxFeatures::Fraction(0.5));
assert!(model.bootstrap);
assert_eq!(model.random_state, Some(99));
assert_eq!(model.n_jobs, Some(2));
}
#[test]
fn test_ensemble_classifier_default() {
let model = ExtraTreesClassifier::<f64>::default();
assert_eq!(model.n_estimators, 100);
assert_eq!(model.max_depth, None);
assert_eq!(model.min_samples_split, 2);
assert_eq!(model.min_samples_leaf, 1);
assert_eq!(model.max_features, MaxFeatures::Sqrt);
assert!(!model.bootstrap);
assert_eq!(model.criterion, ClassificationCriterion::Gini);
assert_eq!(model.random_state, None);
assert_eq!(model.n_jobs, None);
}
#[test]
fn test_ensemble_regressor_default() {
let model = ExtraTreesRegressor::<f64>::default();
assert_eq!(model.n_estimators, 100);
assert_eq!(model.max_depth, None);
assert_eq!(model.min_samples_split, 2);
assert_eq!(model.min_samples_leaf, 1);
assert_eq!(model.max_features, MaxFeatures::All);
assert!(!model.bootstrap);
assert_eq!(model.random_state, None);
assert_eq!(model.n_jobs, None);
}
}