use super::calibration::{
create_params_from_range, determine_data_type, find_min_max, find_min_max_vec,
};
use super::{QuantizationMethod, QuantizationParams};
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
use std::fmt::Debug;
#[allow(dead_code)]
fn to_f<F>(val: f32) -> F
where
F: scirs2_core::numeric::Float + scirs2_core::numeric::FromPrimitive,
{
F::from_f32(val).expect("Operation failed")
}
#[allow(dead_code)]
fn to_f32<F>(val: F) -> f32
where
F: scirs2_core::numeric::Float + scirs2_core::numeric::AsPrimitive<f32>,
{
val.as_()
}
#[allow(dead_code)]
pub fn calibrate_matrix_ema<F>(
matrix: &ArrayView2<F>,
bits: u8,
ema_factor: f32,
max_iterations: usize,
convergence_threshold: f32,
symmetric: bool,
) -> LinalgResult<QuantizationParams>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
if !(0.0..=1.0).contains(&ema_factor) {
return Err(LinalgError::ValueError(
"EMA factor must be between 0.0 and 1.0".to_string(),
));
}
let (min_val_f32, max_val_f32) = find_min_max(matrix);
let mut min_val = to_f::<F>(min_val_f32);
let mut max_val = to_f::<F>(max_val_f32);
let matrix_f32 = matrix.mapv(|x| x.as_());
let mut prev_mse = f32::MAX;
for _iter in 0..max_iterations {
let min_val_f32 = to_f32(min_val);
let max_val_f32 = to_f32(max_val);
let params = create_params_from_range(bits, min_val_f32, max_val_f32, symmetric)?;
let dequantized = simulate_quantization(&matrix_f32, ¶ms, bits);
let mse = (&matrix_f32 - &dequantized).mapv(|x| x * x).sum() / matrix_f32.len() as f32;
if (prev_mse - mse).abs() < convergence_threshold {
return Ok(params);
}
prev_mse = mse;
if symmetric {
let abs_max = max_val.abs().max(min_val.abs());
let mean_abs_error = (&matrix_f32 - &dequantized)
.mapv(|x| x.abs())
.mean()
.unwrap_or(0.0);
let scale_adjustment = if mean_abs_error > 0.01 {
1.0 + mean_abs_error
} else {
1.0
};
let abs_max_f32 = to_f32(abs_max);
let new_abs_max_f32 = abs_max_f32 * scale_adjustment;
let updated_abs_max_f32 =
abs_max_f32 * (1.0 - ema_factor) + new_abs_max_f32 * ema_factor;
let updated_abs_max = to_f::<F>(updated_abs_max_f32);
min_val = -updated_abs_max;
max_val = updated_abs_max;
} else {
let negative_errors = matrix_f32
.iter()
.zip(dequantized.iter())
.filter_map(|(&orig, &deq)| {
if orig < deq {
Some((orig - deq).abs())
} else {
None
}
})
.fold(0.0, |sum, error| sum + error);
let positive_errors = matrix_f32
.iter()
.zip(dequantized.iter())
.filter_map(|(&orig, &deq)| {
if orig > deq {
Some((orig - deq).abs())
} else {
None
}
})
.fold(0.0, |sum, error| sum + error);
let neg_count = matrix_f32.iter().filter(|&&x| x < 0.0).count() as f32;
let pos_count = matrix_f32.iter().filter(|&&x| x > 0.0).count() as f32;
let neg_adjustment = if neg_count > 0.0 {
1.0 + (negative_errors / neg_count)
} else {
1.0
};
let pos_adjustment = if pos_count > 0.0 {
1.0 + (positive_errors / pos_count)
} else {
1.0
};
let min_val_f32 = to_f32(min_val);
let max_val_f32 = to_f32(max_val);
let new_min_f32 = min_val_f32 * neg_adjustment;
let new_max_f32 = max_val_f32 * pos_adjustment;
let updated_min_f32 = min_val_f32 * (1.0 - ema_factor) + new_min_f32 * ema_factor;
let updated_max_f32 = max_val_f32 * (1.0 - ema_factor) + new_max_f32 * ema_factor;
min_val = to_f::<F>(updated_min_f32);
max_val = to_f::<F>(updated_max_f32);
}
if _iter == max_iterations - 1 {
println!("EMA calibration reached max iterations with MSE: {mse}");
}
}
let min_val_f32 = to_f32(min_val);
let max_val_f32 = to_f32(max_val);
create_params_from_range(bits, min_val_f32, max_val_f32, symmetric)
}
#[allow(dead_code)]
pub fn calibrate_matrix_per_channel_ema<F>(
matrix: &ArrayView2<F>,
bits: u8,
ema_factor: f32,
max_iterations: usize,
convergence_threshold: f32,
symmetric: bool,
) -> LinalgResult<QuantizationParams>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
if !(0.0..=1.0).contains(&ema_factor) {
return Err(LinalgError::ValueError(
"EMA factor must be between 0.0 and 1.0".to_string(),
));
}
let (_rows, cols) = matrix.dim();
let (global_min, global_max) = find_min_max(matrix);
let mut channel_scales = Vec::with_capacity(cols);
let mut channel_zero_points = Vec::with_capacity(if symmetric { 0 } else { cols });
let matrix_f32 = matrix.mapv(|x| x.as_());
for col_idx in 0..cols {
let column_view = matrix.column(col_idx);
let column_f32_view = matrix_f32.column(col_idx);
let (col_min_f32, col_max_f32) = find_min_max_vec(&column_view);
let mut col_min = to_f::<F>(col_min_f32);
let mut col_max = to_f::<F>(col_max_f32);
let mut prev_mse = f32::MAX;
for _iter in 0..max_iterations {
let (scale, zero_point) = if symmetric {
let abs_max = col_max.abs().max(col_min.abs());
let abs_max_f32 = to_f32(abs_max);
let scale_f32 = abs_max_f32 / ((1 << (bits - 1)) - 1) as f32;
(scale_f32, 0)
} else {
let min_val_f32 = to_f32(col_min);
let max_val_f32 = to_f32(col_max);
let scale_f32 = (max_val_f32 - min_val_f32) / ((1 << bits) - 1) as f32;
let zero_point = (-min_val_f32 / scale_f32).round() as i32;
(scale_f32, zero_point)
};
let dequantized_col = simulate_quantization_vector_f32(
&column_f32_view,
scale,
zero_point,
bits,
symmetric,
);
let mse = column_f32_view
.iter()
.zip(dequantized_col.iter())
.map(|(&orig, &deq)| (orig - deq).powi(2))
.sum::<f32>()
/ column_f32_view.len() as f32;
if (prev_mse - mse).abs() < convergence_threshold {
break;
}
prev_mse = mse;
if symmetric {
let abs_max = col_max.abs().max(col_min.abs());
let mean_abs_error = column_f32_view
.iter()
.zip(dequantized_col.iter())
.map(|(&orig, &deq)| (orig - deq).abs())
.sum::<f32>()
/ column_f32_view.len() as f32;
let scale_adjustment = if mean_abs_error > 0.01 {
1.0 + mean_abs_error
} else {
1.0
};
let abs_max_f32 = to_f32(abs_max);
let new_abs_max_f32 = abs_max_f32 * scale_adjustment;
let updated_abs_max_f32 =
abs_max_f32 * (1.0 - ema_factor) + new_abs_max_f32 * ema_factor;
let updated_abs_max = to_f::<F>(updated_abs_max_f32);
col_min = -updated_abs_max;
col_max = updated_abs_max;
} else {
let negative_errors = column_f32_view
.iter()
.zip(dequantized_col.iter())
.filter_map(|(&orig, &deq)| {
if orig < deq {
Some((orig - deq).abs())
} else {
None
}
})
.sum::<f32>();
let positive_errors = column_f32_view
.iter()
.zip(dequantized_col.iter())
.filter_map(|(&orig, &deq)| {
if orig > deq {
Some((orig - deq).abs())
} else {
None
}
})
.sum::<f32>();
let neg_count = column_f32_view.iter().filter(|&&x| x < 0.0).count() as f32;
let pos_count = column_f32_view.iter().filter(|&&x| x > 0.0).count() as f32;
let neg_adjustment = if neg_count > 0.0 {
1.0 + (negative_errors / neg_count)
} else {
1.0
};
let pos_adjustment = if pos_count > 0.0 {
1.0 + (positive_errors / pos_count)
} else {
1.0
};
let min_val_f32 = to_f32(col_min);
let max_val_f32 = to_f32(col_max);
let new_min_f32 = min_val_f32 * neg_adjustment;
let new_max_f32 = max_val_f32 * pos_adjustment;
let updated_min_f32 = min_val_f32 * (1.0 - ema_factor) + new_min_f32 * ema_factor;
let updated_max_f32 = max_val_f32 * (1.0 - ema_factor) + new_max_f32 * ema_factor;
col_min = to_f::<F>(updated_min_f32);
col_max = to_f::<F>(updated_max_f32);
}
}
let (scale, zero_point) = if symmetric {
let abs_max = col_max.abs().max(col_min.abs());
let abs_max_f32 = to_f32(abs_max);
let scale_f32 = abs_max_f32 / ((1 << (bits - 1)) - 1) as f32;
(scale_f32, 0)
} else {
let min_val_f32 = to_f32(col_min);
let max_val_f32 = to_f32(col_max);
let scale_f32 = (max_val_f32 - min_val_f32) / ((1 << bits) - 1) as f32;
let zero_point = (-min_val_f32 / scale_f32).round() as i32;
(scale_f32, zero_point)
};
channel_scales.push(scale);
if !symmetric {
channel_zero_points.push(zero_point);
}
}
let q_method = if symmetric {
QuantizationMethod::PerChannelSymmetric
} else {
QuantizationMethod::PerChannelAffine
};
Ok(QuantizationParams {
bits,
scale: 0.0, zero_point: 0, min_val: global_min,
max_val: global_max,
method: q_method,
data_type: determine_data_type(bits),
channel_scales: Some(channel_scales),
channel_zero_points: if symmetric {
None
} else {
Some(channel_zero_points)
},
})
}
#[allow(dead_code)]
pub fn calibrate_vector_ema<F>(
vector: &ArrayView1<F>,
bits: u8,
ema_factor: f32,
max_iterations: usize,
convergence_threshold: f32,
symmetric: bool,
) -> LinalgResult<QuantizationParams>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
if !(0.0..=1.0).contains(&ema_factor) {
return Err(LinalgError::ValueError(
"EMA factor must be between 0.0 and 1.0".to_string(),
));
}
let (min_val_f32, max_val_f32) = find_min_max_vec(vector);
let mut min_val = to_f::<F>(min_val_f32);
let mut max_val = to_f::<F>(max_val_f32);
let vector_f32 = vector.mapv(|x| x.as_());
let mut prev_mse = f32::MAX;
for _iter in 0..max_iterations {
let (scale, zero_point) = if symmetric {
let abs_max = max_val.abs().max(min_val.abs());
let abs_max_f32 = to_f32(abs_max);
let scale_f32 = abs_max_f32 / ((1 << (bits - 1)) - 1) as f32;
(scale_f32, 0)
} else {
let min_val_f32 = to_f32(min_val);
let max_val_f32 = to_f32(max_val);
let scale_f32 = (max_val_f32 - min_val_f32) / ((1 << bits) - 1) as f32;
let zero_point = (-min_val_f32 / scale_f32).round() as i32;
(scale_f32, zero_point)
};
let dequantized = simulate_quantization_vector_f32(
&vector_f32.view(),
scale,
zero_point,
bits,
symmetric,
);
let mse = vector_f32
.iter()
.zip(dequantized.iter())
.map(|(&orig, &deq)| (orig - deq).powi(2))
.sum::<f32>()
/ vector_f32.len() as f32;
if (prev_mse - mse).abs() < convergence_threshold {
let min_val_f32 = to_f32(min_val);
let max_val_f32 = to_f32(max_val);
return create_params_from_range(bits, min_val_f32, max_val_f32, symmetric);
}
prev_mse = mse;
if symmetric {
let abs_max = max_val.abs().max(min_val.abs());
let mean_abs_error = vector_f32
.iter()
.zip(dequantized.iter())
.map(|(&orig, &deq)| (orig - deq).abs())
.sum::<f32>()
/ vector_f32.len() as f32;
let scale_adjustment = if mean_abs_error > 0.01 {
1.0 + mean_abs_error
} else {
1.0
};
let abs_max_f32 = to_f32(abs_max);
let new_abs_max_f32 = abs_max_f32 * scale_adjustment;
let updated_abs_max_f32 =
abs_max_f32 * (1.0 - ema_factor) + new_abs_max_f32 * ema_factor;
let updated_abs_max = to_f::<F>(updated_abs_max_f32);
min_val = -updated_abs_max;
max_val = updated_abs_max;
} else {
let negative_errors = vector_f32
.iter()
.zip(dequantized.iter())
.filter_map(|(&orig, &deq)| {
if orig < deq {
Some((orig - deq).abs())
} else {
None
}
})
.sum::<f32>();
let positive_errors = vector_f32
.iter()
.zip(dequantized.iter())
.filter_map(|(&orig, &deq)| {
if orig > deq {
Some((orig - deq).abs())
} else {
None
}
})
.sum::<f32>();
let neg_count = vector_f32.iter().filter(|&&x| x < 0.0).count() as f32;
let pos_count = vector_f32.iter().filter(|&&x| x > 0.0).count() as f32;
let neg_adjustment = if neg_count > 0.0 {
1.0 + (negative_errors / neg_count)
} else {
1.0
};
let pos_adjustment = if pos_count > 0.0 {
1.0 + (positive_errors / pos_count)
} else {
1.0
};
let min_val_f32 = to_f32(min_val);
let max_val_f32 = to_f32(max_val);
let new_min_f32 = min_val_f32 * neg_adjustment;
let new_max_f32 = max_val_f32 * pos_adjustment;
let updated_min_f32 = min_val_f32 * (1.0 - ema_factor) + new_min_f32 * ema_factor;
let updated_max_f32 = max_val_f32 * (1.0 - ema_factor) + new_max_f32 * ema_factor;
min_val = to_f::<F>(updated_min_f32);
max_val = to_f::<F>(updated_max_f32);
}
}
let min_val_f32 = to_f32(min_val);
let max_val_f32 = to_f32(max_val);
create_params_from_range(bits, min_val_f32, max_val_f32, symmetric)
}
#[allow(dead_code)]
fn simulate_quantization_vector_f32(
vector: &ArrayView1<f32>,
scale: f32,
zero_point: i32,
bits: u8,
symmetric: bool,
) -> Array2<f32> {
let mut result = Array2::zeros((vector.len(), 1));
if symmetric {
let clamp_min = -(1 << (bits - 1)) as f32;
let clamp_max = ((1 << (bits - 1)) - 1) as f32;
for (i, &val) in vector.iter().enumerate() {
let quantized = (val / scale).round().clamp(clamp_min, clamp_max);
result[[i, 0]] = quantized * scale;
}
} else {
let clamp_max = ((1 << bits) - 1) as f32;
let zero_point = zero_point as f32;
for (i, &val) in vector.iter().enumerate() {
let quantized = ((val / scale) + zero_point).round().clamp(0.0, clamp_max);
result[[i, 0]] = (quantized - zero_point) * scale;
}
}
result
}
#[allow(dead_code)]
fn simulate_quantization(
matrix: &Array2<f32>,
params: &QuantizationParams,
bits: u8,
) -> Array2<f32> {
match params.method {
QuantizationMethod::Symmetric | QuantizationMethod::Int4 => {
let scale = params.scale;
let clamp_min = -(1 << (bits - 1)) as f32;
let clamp_max = ((1 << (bits - 1)) - 1) as f32;
matrix.mapv(|x| {
let quantized = (x / scale).round().clamp(clamp_min, clamp_max);
quantized * scale
})
}
QuantizationMethod::Affine | QuantizationMethod::UInt4 => {
let scale = params.scale;
let zero_point = params.zero_point as f32;
let clamp_max = ((1 << bits) - 1) as f32;
matrix.mapv(|x| {
let quantized = ((x / scale) + zero_point).round().clamp(0.0, clamp_max);
(quantized - zero_point) * scale
})
}
QuantizationMethod::PerChannelSymmetric | QuantizationMethod::PerChannelAffine => {
let mut result = Array2::zeros(matrix.dim());
let (_, cols) = matrix.dim();
if let Some(channel_scales) = ¶ms.channel_scales {
for col_idx in 0..cols {
let col_view = matrix.column(col_idx);
let scale = channel_scales[col_idx];
if params.method == QuantizationMethod::PerChannelSymmetric {
let clamp_min = -(1 << (bits - 1)) as f32;
let clamp_max = ((1 << (bits - 1)) - 1) as f32;
for (row_idx, &val) in col_view.iter().enumerate() {
let quantized = (val / scale).round().clamp(clamp_min, clamp_max);
result[[row_idx, col_idx]] = quantized * scale;
}
} else {
let clamp_max = ((1 << bits) - 1) as f32;
let zero_point = params
.channel_zero_points
.as_ref()
.expect("Operation failed")[col_idx]
as f32;
for (row_idx, &val) in col_view.iter().enumerate() {
let quantized =
((val / scale) + zero_point).round().clamp(0.0, clamp_max);
result[[row_idx, col_idx]] = (quantized - zero_point) * scale;
}
}
}
}
result
}
_ => {
matrix.clone()
}
}
}