use crate::error::{Error, Result};
use crate::time::{AudioDuration, AudioInstant};
use tracing::{debug, trace};
#[derive(Debug, Clone, Copy)]
pub struct QualityMetrics {
pub snr_db: f32,
pub energy: f32,
pub spectral_centroid: f32,
pub quality_score: f32,
}
#[derive(Debug, Clone, Copy)]
pub struct QualityAssessor {
sample_rate: u32,
}
impl QualityAssessor {
pub fn new(sample_rate: u32) -> Self {
Self { sample_rate }
}
pub fn assess(self, samples: &[f32]) -> Result<QualityMetrics> {
trace!(sample_count = samples.len(), "Assessing audio quality");
if samples.is_empty() {
return Err(Error::InvalidInput("Cannot assess empty audio".into()));
}
let processing_start = AudioInstant::now();
let energy = Self::calculate_rms(samples);
let snr_db = Self::calculate_snr(samples, energy)?;
let spectral_centroid = self.calculate_spectral_centroid(samples)?;
let quality_score = self.aggregate_score(snr_db, energy, spectral_centroid);
debug!(
snr_db,
energy, spectral_centroid, quality_score, "Audio quality metrics computed"
);
let metrics = QualityMetrics {
snr_db,
energy,
spectral_centroid,
quality_score,
};
let _latency = elapsed_duration(processing_start);
Ok(metrics)
}
fn calculate_rms(samples: &[f32]) -> f32 {
let sum_squares: f32 = samples.iter().map(|&s| s * s).sum();
let mean_square = sum_squares / samples.len() as f32;
mean_square.sqrt()
}
fn calculate_snr(samples: &[f32], signal_rms: f32) -> Result<f32> {
let frame_energies = Self::frame_energy(samples);
let mut valid_energies: Vec<f32> =
frame_energies.into_iter().filter(|x| !x.is_nan()).collect();
if valid_energies.is_empty() {
return Err(Error::Processing(
"All frame energies are NaN; cannot estimate noise floor".into(),
));
}
valid_energies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let noise_frame_count = (valid_energies.len() / 10).max(1);
let noise_frames = valid_energies
.get(0..noise_frame_count)
.ok_or_else(|| Error::Processing("Insufficient frames for noise estimation".into()))?;
let noise_floor = noise_frames.iter().sum::<f32>() / noise_frames.len() as f32;
if signal_rms < 1e-6 {
return Ok(0.0);
}
if noise_floor < 1e-10 {
return Ok(60.0);
}
let snr = 20.0 * (signal_rms / noise_floor).log10();
Ok(snr.clamp(0.0, 60.0))
}
fn frame_energy(samples: &[f32]) -> Vec<f32> {
const FRAME_SIZE: usize = 256;
samples
.chunks(FRAME_SIZE)
.map(|frame| {
let sum_sq: f32 = frame.iter().map(|&s| s * s).sum();
(sum_sq / frame.len() as f32).sqrt()
})
.collect()
}
fn calculate_spectral_centroid(self, samples: &[f32]) -> Result<f32> {
if samples.len() < 512 {
return Ok(self.sample_rate as f32 / 4.0);
}
let window = samples.get(0..512).ok_or_else(|| {
Error::Processing("Insufficient samples for spectral analysis".into())
})?;
let (magnitude_sum, weighted_sum) =
window
.iter()
.enumerate()
.fold((0.0f32, 0.0f32), |(mag_acc, weighted_acc), (i, &s)| {
let magnitude = s.abs();
(
mag_acc + magnitude,
magnitude.mul_add(i as f32, weighted_acc),
)
});
if magnitude_sum < 1e-10 {
return Ok(self.sample_rate as f32 / 4.0);
}
let centroid_bin = weighted_sum / magnitude_sum;
let centroid_hz = (centroid_bin / 512.0) * (self.sample_rate as f32 / 2.0);
Ok(centroid_hz.clamp(0.0, self.sample_rate as f32 / 2.0))
}
fn aggregate_score(self, snr_db: f32, energy: f32, spectral_centroid: f32) -> f32 {
let snr_score = (snr_db / 60.0).clamp(0.0, 1.0);
let energy_score = (energy / 0.5).clamp(0.0, 1.0);
let centroid_score = (spectral_centroid / (self.sample_rate as f32 / 2.0)).clamp(0.0, 1.0);
let score = 0.5f32.mul_add(
snr_score,
0.3f32.mul_add(energy_score, 0.2 * centroid_score),
);
score.clamp(0.0, 1.0)
}
}
fn elapsed_duration(start: AudioInstant) -> AudioDuration {
AudioInstant::now().duration_since(start)
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 0.01;
#[test]
fn test_high_quality_audio() {
let assessor = QualityAssessor::new(16000);
let mut samples = vec![0.0f32; 16000];
for i in 4000..12000 {
samples[i] = (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 16000.0).sin() * 0.5;
}
let metrics = assessor.assess(&samples).unwrap();
assert!(
metrics.snr_db > 20.0,
"Expected SNR > 20 dB, got {:.1}",
metrics.snr_db
);
assert!((0.0..=1.0).contains(&metrics.quality_score));
assert!(
metrics.quality_score > 0.5,
"Expected quality > 0.5, got {:.2}",
metrics.quality_score
);
}
#[test]
fn test_noisy_audio() {
let assessor = QualityAssessor::new(16000);
let mut noisy = vec![0.0f32; 16000];
for (i, sample) in noisy.iter_mut().enumerate() {
let signal = (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 16000.0).sin() * 0.2;
let noise = (i as f32 * 0.1).sin().mul_add(0.1, (i % 7) as f32 * 0.01);
*sample = signal + noise;
}
let metrics = assessor.assess(&noisy).unwrap();
assert!(
metrics.snr_db < 40.0,
"Expected SNR < 40 dB for noisy audio"
);
assert!((0.0..=1.0).contains(&metrics.quality_score));
}
#[test]
fn test_energy_calculation() {
let assessor = QualityAssessor::new(16000);
let audio = vec![0.5f32; 1000];
let metrics = assessor.assess(&audio).unwrap();
assert!(
(metrics.energy - 0.5).abs() < EPSILON,
"Expected energy ~0.5, got {:.3}",
metrics.energy
);
}
#[test]
fn test_quality_score_bounds() {
let assessor = QualityAssessor::new(16000);
let audio = vec![0.3f32; 5000];
let metrics = assessor.assess(&audio).unwrap();
assert!(
(0.0..=1.0).contains(&metrics.quality_score),
"Quality score {:.2} out of bounds [0.0, 1.0]",
metrics.quality_score
);
assert!(
(0.0..=60.0).contains(&metrics.snr_db),
"SNR {:.1} dB out of bounds [0.0, 60.0]",
metrics.snr_db
);
}
#[test]
fn test_spectral_centroid_computed() {
let assessor = QualityAssessor::new(16000);
let audio = vec![0.2f32; 1024];
let metrics = assessor.assess(&audio).unwrap();
assert!(metrics.spectral_centroid >= 0.0);
assert!(
metrics.spectral_centroid <= 8000.0, "Spectral centroid {:.1} Hz exceeds Nyquist (8000 Hz)",
metrics.spectral_centroid
);
}
#[test]
fn test_empty_audio() {
let assessor = QualityAssessor::new(16000);
let result = assessor.assess(&[]);
assert!(result.is_err(), "Should reject empty audio");
match result.unwrap_err() {
Error::InvalidInput(msg) => {
assert!(
msg.contains("empty"),
"Expected 'empty' error, got: {}",
msg
);
}
other => panic!("Expected InvalidInput error, got: {:?}", other),
}
}
#[test]
fn test_silence_handling() {
let assessor = QualityAssessor::new(16000);
let silence = vec![0.0f32; 16000];
let metrics = assessor.assess(&silence).unwrap();
assert!(
metrics.energy < EPSILON,
"Expected near-zero energy for silence, got {:.6}",
metrics.energy
);
assert!(
metrics.snr_db < 1.0,
"Expected SNR ~0 dB for silence, got {:.1} dB",
metrics.snr_db
);
assert!(
metrics.quality_score < 0.2,
"Expected quality <0.2 for silence, got {:.2}",
metrics.quality_score
);
assert!((0.0..=1.0).contains(&metrics.quality_score));
}
#[test]
fn test_short_audio() {
let assessor = QualityAssessor::new(16000);
let short_audio = vec![0.5f32; 256];
let metrics = assessor.assess(&short_audio).unwrap();
assert!((0.0..=1.0).contains(&metrics.quality_score));
assert!(metrics.spectral_centroid > 0.0);
}
#[test]
fn test_very_quiet_audio() {
let assessor = QualityAssessor::new(16000);
let very_quiet = vec![1e-7f32; 16000];
let metrics = assessor.assess(&very_quiet).unwrap();
assert!(
metrics.energy < 1e-6,
"Expected near-zero energy for very quiet audio, got {:.9}",
metrics.energy
);
assert!(
metrics.snr_db < 5.0,
"Expected low SNR for very quiet audio, got {:.1} dB",
metrics.snr_db
);
assert!(
metrics.quality_score < 0.3,
"Expected low quality for very quiet audio, got {:.2}",
metrics.quality_score
);
}
}