Skip to main content

math_audio_bem/testing/
mod.rs

1//! Testing and validation infrastructure
2//!
3//! Tools for comparing BEM results with analytical solutions,
4//! computing error metrics, and exporting to JSON for visualization.
5
6use math_audio_wave::analytical::AnalyticalSolution;
7use num_complex::Complex64;
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11pub mod json_output;
12pub mod validation;
13
14// Re-export these modules only when their contents are used elsewhere
15// For now, they're only used when the testing feature is active
16#[allow(unused_imports)]
17pub use json_output::*;
18#[allow(unused_imports)]
19pub use validation::*;
20
21/// Comparison between BEM and analytical solution
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ValidationResult {
24    /// Test name
25    pub test_name: String,
26
27    /// Dimensionality (1, 2, or 3)
28    pub dimensions: usize,
29
30    /// Test parameters
31    pub parameters: TestParameters,
32
33    /// Analytical solution data
34    pub analytical: SolutionData,
35
36    /// BEM solution data
37    pub bem: SolutionData,
38
39    /// Error metrics
40    pub errors: ErrorMetrics,
41
42    /// Execution metadata
43    pub metadata: ExecutionMetadata,
44}
45
46/// Test parameters
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TestParameters {
49    /// Wave number k
50    pub wave_number: f64,
51
52    /// Frequency (Hz)
53    pub frequency: f64,
54
55    /// Wavelength (m)
56    pub wavelength: f64,
57
58    /// Characteristic dimension (radius, length, etc.)
59    pub characteristic_dimension: f64,
60
61    /// Dimensionless parameter (ka, kL, etc.)
62    pub dimensionless_param: f64,
63
64    /// Number of elements in BEM mesh
65    pub num_elements: Option<usize>,
66
67    /// Elements per wavelength
68    pub elements_per_wavelength: Option<f64>,
69
70    /// Additional custom parameters
71    #[serde(flatten)]
72    pub custom: serde_json::Value,
73}
74
75/// Solution data (positions and pressure)
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SolutionData {
78    /// Evaluation positions [[x, y, z], ...]
79    pub positions: Vec<[f64; 3]>,
80
81    /// Real part of pressure
82    pub pressure_real: Vec<f64>,
83
84    /// Imaginary part of pressure
85    pub pressure_imag: Vec<f64>,
86
87    /// Magnitude |p|
88    pub magnitude: Vec<f64>,
89
90    /// Phase arg(p) in radians
91    pub phase: Vec<f64>,
92}
93
94/// Error metrics
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ErrorMetrics {
97    /// Relative L2 error: ||p_bem - p_analytical||₂ / ||p_analytical||₂
98    pub l2_relative: f64,
99
100    /// Absolute L2 error: ||p_bem - p_analytical||₂
101    pub l2_absolute: f64,
102
103    /// L∞ error: max|p_bem - p_analytical|
104    pub linf: f64,
105
106    /// Mean absolute error
107    pub mean_absolute: f64,
108
109    /// RMS error
110    pub rms: f64,
111
112    /// Maximum relative error at any point
113    pub max_relative: f64,
114
115    /// Correlation coefficient
116    pub correlation: f64,
117
118    /// Pointwise errors (for plotting)
119    pub pointwise_errors: Vec<f64>,
120}
121
122/// Execution metadata
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ExecutionMetadata {
125    /// Timestamp (ISO 8601)
126    pub timestamp: String,
127
128    /// Git commit hash
129    pub git_commit: String,
130
131    /// Execution time (milliseconds)
132    pub execution_time_ms: u64,
133
134    /// Peak memory usage (MB)
135    pub memory_peak_mb: f64,
136
137    /// Rust version
138    pub rust_version: String,
139
140    /// Library version
141    pub bem_version: String,
142}
143
144impl ValidationResult {
145    /// Create from analytical and BEM solutions
146    pub fn new(
147        test_name: impl Into<String>,
148        analytical: &AnalyticalSolution,
149        bem_pressure: Vec<Complex64>,
150        execution_time_ms: u64,
151        memory_peak_mb: f64,
152    ) -> Self {
153        // Ensure same number of points
154        assert_eq!(
155            analytical.positions.len(),
156            bem_pressure.len(),
157            "Analytical and BEM must have same number of points"
158        );
159
160        // Compute errors
161        let errors = ErrorMetrics::compute(&analytical.pressure, &bem_pressure);
162
163        // Convert positions to arrays
164        let positions: Vec<[f64; 3]> = analytical
165            .positions
166            .iter()
167            .map(|p| [p.x, p.y, p.z])
168            .collect();
169
170        // Analytical data
171        let analytical_data = SolutionData {
172            positions: positions.clone(),
173            pressure_real: analytical.real(),
174            pressure_imag: analytical.imag(),
175            magnitude: analytical.magnitude(),
176            phase: analytical.phase(),
177        };
178
179        // BEM data
180        let bem_data = SolutionData {
181            positions,
182            pressure_real: bem_pressure.iter().map(|p| p.re).collect(),
183            pressure_imag: bem_pressure.iter().map(|p| p.im).collect(),
184            magnitude: bem_pressure.iter().map(|p| p.norm()).collect(),
185            phase: bem_pressure.iter().map(|p| p.arg()).collect(),
186        };
187
188        // Parameters
189        let wavelength = 2.0 * std::f64::consts::PI / analytical.wave_number;
190        let characteristic_dimension = analytical
191            .metadata
192            .get("radius")
193            .and_then(|v| v.as_f64())
194            .unwrap_or(1.0);
195
196        let parameters = TestParameters {
197            wave_number: analytical.wave_number,
198            frequency: analytical.frequency,
199            wavelength,
200            characteristic_dimension,
201            dimensionless_param: analytical.wave_number * characteristic_dimension,
202            num_elements: None, // Fill from BEM metadata
203            elements_per_wavelength: None,
204            custom: analytical.metadata.clone(),
205        };
206
207        // Metadata
208        let metadata = ExecutionMetadata {
209            timestamp: chrono::Utc::now().to_rfc3339(),
210            git_commit: env!("GIT_HASH").to_string(),
211            execution_time_ms,
212            memory_peak_mb,
213            rust_version: env!("CARGO_PKG_RUST_VERSION").to_string(),
214            bem_version: crate::VERSION.to_string(),
215        };
216
217        Self {
218            test_name: test_name.into(),
219            dimensions: analytical.dimensions,
220            parameters,
221            analytical: analytical_data,
222            bem: bem_data,
223            errors,
224            metadata,
225        }
226    }
227
228    /// Save to JSON file
229    pub fn save_json(&self, path: impl AsRef<Path>) -> anyhow::Result<()> {
230        let json = serde_json::to_string_pretty(self)?;
231        std::fs::write(path, json)?;
232        Ok(())
233    }
234
235    /// Load from JSON file
236    pub fn load_json(path: impl AsRef<Path>) -> anyhow::Result<Self> {
237        let json = std::fs::read_to_string(path)?;
238        let result = serde_json::from_str(&json)?;
239        Ok(result)
240    }
241
242    /// Print summary to stdout
243    pub fn print_summary(&self) {
244        println!("╔══════════════════════════════════════════════════════╗");
245        println!("║  BEM Validation: {}  ║", self.test_name);
246        println!("╠══════════════════════════════════════════════════════╣");
247        println!(
248            "║  Dimensions: {}D                                      ║",
249            self.dimensions
250        );
251        println!(
252            "║  Wave number k: {:.4}                              ║",
253            self.parameters.wave_number
254        );
255        println!(
256            "║  Frequency: {:.2} Hz                               ║",
257            self.parameters.frequency
258        );
259        println!(
260            "║  ka: {:.4}                                         ║",
261            self.parameters.dimensionless_param
262        );
263        println!("╠══════════════════════════════════════════════════════╣");
264        println!("║  Error Metrics:                                      ║");
265        println!(
266            "║    L2 (relative): {:.6}                          ║",
267            self.errors.l2_relative
268        );
269        println!(
270            "║    L∞:            {:.6}                          ║",
271            self.errors.linf
272        );
273        println!(
274            "║    Mean abs:      {:.6}                          ║",
275            self.errors.mean_absolute
276        );
277        println!(
278            "║    RMS:           {:.6}                          ║",
279            self.errors.rms
280        );
281        println!(
282            "║    Max relative:  {:.6}                          ║",
283            self.errors.max_relative
284        );
285        println!(
286            "║    Correlation:   {:.6}                          ║",
287            self.errors.correlation
288        );
289        println!("╠══════════════════════════════════════════════════════╣");
290        println!(
291            "║  Execution time: {} ms                            ║",
292            self.metadata.execution_time_ms
293        );
294        println!(
295            "║  Memory peak: {:.2} MB                             ║",
296            self.metadata.memory_peak_mb
297        );
298        println!("╚══════════════════════════════════════════════════════╝");
299    }
300
301    /// Check if test passed (based on error threshold)
302    pub fn passed(&self, l2_threshold: f64) -> bool {
303        self.errors.l2_relative < l2_threshold
304    }
305}
306
307impl ErrorMetrics {
308    /// Compute all error metrics
309    pub fn compute(analytical: &[Complex64], bem: &[Complex64]) -> Self {
310        assert_eq!(analytical.len(), bem.len());
311
312        let n = analytical.len() as f64;
313
314        // Pointwise errors
315        let pointwise_errors: Vec<f64> = analytical
316            .iter()
317            .zip(bem.iter())
318            .map(|(a, b)| (a - b).norm())
319            .collect();
320
321        // L2 absolute error
322        let l2_absolute = pointwise_errors.iter().map(|e| e * e).sum::<f64>().sqrt();
323
324        // L2 relative error
325        let analytical_norm = analytical.iter().map(|a| a.norm_sqr()).sum::<f64>().sqrt();
326        let l2_relative = if analytical_norm > 1e-15 {
327            l2_absolute / analytical_norm
328        } else {
329            l2_absolute
330        };
331
332        // L∞ error
333        let linf = pointwise_errors.iter().cloned().fold(0.0_f64, f64::max);
334
335        // Mean absolute error
336        let mean_absolute = pointwise_errors.iter().sum::<f64>() / n;
337
338        // RMS error
339        let rms = (pointwise_errors.iter().map(|e| e * e).sum::<f64>() / n).sqrt();
340
341        // Maximum relative error
342        let max_relative = analytical
343            .iter()
344            .zip(bem.iter())
345            .map(|(a, b)| {
346                let a_norm = a.norm();
347                if a_norm > 1e-15 {
348                    (a - b).norm() / a_norm
349                } else {
350                    (a - b).norm()
351                }
352            })
353            .fold(0.0_f64, f64::max);
354
355        // Correlation coefficient
356        let correlation = compute_correlation(analytical, bem);
357
358        Self {
359            l2_relative,
360            l2_absolute,
361            linf,
362            mean_absolute,
363            rms,
364            max_relative,
365            correlation,
366            pointwise_errors,
367        }
368    }
369}
370
371/// Compute correlation coefficient between two complex signals
372fn compute_correlation(a: &[Complex64], b: &[Complex64]) -> f64 {
373    let n = a.len() as f64;
374
375    // Use magnitude for correlation
376    let a_mag: Vec<f64> = a.iter().map(|x| x.norm()).collect();
377    let b_mag: Vec<f64> = b.iter().map(|x| x.norm()).collect();
378
379    let a_mean = a_mag.iter().sum::<f64>() / n;
380    let b_mean = b_mag.iter().sum::<f64>() / n;
381
382    let numerator: f64 = a_mag
383        .iter()
384        .zip(b_mag.iter())
385        .map(|(a_i, b_i)| (a_i - a_mean) * (b_i - b_mean))
386        .sum();
387
388    let a_var: f64 = a_mag.iter().map(|a_i| (a_i - a_mean).powi(2)).sum();
389    let b_var: f64 = b_mag.iter().map(|b_i| (b_i - b_mean).powi(2)).sum();
390
391    if a_var > 1e-15 && b_var > 1e-15 {
392        numerator / (a_var * b_var).sqrt()
393    } else {
394        0.0
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[test]
403    fn test_error_metrics_perfect_match() {
404        let data = vec![
405            Complex64::new(1.0, 0.0),
406            Complex64::new(0.5, 0.5),
407            Complex64::new(0.0, 1.0),
408        ];
409
410        let errors = ErrorMetrics::compute(&data, &data);
411
412        assert!(errors.l2_relative < 1e-10);
413        assert!(errors.l2_absolute < 1e-10);
414        assert!(errors.linf < 1e-10);
415    }
416
417    #[test]
418    fn test_error_metrics_nonzero() {
419        let analytical = vec![Complex64::new(1.0, 0.0), Complex64::new(0.5, 0.5)];
420
421        let bem = vec![Complex64::new(1.01, 0.01), Complex64::new(0.51, 0.51)];
422
423        let errors = ErrorMetrics::compute(&analytical, &bem);
424
425        assert!(errors.l2_relative > 0.0);
426        assert!(errors.l2_relative < 0.1); // Should be small
427    }
428}