Skip to main content

numra_ode/
events.rs

1//! Event detection and handling for ODE solvers.
2//!
3//! This module provides zero-crossing detection (event detection) during ODE integration.
4//! Events are defined by a function g(t, y) that crosses zero. When a sign change is
5//! detected between steps, bisection is used to locate the precise crossing time.
6//!
7//! # Example
8//!
9//! ```rust
10//! use numra_ode::events::{EventFunction, EventDirection, EventAction};
11//!
12//! struct GroundContact;
13//!
14//! impl EventFunction<f64> for GroundContact {
15//!     fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
16//!         y[0] // Event when height = 0
17//!     }
18//!     fn direction(&self) -> EventDirection {
19//!         EventDirection::Falling // Only detect when falling through zero
20//!     }
21//!     fn action(&self) -> EventAction {
22//!         EventAction::Stop // Stop integration at event
23//!     }
24//! }
25//! ```
26//!
27//! Author: Moussa Leblouba
28//! Date: 5 March 2026
29//! Modified: 2 May 2026
30
31use numra_core::Scalar;
32
33/// Tolerance for event function value considered close enough to zero.
34// Note: f64-specific tolerances; for f32 use cases these should be scaled to ~1e-6
35const EVENT_ZERO_TOL: f64 = 1e-12;
36/// Convergence tolerance for bisection interval width.
37// Note: f64-specific tolerances; for f32 use cases these should be scaled to ~1e-6
38const EVENT_BISECTION_TOL: f64 = 1e-13;
39/// Maximum number of bisection iterations for event location.
40const EVENT_MAX_BISECTION: usize = 50;
41
42/// Direction of zero-crossing to detect.
43#[derive(Debug, Clone, Copy, PartialEq)]
44pub enum EventDirection {
45    /// Detect when g goes from negative to positive.
46    Rising,
47    /// Detect when g goes from positive to negative.
48    Falling,
49    /// Detect zero-crossings in either direction.
50    Both,
51}
52
53/// Action to take when an event is detected.
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum EventAction {
56    /// Stop integration at the event time.
57    Stop,
58    /// Record the event and continue integration.
59    Continue,
60}
61
62/// Trait for event functions.
63///
64/// An event function defines a scalar function g(t, y) whose zero-crossings
65/// are to be detected during integration.
66pub trait EventFunction<S: Scalar>: Send + Sync {
67    /// Evaluate the event function g(t, y).
68    ///
69    /// An event occurs when this function crosses zero.
70    fn evaluate(&self, t: S, y: &[S]) -> S;
71
72    /// Direction of zero-crossing to detect.
73    ///
74    /// Defaults to `Both` (detect crossings in either direction).
75    fn direction(&self) -> EventDirection {
76        EventDirection::Both
77    }
78
79    /// Action to take when event is detected.
80    ///
81    /// Defaults to `Continue` (record and continue integration).
82    fn action(&self) -> EventAction {
83        EventAction::Continue
84    }
85}
86
87/// Recorded event information.
88#[derive(Debug, Clone)]
89pub struct Event<S: Scalar> {
90    /// Time at which the event occurred.
91    pub t: S,
92    /// State at the event time.
93    pub y: Vec<S>,
94    /// Index of the event function that triggered this event.
95    pub event_index: usize,
96}
97
98/// Find the time of a zero-crossing of the event function in [t_lo, t_hi] using bisection.
99///
100/// Uses the provided interpolation function to evaluate the state at intermediate times.
101/// Returns `Some((t_event, y_event))` if a zero-crossing matching the requested direction
102/// is found, or `None` if no valid crossing exists.
103///
104/// # Arguments
105///
106/// * `event` - The event function to check
107/// * `t_lo` - Start time of the interval
108/// * `y_lo` - State at t_lo
109/// * `t_hi` - End time of the interval
110/// * `y_hi` - State at t_hi
111/// * `interpolate` - Function that returns the interpolated state at any time in [t_lo, t_hi]
112pub fn find_event_time<S: Scalar>(
113    event: &dyn EventFunction<S>,
114    t_lo: S,
115    y_lo: &[S],
116    t_hi: S,
117    y_hi: &[S],
118    interpolate: &dyn Fn(S) -> Vec<S>,
119) -> Option<(S, Vec<S>)> {
120    let g_lo = event.evaluate(t_lo, y_lo);
121    let g_hi = event.evaluate(t_hi, y_hi);
122
123    // Check for sign change
124    if g_lo * g_hi > S::ZERO {
125        return None;
126    }
127
128    // Check direction constraint
129    match event.direction() {
130        EventDirection::Rising if g_hi <= g_lo => return None,
131        EventDirection::Falling if g_hi >= g_lo => return None,
132        _ => {}
133    }
134
135    // Bisection to locate the zero-crossing
136    let mut t_a = t_lo;
137    let mut t_b = t_hi;
138    let mut y_a = y_lo.to_vec();
139    let zero_tol = S::from_f64(EVENT_ZERO_TOL);
140    let bisect_tol = S::from_f64(EVENT_BISECTION_TOL);
141
142    for _ in 0..EVENT_MAX_BISECTION {
143        let t_mid = (t_a + t_b) / S::TWO;
144        let y_mid = interpolate(t_mid);
145        let g_mid = event.evaluate(t_mid, &y_mid);
146
147        if g_mid.abs() < zero_tol {
148            return Some((t_mid, y_mid));
149        }
150
151        let g_a = event.evaluate(t_a, &y_a);
152        if g_mid * g_a < S::ZERO {
153            t_b = t_mid;
154        } else {
155            t_a = t_mid;
156            y_a = y_mid;
157        }
158
159        // Check convergence by interval width
160        if (t_b - t_a).abs() < bisect_tol {
161            break;
162        }
163    }
164
165    let t_root = (t_a + t_b) / S::TWO;
166    let y_root = interpolate(t_root);
167    Some((t_root, y_root))
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_event_direction_default() {
176        struct MyEvent;
177        impl EventFunction<f64> for MyEvent {
178            fn evaluate(&self, _t: f64, _y: &[f64]) -> f64 {
179                0.0
180            }
181        }
182        let e = MyEvent;
183        assert_eq!(e.direction(), EventDirection::Both);
184        assert_eq!(e.action(), EventAction::Continue);
185    }
186
187    #[test]
188    fn test_find_event_time_simple() {
189        // g(t, y) = y[0], linear interpolation from y=1 at t=0 to y=-1 at t=2
190        // Zero crossing at t=1
191        struct ZeroCross;
192        impl EventFunction<f64> for ZeroCross {
193            fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
194                y[0]
195            }
196        }
197
198        let event = ZeroCross;
199        let t_lo = 0.0;
200        let y_lo = [1.0];
201        let t_hi = 2.0;
202        let y_hi = [-1.0];
203
204        let interpolate = |t: f64| -> Vec<f64> {
205            vec![1.0 - t] // linear: y = 1 - t
206        };
207
208        let result = find_event_time(&event, t_lo, &y_lo, t_hi, &y_hi, &interpolate);
209        assert!(result.is_some());
210        let (t_event, y_event) = result.unwrap();
211        assert!(
212            (t_event - 1.0).abs() < 1e-10,
213            "Expected t=1, got t={}",
214            t_event
215        );
216        assert!(y_event[0].abs() < 1e-10);
217    }
218
219    #[test]
220    fn test_find_event_time_direction_filter() {
221        // g goes from +1 to -1 (falling), but we request Rising => should return None
222        struct RisingOnly;
223        impl EventFunction<f64> for RisingOnly {
224            fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
225                y[0]
226            }
227            fn direction(&self) -> EventDirection {
228                EventDirection::Rising
229            }
230        }
231
232        let event = RisingOnly;
233        let interpolate = |t: f64| -> Vec<f64> { vec![1.0 - t] };
234
235        let result = find_event_time(&event, 0.0, &[1.0], 2.0, &[-1.0], &interpolate);
236        assert!(
237            result.is_none(),
238            "Rising event should not trigger on falling crossing"
239        );
240    }
241
242    #[test]
243    fn test_no_sign_change() {
244        struct MyEvent;
245        impl EventFunction<f64> for MyEvent {
246            fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
247                y[0]
248            }
249        }
250
251        let event = MyEvent;
252        let interpolate = |t: f64| -> Vec<f64> { vec![1.0 + t] };
253
254        let result = find_event_time(&event, 0.0, &[1.0], 2.0, &[3.0], &interpolate);
255        assert!(result.is_none(), "No sign change should return None");
256    }
257}