use crate::{Backend, TensorMetadata, TensorPrimitive};
#[derive(Clone, Debug)]
pub struct Float;
#[derive(Clone, Debug)]
pub struct Int;
#[derive(Clone, Debug)]
pub struct Bool;
pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
type Primitive: TensorMetadata;
fn name() -> &'static str;
}
impl<B: Backend> TensorKind<B> for Float {
type Primitive = TensorPrimitive<B>;
fn name() -> &'static str {
"Float"
}
}
impl<B: Backend> TensorKind<B> for Int {
type Primitive = B::IntTensorPrimitive;
fn name() -> &'static str {
"Int"
}
}
impl<B: Backend> TensorKind<B> for Bool {
type Primitive = B::BoolTensorPrimitive;
fn name() -> &'static str {
"Bool"
}
}