burn_backend/backend/
primitive.rs1use crate::{Backend, get_device_settings};
2use burn_std::quantization::{QuantAcc, QuantPropagation, QuantScheme};
3use burn_std::{DType, Shape};
4
5#[derive(Debug, Clone)]
6pub enum TensorPrimitive<B: Backend> {
8 Float(B::FloatTensorPrimitive),
10 QFloat(B::QuantizedTensorPrimitive),
12}
13
14impl<B: Backend> TensorPrimitive<B> {
15 pub fn tensor(self) -> B::FloatTensorPrimitive {
17 match self {
18 Self::QFloat(tensor) => {
19 let dtype = get_device_settings::<B>(&B::q_device(&tensor)).float_dtype;
20 B::dequantize(tensor, dtype)
21 }
22 Self::Float(tensor) => tensor,
23 }
24 }
25
26 pub fn get_mut_ref(&mut self) -> &mut B::FloatTensorPrimitive {
28 match self {
29 Self::QFloat(_tensor) => todo!(),
30 Self::Float(tensor) => tensor,
31 }
32 }
33}
34
35impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
36 fn dtype(&self) -> DType {
37 match self {
38 TensorPrimitive::Float(tensor) => tensor.dtype(),
39 TensorPrimitive::QFloat(tensor) => tensor.dtype(),
40 }
41 }
42
43 fn shape(&self) -> Shape {
44 match self {
45 TensorPrimitive::Float(tensor) => tensor.shape(),
46 TensorPrimitive::QFloat(tensor) => tensor.shape(),
47 }
48 }
49
50 fn rank(&self) -> usize {
51 match self {
52 TensorPrimitive::Float(tensor) => tensor.rank(),
53 TensorPrimitive::QFloat(tensor) => tensor.rank(),
54 }
55 }
56}
57
58pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
60 fn dtype(&self) -> DType;
62 fn shape(&self) -> Shape;
64
65 fn rank(&self) -> usize {
67 self.shape().num_dims()
68 }
69}
70
71pub trait QTensorPrimitive {
73 fn scheme(&self) -> &QuantScheme;
75 fn acc_precision(&self) -> QuantAcc {
77 QuantAcc::F32
78 }
79 fn propagation(&self) -> QuantPropagation {
81 QuantPropagation::Inhibit
82 }
83
84 fn default_scheme() -> QuantScheme {
86 QuantScheme::default()
87 }
88}