burn_tensor/tensor/api/
kind.rs

1use crate::{backend::Backend, DType, Shape};
2
3/// A type-level representation of the kind of a float tensor
4#[derive(Clone, Debug)]
5pub struct Float;
6
7/// A type-level representation of the kind of a int tensor.
8#[derive(Clone, Debug)]
9pub struct Int;
10
11/// A type-level representation of the kind of a bool tensor.
12#[derive(Clone, Debug)]
13pub struct Bool;
14
15#[derive(Debug, Clone)]
16/// A primitive tensor representation.
17pub enum TensorPrimitive<B: Backend> {
18    /// Float tensor primitive.
19    Float(B::FloatTensorPrimitive),
20    /// Quantized float tensor primitive.
21    QFloat(B::QuantizedTensorPrimitive),
22}
23
24impl<B: Backend> TensorPrimitive<B> {
25    /// Returns the full tensor representation.
26    pub fn tensor(self) -> B::FloatTensorPrimitive {
27        match self {
28            Self::QFloat(tensor) => B::dequantize(tensor),
29            Self::Float(tensor) => tensor,
30        }
31    }
32}
33
34impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
35    fn dtype(&self) -> DType {
36        match self {
37            TensorPrimitive::Float(tensor) => tensor.dtype(),
38            TensorPrimitive::QFloat(tensor) => tensor.dtype(),
39        }
40    }
41
42    fn shape(&self) -> Shape {
43        match self {
44            TensorPrimitive::Float(tensor) => tensor.shape(),
45            TensorPrimitive::QFloat(tensor) => tensor.shape(),
46        }
47    }
48}
49
50/// Tensor metadata trait for tensor primitive.
51pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
52    /// The dtype of the tensor.
53    fn dtype(&self) -> DType;
54    /// The shape of the tensor.
55    fn shape(&self) -> Shape;
56}
57
58/// A type-level representation of the kind of a tensor.
59/// Metadata access is lazy.
60pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
61    /// The primitive type of the tensor.
62    type Primitive: TensorMetadata;
63
64    /// The name of the tensor kind.
65    fn name() -> &'static str;
66}
67
68impl<B: Backend> TensorKind<B> for Float {
69    type Primitive = TensorPrimitive<B>;
70    fn name() -> &'static str {
71        "Float"
72    }
73}
74
75impl<B: Backend> TensorKind<B> for Int {
76    type Primitive = B::IntTensorPrimitive;
77    fn name() -> &'static str {
78        "Int"
79    }
80}
81
82impl<B: Backend> TensorKind<B> for Bool {
83    type Primitive = B::BoolTensorPrimitive;
84    fn name() -> &'static str {
85        "Bool"
86    }
87}