1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
use crate::backend::Backend;

#[derive(Clone, Debug)]
pub struct Float;
#[derive(Clone, Debug)]
pub struct Int;
#[derive(Clone, Debug)]
pub struct Bool;

pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
    type Primitive<const D: usize>: Clone + core::fmt::Debug;
}

impl<B: Backend> TensorKind<B> for Float {
    type Primitive<const D: usize> = B::TensorPrimitive<D>;
}

impl<B: Backend> TensorKind<B> for Int {
    type Primitive<const D: usize> = B::IntTensorPrimitive<D>;
}

impl<B: Backend> TensorKind<B> for Bool {
    type Primitive<const D: usize> = B::BoolTensorPrimitive<D>;
}