use alloc::boxed::Box;
use burn_backend::{
Backend, DType, QTensorPrimitive, Shape, TensorMetadata, quantization::QuantScheme,
};
#[cfg(feature = "autodiff")]
use crate::CheckpointingStrategy;
use crate::backends::*;
#[cfg(feature = "autodiff")]
use burn_backend::tensor::FloatTensor;
#[derive(Clone, Debug)]
pub enum BackendTensor<B: Backend> {
Float(B::FloatTensorPrimitive),
Int(B::IntTensorPrimitive),
Bool(B::BoolTensorPrimitive),
Quantized(B::QuantizedTensorPrimitive),
#[cfg(feature = "autodiff")]
Autodiff(FloatTensor<Autodiff<B>>),
}
impl<B: Backend> BackendTensor<B> {
pub(crate) fn float(self) -> B::FloatTensorPrimitive {
match self {
BackendTensor::Float(tensor) => tensor,
BackendTensor::Int(_) => panic!("Should be float, got int"),
BackendTensor::Bool(_) => panic!("Should be float, got bool"),
BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
}
}
pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {
match self {
BackendTensor::Float(tensor) => tensor,
BackendTensor::Int(_) => panic!("Should be float, got int"),
BackendTensor::Bool(_) => panic!("Should be float, got bool"),
BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
}
}
pub(crate) fn int(self) -> B::IntTensorPrimitive {
match self {
BackendTensor::Int(tensor) => tensor,
BackendTensor::Float(_) => panic!("Should be int, got float"),
BackendTensor::Bool(_) => panic!("Should be int, got bool"),
BackendTensor::Quantized(_) => panic!("Should be int, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"),
}
}
pub(crate) fn bool(self) -> B::BoolTensorPrimitive {
match self {
BackendTensor::Bool(tensor) => tensor,
BackendTensor::Float(_) => panic!("Should be bool, got float"),
BackendTensor::Int(_) => panic!("Should be bool, got int"),
BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"),
}
}
pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {
match self {
BackendTensor::Quantized(tensor) => tensor,
_ => unreachable!(),
}
}
#[cfg(feature = "autodiff")]
pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {
match self {
BackendTensor::Autodiff(tensor) => tensor,
_ => unreachable!(),
}
}
#[cfg(feature = "autodiff")]
pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {
match self {
BackendTensor::Autodiff(tensor) => tensor,
_ => unreachable!(),
}
}
#[cfg(feature = "autodiff")]
pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {
match self {
BackendTensor::Autodiff(tensor) => tensor.primitive,
_ => unreachable!(),
}
}
pub(crate) fn device(&self) -> B::Device {
match self {
BackendTensor::Float(tensor) => B::float_device(tensor),
BackendTensor::Int(tensor) => B::int_device(tensor),
BackendTensor::Bool(tensor) => B::bool_device(tensor),
BackendTensor::Quantized(tensor) => B::q_device(tensor),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),
}
}
}
impl<B: Backend> TensorMetadata for BackendTensor<B> {
fn dtype(&self) -> DType {
match self {
BackendTensor::Float(tensor) => tensor.dtype(),
BackendTensor::Int(tensor) => tensor.dtype(),
BackendTensor::Bool(tensor) => tensor.dtype(),
BackendTensor::Quantized(tensor) => tensor.dtype(),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(tensor) => tensor.dtype(),
}
}
fn shape(&self) -> Shape {
match self {
BackendTensor::Float(tensor) => tensor.shape(),
BackendTensor::Int(tensor) => tensor.shape(),
BackendTensor::Bool(tensor) => tensor.shape(),
BackendTensor::Quantized(tensor) => tensor.shape(),
#[cfg(feature = "autodiff")]
BackendTensor::Autodiff(tensor) => tensor.shape(),
}
}
}
impl<B: Backend> QTensorPrimitive for BackendTensor<B> {
fn scheme(&self) -> &QuantScheme {
match self {
BackendTensor::Quantized(tensor) => tensor.scheme(),
_ => panic!(
"Quantization scheme is not valid for dtype {:?}",
self.dtype(),
),
}
}
}
#[derive(Clone, Debug)]
pub struct DispatchTensor {
pub(crate) kind: DispatchTensorKind,
#[cfg(feature = "autodiff")]
pub(crate) checkpointing: CheckpointingStrategy,
}
#[derive(Clone, Debug)]
pub enum DispatchTensorKind {
#[cfg(feature = "cpu")]
Cpu(BackendTensor<Cpu>),
#[cfg(feature = "cuda")]
Cuda(BackendTensor<Cuda>),
#[cfg(wgpu_metal)]
Metal(BackendTensor<Metal>),
#[cfg(feature = "rocm")]
Rocm(BackendTensor<Rocm>),
#[cfg(wgpu_vulkan)]
Vulkan(BackendTensor<Vulkan>),
#[cfg(wgpu_webgpu)]
Wgpu(BackendTensor<Wgpu>),
#[cfg(feature = "flex")]
Flex(BackendTensor<Flex>),
#[cfg(feature = "ndarray")]
NdArray(BackendTensor<NdArray>),
#[cfg(feature = "tch")]
LibTorch(BackendTensor<LibTorch>),
#[cfg(feature = "autodiff")]
Autodiff(Box<DispatchTensorKind>),
}
impl TensorMetadata for DispatchTensorKind {
fn dtype(&self) -> DType {
match self {
#[cfg(feature = "cpu")]
Self::Cpu(tensor) => tensor.dtype(),
#[cfg(feature = "cuda")]
Self::Cuda(tensor) => tensor.dtype(),
#[cfg(wgpu_metal)]
Self::Metal(tensor) => tensor.dtype(),
#[cfg(feature = "rocm")]
Self::Rocm(tensor) => tensor.dtype(),
#[cfg(wgpu_vulkan)]
Self::Vulkan(tensor) => tensor.dtype(),
#[cfg(wgpu_webgpu)]
Self::Wgpu(tensor) => tensor.dtype(),
#[cfg(feature = "flex")]
Self::Flex(tensor) => tensor.dtype(),
#[cfg(feature = "ndarray")]
Self::NdArray(tensor) => tensor.dtype(),
#[cfg(feature = "tch")]
Self::LibTorch(tensor) => tensor.dtype(),
#[cfg(feature = "autodiff")]
Self::Autodiff(tensor) => tensor.dtype(),
}
}
fn shape(&self) -> Shape {
match self {
#[cfg(feature = "cpu")]
Self::Cpu(tensor) => tensor.shape(),
#[cfg(feature = "cuda")]
Self::Cuda(tensor) => tensor.shape(),
#[cfg(wgpu_metal)]
Self::Metal(tensor) => tensor.shape(),
#[cfg(feature = "rocm")]
Self::Rocm(tensor) => tensor.shape(),
#[cfg(wgpu_vulkan)]
Self::Vulkan(tensor) => tensor.shape(),
#[cfg(wgpu_webgpu)]
Self::Wgpu(tensor) => tensor.shape(),
#[cfg(feature = "flex")]
Self::Flex(tensor) => tensor.shape(),
#[cfg(feature = "ndarray")]
Self::NdArray(tensor) => tensor.shape(),
#[cfg(feature = "tch")]
Self::LibTorch(tensor) => tensor.shape(),
#[cfg(feature = "autodiff")]
Self::Autodiff(tensor) => tensor.shape(),
}
}
}
impl QTensorPrimitive for DispatchTensorKind {
fn scheme(&self) -> &QuantScheme {
match self {
#[cfg(feature = "cpu")]
Self::Cpu(tensor) => tensor.scheme(),
#[cfg(feature = "cuda")]
Self::Cuda(tensor) => tensor.scheme(),
#[cfg(wgpu_metal)]
Self::Metal(tensor) => tensor.scheme(),
#[cfg(feature = "rocm")]
Self::Rocm(tensor) => tensor.scheme(),
#[cfg(wgpu_vulkan)]
Self::Vulkan(tensor) => tensor.scheme(),
#[cfg(wgpu_webgpu)]
Self::Wgpu(tensor) => tensor.scheme(),
#[cfg(feature = "flex")]
Self::Flex(tensor) => tensor.scheme(),
#[cfg(feature = "ndarray")]
Self::NdArray(tensor) => tensor.scheme(),
#[cfg(feature = "tch")]
Self::LibTorch(tensor) => tensor.scheme(),
#[cfg(feature = "autodiff")]
Self::Autodiff(tensor) => tensor.scheme(),
}
}
}
impl TensorMetadata for DispatchTensor {
fn dtype(&self) -> DType {
self.kind.dtype()
}
fn shape(&self) -> Shape {
self.kind.shape()
}
}
impl QTensorPrimitive for DispatchTensor {
fn scheme(&self) -> &QuantScheme {
self.kind.scheme()
}
}