Skip to main content

oximedia_gpu/
gpu_cpu_verify.rs

1//! GPU vs CPU output comparison and verification utilities.
2//!
3//! Provides tools for verifying that GPU-accelerated operations produce
4//! results within an acceptable tolerance of their CPU reference
5//! implementations.
6//!
7//! # Tolerance Metrics
8//!
9//! | Metric | Description |
10//! |--------|-------------|
11//! | [`ToleranceMetric::MaxAbsDiff`] | Maximum absolute difference across all elements. |
12//! | [`ToleranceMetric::MeanAbsDiff`] | Mean absolute difference. |
13//! | [`ToleranceMetric::Psnr`] | Peak Signal-to-Noise Ratio (dB). |
14//! | [`ToleranceMetric::RmsError`] | Root Mean Square error. |
15//!
16//! # Example
17//!
18//! ```rust
19//! use oximedia_gpu::gpu_cpu_verify::{ComparisonResult, VerificationSuite, ToleranceMetric};
20//!
21//! let gpu_output = vec![128u8, 129, 127, 255];
22//! let cpu_output = vec![128u8, 128, 128, 255];
23//! let result = ComparisonResult::compare_u8(&gpu_output, &cpu_output);
24//! assert!(result.max_abs_diff <= 1);
25//! ```
26
27// ─── Tolerance metric ──────────────────────────────────────────────────────
28
29/// Which metric to use for pass/fail decisions.
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub enum ToleranceMetric {
32    /// Maximum absolute difference must be ≤ threshold.
33    MaxAbsDiff(u32),
34    /// Mean absolute difference must be ≤ threshold.
35    MeanAbsDiff(f64),
36    /// PSNR must be ≥ threshold (in dB).  Higher is better.
37    Psnr(f64),
38    /// RMS error must be ≤ threshold.
39    RmsError(f64),
40}
41
42impl std::fmt::Display for ToleranceMetric {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        match self {
45            Self::MaxAbsDiff(t) => write!(f, "max_abs_diff <= {t}"),
46            Self::MeanAbsDiff(t) => write!(f, "mean_abs_diff <= {t:.4}"),
47            Self::Psnr(t) => write!(f, "psnr >= {t:.2} dB"),
48            Self::RmsError(t) => write!(f, "rms_error <= {t:.4}"),
49        }
50    }
51}
52
53// ─── ComparisonResult ───────────────────────────────────────────────────────
54
55/// Detailed result of comparing two buffers element-by-element.
56#[derive(Debug, Clone)]
57pub struct ComparisonResult {
58    /// Maximum absolute difference between any two corresponding elements.
59    pub max_abs_diff: u32,
60    /// Mean absolute difference.
61    pub mean_abs_diff: f64,
62    /// Peak Signal-to-Noise Ratio (dB).  `f64::INFINITY` if buffers are identical.
63    pub psnr_db: f64,
64    /// Root Mean Square error.
65    pub rms_error: f64,
66    /// Index of the element with the largest absolute difference.
67    pub worst_index: usize,
68    /// Total number of elements compared.
69    pub element_count: usize,
70    /// Number of elements that differ by at least 1.
71    pub differing_elements: usize,
72    /// Histogram of absolute differences (bin `i` = count of elements with diff == `i`).
73    /// Limited to the first 256 bins.
74    pub diff_histogram: [u64; 256],
75}
76
77impl ComparisonResult {
78    /// Compare two u8 buffers element-by-element.
79    ///
80    /// If the buffers have different lengths, the comparison is performed
81    /// over the shorter length.
82    #[must_use]
83    pub fn compare_u8(a: &[u8], b: &[u8]) -> Self {
84        let len = a.len().min(b.len());
85        if len == 0 {
86            return Self::empty();
87        }
88
89        let mut max_diff: u32 = 0;
90        let mut sum_diff: f64 = 0.0;
91        let mut sum_sq_diff: f64 = 0.0;
92        let mut worst_idx: usize = 0;
93        let mut differing: usize = 0;
94        let mut histogram = [0u64; 256];
95
96        for i in 0..len {
97            let diff = (a[i] as i32 - b[i] as i32).unsigned_abs();
98            if diff > max_diff {
99                max_diff = diff;
100                worst_idx = i;
101            }
102            sum_diff += diff as f64;
103            sum_sq_diff += (diff as f64) * (diff as f64);
104            if diff > 0 {
105                differing += 1;
106            }
107            let bin = (diff as usize).min(255);
108            histogram[bin] += 1;
109        }
110
111        let n = len as f64;
112        let mean_diff = sum_diff / n;
113        let mse = sum_sq_diff / n;
114        let rms = mse.sqrt();
115        let psnr = if mse < 1e-12 {
116            f64::INFINITY
117        } else {
118            10.0 * (255.0 * 255.0 / mse).log10()
119        };
120
121        Self {
122            max_abs_diff: max_diff,
123            mean_abs_diff: mean_diff,
124            psnr_db: psnr,
125            rms_error: rms,
126            worst_index: worst_idx,
127            element_count: len,
128            differing_elements: differing,
129            diff_histogram: histogram,
130        }
131    }
132
133    /// Compare two f32 buffers element-by-element.
134    ///
135    /// Values are assumed to be in the range \[0.0, 1.0\].
136    #[must_use]
137    pub fn compare_f32(a: &[f32], b: &[f32]) -> Self {
138        let len = a.len().min(b.len());
139        if len == 0 {
140            return Self::empty();
141        }
142
143        let mut max_diff: u32 = 0;
144        let mut sum_diff: f64 = 0.0;
145        let mut sum_sq_diff: f64 = 0.0;
146        let mut worst_idx: usize = 0;
147        let mut differing: usize = 0;
148        let mut histogram = [0u64; 256];
149
150        for i in 0..len {
151            let diff_f = ((a[i] - b[i]) as f64).abs();
152            let diff_u = (diff_f * 255.0).round() as u32;
153            if diff_u > max_diff {
154                max_diff = diff_u;
155                worst_idx = i;
156            }
157            sum_diff += diff_f * 255.0;
158            sum_sq_diff += diff_f * diff_f * 255.0 * 255.0;
159            if diff_u > 0 {
160                differing += 1;
161            }
162            let bin = (diff_u as usize).min(255);
163            histogram[bin] += 1;
164        }
165
166        let n = len as f64;
167        let mean_diff = sum_diff / n;
168        let mse = sum_sq_diff / n;
169        let rms = mse.sqrt();
170        let psnr = if mse < 1e-12 {
171            f64::INFINITY
172        } else {
173            10.0 * (255.0 * 255.0 / mse).log10()
174        };
175
176        Self {
177            max_abs_diff: max_diff,
178            mean_abs_diff: mean_diff,
179            psnr_db: psnr,
180            rms_error: rms,
181            worst_index: worst_idx,
182            element_count: len,
183            differing_elements: differing,
184            diff_histogram: histogram,
185        }
186    }
187
188    /// Check whether this result passes a given tolerance metric.
189    #[must_use]
190    pub fn passes(&self, metric: &ToleranceMetric) -> bool {
191        match metric {
192            ToleranceMetric::MaxAbsDiff(t) => self.max_abs_diff <= *t,
193            ToleranceMetric::MeanAbsDiff(t) => self.mean_abs_diff <= *t,
194            ToleranceMetric::Psnr(t) => self.psnr_db >= *t,
195            ToleranceMetric::RmsError(t) => self.rms_error <= *t,
196        }
197    }
198
199    /// Percentage of elements that differ.
200    #[must_use]
201    pub fn diff_percentage(&self) -> f64 {
202        if self.element_count == 0 {
203            return 0.0;
204        }
205        (self.differing_elements as f64 / self.element_count as f64) * 100.0
206    }
207
208    /// Whether the two buffers are exactly identical.
209    #[must_use]
210    pub fn is_exact_match(&self) -> bool {
211        self.max_abs_diff == 0
212    }
213
214    fn empty() -> Self {
215        Self {
216            max_abs_diff: 0,
217            mean_abs_diff: 0.0,
218            psnr_db: f64::INFINITY,
219            rms_error: 0.0,
220            worst_index: 0,
221            element_count: 0,
222            differing_elements: 0,
223            diff_histogram: [0u64; 256],
224        }
225    }
226}
227
228// ─── VerificationReport ─────────────────────────────────────────────────────
229
230/// A single test case result within a verification suite.
231#[derive(Debug, Clone)]
232pub struct VerificationCase {
233    /// Human-readable name of the test case.
234    pub name: String,
235    /// The comparison result.
236    pub result: ComparisonResult,
237    /// The tolerance metric used.
238    pub metric: ToleranceMetric,
239    /// Whether this case passed.
240    pub passed: bool,
241}
242
243/// A suite of GPU vs CPU verification tests.
244#[derive(Debug, Clone)]
245pub struct VerificationSuite {
246    /// Name of the suite.
247    pub name: String,
248    /// Individual test cases.
249    pub cases: Vec<VerificationCase>,
250}
251
252impl VerificationSuite {
253    /// Create a new empty suite.
254    #[must_use]
255    pub fn new(name: impl Into<String>) -> Self {
256        Self {
257            name: name.into(),
258            cases: Vec::new(),
259        }
260    }
261
262    /// Add a test case comparing two u8 buffers.
263    pub fn add_u8_case(
264        &mut self,
265        name: impl Into<String>,
266        gpu_output: &[u8],
267        cpu_output: &[u8],
268        metric: ToleranceMetric,
269    ) {
270        let result = ComparisonResult::compare_u8(gpu_output, cpu_output);
271        let passed = result.passes(&metric);
272        self.cases.push(VerificationCase {
273            name: name.into(),
274            result,
275            metric,
276            passed,
277        });
278    }
279
280    /// Add a test case comparing two f32 buffers.
281    pub fn add_f32_case(
282        &mut self,
283        name: impl Into<String>,
284        gpu_output: &[f32],
285        cpu_output: &[f32],
286        metric: ToleranceMetric,
287    ) {
288        let result = ComparisonResult::compare_f32(gpu_output, cpu_output);
289        let passed = result.passes(&metric);
290        self.cases.push(VerificationCase {
291            name: name.into(),
292            result,
293            metric,
294            passed,
295        });
296    }
297
298    /// Number of test cases.
299    #[must_use]
300    pub fn case_count(&self) -> usize {
301        self.cases.len()
302    }
303
304    /// Number of passing test cases.
305    #[must_use]
306    pub fn passed_count(&self) -> usize {
307        self.cases.iter().filter(|c| c.passed).count()
308    }
309
310    /// Number of failing test cases.
311    #[must_use]
312    pub fn failed_count(&self) -> usize {
313        self.cases.iter().filter(|c| !c.passed).count()
314    }
315
316    /// Whether all test cases passed.
317    #[must_use]
318    pub fn all_passed(&self) -> bool {
319        self.cases.iter().all(|c| c.passed)
320    }
321
322    /// Get all failing cases.
323    #[must_use]
324    pub fn failures(&self) -> Vec<&VerificationCase> {
325        self.cases.iter().filter(|c| !c.passed).collect()
326    }
327
328    /// Generate a summary report as a string.
329    #[must_use]
330    pub fn summary(&self) -> String {
331        let mut report = format!(
332            "Verification Suite: {}\n  Cases: {} total, {} passed, {} failed\n",
333            self.name,
334            self.case_count(),
335            self.passed_count(),
336            self.failed_count(),
337        );
338
339        for case in &self.cases {
340            let status = if case.passed { "PASS" } else { "FAIL" };
341            report.push_str(&format!(
342                "  [{status}] {} — max_diff={}, mean_diff={:.4}, psnr={:.2}dB, rms={:.4}\n",
343                case.name,
344                case.result.max_abs_diff,
345                case.result.mean_abs_diff,
346                case.result.psnr_db,
347                case.result.rms_error,
348            ));
349        }
350
351        report
352    }
353}
354
355// ─── Convenience functions ──────────────────────────────────────────────────
356
357/// Quick check: are two u8 buffers within `max_diff` of each other at every element?
358#[must_use]
359pub fn buffers_within_tolerance(a: &[u8], b: &[u8], max_diff: u32) -> bool {
360    ComparisonResult::compare_u8(a, b).max_abs_diff <= max_diff
361}
362
363/// Quick check: compute PSNR between two u8 buffers.
364#[must_use]
365pub fn compute_buffer_psnr(a: &[u8], b: &[u8]) -> f64 {
366    ComparisonResult::compare_u8(a, b).psnr_db
367}
368
369/// Quick check: compute RMS error between two u8 buffers.
370#[must_use]
371pub fn compute_buffer_rms(a: &[u8], b: &[u8]) -> f64 {
372    ComparisonResult::compare_u8(a, b).rms_error
373}
374
375// ─── Tests ──────────────────────────────────────────────────────────────────
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_identical_buffers_exact_match() {
383        let a = vec![0u8, 128, 255, 64, 192];
384        let result = ComparisonResult::compare_u8(&a, &a);
385        assert!(result.is_exact_match());
386        assert_eq!(result.max_abs_diff, 0);
387        assert_eq!(result.mean_abs_diff, 0.0);
388        assert_eq!(result.psnr_db, f64::INFINITY);
389        assert_eq!(result.rms_error, 0.0);
390        assert_eq!(result.differing_elements, 0);
391    }
392
393    #[test]
394    fn test_single_element_difference() {
395        let a = vec![100u8];
396        let b = vec![105u8];
397        let result = ComparisonResult::compare_u8(&a, &b);
398        assert_eq!(result.max_abs_diff, 5);
399        assert_eq!(result.mean_abs_diff, 5.0);
400        assert_eq!(result.differing_elements, 1);
401        assert_eq!(result.worst_index, 0);
402    }
403
404    #[test]
405    fn test_max_diff_is_worst_case() {
406        let a = vec![0u8, 0, 0, 0];
407        let b = vec![1u8, 2, 10, 3];
408        let result = ComparisonResult::compare_u8(&a, &b);
409        assert_eq!(result.max_abs_diff, 10);
410        assert_eq!(result.worst_index, 2);
411    }
412
413    #[test]
414    fn test_psnr_high_for_small_differences() {
415        let a: Vec<u8> = (0..=255).collect();
416        let b: Vec<u8> = (0..=255).map(|v: u8| v.saturating_add(1)).collect();
417        let result = ComparisonResult::compare_u8(&a, &b);
418        assert!(
419            result.psnr_db > 40.0,
420            "PSNR should be high for +-1 diff, got {}",
421            result.psnr_db
422        );
423    }
424
425    #[test]
426    fn test_tolerance_max_abs_diff_pass() {
427        let result = ComparisonResult::compare_u8(&[100], &[103]);
428        assert!(result.passes(&ToleranceMetric::MaxAbsDiff(5)));
429        assert!(!result.passes(&ToleranceMetric::MaxAbsDiff(2)));
430    }
431
432    #[test]
433    fn test_tolerance_mean_abs_diff_pass() {
434        let a = vec![100u8, 100, 100, 100];
435        let b = vec![102u8, 98, 101, 99];
436        let result = ComparisonResult::compare_u8(&a, &b);
437        assert!(result.passes(&ToleranceMetric::MeanAbsDiff(2.0)));
438    }
439
440    #[test]
441    fn test_tolerance_psnr_pass() {
442        let a: Vec<u8> = vec![128; 1000];
443        let b: Vec<u8> = vec![129; 1000];
444        let result = ComparisonResult::compare_u8(&a, &b);
445        assert!(result.passes(&ToleranceMetric::Psnr(40.0)));
446    }
447
448    #[test]
449    fn test_tolerance_rms_error_pass() {
450        let a = vec![100u8, 100, 100, 100];
451        let b = vec![101u8, 99, 100, 100];
452        let result = ComparisonResult::compare_u8(&a, &b);
453        assert!(result.passes(&ToleranceMetric::RmsError(1.0)));
454    }
455
456    #[test]
457    fn test_diff_percentage() {
458        let a = vec![0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0]; // 10 elements
459        let b = vec![0u8, 1, 0, 1, 0, 0, 0, 0, 0, 0]; // 2 differ
460        let result = ComparisonResult::compare_u8(&a, &b);
461        assert!((result.diff_percentage() - 20.0).abs() < 1e-6);
462    }
463
464    #[test]
465    fn test_empty_buffers() {
466        let result = ComparisonResult::compare_u8(&[], &[]);
467        assert!(result.is_exact_match());
468        assert_eq!(result.element_count, 0);
469        assert_eq!(result.diff_percentage(), 0.0);
470    }
471
472    #[test]
473    fn test_compare_f32_identical() {
474        let a = vec![0.0f32, 0.5, 1.0];
475        let result = ComparisonResult::compare_f32(&a, &a);
476        assert!(result.is_exact_match());
477        assert_eq!(result.psnr_db, f64::INFINITY);
478    }
479
480    #[test]
481    fn test_compare_f32_small_diff() {
482        let a = vec![0.5f32, 0.5, 0.5];
483        let b = vec![0.502f32, 0.498, 0.5];
484        let result = ComparisonResult::compare_f32(&a, &b);
485        assert!(
486            result.max_abs_diff <= 2,
487            "max_abs_diff={}",
488            result.max_abs_diff
489        );
490    }
491
492    #[test]
493    fn test_verification_suite_all_pass() {
494        let mut suite = VerificationSuite::new("test suite");
495        let a = vec![100u8; 16];
496        let b = vec![101u8; 16];
497        suite.add_u8_case("close match", &a, &b, ToleranceMetric::MaxAbsDiff(2));
498        assert!(suite.all_passed());
499        assert_eq!(suite.passed_count(), 1);
500        assert_eq!(suite.failed_count(), 0);
501    }
502
503    #[test]
504    fn test_verification_suite_with_failure() {
505        let mut suite = VerificationSuite::new("mixed");
506        let a = vec![100u8; 16];
507        let b = vec![110u8; 16];
508        suite.add_u8_case("too different", &a, &b, ToleranceMetric::MaxAbsDiff(5));
509        assert!(!suite.all_passed());
510        assert_eq!(suite.failed_count(), 1);
511        let failures = suite.failures();
512        assert_eq!(failures.len(), 1);
513        assert_eq!(failures[0].name, "too different");
514    }
515
516    #[test]
517    fn test_verification_suite_summary_format() {
518        let mut suite = VerificationSuite::new("blur verification");
519        suite.add_u8_case(
520            "uniform image",
521            &[128u8; 4],
522            &[128u8; 4],
523            ToleranceMetric::MaxAbsDiff(0),
524        );
525        let summary = suite.summary();
526        assert!(summary.contains("blur verification"));
527        assert!(summary.contains("PASS"));
528        assert!(summary.contains("1 total"));
529    }
530
531    #[test]
532    fn test_buffers_within_tolerance_convenience() {
533        assert!(buffers_within_tolerance(&[100, 200], &[101, 199], 1));
534        assert!(!buffers_within_tolerance(&[100, 200], &[110, 190], 5));
535    }
536
537    #[test]
538    fn test_compute_buffer_psnr_convenience() {
539        let psnr = compute_buffer_psnr(&[128; 100], &[128; 100]);
540        assert_eq!(psnr, f64::INFINITY);
541    }
542
543    #[test]
544    fn test_compute_buffer_rms_convenience() {
545        let rms = compute_buffer_rms(&[128; 100], &[128; 100]);
546        assert_eq!(rms, 0.0);
547    }
548
549    #[test]
550    fn test_diff_histogram_correct() {
551        let a = vec![10u8, 20, 30, 40, 50];
552        let b = vec![10u8, 21, 32, 40, 53];
553        let result = ComparisonResult::compare_u8(&a, &b);
554        assert_eq!(result.diff_histogram[0], 2); // indices 0, 3
555        assert_eq!(result.diff_histogram[1], 1); // index 1 (diff=1)
556        assert_eq!(result.diff_histogram[2], 1); // index 2 (diff=2)
557        assert_eq!(result.diff_histogram[3], 1); // index 4 (diff=3)
558    }
559
560    #[test]
561    fn test_tolerance_metric_display() {
562        let m = ToleranceMetric::MaxAbsDiff(5);
563        assert!(format!("{m}").contains("max_abs_diff"));
564        let m2 = ToleranceMetric::Psnr(40.0);
565        assert!(format!("{m2}").contains("psnr"));
566    }
567
568    #[test]
569    fn test_suite_f32_case() {
570        let mut suite = VerificationSuite::new("f32 test");
571        let a = vec![0.5f32, 0.5, 0.5];
572        let b = vec![0.5f32, 0.5, 0.5];
573        suite.add_f32_case("exact", &a, &b, ToleranceMetric::MaxAbsDiff(0));
574        assert!(suite.all_passed());
575    }
576
577    #[test]
578    fn test_different_length_buffers_uses_shorter() {
579        let a = vec![100u8, 200, 150];
580        let b = vec![100u8, 200];
581        let result = ComparisonResult::compare_u8(&a, &b);
582        assert_eq!(result.element_count, 2);
583        assert!(result.is_exact_match());
584    }
585}