forestfire-core 0.1.4

Core tree-learning algorithms for the ForestFire project.
Documentation
use super::histogram::ClassificationHistogramBin;
use super::*;

pub(super) struct SplitScoringContext<'a> {
    pub(super) table: &'a dyn TableAccess,
    pub(super) class_indices: &'a [usize],
    pub(super) num_classes: usize,
    pub(super) criterion: Criterion,
    pub(super) min_samples_leaf: usize,
}

#[derive(Debug, Clone, Copy)]
pub(super) enum MultiwayMetric {
    InformationGain,
    GainRatio,
}

pub(super) fn score_multiway_split_choice(
    context: &SplitScoringContext<'_>,
    feature_index: usize,
    rows: &[usize],
    metric: MultiwayMetric,
) -> Option<MultiwaySplitChoice> {
    let grouped_counts = if context.table.is_binary_binned_feature(feature_index) {
        let mut false_counts = vec![0usize; context.num_classes];
        let mut true_counts = vec![0usize; context.num_classes];
        let mut false_size = 0usize;
        let mut true_size = 0usize;
        for row_idx in rows {
            let class_index = context.class_indices[*row_idx];
            if !context
                .table
                .binned_boolean_value(feature_index, *row_idx)
                .expect("binary feature must expose boolean values")
            {
                false_counts[class_index] += 1;
                false_size += 1;
            } else {
                true_counts[class_index] += 1;
                true_size += 1;
            }
        }
        [
            (0u16, (false_size, false_counts)),
            (1u16, (true_size, true_counts)),
        ]
        .into_iter()
        .filter(|(_, (size, _))| *size > 0)
        .collect::<Vec<_>>()
    } else {
        let mut grouped = BTreeMap::<u16, (usize, Vec<usize>)>::new();
        for row_idx in rows {
            let bin = context.table.binned_value(feature_index, *row_idx);
            let entry = grouped
                .entry(bin)
                .or_insert_with(|| (0usize, vec![0usize; context.num_classes]));
            entry.0 += 1;
            entry.1[context.class_indices[*row_idx]] += 1;
        }
        grouped.into_iter().collect::<Vec<_>>()
    };

    if grouped_counts.len() <= 1
        || grouped_counts
            .iter()
            .any(|(_, (group_size, _))| *group_size < context.min_samples_leaf)
    {
        return None;
    }

    let parent_counts = class_counts(rows, context.class_indices, context.num_classes);
    let parent_impurity = classification_impurity(&parent_counts, rows.len(), context.criterion);
    let weighted_child_impurity = grouped_counts
        .iter()
        .map(|(_, (group_size, counts))| {
            (*group_size as f64 / rows.len() as f64)
                * classification_impurity(counts, *group_size, context.criterion)
        })
        .sum::<f64>();
    let information_gain = parent_impurity - weighted_child_impurity;

    let score = match metric {
        MultiwayMetric::InformationGain => information_gain,
        MultiwayMetric::GainRatio => {
            let split_info = grouped_counts
                .iter()
                .map(|(_, (group_size, _))| {
                    let probability = *group_size as f64 / rows.len() as f64;
                    -probability * probability.log2()
                })
                .sum::<f64>();
            if split_info == 0.0 {
                return None;
            }
            information_gain / split_info
        }
    };

    Some(MultiwaySplitChoice {
        feature_index,
        score,
        branch_bins: grouped_counts.into_iter().map(|(bin, _)| bin).collect(),
    })
}

pub(super) fn score_binary_split_choice_from_hist(
    context: &SplitScoringContext<'_>,
    histogram: &ClassificationFeatureHistogram,
    feature_index: usize,
    rows: &[usize],
    parent_counts: &[usize],
    algorithm: DecisionTreeAlgorithm,
) -> Option<BinarySplitChoice> {
    match (algorithm, histogram) {
        (
            DecisionTreeAlgorithm::Cart,
            ClassificationFeatureHistogram::Binary {
                false_bin,
                true_bin,
            },
        ) => score_binary_cart_split_choice_from_counts(
            context,
            feature_index,
            parent_counts,
            &false_bin.counts,
            false_bin.size(),
            &true_bin.counts,
            true_bin.size(),
        ),
        (
            DecisionTreeAlgorithm::Cart,
            ClassificationFeatureHistogram::Numeric {
                bins,
                observed_bins,
            },
        ) => score_numeric_cart_split_choice_from_hist(
            context,
            feature_index,
            parent_counts,
            rows.len(),
            bins,
            observed_bins,
        ),
        (
            DecisionTreeAlgorithm::Randomized,
            ClassificationFeatureHistogram::Binary {
                false_bin,
                true_bin,
            },
        ) => score_binary_cart_split_choice_from_counts(
            context,
            feature_index,
            parent_counts,
            &false_bin.counts,
            false_bin.size(),
            &true_bin.counts,
            true_bin.size(),
        ),
        (
            DecisionTreeAlgorithm::Randomized,
            ClassificationFeatureHistogram::Numeric { observed_bins, .. },
        ) => score_numeric_randomized_split_choice_from_hist(
            context,
            feature_index,
            rows,
            parent_counts,
            observed_bins,
            histogram,
        ),
        _ => None,
    }
}

fn score_binary_cart_split_choice_from_counts(
    context: &SplitScoringContext<'_>,
    feature_index: usize,
    parent_counts: &[usize],
    left_counts: &[usize],
    left_size: usize,
    right_counts: &[usize],
    right_size: usize,
) -> Option<BinarySplitChoice> {
    if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
        return None;
    }
    let parent_impurity =
        classification_impurity(parent_counts, left_size + right_size, context.criterion);
    let weighted_impurity = (left_size as f64 / (left_size + right_size) as f64)
        * classification_impurity(left_counts, left_size, context.criterion)
        + (right_size as f64 / (left_size + right_size) as f64)
            * classification_impurity(right_counts, right_size, context.criterion);
    Some(BinarySplitChoice {
        feature_index,
        score: parent_impurity - weighted_impurity,
        threshold_bin: 0,
    })
}

fn score_numeric_cart_split_choice_from_hist(
    context: &SplitScoringContext<'_>,
    feature_index: usize,
    parent_counts: &[usize],
    row_count: usize,
    bin_class_counts: &[ClassificationHistogramBin],
    observed_bins: &[usize],
) -> Option<BinarySplitChoice> {
    if observed_bins.len() <= 1 {
        return None;
    }
    let parent_impurity = classification_impurity(parent_counts, row_count, context.criterion);
    let mut left_counts = vec![0usize; context.num_classes];
    let mut left_size = 0usize;
    let mut best_threshold = None;
    let mut best_score = f64::NEG_INFINITY;

    for &bin in observed_bins {
        for (left_count, bin_count) in left_counts
            .iter_mut()
            .zip(bin_class_counts[bin].counts.iter())
        {
            *left_count += *bin_count;
        }
        left_size += bin_class_counts[bin].size();
        let right_size = row_count - left_size;
        if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
            continue;
        }
        let right_counts = parent_counts
            .iter()
            .zip(left_counts.iter())
            .map(|(parent, left)| parent - left)
            .collect::<Vec<_>>();
        let weighted_impurity = (left_size as f64 / row_count as f64)
            * classification_impurity(&left_counts, left_size, context.criterion)
            + (right_size as f64 / row_count as f64)
                * classification_impurity(&right_counts, right_size, context.criterion);
        let score = parent_impurity - weighted_impurity;
        if score > best_score {
            best_score = score;
            best_threshold = Some(bin as u16);
        }
    }

    best_threshold.map(|threshold_bin| BinarySplitChoice {
        feature_index,
        score: best_score,
        threshold_bin,
    })
}

fn score_numeric_randomized_split_choice_from_hist(
    context: &SplitScoringContext<'_>,
    feature_index: usize,
    rows: &[usize],
    parent_counts: &[usize],
    observed_bins: &[usize],
    histogram: &ClassificationFeatureHistogram,
) -> Option<BinarySplitChoice> {
    let candidate_thresholds = observed_bins
        .iter()
        .copied()
        .map(|bin| bin as u16)
        .collect::<Vec<_>>();
    let threshold_bin =
        choose_random_threshold(&candidate_thresholds, feature_index, rows, 0xC1A551F1u64)?;
    let ClassificationFeatureHistogram::Numeric { bins, .. } = histogram else {
        unreachable!("randomized numeric histogram must be numeric");
    };
    let mut left_counts = vec![0usize; context.num_classes];
    let mut left_size = 0usize;
    for bin in 0..=threshold_bin as usize {
        if bin >= bins.len() {
            break;
        }
        for (left_count, bin_count) in left_counts.iter_mut().zip(bins[bin].counts.iter()) {
            *left_count += *bin_count;
        }
        left_size += bins[bin].size();
    }
    let row_count = rows.len();
    let right_size = row_count - left_size;
    if left_size < context.min_samples_leaf || right_size < context.min_samples_leaf {
        return None;
    }
    let right_counts = parent_counts
        .iter()
        .zip(left_counts.iter())
        .map(|(parent, left)| parent - left)
        .collect::<Vec<_>>();
    let parent_impurity = classification_impurity(parent_counts, row_count, context.criterion);
    let weighted_impurity = (left_size as f64 / row_count as f64)
        * classification_impurity(&left_counts, left_size, context.criterion)
        + (right_size as f64 / row_count as f64)
            * classification_impurity(&right_counts, right_size, context.criterion);
    Some(BinarySplitChoice {
        feature_index,
        score: parent_impurity - weighted_impurity,
        threshold_bin,
    })
}