burn_wgpu/
lib.rs

1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2
3extern crate alloc;
4
5#[cfg(feature = "template")]
6pub use burn_cubecl::{
7    kernel::{KernelMetadata, into_contiguous},
8    kernel_source,
9    template::{KernelSource, SourceKernel, SourceTemplate, build_info},
10};
11
12pub use burn_cubecl::{BoolElement, FloatElement, IntElement};
13pub use burn_cubecl::{CubeBackend, tensor::CubeTensor};
14pub use cubecl::CubeDim;
15pub use cubecl::flex32;
16
17pub use cubecl::wgpu::{
18    AutoCompiler, MemoryConfiguration, RuntimeOptions, WgpuDevice, WgpuResource, WgpuRuntime,
19    WgpuSetup, WgpuStorage, init_device, init_setup, init_setup_async,
20};
21// Vulkan and WebGpu would have conflicting type names
22pub mod graphics {
23    pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu};
24}
25
26#[cfg(feature = "cubecl-wgsl")]
27pub use cubecl::wgpu::WgslCompiler;
28#[cfg(feature = "cubecl-spirv")]
29pub use cubecl::wgpu::vulkan::VkSpirvCompiler;
30
31#[cfg(feature = "fusion")]
32/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.
33///
34/// This backend can target multiple graphics APIs, including:
35///   - [Vulkan][crate::graphics::Vulkan] on Linux, Windows, and Android.
36///   - [OpenGL](crate::graphics::OpenGl) on Linux, Windows, and Android.
37///   - [DirectX 12](crate::graphics::Dx12) on Windows.
38///   - [Metal][crate::graphics::Metal] on Apple hardware.
39///   - [WebGPU](crate::graphics::WebGpu) on supported browsers and `wasm` runtimes.
40///
41/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
42/// you have to manually initialize the runtime. For example:
43///
44/// ```rust, ignore
45/// fn custom_init() {
46///     let device = Default::default();
47///     burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
48///         &device,
49///         Default::default(),
50///     );
51/// }
52/// ```
53/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
54/// It's also possible to use an existing wgpu device, by using `init_device`.
55///
56/// # Notes
57///
58/// This version of the wgpu backend uses [burn_fusion] to compile and optimize streams of tensor
59/// operations for improved performance.
60///
61/// You can disable the `fusion` feature flag to remove that functionality, which might be
62/// necessary on `wasm` for now.
63pub type Wgpu<F = f32, I = i32, B = u32> =
64    burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>>;
65
66#[cfg(not(feature = "fusion"))]
67/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.
68///
69/// This backend can target multiple graphics APIs, including:
70///   - [Vulkan] on Linux, Windows, and Android.
71///   - [OpenGL](crate::OpenGl) on Linux, Windows, and Android.
72///   - [DirectX 12](crate::Dx12) on Windows.
73///   - [Metal] on Apple hardware.
74///   - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
75///
76/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
77/// you have to manually initialize the runtime. For example:
78///
79/// ```rust, ignore
80/// fn custom_init() {
81///     let device = Default::default();
82///     burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
83///         &device,
84///         Default::default(),
85///     );
86/// }
87/// ```
88/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
89/// It's also possible to use an existing wgpu device, by using `init_device`.
90///
91/// # Notes
92///
93/// This version of the wgpu backend doesn't use [burn_fusion] to compile and optimize streams of tensor
94/// operations.
95///
96/// You can enable the `fusion` feature flag to add that functionality, which might improve
97/// performance.
98pub type Wgpu<F = f32, I = i32, B = u32> = CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>;
99
100#[cfg(feature = "vulkan")]
101/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V.
102pub type Vulkan<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
103
104#[cfg(feature = "webgpu")]
105/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL.
106pub type WebGpu<F = f32, I = i32, B = u32> = Wgpu<F, I, B>;
107
108#[cfg(feature = "metal")]
109/// Tensor backend that leverages the Metal graphics API to execute GPU compute shaders compiled to MSL.
110pub type Metal<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
111
112#[cfg(test)]
113mod tests {
114    use burn_cubecl::CubeBackend;
115    #[cfg(feature = "vulkan")]
116    pub use half::f16;
117    #[cfg(feature = "metal")]
118    pub use half::f16;
119
120    pub type TestRuntime = cubecl::wgpu::WgpuRuntime;
121
122    // Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it
123    // breaks a lot of tests from precision issues
124    #[cfg(feature = "vulkan")]
125    burn_cubecl::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]);
126    #[cfg(feature = "metal")]
127    burn_cubecl::testgen_all!([f16, f32], [i16, i32], [u32]);
128    #[cfg(all(not(feature = "vulkan"), not(feature = "metal")))]
129    burn_cubecl::testgen_all!([f32], [i32], [u32]);
130}