use super::core::{RandomForest, RandomForestOptions};
use crate::criterion::RegressionCriterion;
use crate::functions;
use crate::table::Table;
use std::io::{Read, Write};
use std::num::NonZeroUsize;
#[derive(Debug, Clone, Default)]
pub struct RandomForestRegressorOptions {
inner: RandomForestOptions,
}
impl RandomForestRegressorOptions {
pub fn new() -> Self {
Self::default()
}
pub fn seed(&mut self, seed: u64) -> &mut Self {
self.inner.seed(seed);
self
}
pub fn trees(&mut self, trees: NonZeroUsize) -> &mut Self {
self.inner.trees(trees);
self
}
pub fn max_features(&mut self, max: NonZeroUsize) -> &mut Self {
self.inner.max_features(max);
self
}
pub fn max_samples(&mut self, max: NonZeroUsize) -> &mut Self {
self.inner.max_samples(max);
self
}
pub fn parallel(&mut self) -> &mut Self {
self.inner.parallel();
self
}
pub fn fit<T: RegressionCriterion>(&self, criterion: T, table: Table) -> RandomForestRegressor {
RandomForestRegressor {
inner: self.inner.fit(criterion, true, table),
}
}
}
#[derive(Debug)]
pub struct RandomForestRegressor {
inner: RandomForest,
}
impl RandomForestRegressor {
pub fn fit<T: RegressionCriterion>(criterion: T, table: Table) -> Self {
RandomForestRegressorOptions::default().fit(criterion, table)
}
pub fn predict(&self, features: &[f64]) -> f64 {
functions::mean(self.inner.predict(features))
}
pub fn predict_individuals<'a>(
&'a self,
features: &'a [f64],
) -> impl 'a + Iterator<Item = f64> {
self.inner.predict(features)
}
pub fn serialize<W: Write>(&self, writer: W) -> std::io::Result<()> {
self.inner.serialize(writer)
}
pub fn deserialize<R: Read>(reader: R) -> std::io::Result<Self> {
let inner = RandomForest::deserialize(reader)?;
Ok(Self { inner })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::criterion::Mse;
use crate::table::TableBuilder;
#[test]
fn regression_works() -> Result<(), anyhow::Error> {
let features = [
&[0.0, 2.0, 1.0, 0.0][..],
&[0.0, 2.0, 1.0, 1.0][..],
&[1.0, 2.0, 1.0, 0.0][..],
&[2.0, 1.0, 1.0, 0.0][..],
&[2.0, 0.0, 0.0, 0.0][..],
&[2.0, 0.0, 0.0, 1.0][..],
&[1.0, 0.0, 0.0, 1.0][..],
&[0.0, 1.0, 1.0, 0.0][..],
&[0.0, 0.0, 0.0, 0.0][..],
&[2.0, 1.0, 0.0, 0.0][..],
&[0.0, 1.0, 0.0, 1.0][..],
&[1.0, 1.0, 1.0, 1.0][..],
&[1.0, 2.0, 0.0, 0.0][..],
&[2.0, 1.0, 1.0, 1.0][..],
];
let target = [
25.0, 30.0, 46.0, 45.0, 52.0, 23.0, 43.0, 35.0, 38.0, 46.0, 48.0, 52.0, 44.0, 30.0,
];
let train_len = target.len() - 2;
let mut table_builder = TableBuilder::new();
for (xs, y) in features.iter().zip(target.iter()).take(train_len) {
table_builder.add_row(xs, *y)?;
}
let table = table_builder.build()?;
let regressor = RandomForestRegressorOptions::new()
.seed(0)
.fit(Mse, table.clone());
assert_eq!(regressor.predict(&features[train_len]), 41.9785);
assert_eq!(
regressor.predict(&features[train_len + 1]),
43.50333333333333
);
let regressor_parallel = RandomForestRegressorOptions::new()
.seed(0)
.parallel()
.fit(Mse, table);
assert_eq!(
regressor.predict(&features[train_len]),
regressor_parallel.predict(&features[train_len])
);
let mut bytes = Vec::new();
regressor.serialize(&mut bytes)?;
let regressor_deserialized = RandomForestRegressor::deserialize(&mut &bytes[..])?;
assert_eq!(
regressor.predict(&features[train_len]),
regressor_deserialized.predict(&features[train_len])
);
Ok(())
}
}