burn_candle/
tensor.rs

1use burn_tensor::{
2    quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
3    DType, Element, Shape, TensorData, TensorMetadata,
4};
5
6use crate::{element::CandleElement, CandleDevice};
7
8/// A tensor that uses the candle backend.
9#[derive(Debug, Clone)]
10pub struct CandleTensor {
11    pub(crate) tensor: candle_core::Tensor,
12}
13
14impl TensorMetadata for CandleTensor {
15    fn dtype(&self) -> DType {
16        match self.tensor.dtype() {
17            candle_core::DType::U8 => DType::U8,
18            candle_core::DType::U32 => DType::U32,
19            candle_core::DType::I64 => DType::I64,
20            candle_core::DType::BF16 => DType::BF16,
21            candle_core::DType::F16 => DType::F16,
22            candle_core::DType::F32 => DType::F32,
23            candle_core::DType::F64 => DType::F64,
24        }
25    }
26
27    fn shape(&self) -> Shape {
28        Shape::from(self.tensor.dims().to_vec())
29    }
30}
31
32impl CandleTensor {
33    /// Create a new tensor.
34    pub fn new(tensor: candle_core::Tensor) -> Self {
35        Self { tensor }
36    }
37
38    /// Creates a new tensor from data and a device.
39    ///
40    /// # Arguments
41    ///
42    /// * `data` - The tensor's data.
43    /// * `device` - The device on which the tensor will be allocated.
44    ///
45    /// # Returns
46    ///
47    /// A new tensor.
48    pub fn from_data<E: CandleElement>(data: TensorData, device: CandleDevice) -> Self {
49        let candle_shape: candle_core::Shape = data.shape.clone().into();
50        let tensor = candle_core::Tensor::from_slice(
51            data.convert::<E>().as_slice::<E>().unwrap(),
52            candle_shape,
53            &device.into(),
54        );
55        Self::new(tensor.unwrap())
56    }
57}
58
59/// A quantized tensor for the candle backend.
60#[derive(Clone, Debug)]
61pub struct CandleQTensor {
62    /// The quantized tensor.
63    // NOTE: candle  does not implement `WithDType` for i8
64    pub qtensor: CandleTensor,
65    /// The quantization scheme.
66    pub scheme: QuantizationScheme,
67}
68
69impl QTensorPrimitive for CandleQTensor {
70    fn scheme(&self) -> &QuantizationScheme {
71        &self.scheme
72    }
73}
74
75impl TensorMetadata for CandleQTensor {
76    fn dtype(&self) -> DType {
77        DType::QFloat(self.scheme)
78    }
79
80    fn shape(&self) -> Shape {
81        self.qtensor.shape()
82    }
83}