Skip to main content

proof_engine/solver/
conservation.rs

1//! Conservation law verification — energy, momentum, mass conservation checks.
2
3use super::ode::OdeState;
4
5/// A conservation law to verify during simulation.
6pub trait ConservationLaw: Send + Sync {
7    fn name(&self) -> &str;
8    fn evaluate(&self, state: &OdeState) -> f64;
9}
10
11/// Conservation check result.
12#[derive(Debug, Clone)]
13pub struct ConservationCheck {
14    pub law_name: String,
15    pub initial_value: f64,
16    pub current_value: f64,
17    pub absolute_error: f64,
18    pub relative_error: f64,
19    pub violated: bool,
20}
21
22/// Monitor for tracking conservation laws.
23pub struct ConservationMonitor {
24    laws: Vec<Box<dyn ConservationLaw>>,
25    initial_values: Vec<f64>,
26    tolerance: f64,
27}
28
29impl ConservationMonitor {
30    pub fn new(tolerance: f64) -> Self {
31        Self { laws: Vec::new(), initial_values: Vec::new(), tolerance }
32    }
33
34    pub fn add_law(&mut self, law: Box<dyn ConservationLaw>) {
35        self.laws.push(law);
36        self.initial_values.push(f64::NAN);
37    }
38
39    /// Initialize with the starting state.
40    pub fn initialize(&mut self, state: &OdeState) {
41        for (i, law) in self.laws.iter().enumerate() {
42            self.initial_values[i] = law.evaluate(state);
43        }
44    }
45
46    /// Check all conservation laws at the current state.
47    pub fn check(&self, state: &OdeState) -> Vec<ConservationCheck> {
48        self.laws.iter().zip(self.initial_values.iter()).map(|(law, &initial)| {
49            let current = law.evaluate(state);
50            let abs_err = (current - initial).abs();
51            let rel_err = if initial.abs() > 1e-15 { abs_err / initial.abs() } else { abs_err };
52            ConservationCheck {
53                law_name: law.name().to_string(),
54                initial_value: initial,
55                current_value: current,
56                absolute_error: abs_err,
57                relative_error: rel_err,
58                violated: rel_err > self.tolerance,
59            }
60        }).collect()
61    }
62}
63
64/// Built-in: total energy for harmonic oscillator (E = 0.5*x² + 0.5*v²).
65pub struct HarmonicEnergy;
66impl ConservationLaw for HarmonicEnergy {
67    fn name(&self) -> &str { "Harmonic Energy" }
68    fn evaluate(&self, state: &OdeState) -> f64 {
69        if state.y.len() >= 2 { 0.5 * state.y[0].powi(2) + 0.5 * state.y[1].powi(2) } else { 0.0 }
70    }
71}
72
73/// Built-in: L2 norm (total "mass").
74pub struct L2Norm;
75impl ConservationLaw for L2Norm {
76    fn name(&self) -> &str { "L2 Norm" }
77    fn evaluate(&self, state: &OdeState) -> f64 {
78        state.y.iter().map(|v| v * v).sum::<f64>().sqrt()
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use crate::solver::ode::{OdeSolver, OdeMethod, HarmonicOscillator, OdeState};
86
87    #[test]
88    fn conservation_monitor_detects_good_solver() {
89        let sys = HarmonicOscillator { omega: 1.0 };
90        let initial = OdeState { t: 0.0, y: vec![1.0, 0.0] };
91
92        let mut monitor = ConservationMonitor::new(0.01);
93        monitor.add_law(Box::new(HarmonicEnergy));
94        monitor.initialize(&initial);
95
96        let mut solver = OdeSolver::rk4(0.01);
97        let final_state = solver.solve(&sys, &initial, 10.0);
98        let checks = monitor.check(&final_state);
99
100        assert!(!checks[0].violated, "RK4 should conserve energy well: err={}", checks[0].relative_error);
101    }
102}