#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::{collections::BTreeMap as HashMap, string::String};
use torsh_core::{
dtype::DType,
error::{Result as TorshResult, TorshError},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum QScheme {
PerTensorAffine,
PerChannelAffine,
PerTensorSymmetric,
PerChannelSymmetric,
Int4PerTensor,
Int4PerChannel,
MixedPrecision,
Binary,
Ternary,
GroupWise,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum QuantBackend {
Fbgemm,
Qnnpack,
Native,
Xnnpack,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ReduceRange {
None,
Reduce,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ObserverType {
MinMax,
MovingAverage,
Histogram,
Percentile,
KLDivergence,
Entropy,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct QuantConfig {
pub dtype: DType,
pub scheme: QScheme,
pub enable_fake_quant: bool,
pub observer_type: ObserverType,
pub backend: QuantBackend,
pub reduce_range: ReduceRange,
pub qint_min: Option<i32>,
pub qint_max: Option<i32>,
pub eps: f32,
pub averaging_constant: f32,
pub ch_axis: Option<usize>,
pub group_size: Option<usize>,
}
impl Default for QuantConfig {
fn default() -> Self {
Self {
dtype: DType::I8,
scheme: QScheme::PerTensorAffine,
enable_fake_quant: false,
observer_type: ObserverType::MinMax,
backend: QuantBackend::Native,
reduce_range: ReduceRange::None,
qint_min: None,
qint_max: None,
eps: 1e-8,
averaging_constant: 0.01,
ch_axis: None,
group_size: None,
}
}
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionConfig {
pub layer_precision: HashMap<String, DType>,
pub default_precision: DType,
pub sensitivity_threshold: f32,
}
impl Default for MixedPrecisionConfig {
fn default() -> Self {
let mut layer_precision = HashMap::new();
layer_precision.insert("embedding".to_string(), DType::I8);
layer_precision.insert("attention".to_string(), DType::F16);
layer_precision.insert("output".to_string(), DType::F32);
Self {
layer_precision,
default_precision: DType::I8,
sensitivity_threshold: 0.1,
}
}
}
impl QuantConfig {
pub fn new() -> Self {
Self::default()
}
pub fn int8() -> Self {
Self {
dtype: DType::I8,
qint_min: Some(-128),
qint_max: Some(127),
..Self::default()
}
}
pub fn int4() -> Self {
Self {
dtype: DType::I8, scheme: QScheme::Int4PerTensor,
qint_min: Some(-8),
qint_max: Some(7),
observer_type: ObserverType::Histogram,
..Self::default()
}
}
pub fn binary() -> Self {
Self {
dtype: DType::I8,
scheme: QScheme::Binary,
qint_min: Some(-1),
qint_max: Some(1),
observer_type: ObserverType::MinMax,
..Self::default()
}
}
pub fn ternary() -> Self {
Self {
dtype: DType::I8,
scheme: QScheme::Ternary,
qint_min: Some(-1),
qint_max: Some(1),
observer_type: ObserverType::MinMax,
..Self::default()
}
}
pub fn mixed_precision() -> Self {
Self {
dtype: DType::I8, scheme: QScheme::MixedPrecision,
observer_type: ObserverType::KLDivergence,
..Self::default()
}
}
pub fn uint8() -> Self {
Self {
dtype: DType::U8,
qint_min: Some(0),
qint_max: Some(255),
..Self::default()
}
}
pub fn per_channel(ch_axis: usize) -> Self {
Self {
scheme: QScheme::PerChannelAffine,
ch_axis: Some(ch_axis),
..Self::default()
}
}
pub fn group_wise(ch_axis: usize, group_size: usize) -> Self {
Self {
scheme: QScheme::GroupWise,
ch_axis: Some(ch_axis),
group_size: Some(group_size),
observer_type: ObserverType::Histogram,
..Self::default()
}
}
pub fn qat() -> Self {
Self {
enable_fake_quant: true,
observer_type: ObserverType::MovingAverage,
..Self::default()
}
}
pub fn with_backend(mut self, backend: QuantBackend) -> Self {
self.backend = backend;
self
}
pub fn with_observer(mut self, observer_type: ObserverType) -> Self {
self.observer_type = observer_type;
self
}
pub fn with_scheme(mut self, scheme: QScheme) -> Self {
self.scheme = scheme;
if matches!(
scheme,
QScheme::PerChannelAffine | QScheme::PerChannelSymmetric | QScheme::GroupWise
) && self.ch_axis.is_none()
{
self.ch_axis = Some(0); }
if matches!(scheme, QScheme::GroupWise) && self.group_size.is_none() {
self.group_size = Some(32); }
self
}
pub fn with_fake_quant(mut self, enable: bool) -> Self {
self.enable_fake_quant = enable;
self
}
pub fn with_reduce_range(mut self, reduce_range: ReduceRange) -> Self {
self.reduce_range = reduce_range;
self
}
pub fn with_group_size(mut self, group_size: usize) -> Self {
self.group_size = Some(group_size);
self
}
pub fn get_qint_range(&self) -> (i32, i32) {
let (base_min, base_max) = match self.scheme {
QScheme::Int4PerTensor | QScheme::Int4PerChannel => (-8, 7),
QScheme::Binary => (-1, 1),
QScheme::Ternary => (-1, 1),
_ => match self.dtype {
DType::I8 => (-128, 127),
DType::U8 => (0, 255),
_ => (self.qint_min.unwrap_or(-128), self.qint_max.unwrap_or(127)),
},
};
let (qmin, qmax) = match self.reduce_range {
ReduceRange::None => (base_min, base_max),
ReduceRange::Reduce => {
let range = base_max - base_min;
let reduced_range = range / 2;
let mid = (base_min + base_max) / 2;
(mid - reduced_range / 2, mid + reduced_range / 2)
}
};
(qmin, qmax)
}
pub fn validate(&self) -> TorshResult<()> {
if matches!(
self.scheme,
QScheme::PerChannelAffine | QScheme::PerChannelSymmetric | QScheme::GroupWise
) && self.ch_axis.is_none()
{
return Err(TorshError::InvalidArgument(
"Per-channel/Group-wise quantization requires channel axis".to_string(),
));
}
if matches!(self.scheme, QScheme::GroupWise) {
if self.group_size.is_none() {
return Err(TorshError::InvalidArgument(
"Group-wise quantization requires group size".to_string(),
));
}
if let Some(group_size) = self.group_size {
if group_size == 0 {
return Err(TorshError::InvalidArgument(
"Group size must be greater than 0".to_string(),
));
}
}
}
if matches!(
self.scheme,
QScheme::PerTensorSymmetric | QScheme::PerChannelSymmetric
) {
}
if matches!(self.scheme, QScheme::Binary | QScheme::Ternary)
&& !matches!(
self.observer_type,
ObserverType::MinMax | ObserverType::MovingAverage
)
{
return Err(TorshError::InvalidArgument(
"Binary/ternary quantization requires MinMax or MovingAverage observer".to_string(),
));
}
if matches!(self.scheme, QScheme::MixedPrecision)
&& !matches!(
self.observer_type,
ObserverType::KLDivergence | ObserverType::Entropy
)
{
return Err(TorshError::InvalidArgument(
"Mixed precision quantization requires KLDivergence or Entropy observer"
.to_string(),
));
}
if self.eps <= 0.0 {
return Err(TorshError::InvalidArgument(
"eps must be positive".to_string(),
));
}
if self.averaging_constant <= 0.0 || self.averaging_constant >= 1.0 {
return Err(TorshError::InvalidArgument(
"averaging_constant must be in (0, 1)".to_string(),
));
}
Ok(())
}
}
pub struct QuantConfigBuilder {
config: QuantConfig,
}
impl QuantConfigBuilder {
pub fn new() -> Self {
Self {
config: QuantConfig::default(),
}
}
pub fn dtype(mut self, dtype: DType) -> Self {
self.config.dtype = dtype;
self
}
pub fn scheme(mut self, scheme: QScheme) -> Self {
self.config = self.config.with_scheme(scheme);
self
}
pub fn observer(mut self, observer_type: ObserverType) -> Self {
self.config.observer_type = observer_type;
self
}
pub fn backend(mut self, backend: QuantBackend) -> Self {
self.config.backend = backend;
self
}
pub fn fake_quant(mut self, enable: bool) -> Self {
self.config.enable_fake_quant = enable;
self
}
pub fn channel_axis(mut self, axis: usize) -> Self {
self.config.ch_axis = Some(axis);
self
}
pub fn group_size(mut self, size: usize) -> Self {
self.config.group_size = Some(size);
self
}
pub fn build(self) -> TorshResult<QuantConfig> {
self.config.validate()?;
Ok(self.config)
}
}
impl Default for QuantConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quant_config_defaults() {
let config = QuantConfig::default();
assert_eq!(config.dtype, DType::I8);
assert_eq!(config.scheme, QScheme::PerTensorAffine);
assert!(!config.enable_fake_quant);
assert_eq!(config.observer_type, ObserverType::MinMax);
assert_eq!(config.backend, QuantBackend::Native);
assert_eq!(config.reduce_range, ReduceRange::None);
}
#[test]
fn test_quant_config_presets() {
let int8_config = QuantConfig::int8();
assert_eq!(int8_config.dtype, DType::I8);
assert_eq!(int8_config.qint_min, Some(-128));
assert_eq!(int8_config.qint_max, Some(127));
let binary_config = QuantConfig::binary();
assert_eq!(binary_config.scheme, QScheme::Binary);
assert_eq!(binary_config.qint_min, Some(-1));
assert_eq!(binary_config.qint_max, Some(1));
let int4_config = QuantConfig::int4();
assert_eq!(int4_config.scheme, QScheme::Int4PerTensor);
assert_eq!(int4_config.observer_type, ObserverType::Histogram);
}
#[test]
fn test_quant_config_builder() {
let config = QuantConfigBuilder::new()
.dtype(DType::I8)
.scheme(QScheme::PerChannelAffine)
.observer(ObserverType::Histogram)
.backend(QuantBackend::Fbgemm)
.channel_axis(1)
.build()
.unwrap();
assert_eq!(config.dtype, DType::I8);
assert_eq!(config.scheme, QScheme::PerChannelAffine);
assert_eq!(config.observer_type, ObserverType::Histogram);
assert_eq!(config.backend, QuantBackend::Fbgemm);
assert_eq!(config.ch_axis, Some(1));
}
#[test]
fn test_config_validation() {
let valid_config = QuantConfig::per_channel(0);
assert!(valid_config.validate().is_ok());
let mut invalid_config = QuantConfig::default();
invalid_config.scheme = QScheme::PerChannelAffine;
invalid_config.ch_axis = None;
assert!(invalid_config.validate().is_err());
let mut invalid_group = QuantConfig::default();
invalid_group.scheme = QScheme::GroupWise;
invalid_group.ch_axis = Some(0);
invalid_group.group_size = None;
assert!(invalid_group.validate().is_err());
let mut invalid_eps = QuantConfig::default();
invalid_eps.eps = -1.0;
assert!(invalid_eps.validate().is_err());
let mut invalid_avg = QuantConfig::default();
invalid_avg.averaging_constant = 1.5;
assert!(invalid_avg.validate().is_err());
}
#[test]
fn test_get_qint_range() {
let int8_config = QuantConfig::int8();
assert_eq!(int8_config.get_qint_range(), (-128, 127));
let uint8_config = QuantConfig::uint8();
assert_eq!(uint8_config.get_qint_range(), (0, 255));
let int4_config = QuantConfig::int4();
assert_eq!(int4_config.get_qint_range(), (-8, 7));
let binary_config = QuantConfig::binary();
assert_eq!(binary_config.get_qint_range(), (-1, 1));
let reduced_config = QuantConfig::int8().with_reduce_range(ReduceRange::Reduce);
let (min, max) = reduced_config.get_qint_range();
assert!(min > -128 && max < 127);
}
#[test]
fn test_mixed_precision_config() {
let mixed_config = MixedPrecisionConfig::default();
assert_eq!(mixed_config.default_precision, DType::I8);
assert_eq!(mixed_config.sensitivity_threshold, 0.1);
assert!(mixed_config.layer_precision.contains_key("embedding"));
}
#[test]
fn test_config_serialization() {
let config = QuantConfig::int8().with_observer(ObserverType::Histogram);
let json = serde_json::to_string(&config).unwrap();
let deserialized: QuantConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.dtype, deserialized.dtype);
assert_eq!(config.scheme, deserialized.scheme);
assert_eq!(config.observer_type, deserialized.observer_type);
}
}