use alloc::vec::Vec;
use burn_backend::{DType, QTensorPrimitive, TensorMetadata, quantization::QuantStore};
use burn_std::{QuantScheme, Shape};
use crate::tensor::FlexTensor;
#[derive(Clone, Debug)]
pub struct FlexQTensor {
pub(crate) tensor: FlexTensor,
pub(crate) scheme: QuantScheme,
pub(crate) scales: Vec<f32>,
}
impl FlexQTensor {
pub fn new(tensor: FlexTensor, scheme: QuantScheme, scales: Vec<f32>) -> Self {
assert_eq!(
tensor.dtype(),
DType::I8,
"quantized tensor must store i8 data, got {:?}",
tensor.dtype()
);
assert!(
!scales.is_empty(),
"quantized tensor must have at least one scale factor"
);
Self {
tensor,
scheme,
scales,
}
}
pub fn tensor(&self) -> &FlexTensor {
&self.tensor
}
pub fn scales(&self) -> &[f32] {
&self.scales
}
}
impl QTensorPrimitive for FlexQTensor {
fn scheme(&self) -> &QuantScheme {
&self.scheme
}
fn default_scheme() -> QuantScheme {
QuantScheme::default().with_store(QuantStore::Native)
}
}
impl TensorMetadata for FlexQTensor {
fn dtype(&self) -> DType {
DType::QFloat(self.scheme)
}
fn shape(&self) -> Shape {
self.tensor.shape()
}
fn rank(&self) -> usize {
self.tensor.rank()
}
}