use super::config::{AnomalyThresholds, QuantizationParameters};
use crate::TorshResult;
use std::collections::VecDeque;
use std::time::Instant;
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct QualityAssessor {
#[allow(dead_code)]
metrics: QualityMetrics,
quality_history: VecDeque<QualityMeasurement>,
#[allow(dead_code)]
anomaly_thresholds: AnomalyThresholds,
degradation_detector: DegradationDetector,
}
#[derive(Debug, Clone)]
pub struct QualityMetrics {
pub snr: f32,
pub mse: f32,
pub psnr: f32,
pub perceptual_score: f32,
pub ssim: f32,
}
#[derive(Debug, Clone)]
pub struct QualityMeasurement {
pub metrics: QualityMetrics,
pub quant_params: QuantizationParameters,
pub timestamp: Instant,
pub operation_count: usize,
}
#[derive(Debug, Clone)]
pub struct DegradationDetector {
window_size: usize,
#[allow(dead_code)]
degradation_slope_threshold: f32,
recent_measurements: VecDeque<f32>,
}
impl Default for QualityMetrics {
fn default() -> Self {
Self {
snr: 0.0,
mse: 0.0,
psnr: 0.0,
perceptual_score: 1.0,
ssim: 1.0,
}
}
}
impl QualityAssessor {
pub fn new() -> Self {
Self {
metrics: QualityMetrics::default(),
quality_history: VecDeque::new(),
anomaly_thresholds: AnomalyThresholds {
snr_threshold: 15.0,
mse_threshold: 0.1,
perceptual_threshold: 0.7,
},
degradation_detector: DegradationDetector {
window_size: 10,
degradation_slope_threshold: -0.05,
recent_measurements: VecDeque::new(),
},
}
}
pub fn assess_quality(
&mut self,
original: &Tensor,
quantized: &Tensor,
params: &QuantizationParameters,
) -> TorshResult<QualityMetrics> {
let orig_data = original.data()?;
let quant_data = quantized.data()?;
if orig_data.len() != quant_data.len() {
return Err(torsh_core::TorshError::operation_error(
"quality assessment: Tensor size mismatch",
));
}
let mse = orig_data
.iter()
.zip(quant_data.iter())
.map(|(o, q)| (o - q).powi(2))
.sum::<f32>()
/ orig_data.len() as f32;
let signal_power =
orig_data.iter().map(|x| x.powi(2)).sum::<f32>() / orig_data.len() as f32;
let snr = if mse > 0.0 {
10.0 * (signal_power / mse).log10()
} else {
f32::INFINITY
};
let max_val = orig_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let psnr = if mse > 0.0 {
20.0 * (max_val / mse.sqrt()).log10()
} else {
f32::INFINITY
};
let perceptual_score = 1.0 / (1.0 + mse);
let ssim = self.calculate_ssim(&orig_data, &quant_data);
let metrics = QualityMetrics {
snr: snr.min(100.0), mse,
psnr: psnr.min(100.0), perceptual_score,
ssim,
};
let measurement = QualityMeasurement {
metrics: metrics.clone(),
quant_params: params.clone(),
timestamp: Instant::now(),
operation_count: self.quality_history.len(),
};
self.quality_history.push_back(measurement);
if self.quality_history.len() > 1000 {
self.quality_history.pop_front();
}
self.degradation_detector
.recent_measurements
.push_back(perceptual_score);
if self.degradation_detector.recent_measurements.len()
> self.degradation_detector.window_size
{
self.degradation_detector.recent_measurements.pop_front();
}
Ok(metrics)
}
fn calculate_ssim(&self, orig: &[f32], quant: &[f32]) -> f32 {
let mu1 = orig.iter().sum::<f32>() / orig.len() as f32;
let mu2 = quant.iter().sum::<f32>() / quant.len() as f32;
let sigma1_sq = orig.iter().map(|x| (x - mu1).powi(2)).sum::<f32>() / orig.len() as f32;
let sigma2_sq = quant.iter().map(|x| (x - mu2).powi(2)).sum::<f32>() / quant.len() as f32;
let sigma12 = orig
.iter()
.zip(quant.iter())
.map(|(x, y)| (x - mu1) * (y - mu2))
.sum::<f32>()
/ orig.len() as f32;
let c1 = 0.01_f32.powi(2);
let c2 = 0.03_f32.powi(2);
let numerator = (2.0 * mu1 * mu2 + c1) * (2.0 * sigma12 + c2);
let denominator = (mu1.powi(2) + mu2.powi(2) + c1) * (sigma1_sq + sigma2_sq + c2);
if denominator > 1e-10 {
numerator / denominator
} else {
1.0
}
}
pub fn detect_degradation(&self) -> bool {
if self.degradation_detector.recent_measurements.len() < 3 {
return false;
}
let measurements: Vec<f32> = self
.degradation_detector
.recent_measurements
.iter()
.cloned()
.collect();
let mut decreasing_count = 0;
for i in 1..measurements.len() {
if measurements[i] < measurements[i - 1] {
decreasing_count += 1;
}
}
decreasing_count > measurements.len() / 2
}
pub fn get_quality_statistics(&self) -> QualityStatistics {
if self.quality_history.is_empty() {
return QualityStatistics::default();
}
let recent_count = self.quality_history.len().min(100);
let recent_measurements: Vec<&QualityMeasurement> = self
.quality_history
.iter()
.rev()
.take(recent_count)
.collect();
let avg_snr = recent_measurements
.iter()
.map(|m| m.metrics.snr)
.sum::<f32>()
/ recent_count as f32;
let avg_mse = recent_measurements
.iter()
.map(|m| m.metrics.mse)
.sum::<f32>()
/ recent_count as f32;
let avg_psnr = recent_measurements
.iter()
.map(|m| m.metrics.psnr)
.sum::<f32>()
/ recent_count as f32;
let avg_perceptual = recent_measurements
.iter()
.map(|m| m.metrics.perceptual_score)
.sum::<f32>()
/ recent_count as f32;
let avg_ssim = recent_measurements
.iter()
.map(|m| m.metrics.ssim)
.sum::<f32>()
/ recent_count as f32;
let min_snr = recent_measurements
.iter()
.map(|m| m.metrics.snr)
.fold(f32::INFINITY, f32::min);
let max_snr = recent_measurements
.iter()
.map(|m| m.metrics.snr)
.fold(f32::NEG_INFINITY, f32::max);
QualityStatistics {
sample_count: recent_count,
avg_snr,
avg_mse,
avg_psnr,
avg_perceptual_score: avg_perceptual,
avg_ssim,
min_snr,
max_snr,
degradation_detected: self.detect_degradation(),
}
}
pub fn get_quality_history(&self) -> &VecDeque<QualityMeasurement> {
&self.quality_history
}
pub fn clear_history(&mut self) {
self.quality_history.clear();
self.degradation_detector.recent_measurements.clear();
}
}
impl Default for QualityAssessor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct QualityStatistics {
pub sample_count: usize,
pub avg_snr: f32,
pub avg_mse: f32,
pub avg_psnr: f32,
pub avg_perceptual_score: f32,
pub avg_ssim: f32,
pub min_snr: f32,
pub max_snr: f32,
pub degradation_detected: bool,
}
impl Default for QualityStatistics {
fn default() -> Self {
Self {
sample_count: 0,
avg_snr: 0.0,
avg_mse: 0.0,
avg_psnr: 0.0,
avg_perceptual_score: 1.0,
avg_ssim: 1.0,
min_snr: 0.0,
max_snr: 0.0,
degradation_detected: false,
}
}
}