forestfire-core 0.1.4

Core tree-learning algorithms for the ForestFire project.
Documentation
use super::*;
use crate::tree::shared::{
    FeatureHistogram, HistogramBin, build_feature_histograms, subtract_feature_histograms,
};

#[derive(Debug, Clone)]
pub(super) struct ClassificationHistogramBin {
    pub(super) counts: Vec<usize>,
}

impl ClassificationHistogramBin {
    pub(super) fn new(num_classes: usize) -> Self {
        Self {
            counts: vec![0usize; num_classes],
        }
    }

    pub(super) fn size(&self) -> usize {
        self.counts.iter().sum()
    }
}

impl HistogramBin for ClassificationHistogramBin {
    fn subtract(parent: &Self, child: &Self) -> Self {
        Self {
            counts: parent
                .counts
                .iter()
                .zip(child.counts.iter())
                .map(|(parent, child)| parent - child)
                .collect(),
        }
    }

    fn is_observed(&self) -> bool {
        self.counts.iter().any(|count| *count > 0)
    }
}

pub(super) type ClassificationFeatureHistogram = FeatureHistogram<ClassificationHistogramBin>;

pub(super) fn build_classification_node_histograms(
    table: &dyn TableAccess,
    class_indices: &[usize],
    rows: &[usize],
    num_classes: usize,
) -> Vec<ClassificationFeatureHistogram> {
    build_feature_histograms(
        table,
        rows,
        |_| ClassificationHistogramBin::new(num_classes),
        |_feature_index, payload, row_idx| {
            payload.counts[class_indices[row_idx]] += 1;
        },
    )
}

pub(super) fn subtract_classification_node_histograms(
    parent: &[ClassificationFeatureHistogram],
    child: &[ClassificationFeatureHistogram],
) -> Vec<ClassificationFeatureHistogram> {
    subtract_feature_histograms(parent, child)
}