Skip to main content

burn_dispatch/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "138"]
5
6//! Burn multi-backend dispatch.
7//!
8//! # Available Backends
9//!
10//! The dispatch backend supports the following variants, each enabled via cargo features:
11//!
12//! | Backend    | Feature    | Description |
13//! |------------|------------|-------------|
14//! | `Cpu`      | `cpu`      | Rust CPU backend (MLIR + LLVM) |
15//! | `Cuda`     | `cuda`     | NVIDIA CUDA backend |
16//! | `Metal`    | `metal`    | Apple Metal backend via `wgpu` (MSL) |
17//! | `Rocm`     | `rocm`     | AMD ROCm backend |
18//! | `Vulkan`   | `vulkan`   | Vulkan backend via `wgpu` (SPIR-V) |
19//! | `WebGpu`   | `webgpu`   | WebGPU backend via `wgpu` (WGSL) |
20//! | `NdArray`  | `ndarray`  | Pure Rust CPU backend using `ndarray` |
21//! | `LibTorch` | `tch`      | Libtorch backend via `tch` |
22//! | `Autodiff` | `autodiff` | Autodiff-enabled backend (used in combination with any of the backends above) |
23//!
24//! **Note:** WGPU-based backends (`metal`, `vulkan`, `webgpu`) are mutually exclusive.
25//! All other backends can be combined freely.
26//!
27//! ## WGPU Backend Exclusivity
28//!
29//! The WGPU-based backends (`metal`, `vulkan`, `webgpu`) are **mutually exclusive** due to
30//! the current automatic compile, which can only select one target at a time.
31//!
32//! Enable only **one** of these features in your `Cargo.toml`:
33//! - `metal`
34//! - `vulkan`
35//! - `webgpu`
36//!
37//! If multiple WGPU features are enabled, the build script will emit a warning and **disable all WGPU
38//! backends** to prevent unintended behavior.
39
40#[cfg(not(any(
41    feature = "cpu",
42    feature = "cuda",
43    wgpu_metal,
44    feature = "rocm",
45    wgpu_vulkan,
46    wgpu_webgpu,
47    feature = "ndarray",
48    feature = "tch",
49)))]
50compile_error!("At least one backend feature must be enabled.");
51
52#[macro_use]
53mod macros;
54
55mod backend;
56mod device;
57mod ops;
58mod tensor;
59
60pub use backend::*;
61pub use device::*;
62pub use tensor::*;
63
64extern crate alloc;
65
66/// Backends and devices used.
67pub(crate) mod backends {
68    #[cfg(feature = "autodiff")]
69    pub use burn_autodiff::Autodiff;
70
71    #[cfg(feature = "cpu")]
72    pub use burn_cpu::{Cpu, CpuDevice};
73    #[cfg(feature = "cuda")]
74    pub use burn_cuda::{Cuda, CudaDevice};
75    #[cfg(feature = "rocm")]
76    pub use burn_rocm::{Rocm, RocmDevice};
77    #[cfg(wgpu_metal)]
78    pub use burn_wgpu::Metal;
79    #[cfg(wgpu_vulkan)]
80    pub use burn_wgpu::Vulkan;
81    #[cfg(wgpu_webgpu)]
82    pub use burn_wgpu::WebGpu;
83    #[cfg(any(wgpu_metal, wgpu_vulkan, wgpu_webgpu))]
84    pub use burn_wgpu::WgpuDevice;
85
86    #[cfg(feature = "ndarray")]
87    pub use burn_ndarray::{NdArray, NdArrayDevice};
88    #[cfg(feature = "tch")]
89    pub use burn_tch::{LibTorch, LibTorchDevice};
90}