treeboost 0.1.0

High-performance Gradient Boosted Decision Tree engine for large-scale tabular data
Documentation
//! Backend trait definitions for histogram building and split finding.
//!
//! These traits abstract over different hardware backends:
//! - Scalar (AVX2/NEON loads, scalar scatter) - current implementation
//! - WGPU (all GPUs via Vulkan/Metal/DX12) - future
//! - AVX-512 tensor-tile (vpconflictd) - future
//! - SVE2 tensor-tile (HISTCNT) - future
//! - Native backends: CUDA, ROCm, Metal - future extreme optimization

use crate::defaults::split as split_defaults;
use crate::histogram::Histogram;

// Re-export SparseColumn from dataset to avoid duplication
pub use crate::dataset::SparseColumn;

/// Trait for accessing binned feature data.
///
/// Abstracts over different data layouts:
/// - Column-major (`bins[feature][row]`) for scalar backend
/// - Row-major (`bins[row][feature]`) for tensor-tile backends (GPU/AVX-512/SVE2)
pub trait BinStorage: Sync {
    /// Get the bin value for a specific row and feature.
    fn get_bin(&self, row: usize, feature: usize) -> u8;

    /// Total number of rows in the dataset.
    fn num_rows(&self) -> usize;

    /// Total number of features.
    fn num_features(&self) -> usize;

    /// Get a feature column as a contiguous slice (for scalar backend).
    /// Returns None if the storage is row-major.
    fn feature_column(&self, feature: usize) -> Option<&[u8]>;

    /// Get sparse column representation if available.
    fn sparse_column(&self, feature: usize) -> Option<&SparseColumn>;

    /// Check if a feature has sparse representation.
    fn is_sparse(&self, feature: usize) -> bool {
        self.sparse_column(feature).is_some()
    }

    /// Get the entire dataset as row-major layout (for tensor-tile backends).
    /// Returns None if the storage is column-major.
    fn as_row_major(&self) -> Option<&[u8]> {
        None
    }

    /// Get maximum number of bins across all features.
    fn max_bins(&self) -> u8 {
        255 // Default: assume 8-bit bins
    }

    /// Check if dataset supports 4-bit bin packing.
    ///
    /// Returns true if all features have ≤16 bins.
    fn supports_4bit(&self) -> bool {
        self.max_bins() <= 16
    }

    /// Get 4-bit packed row-major layout (for tensor-tile backends with small bins).
    ///
    /// Packs two 4-bit bin values per byte:
    /// - byte[i] = (feature[2i+1] << 4) | feature[2i]
    ///
    /// Returns None if not supported or storage is column-major.
    fn as_row_major_4bit(&self) -> Option<&[u8]> {
        None
    }

    /// Get bytes per row in 4-bit packed format.
    fn bytes_per_row_4bit(&self) -> usize {
        self.num_features().div_ceil(2)
    }
}

/// Configuration for split finding.
#[derive(Debug, Clone, Copy)]
pub struct SplitConfig {
    /// L2 regularization parameter
    pub lambda: f32,
    /// Minimum samples required in each leaf
    pub min_samples_leaf: u32,
    /// Minimum hessian sum required in each leaf
    pub min_hessian_leaf: f32,
    /// Minimum gain to accept a split
    pub min_gain: f32,
    /// Shannon entropy regularization weight
    pub entropy_weight: f32,
}

impl Default for SplitConfig {
    fn default() -> Self {
        Self {
            lambda: split_defaults::DEFAULT_SPLIT_LAMBDA,
            min_samples_leaf: split_defaults::DEFAULT_SPLIT_MIN_SAMPLES_LEAF,
            min_hessian_leaf: split_defaults::DEFAULT_SPLIT_MIN_HESSIAN_LEAF,
            min_gain: split_defaults::DEFAULT_SPLIT_MIN_GAIN,
            entropy_weight: split_defaults::DEFAULT_SPLIT_ENTROPY_WEIGHT,
        }
    }
}

/// Result of split finding for a single feature.
#[derive(Debug, Clone, Copy)]
pub struct SplitCandidate {
    /// Feature index
    pub feature: usize,
    /// Bin threshold (split at bin <= threshold)
    pub threshold: u8,
    /// Gain from this split
    pub gain: f32,
    /// Left child gradient sum
    pub left_gradient: f32,
    /// Left child hessian sum
    pub left_hessian: f32,
    /// Left child sample count
    pub left_count: u32,
    /// Right child gradient sum
    pub right_gradient: f32,
    /// Right child hessian sum
    pub right_hessian: f32,
    /// Right child sample count
    pub right_count: u32,
}

/// Main trait for histogram building backends.
///
/// Different backends can implement this trait to provide hardware-accelerated
/// histogram building:
/// - `ScalarBackend`: Current CPU implementation (AVX2/NEON loads)
/// - `WgpuBackend`: GPU via WGPU (Vulkan/Metal/DX12) - future
/// - `Avx512Backend`: AVX-512 tensor-tile with vpconflictd - future
/// - `Sve2Backend`: ARM SVE2 tensor-tile with HISTCNT - future
pub trait HistogramBackend: Send + Sync {
    /// Human-readable name for this backend.
    fn name(&self) -> &'static str;

    /// Whether this backend uses tensor-tile (2D row-major) layout.
    /// True for GPU/AVX-512/SVE2, false for scalar.
    fn is_tensor_tile(&self) -> bool;

    /// Build histograms for all features at a tree node.
    ///
    /// # Arguments
    /// * `bins` - Binned feature data
    /// * `grad_hess` - Interleaved (gradient, hessian) pairs for each row
    /// * `row_indices` - Which rows belong to this node
    ///
    /// # Returns
    /// A vector of histograms, one per feature.
    fn build_histograms(
        &self,
        bins: &dyn BinStorage,
        grad_hess: &[(f32, f32)],
        row_indices: &[usize],
    ) -> Vec<Histogram>;

    /// Build sibling histogram using the subtraction trick.
    ///
    /// For a parent node with histogram H_parent, if we compute the smaller
    /// child histogram H_smaller, we can derive the larger child as:
    /// H_larger = H_parent - H_smaller
    ///
    /// This halves the computation for child histogram building.
    fn build_histograms_sibling(
        &self,
        parent: &[Histogram],
        smaller_child: &[Histogram],
    ) -> Vec<Histogram>;

    /// Find the best split for each feature.
    ///
    /// # Arguments
    /// * `histograms` - One histogram per feature
    /// * `config` - Split finding configuration
    ///
    /// # Returns
    /// The best split candidate, or None if no valid split exists.
    fn find_best_split(
        &self,
        histograms: &[Histogram],
        config: &SplitConfig,
    ) -> Option<SplitCandidate>;

    /// Build histograms for multiple batches in a single dispatch.
    ///
    /// This method allows GPU backends to batch multiple small histogram builds
    /// into a single dispatch, amortizing dispatch overhead.
    ///
    /// # Arguments
    /// * `bins` - Binned feature data
    /// * `grad_hess` - Interleaved (gradient, hessian) pairs for each row
    /// * `batches` - Slice of row index slices, one per batch
    ///
    /// # Returns
    /// A vector of histogram vectors, one per batch.
    ///
    /// # Default Implementation
    /// Falls back to individual `build_histograms` calls.
    fn build_histograms_batched(
        &self,
        bins: &dyn BinStorage,
        grad_hess: &[(f32, f32)],
        batches: &[&[usize]],
    ) -> Vec<Vec<Histogram>> {
        // Default: fall back to individual builds
        batches
            .iter()
            .map(|row_indices| self.build_histograms(bins, grad_hess, row_indices))
            .collect()
    }

    /// Build era-stratified histograms for Directional Era Splitting (DES).
    ///
    /// Returns histograms indexed as `[era][feature]`, enabling directional
    /// agreement checks across eras during split finding.
    ///
    /// # Arguments
    /// * `bins` - Binned feature data
    /// * `grad_hess` - Interleaved (gradient, hessian) pairs for each row
    /// * `row_indices` - Which rows belong to this node
    /// * `era_indices` - Era index for each row in the full dataset
    /// * `num_eras` - Total number of unique eras
    ///
    /// # Returns
    /// A 2D vector of histograms: `[num_eras][num_features]`
    ///
    /// # Default Implementation
    /// Falls back to CPU-based era histogram building.
    fn build_era_histograms(
        &self,
        bins: &dyn BinStorage,
        grad_hess: &[(f32, f32)],
        row_indices: &[usize],
        era_indices: &[u16],
        num_eras: usize,
    ) -> Vec<Vec<Histogram>> {
        // Default: CPU-based era histogram building
        let num_features = bins.num_features();
        let mut result = vec![vec![Histogram::new(); num_features]; num_eras];

        for &row in row_indices {
            let era = era_indices[row] as usize;
            let (g, h) = grad_hess[row];

            for (f, hist) in result[era].iter_mut().enumerate() {
                let bin = bins.get_bin(row, f);
                hist.accumulate(bin, g, h);
            }
        }

        result
    }
}