burn_backend/tensor/
kind.rs

1use crate::{Backend, TensorMetadata, TensorPrimitive};
2
3/// A type-level representation of the kind of a float tensor
4#[derive(Clone, Debug)]
5pub struct Float;
6
7/// A type-level representation of the kind of a int tensor.
8#[derive(Clone, Debug)]
9pub struct Int;
10
11/// A type-level representation of the kind of a bool tensor.
12#[derive(Clone, Debug)]
13pub struct Bool;
14
15/// A type-level representation of the kind of a tensor.
16/// Metadata access is lazy.
17pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
18    /// The primitive type of the tensor.
19    type Primitive: TensorMetadata;
20
21    /// The name of the tensor kind.
22    fn name() -> &'static str;
23}
24
25impl<B: Backend> TensorKind<B> for Float {
26    type Primitive = TensorPrimitive<B>;
27    fn name() -> &'static str {
28        "Float"
29    }
30}
31
32impl<B: Backend> TensorKind<B> for Int {
33    type Primitive = B::IntTensorPrimitive;
34    fn name() -> &'static str {
35        "Int"
36    }
37}
38
39impl<B: Backend> TensorKind<B> for Bool {
40    type Primitive = B::BoolTensorPrimitive;
41    fn name() -> &'static str {
42        "Bool"
43    }
44}