burn_wgpu/lib.rs
1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2
3extern crate alloc;
4
5#[cfg(feature = "template")]
6pub use burn_cubecl::{
7 kernel::{KernelMetadata, into_contiguous},
8 kernel_source,
9 template::{KernelSource, SourceKernel, SourceTemplate, build_info},
10};
11
12pub use burn_cubecl::{BoolElement, FloatElement, IntElement};
13pub use burn_cubecl::{CubeBackend, tensor::CubeTensor};
14pub use cubecl::CubeDim;
15pub use cubecl::flex32;
16
17pub use cubecl::wgpu::{
18 AutoCompiler, MemoryConfiguration, RuntimeOptions, WgpuDevice, WgpuResource, WgpuRuntime,
19 WgpuSetup, WgpuStorage, init_device, init_setup, init_setup_async,
20};
21// Vulkan and WebGpu would have conflicting type names
22pub mod graphics {
23 pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu};
24}
25
26#[cfg(feature = "cubecl-wgsl")]
27pub use cubecl::wgpu::WgslCompiler;
28#[cfg(feature = "cubecl-spirv")]
29pub use cubecl::wgpu::vulkan::VkSpirvCompiler;
30
31#[cfg(feature = "fusion")]
32/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.
33///
34/// This backend can target multiple graphics APIs, including:
35/// - [Vulkan][crate::graphics::Vulkan] on Linux, Windows, and Android.
36/// - [OpenGL](crate::graphics::OpenGl) on Linux, Windows, and Android.
37/// - [DirectX 12](crate::graphics::Dx12) on Windows.
38/// - [Metal][crate::graphics::Metal] on Apple hardware.
39/// - [WebGPU](crate::graphics::WebGpu) on supported browsers and `wasm` runtimes.
40///
41/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
42/// you have to manually initialize the runtime. For example:
43///
44/// ```rust, ignore
45/// fn custom_init() {
46/// let device = Default::default();
47/// burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
48/// &device,
49/// Default::default(),
50/// );
51/// }
52/// ```
53/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
54/// It's also possible to use an existing wgpu device, by using `init_device`.
55///
56/// # Notes
57///
58/// This version of the wgpu backend uses [burn_fusion] to compile and optimize streams of tensor
59/// operations for improved performance.
60///
61/// You can disable the `fusion` feature flag to remove that functionality, which might be
62/// necessary on `wasm` for now.
63pub type Wgpu<F = f32, I = i32, B = u32> =
64 burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>>;
65
66#[cfg(not(feature = "fusion"))]
67/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.
68///
69/// This backend can target multiple graphics APIs, including:
70/// - [Vulkan] on Linux, Windows, and Android.
71/// - [OpenGL](crate::OpenGl) on Linux, Windows, and Android.
72/// - [DirectX 12](crate::Dx12) on Windows.
73/// - [Metal] on Apple hardware.
74/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
75///
76/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
77/// you have to manually initialize the runtime. For example:
78///
79/// ```rust, ignore
80/// fn custom_init() {
81/// let device = Default::default();
82/// burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
83/// &device,
84/// Default::default(),
85/// );
86/// }
87/// ```
88/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
89/// It's also possible to use an existing wgpu device, by using `init_device`.
90///
91/// # Notes
92///
93/// This version of the wgpu backend doesn't use [burn_fusion] to compile and optimize streams of tensor
94/// operations.
95///
96/// You can enable the `fusion` feature flag to add that functionality, which might improve
97/// performance.
98pub type Wgpu<F = f32, I = i32, B = u32> = CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>;
99
100#[cfg(feature = "vulkan")]
101/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V.
102pub type Vulkan<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
103
104#[cfg(feature = "webgpu")]
105/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL.
106pub type WebGpu<F = f32, I = i32, B = u32> = Wgpu<F, I, B>;
107
108#[cfg(feature = "metal")]
109/// Tensor backend that leverages the Metal graphics API to execute GPU compute shaders compiled to MSL.
110pub type Metal<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;
111
112#[cfg(test)]
113mod tests {
114 use burn_cubecl::CubeBackend;
115 #[cfg(feature = "vulkan")]
116 pub use half::f16;
117 #[cfg(feature = "metal")]
118 pub use half::f16;
119
120 pub type TestRuntime = cubecl::wgpu::WgpuRuntime;
121
122 // Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it
123 // breaks a lot of tests from precision issues
124 #[cfg(feature = "vulkan")]
125 burn_cubecl::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]);
126 #[cfg(feature = "metal")]
127 burn_cubecl::testgen_all!([f16, f32], [i16, i32], [u32]);
128 #[cfg(all(not(feature = "vulkan"), not(feature = "metal")))]
129 burn_cubecl::testgen_all!([f32], [i32], [u32]);
130}