use serde::{Deserialize, Serialize};
use super::{QuantGranularity, QuantMode};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct QuantParams {
pub scales: Vec<f32>,
pub zero_points: Vec<i32>,
pub granularity: QuantGranularity,
pub mode: QuantMode,
pub bits: u8,
}
impl QuantParams {
pub fn num_groups(&self) -> usize {
self.scales.len()
}
pub fn is_asymmetric(&self) -> bool {
self.mode == QuantMode::Asymmetric
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct QuantizedTensor {
pub data: Vec<i8>,
pub params: QuantParams,
pub shape: Vec<usize>,
}
impl QuantizedTensor {
pub fn memory_bytes(&self) -> usize {
let data_bytes = self.data.len();
let scale_bytes = self.params.scales.len() * 4;
let zp_bytes = self.params.zero_points.len() * 4;
data_bytes + scale_bytes + zp_bytes
}
}