use crate::{QScheme, QuantConfig, TorshResult};
use std::collections::HashMap;
use torsh_core::{DType, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug)]
pub struct QuantizationDebugger {
pub debug_enabled: bool,
pub execution_trace: Vec<DebugStep>,
pub error_stats: ErrorStatistics,
pub range_tracker: RangeTracker,
pub overflow_detector: OverflowDetector,
}
#[derive(Debug, Clone)]
pub struct DebugStep {
pub name: String,
pub input_stats: TensorStatistics,
pub output_stats: TensorStatistics,
pub quant_params: QuantParams,
pub error_metrics: ErrorMetrics,
pub timestamp: std::time::Instant,
}
#[derive(Debug, Clone)]
pub struct TensorStatistics {
pub min: f32,
pub max: f32,
pub mean: f32,
pub std: f32,
pub shape: Vec<usize>,
pub num_elements: usize,
pub dtype: DType,
}
#[derive(Debug, Clone)]
pub struct QuantParams {
pub scale: f32,
pub zero_point: i32,
pub scheme: QScheme,
pub qint_range: (i32, i32),
}
#[derive(Debug, Clone)]
pub struct ErrorMetrics {
pub mae: f32,
pub mse: f32,
pub rmse: f32,
pub snr: f32,
pub psnr: f32,
pub cosine_similarity: f32,
}
#[derive(Debug)]
pub struct ErrorStatistics {
pub total_ops: usize,
pub error_histogram: Vec<usize>,
pub error_bins: Vec<f32>,
pub cumulative_mae: f32,
pub cumulative_mse: f32,
pub layer_errors: HashMap<String, Vec<f32>>,
}
#[derive(Debug)]
pub struct RangeTracker {
pub tensor_ranges: HashMap<String, Vec<(f32, f32)>>,
pub range_violations: Vec<RangeViolation>,
pub expected_ranges: HashMap<String, (f32, f32)>,
pub stability_metrics: HashMap<String, f32>,
}
#[derive(Debug, Clone)]
pub struct RangeViolation {
pub tensor_name: String,
pub expected_range: (f32, f32),
pub actual_range: (f32, f32),
pub severity: ViolationSeverity,
pub timestamp: std::time::Instant,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ViolationSeverity {
Minor,
Moderate,
Severe,
}
#[derive(Debug)]
pub struct OverflowDetector {
pub overflow_events: Vec<OverflowEvent>,
pub underflow_events: Vec<OverflowEvent>,
pub overflow_threshold: f32,
pub underflow_threshold: f32,
pub detection_enabled: bool,
}
#[derive(Debug, Clone)]
pub struct OverflowEvent {
pub tensor_name: String,
pub overflow_value: f32,
pub threshold: f32,
pub num_elements: usize,
pub position: Vec<usize>,
pub timestamp: std::time::Instant,
}
impl QuantizationDebugger {
pub fn new() -> Self {
Self {
debug_enabled: true,
execution_trace: Vec::new(),
error_stats: ErrorStatistics::new(),
range_tracker: RangeTracker::new(),
overflow_detector: OverflowDetector::new(),
}
}
pub fn set_debug_enabled(&mut self, enabled: bool) {
self.debug_enabled = enabled;
}
pub fn debug_quantization(
&mut self,
name: &str,
input: &Tensor,
output: &Tensor,
config: &QuantConfig,
scale: f32,
zero_point: i32,
) -> TorshResult<()> {
if !self.debug_enabled {
return Ok(());
}
let input_stats = self.compute_tensor_statistics(input)?;
let output_stats = self.compute_tensor_statistics(output)?;
let quant_params = QuantParams {
scale,
zero_point,
scheme: config.scheme,
qint_range: config.get_qint_range(),
};
let error_metrics = self.compute_error_metrics(input, output)?;
let debug_step = DebugStep {
name: name.to_string(),
input_stats,
output_stats,
quant_params,
error_metrics: error_metrics.clone(),
timestamp: std::time::Instant::now(),
};
self.execution_trace.push(debug_step);
self.error_stats.update(&error_metrics, name);
self.range_tracker.track_range(name, input)?;
self.overflow_detector.detect_overflow(name, input)?;
Ok(())
}
fn compute_tensor_statistics(&self, tensor: &Tensor) -> TorshResult<TensorStatistics> {
let data = tensor.data()?;
let num_elements = data.len();
if num_elements == 0 {
return Err(TorshError::InvalidArgument("Empty tensor".to_string()));
}
let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mean = data.iter().sum::<f32>() / num_elements as f32;
let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / num_elements as f32;
let std = variance.sqrt();
Ok(TensorStatistics {
min,
max,
mean,
std,
shape: tensor.shape().dims().to_vec(),
num_elements,
dtype: tensor.dtype(),
})
}
fn compute_error_metrics(
&self,
original: &Tensor,
quantized: &Tensor,
) -> TorshResult<ErrorMetrics> {
let orig_data = original.data()?;
let quant_data = quantized.data()?;
if orig_data.len() != quant_data.len() {
return Err(TorshError::InvalidArgument(
"Tensor size mismatch".to_string(),
));
}
let n = orig_data.len() as f32;
let mae = orig_data
.iter()
.zip(quant_data.iter())
.map(|(&a, &b)| (a - b).abs())
.sum::<f32>()
/ n;
let mse = orig_data
.iter()
.zip(quant_data.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f32>()
/ n;
let rmse = mse.sqrt();
let signal_power = orig_data.iter().map(|&x| x.powi(2)).sum::<f32>() / n;
let noise_power = mse;
let snr = if noise_power > 0.0 {
10.0 * (signal_power / noise_power).log10()
} else {
f32::INFINITY
};
let max_val = orig_data.iter().fold(0.0f32, |a, &b| a.max(b.abs()));
let psnr = if mse > 0.0 {
20.0 * (max_val / rmse).log10()
} else {
f32::INFINITY
};
let dot_product = orig_data
.iter()
.zip(quant_data.iter())
.map(|(&a, &b)| a * b)
.sum::<f32>();
let orig_norm = orig_data.iter().map(|&x| x.powi(2)).sum::<f32>().sqrt();
let quant_norm = quant_data.iter().map(|&x| x.powi(2)).sum::<f32>().sqrt();
let cosine_similarity = if orig_norm > 0.0 && quant_norm > 0.0 {
dot_product / (orig_norm * quant_norm)
} else {
0.0
};
Ok(ErrorMetrics {
mae,
mse,
rmse,
snr,
psnr,
cosine_similarity,
})
}
pub fn generate_report(&self) -> String {
let mut report = String::new();
report.push_str("=== QUANTIZATION DEBUG REPORT ===\n\n");
report.push_str(&format!(
"Total Operations: {}\n",
self.execution_trace.len()
));
report.push_str(&format!(
"Debug Mode: {}\n\n",
if self.debug_enabled {
"Enabled"
} else {
"Disabled"
}
));
report.push_str("--- ERROR STATISTICS ---\n");
report.push_str(&format!(
"Total Operations: {}\n",
self.error_stats.total_ops
));
report.push_str(&format!(
"Cumulative MAE: {:.6}\n",
self.error_stats.cumulative_mae
));
report.push_str(&format!(
"Cumulative MSE: {:.6}\n",
self.error_stats.cumulative_mse
));
report.push_str(&format!(
"Cumulative RMSE: {:.6}\n",
self.error_stats.cumulative_mse.sqrt()
));
report.push('\n');
if !self.error_stats.layer_errors.is_empty() {
report.push_str("--- PER-LAYER ERRORS ---\n");
for (layer, errors) in &self.error_stats.layer_errors {
let avg_error = errors.iter().sum::<f32>() / errors.len() as f32;
report.push_str(&format!(
"{}: {:.6} (avg over {} ops)\n",
layer,
avg_error,
errors.len()
));
}
report.push('\n');
}
if !self.range_tracker.range_violations.is_empty() {
report.push_str("--- RANGE VIOLATIONS ---\n");
for violation in &self.range_tracker.range_violations {
report.push_str(&format!(
"{}: Expected [{:.3}, {:.3}], Got [{:.3}, {:.3}] - {:?}\n",
violation.tensor_name,
violation.expected_range.0,
violation.expected_range.1,
violation.actual_range.0,
violation.actual_range.1,
violation.severity
));
}
report.push('\n');
}
if !self.overflow_detector.overflow_events.is_empty()
|| !self.overflow_detector.underflow_events.is_empty()
{
report.push_str("--- OVERFLOW/UNDERFLOW EVENTS ---\n");
for event in &self.overflow_detector.overflow_events {
report.push_str(&format!(
"OVERFLOW in {}: {:.3} > {:.3} ({} elements)\n",
event.tensor_name, event.overflow_value, event.threshold, event.num_elements
));
}
for event in &self.overflow_detector.underflow_events {
report.push_str(&format!(
"UNDERFLOW in {}: {:.3} < {:.3} ({} elements)\n",
event.tensor_name, event.overflow_value, event.threshold, event.num_elements
));
}
report.push('\n');
}
if !self.execution_trace.is_empty() {
report.push_str("--- RECENT EXECUTION TRACE ---\n");
let start_idx = self.execution_trace.len().saturating_sub(10);
for (i, step) in self.execution_trace[start_idx..].iter().enumerate() {
report.push_str(&format!(
"{}. {} - MAE: {:.6}, SNR: {:.2} dB\n",
start_idx + i + 1,
step.name,
step.error_metrics.mae,
step.error_metrics.snr
));
}
}
report
}
pub fn clear(&mut self) {
self.execution_trace.clear();
self.error_stats = ErrorStatistics::new();
self.range_tracker = RangeTracker::new();
self.overflow_detector = OverflowDetector::new();
}
pub fn export_to_json(&self) -> TorshResult<String> {
let mut json = String::new();
json.push_str("{\n");
json.push_str(&format!(
" \"total_operations\": {},\n",
self.execution_trace.len()
));
json.push_str(&format!(" \"debug_enabled\": {},\n", self.debug_enabled));
json.push_str(&format!(
" \"cumulative_mae\": {},\n",
self.error_stats.cumulative_mae
));
json.push_str(&format!(
" \"cumulative_mse\": {},\n",
self.error_stats.cumulative_mse
));
json.push_str(&format!(
" \"range_violations\": {},\n",
self.range_tracker.range_violations.len()
));
json.push_str(&format!(
" \"overflow_events\": {}\n",
self.overflow_detector.overflow_events.len()
));
json.push_str("}\n");
Ok(json)
}
}
impl Default for QuantizationDebugger {
fn default() -> Self {
Self::new()
}
}
impl Default for ErrorStatistics {
fn default() -> Self {
Self::new()
}
}
impl ErrorStatistics {
pub fn new() -> Self {
Self {
total_ops: 0,
error_histogram: vec![0; 100], error_bins: (0..100).map(|i| i as f32 * 0.01).collect(),
cumulative_mae: 0.0,
cumulative_mse: 0.0,
layer_errors: HashMap::new(),
}
}
pub fn update(&mut self, metrics: &ErrorMetrics, layer_name: &str) {
self.total_ops += 1;
self.cumulative_mae += metrics.mae;
self.cumulative_mse += metrics.mse;
self.layer_errors
.entry(layer_name.to_string())
.or_default()
.push(metrics.mae);
let bin_idx = (metrics.mae / 0.01).floor() as usize;
if bin_idx < self.error_histogram.len() {
self.error_histogram[bin_idx] += 1;
}
}
pub fn get_averages(&self) -> (f32, f32) {
if self.total_ops > 0 {
(
self.cumulative_mae / self.total_ops as f32,
self.cumulative_mse / self.total_ops as f32,
)
} else {
(0.0, 0.0)
}
}
}
impl Default for RangeTracker {
fn default() -> Self {
Self::new()
}
}
impl RangeTracker {
pub fn new() -> Self {
Self {
tensor_ranges: HashMap::new(),
range_violations: Vec::new(),
expected_ranges: HashMap::new(),
stability_metrics: HashMap::new(),
}
}
pub fn track_range(&mut self, name: &str, tensor: &Tensor) -> TorshResult<()> {
let data = tensor.data()?;
if data.is_empty() {
return Ok(());
}
let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
self.tensor_ranges
.entry(name.to_string())
.or_default()
.push((min, max));
if let Some(&expected_range) = self.expected_ranges.get(name) {
let violation = self.check_range_violation(name, (min, max), expected_range);
if let Some(v) = violation {
self.range_violations.push(v);
}
}
self.update_stability_metric(name);
Ok(())
}
pub fn set_expected_range(&mut self, name: &str, range: (f32, f32)) {
self.expected_ranges.insert(name.to_string(), range);
}
fn check_range_violation(
&self,
name: &str,
actual: (f32, f32),
expected: (f32, f32),
) -> Option<RangeViolation> {
let (min_actual, max_actual) = actual;
let (min_expected, max_expected) = expected;
let min_violation = (min_actual - min_expected) / (max_expected - min_expected).abs();
let max_violation = (max_actual - max_expected) / (max_expected - min_expected).abs();
let max_violation_pct = min_violation.abs().max(max_violation.abs());
if max_violation_pct > 0.01 {
let severity = if max_violation_pct < 0.1 {
ViolationSeverity::Minor
} else if max_violation_pct < 0.5 {
ViolationSeverity::Moderate
} else {
ViolationSeverity::Severe
};
Some(RangeViolation {
tensor_name: name.to_string(),
expected_range: expected,
actual_range: actual,
severity,
timestamp: std::time::Instant::now(),
})
} else {
None
}
}
fn update_stability_metric(&mut self, name: &str) {
if let Some(ranges) = self.tensor_ranges.get(name) {
if ranges.len() < 2 {
return;
}
let range_sizes: Vec<f32> = ranges.iter().map(|(min, max)| max - min).collect();
let mean_range = range_sizes.iter().sum::<f32>() / range_sizes.len() as f32;
let variance = range_sizes
.iter()
.map(|&x| (x - mean_range).powi(2))
.sum::<f32>()
/ range_sizes.len() as f32;
let std_dev = variance.sqrt();
let stability = if mean_range > 0.0 {
1.0 - (std_dev / mean_range) } else {
0.0
};
self.stability_metrics.insert(name.to_string(), stability);
}
}
pub fn get_stability(&self, name: &str) -> f32 {
self.stability_metrics.get(name).copied().unwrap_or(0.0)
}
}
impl Default for OverflowDetector {
fn default() -> Self {
Self::new()
}
}
impl OverflowDetector {
pub fn new() -> Self {
Self {
overflow_events: Vec::new(),
underflow_events: Vec::new(),
overflow_threshold: 1e6, underflow_threshold: -1e6, detection_enabled: true,
}
}
pub fn set_thresholds(&mut self, overflow: f32, underflow: f32) {
self.overflow_threshold = overflow;
self.underflow_threshold = underflow;
}
pub fn detect_overflow(&mut self, name: &str, tensor: &Tensor) -> TorshResult<()> {
if !self.detection_enabled {
return Ok(());
}
let data = tensor.data()?;
let binding = tensor.shape();
let shape = binding.dims();
for (i, &value) in data.iter().enumerate() {
if value > self.overflow_threshold {
let position = self.linear_to_nd_index(i, shape);
self.overflow_events.push(OverflowEvent {
tensor_name: name.to_string(),
overflow_value: value,
threshold: self.overflow_threshold,
num_elements: data
.iter()
.filter(|&&x| x > self.overflow_threshold)
.count(),
position,
timestamp: std::time::Instant::now(),
});
break;
}
if value < self.underflow_threshold {
let position = self.linear_to_nd_index(i, shape);
self.underflow_events.push(OverflowEvent {
tensor_name: name.to_string(),
overflow_value: value,
threshold: self.underflow_threshold,
num_elements: data
.iter()
.filter(|&&x| x < self.underflow_threshold)
.count(),
position,
timestamp: std::time::Instant::now(),
});
break;
}
}
Ok(())
}
fn linear_to_nd_index(&self, linear_idx: usize, shape: &[usize]) -> Vec<usize> {
let mut indices = Vec::with_capacity(shape.len());
let mut remaining = linear_idx;
for &dim_size in shape.iter().rev() {
indices.push(remaining % dim_size);
remaining /= dim_size;
}
indices.reverse();
indices
}
pub fn set_detection_enabled(&mut self, enabled: bool) {
self.detection_enabled = enabled;
}
pub fn clear_events(&mut self) {
self.overflow_events.clear();
self.underflow_events.clear();
}
pub fn get_overflow_count(&self) -> usize {
self.overflow_events.len()
}
pub fn get_underflow_count(&self) -> usize {
self.underflow_events.len()
}
}
#[derive(Debug)]
pub struct QuantizationComparator {
pub comparison_results: Vec<ComparisonResult>,
pub reference: Option<Tensor>,
}
#[derive(Debug, Clone)]
pub struct ComparisonResult {
pub scheme_a: String,
pub scheme_b: String,
pub metrics_a: ErrorMetrics,
pub metrics_b: ErrorMetrics,
pub winner: String,
pub improvement: ImprovementMetrics,
}
#[derive(Debug, Clone)]
pub struct ImprovementMetrics {
pub mae_improvement: f32,
pub snr_improvement: f32,
pub psnr_improvement: f32,
pub cosine_improvement: f32,
}
impl QuantizationComparator {
pub fn new() -> Self {
Self {
comparison_results: Vec::new(),
reference: None,
}
}
pub fn set_reference(&mut self, reference: Tensor) {
self.reference = Some(reference);
}
pub fn compare_schemes(
&mut self,
scheme_a_name: &str,
quantized_a: &Tensor,
scheme_b_name: &str,
quantized_b: &Tensor,
) -> TorshResult<ComparisonResult> {
let reference = self
.reference
.as_ref()
.ok_or_else(|| TorshError::InvalidArgument("Reference tensor not set".to_string()))?;
let metrics_a = self.compute_error_metrics(reference, quantized_a)?;
let metrics_b = self.compute_error_metrics(reference, quantized_b)?;
let winner = self.determine_winner(&metrics_a, &metrics_b, scheme_a_name, scheme_b_name);
let improvement = ImprovementMetrics {
mae_improvement: ((metrics_a.mae - metrics_b.mae) / metrics_a.mae) * 100.0,
snr_improvement: metrics_b.snr - metrics_a.snr,
psnr_improvement: metrics_b.psnr - metrics_a.psnr,
cosine_improvement: metrics_b.cosine_similarity - metrics_a.cosine_similarity,
};
let result = ComparisonResult {
scheme_a: scheme_a_name.to_string(),
scheme_b: scheme_b_name.to_string(),
metrics_a,
metrics_b,
winner,
improvement,
};
self.comparison_results.push(result.clone());
Ok(result)
}
fn compute_error_metrics(
&self,
original: &Tensor,
quantized: &Tensor,
) -> TorshResult<ErrorMetrics> {
let orig_data = original.data()?;
let quant_data = quantized.data()?;
if orig_data.len() != quant_data.len() {
return Err(TorshError::InvalidArgument(
"Tensor size mismatch".to_string(),
));
}
let n = orig_data.len() as f32;
let mae = orig_data
.iter()
.zip(quant_data.iter())
.map(|(&a, &b)| (a - b).abs())
.sum::<f32>()
/ n;
let mse = orig_data
.iter()
.zip(quant_data.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f32>()
/ n;
let rmse = mse.sqrt();
let signal_power = orig_data.iter().map(|&x| x.powi(2)).sum::<f32>() / n;
let snr = if mse > 0.0 {
10.0 * (signal_power / mse).log10()
} else {
f32::INFINITY
};
let max_val = orig_data.iter().fold(0.0f32, |a, &b| a.max(b.abs()));
let psnr = if mse > 0.0 {
20.0 * (max_val / rmse).log10()
} else {
f32::INFINITY
};
let dot_product = orig_data
.iter()
.zip(quant_data.iter())
.map(|(&a, &b)| a * b)
.sum::<f32>();
let orig_norm = orig_data.iter().map(|&x| x.powi(2)).sum::<f32>().sqrt();
let quant_norm = quant_data.iter().map(|&x| x.powi(2)).sum::<f32>().sqrt();
let cosine_similarity = if orig_norm > 0.0 && quant_norm > 0.0 {
dot_product / (orig_norm * quant_norm)
} else {
0.0
};
Ok(ErrorMetrics {
mae,
mse,
rmse,
snr,
psnr,
cosine_similarity,
})
}
fn determine_winner(
&self,
metrics_a: &ErrorMetrics,
metrics_b: &ErrorMetrics,
name_a: &str,
name_b: &str,
) -> String {
let mut score_a = 0;
let mut score_b = 0;
if metrics_a.mae < metrics_b.mae {
score_a += 1;
} else {
score_b += 1;
}
if metrics_a.snr > metrics_b.snr {
score_a += 1;
} else {
score_b += 1;
}
if metrics_a.psnr > metrics_b.psnr {
score_a += 1;
} else {
score_b += 1;
}
if metrics_a.cosine_similarity > metrics_b.cosine_similarity {
score_a += 1;
} else {
score_b += 1;
}
if score_a > score_b {
name_a.to_string()
} else if score_b > score_a {
name_b.to_string()
} else {
"Tie".to_string()
}
}
pub fn generate_comparison_report(&self) -> String {
let mut report = String::new();
report.push_str("=== QUANTIZATION SCHEME COMPARISON REPORT ===\n\n");
for (i, result) in self.comparison_results.iter().enumerate() {
report.push_str(&format!(
"Comparison {}: {} vs {}\n",
i + 1,
result.scheme_a,
result.scheme_b
));
report.push_str(&format!("Winner: {}\n", result.winner));
report.push_str(&format!(
"MAE Improvement: {:.2}%\n",
result.improvement.mae_improvement
));
report.push_str(&format!(
"SNR Improvement: {:.2} dB\n",
result.improvement.snr_improvement
));
report.push_str(&format!(
"PSNR Improvement: {:.2} dB\n",
result.improvement.psnr_improvement
));
report.push_str(&format!(
"Cosine Similarity Improvement: {:.4}\n",
result.improvement.cosine_improvement
));
report.push('\n');
}
report
}
}
impl Default for QuantizationComparator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::tensor_1d;
#[test]
fn test_quantization_debugger() {
let mut debugger = QuantizationDebugger::new();
assert!(debugger.debug_enabled);
let input = tensor_1d(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let output = tensor_1d(&[1.1, 2.1, 2.9, 3.9]).unwrap();
let config = QuantConfig::int8();
debugger
.debug_quantization("test_op", &input, &output, &config, 0.1, 0)
.unwrap();
assert_eq!(debugger.execution_trace.len(), 1);
assert_eq!(debugger.error_stats.total_ops, 1);
let report = debugger.generate_report();
assert!(report.contains("QUANTIZATION DEBUG REPORT"));
assert!(report.contains("Total Operations: 1"));
debugger.clear();
assert_eq!(debugger.execution_trace.len(), 0);
}
#[test]
fn test_error_statistics() {
let mut stats = ErrorStatistics::new();
assert_eq!(stats.total_ops, 0);
let metrics = ErrorMetrics {
mae: 0.1,
mse: 0.01,
rmse: 0.1,
snr: 20.0,
psnr: 30.0,
cosine_similarity: 0.95,
};
stats.update(&metrics, "test_layer");
assert_eq!(stats.total_ops, 1);
assert_eq!(stats.cumulative_mae, 0.1);
assert!(stats.layer_errors.contains_key("test_layer"));
let (avg_mae, avg_mse) = stats.get_averages();
assert_eq!(avg_mae, 0.1);
assert_eq!(avg_mse, 0.01);
}
#[test]
fn test_range_tracker() {
let mut tracker = RangeTracker::new();
let tensor = tensor_1d(&[1.0, 2.0, 3.0, 4.0]).unwrap();
tracker.track_range("test_tensor", &tensor).unwrap();
assert!(tracker.tensor_ranges.contains_key("test_tensor"));
assert_eq!(tracker.tensor_ranges["test_tensor"].len(), 1);
assert_eq!(tracker.tensor_ranges["test_tensor"][0], (1.0, 4.0));
tracker.set_expected_range("test_tensor", (0.0, 3.0));
let tensor2 = tensor_1d(&[0.0, 1.0, 2.0, 5.0]).unwrap(); tracker.track_range("test_tensor", &tensor2).unwrap();
assert!(!tracker.range_violations.is_empty());
assert_eq!(tracker.range_violations[0].tensor_name, "test_tensor");
}
#[test]
fn test_overflow_detector() {
let mut detector = OverflowDetector::new();
assert!(detector.detection_enabled);
detector.set_thresholds(10.0, -10.0);
let tensor = tensor_1d(&[1.0, 2.0, 15.0, 4.0]).unwrap(); detector.detect_overflow("test_tensor", &tensor).unwrap();
assert_eq!(detector.get_overflow_count(), 1);
assert_eq!(detector.overflow_events[0].tensor_name, "test_tensor");
assert_eq!(detector.overflow_events[0].overflow_value, 15.0);
detector.clear_events();
assert_eq!(detector.get_overflow_count(), 0);
}
#[test]
fn test_quantization_comparator() {
let mut comparator = QuantizationComparator::new();
let reference = tensor_1d(&[1.0, 2.0, 3.0, 4.0]).unwrap();
comparator.set_reference(reference);
let quantized_a = tensor_1d(&[1.1, 2.1, 2.9, 3.9]).unwrap();
let quantized_b = tensor_1d(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let result = comparator
.compare_schemes("INT8", &quantized_a, "FP32", &quantized_b)
.unwrap();
assert_eq!(result.scheme_a, "INT8");
assert_eq!(result.scheme_b, "FP32");
assert_eq!(result.winner, "FP32");
assert_eq!(comparator.comparison_results.len(), 1);
let report = comparator.generate_comparison_report();
assert!(report.contains("QUANTIZATION SCHEME COMPARISON REPORT"));
assert!(report.contains("INT8 vs FP32"));
}
#[test]
fn test_debug_step() {
let input_stats = TensorStatistics {
min: 1.0,
max: 4.0,
mean: 2.5,
std: 1.29,
shape: vec![4],
num_elements: 4,
dtype: DType::F32,
};
let output_stats = input_stats.clone();
let quant_params = QuantParams {
scale: 0.1,
zero_point: 0,
scheme: QScheme::PerTensorAffine,
qint_range: (-128, 127),
};
let error_metrics = ErrorMetrics {
mae: 0.1,
mse: 0.01,
rmse: 0.1,
snr: 20.0,
psnr: 30.0,
cosine_similarity: 0.95,
};
let debug_step = DebugStep {
name: "test_step".to_string(),
input_stats,
output_stats,
quant_params,
error_metrics,
timestamp: std::time::Instant::now(),
};
assert_eq!(debug_step.name, "test_step");
assert_eq!(debug_step.quant_params.scale, 0.1);
assert_eq!(debug_step.error_metrics.mae, 0.1);
}
#[test]
fn test_violation_severity() {
assert_eq!(ViolationSeverity::Minor, ViolationSeverity::Minor);
assert_ne!(ViolationSeverity::Minor, ViolationSeverity::Severe);
}
}