forestfire-core 0.2.0

Core tree-learning algorithms for the ForestFire project.
Documentation
use crate::sampling::sample_feature_subset;
use forestfire_data::{BINARY_MISSING_BIN, TableAccess};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::prelude::*;

const GOLDEN_GAMMA: u64 = 0x9E37_79B9_7F4A_7C15;
const MIX_MULTIPLIER_A: u64 = 0xBF58_476D_1CE4_E5B9;
const MIX_MULTIPLIER_B: u64 = 0x94D0_49BB_1331_11EB;

#[derive(Debug, Clone)]
pub(crate) enum FeatureHistogram<T> {
    Binary {
        false_bin: T,
        true_bin: T,
        missing_bin: T,
    },
    Numeric {
        bins: Vec<T>,
        observed_bins: Vec<usize>,
    },
}

pub(crate) trait HistogramBin: Clone {
    fn subtract(parent: &Self, child: &Self) -> Self;
    fn is_observed(&self) -> bool;
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum MissingBranchDirection {
    Left,
    Right,
    Node,
}

#[inline]
pub(crate) fn numeric_missing_bin(table: &dyn TableAccess) -> u16 {
    table.numeric_bin_cap() as u16
}

#[allow(dead_code)]
#[inline]
pub(crate) fn missing_bin_for_feature(table: &dyn TableAccess, feature_index: usize) -> u16 {
    if table.is_binary_binned_feature(feature_index) {
        BINARY_MISSING_BIN
    } else {
        numeric_missing_bin(table)
    }
}

fn build_feature_histogram_for_feature<T, MakeBin, AddRow>(
    table: &dyn TableAccess,
    rows: &[usize],
    feature_index: usize,
    make_bin: &MakeBin,
    add_row: &AddRow,
) -> FeatureHistogram<T>
where
    T: HistogramBin,
    MakeBin: Fn(usize) -> T,
    AddRow: Fn(usize, &mut T, usize),
{
    if table.is_binary_binned_feature(feature_index) {
        let mut false_bin = make_bin(feature_index);
        let mut true_bin = make_bin(feature_index);
        let mut missing_bin = make_bin(feature_index);
        for &row_idx in rows {
            match table.binned_boolean_value(feature_index, row_idx) {
                Some(false) => add_row(feature_index, &mut false_bin, row_idx),
                Some(true) => add_row(feature_index, &mut true_bin, row_idx),
                None => add_row(feature_index, &mut missing_bin, row_idx),
            }
        }
        FeatureHistogram::Binary {
            false_bin,
            true_bin,
            missing_bin,
        }
    } else {
        let bin_cap = table.numeric_bin_cap() + 1;
        let mut bins = vec![make_bin(feature_index); bin_cap];
        for &row_idx in rows {
            let bin = table.binned_value(feature_index, row_idx) as usize;
            add_row(feature_index, &mut bins[bin], row_idx);
        }
        let observed_bins = bins
            .iter()
            .enumerate()
            .filter_map(|(bin, payload)| payload.is_observed().then_some(bin))
            .collect();
        FeatureHistogram::Numeric {
            bins,
            observed_bins,
        }
    }
}

pub(crate) fn build_feature_histograms<T, MakeBin, AddRow>(
    table: &dyn TableAccess,
    rows: &[usize],
    make_bin: MakeBin,
    add_row: AddRow,
) -> Vec<FeatureHistogram<T>>
where
    T: HistogramBin,
    MakeBin: Fn(usize) -> T,
    AddRow: Fn(usize, &mut T, usize),
{
    (0..table.binned_feature_count())
        .map(|feature_index| {
            build_feature_histogram_for_feature(table, rows, feature_index, &make_bin, &add_row)
        })
        .collect()
}

pub(crate) fn build_feature_histograms_parallel<T, MakeBin, AddRow>(
    table: &dyn TableAccess,
    rows: &[usize],
    make_bin: MakeBin,
    add_row: AddRow,
) -> Vec<FeatureHistogram<T>>
where
    T: HistogramBin + Send,
    MakeBin: Fn(usize) -> T + Sync,
    AddRow: Fn(usize, &mut T, usize) + Sync,
{
    (0..table.binned_feature_count())
        .into_par_iter()
        .map(|feature_index| {
            build_feature_histogram_for_feature(table, rows, feature_index, &make_bin, &add_row)
        })
        .collect()
}

pub(crate) fn subtract_feature_histograms<T: HistogramBin>(
    parent: &[FeatureHistogram<T>],
    child: &[FeatureHistogram<T>],
) -> Vec<FeatureHistogram<T>> {
    parent
        .iter()
        .zip(child.iter())
        .map(
            |(parent_hist, child_hist)| match (parent_hist, child_hist) {
                (
                    FeatureHistogram::Binary {
                        false_bin: parent_false,
                        true_bin: parent_true,
                        missing_bin: parent_missing,
                    },
                    FeatureHistogram::Binary {
                        false_bin: child_false,
                        true_bin: child_true,
                        missing_bin: child_missing,
                    },
                ) => FeatureHistogram::Binary {
                    false_bin: T::subtract(parent_false, child_false),
                    true_bin: T::subtract(parent_true, child_true),
                    missing_bin: T::subtract(parent_missing, child_missing),
                },
                (
                    FeatureHistogram::Numeric {
                        bins: parent_bins, ..
                    },
                    FeatureHistogram::Numeric {
                        bins: child_bins, ..
                    },
                ) => {
                    let bins = parent_bins
                        .iter()
                        .zip(child_bins.iter())
                        .map(|(parent_bin, child_bin)| T::subtract(parent_bin, child_bin))
                        .collect::<Vec<_>>();
                    let observed_bins = bins
                        .iter()
                        .enumerate()
                        .filter_map(|(bin, payload)| payload.is_observed().then_some(bin))
                        .collect::<Vec<_>>();
                    FeatureHistogram::Numeric {
                        bins,
                        observed_bins,
                    }
                }
                _ => unreachable!("histogram shapes must match"),
            },
        )
        .collect()
}

pub(crate) fn choose_random_threshold(
    candidate_thresholds: &[u16],
    feature_index: usize,
    rows: &[usize],
    salt: u64,
) -> Option<u16> {
    if candidate_thresholds.is_empty() {
        return None;
    }

    let seed = avalanche64(
        salt ^ mix_seed(feature_index as u64, candidate_thresholds.len() as u64)
            ^ fingerprint_rows(rows)
            ^ fingerprint_thresholds(candidate_thresholds),
    );
    let mut rng = StdRng::seed_from_u64(seed);
    let selected = rng.gen_range(0..candidate_thresholds.len());
    candidate_thresholds.get(selected).copied()
}

pub(crate) fn partition_rows_for_binary_split(
    table: &dyn TableAccess,
    feature_index: usize,
    threshold_bin: u16,
    missing_direction: MissingBranchDirection,
    rows: &mut [usize],
) -> usize {
    let mut left = 0usize;
    for index in 0..rows.len() {
        let go_left = if table.is_missing(feature_index, rows[index]) {
            matches!(missing_direction, MissingBranchDirection::Left)
        } else if table.is_binary_binned_feature(feature_index) {
            !table
                .binned_boolean_value(feature_index, rows[index])
                .expect("observed binary feature must expose boolean values")
        } else {
            table.binned_value(feature_index, rows[index]) <= threshold_bin
        };
        if go_left {
            rows.swap(left, index);
            left += 1;
        }
    }
    left
}

pub(crate) fn candidate_feature_indices(
    feature_count: usize,
    max_features: Option<usize>,
    seed: u64,
) -> Vec<usize> {
    match max_features {
        Some(count) => sample_feature_subset(feature_count, count, seed),
        None => (0..feature_count).collect(),
    }
}

pub(crate) fn mix_seed(base_seed: u64, value: u64) -> u64 {
    base_seed ^ value.wrapping_mul(GOLDEN_GAMMA).rotate_left(17)
}

pub(crate) fn node_seed(base_seed: u64, depth: usize, rows: &[usize], salt: u64) -> u64 {
    avalanche64(mix_seed(base_seed ^ salt, depth as u64) ^ fingerprint_rows(rows))
}

fn avalanche64(mut value: u64) -> u64 {
    value ^= value >> 30;
    value = value.wrapping_mul(MIX_MULTIPLIER_A);
    value ^= value >> 27;
    value = value.wrapping_mul(MIX_MULTIPLIER_B);
    value ^ (value >> 31)
}

fn fingerprint_rows(rows: &[usize]) -> u64 {
    let mut xor = 0u64;
    let mut sum = 0u64;
    let mut rotated_sum = 0u64;

    for &row in rows {
        let mixed = avalanche64((row as u64).wrapping_add(GOLDEN_GAMMA));
        xor ^= mixed;
        sum = sum.wrapping_add(mixed);
        rotated_sum = rotated_sum.wrapping_add(mixed.rotate_left((row as u32) & 63));
    }

    avalanche64(
        xor ^ sum.rotate_left(7)
            ^ rotated_sum.rotate_left(19)
            ^ (rows.len() as u64).wrapping_mul(MIX_MULTIPLIER_A),
    )
}

fn fingerprint_thresholds(candidate_thresholds: &[u16]) -> u64 {
    let mut xor = 0u64;
    let mut sum = 0u64;

    for (index, threshold) in candidate_thresholds.iter().copied().enumerate() {
        let mixed = avalanche64((threshold as u64).wrapping_add((index as u64) << 16));
        xor ^= mixed;
        sum = sum.wrapping_add(mixed.rotate_left((index as u32) & 63));
    }

    avalanche64(
        xor ^ sum.rotate_left(13)
            ^ (candidate_thresholds.len() as u64).wrapping_mul(MIX_MULTIPLIER_B),
    )
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::BTreeSet;

    #[test]
    fn mix_seed_is_deterministic_and_unique_across_many_values() {
        let mixed = (0..4096u64)
            .map(|value| mix_seed(17, value))
            .collect::<Vec<_>>();
        let unique = mixed.iter().copied().collect::<BTreeSet<_>>();

        assert_eq!(mixed.len(), unique.len());
        assert_eq!(mixed[123], mix_seed(17, 123));
    }

    #[test]
    fn node_seed_is_order_invariant_for_same_row_set() {
        let rows = vec![9usize, 3, 1, 7, 11, 5];
        let mut reversed = rows.clone();
        reversed.reverse();

        assert_eq!(node_seed(41, 2, &rows, 99), node_seed(41, 2, &reversed, 99));
    }

    #[test]
    fn node_seed_changes_when_depth_salt_or_row_membership_changes() {
        let rows = vec![1usize, 2, 3, 4];
        let with_extra = vec![1usize, 2, 3, 4, 5];
        let base = node_seed(41, 2, &rows, 99);

        assert_ne!(base, node_seed(41, 3, &rows, 99));
        assert_ne!(base, node_seed(41, 2, &rows, 100));
        assert_ne!(base, node_seed(41, 2, &with_extra, 99));
    }

    #[test]
    fn choose_random_threshold_is_order_invariant_for_same_row_set() {
        let thresholds = vec![1u16, 3, 7, 9, 11];
        let rows = vec![8usize, 2, 6, 4, 10];
        let mut permuted = rows.clone();
        permuted.rotate_left(2);

        assert_eq!(
            choose_random_threshold(&thresholds, 5, &rows, 1234),
            choose_random_threshold(&thresholds, 5, &permuted, 1234)
        );
    }

    #[test]
    fn feature_subset_sampling_stays_well_formed_under_many_seeds() {
        let feature_count = 32usize;
        let sample_size = 6usize;
        let mut counts = vec![0usize; feature_count];

        for seed in 0..4096u64 {
            let sample = candidate_feature_indices(feature_count, Some(sample_size), seed);
            let unique = sample.iter().copied().collect::<BTreeSet<_>>();

            assert_eq!(sample.len(), sample_size);
            assert_eq!(unique.len(), sample_size);
            assert!(sample.iter().all(|feature| *feature < feature_count));

            for feature in sample {
                counts[feature] += 1;
            }
        }

        let expected = (4096 * sample_size / feature_count) as isize;
        let min = *counts.iter().min().unwrap() as isize;
        let max = *counts.iter().max().unwrap() as isize;

        assert!(
            min >= expected - 280,
            "min count {min} too far below {expected}"
        );
        assert!(
            max <= expected + 280,
            "max count {max} too far above {expected}"
        );
    }

    #[test]
    fn randomized_threshold_selection_covers_candidates_without_extreme_bias() {
        let thresholds = (0u16..8).collect::<Vec<_>>();
        let mut counts = vec![0usize; thresholds.len()];

        for context in 0..4096usize {
            let rows = (0..17usize)
                .map(|offset| (context * 37 + offset * 13) % 257)
                .collect::<Vec<_>>();
            let selected =
                choose_random_threshold(&thresholds, context % 11, &rows, 0xC1A5_5EED).unwrap();
            counts[selected as usize] += 1;
        }

        assert!(counts.iter().all(|count| *count > 300));
        assert!(counts.iter().all(|count| *count < 800));
    }
}