use crate::Backend;
use burn_std::quantization::{QuantAcc, QuantPropagation, QuantScheme};
use burn_std::{DType, Shape};
#[derive(Debug, Clone)]
pub enum TensorPrimitive<B: Backend> {
Float(B::FloatTensorPrimitive),
QFloat(B::QuantizedTensorPrimitive),
}
impl<B: Backend> TensorPrimitive<B> {
pub fn tensor(self) -> B::FloatTensorPrimitive {
match self {
Self::QFloat(tensor) => B::dequantize(tensor),
Self::Float(tensor) => tensor,
}
}
}
impl<B: Backend> TensorMetadata for TensorPrimitive<B> {
fn dtype(&self) -> DType {
match self {
TensorPrimitive::Float(tensor) => tensor.dtype(),
TensorPrimitive::QFloat(tensor) => tensor.dtype(),
}
}
fn shape(&self) -> Shape {
match self {
TensorPrimitive::Float(tensor) => tensor.shape(),
TensorPrimitive::QFloat(tensor) => tensor.shape(),
}
}
fn rank(&self) -> usize {
match self {
TensorPrimitive::Float(tensor) => tensor.rank(),
TensorPrimitive::QFloat(tensor) => tensor.rank(),
}
}
}
pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {
fn dtype(&self) -> DType;
fn shape(&self) -> Shape;
fn rank(&self) -> usize {
self.shape().num_dims()
}
}
pub trait QTensorPrimitive {
fn scheme(&self) -> &QuantScheme;
fn acc_precision(&self) -> QuantAcc {
QuantAcc::F32
}
fn propagation(&self) -> QuantPropagation {
QuantPropagation::Inhibit
}
fn default_scheme() -> QuantScheme {
QuantScheme::default()
}
}