randomforest 0.1.6

Random forest regressor and classifier
Documentation
use crate::criterion::Criterion;
use crate::functions;
use crate::table::{ColumnType, Table};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use rand::seq::SliceRandom as _;
use rand::Rng;
use std::io::{Read, Write};

const MIN_SAMPLES_SPLIT: usize = 2;
const MAX_DEPTH: usize = 64;

#[derive(Debug, Clone, Default)]
pub struct DecisionTreeOptions {
    pub max_features: Option<usize>,
    pub is_regression: bool,
}

#[derive(Debug)]
pub struct DecisionTree {
    root: Node,
}

impl DecisionTree {
    pub fn fit<R: Rng + ?Sized, T: Criterion>(
        rng: &mut R,
        criterion: T,
        mut table: Table,
        options: DecisionTreeOptions,
    ) -> Self {
        let max_features = options.max_features.unwrap_or_else(|| table.features_len());
        let mut builder = NodeBuilder {
            rng,
            max_features,
            is_regression: options.is_regression,
            criterion,
        };
        let root = builder.build(&mut table, 1);
        Self { root }
    }

    pub fn predict(&self, xs: &[f64], columns: &[ColumnType]) -> f64 {
        self.root.predict(xs, columns)
    }

    pub fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
        self.root.serialize(writer)
    }

    pub fn deserialize<R: Read>(reader: &mut R) -> std::io::Result<Self> {
        let root = Node::deserialize(reader)?;
        Ok(Self { root })
    }
}

#[derive(Debug)]
pub enum Node {
    Leaf { value: f64 },
    Internal { children: Children },
}

impl Node {
    fn predict(&self, xs: &[f64], columns: &[ColumnType]) -> f64 {
        match self {
            Self::Leaf { value } => *value,
            Self::Internal { children } => {
                if columns[children.split.column]
                    .is_left(xs[children.split.column], children.split.value)
                {
                    children.left.predict(xs, columns)
                } else {
                    children.right.predict(xs, columns)
                }
            }
        }
    }

    fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
        match self {
            Self::Leaf { value } => {
                writer.write_u8(0)?;
                writer.write_f64::<BigEndian>(*value)?;
            }
            Self::Internal { children } => {
                writer.write_u8(1)?;
                children.serialize(writer)?;
            }
        }
        Ok(())
    }

    fn deserialize<R: Read>(reader: &mut R) -> std::io::Result<Self> {
        let kind = reader.read_u8()?;
        match kind {
            0 => {
                let value = reader.read_f64::<BigEndian>()?;
                Ok(Self::Leaf { value })
            }
            1 => {
                let children = Children::deserialize(reader)?;
                Ok(Self::Internal { children })
            }
            v => Err(std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                format!("unknown node type {}", v),
            )),
        }
    }
}

#[derive(Debug)]
pub struct Children {
    split: SplitPoint,
    left: Box<Node>,
    right: Box<Node>,
}

impl Children {
    fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
        self.split.serialize(writer)?;
        self.left.serialize(writer)?;
        self.right.serialize(writer)?;
        Ok(())
    }

    fn deserialize<R: Read>(reader: &mut R) -> std::io::Result<Self> {
        let split = SplitPoint::deserialize(reader)?;
        let left = Box::new(Node::deserialize(reader)?);
        let right = Box::new(Node::deserialize(reader)?);
        Ok(Self { split, left, right })
    }
}

#[derive(Debug)]
pub struct SplitPoint {
    column: usize,
    value: f64,
}

impl SplitPoint {
    fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
        writer.write_u16::<BigEndian>(self.column as u16)?;
        writer.write_f64::<BigEndian>(self.value)?;
        Ok(())
    }

    fn deserialize<R: Read>(reader: &mut R) -> std::io::Result<Self> {
        let column = reader.read_u16::<BigEndian>()? as usize;
        let value = reader.read_f64::<BigEndian>()?;
        Ok(Self { column, value })
    }
}

#[derive(Debug)]
struct NodeBuilder<R, T> {
    rng: R,
    max_features: usize,
    is_regression: bool,
    criterion: T,
}

impl<R: Rng, T: Criterion> NodeBuilder<R, T> {
    fn build(&mut self, table: &mut Table, depth: usize) -> Node {
        if table.rows_len() < MIN_SAMPLES_SPLIT || depth > MAX_DEPTH {
            let value = self.average(table.target());
            return Node::Leaf { value };
        }

        let impurity = self.criterion.calculate(table.target());
        let valid_columns = (0..table.features_len())
            .filter(|&i| !table.column(i).any(|f| f.is_nan()))
            .collect::<Vec<_>>();

        let mut best_split: Option<SplitPoint> = None;
        let mut best_informatin_gain = std::f64::MIN;
        let max_features = std::cmp::min(valid_columns.len(), self.max_features);
        for &column in valid_columns.choose_multiple(&mut self.rng, max_features) {
            table.sort_rows_by_column(column);
            for (left_row, value) in table.split_points(column) {
                let rows_l = table.target().take(left_row.end).skip(left_row.start);
                let rows_r = table
                    .target()
                    .take(left_row.start)
                    .chain(table.target().skip(left_row.end));
                let impurity_l = self.criterion.calculate(rows_l);
                let impurity_r = self.criterion.calculate(rows_r);
                let ratio_l = (left_row.end - left_row.start) as f64 / table.rows_len() as f64;
                let ratio_r = 1.0 - ratio_l;

                let information_gain = impurity - (ratio_l * impurity_l + ratio_r * impurity_r);
                if best_informatin_gain < information_gain {
                    best_informatin_gain = information_gain;
                    best_split = Some(SplitPoint { column, value });
                }
            }
        }

        if let Some(split) = best_split {
            let children = self.build_children(table, split, depth);
            Node::Internal { children }
        } else {
            let value = self.average(table.target());
            Node::Leaf { value }
        }
    }

    fn build_children(&mut self, table: &mut Table, split: SplitPoint, depth: usize) -> Children {
        match table.column_types()[split.column] {
            ColumnType::Categorical => {
                table.sort_rows_by_categorical_column(split.column, split.value);
            }
            ColumnType::Numerical => {
                table.sort_rows_by_column(split.column);
            }
        }

        let split_row = table
            .column(split.column)
            .take_while(|&f| table.column_types()[split.column].is_left(f, split.value))
            .count();
        let (left, right) =
            table.with_split(split_row, |table| Box::new(self.build(table, depth + 1)));
        Children { split, left, right }
    }

    fn average(&self, ys: impl Iterator<Item = f64>) -> f64 {
        if self.is_regression {
            functions::mean(ys)
        } else {
            functions::most_frequent(ys)
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::criterion::Mse;
    use crate::table::TableBuilder;
    use rand;

    #[test]
    fn regression_works() -> Result<(), anyhow::Error> {
        let features = [
            &[0.0, 2.0, 1.0, 0.0][..],
            &[0.0, 2.0, 1.0, 1.0][..],
            &[1.0, 2.0, 1.0, 0.0][..],
            &[2.0, 1.0, 1.0, 0.0][..],
            &[2.0, 0.0, 0.0, 0.0][..],
            &[2.0, 0.0, 0.0, 1.0][..],
            &[1.0, 0.0, 0.0, 1.0][..],
            &[0.0, 1.0, 1.0, 0.0][..],
            &[0.0, 0.0, 0.0, 0.0][..],
            &[2.0, 1.0, 0.0, 0.0][..],
            &[0.0, 1.0, 0.0, 1.0][..],
            &[1.0, 1.0, 1.0, 1.0][..],
            &[1.0, 2.0, 0.0, 0.0][..],
            &[2.0, 1.0, 1.0, 1.0][..],
        ];
        let target = [
            25.0, 30.0, 46.0, 45.0, 52.0, 23.0, 43.0, 35.0, 38.0, 46.0, 48.0, 52.0, 44.0, 30.0,
        ];
        let train_len = target.len() - 2;

        let mut table_builder = TableBuilder::new();
        for (xs, y) in features.iter().zip(target.iter()).take(train_len) {
            table_builder.add_row(xs, *y)?;
        }
        let table = table_builder.build()?;

        let columns = [
            ColumnType::Numerical,
            ColumnType::Numerical,
            ColumnType::Numerical,
            ColumnType::Numerical,
        ];
        let regressor = DecisionTree::fit(&mut rand::thread_rng(), Mse, table, Default::default());
        assert_eq!(regressor.predict(&features[train_len], &columns), 46.0);
        assert_eq!(regressor.predict(&features[train_len + 1], &columns), 52.0);

        Ok(())
    }
}