#![cfg_attr(docsrs, feature(doc_cfg))]
extern crate alloc;
#[cfg(feature = "template")]
pub use burn_cubecl::{
kernel::{KernelMetadata, into_contiguous},
kernel_source,
template::{KernelSource, SourceKernel, SourceTemplate, build_info},
};
pub use burn_cubecl::{BoolElement, FloatElement, IntElement};
pub use burn_cubecl::{CubeBackend, tensor::CubeTensor};
pub use cubecl::CubeDim;
pub use cubecl::flex32;
pub use cubecl::wgpu::{
AutoCompiler, MemoryConfiguration, RuntimeOptions, WgpuDevice, WgpuResource, WgpuRuntime,
WgpuSetup, WgpuStorage, init_device, init_setup, init_setup_async,
};
pub mod graphics {
pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu};
}
#[cfg(feature = "cubecl-wgsl")]
pub use cubecl::wgpu::WgslCompiler;
#[cfg(feature = "cubecl-spirv")]
pub use cubecl::wgpu::vulkan::VkSpirvCompiler;
#[cfg(feature = "fusion")]
pub type Wgpu<F = f32, I = i32, B = u32> =
burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>>;
#[cfg(not(feature = "fusion"))]
pub type Wgpu<F = f32, I = i32, B = u32> = CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>;
#[cfg(feature = "vulkan")]
pub type Vulkan<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
#[cfg(feature = "webgpu")]
pub type WebGpu<F = f32, I = i32, B = u32> = Wgpu<F, I, B>;
#[cfg(feature = "metal")]
pub type Metal<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
#[cfg(test)]
mod tests {
use super::*;
use burn_backend::{Backend, DType, QTensorPrimitive};
#[test]
fn should_support_dtypes() {
type B = Wgpu;
let device = Default::default();
assert!(B::supports_dtype(&device, DType::F32));
assert!(B::supports_dtype(&device, DType::I64));
assert!(B::supports_dtype(&device, DType::I32));
assert!(B::supports_dtype(&device, DType::U64));
assert!(B::supports_dtype(&device, DType::U32));
assert!(B::supports_dtype(
&device,
DType::QFloat(CubeTensor::<WgpuRuntime>::default_scheme())
));
assert!(B::supports_dtype(&device, DType::Bool));
#[cfg(feature = "vulkan")]
{
assert!(B::supports_dtype(&device, DType::F16));
assert!(B::supports_dtype(&device, DType::I16));
assert!(B::supports_dtype(&device, DType::I8));
assert!(B::supports_dtype(&device, DType::U16));
assert!(B::supports_dtype(&device, DType::U8));
assert!(!B::supports_dtype(&device, DType::F64));
assert!(!B::supports_dtype(&device, DType::Flex32));
assert!(!B::supports_dtype(&device, DType::BF16));
}
#[cfg(feature = "metal")]
{
assert!(B::supports_dtype(&device, DType::F16));
assert!(B::supports_dtype(&device, DType::I16));
assert!(B::supports_dtype(&device, DType::I8));
assert!(B::supports_dtype(&device, DType::U16));
assert!(B::supports_dtype(&device, DType::U8));
assert!(!B::supports_dtype(&device, DType::F64));
assert!(!B::supports_dtype(&device, DType::BF16));
assert!(!B::supports_dtype(&device, DType::Flex32));
}
#[cfg(all(not(any(feature = "vulkan", feature = "metal")), target_os = "macos"))]
{
assert!(B::supports_dtype(&device, DType::Flex32));
assert!(B::supports_dtype(&device, DType::F16));
assert!(!B::supports_dtype(&device, DType::F64));
assert!(!B::supports_dtype(&device, DType::BF16));
assert!(!B::supports_dtype(&device, DType::I16));
assert!(!B::supports_dtype(&device, DType::I8));
assert!(!B::supports_dtype(&device, DType::U16));
assert!(!B::supports_dtype(&device, DType::U8));
}
#[cfg(not(any(feature = "vulkan", feature = "metal", target_os = "macos")))]
{
assert!(B::supports_dtype(&device, DType::F64));
assert!(B::supports_dtype(&device, DType::Flex32));
assert!(B::supports_dtype(&device, DType::F16));
assert!(!B::supports_dtype(&device, DType::BF16));
assert!(!B::supports_dtype(&device, DType::I16));
assert!(!B::supports_dtype(&device, DType::I8));
assert!(!B::supports_dtype(&device, DType::U16));
assert!(!B::supports_dtype(&device, DType::U8));
}
}
}