numra_ode/events.rs
1//! Event detection and handling for ODE solvers.
2//!
3//! This module provides zero-crossing detection (event detection) during ODE integration.
4//! Events are defined by a function g(t, y) that crosses zero. When a sign change is
5//! detected between steps, bisection is used to locate the precise crossing time.
6//!
7//! # Example
8//!
9//! ```rust
10//! use numra_ode::events::{EventFunction, EventDirection, EventAction};
11//!
12//! struct GroundContact;
13//!
14//! impl EventFunction<f64> for GroundContact {
15//! fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
16//! y[0] // Event when height = 0
17//! }
18//! fn direction(&self) -> EventDirection {
19//! EventDirection::Falling // Only detect when falling through zero
20//! }
21//! fn action(&self) -> EventAction {
22//! EventAction::Stop // Stop integration at event
23//! }
24//! }
25//! ```
26//!
27//! Author: Moussa Leblouba
28//! Date: 5 March 2026
29//! Modified: 2 May 2026
30
31use numra_core::Scalar;
32
33/// Tolerance for event function value considered close enough to zero.
34// Note: f64-specific tolerances; for f32 use cases these should be scaled to ~1e-6
35const EVENT_ZERO_TOL: f64 = 1e-12;
36/// Convergence tolerance for bisection interval width.
37// Note: f64-specific tolerances; for f32 use cases these should be scaled to ~1e-6
38const EVENT_BISECTION_TOL: f64 = 1e-13;
39/// Maximum number of bisection iterations for event location.
40const EVENT_MAX_BISECTION: usize = 50;
41
42/// Direction of zero-crossing to detect.
43#[derive(Debug, Clone, Copy, PartialEq)]
44pub enum EventDirection {
45 /// Detect when g goes from negative to positive.
46 Rising,
47 /// Detect when g goes from positive to negative.
48 Falling,
49 /// Detect zero-crossings in either direction.
50 Both,
51}
52
53/// Action to take when an event is detected.
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum EventAction {
56 /// Stop integration at the event time.
57 Stop,
58 /// Record the event and continue integration.
59 Continue,
60}
61
62/// Trait for event functions.
63///
64/// An event function defines a scalar function g(t, y) whose zero-crossings
65/// are to be detected during integration.
66pub trait EventFunction<S: Scalar>: Send + Sync {
67 /// Evaluate the event function g(t, y).
68 ///
69 /// An event occurs when this function crosses zero.
70 fn evaluate(&self, t: S, y: &[S]) -> S;
71
72 /// Direction of zero-crossing to detect.
73 ///
74 /// Defaults to `Both` (detect crossings in either direction).
75 fn direction(&self) -> EventDirection {
76 EventDirection::Both
77 }
78
79 /// Action to take when event is detected.
80 ///
81 /// Defaults to `Continue` (record and continue integration).
82 fn action(&self) -> EventAction {
83 EventAction::Continue
84 }
85}
86
87/// Recorded event information.
88#[derive(Debug, Clone)]
89pub struct Event<S: Scalar> {
90 /// Time at which the event occurred.
91 pub t: S,
92 /// State at the event time.
93 pub y: Vec<S>,
94 /// Index of the event function that triggered this event.
95 pub event_index: usize,
96}
97
98/// Find the time of a zero-crossing of the event function in [t_lo, t_hi] using bisection.
99///
100/// Uses the provided interpolation function to evaluate the state at intermediate times.
101/// Returns `Some((t_event, y_event))` if a zero-crossing matching the requested direction
102/// is found, or `None` if no valid crossing exists.
103///
104/// # Arguments
105///
106/// * `event` - The event function to check
107/// * `t_lo` - Start time of the interval
108/// * `y_lo` - State at t_lo
109/// * `t_hi` - End time of the interval
110/// * `y_hi` - State at t_hi
111/// * `interpolate` - Function that returns the interpolated state at any time in [t_lo, t_hi]
112pub fn find_event_time<S: Scalar>(
113 event: &dyn EventFunction<S>,
114 t_lo: S,
115 y_lo: &[S],
116 t_hi: S,
117 y_hi: &[S],
118 interpolate: &dyn Fn(S) -> Vec<S>,
119) -> Option<(S, Vec<S>)> {
120 let g_lo = event.evaluate(t_lo, y_lo);
121 let g_hi = event.evaluate(t_hi, y_hi);
122
123 // Check for sign change
124 if g_lo * g_hi > S::ZERO {
125 return None;
126 }
127
128 // Check direction constraint
129 match event.direction() {
130 EventDirection::Rising if g_hi <= g_lo => return None,
131 EventDirection::Falling if g_hi >= g_lo => return None,
132 _ => {}
133 }
134
135 // Bisection to locate the zero-crossing
136 let mut t_a = t_lo;
137 let mut t_b = t_hi;
138 let mut y_a = y_lo.to_vec();
139 let zero_tol = S::from_f64(EVENT_ZERO_TOL);
140 let bisect_tol = S::from_f64(EVENT_BISECTION_TOL);
141
142 for _ in 0..EVENT_MAX_BISECTION {
143 let t_mid = (t_a + t_b) / S::TWO;
144 let y_mid = interpolate(t_mid);
145 let g_mid = event.evaluate(t_mid, &y_mid);
146
147 if g_mid.abs() < zero_tol {
148 return Some((t_mid, y_mid));
149 }
150
151 let g_a = event.evaluate(t_a, &y_a);
152 if g_mid * g_a < S::ZERO {
153 t_b = t_mid;
154 } else {
155 t_a = t_mid;
156 y_a = y_mid;
157 }
158
159 // Check convergence by interval width
160 if (t_b - t_a).abs() < bisect_tol {
161 break;
162 }
163 }
164
165 let t_root = (t_a + t_b) / S::TWO;
166 let y_root = interpolate(t_root);
167 Some((t_root, y_root))
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_event_direction_default() {
176 struct MyEvent;
177 impl EventFunction<f64> for MyEvent {
178 fn evaluate(&self, _t: f64, _y: &[f64]) -> f64 {
179 0.0
180 }
181 }
182 let e = MyEvent;
183 assert_eq!(e.direction(), EventDirection::Both);
184 assert_eq!(e.action(), EventAction::Continue);
185 }
186
187 #[test]
188 fn test_find_event_time_simple() {
189 // g(t, y) = y[0], linear interpolation from y=1 at t=0 to y=-1 at t=2
190 // Zero crossing at t=1
191 struct ZeroCross;
192 impl EventFunction<f64> for ZeroCross {
193 fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
194 y[0]
195 }
196 }
197
198 let event = ZeroCross;
199 let t_lo = 0.0;
200 let y_lo = [1.0];
201 let t_hi = 2.0;
202 let y_hi = [-1.0];
203
204 let interpolate = |t: f64| -> Vec<f64> {
205 vec![1.0 - t] // linear: y = 1 - t
206 };
207
208 let result = find_event_time(&event, t_lo, &y_lo, t_hi, &y_hi, &interpolate);
209 assert!(result.is_some());
210 let (t_event, y_event) = result.unwrap();
211 assert!(
212 (t_event - 1.0).abs() < 1e-10,
213 "Expected t=1, got t={}",
214 t_event
215 );
216 assert!(y_event[0].abs() < 1e-10);
217 }
218
219 #[test]
220 fn test_find_event_time_direction_filter() {
221 // g goes from +1 to -1 (falling), but we request Rising => should return None
222 struct RisingOnly;
223 impl EventFunction<f64> for RisingOnly {
224 fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
225 y[0]
226 }
227 fn direction(&self) -> EventDirection {
228 EventDirection::Rising
229 }
230 }
231
232 let event = RisingOnly;
233 let interpolate = |t: f64| -> Vec<f64> { vec![1.0 - t] };
234
235 let result = find_event_time(&event, 0.0, &[1.0], 2.0, &[-1.0], &interpolate);
236 assert!(
237 result.is_none(),
238 "Rising event should not trigger on falling crossing"
239 );
240 }
241
242 #[test]
243 fn test_no_sign_change() {
244 struct MyEvent;
245 impl EventFunction<f64> for MyEvent {
246 fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
247 y[0]
248 }
249 }
250
251 let event = MyEvent;
252 let interpolate = |t: f64| -> Vec<f64> { vec![1.0 + t] };
253
254 let result = find_event_time(&event, 0.0, &[1.0], 2.0, &[3.0], &interpolate);
255 assert!(result.is_none(), "No sign change should return None");
256 }
257}