gllm_kernels/
backend.rs

1use burn::tensor::backend::Backend;
2
3/// CUDA backend (NVIDIA GPU)
4#[cfg(all(feature = "cuda", feature = "fusion"))]
5pub type DefaultBackend = burn_fusion::Fusion<burn_cuda::Cuda>;
6
7/// CUDA backend (NVIDIA GPU)
8#[cfg(all(feature = "cuda", not(feature = "fusion")))]
9pub type DefaultBackend = burn_cuda::Cuda;
10
11/// WGPU backend (WebGPU/Vulkan - cross-platform GPU)
12#[cfg(all(feature = "wgpu", not(feature = "cuda")))]
13pub type DefaultBackend = burn_wgpu::Wgpu;
14
15/// CPU backend (fallback)
16#[cfg(all(feature = "cpu", not(feature = "cuda"), not(feature = "wgpu")))]
17pub type DefaultBackend = burn_ndarray::NdArray<f32>;
18
19/// Select the default device for a backend
20pub fn select_device<B: Backend>() -> B::Device {
21    B::Device::default()
22}