use std::fmt;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct ValidationOutcome {
pub passed: bool,
pub input: f64,
pub expected: f64,
pub computed: f64,
pub abs_error: f64,
pub rel_error: Option<f64>,
}
impl ValidationOutcome {
pub fn new(input: f64, expected: f64, computed: f64, tol_rel: f64, tol_abs: f64) -> Self {
let abs_error = (computed - expected).abs();
let rel_error = if expected.abs() > f64::EPSILON {
Some(abs_error / expected.abs())
} else {
None
};
let passed = abs_error <= tol_abs || rel_error.is_some_and(|r| r <= tol_rel);
Self {
passed,
input,
expected,
computed,
abs_error,
rel_error,
}
}
}
impl fmt::Display for ValidationOutcome {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let rel = self
.rel_error
.map_or_else(|| "N/A".to_string(), |r| format!("{r:.3e}"));
write!(
f,
"x={:.6e} | expected={:.6e} | computed={:.6e} | abs_err={:.3e} | rel_err={} | {}",
self.input,
self.expected,
self.computed,
self.abs_error,
rel,
if self.passed { "PASS" } else { "FAIL" }
)
}
}
#[derive(Debug, Clone)]
pub struct NumericalValidationReport {
pub function_name: String,
pub total_checks: usize,
pub passed_checks: usize,
pub failures: Vec<ValidationOutcome>,
pub timing_ns: u64,
pub max_abs_error: f64,
pub max_rel_error: f64,
}
impl NumericalValidationReport {
pub fn pass_rate(&self) -> f64 {
if self.total_checks == 0 {
1.0
} else {
self.passed_checks as f64 / self.total_checks as f64
}
}
pub fn all_passed(&self) -> bool {
self.failures.is_empty()
}
pub fn to_markdown(&self) -> String {
let mut md = String::new();
md.push_str(&format!(
"## Validation Report: `{}`\n\n",
self.function_name
));
md.push_str("| Metric | Value |\n");
md.push_str("|--------|-------|\n");
md.push_str(&format!("| Total checks | {} |\n", self.total_checks));
md.push_str(&format!("| Passed checks | {} |\n", self.passed_checks));
md.push_str(&format!(
"| Pass rate | {:.1}% |\n",
self.pass_rate() * 100.0
));
md.push_str(&format!(
"| Max absolute error | {:.3e} |\n",
self.max_abs_error
));
md.push_str(&format!(
"| Max relative error | {:.3e} |\n",
self.max_rel_error
));
md.push_str(&format!("| Timing | {} ns |\n", self.timing_ns));
md.push_str(&format!(
"| Status | {} |\n\n",
if self.all_passed() {
"✓ PASS"
} else {
"✗ FAIL"
}
));
if !self.failures.is_empty() {
md.push_str("### Failures\n\n");
md.push_str("| Input | Expected | Computed | Abs Error | Rel Error |\n");
md.push_str("|-------|----------|----------|-----------|-----------|\n");
for f in &self.failures {
let rel = f
.rel_error
.map_or_else(|| "N/A".to_string(), |r| format!("{r:.3e}"));
md.push_str(&format!(
"| {:.6e} | {:.6e} | {:.6e} | {:.3e} | {} |\n",
f.input, f.expected, f.computed, f.abs_error, rel
));
}
}
md
}
}
impl fmt::Display for NumericalValidationReport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}: {}/{} passed ({:.1}%), max_abs={:.3e}, max_rel={:.3e}",
self.function_name,
self.passed_checks,
self.total_checks,
self.pass_rate() * 100.0,
self.max_abs_error,
self.max_rel_error,
)
}
}
#[derive(Debug, Clone)]
pub struct NumericalValidationConfig {
pub relative_tolerance: f64,
pub absolute_tolerance: f64,
pub num_test_points: usize,
pub seed: u64,
}
impl Default for NumericalValidationConfig {
fn default() -> Self {
Self {
relative_tolerance: 1e-6,
absolute_tolerance: 1e-10,
num_test_points: 100,
seed: 42,
}
}
}
pub trait NumericalValidator: Send + Sync {
fn name(&self) -> &str;
fn validate(
&self,
input: f64,
computed: f64,
reference: f64,
config: &NumericalValidationConfig,
) -> ValidationOutcome;
fn run_validation(
&self,
computed_fn: &dyn Fn(f64) -> f64,
reference_fn: &dyn Fn(f64) -> f64,
inputs: &[f64],
config: &NumericalValidationConfig,
) -> NumericalValidationReport;
}
#[derive(Debug, Clone)]
pub struct ComparisonValidator {
name: String,
}
impl ComparisonValidator {
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
}
impl NumericalValidator for ComparisonValidator {
fn name(&self) -> &str {
&self.name
}
fn validate(
&self,
input: f64,
computed: f64,
reference: f64,
config: &NumericalValidationConfig,
) -> ValidationOutcome {
ValidationOutcome::new(
input,
reference,
computed,
config.relative_tolerance,
config.absolute_tolerance,
)
}
fn run_validation(
&self,
computed_fn: &dyn Fn(f64) -> f64,
reference_fn: &dyn Fn(f64) -> f64,
inputs: &[f64],
config: &NumericalValidationConfig,
) -> NumericalValidationReport {
let start = Instant::now();
let mut outcomes: Vec<ValidationOutcome> = Vec::with_capacity(inputs.len());
for &x in inputs {
let computed = computed_fn(x);
let reference = reference_fn(x);
outcomes.push(self.validate(x, computed, reference, config));
}
let elapsed_ns = start.elapsed().as_nanos() as u64;
let total = outcomes.len();
let passed = outcomes.iter().filter(|o| o.passed).count();
let failures: Vec<ValidationOutcome> =
outcomes.iter().filter(|o| !o.passed).cloned().collect();
let max_abs_error = outcomes.iter().map(|o| o.abs_error).fold(0.0_f64, f64::max);
let max_rel_error = outcomes
.iter()
.filter_map(|o| o.rel_error)
.fold(0.0_f64, f64::max);
NumericalValidationReport {
function_name: self.name.clone(),
total_checks: total,
passed_checks: passed,
failures,
timing_ns: elapsed_ns,
max_abs_error,
max_rel_error,
}
}
}
#[derive(Debug, Clone)]
pub struct MonotonicityChecker {
pub increasing: bool,
pub tolerance: f64,
}
impl MonotonicityChecker {
pub fn new_increasing(tolerance: f64) -> Self {
Self {
increasing: true,
tolerance,
}
}
pub fn new_decreasing(tolerance: f64) -> Self {
Self {
increasing: false,
tolerance,
}
}
pub fn check(&self, f: &dyn Fn(f64) -> f64, inputs: &[f64]) -> bool {
if inputs.len() < 2 {
return true;
}
let values: Vec<f64> = inputs.iter().map(|&x| f(x)).collect();
for window in values.windows(2) {
let (prev, next) = (window[0], window[1]);
if self.increasing {
if next < prev - self.tolerance {
return false;
}
} else {
if next > prev + self.tolerance {
return false;
}
}
}
true
}
pub fn check_with_first_violation(
&self,
f: &dyn Fn(f64) -> f64,
inputs: &[f64],
) -> (bool, Option<usize>) {
if inputs.len() < 2 {
return (true, None);
}
let values: Vec<f64> = inputs.iter().map(|&x| f(x)).collect();
for (i, window) in values.windows(2).enumerate() {
let (prev, next) = (window[0], window[1]);
let violated = if self.increasing {
next < prev - self.tolerance
} else {
next > prev + self.tolerance
};
if violated {
return (false, Some(i));
}
}
(true, None)
}
}
#[derive(Debug, Clone)]
pub struct BoundaryChecker {
pub name: String,
}
impl BoundaryChecker {
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
pub fn check_bounds(
&self,
f: &dyn Fn(f64) -> f64,
lower: f64,
upper: f64,
expected_lower: f64,
expected_upper: f64,
tolerance: f64,
) -> (bool, bool) {
let lower_ok = (f(lower) - expected_lower).abs() <= tolerance;
let upper_ok = (f(upper) - expected_upper).abs() <= tolerance;
(lower_ok, upper_ok)
}
pub fn check_point(
&self,
f: &dyn Fn(f64) -> f64,
point: f64,
expected: f64,
tolerance: f64,
) -> bool {
(f(point) - expected).abs() <= tolerance
}
pub fn check_multiple(
&self,
f: &dyn Fn(f64) -> f64,
points: &[(f64, f64)],
tolerance: f64,
) -> usize {
points
.iter()
.filter(|&&(x, expected)| (f(x) - expected).abs() <= tolerance)
.count()
}
}
#[derive(Debug, Clone)]
pub struct SymmetryChecker {
pub even: bool,
pub tolerance: f64,
}
impl SymmetryChecker {
pub fn new_even(tolerance: f64) -> Self {
Self {
even: true,
tolerance,
}
}
pub fn new_odd(tolerance: f64) -> Self {
Self {
even: false,
tolerance,
}
}
pub fn check(&self, f: &dyn Fn(f64) -> f64, inputs: &[f64]) -> bool {
for &x in inputs {
let fx = f(x);
let fnx = f(-x);
let expected = if self.even { fx } else { -fx };
if (fnx - expected).abs() > self.tolerance {
return false;
}
}
true
}
pub fn check_with_first_violation(
&self,
f: &dyn Fn(f64) -> f64,
inputs: &[f64],
) -> (bool, Option<usize>) {
for (i, &x) in inputs.iter().enumerate() {
let fx = f(x);
let fnx = f(-x);
let expected = if self.even { fx } else { -fx };
if (fnx - expected).abs() > self.tolerance {
return (false, Some(i));
}
}
(true, None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn test_outcome_pass_within_abs_tolerance() {
let outcome = ValidationOutcome::new(1.0, 2.0, 2.0 + 1e-11, 1e-6, 1e-10);
assert!(outcome.passed);
assert!(outcome.abs_error < 1e-10);
}
#[test]
fn test_outcome_pass_within_rel_tolerance() {
let outcome = ValidationOutcome::new(1.0, 1000.0, 1000.0 * (1.0 + 5e-7), 1e-6, 1e-10);
assert!(outcome.passed);
}
#[test]
fn test_outcome_fail() {
let outcome = ValidationOutcome::new(1.0, 1.0, 2.0, 1e-6, 1e-10);
assert!(!outcome.passed);
assert!((outcome.abs_error - 1.0).abs() < 1e-12);
}
#[test]
fn test_outcome_rel_error_undefined_near_zero() {
let outcome = ValidationOutcome::new(0.0, 0.0, 1e-15, 1e-6, 1e-10);
assert!(outcome.rel_error.is_none());
}
#[test]
fn test_outcome_display() {
let outcome = ValidationOutcome::new(0.5, 0.5, 0.5, 1e-6, 1e-10);
let s = format!("{outcome}");
assert!(s.contains("PASS"));
}
#[test]
fn test_report_pass_rate_empty() {
let report = NumericalValidationReport {
function_name: "empty".to_string(),
total_checks: 0,
passed_checks: 0,
failures: vec![],
timing_ns: 0,
max_abs_error: 0.0,
max_rel_error: 0.0,
};
assert!((report.pass_rate() - 1.0).abs() < 1e-12);
}
#[test]
fn test_report_pass_rate_partial() {
let report = NumericalValidationReport {
function_name: "partial".to_string(),
total_checks: 4,
passed_checks: 3,
failures: vec![ValidationOutcome::new(0.0, 1.0, 2.0, 1e-6, 1e-10)],
timing_ns: 1000,
max_abs_error: 1.0,
max_rel_error: 1.0,
};
assert!((report.pass_rate() - 0.75).abs() < 1e-12);
assert!(!report.all_passed());
}
#[test]
fn test_report_markdown_format() {
let report = NumericalValidationReport {
function_name: "cosine".to_string(),
total_checks: 10,
passed_checks: 10,
failures: vec![],
timing_ns: 12345,
max_abs_error: 1e-12,
max_rel_error: 1e-11,
};
let md = report.to_markdown();
assert!(md.contains("## Validation Report: `cosine`"));
assert!(md.contains("PASS"));
assert!(md.contains("100.0%"));
}
#[test]
fn test_report_markdown_includes_failure_table() {
let report = NumericalValidationReport {
function_name: "broken_fn".to_string(),
total_checks: 2,
passed_checks: 1,
failures: vec![ValidationOutcome::new(
std::f64::consts::PI,
0.0,
1.0,
1e-6,
1e-10,
)],
timing_ns: 500,
max_abs_error: 1.0,
max_rel_error: 0.0,
};
let md = report.to_markdown();
assert!(md.contains("### Failures"));
assert!(md.contains("FAIL"));
}
#[test]
fn test_comparison_validator_identity() {
let v = ComparisonValidator::new("identity");
let config = NumericalValidationConfig::default();
let inputs: Vec<f64> = (0..=20).map(|i| i as f64 * 0.1 - 1.0).collect();
let report = v.run_validation(&|x| x, &|x| x, &inputs, &config);
assert!(report.all_passed(), "Identity should pass: {report}");
assert_eq!(report.total_checks, inputs.len());
}
#[test]
fn test_comparison_validator_sin_reference() {
let v = ComparisonValidator::new("sin_vs_itself");
let config = NumericalValidationConfig::default();
let inputs: Vec<f64> = (0..50).map(|i| i as f64 * PI / 50.0).collect();
let report = v.run_validation(&|x| x.sin(), &|x| x.sin(), &inputs, &config);
assert!(report.all_passed());
}
#[test]
fn test_comparison_validator_detects_errors() {
let v = ComparisonValidator::new("wrong_fn");
let config = NumericalValidationConfig {
relative_tolerance: 1e-3,
absolute_tolerance: 1e-3,
..Default::default()
};
let inputs = vec![1.0, 2.0, 3.0];
let report = v.run_validation(&|x| x + 1.0, &|x| x, &inputs, &config);
assert!(!report.all_passed());
assert_eq!(report.failures.len(), 3);
}
#[test]
fn test_comparison_validator_name() {
let v = ComparisonValidator::new("my_validator");
assert_eq!(v.name(), "my_validator");
}
#[test]
fn test_monotonicity_increasing_ok() {
let checker = MonotonicityChecker::new_increasing(1e-12);
let inputs: Vec<f64> = (0..=20).map(|i| i as f64 * 0.1).collect();
assert!(checker.check(&|x: f64| x.powi(2), &inputs));
}
#[test]
fn test_monotonicity_decreasing_ok() {
let checker = MonotonicityChecker::new_decreasing(1e-12);
let inputs: Vec<f64> = (0..=20).map(|i| i as f64 * 0.1).collect();
assert!(checker.check(&|x: f64| -x, &inputs));
}
#[test]
fn test_monotonicity_violation_detected() {
let checker = MonotonicityChecker::new_increasing(1e-12);
let inputs: Vec<f64> = (0..=100).map(|i| i as f64 * 2.0 * PI / 100.0).collect();
let (ok, idx) = checker.check_with_first_violation(&|x: f64| x.sin(), &inputs);
assert!(!ok);
assert!(idx.is_some());
}
#[test]
fn test_monotonicity_single_point_always_passes() {
let checker = MonotonicityChecker::new_increasing(0.0);
assert!(checker.check(&|x| x, &[1.0]));
}
#[test]
fn test_monotonicity_cdf_like() {
let checker = MonotonicityChecker::new_increasing(1e-12);
let logistic = |x: f64| 1.0 / (1.0 + (-x).exp());
let inputs: Vec<f64> = (-50..=50).map(|i| i as f64 * 0.1).collect();
assert!(checker.check(&logistic, &inputs));
}
#[test]
fn test_boundary_check_bounds_pass() {
let bc = BoundaryChecker::new("logistic_cdf");
let logistic = |x: f64| 1.0 / (1.0 + (-x).exp());
let (lo_ok, hi_ok) = bc.check_bounds(&logistic, -100.0, 100.0, 0.0, 1.0, 1e-6);
assert!(lo_ok);
assert!(hi_ok);
}
#[test]
fn test_boundary_check_point() {
let bc = BoundaryChecker::new("identity");
assert!(bc.check_point(&|x| x, 0.0, 0.0, 1e-12));
assert!(!bc.check_point(&|x| x + 1.0, 0.0, 0.0, 0.5));
}
#[test]
fn test_boundary_check_multiple() {
let bc = BoundaryChecker::new("quadratic");
let points = vec![(0.0, 0.0), (1.0, 1.0), (2.0, 4.0), (3.0, 9.0)];
let ok = bc.check_multiple(&|x| x * x, &points, 1e-12);
assert_eq!(ok, 4);
}
#[test]
fn test_boundary_fail() {
let bc = BoundaryChecker::new("bad");
let (lo_ok, hi_ok) = bc.check_bounds(&|_| 0.5, -1.0, 1.0, 0.0, 1.0, 1e-6);
assert!(!lo_ok);
assert!(!hi_ok);
}
#[test]
fn test_symmetry_even_cosine() {
let sc = SymmetryChecker::new_even(1e-12);
let inputs: Vec<f64> = (1..=20).map(|i| i as f64 * 0.1).collect();
assert!(sc.check(&|x: f64| x.cos(), &inputs));
}
#[test]
fn test_symmetry_odd_sine() {
let sc = SymmetryChecker::new_odd(1e-12);
let inputs: Vec<f64> = (1..=20).map(|i| i as f64 * 0.1).collect();
assert!(sc.check(&|x: f64| x.sin(), &inputs));
}
#[test]
fn test_symmetry_even_violation() {
let sc = SymmetryChecker::new_even(1e-12);
let inputs = vec![0.5, 1.0];
assert!(!sc.check(&|x: f64| x.sin(), &inputs));
}
#[test]
fn test_symmetry_odd_violation() {
let sc = SymmetryChecker::new_odd(1e-12);
let inputs = vec![0.5, 1.0];
assert!(!sc.check(&|x: f64| x.cos(), &inputs));
}
#[test]
fn test_symmetry_with_first_violation() {
let sc = SymmetryChecker::new_even(1e-12);
let (ok, idx) = sc.check_with_first_violation(&|x: f64| x.sin(), &[0.5, 1.0]);
assert!(!ok);
assert_eq!(idx, Some(0));
}
#[test]
fn test_symmetry_empty_input() {
let sc = SymmetryChecker::new_even(1e-12);
assert!(sc.check(&|x| x, &[]));
}
#[test]
fn test_config_default_values() {
let cfg = NumericalValidationConfig::default();
assert_eq!(cfg.relative_tolerance, 1e-6);
assert_eq!(cfg.absolute_tolerance, 1e-10);
assert_eq!(cfg.num_test_points, 100);
assert_eq!(cfg.seed, 42);
}
}