Skip to main content

oxide_conservation/
lib.rs

1//! # oxide-conservation
2//!
3//! Conservation law verification for GPU computations.
4//!
5//! Checks that energy, mass, and information are conserved across kernel
6//! boundaries using ternary verification results:
7//! - `+1` → conserved (exact match)
8//! - ` 0` → approximate (within epsilon)
9//! - `-1` → violated (exceeds threshold)
10
11use std::fmt;
12
13/// Ternary verification result for conservation checks.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub enum Ternary {
16    /// Quantity is exactly conserved.
17    Conserved = 1,
18    /// Quantity is approximately conserved (within epsilon).
19    Approximate = 0,
20    /// Conservation is violated beyond threshold.
21    Violated = -1,
22}
23
24impl Ternary {
25    /// Create from an `i8` value.
26    pub fn from_i8(v: i8) -> Self {
27        match v {
28            1 => Ternary::Conserved,
29            0 => Ternary::Approximate,
30            _ => Ternary::Violated,
31        }
32    }
33
34    /// Returns `true` if conservation holds (exact or approximate).
35    pub fn is_ok(self) -> bool {
36        self != Ternary::Violated
37    }
38}
39
40impl fmt::Display for Ternary {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            Ternary::Conserved => write!(f, "+1 (conserved)"),
44            Ternary::Approximate => write!(f, " 0 (approximate)"),
45            Ternary::Violated => write!(f, "-1 (violated)"),
46        }
47    }
48}
49
50/// Kinds of conservation laws that can be verified.
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum ConservationLaw {
53    /// Conservation of energy: Σ E_in == Σ E_out
54    Energy,
55    /// Conservation of mass: Σ m_in == Σ m_out
56    Mass,
57    /// Conservation of information (e.g., entropy bounds, data integrity)
58    Information,
59    /// A user-defined conservation law with a label.
60    Custom(String),
61}
62
63impl fmt::Display for ConservationLaw {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        match self {
66            ConservationLaw::Energy => write!(f, "energy"),
67            ConservationLaw::Mass => write!(f, "mass"),
68            ConservationLaw::Information => write!(f, "information"),
69            ConservationLaw::Custom(name) => write!(f, "custom({})", name),
70        }
71    }
72}
73
74/// Result of a single conservation verification.
75#[derive(Debug, Clone)]
76pub struct VerificationResult {
77    /// The law that was checked.
78    pub law: ConservationLaw,
79    /// Quantity before kernel execution.
80    pub before: f64,
81    /// Quantity after kernel execution.
82    pub after: f64,
83    /// Absolute difference `|after - before|`.
84    pub delta: f64,
85    /// Epsilon threshold for approximate conservation.
86    pub epsilon: f64,
87    /// Ternary verdict.
88    pub verdict: Ternary,
89}
90
91impl VerificationResult {
92    /// Build a verification result by comparing before/after values.
93    pub fn check(law: ConservationLaw, before: f64, after: f64, epsilon: f64) -> Self {
94        let delta = (after - before).abs();
95        let verdict = if delta == 0.0 {
96            Ternary::Conserved
97        } else if delta <= epsilon {
98            Ternary::Approximate
99        } else {
100            Ternary::Violated
101        };
102        VerificationResult {
103            law,
104            before,
105            after,
106            delta,
107            epsilon,
108            verdict,
109        }
110    }
111}
112
113impl fmt::Display for VerificationResult {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        write!(
116            f,
117            "[{}] {} → {} (Δ={}, ε={})",
118            self.verdict, self.before, self.after, self.delta, self.epsilon
119        )
120    }
121}
122
123/// Monitors conservation of tracked quantities across kernel boundaries.
124#[derive(Debug, Clone)]
125pub struct ConservationMonitor {
126    /// Epsilon for approximate conservation checks.
127    epsilon: f64,
128    /// Named quantities: (law, label) → value before kernel.
129    snapshots: Vec<(ConservationLaw, String, f64)>,
130    /// Completed verification results.
131    results: Vec<VerificationResult>,
132}
133
134impl ConservationMonitor {
135    /// Create a new monitor with the given epsilon threshold.
136    pub fn new(epsilon: f64) -> Self {
137        ConservationMonitor {
138            epsilon,
139            snapshots: Vec::new(),
140            results: Vec::new(),
141        }
142    }
143
144    /// Record a quantity *before* kernel execution.
145    pub fn snapshot(&mut self, law: ConservationLaw, label: impl Into<String>, value: f64) {
146        self.snapshots.push((law, label.into(), value));
147    }
148
149    /// Verify a previously-snapshotted quantity *after* kernel execution.
150    ///
151    /// Returns `None` if no matching snapshot exists.
152    pub fn verify(&mut self, law: &ConservationLaw, label: &str, after_value: f64) -> Option<VerificationResult> {
153        let idx = self.snapshots.iter().position(|(l, lb, _)| l == law && lb == label)?;
154        let (law, _, before) = self.snapshots.remove(idx);
155        let result = VerificationResult::check(law, before, after_value, self.epsilon);
156        self.results.push(result.clone());
157        Some(result)
158    }
159
160    /// Verify all remaining snapshots against a closure that provides the "after" value.
161    pub fn verify_all<F>(&mut self, mut after_fn: F) -> Vec<VerificationResult>
162    where
163        F: FnMut(&ConservationLaw, &str) -> f64,
164    {
165        let snapshots = std::mem::take(&mut self.snapshots);
166        let mut batch = Vec::new();
167        for (law, label, before) in snapshots {
168            let after = after_fn(&law, &label);
169            let result = VerificationResult::check(law, before, after, self.epsilon);
170            batch.push(result.clone());
171            self.results.push(result);
172        }
173        batch
174    }
175
176    /// Return all verification results collected so far.
177    pub fn results(&self) -> &[VerificationResult] {
178        &self.results
179    }
180
181    /// Number of unverified snapshots remaining.
182    pub fn pending(&self) -> usize {
183        self.snapshots.len()
184    }
185
186    /// Reset the monitor for a new round of checks.
187    pub fn reset(&mut self) {
188        self.snapshots.clear();
189        self.results.clear();
190    }
191}
192
193/// Tracks cumulative conservation drift across multiple kernel executions.
194#[derive(Debug, Clone)]
195pub struct ConservationBudget {
196    /// Per-law cumulative drift.
197    drifts: Vec<(ConservationLaw, f64)>,
198    /// Threshold beyond which an alert is raised.
199    alert_threshold: f64,
200    /// Statistics.
201    total_verifications: u64,
202    total_violations: u64,
203    drift_history: Vec<f64>,
204}
205
206impl ConservationBudget {
207    /// Create a new budget with the given alert threshold.
208    pub fn new(alert_threshold: f64) -> Self {
209        ConservationBudget {
210            drifts: Vec::new(),
211            alert_threshold,
212            total_verifications: 0,
213            total_violations: 0,
214            drift_history: Vec::new(),
215        }
216    }
217
218    /// Record a verification result, accumulating drift.
219    pub fn record(&mut self, result: &VerificationResult) {
220        self.total_verifications += 1;
221        if result.verdict == Ternary::Violated {
222            self.total_violations += 1;
223        }
224        self.drift_history.push(result.delta);
225        if let Some(entry) = self.drifts.iter_mut().find(|(l, _)| l == &result.law) {
226            entry.1 += result.delta;
227        } else {
228            self.drifts.push((result.law.clone(), result.delta));
229        }
230    }
231
232    /// Record multiple results at once.
233    pub fn record_all(&mut self, results: &[VerificationResult]) {
234        for r in results {
235            self.record(r);
236        }
237    }
238
239    /// Get the cumulative drift for a specific law.
240    pub fn drift_for(&self, law: &ConservationLaw) -> f64 {
241        self.drifts
242            .iter()
243            .find(|(l, _)| l == law)
244            .map(|&(_, d)| d)
245            .unwrap_or(0.0)
246    }
247
248    /// Get the total cumulative drift across all laws.
249    pub fn total_drift(&self) -> f64 {
250        self.drifts.iter().map(|(_, d)| d).sum()
251    }
252
253    /// Check whether cumulative drift exceeds the alert threshold.
254    pub fn check_alert(&self) -> Option<Alert> {
255        let total = self.total_drift();
256        if total > self.alert_threshold {
257            Some(Alert {
258                total_drift: total,
259                threshold: self.alert_threshold,
260                worst_law: self.drifts.iter().max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()).map(|(l, _)| l.clone()),
261            })
262        } else {
263            None
264        }
265    }
266
267    /// Total number of verifications recorded.
268    pub fn verification_count(&self) -> u64 {
269        self.total_verifications
270    }
271
272    /// Total number of violations recorded.
273    pub fn violation_count(&self) -> u64 {
274        self.total_violations
275    }
276
277    /// Average drift per verification.
278    pub fn average_drift(&self) -> f64 {
279        if self.drift_history.is_empty() {
280            0.0
281        } else {
282            self.drift_history.iter().sum::<f64>() / self.drift_history.len() as f64
283        }
284    }
285
286    /// Violation rate as a fraction (0.0 to 1.0).
287    pub fn violation_rate(&self) -> f64 {
288        if self.total_verifications == 0 {
289            0.0
290        } else {
291            self.total_violations as f64 / self.total_verifications as f64
292        }
293    }
294}
295
296/// An alert raised when cumulative drift exceeds the budget threshold.
297#[derive(Debug, Clone)]
298pub struct Alert {
299    /// Total cumulative drift.
300    pub total_drift: f64,
301    /// Configured alert threshold.
302    pub threshold: f64,
303    /// The conservation law with the worst drift.
304    pub worst_law: Option<ConservationLaw>,
305}
306
307impl fmt::Display for Alert {
308    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309        write!(
310            f,
311            "CONSERVATION ALERT: drift {:.6} exceeds threshold {:.6}",
312            self.total_drift, self.threshold
313        )?;
314        if let Some(ref law) = self.worst_law {
315            write!(f, " (worst: {})", law)?;
316        }
317        Ok(())
318    }
319}
320
321// ---------------------------------------------------------------------------
322// Tests
323// ---------------------------------------------------------------------------
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn ternary_exact_conservation() {
331        let result = VerificationResult::check(
332            ConservationLaw::Energy,
333            100.0,
334            100.0,
335            1e-9,
336        );
337        assert_eq!(result.verdict, Ternary::Conserved);
338        assert_eq!(result.delta, 0.0);
339    }
340
341    #[test]
342    fn ternary_approximate_conservation() {
343        let result = VerificationResult::check(
344            ConservationLaw::Mass,
345            50.0,
346            50.0005,
347            0.001,
348        );
349        assert_eq!(result.verdict, Ternary::Approximate);
350        assert!(result.delta > 0.0);
351        assert!(result.delta <= 0.001);
352    }
353
354    #[test]
355    fn ternary_violation() {
356        let result = VerificationResult::check(
357            ConservationLaw::Information,
358            200.0,
359            195.0,
360            1.0,
361        );
362        assert_eq!(result.verdict, Ternary::Violated);
363        assert_eq!(result.delta, 5.0);
364    }
365
366    #[test]
367    fn monitor_snapshot_and_verify() {
368        let mut mon = ConservationMonitor::new(0.01);
369        mon.snapshot(ConservationLaw::Energy, "kernel_a", 42.0);
370        assert_eq!(mon.pending(), 1);
371
372        let result = mon.verify(&ConservationLaw::Energy, "kernel_a", 42.005).unwrap();
373        assert_eq!(result.verdict, Ternary::Approximate);
374        assert_eq!(mon.pending(), 0);
375        assert_eq!(mon.results().len(), 1);
376    }
377
378    #[test]
379    fn monitor_verify_missing_returns_none() {
380        let mut mon = ConservationMonitor::new(0.01);
381        assert!(mon.verify(&ConservationLaw::Energy, "nope", 0.0).is_none());
382    }
383
384    #[test]
385    fn monitor_verify_all_batch() {
386        let mut mon = ConservationMonitor::new(0.1);
387        mon.snapshot(ConservationLaw::Energy, "k1", 10.0);
388        mon.snapshot(ConservationLaw::Mass, "k2", 20.0);
389        mon.snapshot(ConservationLaw::Information, "k3", 30.0);
390
391        let results = mon.verify_all(|law, _label| match law {
392            ConservationLaw::Energy => 10.0,
393            ConservationLaw::Mass => 20.05,
394            ConservationLaw::Information => 30.2,
395            _ => 0.0,
396        });
397
398        assert_eq!(results.len(), 3);
399        assert_eq!(results[0].verdict, Ternary::Conserved);
400        assert_eq!(results[1].verdict, Ternary::Approximate);
401        assert_eq!(results[2].verdict, Ternary::Violated);
402    }
403
404    #[test]
405    fn budget_tracks_drift_and_alerts() {
406        let mut budget = ConservationBudget::new(5.0);
407
408        // Three small drifts that accumulate
409        for _ in 0..3 {
410            let r = VerificationResult::check(ConservationLaw::Energy, 100.0, 98.0, 0.5);
411            budget.record(&r);
412        }
413
414        assert_eq!(budget.verification_count(), 3);
415        assert_eq!(budget.violation_count(), 3); // delta=2.0 > epsilon=0.5
416        assert!((budget.drift_for(&ConservationLaw::Energy) - 6.0).abs() < 1e-9);
417        assert!(budget.check_alert().is_some()); // 6.0 > 5.0 threshold
418    }
419
420    #[test]
421    fn budget_statistics() {
422        let mut budget = ConservationBudget::new(100.0);
423
424        let r1 = VerificationResult::check(ConservationLaw::Mass, 10.0, 10.0, 0.01);
425        let r2 = VerificationResult::check(ConservationLaw::Mass, 10.0, 10.02, 0.01);
426        let r3 = VerificationResult::check(ConservationLaw::Mass, 10.0, 10.5, 0.01);
427        budget.record_all(&[r1, r2, r3]);
428
429        assert_eq!(budget.verification_count(), 3);
430        assert_eq!(budget.violation_count(), 2); // r2 approximate, r3 violated
431        let avg = budget.average_drift();
432        assert!(avg > 0.0);
433        let rate = budget.violation_rate();
434        assert!((rate - 2.0 / 3.0).abs() < 1e-9);
435    }
436
437    #[test]
438    fn custom_conservation_law() {
439        let custom = ConservationLaw::Custom("angular_momentum".into());
440        let result = VerificationResult::check(custom.clone(), 7.0, 7.0, 0.0);
441        assert_eq!(result.verdict, Ternary::Conserved);
442        assert_eq!(result.law, custom);
443    }
444}