use alloc::boxed::Box;
use burn_backend::{DeviceId, DeviceOps};
use crate::backends::*;
#[derive(Clone, Eq)]
pub enum DispatchDevice {
#[cfg(feature = "cpu")]
Cpu(CpuDevice),
#[cfg(feature = "cuda")]
Cuda(CudaDevice),
#[cfg(wgpu_metal)]
Metal(WgpuDevice),
#[cfg(feature = "rocm")]
Rocm(RocmDevice),
#[cfg(wgpu_vulkan)]
Vulkan(WgpuDevice),
#[cfg(wgpu_webgpu)]
Wgpu(WgpuDevice),
#[cfg(feature = "flex")]
Flex(FlexDevice),
#[cfg(feature = "ndarray")]
NdArray(NdArrayDevice),
#[cfg(feature = "tch")]
LibTorch(LibTorchDevice),
#[cfg(feature = "autodiff")]
Autodiff(AutodiffDevice),
}
#[cfg(feature = "autodiff")]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AutodiffDevice {
pub(crate) inner: Box<DispatchDevice>,
pub(crate) checkpointing: CheckpointingStrategy,
}
#[cfg(feature = "autodiff")]
impl AutodiffDevice {
pub(crate) fn new(device: DispatchDevice, checkpointing: CheckpointingStrategy) -> Self {
Self {
inner: Box::new(device),
checkpointing,
}
}
}
#[cfg(feature = "autodiff")]
impl core::ops::Deref for AutodiffDevice {
type Target = DispatchDevice;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[cfg(feature = "autodiff")]
#[allow(missing_docs)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum CheckpointingStrategy {
Balanced,
#[default]
None,
}
#[cfg(feature = "autodiff")]
pub(crate) fn validate_checkpointing(
lhs: crate::CheckpointingStrategy,
rhs: crate::CheckpointingStrategy,
) -> crate::CheckpointingStrategy {
assert_eq!(
lhs, rhs,
"Autodiff strategy mismatch: {lhs:?} vs {rhs:?}. Tensors in the same operation must share a strategy."
);
lhs
}
impl core::fmt::Debug for DispatchDevice {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
#[cfg(feature = "cpu")]
Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
#[cfg(feature = "cuda")]
Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
#[cfg(wgpu_metal)]
Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
#[cfg(feature = "rocm")]
Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
#[cfg(wgpu_vulkan)]
Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
#[cfg(wgpu_webgpu)]
Self::Wgpu(device) => f.debug_tuple("Wgpu").field(device).finish(),
#[cfg(feature = "flex")]
Self::Flex(device) => f.debug_tuple("Flex").field(device).finish(),
#[cfg(feature = "ndarray")]
Self::NdArray(device) => f.debug_tuple("NdArray").field(device).finish(),
#[cfg(feature = "tch")]
Self::LibTorch(device) => f.debug_tuple("LibTorch").field(device).finish(),
#[cfg(feature = "autodiff")]
Self::Autodiff(device) => f.debug_tuple("Autodiff").field(&device.inner).finish(),
}
}
}
impl Default for DispatchDevice {
#[allow(unreachable_code)]
fn default() -> Self {
#[cfg(feature = "cuda")]
return Self::Cuda(CudaDevice::default());
#[cfg(wgpu_metal)]
return Self::Metal(burn_wgpu::WgpuDevice::default());
#[cfg(feature = "rocm")]
return Self::Rocm(RocmDevice::default());
#[cfg(wgpu_vulkan)]
return Self::Vulkan(burn_wgpu::WgpuDevice::default());
#[cfg(wgpu_webgpu)]
return Self::Wgpu(burn_wgpu::WgpuDevice::default());
#[cfg(feature = "cpu")]
return Self::Cpu(CpuDevice);
#[cfg(feature = "tch")]
return Self::LibTorch(LibTorchDevice::default());
#[cfg(feature = "flex")]
return Self::Flex(FlexDevice);
#[cfg(feature = "ndarray")]
return Self::NdArray(NdArrayDevice::default());
}
}
impl PartialEq for DispatchDevice {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
#[cfg(feature = "autodiff")]
(DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b,
#[cfg(feature = "autodiff")]
(DispatchDevice::Autodiff(a), b) => a.inner.as_ref() == b,
#[cfg(feature = "autodiff")]
(a, DispatchDevice::Autodiff(b)) => a == b.inner.as_ref(),
#[cfg(feature = "cpu")]
(Self::Cpu(a), Self::Cpu(b)) => a == b,
#[cfg(feature = "cuda")]
(Self::Cuda(a), Self::Cuda(b)) => a == b,
#[cfg(wgpu_metal)]
(Self::Metal(a), Self::Metal(b)) => a == b,
#[cfg(feature = "rocm")]
(Self::Rocm(a), Self::Rocm(b)) => a == b,
#[cfg(wgpu_vulkan)]
(Self::Vulkan(a), Self::Vulkan(b)) => a == b,
#[cfg(wgpu_webgpu)]
(Self::Wgpu(a), Self::Wgpu(b)) => a == b,
#[cfg(feature = "flex")]
(Self::Flex(a), Self::Flex(b)) => a == b,
#[cfg(feature = "ndarray")]
(Self::NdArray(a), Self::NdArray(b)) => a == b,
#[cfg(feature = "tch")]
(Self::LibTorch(a), Self::LibTorch(b)) => a == b,
#[allow(unreachable_patterns)]
(_, _) => false,
}
}
}
const TYPE_ID_BASE: u16 = 10;
impl DispatchDevice {
#[cfg(feature = "autodiff")]
pub fn autodiff(device: impl Into<DispatchDevice>) -> DispatchDevice {
Self::autodiff_checkpointed(device, CheckpointingStrategy::None)
}
#[cfg(feature = "autodiff")]
pub fn autodiff_checkpointed(
device: impl Into<DispatchDevice>,
checkpointing: CheckpointingStrategy,
) -> DispatchDevice {
let device = device.into();
DispatchDevice::Autodiff(AutodiffDevice::new(device, checkpointing))
}
fn backend_id(&self) -> BackendId {
match self {
#[cfg(feature = "cpu")]
Self::Cpu(_) => BackendId::Cpu,
#[cfg(feature = "cuda")]
Self::Cuda(_) => BackendId::Cuda,
#[cfg(wgpu_metal)]
Self::Metal(_) => BackendId::Metal,
#[cfg(feature = "rocm")]
Self::Rocm(_) => BackendId::Rocm,
#[cfg(wgpu_vulkan)]
Self::Vulkan(_) => BackendId::Vulkan,
#[cfg(wgpu_webgpu)]
Self::Wgpu(_) => BackendId::Wgpu,
#[cfg(feature = "flex")]
Self::Flex(_) => BackendId::Flex,
#[cfg(feature = "ndarray")]
Self::NdArray(_) => BackendId::NdArray,
#[cfg(feature = "tch")]
Self::LibTorch(_) => BackendId::LibTorch,
#[cfg(feature = "autodiff")]
Self::Autodiff(device) => device.inner.backend_id(),
}
}
fn encode_type_id(&self, backend_type_id: u16) -> u16 {
u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id
}
pub(crate) fn decode_type_id(type_id: u16) -> (BackendId, u16) {
let variant = type_id / TYPE_ID_BASE;
let backend_type_id = type_id % TYPE_ID_BASE;
(
BackendId::try_from(variant).expect("Unknown DispatchDevice variant"),
backend_type_id,
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub(crate) enum BackendId {
#[cfg(feature = "cpu")]
Cpu = 0,
#[cfg(feature = "cuda")]
Cuda = 1,
#[cfg(wgpu_metal)]
Metal = 2,
#[cfg(feature = "rocm")]
Rocm = 3,
#[cfg(wgpu_vulkan)]
Vulkan = 4,
#[cfg(wgpu_webgpu)]
Wgpu = 5,
#[cfg(feature = "ndarray")]
NdArray = 6,
#[cfg(feature = "tch")]
LibTorch = 7,
#[cfg(feature = "flex")]
Flex = 8,
}
impl From<BackendId> for u16 {
fn from(variant: BackendId) -> Self {
variant as u16
}
}
impl TryFrom<u16> for BackendId {
type Error = ();
fn try_from(value: u16) -> Result<Self, Self::Error> {
match value {
#[cfg(feature = "cpu")]
0 => Ok(Self::Cpu),
#[cfg(feature = "cuda")]
1 => Ok(Self::Cuda),
#[cfg(wgpu_metal)]
2 => Ok(Self::Metal),
#[cfg(feature = "rocm")]
3 => Ok(Self::Rocm),
#[cfg(wgpu_vulkan)]
4 => Ok(Self::Vulkan),
#[cfg(wgpu_webgpu)]
5 => Ok(Self::Wgpu),
#[cfg(feature = "ndarray")]
6 => Ok(Self::NdArray),
#[cfg(feature = "tch")]
7 => Ok(Self::LibTorch),
#[cfg(feature = "flex")]
8 => Ok(Self::Flex),
_ => Err(()),
}
}
}
impl DeviceOps for DispatchDevice {
fn inner(&self) -> &Self {
match self {
#[cfg(feature = "autodiff")]
DispatchDevice::Autodiff(device) => &device.inner,
device => device,
}
}
}
impl burn_backend::Device for DispatchDevice {
fn from_id(mut device_id: DeviceId) -> Self {
let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id);
device_id.type_id = backend_type_id;
match dispatch_id {
#[cfg(feature = "cpu")]
BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
#[cfg(feature = "cuda")]
BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
#[cfg(wgpu_metal)]
BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
#[cfg(feature = "rocm")]
BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
#[cfg(wgpu_vulkan)]
BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
#[cfg(wgpu_webgpu)]
BackendId::Wgpu => Self::Wgpu(WgpuDevice::from_id(device_id)),
#[cfg(feature = "flex")]
BackendId::Flex => Self::Flex(FlexDevice::from_id(device_id)),
#[cfg(feature = "ndarray")]
BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)),
#[cfg(feature = "tch")]
BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)),
}
}
fn to_id(&self) -> DeviceId {
let mut device_id = match self {
#[cfg(feature = "cpu")]
Self::Cpu(device) => device.to_id(),
#[cfg(feature = "cuda")]
Self::Cuda(device) => device.to_id(),
#[cfg(wgpu_metal)]
Self::Metal(device) => device.to_id(),
#[cfg(feature = "rocm")]
Self::Rocm(device) => device.to_id(),
#[cfg(wgpu_vulkan)]
Self::Vulkan(device) => device.to_id(),
#[cfg(wgpu_webgpu)]
Self::Wgpu(device) => device.to_id(),
#[cfg(feature = "flex")]
Self::Flex(device) => device.to_id(),
#[cfg(feature = "ndarray")]
Self::NdArray(device) => device.to_id(),
#[cfg(feature = "tch")]
Self::LibTorch(device) => device.to_id(),
#[cfg(feature = "autodiff")]
Self::Autodiff(device) => device.inner.to_id(),
};
device_id.type_id = self.encode_type_id(device_id.type_id);
device_id
}
}
#[cfg(feature = "cpu")]
impl From<CpuDevice> for DispatchDevice {
fn from(device: CpuDevice) -> Self {
DispatchDevice::Cpu(device)
}
}
#[cfg(feature = "cuda")]
impl From<CudaDevice> for DispatchDevice {
fn from(device: CudaDevice) -> Self {
DispatchDevice::Cuda(device)
}
}
#[cfg(wgpu_metal)]
impl From<WgpuDevice> for DispatchDevice {
fn from(device: WgpuDevice) -> Self {
DispatchDevice::Metal(device)
}
}
#[cfg(feature = "rocm")]
impl From<RocmDevice> for DispatchDevice {
fn from(device: RocmDevice) -> Self {
DispatchDevice::Rocm(device)
}
}
#[cfg(wgpu_vulkan)]
impl From<WgpuDevice> for DispatchDevice {
fn from(device: WgpuDevice) -> Self {
DispatchDevice::Vulkan(device)
}
}
#[cfg(wgpu_webgpu)]
impl From<WgpuDevice> for DispatchDevice {
fn from(device: WgpuDevice) -> Self {
DispatchDevice::Wgpu(device)
}
}
#[cfg(feature = "flex")]
impl From<FlexDevice> for DispatchDevice {
fn from(device: FlexDevice) -> Self {
DispatchDevice::Flex(device)
}
}
#[cfg(feature = "ndarray")]
impl From<NdArrayDevice> for DispatchDevice {
fn from(device: NdArrayDevice) -> Self {
DispatchDevice::NdArray(device)
}
}
#[cfg(feature = "tch")]
impl From<LibTorchDevice> for DispatchDevice {
fn from(device: LibTorchDevice) -> Self {
DispatchDevice::LibTorch(device)
}
}