use crate::analysis::config::{LayerSensitivityResult, SensitivityAnalysisResults};
use crate::{QScheme, QuantConfig, TorshResult};
use torsh_core::TorshError;
use torsh_tensor::Tensor;
pub struct SensitivityAnalyzer {
#[allow(dead_code)]
test_data: Vec<(Tensor, Tensor)>, tolerance: f32,
}
impl SensitivityAnalyzer {
pub fn new(test_data: Vec<(Tensor, Tensor)>) -> Self {
Self {
test_data,
tolerance: 1e-6,
}
}
pub fn set_tolerance(&mut self, tolerance: f32) {
self.tolerance = tolerance;
}
pub fn analyze_layer_sensitivity(
&self,
layer_names: &[String],
evaluation_fn: impl Fn(&str, &QuantConfig) -> TorshResult<f32>,
) -> TorshResult<SensitivityAnalysisResults> {
let mut layer_results = Vec::new();
let baseline_accuracy = evaluation_fn("", &QuantConfig::default())?;
for layer_name in layer_names {
let mut best_accuracy = 0.0;
let mut _best_scheme = QScheme::PerTensorAffine;
let schemes_to_test = vec![
QScheme::PerTensorAffine,
QScheme::PerTensorSymmetric,
QScheme::PerChannelAffine,
QScheme::Int4PerTensor,
QScheme::Binary,
];
for &scheme in &schemes_to_test {
let config = QuantConfig::new().with_scheme(scheme);
match evaluation_fn(layer_name, &config) {
Ok(accuracy) => {
if accuracy > best_accuracy {
best_accuracy = accuracy;
_best_scheme = scheme;
}
}
Err(_) => {
continue;
}
}
}
let result =
LayerSensitivityResult::new(layer_name.clone(), baseline_accuracy, best_accuracy);
layer_results.push(result);
}
Ok(SensitivityAnalysisResults::new(layer_results))
}
pub fn heuristic_sensitivity_analysis(
&self,
layer_names: &[String],
) -> TorshResult<SensitivityAnalysisResults> {
let mut layer_results = Vec::new();
for layer_name in layer_names {
let (sensitivity_score, _recommended_scheme) =
self.estimate_layer_sensitivity(layer_name);
let baseline_accuracy = 0.95; let quantized_accuracy = baseline_accuracy - sensitivity_score;
let result = LayerSensitivityResult::new(
layer_name.clone(),
baseline_accuracy,
quantized_accuracy,
);
layer_results.push(result);
}
Ok(SensitivityAnalysisResults::new(layer_results))
}
fn estimate_layer_sensitivity(&self, layer_name: &str) -> (f32, QScheme) {
let layer_name_lower = layer_name.to_lowercase();
if layer_name_lower.contains("embedding") {
(0.08, QScheme::PerTensorAffine) } else if layer_name_lower.contains("attention") || layer_name_lower.contains("self_attn") {
(0.06, QScheme::PerChannelAffine) } else if layer_name_lower.contains("output") || layer_name_lower.contains("classifier") {
(0.05, QScheme::PerTensorAffine) } else if layer_name_lower.contains("layer_norm") || layer_name_lower.contains("batch_norm")
{
(0.02, QScheme::Int4PerTensor) } else if layer_name_lower.contains("conv") && layer_name_lower.contains("1x1") {
(0.01, QScheme::Int4PerTensor) } else if layer_name_lower.contains("conv") {
(0.03, QScheme::PerChannelAffine) } else if layer_name_lower.contains("linear") || layer_name_lower.contains("dense") {
(0.025, QScheme::PerTensorAffine) } else {
(0.03, QScheme::PerTensorAffine) }
}
pub fn compare_tensor_accuracy(
&self,
original: &Tensor,
quantized: &Tensor,
) -> TorshResult<f32> {
if original.shape() != quantized.shape() {
return Err(TorshError::InvalidArgument(
"Tensors must have the same shape for accuracy comparison".to_string(),
));
}
let original_data = original.data()?;
let quantized_data = quantized.data()?;
let mut correct_predictions = 0;
let total_predictions = original_data.len();
for (orig, quant) in original_data.iter().zip(quantized_data.iter()) {
if (orig - quant).abs() <= self.tolerance {
correct_predictions += 1;
}
}
Ok(correct_predictions as f32 / total_predictions as f32)
}
pub fn calculate_mse(&self, original: &Tensor, quantized: &Tensor) -> TorshResult<f32> {
if original.shape() != quantized.shape() {
return Err(TorshError::InvalidArgument(
"Tensors must have the same shape for MSE calculation".to_string(),
));
}
let original_data = original.data()?;
let quantized_data = quantized.data()?;
let mse = original_data
.iter()
.zip(quantized_data.iter())
.map(|(orig, quant)| (orig - quant).powi(2))
.sum::<f32>()
/ original_data.len() as f32;
Ok(mse)
}
pub fn calculate_snr(&self, original: &Tensor, quantized: &Tensor) -> TorshResult<f32> {
let mse = self.calculate_mse(original, quantized)?;
if mse == 0.0 {
return Ok(f32::INFINITY); }
let original_data = original.data()?;
let signal_power =
original_data.iter().map(|&x| x.powi(2)).sum::<f32>() / original_data.len() as f32;
let snr_db = 10.0 * (signal_power / mse).log10();
Ok(snr_db)
}
}