burn_tensor/tensor/api/
kind.rs

1use crate::{DType, Shape, backend::Backend};
2
3/// A type-level representation of the kind of a float tensor
4#[derive(Clone, Debug)]
5pub struct Float;
6
7/// A type-level representation of the kind of a int tensor.
8#[derive(Clone, Debug)]
9pub struct Int;
10
11/// A type-level representation of the kind of a bool tensor.
12#[derive(Clone, Debug)]
13pub struct Bool;
14
15#[derive(Debug, Clone)]
16/// A primitive tensor representation.
17pub enum TensorPrimitive<B: Backend> {
18    /// Float tensor primitive.
19    Float(B::FloatTensorPrimitive),
20    /// Quantized float tensor primitive.
21    QFloat(B::QuantizedTensorPrimitive),
22}
23
24impl<B: Backend> TensorPrimitive<B> {
25    /// Returns the full tensor representation.
26    pub fn tensor(self) -> B::FloatTensorPrimitive {
27        match self {
28            Self::QFloat(tensor) => B::dequantize(tensor),
29            Self::Float(tensor) => tensor,
30        }
31    }
32}
33
34impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
35    fn dtype(&self) -> DType {
36        match self {
37            TensorPrimitive::Float(tensor) => tensor.dtype(),
38            TensorPrimitive::QFloat(tensor) => tensor.dtype(),
39        }
40    }
41
42    fn shape(&self) -> Shape {
43        match self {
44            TensorPrimitive::Float(tensor) => tensor.shape(),
45            TensorPrimitive::QFloat(tensor) => tensor.shape(),
46        }
47    }
48
49    fn rank(&self) -> usize {
50        match self {
51            TensorPrimitive::Float(tensor) => tensor.rank(),
52            TensorPrimitive::QFloat(tensor) => tensor.rank(),
53        }
54    }
55}
56
57/// Tensor metadata trait for tensor primitive.
58pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
59    /// The dtype of the tensor.
60    fn dtype(&self) -> DType;
61    /// The shape of the tensor.
62    fn shape(&self) -> Shape;
63
64    /// The number of dimensions of the tensor.
65    fn rank(&self) -> usize {
66        self.shape().num_dims()
67    }
68}
69
70/// A type-level representation of the kind of a tensor.
71/// Metadata access is lazy.
72pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
73    /// The primitive type of the tensor.
74    type Primitive: TensorMetadata;
75
76    /// The name of the tensor kind.
77    fn name() -> &'static str;
78}
79
80impl<B: Backend> TensorKind<B> for Float {
81    type Primitive = TensorPrimitive<B>;
82    fn name() -> &'static str {
83        "Float"
84    }
85}
86
87impl<B: Backend> TensorKind<B> for Int {
88    type Primitive = B::IntTensorPrimitive;
89    fn name() -> &'static str {
90        "Int"
91    }
92}
93
94impl<B: Backend> TensorKind<B> for Bool {
95    type Primitive = B::BoolTensorPrimitive;
96    fn name() -> &'static str {
97        "Bool"
98    }
99}