1use burn_backend::{DType, FloatDType, IntDType, Shape, quantization::QuantScheme};
2use burn_backend::{Element, QTensorPrimitive, TensorData, TensorMetadata};
3
4use crate::{CandleDevice, element::CandleElement};
5
6#[derive(Debug, Clone)]
8pub struct CandleTensor {
9 pub(crate) tensor: candle_core::Tensor,
10}
11
12impl TensorMetadata for CandleTensor {
13 fn dtype(&self) -> DType {
14 match self.tensor.dtype() {
15 candle_core::DType::U8 => DType::U8,
16 candle_core::DType::U32 => DType::U32,
17 candle_core::DType::I64 => DType::I64,
18 candle_core::DType::BF16 => DType::BF16,
19 candle_core::DType::F16 => DType::F16,
20 candle_core::DType::F32 => DType::F32,
21 candle_core::DType::F64 => DType::F64,
22 }
23 }
24
25 fn shape(&self) -> Shape {
26 Shape::from(self.tensor.dims().to_vec())
27 }
28
29 fn rank(&self) -> usize {
30 self.tensor.dims().len()
31 }
32}
33
34impl QTensorPrimitive for CandleTensor {
35 fn scheme(&self) -> &QuantScheme {
36 unimplemented!("Quantization is not supported")
37 }
38}
39
40impl CandleTensor {
41 pub fn new(tensor: candle_core::Tensor) -> Self {
43 Self { tensor }
44 }
45
46 pub fn from_data<E: CandleElement>(data: TensorData, device: CandleDevice) -> Self {
57 let candle_shape: candle_core::Shape = data.shape.clone().into();
58 let tensor = candle_core::Tensor::from_slice(
59 data.as_slice::<E>().unwrap(),
60 candle_shape,
61 &device.into(),
62 );
63 Self::new(tensor.unwrap())
64 }
65}
66
67pub(crate) trait IntoDType {
68 fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error>;
69
70 fn into_dtype(self) -> candle_core::DType
71 where
72 Self: Sized,
73 {
74 self.try_into_dtype().unwrap()
75 }
76}
77
78impl IntoDType for IntDType {
79 fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
80 let dtype: DType = self.into();
81 dtype.try_into_dtype()
82 }
83}
84
85impl IntoDType for FloatDType {
86 fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
87 let dtype: DType = self.into();
88 dtype.try_into_dtype()
89 }
90}
91
92impl IntoDType for DType {
93 fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
94 match self {
95 DType::F64 => Ok(candle_core::DType::F64),
96 DType::F32 => Ok(candle_core::DType::F32),
97 DType::Flex32 => Ok(candle_core::DType::F32),
98 DType::F16 => Ok(candle_core::DType::F16),
99 DType::BF16 => Ok(candle_core::DType::BF16),
100 DType::I64 => Ok(candle_core::DType::I64),
101 DType::U32 => Ok(candle_core::DType::U32),
102 DType::U8 => Ok(candle_core::DType::U8),
103 _ => Err(candle_core::Error::Msg(format!(
105 "Unsupported dtype {self:?}"
106 ))),
107 }
108 }
109}