1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3extern crate alloc;
4
5#[cfg(feature = "template")]
6pub use burn_cubecl::{
7 kernel::{KernelMetadata, into_contiguous},
8 kernel_source,
9 template::{KernelSource, SourceKernel, SourceTemplate, build_info},
10};
11
12pub use burn_cubecl::{BoolElement, FloatElement, IntElement};
13pub use burn_cubecl::{CubeBackend, tensor::CubeTensor};
14pub use cubecl::CubeDim;
15pub use cubecl::flex32;
16
17pub use cubecl::wgpu::{
18 AutoCompiler, MemoryConfiguration, RuntimeOptions, WgpuDevice, WgpuResource, WgpuRuntime,
19 WgpuSetup, WgpuStorage, init_device, init_setup, init_setup_async,
20};
21pub mod graphics {
23 pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu};
24}
25
26#[cfg(feature = "cubecl-wgsl")]
27pub use cubecl::wgpu::WgslCompiler;
28#[cfg(feature = "cubecl-spirv")]
29pub use cubecl::wgpu::vulkan::VkSpirvCompiler;
30
31#[cfg(feature = "fusion")]
32pub type Wgpu<F = f32, I = i32, B = u32> =
64 burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>>;
65
66#[cfg(not(feature = "fusion"))]
67pub type Wgpu<F = f32, I = i32, B = u32> = CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>;
99
100#[cfg(feature = "vulkan")]
101pub type Vulkan<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
103
104#[cfg(feature = "webgpu")]
105pub type WebGpu<F = f32, I = i32, B = u32> = Wgpu<F, I, B>;
107
108#[cfg(feature = "metal")]
109pub type Metal<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use burn_backend::{Backend, BoolStore, DType, QTensorPrimitive};
116
117 #[test]
118 fn should_support_dtypes() {
119 type B = Wgpu;
120 let device = Default::default();
121
122 assert!(B::supports_dtype(&device, DType::F32));
123 assert!(B::supports_dtype(&device, DType::I64));
124 assert!(B::supports_dtype(&device, DType::I32));
125 assert!(B::supports_dtype(&device, DType::U64));
126 assert!(B::supports_dtype(&device, DType::U32));
127 assert!(B::supports_dtype(
128 &device,
129 DType::QFloat(CubeTensor::<WgpuRuntime>::default_scheme())
130 ));
131 assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native)));
133
134 #[cfg(feature = "vulkan")]
135 {
136 assert!(B::supports_dtype(&device, DType::F16));
137 assert!(B::supports_dtype(&device, DType::I16));
138 assert!(B::supports_dtype(&device, DType::I8));
139 assert!(B::supports_dtype(&device, DType::U16));
140 assert!(B::supports_dtype(&device, DType::U8));
141
142 assert!(!B::supports_dtype(&device, DType::F64));
143 assert!(!B::supports_dtype(&device, DType::Flex32));
144 assert!(!B::supports_dtype(&device, DType::BF16));
146 }
147
148 #[cfg(feature = "metal")]
149 {
150 assert!(B::supports_dtype(&device, DType::F16));
151 assert!(B::supports_dtype(&device, DType::I16));
152 assert!(B::supports_dtype(&device, DType::I8));
153 assert!(B::supports_dtype(&device, DType::U16));
154 assert!(B::supports_dtype(&device, DType::U8));
155
156 assert!(!B::supports_dtype(&device, DType::F64));
157 assert!(!B::supports_dtype(&device, DType::BF16));
158 assert!(!B::supports_dtype(&device, DType::Flex32));
159 }
160
161 #[cfg(all(not(any(feature = "vulkan", feature = "metal")), target_os = "macos"))]
164 {
165 assert!(B::supports_dtype(&device, DType::Flex32));
166 assert!(B::supports_dtype(&device, DType::F16));
167
168 assert!(!B::supports_dtype(&device, DType::F64));
169 assert!(!B::supports_dtype(&device, DType::BF16));
170 assert!(!B::supports_dtype(&device, DType::I16));
171 assert!(!B::supports_dtype(&device, DType::I8));
172 assert!(!B::supports_dtype(&device, DType::U16));
173 assert!(!B::supports_dtype(&device, DType::U8));
174 }
175
176 #[cfg(not(any(feature = "vulkan", feature = "metal", target_os = "macos")))]
177 {
178 assert!(B::supports_dtype(&device, DType::F64));
179 assert!(B::supports_dtype(&device, DType::Flex32));
180 assert!(B::supports_dtype(&device, DType::F16));
181
182 assert!(!B::supports_dtype(&device, DType::BF16));
183 assert!(!B::supports_dtype(&device, DType::I16));
184 assert!(!B::supports_dtype(&device, DType::I8));
185 assert!(!B::supports_dtype(&device, DType::U16));
186 assert!(!B::supports_dtype(&device, DType::U8));
187 }
188 }
189}