1#[cfg(not(feature = "no-std"))]
7use std::collections::HashMap;
8#[cfg(not(feature = "no-std"))]
9use std::string::{String, ToString};
10#[cfg(not(feature = "no-std"))]
11use std::time::Instant;
12#[cfg(not(feature = "no-std"))]
13use std::vec::Vec;
14
15#[cfg(feature = "no-std")]
16use alloc::collections::BTreeMap as HashMap;
17#[cfg(feature = "no-std")]
18use alloc::string::{String, ToString};
19#[cfg(feature = "no-std")]
20use alloc::vec::Vec;
21#[cfg(feature = "no-std")]
22use alloc::{format, vec};
23
24#[cfg(feature = "no-std")]
26#[derive(Debug, Clone, Copy)]
27pub struct Instant;
28
29#[cfg(feature = "no-std")]
30#[derive(Debug, Clone, Copy)]
31pub struct Duration;
32
33#[cfg(feature = "no-std")]
34impl Instant {
35 pub fn now() -> Self {
36 Instant }
38
39 pub fn elapsed(&self) -> Duration {
40 Duration }
42}
43
44#[cfg(feature = "no-std")]
45impl Duration {
46 pub fn as_nanos(&self) -> u128 {
47 0 }
49}
50
51pub mod precision {
53 use super::*;
54
55 #[derive(Debug, Clone, Copy)]
57 pub struct Tolerance {
58 pub absolute: f64,
59 pub relative: f64,
60 }
61
62 impl Tolerance {
63 pub const STRICT: Self = Self {
64 absolute: 1e-15,
65 relative: 1e-14,
66 };
67
68 pub const NORMAL: Self = Self {
69 absolute: 1e-12,
70 relative: 1e-11,
71 };
72
73 pub const RELAXED: Self = Self {
74 absolute: 1e-9,
75 relative: 1e-8,
76 };
77
78 pub const VERY_RELAXED: Self = Self {
79 absolute: 1e-6,
80 relative: 1e-5,
81 };
82 }
83
84 pub fn compare_f32(a: f32, b: f32, tolerance: Tolerance) -> bool {
86 let abs_diff = (a - b).abs() as f64;
87 let rel_diff = if b != 0.0 {
88 abs_diff / (b.abs() as f64)
89 } else {
90 abs_diff
91 };
92
93 abs_diff <= tolerance.absolute || rel_diff <= tolerance.relative
94 }
95
96 pub fn compare_f64(a: f64, b: f64, tolerance: Tolerance) -> bool {
98 let abs_diff = (a - b).abs();
99 let rel_diff = if b != 0.0 {
100 abs_diff / b.abs()
101 } else {
102 abs_diff
103 };
104
105 abs_diff <= tolerance.absolute || rel_diff <= tolerance.relative
106 }
107
108 pub fn compare_f32_slice(a: &[f32], b: &[f32], tolerance: Tolerance) -> ValidationResult {
110 if a.len() != b.len() {
111 return ValidationResult::error("Length mismatch");
112 }
113
114 let mut mismatches = Vec::new();
115 let mut max_abs_error = 0.0f64;
116 let mut max_rel_error = 0.0f64;
117
118 for (i, (&val_a, &val_b)) in a.iter().zip(b.iter()).enumerate() {
119 if !compare_f32(val_a, val_b, tolerance) {
120 let abs_error = (val_a - val_b).abs() as f64;
121 let rel_error = if val_b != 0.0 {
122 abs_error / (val_b.abs() as f64)
123 } else {
124 abs_error
125 };
126
127 max_abs_error = max_abs_error.max(abs_error);
128 max_rel_error = max_rel_error.max(rel_error);
129
130 mismatches.push(ValidationError {
131 index: Some(i),
132 expected: val_b as f64,
133 actual: val_a as f64,
134 abs_error,
135 rel_error,
136 description: format!("Mismatch at index {}", i),
137 });
138
139 if mismatches.len() >= 10 {
140 break; }
142 }
143 }
144
145 if mismatches.is_empty() {
146 ValidationResult::success()
147 } else {
148 let failed_count = mismatches.len();
149 ValidationResult {
150 passed: false,
151 errors: mismatches,
152 statistics: Some(ValidationStatistics {
153 max_abs_error,
154 max_rel_error,
155 total_comparisons: a.len(),
156 failed_comparisons: failed_count,
157 }),
158 }
159 }
160 }
161
162 pub fn compare_f64_slice(a: &[f64], b: &[f64], tolerance: Tolerance) -> ValidationResult {
164 if a.len() != b.len() {
165 return ValidationResult::error("Length mismatch");
166 }
167
168 let mut mismatches = Vec::new();
169 let mut max_abs_error = 0.0f64;
170 let mut max_rel_error = 0.0f64;
171
172 for (i, (&val_a, &val_b)) in a.iter().zip(b.iter()).enumerate() {
173 if !compare_f64(val_a, val_b, tolerance) {
174 let abs_error = (val_a - val_b).abs();
175 let rel_error = if val_b != 0.0 {
176 abs_error / val_b.abs()
177 } else {
178 abs_error
179 };
180
181 max_abs_error = max_abs_error.max(abs_error);
182 max_rel_error = max_rel_error.max(rel_error);
183
184 mismatches.push(ValidationError {
185 index: Some(i),
186 expected: val_b,
187 actual: val_a,
188 abs_error,
189 rel_error,
190 description: format!("Mismatch at index {}", i),
191 });
192
193 if mismatches.len() >= 10 {
194 break;
195 }
196 }
197 }
198
199 if mismatches.is_empty() {
200 ValidationResult::success()
201 } else {
202 let failed_count = mismatches.len();
203 ValidationResult {
204 passed: false,
205 errors: mismatches,
206 statistics: Some(ValidationStatistics {
207 max_abs_error,
208 max_rel_error,
209 total_comparisons: a.len(),
210 failed_comparisons: failed_count,
211 }),
212 }
213 }
214 }
215}
216
217pub mod edge_cases {
219 use super::*;
220
221 pub fn get_special_f32_values() -> Vec<f32> {
223 vec![
224 0.0,
225 -0.0,
226 1.0,
227 -1.0,
228 f32::INFINITY,
229 f32::NEG_INFINITY,
230 f32::NAN,
231 f32::MIN,
232 f32::MAX,
233 f32::MIN_POSITIVE,
234 f32::EPSILON,
235 core::f32::consts::PI,
236 core::f32::consts::E,
237 1e-30,
238 1e30,
239 -1e-30,
240 -1e30,
241 ]
242 }
243
244 pub fn get_special_f64_values() -> Vec<f64> {
246 vec![
247 0.0,
248 -0.0,
249 1.0,
250 -1.0,
251 f64::INFINITY,
252 f64::NEG_INFINITY,
253 f64::NAN,
254 f64::MIN,
255 f64::MAX,
256 f64::MIN_POSITIVE,
257 f64::EPSILON,
258 core::f64::consts::PI,
259 core::f64::consts::E,
260 1e-100,
261 1e100,
262 -1e-100,
263 -1e100,
264 ]
265 }
266
267 pub fn test_unary_f32<F>(
269 func: F,
270 reference_func: F,
271 tolerance: precision::Tolerance,
272 ) -> ValidationResult
273 where
274 F: Fn(f32) -> f32,
275 {
276 let test_values = get_special_f32_values();
277 let mut errors = Vec::new();
278
279 for &val in &test_values {
280 let result = func(val);
281 let expected = reference_func(val);
282
283 if !are_equal_with_nan_handling_f32(result, expected, tolerance) {
284 errors.push(ValidationError {
285 index: None,
286 expected: expected as f64,
287 actual: result as f64,
288 abs_error: (result - expected).abs() as f64,
289 rel_error: if expected != 0.0 {
290 ((result - expected) / expected).abs() as f64
291 } else {
292 (result - expected).abs() as f64
293 },
294 description: format!("Edge case failure for input: {}", val),
295 });
296 }
297 }
298
299 if errors.is_empty() {
300 ValidationResult::success()
301 } else {
302 ValidationResult {
303 passed: false,
304 errors,
305 statistics: None,
306 }
307 }
308 }
309
310 pub fn test_binary_f32<F>(
312 func: F,
313 reference_func: F,
314 tolerance: precision::Tolerance,
315 ) -> ValidationResult
316 where
317 F: Fn(f32, f32) -> f32,
318 {
319 let test_values = get_special_f32_values();
320 let mut errors = Vec::new();
321
322 for &a in &test_values {
323 for &b in &test_values {
324 let result = func(a, b);
325 let expected = reference_func(a, b);
326
327 if !are_equal_with_nan_handling_f32(result, expected, tolerance) {
328 errors.push(ValidationError {
329 index: None,
330 expected: expected as f64,
331 actual: result as f64,
332 abs_error: (result - expected).abs() as f64,
333 rel_error: if expected != 0.0 {
334 ((result - expected) / expected).abs() as f64
335 } else {
336 (result - expected).abs() as f64
337 },
338 description: format!("Edge case failure for inputs: {}, {}", a, b),
339 });
340
341 if errors.len() >= 20 {
342 break;
343 }
344 }
345 }
346 if errors.len() >= 20 {
347 break;
348 }
349 }
350
351 if errors.is_empty() {
352 ValidationResult::success()
353 } else {
354 ValidationResult {
355 passed: false,
356 errors,
357 statistics: None,
358 }
359 }
360 }
361
362 fn are_equal_with_nan_handling_f32(a: f32, b: f32, tolerance: precision::Tolerance) -> bool {
363 if a.is_nan() && b.is_nan() {
364 true
365 } else if a.is_infinite() && b.is_infinite() {
366 a.signum() == b.signum()
367 } else {
368 precision::compare_f32(a, b, tolerance)
369 }
370 }
371}
372
373pub mod correctness {
375 use super::*;
376
377 pub fn verify_against_scalar<F1, F2, T, R>(
379 simd_func: F1,
380 scalar_func: F2,
381 test_data: &[T],
382 _tolerance: precision::Tolerance,
383 operation_name: &str,
384 ) -> ValidationResult
385 where
386 F1: Fn(&[T]) -> R,
387 F2: Fn(&[T]) -> R,
388 R: PartialEq + core::fmt::Debug + Clone,
389 {
390 let simd_result = simd_func(test_data);
391 let scalar_result = scalar_func(test_data);
392
393 if simd_result == scalar_result {
394 ValidationResult::success()
395 } else {
396 ValidationResult::error(&format!(
397 "SIMD result {:?} does not match scalar result {:?} for operation: {}",
398 simd_result, scalar_result, operation_name
399 ))
400 }
401 }
402
403 pub fn verify_f32_slice_operation<F1, F2>(
405 simd_func: F1,
406 scalar_func: F2,
407 test_data: &[f32],
408 tolerance: precision::Tolerance,
409 operation_name: &str,
410 ) -> ValidationResult
411 where
412 F1: Fn(&[f32]) -> Vec<f32>,
413 F2: Fn(&[f32]) -> Vec<f32>,
414 {
415 let simd_result = simd_func(test_data);
416 let scalar_result = scalar_func(test_data);
417
418 let mut validation_result =
419 precision::compare_f32_slice(&simd_result, &scalar_result, tolerance);
420
421 if !validation_result.passed {
422 for error in &mut validation_result.errors {
423 error.description = format!("{}: {}", operation_name, error.description);
424 }
425 }
426
427 validation_result
428 }
429
430 pub fn verify_f64_slice_operation<F1, F2>(
432 simd_func: F1,
433 scalar_func: F2,
434 test_data: &[f64],
435 tolerance: precision::Tolerance,
436 operation_name: &str,
437 ) -> ValidationResult
438 where
439 F1: Fn(&[f64]) -> Vec<f64>,
440 F2: Fn(&[f64]) -> Vec<f64>,
441 {
442 let simd_result = simd_func(test_data);
443 let scalar_result = scalar_func(test_data);
444
445 let mut validation_result =
446 precision::compare_f64_slice(&simd_result, &scalar_result, tolerance);
447
448 if !validation_result.passed {
449 for error in &mut validation_result.errors {
450 error.description = format!("{}: {}", operation_name, error.description);
451 }
452 }
453
454 validation_result
455 }
456
457 pub fn generate_test_datasets_f32() -> Vec<Vec<f32>> {
459 vec![
460 vec![],
462 vec![1.0],
464 vec![1.0, 2.0, 3.0],
466 vec![-1.0, 0.0, 1.0],
467 (0..4).map(|i| i as f32).collect(),
469 (0..8).map(|i| i as f32).collect(),
470 (0..16).map(|i| i as f32).collect(),
471 (0..32).map(|i| i as f32).collect(),
472 (0..7).map(|i| i as f32).collect(),
474 (0..15).map(|i| i as f32).collect(),
475 (0..31).map(|i| i as f32).collect(),
476 (0..1000).map(|i| (i as f32) * 0.1).collect(),
478 vec![
480 0.1, -2.3, 4.7, -0.9, 8.2, -3.1, 5.6, -7.4, 1.8, -6.5, 9.3, -4.7, 2.1, -8.9, 3.4,
481 -1.2,
482 ],
483 vec![1e10, -1e10, 1e20, -1e20],
485 vec![1e-10, -1e-10, 1e-20, -1e-20],
487 vec![1e-10, 1.0, 1e10, -1e-10, -1.0, -1e10],
489 ]
490 }
491
492 pub fn generate_test_datasets_f64() -> Vec<Vec<f64>> {
494 vec![
495 vec![],
497 vec![1.0],
499 vec![1.0, 2.0, 3.0],
501 vec![-1.0, 0.0, 1.0],
502 (0..4).map(|i| i as f64).collect(),
504 (0..8).map(|i| i as f64).collect(),
505 (0..16).map(|i| i as f64).collect(),
506 (0..1000).map(|i| (i as f64) * 0.1).collect(),
508 vec![
510 core::f64::consts::PI,
511 core::f64::consts::E,
512 core::f64::consts::SQRT_2,
513 core::f64::consts::LN_2,
514 ],
515 vec![f64::MIN, f64::MAX, f64::MIN_POSITIVE],
517 ]
518 }
519}
520
521pub mod performance {
523 use super::*;
524
525 #[derive(Debug, Clone)]
527 pub struct PerformanceResult {
528 pub operation_name: String,
529 pub duration_ns: u64,
530 pub throughput_ops_per_sec: f64,
531 pub data_size: usize,
532 }
533
534 pub fn benchmark_function<F, T, R>(
536 func: F,
537 data: &[T],
538 operation_name: &str,
539 iterations: usize,
540 ) -> PerformanceResult
541 where
542 F: Fn(&[T]) -> R,
543 T: Clone,
544 {
545 let start = Instant::now();
546
547 for _ in 0..iterations {
548 let _ = func(data);
549 }
550
551 let duration = start.elapsed();
552 let duration_ns = duration.as_nanos() as u64;
553 let avg_duration_ns = duration_ns / iterations as u64;
554 let throughput = if avg_duration_ns > 0 {
555 1_000_000_000.0 / (avg_duration_ns as f64)
556 } else {
557 f64::INFINITY
558 };
559
560 PerformanceResult {
561 operation_name: operation_name.to_string(),
562 duration_ns: avg_duration_ns,
563 throughput_ops_per_sec: throughput,
564 data_size: data.len(),
565 }
566 }
567
568 pub fn compare_simd_vs_scalar<F1, F2, T, R>(
570 simd_func: F1,
571 scalar_func: F2,
572 data: &[T],
573 operation_name: &str,
574 iterations: usize,
575 ) -> PerformanceComparison
576 where
577 F1: Fn(&[T]) -> R,
578 F2: Fn(&[T]) -> R,
579 T: Clone,
580 {
581 let simd_result = benchmark_function(
582 simd_func,
583 data,
584 &format!("{operation_name}_simd"),
585 iterations,
586 );
587
588 let scalar_result = benchmark_function(
589 scalar_func,
590 data,
591 &format!("{operation_name}_scalar"),
592 iterations,
593 );
594
595 let speedup = if scalar_result.duration_ns > 0 {
596 scalar_result.duration_ns as f64 / simd_result.duration_ns as f64
597 } else {
598 1.0
599 };
600
601 PerformanceComparison {
602 operation_name: operation_name.to_string(),
603 simd_result,
604 scalar_result,
605 speedup,
606 }
607 }
608
609 pub fn check_performance_regression(
611 current: &PerformanceResult,
612 baseline: &PerformanceResult,
613 max_regression_percent: f64,
614 ) -> ValidationResult {
615 if baseline.duration_ns == 0 {
616 return ValidationResult::error("Baseline duration is zero");
617 }
618
619 let regression_ratio = current.duration_ns as f64 / baseline.duration_ns as f64;
620 let regression_percent = (regression_ratio - 1.0) * 100.0;
621
622 if regression_percent > max_regression_percent {
623 ValidationResult::error(&format!(
624 "Performance regression detected: {regression_percent:.2}% slower than baseline (max allowed: {max_regression_percent:.2}%)"
625 ))
626 } else {
627 ValidationResult::success()
628 }
629 }
630
631 #[derive(Debug, Clone)]
632 pub struct PerformanceComparison {
633 pub operation_name: String,
634 pub simd_result: PerformanceResult,
635 pub scalar_result: PerformanceResult,
636 pub speedup: f64,
637 }
638}
639
640#[derive(Debug, Clone)]
642pub struct ValidationError {
643 pub index: Option<usize>,
644 pub expected: f64,
645 pub actual: f64,
646 pub abs_error: f64,
647 pub rel_error: f64,
648 pub description: String,
649}
650
651#[derive(Debug, Clone)]
652pub struct ValidationStatistics {
653 pub max_abs_error: f64,
654 pub max_rel_error: f64,
655 pub total_comparisons: usize,
656 pub failed_comparisons: usize,
657}
658
659#[derive(Debug, Clone)]
660pub struct ValidationResult {
661 pub passed: bool,
662 pub errors: Vec<ValidationError>,
663 pub statistics: Option<ValidationStatistics>,
664}
665
666impl ValidationResult {
667 pub fn success() -> Self {
668 Self {
669 passed: true,
670 errors: Vec::new(),
671 statistics: None,
672 }
673 }
674
675 pub fn error(message: &str) -> Self {
676 Self {
677 passed: false,
678 errors: vec![ValidationError {
679 index: None,
680 expected: 0.0,
681 actual: 0.0,
682 abs_error: 0.0,
683 rel_error: 0.0,
684 description: message.to_string(),
685 }],
686 statistics: None,
687 }
688 }
689
690 pub fn combine(mut self, other: ValidationResult) -> Self {
691 self.passed = self.passed && other.passed;
692 self.errors.extend(other.errors);
693 self
694 }
695}
696
697pub struct ValidationSuite {
699 pub results: HashMap<String, ValidationResult>,
700 pub performance_results: HashMap<String, performance::PerformanceResult>,
701}
702
703impl Default for ValidationSuite {
704 fn default() -> Self {
705 Self::new()
706 }
707}
708
709impl ValidationSuite {
710 pub fn new() -> Self {
711 Self {
712 results: HashMap::new(),
713 performance_results: HashMap::new(),
714 }
715 }
716
717 pub fn add_result(&mut self, name: String, result: ValidationResult) {
718 self.results.insert(name, result);
719 }
720
721 pub fn add_performance_result(&mut self, name: String, result: performance::PerformanceResult) {
722 self.performance_results.insert(name, result);
723 }
724
725 pub fn all_passed(&self) -> bool {
726 self.results.values().all(|r| r.passed)
727 }
728
729 pub fn print_summary(&self) {
730 #[cfg(not(feature = "no-std"))]
731 {
732 let total_tests = self.results.len();
733 let passed_tests = self.results.values().filter(|r| r.passed).count();
734
735 println!("Validation Summary:");
736 println!(" Total tests: {total_tests}");
737 println!(" Passed: {passed_tests}");
738 println!(" Failed: {}", total_tests - passed_tests);
739
740 for (name, result) in &self.results {
741 if !result.passed {
742 println!(" FAILED: {name}");
743 for error in &result.errors {
744 println!(" {}", error.description);
745 }
746 }
747 }
748
749 if !self.performance_results.is_empty() {
750 println!("\nPerformance Results:");
751 for (name, perf) in &self.performance_results {
752 println!(
753 " {}: {:.2} ns/op ({:.2e} ops/sec)",
754 name, perf.duration_ns, perf.throughput_ops_per_sec
755 );
756 }
757 }
758 }
759 }
760}
761
762#[allow(non_snake_case)]
763#[cfg(all(test, not(feature = "no-std")))]
764mod tests {
765 use super::*;
766
767 #[cfg(feature = "no-std")]
768 use alloc::{vec, vec::Vec};
769
770 #[test]
771 fn test_precision_comparison() {
772 assert!(precision::compare_f32(
773 1.0,
774 1.0,
775 precision::Tolerance::STRICT
776 ));
777 assert!(precision::compare_f32(
778 1.0,
779 1.0 + 1e-12,
780 precision::Tolerance::NORMAL
781 ));
782 assert!(!precision::compare_f32(
783 1.0,
784 1.1,
785 precision::Tolerance::STRICT
786 ));
787 }
788
789 #[test]
790 fn test_edge_cases() {
791 let special_values = edge_cases::get_special_f32_values();
792 assert!(special_values.iter().any(|x| x.is_nan())); assert!(special_values.contains(&f32::INFINITY));
794 assert!(special_values.contains(&0.0));
795 }
796
797 #[test]
798 fn test_slice_comparison() {
799 let a = vec![1.0, 2.0, 3.0];
800 let b = vec![1.0, 2.0, 3.0];
801 let result = precision::compare_f32_slice(&a, &b, precision::Tolerance::NORMAL);
802 assert!(result.passed);
803
804 let c = vec![1.0, 2.1, 3.0];
805 let result2 = precision::compare_f32_slice(&a, &c, precision::Tolerance::STRICT);
806 assert!(!result2.passed);
807 }
808
809 #[test]
810 fn test_validation_suite() {
811 let mut suite = ValidationSuite::new();
812 suite.add_result("test1".to_string(), ValidationResult::success());
813 suite.add_result("test2".to_string(), ValidationResult::error("Test error"));
814
815 assert!(!suite.all_passed());
816 assert_eq!(suite.results.len(), 2);
817 }
818
819 #[test]
820 fn test_performance_measurement() {
821 let data = vec![1.0f32; 1000];
822 let result = performance::benchmark_function(
823 |slice| slice.iter().sum::<f32>(),
824 &data,
825 "sum_test",
826 100,
827 );
828
829 assert_eq!(result.operation_name, "sum_test");
830 assert!(result.duration_ns > 0);
831 assert!(result.throughput_ops_per_sec > 0.0);
832 }
833
834 #[test]
835 fn test_test_data_generation() {
836 let datasets = correctness::generate_test_datasets_f32();
837 assert!(!datasets.is_empty());
838 assert!(datasets.iter().any(|d| d.is_empty()));
839 assert!(datasets.iter().any(|d| d.len() == 1));
840 assert!(datasets.iter().any(|d| d.len() > 100));
841 }
842}