burn_backend/backend/
primitive.rs

1use crate::Backend;
2use burn_std::quantization::{QuantAcc, QuantPropagation, QuantScheme};
3use burn_std::{DType, Shape};
4
5#[derive(Debug, Clone)]
6/// A primitive tensor representation.
7pub enum TensorPrimitive<B: Backend> {
8    /// Float tensor primitive.
9    Float(B::FloatTensorPrimitive),
10    /// Quantized float tensor primitive.
11    QFloat(B::QuantizedTensorPrimitive),
12}
13
14impl<B: Backend> TensorPrimitive<B> {
15    /// Returns the full tensor representation.
16    pub fn tensor(self) -> B::FloatTensorPrimitive {
17        match self {
18            Self::QFloat(tensor) => B::dequantize(tensor),
19            Self::Float(tensor) => tensor,
20        }
21    }
22}
23
24impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
25    fn dtype(&self) -> DType {
26        match self {
27            TensorPrimitive::Float(tensor) => tensor.dtype(),
28            TensorPrimitive::QFloat(tensor) => tensor.dtype(),
29        }
30    }
31
32    fn shape(&self) -> Shape {
33        match self {
34            TensorPrimitive::Float(tensor) => tensor.shape(),
35            TensorPrimitive::QFloat(tensor) => tensor.shape(),
36        }
37    }
38
39    fn rank(&self) -> usize {
40        match self {
41            TensorPrimitive::Float(tensor) => tensor.rank(),
42            TensorPrimitive::QFloat(tensor) => tensor.rank(),
43        }
44    }
45}
46
47/// Tensor metadata trait for tensor primitive.
48pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
49    /// The dtype of the tensor.
50    fn dtype(&self) -> DType;
51    /// The shape of the tensor.
52    fn shape(&self) -> Shape;
53
54    /// The number of dimensions of the tensor.
55    fn rank(&self) -> usize {
56        self.shape().num_dims()
57    }
58}
59
60/// Quantized tensor primitive.
61pub trait QTensorPrimitive {
62    /// Returns the quantization settings for the given tensor.
63    fn scheme(&self) -> &QuantScheme;
64    /// The precision used for the accumulation in various kernels.
65    fn acc_precision(&self) -> QuantAcc {
66        QuantAcc::F32
67    }
68    /// How quantization is propagated during computation.
69    fn propagation(&self) -> QuantPropagation {
70        QuantPropagation::Inhibit
71    }
72
73    /// Returns the default tensor quantization scheme.
74    fn default_scheme() -> QuantScheme {
75        QuantScheme::default()
76    }
77}