1#[derive(Debug, Clone, Copy, PartialEq)]
31pub enum ToleranceMetric {
32 MaxAbsDiff(u32),
34 MeanAbsDiff(f64),
36 Psnr(f64),
38 RmsError(f64),
40}
41
42impl std::fmt::Display for ToleranceMetric {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 Self::MaxAbsDiff(t) => write!(f, "max_abs_diff <= {t}"),
46 Self::MeanAbsDiff(t) => write!(f, "mean_abs_diff <= {t:.4}"),
47 Self::Psnr(t) => write!(f, "psnr >= {t:.2} dB"),
48 Self::RmsError(t) => write!(f, "rms_error <= {t:.4}"),
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
57pub struct ComparisonResult {
58 pub max_abs_diff: u32,
60 pub mean_abs_diff: f64,
62 pub psnr_db: f64,
64 pub rms_error: f64,
66 pub worst_index: usize,
68 pub element_count: usize,
70 pub differing_elements: usize,
72 pub diff_histogram: [u64; 256],
75}
76
77impl ComparisonResult {
78 #[must_use]
83 pub fn compare_u8(a: &[u8], b: &[u8]) -> Self {
84 let len = a.len().min(b.len());
85 if len == 0 {
86 return Self::empty();
87 }
88
89 let mut max_diff: u32 = 0;
90 let mut sum_diff: f64 = 0.0;
91 let mut sum_sq_diff: f64 = 0.0;
92 let mut worst_idx: usize = 0;
93 let mut differing: usize = 0;
94 let mut histogram = [0u64; 256];
95
96 for i in 0..len {
97 let diff = (a[i] as i32 - b[i] as i32).unsigned_abs();
98 if diff > max_diff {
99 max_diff = diff;
100 worst_idx = i;
101 }
102 sum_diff += diff as f64;
103 sum_sq_diff += (diff as f64) * (diff as f64);
104 if diff > 0 {
105 differing += 1;
106 }
107 let bin = (diff as usize).min(255);
108 histogram[bin] += 1;
109 }
110
111 let n = len as f64;
112 let mean_diff = sum_diff / n;
113 let mse = sum_sq_diff / n;
114 let rms = mse.sqrt();
115 let psnr = if mse < 1e-12 {
116 f64::INFINITY
117 } else {
118 10.0 * (255.0 * 255.0 / mse).log10()
119 };
120
121 Self {
122 max_abs_diff: max_diff,
123 mean_abs_diff: mean_diff,
124 psnr_db: psnr,
125 rms_error: rms,
126 worst_index: worst_idx,
127 element_count: len,
128 differing_elements: differing,
129 diff_histogram: histogram,
130 }
131 }
132
133 #[must_use]
137 pub fn compare_f32(a: &[f32], b: &[f32]) -> Self {
138 let len = a.len().min(b.len());
139 if len == 0 {
140 return Self::empty();
141 }
142
143 let mut max_diff: u32 = 0;
144 let mut sum_diff: f64 = 0.0;
145 let mut sum_sq_diff: f64 = 0.0;
146 let mut worst_idx: usize = 0;
147 let mut differing: usize = 0;
148 let mut histogram = [0u64; 256];
149
150 for i in 0..len {
151 let diff_f = ((a[i] - b[i]) as f64).abs();
152 let diff_u = (diff_f * 255.0).round() as u32;
153 if diff_u > max_diff {
154 max_diff = diff_u;
155 worst_idx = i;
156 }
157 sum_diff += diff_f * 255.0;
158 sum_sq_diff += diff_f * diff_f * 255.0 * 255.0;
159 if diff_u > 0 {
160 differing += 1;
161 }
162 let bin = (diff_u as usize).min(255);
163 histogram[bin] += 1;
164 }
165
166 let n = len as f64;
167 let mean_diff = sum_diff / n;
168 let mse = sum_sq_diff / n;
169 let rms = mse.sqrt();
170 let psnr = if mse < 1e-12 {
171 f64::INFINITY
172 } else {
173 10.0 * (255.0 * 255.0 / mse).log10()
174 };
175
176 Self {
177 max_abs_diff: max_diff,
178 mean_abs_diff: mean_diff,
179 psnr_db: psnr,
180 rms_error: rms,
181 worst_index: worst_idx,
182 element_count: len,
183 differing_elements: differing,
184 diff_histogram: histogram,
185 }
186 }
187
188 #[must_use]
190 pub fn passes(&self, metric: &ToleranceMetric) -> bool {
191 match metric {
192 ToleranceMetric::MaxAbsDiff(t) => self.max_abs_diff <= *t,
193 ToleranceMetric::MeanAbsDiff(t) => self.mean_abs_diff <= *t,
194 ToleranceMetric::Psnr(t) => self.psnr_db >= *t,
195 ToleranceMetric::RmsError(t) => self.rms_error <= *t,
196 }
197 }
198
199 #[must_use]
201 pub fn diff_percentage(&self) -> f64 {
202 if self.element_count == 0 {
203 return 0.0;
204 }
205 (self.differing_elements as f64 / self.element_count as f64) * 100.0
206 }
207
208 #[must_use]
210 pub fn is_exact_match(&self) -> bool {
211 self.max_abs_diff == 0
212 }
213
214 fn empty() -> Self {
215 Self {
216 max_abs_diff: 0,
217 mean_abs_diff: 0.0,
218 psnr_db: f64::INFINITY,
219 rms_error: 0.0,
220 worst_index: 0,
221 element_count: 0,
222 differing_elements: 0,
223 diff_histogram: [0u64; 256],
224 }
225 }
226}
227
228#[derive(Debug, Clone)]
232pub struct VerificationCase {
233 pub name: String,
235 pub result: ComparisonResult,
237 pub metric: ToleranceMetric,
239 pub passed: bool,
241}
242
243#[derive(Debug, Clone)]
245pub struct VerificationSuite {
246 pub name: String,
248 pub cases: Vec<VerificationCase>,
250}
251
252impl VerificationSuite {
253 #[must_use]
255 pub fn new(name: impl Into<String>) -> Self {
256 Self {
257 name: name.into(),
258 cases: Vec::new(),
259 }
260 }
261
262 pub fn add_u8_case(
264 &mut self,
265 name: impl Into<String>,
266 gpu_output: &[u8],
267 cpu_output: &[u8],
268 metric: ToleranceMetric,
269 ) {
270 let result = ComparisonResult::compare_u8(gpu_output, cpu_output);
271 let passed = result.passes(&metric);
272 self.cases.push(VerificationCase {
273 name: name.into(),
274 result,
275 metric,
276 passed,
277 });
278 }
279
280 pub fn add_f32_case(
282 &mut self,
283 name: impl Into<String>,
284 gpu_output: &[f32],
285 cpu_output: &[f32],
286 metric: ToleranceMetric,
287 ) {
288 let result = ComparisonResult::compare_f32(gpu_output, cpu_output);
289 let passed = result.passes(&metric);
290 self.cases.push(VerificationCase {
291 name: name.into(),
292 result,
293 metric,
294 passed,
295 });
296 }
297
298 #[must_use]
300 pub fn case_count(&self) -> usize {
301 self.cases.len()
302 }
303
304 #[must_use]
306 pub fn passed_count(&self) -> usize {
307 self.cases.iter().filter(|c| c.passed).count()
308 }
309
310 #[must_use]
312 pub fn failed_count(&self) -> usize {
313 self.cases.iter().filter(|c| !c.passed).count()
314 }
315
316 #[must_use]
318 pub fn all_passed(&self) -> bool {
319 self.cases.iter().all(|c| c.passed)
320 }
321
322 #[must_use]
324 pub fn failures(&self) -> Vec<&VerificationCase> {
325 self.cases.iter().filter(|c| !c.passed).collect()
326 }
327
328 #[must_use]
330 pub fn summary(&self) -> String {
331 let mut report = format!(
332 "Verification Suite: {}\n Cases: {} total, {} passed, {} failed\n",
333 self.name,
334 self.case_count(),
335 self.passed_count(),
336 self.failed_count(),
337 );
338
339 for case in &self.cases {
340 let status = if case.passed { "PASS" } else { "FAIL" };
341 report.push_str(&format!(
342 " [{status}] {} — max_diff={}, mean_diff={:.4}, psnr={:.2}dB, rms={:.4}\n",
343 case.name,
344 case.result.max_abs_diff,
345 case.result.mean_abs_diff,
346 case.result.psnr_db,
347 case.result.rms_error,
348 ));
349 }
350
351 report
352 }
353}
354
355#[must_use]
359pub fn buffers_within_tolerance(a: &[u8], b: &[u8], max_diff: u32) -> bool {
360 ComparisonResult::compare_u8(a, b).max_abs_diff <= max_diff
361}
362
363#[must_use]
365pub fn compute_buffer_psnr(a: &[u8], b: &[u8]) -> f64 {
366 ComparisonResult::compare_u8(a, b).psnr_db
367}
368
369#[must_use]
371pub fn compute_buffer_rms(a: &[u8], b: &[u8]) -> f64 {
372 ComparisonResult::compare_u8(a, b).rms_error
373}
374
375#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_identical_buffers_exact_match() {
383 let a = vec![0u8, 128, 255, 64, 192];
384 let result = ComparisonResult::compare_u8(&a, &a);
385 assert!(result.is_exact_match());
386 assert_eq!(result.max_abs_diff, 0);
387 assert_eq!(result.mean_abs_diff, 0.0);
388 assert_eq!(result.psnr_db, f64::INFINITY);
389 assert_eq!(result.rms_error, 0.0);
390 assert_eq!(result.differing_elements, 0);
391 }
392
393 #[test]
394 fn test_single_element_difference() {
395 let a = vec![100u8];
396 let b = vec![105u8];
397 let result = ComparisonResult::compare_u8(&a, &b);
398 assert_eq!(result.max_abs_diff, 5);
399 assert_eq!(result.mean_abs_diff, 5.0);
400 assert_eq!(result.differing_elements, 1);
401 assert_eq!(result.worst_index, 0);
402 }
403
404 #[test]
405 fn test_max_diff_is_worst_case() {
406 let a = vec![0u8, 0, 0, 0];
407 let b = vec![1u8, 2, 10, 3];
408 let result = ComparisonResult::compare_u8(&a, &b);
409 assert_eq!(result.max_abs_diff, 10);
410 assert_eq!(result.worst_index, 2);
411 }
412
413 #[test]
414 fn test_psnr_high_for_small_differences() {
415 let a: Vec<u8> = (0..=255).collect();
416 let b: Vec<u8> = (0..=255).map(|v: u8| v.saturating_add(1)).collect();
417 let result = ComparisonResult::compare_u8(&a, &b);
418 assert!(
419 result.psnr_db > 40.0,
420 "PSNR should be high for +-1 diff, got {}",
421 result.psnr_db
422 );
423 }
424
425 #[test]
426 fn test_tolerance_max_abs_diff_pass() {
427 let result = ComparisonResult::compare_u8(&[100], &[103]);
428 assert!(result.passes(&ToleranceMetric::MaxAbsDiff(5)));
429 assert!(!result.passes(&ToleranceMetric::MaxAbsDiff(2)));
430 }
431
432 #[test]
433 fn test_tolerance_mean_abs_diff_pass() {
434 let a = vec![100u8, 100, 100, 100];
435 let b = vec![102u8, 98, 101, 99];
436 let result = ComparisonResult::compare_u8(&a, &b);
437 assert!(result.passes(&ToleranceMetric::MeanAbsDiff(2.0)));
438 }
439
440 #[test]
441 fn test_tolerance_psnr_pass() {
442 let a: Vec<u8> = vec![128; 1000];
443 let b: Vec<u8> = vec![129; 1000];
444 let result = ComparisonResult::compare_u8(&a, &b);
445 assert!(result.passes(&ToleranceMetric::Psnr(40.0)));
446 }
447
448 #[test]
449 fn test_tolerance_rms_error_pass() {
450 let a = vec![100u8, 100, 100, 100];
451 let b = vec![101u8, 99, 100, 100];
452 let result = ComparisonResult::compare_u8(&a, &b);
453 assert!(result.passes(&ToleranceMetric::RmsError(1.0)));
454 }
455
456 #[test]
457 fn test_diff_percentage() {
458 let a = vec![0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let b = vec![0u8, 1, 0, 1, 0, 0, 0, 0, 0, 0]; let result = ComparisonResult::compare_u8(&a, &b);
461 assert!((result.diff_percentage() - 20.0).abs() < 1e-6);
462 }
463
464 #[test]
465 fn test_empty_buffers() {
466 let result = ComparisonResult::compare_u8(&[], &[]);
467 assert!(result.is_exact_match());
468 assert_eq!(result.element_count, 0);
469 assert_eq!(result.diff_percentage(), 0.0);
470 }
471
472 #[test]
473 fn test_compare_f32_identical() {
474 let a = vec![0.0f32, 0.5, 1.0];
475 let result = ComparisonResult::compare_f32(&a, &a);
476 assert!(result.is_exact_match());
477 assert_eq!(result.psnr_db, f64::INFINITY);
478 }
479
480 #[test]
481 fn test_compare_f32_small_diff() {
482 let a = vec![0.5f32, 0.5, 0.5];
483 let b = vec![0.502f32, 0.498, 0.5];
484 let result = ComparisonResult::compare_f32(&a, &b);
485 assert!(
486 result.max_abs_diff <= 2,
487 "max_abs_diff={}",
488 result.max_abs_diff
489 );
490 }
491
492 #[test]
493 fn test_verification_suite_all_pass() {
494 let mut suite = VerificationSuite::new("test suite");
495 let a = vec![100u8; 16];
496 let b = vec![101u8; 16];
497 suite.add_u8_case("close match", &a, &b, ToleranceMetric::MaxAbsDiff(2));
498 assert!(suite.all_passed());
499 assert_eq!(suite.passed_count(), 1);
500 assert_eq!(suite.failed_count(), 0);
501 }
502
503 #[test]
504 fn test_verification_suite_with_failure() {
505 let mut suite = VerificationSuite::new("mixed");
506 let a = vec![100u8; 16];
507 let b = vec![110u8; 16];
508 suite.add_u8_case("too different", &a, &b, ToleranceMetric::MaxAbsDiff(5));
509 assert!(!suite.all_passed());
510 assert_eq!(suite.failed_count(), 1);
511 let failures = suite.failures();
512 assert_eq!(failures.len(), 1);
513 assert_eq!(failures[0].name, "too different");
514 }
515
516 #[test]
517 fn test_verification_suite_summary_format() {
518 let mut suite = VerificationSuite::new("blur verification");
519 suite.add_u8_case(
520 "uniform image",
521 &[128u8; 4],
522 &[128u8; 4],
523 ToleranceMetric::MaxAbsDiff(0),
524 );
525 let summary = suite.summary();
526 assert!(summary.contains("blur verification"));
527 assert!(summary.contains("PASS"));
528 assert!(summary.contains("1 total"));
529 }
530
531 #[test]
532 fn test_buffers_within_tolerance_convenience() {
533 assert!(buffers_within_tolerance(&[100, 200], &[101, 199], 1));
534 assert!(!buffers_within_tolerance(&[100, 200], &[110, 190], 5));
535 }
536
537 #[test]
538 fn test_compute_buffer_psnr_convenience() {
539 let psnr = compute_buffer_psnr(&[128; 100], &[128; 100]);
540 assert_eq!(psnr, f64::INFINITY);
541 }
542
543 #[test]
544 fn test_compute_buffer_rms_convenience() {
545 let rms = compute_buffer_rms(&[128; 100], &[128; 100]);
546 assert_eq!(rms, 0.0);
547 }
548
549 #[test]
550 fn test_diff_histogram_correct() {
551 let a = vec![10u8, 20, 30, 40, 50];
552 let b = vec![10u8, 21, 32, 40, 53];
553 let result = ComparisonResult::compare_u8(&a, &b);
554 assert_eq!(result.diff_histogram[0], 2); assert_eq!(result.diff_histogram[1], 1); assert_eq!(result.diff_histogram[2], 1); assert_eq!(result.diff_histogram[3], 1); }
559
560 #[test]
561 fn test_tolerance_metric_display() {
562 let m = ToleranceMetric::MaxAbsDiff(5);
563 assert!(format!("{m}").contains("max_abs_diff"));
564 let m2 = ToleranceMetric::Psnr(40.0);
565 assert!(format!("{m2}").contains("psnr"));
566 }
567
568 #[test]
569 fn test_suite_f32_case() {
570 let mut suite = VerificationSuite::new("f32 test");
571 let a = vec![0.5f32, 0.5, 0.5];
572 let b = vec![0.5f32, 0.5, 0.5];
573 suite.add_f32_case("exact", &a, &b, ToleranceMetric::MaxAbsDiff(0));
574 assert!(suite.all_passed());
575 }
576
577 #[test]
578 fn test_different_length_buffers_uses_shorter() {
579 let a = vec![100u8, 200, 150];
580 let b = vec![100u8, 200];
581 let result = ComparisonResult::compare_u8(&a, &b);
582 assert_eq!(result.element_count, 2);
583 assert!(result.is_exact_match());
584 }
585}