1use burn_tensor::{
2 quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
3 DType, Element, Shape, TensorData, TensorMetadata,
4};
5
6use crate::{element::CandleElement, CandleDevice};
7
8#[derive(Debug, Clone)]
10pub struct CandleTensor {
11 pub(crate) tensor: candle_core::Tensor,
12}
13
14impl TensorMetadata for CandleTensor {
15 fn dtype(&self) -> DType {
16 match self.tensor.dtype() {
17 candle_core::DType::U8 => DType::U8,
18 candle_core::DType::U32 => DType::U32,
19 candle_core::DType::I64 => DType::I64,
20 candle_core::DType::BF16 => DType::BF16,
21 candle_core::DType::F16 => DType::F16,
22 candle_core::DType::F32 => DType::F32,
23 candle_core::DType::F64 => DType::F64,
24 }
25 }
26
27 fn shape(&self) -> Shape {
28 Shape::from(self.tensor.dims().to_vec())
29 }
30}
31
32impl CandleTensor {
33 pub fn new(tensor: candle_core::Tensor) -> Self {
35 Self { tensor }
36 }
37
38 pub fn from_data<E: CandleElement>(data: TensorData, device: CandleDevice) -> Self {
49 let candle_shape: candle_core::Shape = data.shape.clone().into();
50 let tensor = candle_core::Tensor::from_slice(
51 data.convert::<E>().as_slice::<E>().unwrap(),
52 candle_shape,
53 &device.into(),
54 );
55 Self::new(tensor.unwrap())
56 }
57}
58
59#[derive(Clone, Debug)]
61pub struct CandleQTensor {
62 pub qtensor: CandleTensor,
65 pub scheme: QuantizationScheme,
67}
68
69impl QTensorPrimitive for CandleQTensor {
70 fn scheme(&self) -> &QuantizationScheme {
71 &self.scheme
72 }
73}
74
75impl TensorMetadata for CandleQTensor {
76 fn dtype(&self) -> DType {
77 DType::QFloat(self.scheme)
78 }
79
80 fn shape(&self) -> Shape {
81 self.qtensor.shape()
82 }
83}