burn_dispatch/lib.rs
1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "138"]
5
6//! Burn multi-backend dispatch.
7//!
8//! # Available Backends
9//!
10//! The dispatch backend supports the following variants, each enabled via cargo features:
11//!
12//! | Backend | Feature | Description |
13//! |------------|------------|-------------|
14//! | `Cpu` | `cpu` | Rust CPU backend (MLIR + LLVM) |
15//! | `Cuda` | `cuda` | NVIDIA CUDA backend |
16//! | `Metal` | `metal` | Apple Metal backend via `wgpu` (MSL) |
17//! | `Rocm` | `rocm` | AMD ROCm backend |
18//! | `Vulkan` | `vulkan` | Vulkan backend via `wgpu` (SPIR-V) |
19//! | `Wgpu` | `webgpu` | WebGPU backend via `wgpu` (WGSL) |
20//! | `Flex` | `flex` | Pure Rust CPU backend using `burn-flex` |
21//! | `NdArray` | `ndarray` | Pure Rust CPU backend using `ndarray` (legacy - prefer `flex`) |
22//! | `LibTorch` | `tch` | Libtorch backend via `tch` |
23//! | `Autodiff` | `autodiff` | Autodiff-enabled backend (used in combination with any of the backends above) |
24//!
25//! **Note:** WGPU-based backends (`metal`, `vulkan`, `webgpu`) are mutually exclusive.
26//! All other backends can be combined freely.
27//!
28//! ## WGPU Backend Exclusivity
29//!
30//! The WGPU-based backends (`metal`, `vulkan`, `webgpu`) are **mutually exclusive** due to
31//! the current automatic compile, which can only select one target at a time.
32//!
33//! Enable only **one** of these features in your `Cargo.toml`:
34//! - `metal`
35//! - `vulkan`
36//! - `webgpu`
37//!
38//! If multiple WGPU features are enabled, the build script will emit a warning and **disable all WGPU
39//! backends** to prevent unintended behavior.
40
41#[cfg(not(any(
42 feature = "cpu",
43 feature = "cuda",
44 wgpu_metal,
45 feature = "rocm",
46 wgpu_vulkan,
47 wgpu_webgpu,
48 feature = "flex",
49 feature = "ndarray",
50 feature = "tch",
51)))]
52compile_error!("At least one backend feature must be enabled.");
53
54#[macro_use]
55mod macros;
56
57mod backend;
58mod device;
59mod ops;
60mod tensor;
61
62pub use backend::*;
63pub use device::*;
64pub use tensor::*;
65
66extern crate alloc;
67
68/// Backends and devices used.
69pub(crate) mod backends {
70 #[cfg(feature = "autodiff")]
71 pub use burn_autodiff::Autodiff;
72
73 #[cfg(feature = "cpu")]
74 pub use burn_cpu::{Cpu, CpuDevice};
75 #[cfg(feature = "cuda")]
76 pub use burn_cuda::{Cuda, CudaDevice};
77 #[cfg(feature = "rocm")]
78 pub use burn_rocm::{Rocm, RocmDevice};
79 #[cfg(wgpu_metal)]
80 pub use burn_wgpu::Metal;
81 #[cfg(wgpu_vulkan)]
82 pub use burn_wgpu::Vulkan;
83 #[cfg(wgpu_webgpu)]
84 pub use burn_wgpu::Wgpu;
85 #[cfg(any(wgpu_metal, wgpu_vulkan, wgpu_webgpu))]
86 pub use burn_wgpu::WgpuDevice;
87
88 #[cfg(feature = "flex")]
89 pub use burn_flex::{Flex, FlexDevice};
90 #[cfg(feature = "ndarray")]
91 pub use burn_ndarray::{NdArray, NdArrayDevice};
92 #[cfg(feature = "tch")]
93 pub use burn_tch::{LibTorch, LibTorchDevice};
94}