use alloc::vec::Vec;
use burn_backend::{DType, QTensorPrimitive, TensorMetadata, quantization::QuantStore};
use burn_std::{QuantScheme, Shape};
use crate::tensor::FlexTensor;
#[derive(Clone, Debug)]
pub struct FlexQTensor {
pub tensor: FlexTensor,
pub scheme: QuantScheme,
pub scales: Vec<f32>,
}
impl QTensorPrimitive for FlexQTensor {
fn scheme(&self) -> &QuantScheme {
&self.scheme
}
fn default_scheme() -> QuantScheme {
QuantScheme::default().with_store(QuantStore::Native)
}
}
impl TensorMetadata for FlexQTensor {
fn dtype(&self) -> DType {
DType::QFloat(self.scheme)
}
fn shape(&self) -> Shape {
self.tensor.shape()
}
fn rank(&self) -> usize {
self.tensor.rank()
}
}