use std::default::Default;
use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::tree::base_tree_regressor::Splitter;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ExtraTreesRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub seed: u64,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct ExtraTreesRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
}
impl ExtraTreesRegressorParameters {
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
self.max_depth = Some(max_depth);
self
}
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
self.min_samples_leaf = min_samples_leaf;
self
}
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
self.min_samples_split = min_samples_split;
self
}
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
self.n_trees = n_trees;
self
}
pub fn with_m(mut self, m: usize) -> Self {
self.m = Some(m);
self
}
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
self.keep_samples = keep_samples;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
}
impl Default for ExtraTreesRegressorParameters {
fn default() -> Self {
ExtraTreesRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 10,
m: Option::None,
keep_samples: false,
seed: 0,
}
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, ExtraTreesRegressorParameters> for ExtraTreesRegressor<TX, TY, X, Y>
{
fn new() -> Self {
Self {
forest_regressor: Option::None,
}
}
fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result<Self, Failed> {
ExtraTreesRegressor::fit(x, y, parameters)
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
Predictor<X, Y> for ExtraTreesRegressor<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
ExtraTreesRegressor<TX, TY, X, Y>
{
pub fn fit(
x: &X,
y: &Y,
parameters: ExtraTreesRegressorParameters,
) -> Result<ExtraTreesRegressor<TX, TY, X, Y>, Failed> {
let regressor_params = BaseForestRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
n_trees: parameters.n_trees,
m: parameters.m,
keep_samples: parameters.keep_samples,
seed: parameters.seed,
bootstrap: false,
splitter: Splitter::Random,
};
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
Ok(ExtraTreesRegressor {
forest_regressor: Some(forest_regressor),
})
}
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict(x)
}
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict_oob(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_squared_error;
#[test]
fn test_extra_trees_regressor_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
&[11., 12.],
&[13., 14.],
&[15., 16.],
])
.unwrap();
let y = vec![1., 2., 3., 4., 5., 6., 7., 8.];
let parameters = ExtraTreesRegressorParameters::default()
.with_n_trees(100)
.with_seed(42);
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
let y_hat = regressor.predict(&x).unwrap();
assert_eq!(y_hat.len(), y.len());
let mse = mean_squared_error(&y, &y_hat);
assert!(mse < 1.0);
}
#[test]
fn test_fit_predict_higher_dims() {
let x = DenseMatrix::from_2d_array(&[
&[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.],
&[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.],
&[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.],
&[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.],
&[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.],
&[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.],
])
.unwrap();
let y = vec![10., 20., 30., 40., 55., 65.];
let parameters = ExtraTreesRegressorParameters::default()
.with_n_trees(100)
.with_seed(42);
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
let y_hat = regressor.predict(&x).unwrap();
assert_eq!(y_hat.len(), y.len());
let mse = mean_squared_error(&y, &y_hat);
assert!(mse < 1.0);
}
#[test]
fn test_reproducibility() {
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
&[11., 12.],
])
.unwrap();
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let params = ExtraTreesRegressorParameters::default().with_seed(42);
let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
let y_hat1 = regressor1.predict(&x).unwrap();
let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
let y_hat2 = regressor2.predict(&x).unwrap();
assert_eq!(y_hat1, y_hat2);
}
}