use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum QuantGranularity {
PerTensor,
#[default]
PerChannel,
PerToken,
PerBlock {
block_size: usize,
},
}
impl QuantGranularity {
pub fn params_per_elements(&self, n: usize, channels: usize) -> usize {
match self {
QuantGranularity::PerTensor => 2, QuantGranularity::PerChannel => channels * 2,
QuantGranularity::PerToken => (n / channels) * 2,
QuantGranularity::PerBlock { block_size } => ((n + block_size - 1) / block_size) * 2,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum SteVariant {
Standard,
Clipped {
clip_val: f32,
},
LearnedStepSize,
Ewgs {
lambda: f32,
},
}
impl Default for SteVariant {
fn default() -> Self {
SteVariant::Standard
}
}
impl SteVariant {
pub fn clipped() -> Self {
SteVariant::Clipped { clip_val: 1.0 }
}
pub fn ewgs() -> Self {
SteVariant::Ewgs { lambda: 0.1 }
}
pub fn requires_scale_grad(&self) -> bool {
matches!(self, SteVariant::LearnedStepSize)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct QatLossWeights {
pub lambda_task: f32,
pub lambda_kd: f32,
pub lambda_reasoning: f32,
pub kd_temperature: f32,
}
impl Default for QatLossWeights {
fn default() -> Self {
Self {
lambda_task: 1.0,
lambda_kd: 0.5,
lambda_reasoning: 0.2,
kd_temperature: 3.0,
}
}
}
impl QatLossWeights {
pub fn distillation_heavy() -> Self {
Self {
lambda_task: 0.5,
lambda_kd: 1.0,
lambda_reasoning: 0.3,
kd_temperature: 4.0,
}
}
pub fn reasoning_focused() -> Self {
Self {
lambda_task: 1.0,
lambda_kd: 0.3,
lambda_reasoning: 0.5,
kd_temperature: 2.0,
}
}
pub fn normalized(&self) -> Self {
let sum = self.lambda_task + self.lambda_kd + self.lambda_reasoning;
if sum > 0.0 {
Self {
lambda_task: self.lambda_task / sum,
lambda_kd: self.lambda_kd / sum,
lambda_reasoning: self.lambda_reasoning / sum,
kd_temperature: self.kd_temperature,
}
} else {
Self::default()
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct QatConfig {
pub bits: u8,
pub ste_variant: SteVariant,
pub granularity: QuantGranularity,
pub loss_weights: QatLossWeights,
pub epochs: usize,
pub learning_rate: f32,
pub pi_k: Option<u8>,
pub use_incoherence: bool,
pub symmetric: bool,
pub warmup_epochs: usize,
pub grad_clip: Option<f32>,
pub weight_decay: f32,
pub use_distillation: bool,
pub freeze_teacher: bool,
}
impl Default for QatConfig {
fn default() -> Self {
Self {
bits: 4,
ste_variant: SteVariant::Standard,
granularity: QuantGranularity::PerChannel,
loss_weights: QatLossWeights::default(),
epochs: 3,
learning_rate: 1e-4,
pi_k: None,
use_incoherence: false,
symmetric: true,
warmup_epochs: 1,
grad_clip: Some(1.0),
weight_decay: 0.01,
use_distillation: true,
freeze_teacher: true, }
}
}
impl QatConfig {
pub fn new(bits: u8) -> Self {
Self {
bits,
..Default::default()
}
}
pub fn pi_quant(bits: u8, k: u8) -> Self {
Self {
bits,
pi_k: Some(k),
ste_variant: SteVariant::LearnedStepSize, ..Default::default()
}
}
pub fn piq3() -> Self {
Self::pi_quant(3, 4) }
pub fn piq2() -> Self {
Self {
use_incoherence: true, ..Self::pi_quant(2, 3) }
}
pub fn with_bits(mut self, bits: u8) -> Self {
self.bits = bits;
self
}
pub fn with_ste(mut self, ste: SteVariant) -> Self {
self.ste_variant = ste;
self
}
pub fn with_granularity(mut self, granularity: QuantGranularity) -> Self {
self.granularity = granularity;
self
}
pub fn with_loss_weights(mut self, weights: QatLossWeights) -> Self {
self.loss_weights = weights;
self
}
pub fn with_epochs(mut self, epochs: usize) -> Self {
self.epochs = epochs;
self
}
pub fn with_learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn with_pi_k(mut self, k: u8) -> Self {
self.pi_k = Some(k);
self
}
pub fn with_incoherence(mut self, enable: bool) -> Self {
self.use_incoherence = enable;
self
}
pub fn with_symmetric(mut self, symmetric: bool) -> Self {
self.symmetric = symmetric;
self
}
pub fn with_warmup(mut self, epochs: usize) -> Self {
self.warmup_epochs = epochs;
self
}
pub fn with_grad_clip(mut self, clip: Option<f32>) -> Self {
self.grad_clip = clip;
self
}
pub fn with_distillation(mut self, enable: bool) -> Self {
self.use_distillation = enable;
self
}
pub fn validate(&self) -> Result<(), String> {
if !matches!(self.bits, 2 | 3 | 4 | 5 | 8) {
return Err(format!(
"Invalid bit width: {}. Must be 2, 3, 4, 5, or 8",
self.bits
));
}
if let Some(k) = self.pi_k {
if !matches!(k, 2 | 3 | 4 | 5) {
return Err(format!("Invalid pi_k: {}. Must be 2, 3, 4, or 5", k));
}
}
if self.learning_rate <= 0.0 {
return Err("Learning rate must be positive".to_string());
}
if self.epochs == 0 {
return Err("Epochs must be greater than 0".to_string());
}
if self.loss_weights.lambda_task < 0.0
|| self.loss_weights.lambda_kd < 0.0
|| self.loss_weights.lambda_reasoning < 0.0
{
return Err("Loss weights must be non-negative".to_string());
}
if self.loss_weights.kd_temperature <= 0.0 {
return Err("KD temperature must be positive".to_string());
}
if let SteVariant::Clipped { clip_val } = self.ste_variant {
if clip_val <= 0.0 {
return Err("Clip value must be positive".to_string());
}
}
if let SteVariant::Ewgs { lambda } = self.ste_variant {
if lambda < 0.0 {
return Err("EWGS lambda must be non-negative".to_string());
}
}
Ok(())
}
pub fn num_levels(&self) -> usize {
1 << self.bits
}
pub fn symmetric_range(&self) -> (i32, i32) {
let half = (1i32 << self.bits) / 2;
(-half, half - 1)
}
pub fn pi_step(&self, alpha: f32) -> Option<f32> {
self.pi_k.map(|k| alpha * std::f32::consts::PI / (k as f32))
}
pub fn is_pi_quant(&self) -> bool {
self.pi_k.is_some()
}
pub fn bits_per_weight(&self) -> f32 {
let base_bits = self.bits as f32;
match self.granularity {
QuantGranularity::PerTensor => base_bits + 0.0625, QuantGranularity::PerChannel => base_bits + 0.125, QuantGranularity::PerToken => base_bits + 0.25, QuantGranularity::PerBlock { block_size } => {
base_bits + 16.0 / (block_size as f32) }
}
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(self)
}
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = QatConfig::default();
assert_eq!(config.bits, 4);
assert_eq!(config.ste_variant, SteVariant::Standard);
assert!(config.validate().is_ok());
}
#[test]
fn test_piq3_config() {
let config = QatConfig::piq3();
assert_eq!(config.bits, 3);
assert_eq!(config.pi_k, Some(4));
assert!(config.is_pi_quant());
assert!(config.validate().is_ok());
}
#[test]
fn test_piq2_config() {
let config = QatConfig::piq2();
assert_eq!(config.bits, 2);
assert_eq!(config.pi_k, Some(3));
assert!(config.use_incoherence);
assert!(config.validate().is_ok());
}
#[test]
fn test_builder_pattern() {
let config = QatConfig::default()
.with_bits(3)
.with_ste(SteVariant::LearnedStepSize)
.with_epochs(5)
.with_learning_rate(2e-4)
.with_pi_k(4);
assert_eq!(config.bits, 3);
assert_eq!(config.ste_variant, SteVariant::LearnedStepSize);
assert_eq!(config.epochs, 5);
assert_eq!(config.learning_rate, 2e-4);
assert_eq!(config.pi_k, Some(4));
assert!(config.validate().is_ok());
}
#[test]
fn test_invalid_bits() {
let config = QatConfig::default().with_bits(6);
assert!(config.validate().is_err());
}
#[test]
fn test_invalid_pi_k() {
let mut config = QatConfig::default();
config.pi_k = Some(7);
assert!(config.validate().is_err());
}
#[test]
fn test_pi_step_calculation() {
let config = QatConfig::piq3(); let alpha = 1.0;
let step = config.pi_step(alpha).unwrap();
let expected = std::f32::consts::PI / 4.0;
assert!((step - expected).abs() < 1e-6);
}
#[test]
fn test_symmetric_range() {
let config = QatConfig::default().with_bits(3);
let (min, max) = config.symmetric_range();
assert_eq!(min, -4);
assert_eq!(max, 3);
}
#[test]
fn test_num_levels() {
assert_eq!(QatConfig::default().with_bits(2).num_levels(), 4);
assert_eq!(QatConfig::default().with_bits(3).num_levels(), 8);
assert_eq!(QatConfig::default().with_bits(4).num_levels(), 16);
}
#[test]
fn test_loss_weights() {
let weights = QatLossWeights::default();
assert_eq!(weights.lambda_task, 1.0);
let normalized = weights.normalized();
let sum = normalized.lambda_task + normalized.lambda_kd + normalized.lambda_reasoning;
assert!((sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_granularity_params() {
let per_tensor = QuantGranularity::PerTensor;
assert_eq!(per_tensor.params_per_elements(1024, 64), 2);
let per_channel = QuantGranularity::PerChannel;
assert_eq!(per_channel.params_per_elements(1024, 64), 128);
let per_block = QuantGranularity::PerBlock { block_size: 32 };
assert_eq!(per_block.params_per_elements(1024, 64), 64); }
#[test]
fn test_ste_variants() {
assert!(!SteVariant::Standard.requires_scale_grad());
assert!(!SteVariant::clipped().requires_scale_grad());
assert!(SteVariant::LearnedStepSize.requires_scale_grad());
assert!(!SteVariant::ewgs().requires_scale_grad());
}
#[test]
fn test_json_serialization() {
let config = QatConfig::piq3();
let json = config.to_json().unwrap();
let restored = QatConfig::from_json(&json).unwrap();
assert_eq!(config.bits, restored.bits);
assert_eq!(config.pi_k, restored.pi_k);
}
#[test]
fn test_bits_per_weight() {
let config = QatConfig::default()
.with_bits(4)
.with_granularity(QuantGranularity::PerTensor);
assert!(config.bits_per_weight() > 4.0);
assert!(config.bits_per_weight() < 4.2);
}
}