treeboost 0.1.0

High-performance Gradient Boosted Decision Tree engine for large-scale tabular data
Documentation
// TreeBoost GPU Histogram Building Shader - Subgroup Optimized
//
// Computes gradient/hessian histograms for GBDT training.
// One workgroup per feature, 256 threads per workgroup.
//
// Optimization: Uses subgroup operations to reduce atomic contention.
// When multiple threads in a subgroup write to the SAME bin, we use
// subgroupAdd to combine their values before doing a single atomic write.
//
// This helps when data has locality - adjacent rows often have the same
// bin value for a feature, so threads processing adjacent rows hit the same bin.
//
// Requires: wgpu::Features::SUBGROUP
//
// Fixed-point format (same as base shader):
// - Gradients/hessians packed as i16 pairs in u32 (2x bandwidth reduction)
// - Scale factor: 2^10 = 1024
//
// Note: The `enable subgroups;` directive is not used because naga (wgpu's WGSL
// compiler) hasn't fully implemented the WGSL subgroups extension yet.
// The subgroup builtins work directly when the SUBGROUP feature is requested.

// Uniform buffer with histogram parameters
struct Params {
    num_rows: u32,
    num_features: u32,
    num_indices: u32,
    num_batches: u32,
}

struct BatchInfo {
    start: u32,
    count: u32,
}

@group(0) @binding(0) var<uniform> params: Params;
@group(0) @binding(1) var<storage, read> bins: array<u32>;
@group(0) @binding(2) var<storage, read> grad_hess: array<u32>;
@group(0) @binding(3) var<storage, read> row_indices: array<u32>;

@group(0) @binding(4) var<storage, read_write> hist_grad: array<atomic<i32>>;
@group(0) @binding(5) var<storage, read_write> hist_hess: array<atomic<i32>>;
@group(0) @binding(6) var<storage, read_write> hist_counts: array<atomic<u32>>;

@group(0) @binding(11) var<storage, read> batch_info: array<BatchInfo>;

var<workgroup> local_grad: array<atomic<i32>, 256>;
var<workgroup> local_hess: array<atomic<i32>, 256>;
var<workgroup> local_counts: array<atomic<u32>, 256>;

fn get_bin(packed: u32, byte_idx: u32) -> u32 {
    return (packed >> (byte_idx * 8u)) & 0xFFu;
}

fn unpack_grad(packed: u32) -> i32 {
    let raw = packed & 0xFFFFu;
    if (raw & 0x8000u) != 0u {
        return i32(raw | 0xFFFF0000u);
    }
    return i32(raw);
}

fn unpack_hess(packed: u32) -> i32 {
    let raw = (packed >> 16u) & 0xFFFFu;
    if (raw & 0x8000u) != 0u {
        return i32(raw | 0xFFFF0000u);
    }
    return i32(raw);
}

// Subgroup-optimized histogram building
//
// Key optimization: Use subgroup broadcast and reduction to combine
// values from threads that hit the same bin.
//
// For each row:
// 1. Each thread gets its bin and values
// 2. Use subgroupBroadcastFirst to get the "pivot" bin from first active thread
// 3. Threads with matching bin contribute to subgroupAdd
// 4. First thread does one atomic for all matching threads
// 5. Repeat for non-matching threads (process in waves)
//
// Best case: All threads hit same bin -> 1 atomic instead of subgroup_size
// Worst case: All different bins -> same as baseline (each thread does atomic)

@compute @workgroup_size(256, 1, 1)
fn histogram_dense_subgroups(
    @builtin(workgroup_id) wg_id: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(subgroup_invocation_id) sg_id: u32,
) {
    let feature = wg_id.x;
    let thread_id = lid.x;
    let num_threads = 256u;

    // Initialize shared memory
    atomicStore(&local_grad[thread_id], 0i);
    atomicStore(&local_hess[thread_id], 0i);
    atomicStore(&local_counts[thread_id], 0u);

    workgroupBarrier();

    let total_rows = select(params.num_rows, params.num_indices, params.num_indices > 0u);

    // Process rows
    for (var i = thread_id; i < total_rows; i += num_threads) {
        let row = select(i, row_indices[i], params.num_indices > 0u);

        // Get bin value
        let bin_offset = row * params.num_features + feature;
        let packed_idx = bin_offset / 4u;
        let byte_idx = bin_offset % 4u;
        let bin = get_bin(bins[packed_idx], byte_idx);

        // Get gradient and hessian
        let packed_gh = grad_hess[row];
        let grad = unpack_grad(packed_gh);
        let hess = unpack_hess(packed_gh);

        // Subgroup optimization: Check if multiple threads have the same bin
        // Use ballot to find threads with same bin as first thread's bin
        let first_bin = subgroupBroadcastFirst(bin);
        let matches_first = (bin == first_bin);

        if matches_first {
            // Multiple threads might have this bin - use reduction
            let sum_grad = subgroupAdd(grad);
            let sum_hess = subgroupAdd(hess);
            let sum_count = subgroupAdd(1u);

            // First thread in subgroup does the atomic
            if sg_id == 0u {
                atomicAdd(&local_grad[first_bin], sum_grad);
                atomicAdd(&local_hess[first_bin], sum_hess);
                atomicAdd(&local_counts[first_bin], sum_count);
            }
        } else {
            // This thread has a different bin - do direct atomic
            // (These threads were excluded from the subgroup reduction above)
            atomicAdd(&local_grad[bin], grad);
            atomicAdd(&local_hess[bin], hess);
            atomicAdd(&local_counts[bin], 1u);
        }
    }

    workgroupBarrier();

    // Write to global memory
    let global_offset = feature * 256u + thread_id;
    let local_count = atomicLoad(&local_counts[thread_id]);

    if local_count > 0u {
        atomicAdd(&hist_grad[global_offset], atomicLoad(&local_grad[thread_id]));
        atomicAdd(&hist_hess[global_offset], atomicLoad(&local_hess[thread_id]));
        atomicAdd(&hist_counts[global_offset], local_count);
    }
}

// Zero histograms (same as base shader)
@compute @workgroup_size(256, 1, 1)
fn zero_histograms(
    @builtin(global_invocation_id) gid: vec3<u32>,
) {
    let idx = gid.x;
    let total_bins = params.num_features * 256u;

    if idx < total_bins {
        atomicStore(&hist_grad[idx], 0i);
        atomicStore(&hist_hess[idx], 0i);
        atomicStore(&hist_counts[idx], 0u);
    }
}

// Batched histogram with subgroups
@compute @workgroup_size(256, 1, 1)
fn histogram_batched_subgroups(
    @builtin(workgroup_id) wg_id: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(subgroup_invocation_id) sg_id: u32,
) {
    let feature = wg_id.x;
    let batch = wg_id.y;
    let thread_id = lid.x;
    let num_threads = 256u;

    let batch_start = batch_info[batch].start;
    let batch_count = batch_info[batch].count;

    atomicStore(&local_grad[thread_id], 0i);
    atomicStore(&local_hess[thread_id], 0i);
    atomicStore(&local_counts[thread_id], 0u);

    workgroupBarrier();

    for (var i = thread_id; i < batch_count; i += num_threads) {
        let row = row_indices[batch_start + i];

        let bin_offset = row * params.num_features + feature;
        let packed_idx = bin_offset / 4u;
        let byte_idx = bin_offset % 4u;
        let bin = get_bin(bins[packed_idx], byte_idx);

        let packed_gh = grad_hess[row];
        let grad = unpack_grad(packed_gh);
        let hess = unpack_hess(packed_gh);

        let first_bin = subgroupBroadcastFirst(bin);
        let matches_first = (bin == first_bin);

        if matches_first {
            let sum_grad = subgroupAdd(grad);
            let sum_hess = subgroupAdd(hess);
            let sum_count = subgroupAdd(1u);

            if sg_id == 0u {
                atomicAdd(&local_grad[first_bin], sum_grad);
                atomicAdd(&local_hess[first_bin], sum_hess);
                atomicAdd(&local_counts[first_bin], sum_count);
            }
        } else {
            atomicAdd(&local_grad[bin], grad);
            atomicAdd(&local_hess[bin], hess);
            atomicAdd(&local_counts[bin], 1u);
        }
    }

    workgroupBarrier();

    let hist_stride = params.num_features * 256u;
    let global_offset = batch * hist_stride + feature * 256u + thread_id;
    let local_count = atomicLoad(&local_counts[thread_id]);

    if local_count > 0u {
        atomicStore(&hist_grad[global_offset], atomicLoad(&local_grad[thread_id]));
        atomicStore(&hist_hess[global_offset], atomicLoad(&local_hess[thread_id]));
        atomicStore(&hist_counts[global_offset], local_count);
    }
}

// Zero batched histograms
@compute @workgroup_size(256, 1, 1)
fn zero_histograms_batched(
    @builtin(global_invocation_id) gid: vec3<u32>,
) {
    let idx = gid.x;
    let total_bins = params.num_batches * params.num_features * 256u;

    if idx < total_bins {
        atomicStore(&hist_grad[idx], 0i);
        atomicStore(&hist_hess[idx], 0i);
        atomicStore(&hist_counts[idx], 0u);
    }
}