gloss_burn_multibackend/
tensor.rs1#[cfg(feature = "burn-candle")]
2use crate::backend::CandleBackend;
3#[cfg(feature = "burn-ndarray")]
4use crate::backend::NdArrayBackend;
5#[cfg(feature = "burn-wgpu")]
6use crate::backend::WgpuBackend;
7
8use burn::tensor::{
9 ops::{BoolTensor, FloatTensor, IntTensor},
10 quantization::{QTensorPrimitive, QuantScheme},
11 DType, Shape, TensorMetadata,
12};
13
14#[non_exhaustive]
15#[derive(Debug, Clone)]
16pub enum MultiFloatTensor {
17 #[cfg(feature = "burn-candle")]
18 Candle(FloatTensor<CandleBackend>),
19 #[cfg(feature = "burn-ndarray")]
20 NdArray(FloatTensor<NdArrayBackend>),
21 #[cfg(feature = "burn-wgpu")]
22 Wgpu(FloatTensor<WgpuBackend>),
23 }
26
27#[non_exhaustive]
28#[derive(Debug, Clone)]
29pub enum MultiIntTensor {
30 #[cfg(feature = "burn-candle")]
31 Candle(IntTensor<CandleBackend>),
32 #[cfg(feature = "burn-ndarray")]
33 NdArray(IntTensor<NdArrayBackend>),
34 #[cfg(feature = "burn-wgpu")]
35 Wgpu(IntTensor<WgpuBackend>),
36 }
39
40#[non_exhaustive]
41#[derive(Debug, Clone)]
42pub enum MultiBoolTensor {
43 #[cfg(feature = "burn-candle")]
44 Candle(BoolTensor<CandleBackend>),
45 #[cfg(feature = "burn-ndarray")]
46 NdArray(BoolTensor<NdArrayBackend>),
47 #[cfg(feature = "burn-wgpu")]
48 Wgpu(BoolTensor<WgpuBackend>),
49 }
52
53impl TensorMetadata for MultiFloatTensor {
55 fn dtype(&self) -> DType {
56 DType::F32
57 }
58
59 fn shape(&self) -> Shape {
60 match self {
61 #[cfg(feature = "burn-candle")]
62 MultiFloatTensor::Candle(t) => t.shape(),
63 #[cfg(feature = "burn-ndarray")]
64 MultiFloatTensor::NdArray(t) => t.shape(),
65 #[cfg(feature = "burn-wgpu")]
66 MultiFloatTensor::Wgpu(t) => t.shape(),
67 }
70 }
71}
72
73impl TensorMetadata for MultiIntTensor {
74 fn dtype(&self) -> DType {
75 match self {
76 #[cfg(feature = "burn-candle")]
77 MultiIntTensor::Candle(_) => DType::I64,
78 #[cfg(feature = "burn-ndarray")]
79 MultiIntTensor::NdArray(_) => DType::I32,
80 #[cfg(feature = "burn-wgpu")]
81 MultiIntTensor::Wgpu(_) => DType::I32,
82 }
83 }
84
85 fn shape(&self) -> Shape {
86 match self {
87 #[cfg(feature = "burn-candle")]
88 MultiIntTensor::Candle(t) => t.shape(),
89 #[cfg(feature = "burn-ndarray")]
90 MultiIntTensor::NdArray(t) => t.shape(),
91 #[cfg(feature = "burn-wgpu")]
92 MultiIntTensor::Wgpu(t) => t.shape(),
93 }
96 }
97}
98
99impl TensorMetadata for MultiBoolTensor {
100 fn dtype(&self) -> DType {
101 DType::U8
102 }
103
104 fn shape(&self) -> Shape {
105 match self {
106 #[cfg(feature = "burn-candle")]
107 MultiBoolTensor::Candle(t) => t.shape(),
108 #[cfg(feature = "burn-ndarray")]
109 MultiBoolTensor::NdArray(t) => t.shape(),
110 #[cfg(feature = "burn-wgpu")]
111 MultiBoolTensor::Wgpu(t) => t.shape(),
112 }
115 }
116}
117
118impl QTensorPrimitive for MultiIntTensor {
119 fn scheme(&self) -> &QuantScheme {
120 unimplemented!("Quantization is not supported")
121 }
122}