treeboost 0.1.0

High-performance Gradient Boosted Decision Tree engine for large-scale tabular data
Documentation
// TreeBoost GPU Histogram Building Shader - Register Cache Optimization
//
// Reduces shared memory atomic contention by caching recent bins in registers.
// When consecutive rows hit the same bin, we accumulate locally in registers
// and only flush to shared memory when the bin changes.
//
// This is particularly effective for sorted or semi-sorted data where
// consecutive rows often have similar bin values.
//
// Based on CatBoost's register-resident accumulation pattern.

struct Params {
    num_rows: u32,
    num_features: u32,
    num_indices: u32,
    num_batches: u32,
}

struct BatchInfo {
    start: u32,
    count: u32,
}

@group(0) @binding(0) var<uniform> params: Params;
@group(0) @binding(1) var<storage, read> bins: array<u32>;
@group(0) @binding(2) var<storage, read> grad_hess: array<u32>;
@group(0) @binding(3) var<storage, read> row_indices: array<u32>;
@group(0) @binding(4) var<storage, read_write> hist_grad: array<atomic<i32>>;
@group(0) @binding(5) var<storage, read_write> hist_hess: array<atomic<i32>>;
@group(0) @binding(6) var<storage, read_write> hist_counts: array<atomic<u32>>;
@group(0) @binding(11) var<storage, read> batch_info: array<BatchInfo>;

var<workgroup> local_grad: array<atomic<i32>, 256>;
var<workgroup> local_hess: array<atomic<i32>, 256>;
var<workgroup> local_counts: array<atomic<u32>, 256>;

fn get_bin(packed: u32, byte_idx: u32) -> u32 {
    return (packed >> (byte_idx * 8u)) & 0xFFu;
}

fn unpack_grad(packed: u32) -> i32 {
    let raw = packed & 0xFFFFu;
    if (raw & 0x8000u) != 0u {
        return i32(raw | 0xFFFF0000u);
    }
    return i32(raw);
}

fn unpack_hess(packed: u32) -> i32 {
    let raw = (packed >> 16u) & 0xFFFFu;
    if (raw & 0x8000u) != 0u {
        return i32(raw | 0xFFFF0000u);
    }
    return i32(raw);
}

// Register cache size - trade-off between register pressure and cache hit rate
// Using 4 slots as a balance (CatBoost uses 4-8)
const CACHE_SIZE: u32 = 4u;
const INVALID_BIN: u32 = 0xFFFFFFFFu;

@compute @workgroup_size(256, 1, 1)
fn histogram_dense_register_cache(
    @builtin(workgroup_id) wg_id: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
) {
    let feature = wg_id.x;
    let thread_id = lid.x;
    let num_threads = 256u;

    // Initialize shared memory histogram to zero
    atomicStore(&local_grad[thread_id], 0i);
    atomicStore(&local_hess[thread_id], 0i);
    atomicStore(&local_counts[thread_id], 0u);

    workgroupBarrier();

    // Register cache: stores (bin, accumulated_grad, accumulated_hess, count)
    // Using 4 cache slots per thread
    var cache_bins: array<u32, 4>;
    var cache_grads: array<i32, 4>;
    var cache_hess: array<i32, 4>;
    var cache_counts: array<u32, 4>;

    // Initialize cache as empty
    cache_bins[0] = INVALID_BIN;
    cache_bins[1] = INVALID_BIN;
    cache_bins[2] = INVALID_BIN;
    cache_bins[3] = INVALID_BIN;
    cache_grads[0] = 0i;
    cache_grads[1] = 0i;
    cache_grads[2] = 0i;
    cache_grads[3] = 0i;
    cache_hess[0] = 0i;
    cache_hess[1] = 0i;
    cache_hess[2] = 0i;
    cache_hess[3] = 0i;
    cache_counts[0] = 0u;
    cache_counts[1] = 0u;
    cache_counts[2] = 0u;
    cache_counts[3] = 0u;

    // Next cache slot to use (round-robin eviction)
    var next_slot: u32 = 0u;

    let total_rows = select(params.num_rows, params.num_indices, params.num_indices > 0u);

    // Process rows with register caching
    for (var i = thread_id; i < total_rows; i += num_threads) {
        let row = select(i, row_indices[i], params.num_indices > 0u);

        // Get bin value for this feature
        let bin_offset = row * params.num_features + feature;
        let packed_idx = bin_offset / 4u;
        let byte_idx = bin_offset % 4u;
        let bin = get_bin(bins[packed_idx], byte_idx);

        // Get gradient and hessian
        let packed_gh = grad_hess[row];
        let grad = unpack_grad(packed_gh);
        let hess = unpack_hess(packed_gh);

        // Check if bin is in cache (linear search through 4 slots)
        var found = false;
        if cache_bins[0] == bin {
            cache_grads[0] += grad;
            cache_hess[0] += hess;
            cache_counts[0] += 1u;
            found = true;
        } else if cache_bins[1] == bin {
            cache_grads[1] += grad;
            cache_hess[1] += hess;
            cache_counts[1] += 1u;
            found = true;
        } else if cache_bins[2] == bin {
            cache_grads[2] += grad;
            cache_hess[2] += hess;
            cache_counts[2] += 1u;
            found = true;
        } else if cache_bins[3] == bin {
            cache_grads[3] += grad;
            cache_hess[3] += hess;
            cache_counts[3] += 1u;
            found = true;
        }

        if !found {
            // Cache miss - evict current slot and use it for new bin
            let evict_slot = next_slot;
            next_slot = (next_slot + 1u) % CACHE_SIZE;

            // Flush evicted entry to shared memory (if valid)
            let evict_bin = cache_bins[evict_slot];
            if evict_bin != INVALID_BIN {
                atomicAdd(&local_grad[evict_bin], cache_grads[evict_slot]);
                atomicAdd(&local_hess[evict_bin], cache_hess[evict_slot]);
                atomicAdd(&local_counts[evict_bin], cache_counts[evict_slot]);
            }

            // Store new entry in cache
            cache_bins[evict_slot] = bin;
            cache_grads[evict_slot] = grad;
            cache_hess[evict_slot] = hess;
            cache_counts[evict_slot] = 1u;
        }
    }

    // Flush remaining cache entries to shared memory
    for (var slot = 0u; slot < CACHE_SIZE; slot++) {
        let bin = cache_bins[slot];
        if bin != INVALID_BIN {
            atomicAdd(&local_grad[bin], cache_grads[slot]);
            atomicAdd(&local_hess[bin], cache_hess[slot]);
            atomicAdd(&local_counts[bin], cache_counts[slot]);
        }
    }

    workgroupBarrier();

    // Write shared memory histogram to global memory
    let global_offset = feature * 256u + thread_id;
    let local_count = atomicLoad(&local_counts[thread_id]);

    if local_count > 0u {
        atomicAdd(&hist_grad[global_offset], atomicLoad(&local_grad[thread_id]));
        atomicAdd(&hist_hess[global_offset], atomicLoad(&local_hess[thread_id]));
        atomicAdd(&hist_counts[global_offset], local_count);
    }
}

// Zero out histogram buffers
@compute @workgroup_size(256, 1, 1)
fn zero_histograms_register_cache(
    @builtin(global_invocation_id) gid: vec3<u32>,
) {
    let idx = gid.x;
    let total_bins = params.num_features * 256u;

    if idx < total_bins {
        atomicStore(&hist_grad[idx], 0i);
        atomicStore(&hist_hess[idx], 0i);
        atomicStore(&hist_counts[idx], 0u);
    }
}

// Batched version with register caching
@compute @workgroup_size(256, 1, 1)
fn histogram_batched_register_cache(
    @builtin(workgroup_id) wg_id: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
) {
    let feature = wg_id.x;
    let batch = wg_id.y;
    let thread_id = lid.x;
    let num_threads = 256u;

    let batch_start = batch_info[batch].start;
    let batch_count = batch_info[batch].count;

    // Initialize shared memory
    atomicStore(&local_grad[thread_id], 0i);
    atomicStore(&local_hess[thread_id], 0i);
    atomicStore(&local_counts[thread_id], 0u);

    workgroupBarrier();

    // Register cache
    var cache_bins: array<u32, 4>;
    var cache_grads: array<i32, 4>;
    var cache_hess: array<i32, 4>;
    var cache_counts: array<u32, 4>;

    cache_bins[0] = INVALID_BIN;
    cache_bins[1] = INVALID_BIN;
    cache_bins[2] = INVALID_BIN;
    cache_bins[3] = INVALID_BIN;
    cache_grads[0] = 0i;
    cache_grads[1] = 0i;
    cache_grads[2] = 0i;
    cache_grads[3] = 0i;
    cache_hess[0] = 0i;
    cache_hess[1] = 0i;
    cache_hess[2] = 0i;
    cache_hess[3] = 0i;
    cache_counts[0] = 0u;
    cache_counts[1] = 0u;
    cache_counts[2] = 0u;
    cache_counts[3] = 0u;

    var next_slot: u32 = 0u;

    for (var i = thread_id; i < batch_count; i += num_threads) {
        let row = row_indices[batch_start + i];

        let bin_offset = row * params.num_features + feature;
        let packed_idx = bin_offset / 4u;
        let byte_idx = bin_offset % 4u;
        let bin = get_bin(bins[packed_idx], byte_idx);

        let packed_gh = grad_hess[row];
        let grad = unpack_grad(packed_gh);
        let hess = unpack_hess(packed_gh);

        // Check cache
        var found = false;
        if cache_bins[0] == bin {
            cache_grads[0] += grad;
            cache_hess[0] += hess;
            cache_counts[0] += 1u;
            found = true;
        } else if cache_bins[1] == bin {
            cache_grads[1] += grad;
            cache_hess[1] += hess;
            cache_counts[1] += 1u;
            found = true;
        } else if cache_bins[2] == bin {
            cache_grads[2] += grad;
            cache_hess[2] += hess;
            cache_counts[2] += 1u;
            found = true;
        } else if cache_bins[3] == bin {
            cache_grads[3] += grad;
            cache_hess[3] += hess;
            cache_counts[3] += 1u;
            found = true;
        }

        if !found {
            let evict_slot = next_slot;
            next_slot = (next_slot + 1u) % CACHE_SIZE;

            let evict_bin = cache_bins[evict_slot];
            if evict_bin != INVALID_BIN {
                atomicAdd(&local_grad[evict_bin], cache_grads[evict_slot]);
                atomicAdd(&local_hess[evict_bin], cache_hess[evict_slot]);
                atomicAdd(&local_counts[evict_bin], cache_counts[evict_slot]);
            }

            cache_bins[evict_slot] = bin;
            cache_grads[evict_slot] = grad;
            cache_hess[evict_slot] = hess;
            cache_counts[evict_slot] = 1u;
        }
    }

    // Flush remaining cache
    for (var slot = 0u; slot < CACHE_SIZE; slot++) {
        let bin = cache_bins[slot];
        if bin != INVALID_BIN {
            atomicAdd(&local_grad[bin], cache_grads[slot]);
            atomicAdd(&local_hess[bin], cache_hess[slot]);
            atomicAdd(&local_counts[bin], cache_counts[slot]);
        }
    }

    workgroupBarrier();

    // Write to global memory
    let hist_stride = params.num_features * 256u;
    let global_offset = batch * hist_stride + feature * 256u + thread_id;
    let local_count = atomicLoad(&local_counts[thread_id]);

    if local_count > 0u {
        atomicStore(&hist_grad[global_offset], atomicLoad(&local_grad[thread_id]));
        atomicStore(&hist_hess[global_offset], atomicLoad(&local_hess[thread_id]));
        atomicStore(&hist_counts[global_offset], local_count);
    }
}

// Zero batched histograms
@compute @workgroup_size(256, 1, 1)
fn zero_histograms_batched_register_cache(
    @builtin(global_invocation_id) gid: vec3<u32>,
) {
    let idx = gid.x;
    let total_bins = params.num_batches * params.num_features * 256u;

    if idx < total_bins {
        atomicStore(&hist_grad[idx], 0i);
        atomicStore(&hist_hess[idx], 0i);
        atomicStore(&hist_counts[idx], 0u);
    }
}