burn_candle/
tensor.rs

1use burn_backend::{DType, FloatDType, IntDType, Shape, quantization::QuantScheme};
2use burn_backend::{Element, QTensorPrimitive, TensorData, TensorMetadata};
3
4use crate::{CandleDevice, element::CandleElement};
5
6/// A tensor that uses the candle backend.
7#[derive(Debug, Clone)]
8pub struct CandleTensor {
9    pub(crate) tensor: candle_core::Tensor,
10}
11
12impl TensorMetadata for CandleTensor {
13    fn dtype(&self) -> DType {
14        match self.tensor.dtype() {
15            candle_core::DType::U8 => DType::U8,
16            candle_core::DType::U32 => DType::U32,
17            candle_core::DType::I64 => DType::I64,
18            candle_core::DType::BF16 => DType::BF16,
19            candle_core::DType::F16 => DType::F16,
20            candle_core::DType::F32 => DType::F32,
21            candle_core::DType::F64 => DType::F64,
22        }
23    }
24
25    fn shape(&self) -> Shape {
26        Shape::from(self.tensor.dims().to_vec())
27    }
28
29    fn rank(&self) -> usize {
30        self.tensor.dims().len()
31    }
32}
33
34impl QTensorPrimitive for CandleTensor {
35    fn scheme(&self) -> &QuantScheme {
36        unimplemented!("Quantization is not supported")
37    }
38}
39
40impl CandleTensor {
41    /// Create a new tensor.
42    pub fn new(tensor: candle_core::Tensor) -> Self {
43        Self { tensor }
44    }
45
46    /// Creates a new tensor from data and a device.
47    ///
48    /// # Arguments
49    ///
50    /// * `data` - The tensor's data.
51    /// * `device` - The device on which the tensor will be allocated.
52    ///
53    /// # Returns
54    ///
55    /// A new tensor.
56    pub fn from_data<E: CandleElement>(data: TensorData, device: CandleDevice) -> Self {
57        let candle_shape: candle_core::Shape = data.shape.clone().into();
58        let tensor = candle_core::Tensor::from_slice(
59            data.as_slice::<E>().unwrap(),
60            candle_shape,
61            &device.into(),
62        );
63        Self::new(tensor.unwrap())
64    }
65}
66
67pub(crate) trait IntoDType {
68    fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error>;
69
70    fn into_dtype(self) -> candle_core::DType
71    where
72        Self: Sized,
73    {
74        self.try_into_dtype().unwrap()
75    }
76}
77
78impl IntoDType for IntDType {
79    fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
80        let dtype: DType = self.into();
81        dtype.try_into_dtype()
82    }
83}
84
85impl IntoDType for FloatDType {
86    fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
87        let dtype: DType = self.into();
88        dtype.try_into_dtype()
89    }
90}
91
92impl IntoDType for DType {
93    fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {
94        match self {
95            DType::F64 => Ok(candle_core::DType::F64),
96            DType::F32 => Ok(candle_core::DType::F32),
97            DType::Flex32 => Ok(candle_core::DType::F32),
98            DType::F16 => Ok(candle_core::DType::F16),
99            DType::BF16 => Ok(candle_core::DType::BF16),
100            DType::I64 => Ok(candle_core::DType::I64),
101            DType::U32 => Ok(candle_core::DType::U32),
102            DType::U8 => Ok(candle_core::DType::U8),
103            // DType::Bool => Ok(candle_core::DType::U8),
104            _ => Err(candle_core::Error::Msg(format!(
105                "Unsupported dtype {self:?}"
106            ))),
107        }
108    }
109}