Skip to main content

burn_flex/
qtensor.rs

1use alloc::vec::Vec;
2
3use burn_backend::{DType, QTensorPrimitive, TensorMetadata, quantization::QuantStore};
4use burn_std::{QuantScheme, Shape};
5
6use crate::tensor::FlexTensor;
7
8/// Quantized tensor for the Flex backend.
9///
10/// Stores quantized i8 values in the tensor and keeps scales separately
11/// for efficient dequantization without reparsing bytes.
12#[derive(Clone, Debug)]
13pub struct FlexQTensor {
14    /// The underlying quantized data (stored as i8).
15    pub(crate) tensor: FlexTensor,
16    /// Quantization scheme.
17    pub(crate) scheme: QuantScheme,
18    /// Per-tensor or per-block scale factors.
19    pub(crate) scales: Vec<f32>,
20}
21
22impl FlexQTensor {
23    /// Create a new quantized tensor.
24    ///
25    /// The tensor must store i8 data and scales must be non-empty.
26    pub fn new(tensor: FlexTensor, scheme: QuantScheme, scales: Vec<f32>) -> Self {
27        assert_eq!(
28            tensor.dtype(),
29            DType::I8,
30            "quantized tensor must store i8 data, got {:?}",
31            tensor.dtype()
32        );
33        assert!(
34            !scales.is_empty(),
35            "quantized tensor must have at least one scale factor"
36        );
37        Self {
38            tensor,
39            scheme,
40            scales,
41        }
42    }
43
44    /// Get the underlying tensor.
45    pub fn tensor(&self) -> &FlexTensor {
46        &self.tensor
47    }
48
49    /// Get the quantization scales.
50    pub fn scales(&self) -> &[f32] {
51        &self.scales
52    }
53}
54
55impl QTensorPrimitive for FlexQTensor {
56    fn scheme(&self) -> &QuantScheme {
57        &self.scheme
58    }
59
60    fn default_scheme() -> QuantScheme {
61        QuantScheme::default().with_store(QuantStore::Native)
62    }
63}
64
65impl TensorMetadata for FlexQTensor {
66    fn dtype(&self) -> DType {
67        DType::QFloat(self.scheme)
68    }
69
70    fn shape(&self) -> Shape {
71        self.tensor.shape()
72    }
73
74    fn rank(&self) -> usize {
75        self.tensor.rank()
76    }
77}