burn_backend/backend/
primitive.rs1use crate::Backend;
2use burn_std::quantization::{QuantAcc, QuantPropagation, QuantScheme};
3use burn_std::{DType, Shape};
4
5#[derive(Debug, Clone)]
6pub enum TensorPrimitive<B: Backend> {
8 Float(B::FloatTensorPrimitive),
10 QFloat(B::QuantizedTensorPrimitive),
12}
13
14impl<B: Backend> TensorPrimitive<B> {
15 pub fn tensor(self) -> B::FloatTensorPrimitive {
17 match self {
18 Self::QFloat(tensor) => B::dequantize(tensor),
19 Self::Float(tensor) => tensor,
20 }
21 }
22}
23
24impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
25 fn dtype(&self) -> DType {
26 match self {
27 TensorPrimitive::Float(tensor) => tensor.dtype(),
28 TensorPrimitive::QFloat(tensor) => tensor.dtype(),
29 }
30 }
31
32 fn shape(&self) -> Shape {
33 match self {
34 TensorPrimitive::Float(tensor) => tensor.shape(),
35 TensorPrimitive::QFloat(tensor) => tensor.shape(),
36 }
37 }
38
39 fn rank(&self) -> usize {
40 match self {
41 TensorPrimitive::Float(tensor) => tensor.rank(),
42 TensorPrimitive::QFloat(tensor) => tensor.rank(),
43 }
44 }
45}
46
47pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
49 fn dtype(&self) -> DType;
51 fn shape(&self) -> Shape;
53
54 fn rank(&self) -> usize {
56 self.shape().num_dims()
57 }
58}
59
60pub trait QTensorPrimitive {
62 fn scheme(&self) -> &QuantScheme;
64 fn acc_precision(&self) -> QuantAcc {
66 QuantAcc::F32
67 }
68 fn propagation(&self) -> QuantPropagation {
70 QuantPropagation::Inhibit
71 }
72
73 fn default_scheme() -> QuantScheme {
75 QuantScheme::default()
76 }
77}