burn_tensor/tensor/api/
kind.rs1use crate::{DType, Shape, backend::Backend};
2
3#[derive(Clone, Debug)]
5pub struct Float;
6
7#[derive(Clone, Debug)]
9pub struct Int;
10
11#[derive(Clone, Debug)]
13pub struct Bool;
14
15#[derive(Debug, Clone)]
16pub enum TensorPrimitive<B: Backend> {
18 Float(B::FloatTensorPrimitive),
20 QFloat(B::QuantizedTensorPrimitive),
22}
23
24impl<B: Backend> TensorPrimitive<B> {
25 pub fn tensor(self) -> B::FloatTensorPrimitive {
27 match self {
28 Self::QFloat(tensor) => B::dequantize(tensor),
29 Self::Float(tensor) => tensor,
30 }
31 }
32}
33
34impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
35 fn dtype(&self) -> DType {
36 match self {
37 TensorPrimitive::Float(tensor) => tensor.dtype(),
38 TensorPrimitive::QFloat(tensor) => tensor.dtype(),
39 }
40 }
41
42 fn shape(&self) -> Shape {
43 match self {
44 TensorPrimitive::Float(tensor) => tensor.shape(),
45 TensorPrimitive::QFloat(tensor) => tensor.shape(),
46 }
47 }
48
49 fn rank(&self) -> usize {
50 match self {
51 TensorPrimitive::Float(tensor) => tensor.rank(),
52 TensorPrimitive::QFloat(tensor) => tensor.rank(),
53 }
54 }
55}
56
57pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
59 fn dtype(&self) -> DType;
61 fn shape(&self) -> Shape;
63
64 fn rank(&self) -> usize {
66 self.shape().num_dims()
67 }
68}
69
70pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
73 type Primitive: TensorMetadata;
75
76 fn name() -> &'static str;
78}
79
80impl<B: Backend> TensorKind<B> for Float {
81 type Primitive = TensorPrimitive<B>;
82 fn name() -> &'static str {
83 "Float"
84 }
85}
86
87impl<B: Backend> TensorKind<B> for Int {
88 type Primitive = B::IntTensorPrimitive;
89 fn name() -> &'static str {
90 "Int"
91 }
92}
93
94impl<B: Backend> TensorKind<B> for Bool {
95 type Primitive = B::BoolTensorPrimitive;
96 fn name() -> &'static str {
97 "Bool"
98 }
99}