forestfire-core 0.1.4

Core tree-learning algorithms for the ForestFire project.
Documentation
use super::*;

pub(super) fn partition_rows_for_multiway_split(
    table: &dyn TableAccess,
    feature_index: usize,
    branch_bins: &[u16],
    rows: &mut [usize],
) -> Vec<(u16, usize, usize)> {
    let mut scratch = vec![0usize; rows.len()];
    let mut counts = vec![0usize; branch_bins.len()];

    for row_idx in rows.iter().copied() {
        let bin = if table.is_binary_binned_feature(feature_index) {
            if table
                .binned_boolean_value(feature_index, row_idx)
                .expect("binary feature must expose boolean values")
            {
                1
            } else {
                0
            }
        } else {
            table.binned_value(feature_index, row_idx)
        };
        let branch_index = branch_bins
            .binary_search(&bin)
            .expect("branch bins must cover all observed bins");
        counts[branch_index] += 1;
    }

    let mut offsets = Vec::with_capacity(branch_bins.len());
    let mut next = 0usize;
    for count in &counts {
        offsets.push(next);
        next += *count;
    }
    let mut write_positions = offsets.clone();
    for row_idx in rows.iter().copied() {
        let bin = if table.is_binary_binned_feature(feature_index) {
            if table
                .binned_boolean_value(feature_index, row_idx)
                .expect("binary feature must expose boolean values")
            {
                1
            } else {
                0
            }
        } else {
            table.binned_value(feature_index, row_idx)
        };
        let branch_index = branch_bins
            .binary_search(&bin)
            .expect("branch bins must cover all observed bins");
        let write_index = write_positions[branch_index];
        scratch[write_index] = row_idx;
        write_positions[branch_index] += 1;
    }
    rows.copy_from_slice(&scratch);

    branch_bins
        .iter()
        .copied()
        .zip(offsets)
        .zip(counts)
        .map(|((bin, start), count)| (bin, start, start + count))
        .collect()
}