Skip to main content

burn_dispatch/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "138"]
5
6//! Burn multi-backend dispatch.
7//!
8//! # Available Backends
9//!
10//! The dispatch backend supports the following variants, each enabled via cargo features:
11//!
12//! | Backend    | Feature    | Description |
13//! |------------|------------|-------------|
14//! | `Cpu`      | `cpu`      | Rust CPU backend (MLIR + LLVM) |
15//! | `Cuda`     | `cuda`     | NVIDIA CUDA backend |
16//! | `Metal`    | `metal`    | Apple Metal backend via `wgpu` (MSL) |
17//! | `Rocm`     | `rocm`     | AMD ROCm backend |
18//! | `Vulkan`   | `vulkan`   | Vulkan backend via `wgpu` (SPIR-V) |
19//! | `Wgpu`     | `webgpu`   | WebGPU backend via `wgpu` (WGSL) |
20//! | `Flex`     | `flex`     | Pure Rust CPU backend using `burn-flex` |
21//! | `NdArray`  | `ndarray`  | Pure Rust CPU backend using `ndarray` (legacy - prefer `flex`) |
22//! | `LibTorch` | `tch`      | Libtorch backend via `tch` |
23//! | `Autodiff` | `autodiff` | Autodiff-enabled backend (used in combination with any of the backends above) |
24//!
25//! **Note:** WGPU-based backends (`metal`, `vulkan`, `webgpu`) are mutually exclusive.
26//! All other backends can be combined freely.
27//!
28//! ## WGPU Backend Exclusivity
29//!
30//! The WGPU-based backends (`metal`, `vulkan`, `webgpu`) are **mutually exclusive** due to
31//! the current automatic compile, which can only select one target at a time.
32//!
33//! Enable only **one** of these features in your `Cargo.toml`:
34//! - `metal`
35//! - `vulkan`
36//! - `webgpu`
37//!
38//! If multiple WGPU features are enabled, the build script will emit a warning and **disable all WGPU
39//! backends** to prevent unintended behavior.
40
41#[cfg(not(any(
42    feature = "cpu",
43    feature = "cuda",
44    wgpu_metal,
45    feature = "rocm",
46    wgpu_vulkan,
47    wgpu_webgpu,
48    feature = "flex",
49    feature = "ndarray",
50    feature = "tch",
51)))]
52compile_error!("At least one backend feature must be enabled.");
53
54#[macro_use]
55mod macros;
56
57mod backend;
58mod device;
59mod ops;
60mod tensor;
61
62pub use backend::*;
63pub use device::*;
64pub use tensor::*;
65
66extern crate alloc;
67
68/// Backends and devices used.
69pub(crate) mod backends {
70    #[cfg(feature = "autodiff")]
71    pub use burn_autodiff::Autodiff;
72
73    #[cfg(feature = "cpu")]
74    pub use burn_cpu::{Cpu, CpuDevice};
75    #[cfg(feature = "cuda")]
76    pub use burn_cuda::{Cuda, CudaDevice};
77    #[cfg(feature = "rocm")]
78    pub use burn_rocm::{Rocm, RocmDevice};
79    #[cfg(wgpu_metal)]
80    pub use burn_wgpu::Metal;
81    #[cfg(wgpu_vulkan)]
82    pub use burn_wgpu::Vulkan;
83    #[cfg(wgpu_webgpu)]
84    pub use burn_wgpu::Wgpu;
85    #[cfg(any(wgpu_metal, wgpu_vulkan, wgpu_webgpu))]
86    pub use burn_wgpu::WgpuDevice;
87
88    #[cfg(feature = "flex")]
89    pub use burn_flex::{Flex, FlexDevice};
90    #[cfg(feature = "ndarray")]
91    pub use burn_ndarray::{NdArray, NdArrayDevice};
92    #[cfg(feature = "tch")]
93    pub use burn_tch::{LibTorch, LibTorchDevice};
94}