mod config;
pub mod scalar;
mod traits;
#[cfg(feature = "gpu")]
pub mod wgpu;
#[cfg(feature = "cuda")]
pub mod cuda;
pub use config::{BackendConfig, BackendPreset, BackendType, GpuMode};
pub use scalar::ScalarBackend;
pub use traits::{BinStorage, HistogramBackend, SparseColumn, SplitCandidate, SplitConfig};
#[cfg(feature = "gpu")]
pub use wgpu::WgpuBackend;
#[cfg(feature = "cuda")]
pub use cuda::CudaBackend;
#[derive(Debug)]
pub struct BackendSelector {
config: BackendConfig,
}
impl BackendSelector {
pub fn new() -> Self {
Self {
config: BackendConfig::default(),
}
}
pub fn with_config(config: BackendConfig) -> Self {
Self { config }
}
pub fn select(&self, num_rows: usize) -> Box<dyn HistogramBackend> {
if num_rows < self.config.tensor_tile_min_rows {
return Box::new(ScalarBackend::new());
}
match self.config.preferred {
BackendType::Auto => self.detect_best(),
BackendType::Scalar => Box::new(ScalarBackend::new()),
BackendType::Wgpu => self.try_wgpu_or_fallback(),
BackendType::Avx512 => self.try_avx512_or_fallback(),
BackendType::Sve2 => self.try_sve2_or_fallback(),
BackendType::Cuda => self.try_cuda_or_fallback(),
BackendType::Rocm => self.try_rocm_or_fallback(),
BackendType::Metal => self.try_metal_or_fallback(),
}
}
fn detect_best(&self) -> Box<dyn HistogramBackend> {
#[cfg(feature = "cuda")]
{
if let Some(backend) = cuda::CudaBackend::new() {
return Box::new(backend);
}
}
#[cfg(feature = "gpu")]
{
if let Some(backend) = wgpu::WgpuBackend::new() {
backend.set_use_subgroups(self.config.use_gpu_subgroups);
return Box::new(backend);
}
}
Box::new(ScalarBackend::new())
}
fn try_wgpu_or_fallback(&self) -> Box<dyn HistogramBackend> {
#[cfg(feature = "gpu")]
{
if let Some(backend) = wgpu::WgpuBackend::new() {
backend.set_use_subgroups(self.config.use_gpu_subgroups);
return Box::new(backend);
}
}
if self.config.fallback_to_scalar {
Box::new(ScalarBackend::new())
} else {
panic!("WGPU backend requested but not available (no GPU or 'gpu' feature disabled)")
}
}
fn try_avx512_or_fallback(&self) -> Box<dyn HistogramBackend> {
if self.config.fallback_to_scalar {
Box::new(ScalarBackend::new())
} else {
panic!("AVX-512 tensor-tile backend not yet implemented")
}
}
fn try_sve2_or_fallback(&self) -> Box<dyn HistogramBackend> {
if self.config.fallback_to_scalar {
Box::new(ScalarBackend::new())
} else {
panic!("SVE2 tensor-tile backend not yet implemented")
}
}
fn try_cuda_or_fallback(&self) -> Box<dyn HistogramBackend> {
#[cfg(feature = "cuda")]
{
if let Some(backend) = cuda::CudaBackend::new() {
return Box::new(backend);
}
}
if self.config.fallback_to_scalar {
Box::new(ScalarBackend::new())
} else {
panic!("CUDA backend requested but not available (no NVIDIA GPU or 'cuda' feature disabled)")
}
}
fn try_rocm_or_fallback(&self) -> Box<dyn HistogramBackend> {
if self.config.fallback_to_scalar {
Box::new(ScalarBackend::new())
} else {
panic!("ROCm backend not yet implemented")
}
}
fn try_metal_or_fallback(&self) -> Box<dyn HistogramBackend> {
if self.config.fallback_to_scalar {
Box::new(ScalarBackend::new())
} else {
panic!("Metal backend not yet implemented")
}
}
pub fn backend_name(&self, num_rows: usize) -> &'static str {
self.select(num_rows).name()
}
}
impl Default for BackendSelector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_selector_default() {
let selector = BackendSelector::new();
let backend = selector.select(1000);
assert!(backend.name().starts_with("Scalar"));
}
#[test]
fn test_backend_selector_small_dataset() {
let selector = BackendSelector::new();
let backend = selector.select(100);
assert!(backend.name().starts_with("Scalar"));
}
#[test]
fn test_backend_selector_large_dataset() {
let selector = BackendSelector::new();
let backend = selector.select(100_000);
assert!(
backend.name() == "CUDA"
|| backend.name() == "WGPU"
|| backend.name().starts_with("Scalar"),
"Expected CUDA, WGPU, or Scalar, got: {}",
backend.name()
);
}
#[test]
fn test_backend_config_scalar() {
let config = BackendConfig::scalar();
let selector = BackendSelector::with_config(config);
let backend = selector.select(1_000_000);
assert!(backend.name().starts_with("Scalar"));
}
#[test]
fn test_backend_config_prefer_gpu() {
let config = BackendConfig::prefer_gpu();
let selector = BackendSelector::with_config(config);
let backend = selector.select(100_000);
assert!(
backend.name() == "CUDA"
|| backend.name() == "WGPU"
|| backend.name().starts_with("Scalar"),
"Expected CUDA, WGPU, or Scalar, got: {}",
backend.name()
);
}
}