gloss_burn_multibackend/
tensor.rs

1#[cfg(feature = "burn-candle")]
2use crate::backend::CandleBackend;
3#[cfg(feature = "burn-ndarray")]
4use crate::backend::NdArrayBackend;
5#[cfg(feature = "burn-wgpu")]
6use crate::backend::WgpuBackend;
7
8use burn::tensor::{
9    ops::{BoolTensor, FloatTensor, IntTensor},
10    quantization::{QTensorPrimitive, QuantScheme},
11    DType, Shape, TensorMetadata,
12};
13
14#[non_exhaustive]
15#[derive(Debug, Clone)]
16pub enum MultiFloatTensor {
17    #[cfg(feature = "burn-candle")]
18    Candle(FloatTensor<CandleBackend>),
19    #[cfg(feature = "burn-ndarray")]
20    NdArray(FloatTensor<NdArrayBackend>),
21    #[cfg(feature = "burn-wgpu")]
22    Wgpu(FloatTensor<WgpuBackend>),
23    // #[cfg(feature = "autodiff")]
24    // Autodiff(FloatTensor<burn_autodiff::Autodiff<MultiBackend>>),
25}
26
27#[non_exhaustive]
28#[derive(Debug, Clone)]
29pub enum MultiIntTensor {
30    #[cfg(feature = "burn-candle")]
31    Candle(IntTensor<CandleBackend>),
32    #[cfg(feature = "burn-ndarray")]
33    NdArray(IntTensor<NdArrayBackend>),
34    #[cfg(feature = "burn-wgpu")]
35    Wgpu(IntTensor<WgpuBackend>),
36    // #[cfg(feature = "autodiff")]
37    // Autodiff(IntTensor<burn_autodiff::Autodiff<MultiBackend>>),
38}
39
40#[non_exhaustive]
41#[derive(Debug, Clone)]
42pub enum MultiBoolTensor {
43    #[cfg(feature = "burn-candle")]
44    Candle(BoolTensor<CandleBackend>),
45    #[cfg(feature = "burn-ndarray")]
46    NdArray(BoolTensor<NdArrayBackend>),
47    #[cfg(feature = "burn-wgpu")]
48    Wgpu(BoolTensor<WgpuBackend>),
49    // #[cfg(feature = "autodiff")]
50    // Autodiff(BoolTensor<burn_autodiff::Autodiff<MultiBackend>>),
51}
52
53// TensorMetadata implementations
54impl TensorMetadata for MultiFloatTensor {
55    fn dtype(&self) -> DType {
56        DType::F32
57    }
58
59    fn shape(&self) -> Shape {
60        match self {
61            #[cfg(feature = "burn-candle")]
62            MultiFloatTensor::Candle(t) => t.shape(),
63            #[cfg(feature = "burn-ndarray")]
64            MultiFloatTensor::NdArray(t) => t.shape(),
65            #[cfg(feature = "burn-wgpu")]
66            MultiFloatTensor::Wgpu(t) => t.shape(),
67            // #[cfg(feature = "autodiff")]
68            // MultiFloatTensor::Autodiff(t) => t.shape(),
69        }
70    }
71}
72
73impl TensorMetadata for MultiIntTensor {
74    fn dtype(&self) -> DType {
75        match self {
76            #[cfg(feature = "burn-candle")]
77            MultiIntTensor::Candle(_) => DType::I64,
78            #[cfg(feature = "burn-ndarray")]
79            MultiIntTensor::NdArray(_) => DType::I32,
80            #[cfg(feature = "burn-wgpu")]
81            MultiIntTensor::Wgpu(_) => DType::I32,
82        }
83    }
84
85    fn shape(&self) -> Shape {
86        match self {
87            #[cfg(feature = "burn-candle")]
88            MultiIntTensor::Candle(t) => t.shape(),
89            #[cfg(feature = "burn-ndarray")]
90            MultiIntTensor::NdArray(t) => t.shape(),
91            #[cfg(feature = "burn-wgpu")]
92            MultiIntTensor::Wgpu(t) => t.shape(),
93            // #[cfg(feature = "autodiff")]
94            // MultiIntTensor::Autodiff(t) => t.shape(),
95        }
96    }
97}
98
99impl TensorMetadata for MultiBoolTensor {
100    fn dtype(&self) -> DType {
101        DType::U8
102    }
103
104    fn shape(&self) -> Shape {
105        match self {
106            #[cfg(feature = "burn-candle")]
107            MultiBoolTensor::Candle(t) => t.shape(),
108            #[cfg(feature = "burn-ndarray")]
109            MultiBoolTensor::NdArray(t) => t.shape(),
110            #[cfg(feature = "burn-wgpu")]
111            MultiBoolTensor::Wgpu(t) => t.shape(),
112            // #[cfg(feature = "autodiff")]
113            // MultiIntTensor::Autodiff(t) => t.shape(),
114        }
115    }
116}
117
118impl QTensorPrimitive for MultiIntTensor {
119    fn scheme(&self) -> &QuantScheme {
120        unimplemented!("Quantization is not supported")
121    }
122}