1use alloc::vec::Vec;
2
3use burn_backend::{DType, QTensorPrimitive, TensorMetadata, quantization::QuantStore};
4use burn_std::{QuantScheme, Shape};
5
6use crate::tensor::FlexTensor;
7
8#[derive(Clone, Debug)]
13pub struct FlexQTensor {
14 pub(crate) tensor: FlexTensor,
16 pub(crate) scheme: QuantScheme,
18 pub(crate) scales: Vec<f32>,
20}
21
22impl FlexQTensor {
23 pub fn new(tensor: FlexTensor, scheme: QuantScheme, scales: Vec<f32>) -> Self {
27 assert_eq!(
28 tensor.dtype(),
29 DType::I8,
30 "quantized tensor must store i8 data, got {:?}",
31 tensor.dtype()
32 );
33 assert!(
34 !scales.is_empty(),
35 "quantized tensor must have at least one scale factor"
36 );
37 Self {
38 tensor,
39 scheme,
40 scales,
41 }
42 }
43
44 pub fn tensor(&self) -> &FlexTensor {
46 &self.tensor
47 }
48
49 pub fn scales(&self) -> &[f32] {
51 &self.scales
52 }
53}
54
55impl QTensorPrimitive for FlexQTensor {
56 fn scheme(&self) -> &QuantScheme {
57 &self.scheme
58 }
59
60 fn default_scheme() -> QuantScheme {
61 QuantScheme::default().with_store(QuantStore::Native)
62 }
63}
64
65impl TensorMetadata for FlexQTensor {
66 fn dtype(&self) -> DType {
67 DType::QFloat(self.scheme)
68 }
69
70 fn shape(&self) -> Shape {
71 self.tensor.shape()
72 }
73
74 fn rank(&self) -> usize {
75 self.tensor.rank()
76 }
77}