1use crate::error::SpecialResult;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::process::Command;
11
12#[derive(Debug, Clone, Copy)]
14pub enum ReferenceSource {
15 SciPy,
16 GSL,
17 Mathematica,
18 MPFR,
19 Boost,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct TestCase {
25 pub function: String,
26 pub inputs: Vec<f64>,
27 pub expected: f64,
28 pub source: String,
29 pub tolerance: f64,
30}
31
32#[derive(Debug, Clone)]
34pub struct ValidationResult {
35 pub test_case: TestCase,
36 pub computed: f64,
37 pub error: f64,
38 pub relative_error: f64,
39 pub ulp_error: i64,
40 pub passed: bool,
41}
42
43#[derive(Debug)]
45pub struct ValidationSummary {
46 pub function: String,
47 pub total_tests: usize,
48 pub passed: usize,
49 pub failed: usize,
50 pub max_error: f64,
51 pub mean_error: f64,
52 pub max_ulp_error: i64,
53 pub failed_cases: Vec<ValidationResult>,
54}
55
56pub struct CrossValidator {
58 test_cases: HashMap<String, Vec<TestCase>>,
59 results: HashMap<String, Vec<ValidationResult>>,
60}
61
62impl Default for CrossValidator {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl CrossValidator {
69 pub fn new() -> Self {
70 Self {
71 test_cases: HashMap::new(),
72 results: HashMap::new(),
73 }
74 }
75
76 pub fn load_test_cases(&mut self) -> SpecialResult<()> {
78 self.load_scipy_references()?;
80
81 self.load_gsl_references()?;
83
84 self.load_mpfr_references()?;
86
87 Ok(())
88 }
89
90 fn load_scipy_references(&mut self) -> SpecialResult<()> {
92 let gamma_tests = vec![
96 TestCase {
97 function: "gamma".to_string(),
98 inputs: vec![0.5],
99 expected: 1.7724538509055159, source: "SciPy".to_string(),
101 tolerance: 1e-15,
102 },
103 TestCase {
104 function: "gamma".to_string(),
105 inputs: vec![5.0],
106 expected: 24.0,
107 source: "SciPy".to_string(),
108 tolerance: 1e-15,
109 },
110 TestCase {
111 function: "gamma".to_string(),
112 inputs: vec![10.5],
113 expected: 1133278.3889487855,
114 source: "SciPy".to_string(),
115 tolerance: 1e-10,
116 },
117 ];
118
119 self.test_cases.insert("gamma".to_string(), gamma_tests);
120
121 let bessel_tests = vec![
122 TestCase {
123 function: "j0".to_string(),
124 inputs: vec![1.0],
125 expected: 0.7651976865579666,
126 source: "SciPy".to_string(),
127 tolerance: 1e-15,
128 },
129 TestCase {
130 function: "j0".to_string(),
131 inputs: vec![10.0],
132 expected: -0.245_935_764_451_348_3,
133 source: "SciPy".to_string(),
134 tolerance: 1e-15,
135 },
136 ];
137
138 self.test_cases
139 .insert("bessel_j0".to_string(), bessel_tests);
140
141 Ok(())
142 }
143
144 fn load_gsl_references(&mut self) -> SpecialResult<()> {
146 let erf_tests = vec![
148 TestCase {
149 function: "erf".to_string(),
150 inputs: vec![1.0],
151 expected: 0.8427007929497149,
152 source: "GSL".to_string(),
153 tolerance: 1e-15,
154 },
155 TestCase {
156 function: "erf".to_string(),
157 inputs: vec![2.0],
158 expected: 0.9953222650189527,
159 source: "GSL".to_string(),
160 tolerance: 1e-15,
161 },
162 ];
163
164 self.test_cases
165 .entry("erf".to_string())
166 .or_default()
167 .extend(erf_tests);
168
169 Ok(())
170 }
171
172 fn load_mpfr_references(&mut self) -> SpecialResult<()> {
174 let edge_cases = vec![
176 TestCase {
177 function: "gamma".to_string(),
178 inputs: vec![1e-10],
179 expected: 9999999999.422784,
180 source: "MPFR".to_string(),
181 tolerance: 1e-6,
182 },
183 TestCase {
184 function: "gamma".to_string(),
185 inputs: vec![170.5],
186 expected: 4.269_068_009_016_085_7e304,
187 source: "MPFR".to_string(),
188 tolerance: 1e-10,
189 },
190 ];
191
192 self.test_cases
193 .entry("gamma".to_string())
194 .or_default()
195 .extend(edge_cases);
196
197 Ok(())
198 }
199
200 pub fn validate_function<F>(&mut self, name: &str, func: F) -> ValidationSummary
202 where
203 F: Fn(&[f64]) -> f64,
204 {
205 let test_cases = self.test_cases.get(name).cloned().unwrap_or_default();
206 let mut results = Vec::new();
207 let mut errors = Vec::new();
208 let mut ulp_errors = Vec::new();
209
210 for test in test_cases {
211 let computed = func(&test.inputs);
212 let error = (computed - test.expected).abs();
213 let relative_error = if test.expected != 0.0 {
214 error / test.expected.abs()
215 } else {
216 error
217 };
218
219 let ulp_error = compute_ulp_error(computed, test.expected);
220 let passed = relative_error <= test.tolerance;
221
222 let result = ValidationResult {
223 test_case: test.clone(),
224 computed,
225 error,
226 relative_error,
227 ulp_error,
228 passed,
229 };
230
231 if !passed {
232 results.push(result.clone());
233 }
234
235 errors.push(error);
236 ulp_errors.push(ulp_error);
237 }
238
239 let total = errors.len();
240 let passed = errors.iter().filter(|&&e| e <= 1e-10).count();
241
242 ValidationSummary {
243 function: name.to_string(),
244 total_tests: total,
245 passed,
246 failed: total - passed,
247 max_error: errors.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
248 mean_error: errors.iter().sum::<f64>() / total as f64,
249 max_ulp_error: ulp_errors.iter().cloned().max().unwrap_or(0),
250 failed_cases: results,
251 }
252 }
253
254 pub fn generate_report(&self) -> String {
256 let mut report = String::from("# Cross-Validation Report\n\n");
257
258 for (function, results) in &self.results {
259 report.push_str(&format!("## {function}\n\n"));
260
261 let total: usize = results.len();
263 let passed = results.iter().filter(|r| r.passed).count();
264 let failed = total - passed;
265
266 report.push_str(&format!("- Total tests: {total}\n"));
267 report.push_str(&format!(
268 "- Passed: {passed} ({:.1}%)\n",
269 100.0 * passed as f64 / total as f64
270 ));
271 report.push_str(&format!(
272 "- Failed: {failed} ({:.1}%)\n",
273 100.0 * failed as f64 / total as f64
274 ));
275
276 if failed > 0 {
278 report.push_str("\n### Failed Cases\n\n");
279 report.push_str(
280 "| Inputs | Expected | Computed | Rel Error | ULP Error | Source |\n",
281 );
282 report.push_str(
283 "|--------|----------|----------|-----------|-----------|--------|\n",
284 );
285
286 for result in results.iter().filter(|r| !r.passed).take(10) {
287 report.push_str(&format!(
288 "| {inputs:?} | {expected:.6e} | {computed:.6e} | {rel_error:.2e} | {ulp_error} | {source} |\n",
289 inputs = result.test_case.inputs,
290 expected = result.test_case.expected,
291 computed = result.computed,
292 rel_error = result.relative_error,
293 ulp_error = result.ulp_error,
294 source = result.test_case.source,
295 ));
296 }
297
298 if failed > 10 {
299 let more_failed = failed - 10;
300 report.push_str(&format!("\n... and {more_failed} more failed cases\n"));
301 }
302 }
303
304 report.push('\n');
305 }
306
307 report
308 }
309}
310
311#[allow(dead_code)]
313fn compute_ulp_error(a: f64, b: f64) -> i64 {
314 if a == b {
315 return 0;
316 }
317
318 let a_bits = a.to_bits();
319 let b_bits = b.to_bits();
320
321 if a_bits >= b_bits {
323 (a_bits - b_bits) as i64
324 } else {
325 (b_bits - a_bits) as i64
326 }
327}
328
329pub struct PythonValidator {
331 python_path: String,
332}
333
334impl Default for PythonValidator {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340impl PythonValidator {
341 pub fn new() -> Self {
342 Self {
343 python_path: "python3".to_string(),
344 }
345 }
346
347 pub fn compute_reference(&self, function: &str, args: &[f64]) -> SpecialResult<f64> {
349 let args_str = args
350 .iter()
351 .map(|x| x.to_string())
352 .collect::<Vec<_>>()
353 .join(", ");
354 let script = format!(
355 r#"
356import scipy.special as sp
357import sys
358
359result = sp.{function}({args_str})
360print(result)
361"#
362 );
363
364 let output = Command::new(&self.python_path)
365 .arg("-c")
366 .arg(&script)
367 .output()
368 .map_err(|e| crate::error::SpecialError::ComputationError(e.to_string()))?;
369
370 if !output.status.success() {
371 return Err(crate::error::SpecialError::ComputationError(
372 String::from_utf8_lossy(&output.stderr).to_string(),
373 ));
374 }
375
376 let result_str = String::from_utf8_lossy(&output.stdout);
377 result_str
378 .trim()
379 .parse::<f64>()
380 .map_err(|e| crate::error::SpecialError::ComputationError(e.to_string()))
381 }
382}
383
384#[allow(dead_code)]
386pub fn generate_test_suite() -> SpecialResult<()> {
387 let mut validator = CrossValidator::new();
388 validator.load_test_cases()?;
389
390 let mut test_code = String::from("// Auto-generated cross-validation tests\n\n");
392 test_code.push_str("#[cfg(test)]\nmod cross_validation_tests {\n");
393 test_code.push_str(" use super::*;\n");
394 test_code.push_str(" use approx::assert_relative_eq;\n\n");
395
396 for (function, cases) in validator.test_cases {
397 for (i, case) in cases.iter().enumerate() {
398 let source_lower = case.source.to_lowercase();
399 let input_str = case.inputs[0]
400 .to_string()
401 .replace('.', "_")
402 .replace('-', "neg");
403 let args_str = case
404 .inputs
405 .iter()
406 .map(|x| x.to_string())
407 .collect::<Vec<_>>()
408 .join(", ");
409 test_code.push_str(&format!(
410 r#"
411 #[test]
412 fn test_{function}_{source_lower}_{i}_{input_str}() {{
413 let result = {function}({args_str});
414 assert_relative_eq!(result, {expected}, epsilon = {tolerance});
415 }}
416"#,
417 expected = case.expected,
418 tolerance = case.tolerance,
419 ));
420 }
421 }
422
423 test_code.push_str("}\n");
424
425 std::fs::write("src/generated_cross_validation_tests.rs", test_code)
426 .map_err(|e| crate::error::SpecialError::ComputationError(e.to_string()))?;
427
428 Ok(())
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use crate::gamma;
435
436 #[test]
437 fn test_cross_validator() {
438 let mut validator = CrossValidator::new();
439 validator.load_test_cases().expect("Operation failed");
440
441 let summary = validator.validate_function("gamma", |args| gamma(args[0]));
442
443 assert!(summary.total_tests > 0);
444 assert!(summary.passed > 0);
445 }
447
448 #[test]
449 fn test_ulp_error() {
450 assert_eq!(compute_ulp_error(1.0, 1.0), 0);
451 assert!(compute_ulp_error(1.0, 1.0 + f64::EPSILON) <= 2);
452 }
453}