use crate::error::SpecialResult;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::process::Command;
#[derive(Debug, Clone, Copy)]
pub enum ReferenceSource {
SciPy,
GSL,
Mathematica,
MPFR,
Boost,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestCase {
pub function: String,
pub inputs: Vec<f64>,
pub expected: f64,
pub source: String,
pub tolerance: f64,
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub test_case: TestCase,
pub computed: f64,
pub error: f64,
pub relative_error: f64,
pub ulp_error: i64,
pub passed: bool,
}
#[derive(Debug)]
pub struct ValidationSummary {
pub function: String,
pub total_tests: usize,
pub passed: usize,
pub failed: usize,
pub max_error: f64,
pub mean_error: f64,
pub max_ulp_error: i64,
pub failed_cases: Vec<ValidationResult>,
}
pub struct CrossValidator {
test_cases: HashMap<String, Vec<TestCase>>,
results: HashMap<String, Vec<ValidationResult>>,
}
impl Default for CrossValidator {
fn default() -> Self {
Self::new()
}
}
impl CrossValidator {
pub fn new() -> Self {
Self {
test_cases: HashMap::new(),
results: HashMap::new(),
}
}
pub fn load_test_cases(&mut self) -> SpecialResult<()> {
self.load_scipy_references()?;
self.load_gsl_references()?;
self.load_mpfr_references()?;
Ok(())
}
fn load_scipy_references(&mut self) -> SpecialResult<()> {
let gamma_tests = vec![
TestCase {
function: "gamma".to_string(),
inputs: vec![0.5],
expected: 1.7724538509055159, source: "SciPy".to_string(),
tolerance: 1e-15,
},
TestCase {
function: "gamma".to_string(),
inputs: vec![5.0],
expected: 24.0,
source: "SciPy".to_string(),
tolerance: 1e-15,
},
TestCase {
function: "gamma".to_string(),
inputs: vec![10.5],
expected: 1133278.3889487855,
source: "SciPy".to_string(),
tolerance: 1e-10,
},
];
self.test_cases.insert("gamma".to_string(), gamma_tests);
let bessel_tests = vec![
TestCase {
function: "j0".to_string(),
inputs: vec![1.0],
expected: 0.7651976865579666,
source: "SciPy".to_string(),
tolerance: 1e-15,
},
TestCase {
function: "j0".to_string(),
inputs: vec![10.0],
expected: -0.245_935_764_451_348_3,
source: "SciPy".to_string(),
tolerance: 1e-15,
},
];
self.test_cases
.insert("bessel_j0".to_string(), bessel_tests);
Ok(())
}
fn load_gsl_references(&mut self) -> SpecialResult<()> {
let erf_tests = vec![
TestCase {
function: "erf".to_string(),
inputs: vec![1.0],
expected: 0.8427007929497149,
source: "GSL".to_string(),
tolerance: 1e-15,
},
TestCase {
function: "erf".to_string(),
inputs: vec![2.0],
expected: 0.9953222650189527,
source: "GSL".to_string(),
tolerance: 1e-15,
},
];
self.test_cases
.entry("erf".to_string())
.or_default()
.extend(erf_tests);
Ok(())
}
fn load_mpfr_references(&mut self) -> SpecialResult<()> {
let edge_cases = vec![
TestCase {
function: "gamma".to_string(),
inputs: vec![1e-10],
expected: 9999999999.422784,
source: "MPFR".to_string(),
tolerance: 1e-6,
},
TestCase {
function: "gamma".to_string(),
inputs: vec![170.5],
expected: 4.269_068_009_016_085_7e304,
source: "MPFR".to_string(),
tolerance: 1e-10,
},
];
self.test_cases
.entry("gamma".to_string())
.or_default()
.extend(edge_cases);
Ok(())
}
pub fn validate_function<F>(&mut self, name: &str, func: F) -> ValidationSummary
where
F: Fn(&[f64]) -> f64,
{
let test_cases = self.test_cases.get(name).cloned().unwrap_or_default();
let mut results = Vec::new();
let mut errors = Vec::new();
let mut ulp_errors = Vec::new();
for test in test_cases {
let computed = func(&test.inputs);
let error = (computed - test.expected).abs();
let relative_error = if test.expected != 0.0 {
error / test.expected.abs()
} else {
error
};
let ulp_error = compute_ulp_error(computed, test.expected);
let passed = relative_error <= test.tolerance;
let result = ValidationResult {
test_case: test.clone(),
computed,
error,
relative_error,
ulp_error,
passed,
};
if !passed {
results.push(result.clone());
}
errors.push(error);
ulp_errors.push(ulp_error);
}
let total = errors.len();
let passed = errors.iter().filter(|&&e| e <= 1e-10).count();
ValidationSummary {
function: name.to_string(),
total_tests: total,
passed,
failed: total - passed,
max_error: errors.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
mean_error: errors.iter().sum::<f64>() / total as f64,
max_ulp_error: ulp_errors.iter().cloned().max().unwrap_or(0),
failed_cases: results,
}
}
pub fn generate_report(&self) -> String {
let mut report = String::from("# Cross-Validation Report\n\n");
for (function, results) in &self.results {
report.push_str(&format!("## {function}\n\n"));
let total: usize = results.len();
let passed = results.iter().filter(|r| r.passed).count();
let failed = total - passed;
report.push_str(&format!("- Total tests: {total}\n"));
report.push_str(&format!(
"- Passed: {passed} ({:.1}%)\n",
100.0 * passed as f64 / total as f64
));
report.push_str(&format!(
"- Failed: {failed} ({:.1}%)\n",
100.0 * failed as f64 / total as f64
));
if failed > 0 {
report.push_str("\n### Failed Cases\n\n");
report.push_str(
"| Inputs | Expected | Computed | Rel Error | ULP Error | Source |\n",
);
report.push_str(
"|--------|----------|----------|-----------|-----------|--------|\n",
);
for result in results.iter().filter(|r| !r.passed).take(10) {
report.push_str(&format!(
"| {inputs:?} | {expected:.6e} | {computed:.6e} | {rel_error:.2e} | {ulp_error} | {source} |\n",
inputs = result.test_case.inputs,
expected = result.test_case.expected,
computed = result.computed,
rel_error = result.relative_error,
ulp_error = result.ulp_error,
source = result.test_case.source,
));
}
if failed > 10 {
let more_failed = failed - 10;
report.push_str(&format!("\n... and {more_failed} more failed cases\n"));
}
}
report.push('\n');
}
report
}
}
#[allow(dead_code)]
fn compute_ulp_error(a: f64, b: f64) -> i64 {
if a == b {
return 0;
}
let a_bits = a.to_bits();
let b_bits = b.to_bits();
if a_bits >= b_bits {
(a_bits - b_bits) as i64
} else {
(b_bits - a_bits) as i64
}
}
pub struct PythonValidator {
python_path: String,
}
impl Default for PythonValidator {
fn default() -> Self {
Self::new()
}
}
impl PythonValidator {
pub fn new() -> Self {
Self {
python_path: "python3".to_string(),
}
}
pub fn compute_reference(&self, function: &str, args: &[f64]) -> SpecialResult<f64> {
let args_str = args
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", ");
let script = format!(
r#"
import scipy.special as sp
import sys
result = sp.{function}({args_str})
print(result)
"#
);
let output = Command::new(&self.python_path)
.arg("-c")
.arg(&script)
.output()
.map_err(|e| crate::error::SpecialError::ComputationError(e.to_string()))?;
if !output.status.success() {
return Err(crate::error::SpecialError::ComputationError(
String::from_utf8_lossy(&output.stderr).to_string(),
));
}
let result_str = String::from_utf8_lossy(&output.stdout);
result_str
.trim()
.parse::<f64>()
.map_err(|e| crate::error::SpecialError::ComputationError(e.to_string()))
}
}
#[allow(dead_code)]
pub fn generate_test_suite() -> SpecialResult<()> {
let mut validator = CrossValidator::new();
validator.load_test_cases()?;
let mut test_code = String::from("// Auto-generated cross-validation tests\n\n");
test_code.push_str("#[cfg(test)]\nmod cross_validation_tests {\n");
test_code.push_str(" use super::*;\n");
test_code.push_str(" use approx::assert_relative_eq;\n\n");
for (function, cases) in validator.test_cases {
for (i, case) in cases.iter().enumerate() {
let source_lower = case.source.to_lowercase();
let input_str = case.inputs[0]
.to_string()
.replace('.', "_")
.replace('-', "neg");
let args_str = case
.inputs
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", ");
test_code.push_str(&format!(
r#"
#[test]
fn test_{function}_{source_lower}_{i}_{input_str}() {{
let result = {function}({args_str});
assert_relative_eq!(result, {expected}, epsilon = {tolerance});
}}
"#,
expected = case.expected,
tolerance = case.tolerance,
));
}
}
test_code.push_str("}\n");
std::fs::write("src/generated_cross_validation_tests.rs", test_code)
.map_err(|e| crate::error::SpecialError::ComputationError(e.to_string()))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gamma;
#[test]
fn test_cross_validator() {
let mut validator = CrossValidator::new();
validator.load_test_cases().expect("Operation failed");
let summary = validator.validate_function("gamma", |args| gamma(args[0]));
assert!(summary.total_tests > 0);
assert!(summary.passed > 0);
}
#[test]
fn test_ulp_error() {
assert_eq!(compute_ulp_error(1.0, 1.0), 0);
assert!(compute_ulp_error(1.0, 1.0 + f64::EPSILON) <= 2);
}
}