rafor 0.3.0

Fast Random Forest library.
Documentation
use crate::{ClassTarget, SampleWeight};

#[derive(Default, Clone, Debug)]
pub struct Gini {
    bins: Vec<f64>,
    total_weight: f64,
    sum_squares: f64,
}

#[derive(Default, Clone)]
pub struct Mse {
    mean: f64,
    sum_squares: f64,
    total_weight: f64,
}

pub trait ImpurityMetric<Target> {
    fn push(&mut self, item: Target, weight: SampleWeight);
    fn pop(&mut self, item: Target, weight: SampleWeight);
    fn pure(&self) -> bool;
    fn split_impurity(&self, other: &Self) -> f64;
}

pub trait WithClasses {
    fn with_classes(num_classes: usize) -> Self;
}

impl ImpurityMetric<ClassTarget> for Gini {
    #[inline(always)]
    fn push(&mut self, bin_index: ClassTarget, weight: SampleWeight) {
        let weight = weight as f64;
        self.sum_squares += weight * (2. * self.bins[bin_index as usize] + weight);
        self.bins[bin_index as usize] += weight;
        self.total_weight += weight;
    }

    #[inline(always)]
    fn pop(&mut self, bin_index: ClassTarget, weight: SampleWeight) {
        let weight = weight as f64;
        self.sum_squares =
            self.sum_squares + weight * (weight - 2. * self.bins[bin_index as usize]);
        self.bins[bin_index as usize] -= weight;
        self.total_weight -= weight;
    }

    #[inline(always)]
    fn pure(&self) -> bool {
        let empty_bins = self.bins.iter().filter(|&x| *x == 0.).count();
        self.bins.len() <= empty_bins + 1
    }

    #[inline(always)]
    fn split_impurity(&self, other: &Self) -> f64 {
        1.0 - (self.sum_squares * other.total_weight + other.sum_squares * self.total_weight) as f64
            / (self.total_weight * other.total_weight * (self.total_weight + other.total_weight))
                as f64
    }
}

impl WithClasses for Gini {
    fn with_classes(num_classes: usize) -> Gini {
        Gini {
            bins: vec![0.; num_classes],
            total_weight: 0.,
            sum_squares: 0.,
        }
    }
}

impl ImpurityMetric<f32> for Mse {
    #[inline(always)]
    fn push(&mut self, y: f32, weight: SampleWeight) {
        let weight = weight as f64;
        let y = y as f64;

        let next_mean =
            self.mean + weight as f64 * (y - self.mean) / (self.total_weight + weight) as f64;
        self.sum_squares += weight as f64 * (y - self.mean) * (y - next_mean);
        self.mean = next_mean;
        self.total_weight += weight;
    }

    #[inline(always)]
    fn pop(&mut self, y: f32, weight: SampleWeight) {
        let weight = weight as f64;
        let y = y as f64;

        let next_mean =
            y + self.total_weight * (self.mean - y) / (self.total_weight - weight) as f64;
        self.sum_squares -= weight as f64 * (y - next_mean) * (y - self.mean);
        self.mean = next_mean;
        self.total_weight -= weight;
    }

    #[inline(always)]
    fn pure(&self) -> bool {
        self.sum_squares == 0.
    }

    #[inline(always)]
    fn split_impurity(&self, other: &Self) -> f64 {
        self.sum_squares + other.sum_squares
    }
}