use super::super::{
calibration_ema::{
calibrate_matrix_ema, calibrate_matrix_per_channel_ema, calibrate_vector_ema,
},
QuantizationMethod, QuantizationParams,
};
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{ArrayView1, ArrayView2};
use std::fmt::Debug;
use super::{matrix_calibration::*, utils::*, vector_calibration::*};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CalibrationMethod {
MinMax,
MovingAverageMinMax,
EntropyCalibration,
PercentileCalibration,
MSEOptimization,
ExponentialMovingAverage,
}
#[derive(Debug, Clone)]
pub struct CalibrationConfig {
pub method: CalibrationMethod,
pub num_bins: usize,
pub percentile: f32,
pub windowsize: usize,
pub per_channel: bool,
pub symmetric: bool,
pub ema_factor: f32,
pub max_iterations: usize,
pub convergence_threshold: f32,
}
impl Default for CalibrationConfig {
fn default() -> Self {
CalibrationConfig {
method: CalibrationMethod::MinMax,
num_bins: 2048,
percentile: 0.999,
windowsize: 10,
per_channel: false,
symmetric: true,
ema_factor: 0.1,
max_iterations: 10,
convergence_threshold: 1e-6,
}
}
}
#[allow(dead_code)]
pub fn calibrate_matrix<F>(
matrix: &ArrayView2<F>,
bits: u8,
config: &CalibrationConfig,
) -> LinalgResult<QuantizationParams>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
match config.method {
CalibrationMethod::MinMax => {
if config.per_channel {
calibrate_matrix_per_channel_minmax(matrix, bits, config.symmetric)
} else {
calibrate_matrix_minmax(matrix, bits, config.symmetric)
}
}
CalibrationMethod::MovingAverageMinMax => {
if config.per_channel {
calibrate_matrix_per_channel_moving_average(
matrix,
bits,
config.windowsize,
config.symmetric,
)
} else {
calibrate_matrix_moving_average(matrix, bits, config.windowsize, config.symmetric)
}
}
CalibrationMethod::PercentileCalibration => {
if config.per_channel {
calibrate_matrix_per_channel_percentile(
matrix,
bits,
config.percentile,
config.symmetric,
)
} else {
calibrate_matrix_percentile(matrix, bits, config.percentile, config.symmetric)
}
}
CalibrationMethod::EntropyCalibration => {
if config.per_channel {
calibrate_matrix_per_channel_entropy(
matrix,
bits,
config.num_bins,
config.symmetric,
)
} else {
calibrate_matrix_entropy(matrix, bits, config.num_bins, config.symmetric)
}
}
CalibrationMethod::MSEOptimization => {
if config.per_channel {
calibrate_matrix_per_channel_mse(matrix, bits, config.symmetric)
} else {
calibrate_matrix_mse(matrix, bits, config.symmetric)
}
}
CalibrationMethod::ExponentialMovingAverage => {
if config.per_channel {
calibrate_matrix_per_channel_ema(
matrix,
bits,
config.ema_factor,
config.max_iterations,
config.convergence_threshold,
config.symmetric,
)
} else {
calibrate_matrix_ema(
matrix,
bits,
config.ema_factor,
config.max_iterations,
config.convergence_threshold,
config.symmetric,
)
}
}
}
}
#[allow(dead_code)]
pub fn calibrate_vector<F>(
vector: &ArrayView1<F>,
bits: u8,
config: &CalibrationConfig,
) -> LinalgResult<QuantizationParams>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
let mut config = config.clone();
config.per_channel = false;
match config.method {
CalibrationMethod::MinMax => calibrate_vector_minmax(vector, bits, config.symmetric),
CalibrationMethod::MovingAverageMinMax => {
calibrate_vector_moving_average(vector, bits, config.windowsize, config.symmetric)
}
CalibrationMethod::PercentileCalibration => {
calibrate_vector_percentile(vector, bits, config.percentile, config.symmetric)
}
CalibrationMethod::EntropyCalibration => {
calibrate_vector_entropy(vector, bits, config.num_bins, config.symmetric)
}
CalibrationMethod::MSEOptimization => calibrate_vector_mse(vector, bits, config.symmetric),
CalibrationMethod::ExponentialMovingAverage => calibrate_vector_ema(
vector,
bits,
config.ema_factor,
config.max_iterations,
config.convergence_threshold,
config.symmetric,
),
}
}
#[allow(dead_code)]
pub fn get_weight_calibration_config(bits: u8, aggressive: bool) -> CalibrationConfig {
if aggressive {
CalibrationConfig {
method: CalibrationMethod::PercentileCalibration,
symmetric: true,
percentile: 0.99, per_channel: true, ..Default::default()
}
} else {
CalibrationConfig {
method: CalibrationMethod::EntropyCalibration,
symmetric: true,
num_bins: 2048, per_channel: true,
..Default::default()
}
}
}
#[allow(dead_code)]
pub fn get_activation_calibration_config(
_bits: u8,
non_negative: bool,
outlier_sensitive: bool,
) -> CalibrationConfig {
let mut config = if outlier_sensitive {
CalibrationConfig {
method: CalibrationMethod::MSEOptimization,
num_bins: 1024,
per_channel: false, ..Default::default()
}
} else {
CalibrationConfig {
method: CalibrationMethod::PercentileCalibration,
percentile: 0.9995, per_channel: false,
..Default::default()
}
};
config.symmetric = !non_negative;
config
}