burn_backend/tensor/
kind.rs1use crate::{Backend, TensorMetadata, TensorPrimitive};
2
3#[derive(Clone, Debug)]
5pub struct Float;
6
7#[derive(Clone, Debug)]
9pub struct Int;
10
11#[derive(Clone, Debug)]
13pub struct Bool;
14
15pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
18 type Primitive: TensorMetadata;
20
21 fn name() -> &'static str;
23}
24
25impl<B: Backend> TensorKind<B> for Float {
26 type Primitive = TensorPrimitive<B>;
27 fn name() -> &'static str {
28 "Float"
29 }
30}
31
32impl<B: Backend> TensorKind<B> for Int {
33 type Primitive = B::IntTensorPrimitive;
34 fn name() -> &'static str {
35 "Int"
36 }
37}
38
39impl<B: Backend> TensorKind<B> for Bool {
40 type Primitive = B::BoolTensorPrimitive;
41 fn name() -> &'static str {
42 "Bool"
43 }
44}