use crate::{backend::Backend, DType, Shape};
#[derive(Clone, Debug)]
pub struct Float;
#[derive(Clone, Debug)]
pub struct Int;
#[derive(Clone, Debug)]
pub struct Bool;
#[derive(Debug, Clone)]
pub enum TensorPrimitive<B: Backend> {
Float(B::FloatTensorPrimitive),
QFloat(B::QuantizedTensorPrimitive),
}
impl<B: Backend> TensorPrimitive<B> {
pub fn tensor(self) -> B::FloatTensorPrimitive {
match self {
Self::QFloat(tensor) => B::dequantize(tensor),
Self::Float(tensor) => tensor,
}
}
}
impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
fn dtype(&self) -> DType {
match self {
TensorPrimitive::Float(tensor) => tensor.dtype(),
TensorPrimitive::QFloat(tensor) => tensor.dtype(),
}
}
fn shape(&self) -> Shape {
match self {
TensorPrimitive::Float(tensor) => tensor.shape(),
TensorPrimitive::QFloat(tensor) => tensor.shape(),
}
}
}
pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
fn dtype(&self) -> DType;
fn shape(&self) -> Shape;
}
pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
type Primitive: TensorMetadata;
fn name() -> &'static str;
}
impl<B: Backend> TensorKind<B> for Float {
type Primitive = TensorPrimitive<B>;
fn name() -> &'static str {
"Float"
}
}
impl<B: Backend> TensorKind<B> for Int {
type Primitive = B::IntTensorPrimitive;
fn name() -> &'static str {
"Int"
}
}
impl<B: Backend> TensorKind<B> for Bool {
type Primitive = B::BoolTensorPrimitive;
fn name() -> &'static str {
"Bool"
}
}