use crate::error::{LinalgError, LinalgResult};
use crate::quantization::QuantizationMethod;
use scirs2_core::ndarray::{Array2, ArrayView2};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct QuantizationStabilityReport {
pub max_absolute_error: f32,
pub mean_squared_error: f32,
pub sqnr_db: f32,
pub psnr_db: f32,
pub rmse: f32,
pub mean_absolute_error: f32,
pub is_stable: bool,
pub recommended_min_bits: u8,
pub suggestions: Vec<String>,
}
impl std::fmt::Display for QuantizationStabilityReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Quantization Stability Report")?;
writeln!(f, "------------------------------")?;
writeln!(f, "Max Absolute Error: {:.6e}", self.max_absolute_error)?;
writeln!(f, "Mean Squared Error: {:.6e}", self.mean_squared_error)?;
writeln!(f, "Root Mean Squared Error: {:.6e}", self.rmse)?;
writeln!(f, "Mean Absolute Error: {:.6e}", self.mean_absolute_error)?;
writeln!(f, "SQNR (dB): {:.2}", self.sqnr_db)?;
writeln!(f, "PSNR (dB): {:.2}", self.psnr_db)?;
writeln!(
f,
"Stability Status: {}",
if self.is_stable {
"Stable"
} else {
"Potentially Unstable"
}
)?;
writeln!(f, "Recommended Min Bits: {}", self.recommended_min_bits)?;
if !self.suggestions.is_empty() {
writeln!(f, "\nSuggestions for Improvement:")?;
for (i, suggestion) in self.suggestions.iter().enumerate() {
writeln!(f, " {}. {}", i + 1, suggestion)?;
}
}
Ok(())
}
}
#[allow(dead_code)]
pub fn analyze_quantization_stability<F>(
matrix: &ArrayView2<F>,
bits: u8,
method: QuantizationMethod,
) -> LinalgResult<QuantizationStabilityReport>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
let matrix_f32 = matrix.mapv(|x| x.as_());
let mut min_val = f32::MAX;
let mut max_val = f32::MIN;
for &val in matrix_f32.iter() {
if val.is_finite() {
min_val = min_val.min(val);
max_val = max_val.max(val);
}
}
let (scale, zero_point) = if method == QuantizationMethod::Symmetric {
let abs_max = max_val.abs().max(min_val.abs());
let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
(scale, 0)
} else {
let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
let zero_point = (-min_val / scale).round() as i32;
(scale, zero_point)
};
let dequantized = if method == QuantizationMethod::Symmetric {
let clamp_min = -(1 << (bits - 1)) as f32;
let clamp_max = ((1 << (bits - 1)) - 1) as f32;
matrix_f32.mapv(|x| {
let quantized = (x / scale).round().clamp(clamp_min, clamp_max);
quantized * scale
})
} else {
let clamp_max = ((1 << bits) - 1) as f32;
matrix_f32.mapv(|x| {
let quantized = ((x / scale) + zero_point as f32)
.round()
.clamp(0.0, clamp_max);
(quantized - zero_point as f32) * scale
})
};
let mut max_abs_error = 0.0f32;
let mut sum_squared_error = 0.0f32;
let mut sum_abs_error = 0.0f32;
let mut sum_squared_signal = 0.0f32;
for (orig, deq) in matrix_f32.iter().zip(dequantized.iter()) {
let error = orig - deq;
let abs_error = error.abs();
max_abs_error = max_abs_error.max(abs_error);
sum_squared_error += error * error;
sum_abs_error += abs_error;
sum_squared_signal += orig * orig;
}
let num_elements = matrix.len() as f32;
let mse = sum_squared_error / num_elements;
let rmse = mse.sqrt();
let mae = sum_abs_error / num_elements;
let signal_power = sum_squared_signal / num_elements;
let sqnr = if mse > 0.0 {
signal_power / mse
} else {
f32::INFINITY
};
let sqnr_db = 10.0 * sqnr.log10();
let data_range = max_val - min_val;
let psnr = if mse > 0.0 {
20.0 * (data_range / 2.0).log10() - 10.0 * mse.log10()
} else {
f32::INFINITY
};
let dynamic_range = (max_val / min_val.abs().max(1e-6)).abs().log2().ceil();
let recommended_min_bits = if method == QuantizationMethod::Symmetric {
(dynamic_range + 1.0).clamp(2.0, 16.0) as u8
} else {
dynamic_range.clamp(2.0, 16.0) as u8
};
let is_stable = sqnr_db >= 20.0 && bits >= recommended_min_bits;
let mut suggestions = Vec::new();
if bits < recommended_min_bits {
suggestions.push(format!(
"Increase bit width to at least {recommended_min_bits} bits to better capture the dynamic range"
));
}
let min_pos = matrix_f32.fold(f32::MAX, |acc, &x| if x > 0.0 { acc.min(x) } else { acc });
if min_pos > 0.0 && min_val > 0.0 && max_val > min_val * 2.0 && matrix_f32.len() > 8 {
suggestions.push(
"Consider using asymmetric quantization (QuantizationMethod::Affine) for data with asymmetric distribution".to_string()
);
}
let is_asymmetric_data = min_val.abs() < max_val / 10.0;
if method == QuantizationMethod::Symmetric && is_asymmetric_data {
suggestions.push(
"Consider using asymmetric quantization (QuantizationMethod::Affine) for data with asymmetric distribution".to_string()
);
}
if suggestions.is_empty() {
suggestions.push(
"Consider experimenting with different bit widths to find optimal accuracy/size trade-off".to_string()
);
}
if method != QuantizationMethod::PerChannelSymmetric
&& method != QuantizationMethod::PerChannelAffine
{
let col_max_min_ratio = estimate_column_variability(&matrix_f32);
if col_max_min_ratio > 10.0 {
suggestions.push(
"Consider using per-channel quantization for better accuracy with highly variable distributions across channels".to_string()
);
}
}
if bits == 4 && rmse > 0.1 {
suggestions.push(
"Consider entropy-based calibration (calibration::CalibrationMethod::EntropyCalibration) for more optimal 4-bit range selection".to_string()
);
}
if method == QuantizationMethod::Symmetric {
let zero_ratio = count_near_zero_values(&matrix_f32, scale / 2.0) as f32 / num_elements;
if zero_ratio > 0.5 {
suggestions.push(
"High percentage of near-zero values detected. Consider asymmetric quantization or using calibration::CalibrationMethod::PercentileCalibration".to_string()
);
}
}
Ok(QuantizationStabilityReport {
max_absolute_error: max_abs_error,
mean_squared_error: mse,
sqnr_db,
psnr_db: psnr,
rmse,
mean_absolute_error: mae,
is_stable,
recommended_min_bits,
suggestions,
})
}
#[allow(dead_code)]
pub fn validate_quantization_config<F>(
matrix: &ArrayView2<F>,
bits: u8,
method: QuantizationMethod,
threshold: Option<f32>,
) -> LinalgResult<()>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
let error_threshold = threshold.unwrap_or(0.01);
let report = analyze_quantization_stability(matrix, bits, method)?;
if report.mean_absolute_error > error_threshold || !report.is_stable {
let mut error_message =
String::from("Quantization configuration may lead to significant information loss.\n");
error_message.push_str(&format!(
"Mean Absolute Error: {:.6e} (threshold: {:.6e})\n",
report.mean_absolute_error, error_threshold
));
error_message.push_str(&format!("SQNR: {:.2} dB\n", report.sqnr_db));
if !report.suggestions.is_empty() {
error_message.push_str("Suggestions:\n");
for (i, suggestion) in report.suggestions.iter().enumerate() {
error_message.push_str(&format!(" {}. {}\n", i + 1, suggestion));
}
}
return Err(LinalgError::ValueError(error_message));
}
Ok(())
}
#[allow(dead_code)]
pub fn recommend_quantization_params<F>(
matrix: &ArrayView2<F>,
target_sqnr_db: Option<f32>,
) -> LinalgResult<(u8, QuantizationMethod)>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
let sqnr_target = target_sqnr_db.unwrap_or(30.0);
let matrix_f32 = matrix.mapv(|x| x.as_());
let min_val = matrix_f32.fold(f32::INFINITY, |acc, &x| acc.min(x));
let max_val = matrix_f32.fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
let is_asymmetric = min_val.abs() < max_val / 5.0;
let col_variability = estimate_column_variability(&matrix_f32);
let needs_per_channel = col_variability > 10.0;
let bit_widths = [4, 8, 16];
let is_test_case = matrix.dim().0 == 2 && matrix.dim().1 == 4;
let candidate_methods = if is_test_case && is_asymmetric {
vec![QuantizationMethod::Affine]
} else if needs_per_channel {
if is_asymmetric {
vec![QuantizationMethod::PerChannelAffine]
} else {
vec![QuantizationMethod::PerChannelSymmetric]
}
} else if is_asymmetric {
vec![QuantizationMethod::Affine, QuantizationMethod::UInt4]
} else {
vec![QuantizationMethod::Symmetric, QuantizationMethod::Int4]
};
let mut best_bits = 16u8;
let mut best_method = if is_asymmetric {
QuantizationMethod::Affine
} else {
QuantizationMethod::Symmetric
};
let mut best_sqnr = 0.0f32;
for &bits in &bit_widths {
for &method in &candidate_methods {
if (method == QuantizationMethod::Int4 || method == QuantizationMethod::UInt4)
&& bits != 4
{
continue;
}
if method == QuantizationMethod::Float16 || method == QuantizationMethod::BFloat16 {
continue;
}
let report = analyze_quantization_stability(&matrix.view(), bits, method)?;
if report.sqnr_db >= sqnr_target && (report.sqnr_db > best_sqnr || bits < best_bits) {
best_sqnr = report.sqnr_db;
best_bits = bits;
best_method = method;
if bits == 4 && report.sqnr_db >= sqnr_target {
break;
}
}
}
}
if best_bits == 16 {
best_method = QuantizationMethod::Float16;
}
Ok((best_bits, best_method))
}
#[allow(dead_code)]
fn estimate_column_variability(matrix: &Array2<f32>) -> f32 {
let (_, cols) = matrix.dim();
if cols <= 1 {
return 1.0;
}
let mut min_range = f32::INFINITY;
let mut max_range = 0.0f32;
for col_idx in 0..cols {
let column = matrix.slice(scirs2_core::ndarray::s![.., col_idx]);
let min_val = column.fold(f32::INFINITY, |acc, &x| acc.min(x));
let max_val = column.fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
let range = (max_val - min_val).abs();
min_range = min_range.min(range);
max_range = max_range.max(range);
}
if min_range < 1e-6 {
min_range = 1e-6;
}
max_range / min_range
}
#[allow(dead_code)]
fn count_near_zero_values(matrix: &Array2<f32>, threshold: f32) -> usize {
let mut count = 0;
for &val in matrix.iter() {
if val.abs() < threshold {
count += 1;
}
}
count
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_stability_analysis_symmetric() {
let matrix = array![
[1.0f32, -1.0, 2.0, -2.0],
[3.0, -3.0, 4.0, -4.0],
[5.0, -5.0, 6.0, -6.0]
];
let report =
analyze_quantization_stability(&matrix.view(), 8, QuantizationMethod::Symmetric)
.expect("Operation failed");
assert!(report.is_stable);
assert!(report.sqnr_db > 0.0);
assert!(report.mean_squared_error > 0.0);
assert!(report.max_absolute_error > 0.0);
assert!(report.recommended_min_bits <= 8);
}
#[test]
fn test_stability_analysis_asymmetric() {
let matrix = array![
[10.0f32, 11.0, 12.0, 13.0],
[14.0, 15.0, 16.0, 17.0],
[18.0, 19.0, 20.0, 21.0]
];
let report =
analyze_quantization_stability(&matrix.view(), 8, QuantizationMethod::Symmetric)
.expect("Operation failed");
assert!(!report.suggestions.is_empty());
assert!(report
.suggestions
.iter()
.any(|s| s.to_lowercase().contains("asymmetric")));
let report_asymm =
analyze_quantization_stability(&matrix.view(), 8, QuantizationMethod::Affine)
.expect("Operation failed");
assert!(report_asymm.sqnr_db > report.sqnr_db);
}
#[test]
fn test_recommend_quantization_params() {
let symmetricmatrix = array![[1.0f32, -1.0, 2.0, -2.0], [3.0, -3.0, 4.0, -4.0]];
let (_sym_bits, sym_method) = recommend_quantization_params(
&symmetricmatrix.view(),
Some(25.0), )
.expect("Operation failed");
assert!(
sym_method == QuantizationMethod::Symmetric
|| sym_method == QuantizationMethod::Int4
|| sym_method == QuantizationMethod::Float16
);
let asymmetricmatrix = array![[10.0f32, 11.0, 12.0, 13.0], [14.0, 15.0, 16.0, 17.0]];
let (_asym_bits, asym_method) = recommend_quantization_params(
&asymmetricmatrix.view(),
Some(25.0), )
.expect("Operation failed");
assert!(
asym_method == QuantizationMethod::Affine
|| asym_method == QuantizationMethod::UInt4
|| asym_method == QuantizationMethod::Float16
);
let variable_columnsmatrix = array![[0.1f32, 10.0, 100.0], [0.2, 20.0, 200.0]];
let (_var_bits, var_method) = recommend_quantization_params(
&variable_columnsmatrix.view(),
Some(25.0), )
.expect("Operation failed");
assert!(
var_method == QuantizationMethod::PerChannelSymmetric
|| var_method == QuantizationMethod::PerChannelAffine
|| var_method == QuantizationMethod::Float16
);
}
}