pub mod device;
pub mod full_gpu;
pub mod kernels;
pub mod partition;
pub use device::GpuDevice;
pub use full_gpu::FullGpuTreeBuilder;
pub use kernels::GpuProfileData;
pub use partition::{NodeSplit, PartitionKernel, PartitionResult};
use std::sync::Arc;
use crate::backend::scalar::ScalarBackend;
use crate::backend::traits::{BinStorage, HistogramBackend, SplitCandidate, SplitConfig};
use crate::histogram::Histogram;
use crate::kernel;
use kernels::HistogramKernel;
pub struct WgpuBackend {
device: Arc<GpuDevice>,
kernel: HistogramKernel,
_cpu_fallback: ScalarBackend,
}
impl WgpuBackend {
pub fn new() -> Option<Self> {
let device = Arc::new(GpuDevice::new()?);
let kernel = HistogramKernel::new(device.clone());
Some(Self {
device,
kernel,
_cpu_fallback: ScalarBackend::new(),
})
}
pub fn device_name(&self) -> String {
self.device.name()
}
pub fn backend_type(&self) -> wgpu::Backend {
self.device.backend()
}
pub fn subgroups_available(&self) -> bool {
self.kernel.subgroups_available()
}
pub fn has_subgroups(&self) -> bool {
self.kernel.has_subgroups()
}
pub fn set_use_subgroups(&self, enabled: bool) {
self.kernel.set_use_subgroups(enabled);
}
pub fn subgroup_size(&self) -> (u32, u32) {
(self.device.min_subgroup_size, self.device.max_subgroup_size)
}
pub fn build_histograms_base_shader(
&self,
bins: &dyn BinStorage,
grad_hess: &[(f32, f32)],
row_indices: &[usize],
) -> Vec<Histogram> {
let num_rows = bins.num_rows();
let num_features = bins.num_features();
let bins_row_major: std::borrow::Cow<[u8]> = match bins.as_row_major() {
Some(data) => std::borrow::Cow::Borrowed(data),
None => {
let mut row_major = vec![0u8; num_rows * num_features];
for f in 0..num_features {
if let Some(col) = bins.feature_column(f) {
for r in 0..num_rows {
row_major[r * num_features + f] = col[r];
}
}
}
std::borrow::Cow::Owned(row_major)
}
};
self.kernel.build_histograms_base_shader(
&bins_row_major,
grad_hess,
row_indices,
num_rows,
num_features,
)
}
pub fn build_histograms_profiled(
&self,
bins: &dyn BinStorage,
grad_hess: &[(f32, f32)],
row_indices: &[usize],
) -> (Vec<Histogram>, GpuProfileData) {
let num_rows = bins.num_rows();
let num_features = bins.num_features();
let bins_row_major: std::borrow::Cow<[u8]> = match bins.as_row_major() {
Some(data) => std::borrow::Cow::Borrowed(data),
None => {
let mut row_major = vec![0u8; num_rows * num_features];
for f in 0..num_features {
if let Some(col) = bins.feature_column(f) {
for r in 0..num_rows {
row_major[r * num_features + f] = col[r];
}
}
}
std::borrow::Cow::Owned(row_major)
}
};
self.kernel.build_histograms_profiled(
&bins_row_major,
grad_hess,
row_indices,
num_rows,
num_features,
)
}
pub fn build_histograms_batched(
&self,
bins: &dyn BinStorage,
grad_hess: &[(f32, f32)],
batches: &[&[usize]],
) -> Vec<Vec<Histogram>> {
let num_rows = bins.num_rows();
let num_features = bins.num_features();
let bins_row_major: std::borrow::Cow<[u8]> = match bins.as_row_major() {
Some(data) => std::borrow::Cow::Borrowed(data),
None => {
let mut row_major = vec![0u8; num_rows * num_features];
for f in 0..num_features {
if let Some(col) = bins.feature_column(f) {
for r in 0..num_rows {
row_major[r * num_features + f] = col[r];
}
}
}
std::borrow::Cow::Owned(row_major)
}
};
self.kernel.build_histograms_batched(
&bins_row_major,
grad_hess,
batches,
num_rows,
num_features,
)
}
}
impl HistogramBackend for WgpuBackend {
fn name(&self) -> &'static str {
"WGPU"
}
fn is_tensor_tile(&self) -> bool {
true
}
fn build_histograms(
&self,
bins: &dyn BinStorage,
grad_hess: &[(f32, f32)],
row_indices: &[usize],
) -> Vec<Histogram> {
let num_rows = bins.num_rows();
let num_features = bins.num_features();
if bins.supports_4bit() {
if let Some(bins_4bit) = bins.as_row_major_4bit() {
return self.kernel.build_histograms_4bit(
bins_4bit,
grad_hess,
row_indices,
num_rows,
num_features,
);
}
}
let bins_row_major: std::borrow::Cow<[u8]> = match bins.as_row_major() {
Some(data) => std::borrow::Cow::Borrowed(data),
None => {
let mut row_major = vec![0u8; num_rows * num_features];
for f in 0..num_features {
if let Some(col) = bins.feature_column(f) {
for r in 0..num_rows {
row_major[r * num_features + f] = col[r];
}
}
}
std::borrow::Cow::Owned(row_major)
}
};
self.kernel.build_histograms(
&bins_row_major,
grad_hess,
row_indices,
num_rows,
num_features,
)
}
fn build_histograms_sibling(
&self,
parent: &[Histogram],
smaller_child: &[Histogram],
) -> Vec<Histogram> {
parent
.iter()
.zip(smaller_child.iter())
.map(|(p, s)| Histogram::from_subtraction(p, s))
.collect()
}
fn find_best_split(
&self,
histograms: &[Histogram],
config: &SplitConfig,
) -> Option<SplitCandidate> {
let mut best: Option<SplitCandidate> = None;
for (feature, hist) in histograms.iter().enumerate() {
let (total_grad, total_hess, total_count) = hist.totals();
if total_count < 2 * config.min_samples_leaf {
continue;
}
if let Some(candidate) = kernel::find_best_split(
&hist.sum_gradients(),
&hist.sum_hessians(),
&hist.counts(),
total_grad,
total_hess,
total_count,
config.lambda,
config.min_samples_leaf,
config.min_hessian_leaf,
) {
let split = SplitCandidate {
feature,
threshold: candidate.bin_threshold,
gain: candidate.gain,
left_gradient: candidate.left_gradient,
left_hessian: candidate.left_hessian,
left_count: candidate.left_count,
right_gradient: candidate.right_gradient,
right_hessian: candidate.right_hessian,
right_count: candidate.right_count,
};
if split.gain > config.min_gain {
match &best {
None => best = Some(split),
Some(b) if split.gain > b.gain => best = Some(split),
_ => {}
}
}
}
}
best
}
fn build_histograms_batched(
&self,
bins: &dyn BinStorage,
grad_hess: &[(f32, f32)],
batches: &[&[usize]],
) -> Vec<Vec<Histogram>> {
WgpuBackend::build_histograms_batched(self, bins, grad_hess, batches)
}
fn build_era_histograms(
&self,
bins: &dyn BinStorage,
grad_hess: &[(f32, f32)],
row_indices: &[usize],
era_indices: &[u16],
num_eras: usize,
) -> Vec<Vec<Histogram>> {
let num_rows = bins.num_rows();
let num_features = bins.num_features();
let bins_row_major: std::borrow::Cow<[u8]> = match bins.as_row_major() {
Some(data) => std::borrow::Cow::Borrowed(data),
None => {
let mut row_major = vec![0u8; num_rows * num_features];
for f in 0..num_features {
if let Some(col) = bins.feature_column(f) {
for r in 0..num_rows {
row_major[r * num_features + f] = col[r];
}
}
}
std::borrow::Cow::Owned(row_major)
}
};
self.kernel.build_era_histograms(
&bins_row_major,
grad_hess,
row_indices,
era_indices,
num_rows,
num_features,
num_eras,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::ScalarBackend;
use crate::dataset::{BinnedDataset, FeatureInfo, FeatureType};
#[test]
fn test_wgpu_backend_creation() {
match WgpuBackend::new() {
Some(backend) => {
println!("WGPU backend created: {}", backend.device_name());
assert_eq!(backend.name(), "WGPU");
assert!(backend.is_tensor_tile());
}
None => {
println!("No GPU available, skipping WGPU backend test");
}
}
}
fn create_test_dataset(num_rows: usize, num_features: usize) -> BinnedDataset {
let mut features = vec![0u8; num_rows * num_features];
for f in 0..num_features {
for r in 0..num_rows {
features[f * num_rows + r] = (r % 256) as u8;
}
}
let targets: Vec<f32> = (0..num_rows).map(|i| i as f32).collect();
let feature_info: Vec<FeatureInfo> = (0..num_features)
.map(|i| FeatureInfo {
name: format!("feature_{}", i),
feature_type: FeatureType::Numeric,
num_bins: 255,
bin_boundaries: vec![],
})
.collect();
BinnedDataset::new(num_rows, features, targets, feature_info)
}
#[test]
fn test_wgpu_matches_scalar_small() {
let wgpu_backend = match WgpuBackend::new() {
Some(b) => b,
None => {
println!("No GPU available, skipping WGPU vs Scalar comparison test");
return;
}
};
let scalar_backend = ScalarBackend::new();
let dataset = create_test_dataset(256, 4);
let grad_hess: Vec<(f32, f32)> = (0..256).map(|i| (i as f32 * 0.1, 1.0)).collect();
let row_indices: Vec<usize> = (0..256).collect();
let wgpu_hists = wgpu_backend.build_histograms(&dataset, &grad_hess, &row_indices);
let scalar_hists = scalar_backend.build_histograms(&dataset, &grad_hess, &row_indices);
assert_eq!(wgpu_hists.len(), scalar_hists.len());
for (f, (wgpu_hist, scalar_hist)) in wgpu_hists.iter().zip(scalar_hists.iter()).enumerate()
{
let (wgpu_total_grad, wgpu_total_hess, wgpu_total_count) = wgpu_hist.totals();
let (scalar_total_grad, scalar_total_hess, scalar_total_count) = scalar_hist.totals();
assert_eq!(
wgpu_total_count, scalar_total_count,
"Feature {}: count mismatch ({} vs {})",
f, wgpu_total_count, scalar_total_count
);
let grad_diff = (wgpu_total_grad - scalar_total_grad).abs();
let hess_diff = (wgpu_total_hess - scalar_total_hess).abs();
let grad_tolerance = scalar_total_grad.abs() * 0.001 + 0.001;
let hess_tolerance = scalar_total_hess.abs() * 0.001 + 0.001;
assert!(
grad_diff < grad_tolerance,
"Feature {}: gradient mismatch ({} vs {}, diff={})",
f,
wgpu_total_grad,
scalar_total_grad,
grad_diff
);
assert!(
hess_diff < hess_tolerance,
"Feature {}: hessian mismatch ({} vs {}, diff={})",
f,
wgpu_total_hess,
scalar_total_hess,
hess_diff
);
}
println!("WGPU matches Scalar for 256 rows × 4 features");
}
#[test]
fn test_wgpu_matches_scalar_medium() {
let wgpu_backend = match WgpuBackend::new() {
Some(b) => b,
None => {
println!("No GPU available, skipping WGPU vs Scalar comparison test");
return;
}
};
let scalar_backend = ScalarBackend::new();
let num_rows = 10000;
let num_features = 10;
let dataset = create_test_dataset(num_rows, num_features);
let grad_hess: Vec<(f32, f32)> = (0..num_rows)
.map(|i| {
let g = ((i as f32 * 0.01).sin() * 10.0) as f32;
let h = 1.0 + (i % 10) as f32 * 0.1;
(g, h)
})
.collect();
let row_indices: Vec<usize> = (0..num_rows).collect();
let wgpu_hists = wgpu_backend.build_histograms(&dataset, &grad_hess, &row_indices);
let scalar_hists = scalar_backend.build_histograms(&dataset, &grad_hess, &row_indices);
assert_eq!(wgpu_hists.len(), scalar_hists.len());
for f in 0..num_features {
let (wgpu_total_grad, wgpu_total_hess, wgpu_total_count) = wgpu_hists[f].totals();
let (scalar_total_grad, scalar_total_hess, scalar_total_count) =
scalar_hists[f].totals();
assert_eq!(
wgpu_total_count, scalar_total_count,
"Feature {}: count mismatch",
f
);
let grad_tolerance = scalar_total_grad.abs() * 0.01 + 0.1;
let hess_tolerance = scalar_total_hess.abs() * 0.01 + 0.1;
let grad_diff = (wgpu_total_grad - scalar_total_grad).abs();
let hess_diff = (wgpu_total_hess - scalar_total_hess).abs();
assert!(
grad_diff < grad_tolerance,
"Feature {}: gradient mismatch (GPU={}, CPU={}, diff={})",
f,
wgpu_total_grad,
scalar_total_grad,
grad_diff
);
assert!(
hess_diff < hess_tolerance,
"Feature {}: hessian mismatch (GPU={}, CPU={}, diff={})",
f,
wgpu_total_hess,
scalar_total_hess,
hess_diff
);
}
println!(
"WGPU matches Scalar for {} rows × {} features",
num_rows, num_features
);
}
#[test]
fn test_wgpu_with_row_indices() {
let wgpu_backend = match WgpuBackend::new() {
Some(b) => b,
None => {
println!("No GPU available, skipping WGPU row indices test");
return;
}
};
let scalar_backend = ScalarBackend::new();
let num_rows = 1000;
let dataset = create_test_dataset(num_rows, 5);
let grad_hess: Vec<(f32, f32)> = (0..num_rows).map(|i| (i as f32, 1.0)).collect();
let row_indices: Vec<usize> = (0..num_rows).filter(|i| i % 2 == 0).collect();
let wgpu_hists = wgpu_backend.build_histograms(&dataset, &grad_hess, &row_indices);
let scalar_hists = scalar_backend.build_histograms(&dataset, &grad_hess, &row_indices);
for f in 0..5 {
let (_, _, wgpu_count) = wgpu_hists[f].totals();
let (_, _, scalar_count) = scalar_hists[f].totals();
assert_eq!(
wgpu_count, scalar_count,
"Feature {}: count mismatch with row indices",
f
);
assert_eq!(
wgpu_count, 500,
"Expected 500 rows (half of 1000), got {}",
wgpu_count
);
}
println!("WGPU correctly handles row indices");
}
fn create_test_dataset_4bit(num_rows: usize, num_features: usize) -> BinnedDataset {
let mut features = vec![0u8; num_rows * num_features];
for f in 0..num_features {
for r in 0..num_rows {
features[f * num_rows + r] = (r % 16) as u8;
}
}
let targets: Vec<f32> = (0..num_rows).map(|i| i as f32).collect();
let feature_info: Vec<FeatureInfo> = (0..num_features)
.map(|i| FeatureInfo {
name: format!("feature_{}", i),
feature_type: FeatureType::Numeric,
num_bins: 16, bin_boundaries: vec![],
})
.collect();
BinnedDataset::new(num_rows, features, targets, feature_info)
}
#[test]
fn test_wgpu_4bit_matches_scalar() {
let wgpu_backend = match WgpuBackend::new() {
Some(b) => b,
None => {
println!("No GPU available, skipping WGPU 4-bit test");
return;
}
};
let scalar_backend = ScalarBackend::new();
let num_rows = 1000;
let num_features = 8;
let dataset = create_test_dataset_4bit(num_rows, num_features);
assert!(dataset.supports_4bit());
assert!(dataset.max_bins() <= 16);
let grad_hess: Vec<(f32, f32)> = (0..num_rows).map(|i| (i as f32 * 0.01, 1.0)).collect();
let row_indices: Vec<usize> = (0..num_rows).collect();
let wgpu_hists = wgpu_backend.build_histograms(&dataset, &grad_hess, &row_indices);
let scalar_hists = scalar_backend.build_histograms(&dataset, &grad_hess, &row_indices);
assert_eq!(wgpu_hists.len(), scalar_hists.len());
for (f, (wgpu_hist, scalar_hist)) in wgpu_hists.iter().zip(scalar_hists.iter()).enumerate()
{
let (wgpu_total_grad, wgpu_total_hess, wgpu_total_count) = wgpu_hist.totals();
let (scalar_total_grad, scalar_total_hess, scalar_total_count) = scalar_hist.totals();
assert_eq!(
wgpu_total_count, scalar_total_count,
"Feature {}: count mismatch ({} vs {})",
f, wgpu_total_count, scalar_total_count
);
let grad_diff = (wgpu_total_grad - scalar_total_grad).abs();
let hess_diff = (wgpu_total_hess - scalar_total_hess).abs();
let grad_tolerance = scalar_total_grad.abs() * 0.01 + 0.1;
let hess_tolerance = scalar_total_hess.abs() * 0.01 + 0.1;
assert!(
grad_diff < grad_tolerance,
"Feature {}: gradient mismatch ({} vs {}, diff={})",
f,
wgpu_total_grad,
scalar_total_grad,
grad_diff
);
assert!(
hess_diff < hess_tolerance,
"Feature {}: hessian mismatch ({} vs {}, diff={})",
f,
wgpu_total_hess,
scalar_total_hess,
hess_diff
);
}
println!(
"WGPU 4-bit matches Scalar for {} rows × {} features (max_bins={})",
num_rows,
num_features,
dataset.max_bins()
);
}
#[test]
fn test_wgpu_4bit_with_odd_features() {
let wgpu_backend = match WgpuBackend::new() {
Some(b) => b,
None => {
println!("No GPU available, skipping WGPU 4-bit odd features test");
return;
}
};
let scalar_backend = ScalarBackend::new();
let num_rows = 500;
let num_features = 7;
let dataset = create_test_dataset_4bit(num_rows, num_features);
assert!(dataset.supports_4bit());
let grad_hess: Vec<(f32, f32)> = (0..num_rows)
.map(|i| ((i as f32).sin() * 5.0, 1.0))
.collect();
let row_indices: Vec<usize> = (0..num_rows).collect();
let wgpu_hists = wgpu_backend.build_histograms(&dataset, &grad_hess, &row_indices);
let scalar_hists = scalar_backend.build_histograms(&dataset, &grad_hess, &row_indices);
for f in 0..num_features {
let (_, _, wgpu_count) = wgpu_hists[f].totals();
let (_, _, scalar_count) = scalar_hists[f].totals();
assert_eq!(
wgpu_count, scalar_count,
"Feature {}: count mismatch with odd features",
f
);
}
println!(
"WGPU 4-bit correctly handles {} features (odd count)",
num_features
);
}
}