Skip to main content

burn_backend/backend/
primitive.rs

1use crate::{Backend, get_device_settings};
2use burn_std::quantization::{QuantAcc, QuantPropagation, QuantScheme};
3use burn_std::{DType, Shape};
4
5#[derive(Debug, Clone)]
6/// A primitive tensor representation.
7pub enum TensorPrimitive<B: Backend> {
8    /// Float tensor primitive.
9    Float(B::FloatTensorPrimitive),
10    /// Quantized float tensor primitive.
11    QFloat(B::QuantizedTensorPrimitive),
12}
13
14impl<B: Backend> TensorPrimitive<B> {
15    /// Returns the full tensor representation.
16    pub fn tensor(self) -> B::FloatTensorPrimitive {
17        match self {
18            Self::QFloat(tensor) => {
19                let dtype = get_device_settings::<B>(&B::q_device(&tensor)).float_dtype;
20                B::dequantize(tensor, dtype)
21            }
22            Self::Float(tensor) => tensor,
23        }
24    }
25
26    /// Returns a mutable reference to the full tensor representation.
27    pub fn get_mut_ref(&mut self) -> &mut B::FloatTensorPrimitive {
28        match self {
29            Self::QFloat(_tensor) => todo!(),
30            Self::Float(tensor) => tensor,
31        }
32    }
33}
34
35impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
36    fn dtype(&self) -> DType {
37        match self {
38            TensorPrimitive::Float(tensor) => tensor.dtype(),
39            TensorPrimitive::QFloat(tensor) => tensor.dtype(),
40        }
41    }
42
43    fn shape(&self) -> Shape {
44        match self {
45            TensorPrimitive::Float(tensor) => tensor.shape(),
46            TensorPrimitive::QFloat(tensor) => tensor.shape(),
47        }
48    }
49
50    fn rank(&self) -> usize {
51        match self {
52            TensorPrimitive::Float(tensor) => tensor.rank(),
53            TensorPrimitive::QFloat(tensor) => tensor.rank(),
54        }
55    }
56}
57
58/// Tensor metadata trait for tensor primitive.
59pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
60    /// The dtype of the tensor.
61    fn dtype(&self) -> DType;
62    /// The shape of the tensor.
63    fn shape(&self) -> Shape;
64
65    /// The number of dimensions of the tensor.
66    fn rank(&self) -> usize {
67        self.shape().num_dims()
68    }
69}
70
71/// Quantized tensor primitive.
72pub trait QTensorPrimitive {
73    /// Returns the quantization settings for the given tensor.
74    fn scheme(&self) -> &QuantScheme;
75    /// The precision used for the accumulation in various kernels.
76    fn acc_precision(&self) -> QuantAcc {
77        QuantAcc::F32
78    }
79    /// How quantization is propagated during computation.
80    fn propagation(&self) -> QuantPropagation {
81        QuantPropagation::Inhibit
82    }
83
84    /// Returns the default tensor quantization scheme.
85    fn default_scheme() -> QuantScheme {
86        QuantScheme::default()
87    }
88}