Skip to main content

oximedia_codec/
quality_metrics.rs

1//! Quality metrics for codec evaluation.
2//!
3//! Provides PSNR, SNR, and MSE calculations for comparing original and
4//! reconstructed media data. These metrics are used to validate codec
5//! round-trip fidelity and to verify that encoding quality meets minimum
6//! thresholds.
7//!
8//! # Metrics
9//!
10//! - **MSE** (Mean Squared Error): Average squared difference between samples.
11//! - **PSNR** (Peak Signal-to-Noise Ratio): Logarithmic quality measure in dB.
12//! - **SNR** (Signal-to-Noise Ratio): Ratio of signal power to noise power.
13//!
14//! # Usage
15//!
16//! ```rust
17//! use oximedia_codec::quality_metrics::{compute_psnr_u8, compute_mse_f32, compute_snr_f32};
18//!
19//! let original = [100u8, 150, 200, 50, 30];
20//! let decoded  = [101u8, 149, 201, 49, 31];
21//! let psnr = compute_psnr_u8(&original, &decoded);
22//! assert!(psnr > 40.0, "near-identical samples should have high PSNR");
23//! ```
24
25/// Compute Mean Squared Error between two `u8` slices.
26///
27/// Returns `0.0` if both slices are empty. Only the overlapping portion
28/// (min length) is compared.
29#[must_use]
30pub fn compute_mse_u8(original: &[u8], decoded: &[u8]) -> f64 {
31    let len = original.len().min(decoded.len());
32    if len == 0 {
33        return 0.0;
34    }
35    let sum_sq: f64 = original[..len]
36        .iter()
37        .zip(decoded[..len].iter())
38        .map(|(&a, &b)| {
39            let diff = f64::from(a) - f64::from(b);
40            diff * diff
41        })
42        .sum();
43    sum_sq / len as f64
44}
45
46/// Compute PSNR (Peak Signal-to-Noise Ratio) for `u8` data.
47///
48/// Peak value is 255. Returns `f64::INFINITY` when the signals are identical
49/// (MSE == 0).
50#[must_use]
51pub fn compute_psnr_u8(original: &[u8], decoded: &[u8]) -> f64 {
52    let mse = compute_mse_u8(original, decoded);
53    if mse < f64::EPSILON {
54        return f64::INFINITY;
55    }
56    let peak = 255.0_f64;
57    10.0 * (peak * peak / mse).log10()
58}
59
60/// Compute Mean Squared Error between two `f32` slices.
61///
62/// Compares only the overlapping portion.
63#[must_use]
64pub fn compute_mse_f32(original: &[f32], decoded: &[f32]) -> f64 {
65    let len = original.len().min(decoded.len());
66    if len == 0 {
67        return 0.0;
68    }
69    let sum_sq: f64 = original[..len]
70        .iter()
71        .zip(decoded[..len].iter())
72        .map(|(&a, &b)| {
73            let diff = f64::from(a) - f64::from(b);
74            diff * diff
75        })
76        .sum();
77    sum_sq / len as f64
78}
79
80/// Compute PSNR for normalised `f32` audio data (peak = 1.0).
81///
82/// Returns `f64::INFINITY` when signals are identical.
83#[must_use]
84pub fn compute_psnr_f32(original: &[f32], decoded: &[f32]) -> f64 {
85    let mse = compute_mse_f32(original, decoded);
86    if mse < f64::EPSILON {
87        return f64::INFINITY;
88    }
89    let peak = 1.0_f64;
90    10.0 * (peak * peak / mse).log10()
91}
92
93/// Compute Signal-to-Noise Ratio for `f32` data.
94///
95/// SNR = 10 * log10(signal_power / noise_power).
96/// Returns `f64::INFINITY` when noise is zero.
97#[must_use]
98pub fn compute_snr_f32(original: &[f32], decoded: &[f32]) -> f64 {
99    let len = original.len().min(decoded.len());
100    if len == 0 {
101        return 0.0;
102    }
103    let signal_power: f64 = original[..len]
104        .iter()
105        .map(|&s| f64::from(s) * f64::from(s))
106        .sum::<f64>()
107        / len as f64;
108    let noise_power: f64 = original[..len]
109        .iter()
110        .zip(decoded[..len].iter())
111        .map(|(&a, &b)| {
112            let d = f64::from(a) - f64::from(b);
113            d * d
114        })
115        .sum::<f64>()
116        / len as f64;
117
118    if noise_power < f64::EPSILON {
119        return f64::INFINITY;
120    }
121    if signal_power < f64::EPSILON {
122        return 0.0;
123    }
124    10.0 * (signal_power / noise_power).log10()
125}
126
127/// Compute PSNR for `u16` data (useful for 10-bit or 12-bit video).
128///
129/// `bit_depth` determines the peak value (2^bit_depth - 1).
130#[must_use]
131pub fn compute_psnr_u16(original: &[u16], decoded: &[u16], bit_depth: u8) -> f64 {
132    let len = original.len().min(decoded.len());
133    if len == 0 {
134        return 0.0;
135    }
136    let sum_sq: f64 = original[..len]
137        .iter()
138        .zip(decoded[..len].iter())
139        .map(|(&a, &b)| {
140            let diff = f64::from(a) - f64::from(b);
141            diff * diff
142        })
143        .sum();
144    let mse = sum_sq / len as f64;
145    if mse < f64::EPSILON {
146        return f64::INFINITY;
147    }
148    let peak = (1u32 << bit_depth) as f64 - 1.0;
149    10.0 * (peak * peak / mse).log10()
150}
151
152/// Compute the Structural Similarity Index (simplified) for `u8` data.
153///
154/// This is a simplified version using global statistics rather than the
155/// standard 11x11 Gaussian windowed approach, suitable for quick quality
156/// checks in codec testing.
157///
158/// Returns a value in `[0.0, 1.0]` where 1.0 indicates identical signals.
159#[must_use]
160pub fn compute_ssim_simplified_u8(original: &[u8], decoded: &[u8]) -> f64 {
161    let len = original.len().min(decoded.len());
162    if len == 0 {
163        return 1.0;
164    }
165
166    let n = len as f64;
167
168    // Mean
169    let mu_x: f64 = original[..len].iter().map(|&v| f64::from(v)).sum::<f64>() / n;
170    let mu_y: f64 = decoded[..len].iter().map(|&v| f64::from(v)).sum::<f64>() / n;
171
172    // Variance and covariance
173    let mut var_x = 0.0_f64;
174    let mut var_y = 0.0_f64;
175    let mut cov_xy = 0.0_f64;
176
177    for i in 0..len {
178        let dx = f64::from(original[i]) - mu_x;
179        let dy = f64::from(decoded[i]) - mu_y;
180        var_x += dx * dx;
181        var_y += dy * dy;
182        cov_xy += dx * dy;
183    }
184    var_x /= n;
185    var_y /= n;
186    cov_xy /= n;
187
188    // Constants (for 8-bit dynamic range: L=255)
189    let c1 = (0.01 * 255.0) * (0.01 * 255.0); // 6.5025
190    let c2 = (0.03 * 255.0) * (0.03 * 255.0); // 58.5225
191
192    let numerator = (2.0 * mu_x * mu_y + c1) * (2.0 * cov_xy + c2);
193    let denominator = (mu_x * mu_x + mu_y * mu_y + c1) * (var_x + var_y + c2);
194
195    if denominator < f64::EPSILON {
196        return 1.0;
197    }
198
199    (numerator / denominator).clamp(0.0, 1.0)
200}
201
202/// Quality assessment result for a codec round-trip test.
203#[derive(Debug, Clone)]
204pub struct QualityAssessment {
205    /// PSNR in dB.
206    pub psnr_db: f64,
207    /// Mean Squared Error.
208    pub mse: f64,
209    /// SNR in dB.
210    pub snr_db: f64,
211    /// Whether the quality meets the minimum threshold.
212    pub passes_threshold: bool,
213    /// Minimum PSNR threshold that was applied.
214    pub threshold_db: f64,
215}
216
217/// Perform a full quality assessment on `u8` data.
218///
219/// `min_psnr_db` is the threshold: if the measured PSNR is at least this
220/// value, `passes_threshold` is set to `true`.
221#[must_use]
222pub fn assess_quality_u8(original: &[u8], decoded: &[u8], min_psnr_db: f64) -> QualityAssessment {
223    let mse = compute_mse_u8(original, decoded);
224    let psnr_db = compute_psnr_u8(original, decoded);
225
226    // Convert u8 to f32 for SNR
227    let orig_f32: Vec<f32> = original.iter().map(|&v| v as f32 / 255.0).collect();
228    let dec_f32: Vec<f32> = decoded.iter().map(|&v| v as f32 / 255.0).collect();
229    let snr_db = compute_snr_f32(&orig_f32, &dec_f32);
230
231    QualityAssessment {
232        psnr_db,
233        mse,
234        snr_db,
235        passes_threshold: psnr_db >= min_psnr_db,
236        threshold_db: min_psnr_db,
237    }
238}
239
240/// Perform a full quality assessment on `f32` data.
241#[must_use]
242pub fn assess_quality_f32(
243    original: &[f32],
244    decoded: &[f32],
245    min_psnr_db: f64,
246) -> QualityAssessment {
247    let mse = compute_mse_f32(original, decoded);
248    let psnr_db = compute_psnr_f32(original, decoded);
249    let snr_db = compute_snr_f32(original, decoded);
250
251    QualityAssessment {
252        psnr_db,
253        mse,
254        snr_db,
255        passes_threshold: psnr_db >= min_psnr_db,
256        threshold_db: min_psnr_db,
257    }
258}
259
260// =============================================================================
261// Tests — Round-trip encode/decode quality (PSNR > threshold)
262// =============================================================================
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    // ── MSE tests ───────────────────────────────────────────────────────────
269
270    #[test]
271    fn test_mse_u8_identical() {
272        let data = [10u8, 20, 30, 40, 50];
273        let mse = compute_mse_u8(&data, &data);
274        assert!(mse < f64::EPSILON, "MSE of identical signals should be 0");
275    }
276
277    #[test]
278    fn test_mse_u8_known_difference() {
279        let a = [100u8, 200, 150];
280        let b = [101u8, 199, 151];
281        // Each diff is ±1 → sum_sq = 1+1+1 = 3 → MSE = 1.0
282        let mse = compute_mse_u8(&a, &b);
283        assert!((mse - 1.0).abs() < f64::EPSILON);
284    }
285
286    #[test]
287    fn test_mse_u8_empty() {
288        let mse = compute_mse_u8(&[], &[]);
289        assert!(mse < f64::EPSILON);
290    }
291
292    #[test]
293    fn test_mse_f32_identical() {
294        let data = [0.1f32, 0.5, -0.3, 0.9, -0.8];
295        let mse = compute_mse_f32(&data, &data);
296        assert!(mse < f64::EPSILON);
297    }
298
299    #[test]
300    fn test_mse_f32_known_difference() {
301        let a = [1.0f32, 0.0, 0.0];
302        let b = [0.0f32, 0.0, 0.0];
303        // diff = 1 → sum_sq = 1 → MSE = 1/3
304        let mse = compute_mse_f32(&a, &b);
305        assert!((mse - 1.0 / 3.0).abs() < 1e-10);
306    }
307
308    // ── PSNR tests ──────────────────────────────────────────────────────────
309
310    #[test]
311    fn test_psnr_u8_identical() {
312        let data = [128u8; 100];
313        let psnr = compute_psnr_u8(&data, &data);
314        assert!(
315            psnr.is_infinite(),
316            "identical signals should have infinite PSNR"
317        );
318    }
319
320    #[test]
321    fn test_psnr_u8_small_error() {
322        let original: Vec<u8> = (0..=255).collect();
323        let mut decoded = original.clone();
324        // Introduce ±1 error on every other sample
325        for i in (0..decoded.len()).step_by(2) {
326            decoded[i] = decoded[i].saturating_add(1);
327        }
328        let psnr = compute_psnr_u8(&original, &decoded);
329        // ±1 error on half the samples → MSE = 0.5 → PSNR ≈ 51.1 dB
330        assert!(psnr > 40.0, "±1 error should yield high PSNR, got {psnr}");
331    }
332
333    #[test]
334    fn test_psnr_u8_large_error() {
335        let original = vec![128u8; 100];
336        let decoded = vec![0u8; 100];
337        let psnr = compute_psnr_u8(&original, &decoded);
338        // MSE = 128² = 16384 → PSNR ≈ 6.0 dB
339        assert!(psnr < 10.0, "large error should yield low PSNR, got {psnr}");
340    }
341
342    #[test]
343    fn test_psnr_f32_identical() {
344        let data = [0.5f32; 50];
345        let psnr = compute_psnr_f32(&data, &data);
346        assert!(psnr.is_infinite());
347    }
348
349    #[test]
350    fn test_psnr_f32_small_error() {
351        let original: Vec<f32> = (0..100).map(|i| i as f32 / 100.0).collect();
352        let decoded: Vec<f32> = original.iter().map(|&s| s + 0.001).collect();
353        let psnr = compute_psnr_f32(&original, &decoded);
354        assert!(psnr > 50.0, "tiny error should yield high PSNR, got {psnr}");
355    }
356
357    // ── SNR tests ───────────────────────────────────────────────────────────
358
359    #[test]
360    fn test_snr_f32_identical() {
361        let data = [0.5f32, -0.3, 0.8, -0.1, 0.9];
362        let snr = compute_snr_f32(&data, &data);
363        assert!(snr.is_infinite());
364    }
365
366    #[test]
367    fn test_snr_f32_noisy() {
368        let original: Vec<f32> = (0..100).map(|i| (i as f32 / 50.0) - 1.0).collect();
369        let decoded: Vec<f32> = original.iter().map(|&s| s + 0.1).collect();
370        let snr = compute_snr_f32(&original, &decoded);
371        assert!(
372            snr > 0.0,
373            "signal with small noise should have positive SNR"
374        );
375    }
376
377    #[test]
378    fn test_snr_f32_zero_signal() {
379        let original = [0.0f32; 100];
380        let decoded = [0.1f32; 100];
381        let snr = compute_snr_f32(&original, &decoded);
382        assert!(
383            snr <= 0.0,
384            "zero signal with noise should have non-positive SNR"
385        );
386    }
387
388    // ── PSNR u16 tests ──────────────────────────────────────────────────────
389
390    #[test]
391    fn test_psnr_u16_identical() {
392        let data: Vec<u16> = (0..100).map(|i| (i * 10) as u16).collect();
393        let psnr = compute_psnr_u16(&data, &data, 10);
394        assert!(psnr.is_infinite());
395    }
396
397    #[test]
398    fn test_psnr_u16_10bit_small_error() {
399        let original: Vec<u16> = (0..1024).collect();
400        let mut decoded = original.clone();
401        for i in (0..decoded.len()).step_by(2) {
402            decoded[i] = decoded[i].saturating_add(1);
403        }
404        let psnr = compute_psnr_u16(&original, &decoded, 10);
405        // 10-bit peak = 1023
406        assert!(
407            psnr > 50.0,
408            "±1 error on 10-bit should yield high PSNR, got {psnr}"
409        );
410    }
411
412    // ── SSIM tests ──────────────────────────────────────────────────────────
413
414    #[test]
415    fn test_ssim_identical() {
416        let data: Vec<u8> = (0..=255).collect();
417        let ssim = compute_ssim_simplified_u8(&data, &data);
418        assert!(
419            (ssim - 1.0).abs() < 1e-6,
420            "identical signals should have SSIM=1, got {ssim}"
421        );
422    }
423
424    #[test]
425    fn test_ssim_similar() {
426        let original: Vec<u8> = (0..=255).collect();
427        let mut decoded = original.clone();
428        for i in (0..decoded.len()).step_by(3) {
429            decoded[i] = decoded[i].saturating_add(1);
430        }
431        let ssim = compute_ssim_simplified_u8(&original, &decoded);
432        assert!(
433            ssim > 0.99,
434            "nearly identical signals should have high SSIM, got {ssim}"
435        );
436    }
437
438    #[test]
439    fn test_ssim_empty() {
440        let ssim = compute_ssim_simplified_u8(&[], &[]);
441        assert!((ssim - 1.0).abs() < 1e-6);
442    }
443
444    // ── Quality assessment tests ────────────────────────────────────────────
445
446    #[test]
447    fn test_assess_quality_u8_passes() {
448        let original: Vec<u8> = (0..=255).collect();
449        let decoded = original.clone();
450        let result = assess_quality_u8(&original, &decoded, 30.0);
451        assert!(
452            result.passes_threshold,
453            "identical data should pass any threshold"
454        );
455        assert!(result.psnr_db.is_infinite());
456        assert!(result.mse < f64::EPSILON);
457    }
458
459    #[test]
460    fn test_assess_quality_u8_fails() {
461        let original = vec![128u8; 100];
462        let decoded = vec![0u8; 100];
463        let result = assess_quality_u8(&original, &decoded, 30.0);
464        assert!(
465            !result.passes_threshold,
466            "large error should fail threshold"
467        );
468        assert!(result.psnr_db < 30.0);
469    }
470
471    #[test]
472    fn test_assess_quality_f32_passes() {
473        let original: Vec<f32> = (0..100).map(|i| i as f32 / 100.0).collect();
474        let decoded: Vec<f32> = original.iter().map(|&s| s + 0.0001).collect();
475        let result = assess_quality_f32(&original, &decoded, 60.0);
476        assert!(
477            result.passes_threshold,
478            "tiny error should pass 60 dB threshold, got {} dB",
479            result.psnr_db
480        );
481    }
482
483    // ── PCM codec round-trip PSNR tests ─────────────────────────────────────
484
485    #[test]
486    fn test_pcm_f32_roundtrip_psnr() {
487        use crate::audio::{AudioFrame, SampleFormat};
488        use crate::pcm::{ByteOrder, PcmConfig, PcmDecoder, PcmEncoder, PcmFormat};
489
490        let config = PcmConfig {
491            format: PcmFormat::F32,
492            byte_order: ByteOrder::Little,
493            sample_rate: 48000,
494            channels: 1,
495        };
496        let enc = PcmEncoder::new(config.clone());
497        let dec = PcmDecoder::new(config);
498
499        // Generate a 440 Hz sine wave
500        let samples: Vec<f32> = (0..4800)
501            .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 48000.0).sin())
502            .collect();
503
504        let raw_bytes: Vec<u8> = samples.iter().flat_map(|s| s.to_le_bytes()).collect();
505        let frame = AudioFrame::new(raw_bytes, 4800, 48000, 1, SampleFormat::F32);
506        let encoded = enc.encode_frame(&frame).expect("encode");
507        let decoded_frame = dec.decode_bytes(&encoded).expect("decode");
508
509        // Extract decoded samples
510        let decoded_samples: Vec<f32> = decoded_frame
511            .samples
512            .chunks_exact(4)
513            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
514            .collect();
515
516        let psnr = compute_psnr_f32(&samples, &decoded_samples);
517        assert!(
518            psnr.is_infinite(),
519            "PCM F32 round-trip should be lossless (infinite PSNR), got {psnr}"
520        );
521    }
522
523    #[test]
524    fn test_pcm_i16_roundtrip_psnr_above_threshold() {
525        use crate::audio::{AudioFrame, SampleFormat};
526        use crate::pcm::{ByteOrder, PcmConfig, PcmDecoder, PcmEncoder, PcmFormat};
527
528        let config = PcmConfig {
529            format: PcmFormat::I16,
530            byte_order: ByteOrder::Little,
531            sample_rate: 44100,
532            channels: 2,
533        };
534        let enc = PcmEncoder::new(config.clone());
535        let dec = PcmDecoder::new(config);
536
537        // Generate stereo signal
538        let samples: Vec<f32> = (0..8820)
539            .map(|i| {
540                let t = i as f32 / 44100.0;
541                let ch = i % 2;
542                let freq = if ch == 0 { 440.0 } else { 880.0 };
543                (2.0 * std::f32::consts::PI * freq * t).sin() * 0.8
544            })
545            .collect();
546
547        let raw_bytes: Vec<u8> = samples.iter().flat_map(|s| s.to_le_bytes()).collect();
548        let frame = AudioFrame::new(raw_bytes, 4410, 44100, 2, SampleFormat::F32);
549        let encoded = enc.encode_frame(&frame).expect("encode");
550        let decoded_frame = dec.decode_bytes(&encoded).expect("decode");
551
552        let decoded_samples: Vec<f32> = decoded_frame
553            .samples
554            .chunks_exact(4)
555            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
556            .collect();
557
558        // I16 quantization introduces ≈ 96 dB SNR → PSNR should be very high
559        let psnr = compute_psnr_f32(&samples, &decoded_samples);
560        assert!(
561            psnr > 80.0,
562            "PCM I16 round-trip should have PSNR > 80 dB, got {psnr}"
563        );
564    }
565
566    #[test]
567    fn test_pcm_u8_roundtrip_psnr_above_threshold() {
568        use crate::audio::{AudioFrame, SampleFormat};
569        use crate::pcm::{ByteOrder, PcmConfig, PcmDecoder, PcmEncoder, PcmFormat};
570
571        let config = PcmConfig {
572            format: PcmFormat::U8,
573            byte_order: ByteOrder::Little,
574            sample_rate: 22050,
575            channels: 1,
576        };
577        let enc = PcmEncoder::new(config.clone());
578        let dec = PcmDecoder::new(config);
579
580        let samples: Vec<f32> = (0..2205)
581            .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 22050.0).sin() * 0.5)
582            .collect();
583
584        let raw_bytes: Vec<u8> = samples.iter().flat_map(|s| s.to_le_bytes()).collect();
585        let frame = AudioFrame::new(raw_bytes, 2205, 22050, 1, SampleFormat::F32);
586        let encoded = enc.encode_frame(&frame).expect("encode");
587        let decoded_frame = dec.decode_bytes(&encoded).expect("decode");
588
589        let decoded_samples: Vec<f32> = decoded_frame
590            .samples
591            .chunks_exact(4)
592            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
593            .collect();
594
595        // U8 has only 8-bit resolution → ~48 dB SNR
596        let psnr = compute_psnr_f32(&samples, &decoded_samples);
597        assert!(
598            psnr > 30.0,
599            "PCM U8 round-trip should have PSNR > 30 dB, got {psnr}"
600        );
601    }
602}