use crate::functions;
use crate::table::Table;
use rand::Rng;
use rand::seq::IndexedRandom as _;
const MIN_SAMPLES_SPLIT: usize = 2;
const MAX_DEPTH: usize = 64;
#[derive(Debug, Clone, Default)]
pub struct DecisionTreeOptions {
pub max_features: Option<usize>,
}
#[derive(Debug)]
pub struct DecisionTreeRegressor {
tree: Tree,
}
impl DecisionTreeRegressor {
pub fn fit<R: Rng + ?Sized>(rng: &mut R, table: Table, options: DecisionTreeOptions) -> Self {
let tree = Tree::fit(rng, table, options);
Self { tree }
}
#[cfg(test)]
pub fn predict(&self, xs: &[f64]) -> f64 {
self.tree.predict(xs)
}
pub fn fold<InternalT, InternalF, LeafT, LeafF>(
&self,
internal_init: InternalT,
mut internal_f: InternalF,
leaf_init: LeafT,
mut leaf_f: LeafF,
) -> LeafT
where
InternalF: FnMut(InternalT, &SplitPoint) -> (InternalT, InternalT),
LeafF: FnMut(LeafT, InternalT, f64) -> LeafT,
{
let mut leaf_acc = leaf_init;
let mut stack = vec![(&self.tree.root, internal_init)];
while let Some((node, internal_acc)) = stack.pop() {
match node {
Node::Leaf { value } => {
leaf_acc = leaf_f(leaf_acc, internal_acc, *value);
}
Node::Internal { children } => {
let (acc_l, acc_r) = internal_f(internal_acc, &children.split);
stack.push((&children.left, acc_l));
stack.push((&children.right, acc_r));
}
}
}
leaf_acc
}
}
#[derive(Debug)]
pub struct Tree {
root: Node,
}
impl Tree {
pub fn fit<R: Rng + ?Sized>(
rng: &mut R,
mut table: Table,
options: DecisionTreeOptions,
) -> Self {
let max_features = options.max_features.unwrap_or_else(|| table.features_len());
let mut builder = NodeBuilder { rng, max_features };
let root = builder.build(&mut table, 1);
Self { root }
}
#[cfg(test)]
fn predict(&self, xs: &[f64]) -> f64 {
self.root.predict(xs)
}
}
#[derive(Debug)]
pub enum Node {
Leaf { value: f64 },
Internal { children: Children },
}
impl Node {
#[cfg(test)]
fn predict(&self, xs: &[f64]) -> f64 {
match self {
Self::Leaf { value } => *value,
Self::Internal { children } => {
if xs[children.split.column] <= children.split.threshold {
children.left.predict(xs)
} else {
children.right.predict(xs)
}
}
}
}
}
#[derive(Debug)]
pub struct Children {
split: SplitPoint,
left: Box<Node>,
right: Box<Node>,
}
#[derive(Debug)]
pub struct SplitPoint {
pub column: usize,
pub threshold: f64,
}
#[derive(Debug)]
struct NodeBuilder<R> {
rng: R,
max_features: usize,
}
impl<R: Rng> NodeBuilder<R> {
fn build(&mut self, table: &mut Table, depth: usize) -> Node {
if table.rows_len() < MIN_SAMPLES_SPLIT || depth > MAX_DEPTH {
let value = functions::mean(table.target());
return Node::Leaf { value };
}
let impurity = functions::mse(table.target());
let valid_columns = (0..table.features_len())
.filter(|&i| !table.column(i).any(|f| f.is_nan()))
.collect::<Vec<_>>();
let mut best_split: Option<SplitPoint> = None;
let mut best_informatin_gain = f64::MIN;
let max_features = std::cmp::min(valid_columns.len(), self.max_features);
for &column in valid_columns.sample(&mut self.rng, max_features) {
table.sort_rows_by_column(column);
for (row, threshold) in table.thresholds(column) {
let impurity_l = functions::mse(table.target().take(row));
let impurity_r = functions::mse(table.target().skip(row));
let ratio_l = row as f64 / table.rows_len() as f64;
let ratio_r = 1.0 - ratio_l;
let information_gain = impurity - (ratio_l * impurity_l + ratio_r * impurity_r);
if best_informatin_gain < information_gain {
best_informatin_gain = information_gain;
best_split = Some(SplitPoint { column, threshold });
}
}
}
if let Some(split) = best_split {
let children = self.build_children(table, split, depth);
Node::Internal { children }
} else {
let value = functions::mean(table.target());
Node::Leaf { value }
}
}
fn build_children(&mut self, table: &mut Table, split: SplitPoint, depth: usize) -> Children {
table.sort_rows_by_column(split.column);
let split_row = table
.column(split.column)
.take_while(|&f| f <= split.threshold)
.count();
let (left, right) =
table.with_split(split_row, |table| Box::new(self.build(table, depth + 1)));
Children { split, left, right }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn regression_works() -> Result<(), Box<dyn std::error::Error>> {
let columns = [
&[
0.0, 0.0, 1.0, 2.0, 2.0, 2.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0, 2.0,
],
&[
2.0, 2.0, 2.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0,
],
&[
1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0,
],
&[
0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0,
],
&[
25.0, 30.0, 46.0, 45.0, 52.0, 23.0, 43.0, 35.0, 38.0, 46.0, 48.0, 52.0, 44.0, 30.0,
],
];
let train_len = columns[0].len() - 2;
let table = Table::new(columns.iter().map(|c| &c[..train_len]).collect())?;
let regressor = DecisionTreeRegressor::fit(&mut rand::rng(), table, Default::default());
assert_eq!(
regressor.predict(&columns.iter().map(|f| f[train_len]).collect::<Vec<_>>()),
46.0
);
assert_eq!(
regressor.predict(&columns.iter().map(|f| f[train_len + 1]).collect::<Vec<_>>()),
52.0
);
Ok(())
}
}