burn_dispatch/lib.rs
1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![recursion_limit = "138"]
5
6//! Burn multi-backend dispatch.
7//!
8//! # Available Backends
9//!
10//! The dispatch backend supports the following variants, each enabled via cargo features:
11//!
12//! | Backend | Feature | Description |
13//! |------------|------------|-------------|
14//! | `Cpu` | `cpu` | Rust CPU backend (MLIR + LLVM) |
15//! | `Cuda` | `cuda` | NVIDIA CUDA backend |
16//! | `Metal` | `metal` | Apple Metal backend via `wgpu` (MSL) |
17//! | `Rocm` | `rocm` | AMD ROCm backend |
18//! | `Vulkan` | `vulkan` | Vulkan backend via `wgpu` (SPIR-V) |
19//! | `WebGpu` | `webgpu` | WebGPU backend via `wgpu` (WGSL) |
20//! | `NdArray` | `ndarray` | Pure Rust CPU backend using `ndarray` |
21//! | `LibTorch` | `tch` | Libtorch backend via `tch` |
22//! | `Autodiff` | `autodiff` | Autodiff-enabled backend (used in combination with any of the backends above) |
23//!
24//! **Note:** WGPU-based backends (`metal`, `vulkan`, `webgpu`) are mutually exclusive.
25//! All other backends can be combined freely.
26//!
27//! ## WGPU Backend Exclusivity
28//!
29//! The WGPU-based backends (`metal`, `vulkan`, `webgpu`) are **mutually exclusive** due to
30//! the current automatic compile, which can only select one target at a time.
31//!
32//! Enable only **one** of these features in your `Cargo.toml`:
33//! - `metal`
34//! - `vulkan`
35//! - `webgpu`
36//!
37//! If multiple WGPU features are enabled, the build script will emit a warning and **disable all WGPU
38//! backends** to prevent unintended behavior.
39
40#[cfg(not(any(
41 feature = "cpu",
42 feature = "cuda",
43 wgpu_metal,
44 feature = "rocm",
45 wgpu_vulkan,
46 wgpu_webgpu,
47 feature = "ndarray",
48 feature = "tch",
49)))]
50compile_error!("At least one backend feature must be enabled.");
51
52#[macro_use]
53mod macros;
54
55mod backend;
56mod device;
57mod ops;
58mod tensor;
59
60pub use backend::*;
61pub use device::*;
62pub use tensor::*;
63
64extern crate alloc;
65
66/// Backends and devices used.
67pub(crate) mod backends {
68 #[cfg(feature = "autodiff")]
69 pub use burn_autodiff::Autodiff;
70
71 #[cfg(feature = "cpu")]
72 pub use burn_cpu::{Cpu, CpuDevice};
73 #[cfg(feature = "cuda")]
74 pub use burn_cuda::{Cuda, CudaDevice};
75 #[cfg(feature = "rocm")]
76 pub use burn_rocm::{Rocm, RocmDevice};
77 #[cfg(wgpu_metal)]
78 pub use burn_wgpu::Metal;
79 #[cfg(wgpu_vulkan)]
80 pub use burn_wgpu::Vulkan;
81 #[cfg(wgpu_webgpu)]
82 pub use burn_wgpu::WebGpu;
83 #[cfg(any(wgpu_metal, wgpu_vulkan, wgpu_webgpu))]
84 pub use burn_wgpu::WgpuDevice;
85
86 #[cfg(feature = "ndarray")]
87 pub use burn_ndarray::{NdArray, NdArrayDevice};
88 #[cfg(feature = "tch")]
89 pub use burn_tch::{LibTorch, LibTorchDevice};
90}