burn_tensor/tensor/api/
kind.rs1use crate::{backend::Backend, DType, Shape};
2
3#[derive(Clone, Debug)]
5pub struct Float;
6
7#[derive(Clone, Debug)]
9pub struct Int;
10
11#[derive(Clone, Debug)]
13pub struct Bool;
14
15#[derive(Debug, Clone)]
16pub enum TensorPrimitive<B: Backend> {
18 Float(B::FloatTensorPrimitive),
20 QFloat(B::QuantizedTensorPrimitive),
22}
23
24impl<B: Backend> TensorPrimitive<B> {
25 pub fn tensor(self) -> B::FloatTensorPrimitive {
27 match self {
28 Self::QFloat(tensor) => B::dequantize(tensor),
29 Self::Float(tensor) => tensor,
30 }
31 }
32}
33
34impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
35 fn dtype(&self) -> DType {
36 match self {
37 TensorPrimitive::Float(tensor) => tensor.dtype(),
38 TensorPrimitive::QFloat(tensor) => tensor.dtype(),
39 }
40 }
41
42 fn shape(&self) -> Shape {
43 match self {
44 TensorPrimitive::Float(tensor) => tensor.shape(),
45 TensorPrimitive::QFloat(tensor) => tensor.shape(),
46 }
47 }
48}
49
50pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
52 fn dtype(&self) -> DType;
54 fn shape(&self) -> Shape;
56}
57
58pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
61 type Primitive: TensorMetadata;
63
64 fn name() -> &'static str;
66}
67
68impl<B: Backend> TensorKind<B> for Float {
69 type Primitive = TensorPrimitive<B>;
70 fn name() -> &'static str {
71 "Float"
72 }
73}
74
75impl<B: Backend> TensorKind<B> for Int {
76 type Primitive = B::IntTensorPrimitive;
77 fn name() -> &'static str {
78 "Int"
79 }
80}
81
82impl<B: Backend> TensorKind<B> for Bool {
83 type Primitive = B::BoolTensorPrimitive;
84 fn name() -> &'static str {
85 "Bool"
86 }
87}