use crate::{QScheme, TorshResult};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct AnalysisConfig {
pub sensitivity_threshold: f32,
pub fp32_threshold: f32,
pub aggressive_threshold: f32,
pub max_accuracy_drop_percent: f32,
pub efficiency_weights: EfficiencyWeights,
pub normalization_factors: NormalizationFactors,
}
#[derive(Debug, Clone)]
pub struct EfficiencyWeights {
pub accuracy: f32,
pub size: f32,
pub speed: f32,
}
#[derive(Debug, Clone)]
pub struct NormalizationFactors {
pub max_size_reduction: f32,
pub max_speed_improvement: f32,
}
impl Default for AnalysisConfig {
fn default() -> Self {
Self {
sensitivity_threshold: 0.05,
fp32_threshold: 0.05,
aggressive_threshold: 0.01,
max_accuracy_drop_percent: 5.0,
efficiency_weights: EfficiencyWeights::default(),
normalization_factors: NormalizationFactors::default(),
}
}
}
impl Default for EfficiencyWeights {
fn default() -> Self {
Self {
accuracy: 0.5,
size: 0.3,
speed: 0.2,
}
}
}
impl Default for NormalizationFactors {
fn default() -> Self {
Self {
max_size_reduction: 8.0,
max_speed_improvement: 10.0,
}
}
}
impl AnalysisConfig {
pub fn with_sensitivity_thresholds(
sensitivity_threshold: f32,
fp32_threshold: f32,
aggressive_threshold: f32,
) -> Self {
Self {
sensitivity_threshold,
fp32_threshold,
aggressive_threshold,
..Default::default()
}
}
pub fn with_efficiency_weights(accuracy: f32, size: f32, speed: f32) -> Self {
Self {
efficiency_weights: EfficiencyWeights {
accuracy,
size,
speed,
},
..Default::default()
}
}
pub fn conservative() -> Self {
Self {
sensitivity_threshold: 0.02,
fp32_threshold: 0.02,
aggressive_threshold: 0.005,
max_accuracy_drop_percent: 2.0,
..Default::default()
}
}
pub fn aggressive() -> Self {
Self {
sensitivity_threshold: 0.1,
fp32_threshold: 0.1,
aggressive_threshold: 0.05,
max_accuracy_drop_percent: 10.0,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct LayerSensitivityResult {
pub layer_name: String,
pub original_accuracy: f32,
pub quantized_accuracy: f32,
pub sensitivity_score: f32,
pub recommended_scheme: QScheme,
pub keep_fp32: bool,
}
impl LayerSensitivityResult {
pub fn new(layer_name: String, original_accuracy: f32, quantized_accuracy: f32) -> Self {
Self::new_with_config(
layer_name,
original_accuracy,
quantized_accuracy,
&AnalysisConfig::default(),
)
}
pub fn new_with_config(
layer_name: String,
original_accuracy: f32,
quantized_accuracy: f32,
config: &AnalysisConfig,
) -> Self {
let sensitivity_score = original_accuracy - quantized_accuracy;
let keep_fp32 = sensitivity_score > config.fp32_threshold;
let recommended_scheme = Self::determine_recommended_scheme(sensitivity_score, config);
Self {
layer_name,
original_accuracy,
quantized_accuracy,
sensitivity_score,
recommended_scheme,
keep_fp32,
}
}
fn determine_recommended_scheme(sensitivity_score: f32, config: &AnalysisConfig) -> QScheme {
if sensitivity_score > config.fp32_threshold {
QScheme::PerTensorAffine
} else if sensitivity_score > config.aggressive_threshold {
QScheme::PerChannelAffine
} else if sensitivity_score > config.aggressive_threshold / 2.0 {
QScheme::Int4PerTensor
} else {
QScheme::Int4PerChannel
}
}
pub fn accuracy_drop_percentage(&self) -> f32 {
(self.sensitivity_score / self.original_accuracy) * 100.0
}
pub fn is_high_sensitivity(&self) -> bool {
self.is_high_sensitivity_with_config(&AnalysisConfig::default())
}
pub fn is_high_sensitivity_with_config(&self, config: &AnalysisConfig) -> bool {
self.sensitivity_score > config.sensitivity_threshold
|| self.accuracy_drop_percentage() > config.max_accuracy_drop_percent
}
}
#[derive(Debug, Clone)]
pub struct SensitivityAnalysisResults {
pub layer_results: Vec<LayerSensitivityResult>,
pub overall_sensitivity: f32,
pub most_sensitive_layers: Vec<String>,
pub least_sensitive_layers: Vec<String>,
pub recommended_config: HashMap<String, QScheme>,
}
impl SensitivityAnalysisResults {
pub fn new(layer_results: Vec<LayerSensitivityResult>) -> Self {
let overall_sensitivity = if layer_results.is_empty() {
0.0
} else {
layer_results
.iter()
.map(|r| r.sensitivity_score)
.sum::<f32>()
/ layer_results.len() as f32
};
let mut sorted_results = layer_results.clone();
sorted_results.sort_by(|a, b| {
b.sensitivity_score
.partial_cmp(&a.sensitivity_score)
.unwrap()
});
let num_layers = sorted_results.len();
let top_10_percent = (num_layers as f32 * 0.1).ceil() as usize;
let bottom_10_percent = (num_layers as f32 * 0.1).ceil() as usize;
let most_sensitive_layers = sorted_results
.iter()
.take(top_10_percent)
.map(|r| r.layer_name.clone())
.collect();
let least_sensitive_layers = sorted_results
.iter()
.rev()
.take(bottom_10_percent)
.map(|r| r.layer_name.clone())
.collect();
let mut recommended_config = HashMap::new();
for result in &layer_results {
recommended_config.insert(result.layer_name.clone(), result.recommended_scheme);
}
Self {
layer_results,
overall_sensitivity,
most_sensitive_layers,
least_sensitive_layers,
recommended_config,
}
}
pub fn get_fp32_layers(&self) -> Vec<&String> {
self.layer_results
.iter()
.filter(|r| r.keep_fp32)
.map(|r| &r.layer_name)
.collect()
}
pub fn average_sensitivity(&self) -> f32 {
self.overall_sensitivity
}
pub fn get_aggressive_quantization_candidates(&self) -> Vec<&String> {
self.get_aggressive_quantization_candidates_with_config(&AnalysisConfig::default())
}
pub fn get_aggressive_quantization_candidates_with_config(
&self,
config: &AnalysisConfig,
) -> Vec<&String> {
self.layer_results
.iter()
.filter(|r| r.sensitivity_score < config.aggressive_threshold)
.map(|r| &r.layer_name)
.collect()
}
pub fn summary_report(&self) -> String {
format!(
"Sensitivity Analysis Summary:\n\
- Total layers analyzed: {}\n\
- Average sensitivity: {:.4}\n\
- Most sensitive layers ({}):\n{}\n\
- Least sensitive layers ({}):\n{}\n\
- Layers recommended for FP32: {}",
self.layer_results.len(),
self.overall_sensitivity,
self.most_sensitive_layers.len(),
self.most_sensitive_layers
.iter()
.map(|name| format!(" - {}", name))
.collect::<Vec<_>>()
.join("\n"),
self.least_sensitive_layers.len(),
self.least_sensitive_layers
.iter()
.map(|name| format!(" - {}", name))
.collect::<Vec<_>>()
.join("\n"),
self.get_fp32_layers().len()
)
}
}
#[derive(Debug, Clone)]
pub struct AccuracyComparison {
pub original_accuracy: f32,
pub quantized_accuracy: f32,
pub accuracy_drop: f32,
pub accuracy_drop_percentage: f32,
pub is_acceptable: bool,
pub detailed_metrics: HashMap<String, f32>,
}
impl AccuracyComparison {
pub fn new(original_accuracy: f32, quantized_accuracy: f32) -> Self {
Self::new_with_threshold(original_accuracy, quantized_accuracy, 5.0)
}
pub fn new_with_threshold(
original_accuracy: f32,
quantized_accuracy: f32,
acceptable_drop_percentage: f32,
) -> Self {
let accuracy_drop = original_accuracy - quantized_accuracy;
let accuracy_drop_percentage = (accuracy_drop / original_accuracy) * 100.0;
let is_acceptable = accuracy_drop_percentage <= acceptable_drop_percentage;
Self {
original_accuracy,
quantized_accuracy,
accuracy_drop,
accuracy_drop_percentage,
is_acceptable,
detailed_metrics: HashMap::new(),
}
}
pub fn add_metric(&mut self, name: String, value: f32) {
self.detailed_metrics.insert(name, value);
}
pub fn efficiency_score(&self) -> f32 {
if self.original_accuracy == 0.0 {
0.0
} else {
self.quantized_accuracy / self.original_accuracy
}
}
pub fn is_quantization_recommended(&self) -> bool {
self.is_acceptable && self.efficiency_score() > 0.95
}
pub fn report(&self) -> String {
let mut report = format!(
"Accuracy Comparison Report:\n\
- Original Accuracy: {:.4} ({:.2}%)\n\
- Quantized Accuracy: {:.4} ({:.2}%)\n\
- Accuracy Drop: {:.4} ({:.2}%)\n\
- Efficiency Score: {:.4}\n\
- Acceptable: {}\n\
- Quantization Recommended: {}",
self.original_accuracy,
self.original_accuracy * 100.0,
self.quantized_accuracy,
self.quantized_accuracy * 100.0,
self.accuracy_drop,
self.accuracy_drop_percentage,
self.efficiency_score(),
if self.is_acceptable { "Yes" } else { "No" },
if self.is_quantization_recommended() {
"Yes"
} else {
"No"
}
);
if !self.detailed_metrics.is_empty() {
report.push_str("\n\nDetailed Metrics:");
for (name, value) in &self.detailed_metrics {
report.push_str(&format!("\n - {}: {:.4}", name, value));
}
}
report
}
}