Skip to main content

scirs2_integrate/adaptive/
events.rs

1//! Event detection for ODE integration.
2//!
3//! This module provides infrastructure for detecting and locating **zero
4//! crossings** (events) of user-defined scalar functions of the ODE state
5//! during integration.  Events are used to stop integration early, record
6//! exact crossing times, or switch between different ODE systems.
7//!
8//! # Event detection algorithm
9//!
10//! 1. After each accepted ODE step the event function `g(t, y)` is evaluated
11//!    at the new time point.
12//! 2. If the sign of `g` changes compared with the previous step the solver
13//!    knows a zero crossing occurred somewhere in `(t_prev, t_curr)`.
14//! 3. The **Illinois algorithm** (a bracket-based secant method with
15//!    superlinear convergence) is used to find the exact crossing time to
16//!    within a small tolerance.
17//! 4. If the event is marked `terminal = true` integration stops at that
18//!    point; otherwise the crossing is recorded and integration continues.
19//!
20//! # Usage
21//!
22//! Combine `EventSpec` and `EventSet` with `dopri5_with_events` to obtain
23//! both the solution trajectory and a list of detected crossings.
24
25use super::embedded_rk::{dopri5, OdeResult};
26use crate::error::{IntegrateError, IntegrateResult};
27
28// ─── Public types ────────────────────────────────────────────────────────────
29
30/// Specifies the direction of a zero crossing to be detected.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum EventDirection {
33    /// Detect only rising crossings (g goes from negative to positive).
34    Rising,
35    /// Detect only falling crossings (g goes from positive to negative).
36    Falling,
37    /// Detect crossings in either direction.
38    Both,
39}
40
41/// A single event specification.
42///
43/// An event is triggered when the scalar function `func(t, y)` passes
44/// through zero in the given `direction`.
45pub struct EventSpec {
46    /// The event function.  An event triggers when this crosses zero.
47    pub func: Box<dyn Fn(f64, &[f64]) -> f64 + Send + Sync>,
48    /// Which sign change directions to detect.
49    pub direction: EventDirection,
50    /// Whether to halt integration when this event fires.
51    pub terminal: bool,
52}
53
54impl std::fmt::Debug for EventSpec {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("EventSpec")
57            .field("direction", &self.direction)
58            .field("terminal", &self.terminal)
59            .finish()
60    }
61}
62
63/// Result of a detected event.
64#[derive(Debug, Clone)]
65pub struct EventResult {
66    /// The time at which the event function crossed zero.
67    pub t_event: f64,
68    /// The interpolated state at `t_event`.
69    pub y_event: Vec<f64>,
70    /// Index into the `EventSpec` slice that fired.
71    pub event_idx: usize,
72}
73
74/// A collection of events to be monitored during integration.
75pub struct EventSet {
76    /// The list of events, indexed starting from zero.
77    pub specs: Vec<EventSpec>,
78}
79
80impl EventSet {
81    /// Create a new `EventSet` from a vector of `EventSpec`.
82    pub fn new(specs: Vec<EventSpec>) -> Self {
83        Self { specs }
84    }
85}
86
87// ─── Illinois bracket root finder ────────────────────────────────────────────
88
89/// Maximum number of Illinois iterations for root polishing.
90const MAX_ILLINOIS: usize = 50;
91/// Tolerance for the Illinois root-finding iteration (in time).
92const ILLINOIS_TOL: f64 = 1e-12;
93
94/// Check whether the sign change between `g_prev` and `g_curr` matches the
95/// requested `direction`.
96fn direction_matches(g_prev: f64, g_curr: f64, direction: EventDirection) -> bool {
97    match direction {
98        EventDirection::Both => g_prev * g_curr < 0.0,
99        EventDirection::Rising => g_prev < 0.0 && g_curr > 0.0,
100        EventDirection::Falling => g_prev > 0.0 && g_curr < 0.0,
101    }
102}
103
104/// Internal Illinois iteration.
105///
106/// Performs the Illinois secant method on a user-supplied evaluation
107/// function `eval(t) -> g`.  The bracket `[ta, tb]` must satisfy
108/// `ga * gb < 0`.  Returns the located crossing time and state.
109fn illinois_bracket<E>(mut ta: f64, mut tb: f64, mut ga: f64, mut gb: f64, eval: E) -> f64
110where
111    E: Fn(f64) -> f64,
112{
113    // Illinois state: which side was most recently *not* updated
114    // (we halve that side's function value to improve convergence).
115    let mut side = 0i32; // 0 = neutral, +1 = tb stale, -1 = ta stale
116
117    for _ in 0..MAX_ILLINOIS {
118        // Secant step
119        let dg = gb - ga;
120        let t_new = if dg.abs() < 1e-300 {
121            (ta + tb) / 2.0
122        } else {
123            ta - ga * (tb - ta) / dg
124        };
125        let t_new = t_new.clamp(ta.min(tb), ta.max(tb));
126
127        if (tb - ta).abs() < ILLINOIS_TOL {
128            return t_new;
129        }
130
131        let g_new = eval(t_new);
132
133        if g_new.abs() < ILLINOIS_TOL {
134            return t_new;
135        }
136
137        if ga * g_new < 0.0 {
138            // Root in [ta, t_new]; tb moves to t_new
139            if side == 1 {
140                // tb was already stale; halve ga (Illinois modification)
141                ga /= 2.0;
142            }
143            tb = t_new;
144            gb = g_new;
145            side = 1; // tb just moved → ta is now the stale side
146        } else {
147            // Root in [t_new, tb]; ta moves to t_new
148            if side == -1 {
149                // ta was already stale; halve gb
150                gb /= 2.0;
151            }
152            ta = t_new;
153            ga = g_new;
154            side = -1; // ta just moved → tb is now the stale side
155        }
156    }
157
158    (ta + tb) / 2.0
159}
160
161/// Locate the zero crossing of `event.func` in the interval `[t_prev, t_curr]`
162/// using the **Illinois algorithm**.
163///
164/// The ODE solution at intermediate times is approximated by linearly
165/// interpolating the state vectors `y_prev` and `y_curr`.  For higher
166/// accuracy use `find_event_root_dense` with a dense-output interpolant.
167///
168/// Returns `Some(EventResult)` if a crossing is found, `None` if there is no
169/// bracketed zero (e.g. the direction filter rejects the crossing).
170///
171/// # Parameters
172///
173/// * `g_prev`    – Event function value at `t_prev`.
174/// * `g_curr`    – Event function value at `t_curr`.
175/// * `t_prev`    – Left bracket time.
176/// * `t_curr`    – Right bracket time.
177/// * `y_prev`    – State vector at `t_prev`.
178/// * `y_curr`    – State vector at `t_curr`.
179/// * `event_idx` – Index of the triggering event in the surrounding slice.
180/// * `event`     – The `EventSpec` whose zero we are locating.
181pub fn find_event_root(
182    g_prev: f64,
183    g_curr: f64,
184    t_prev: f64,
185    t_curr: f64,
186    y_prev: &[f64],
187    y_curr: &[f64],
188    event_idx: usize,
189    event: &EventSpec,
190) -> Option<EventResult> {
191    if !direction_matches(g_prev, g_curr, event.direction) {
192        return None;
193    }
194
195    let n = y_prev.len();
196    let dt = t_curr - t_prev;
197
198    // Linear interpolation helper
199    let interp = |t: f64| -> Vec<f64> {
200        let alpha = if dt.abs() < 1e-300 {
201            0.5
202        } else {
203            (t - t_prev) / dt
204        };
205        (0..n)
206            .map(|i| y_prev[i] + alpha * (y_curr[i] - y_prev[i]))
207            .collect()
208    };
209
210    let eval = |t: f64| -> f64 {
211        let y = interp(t);
212        (event.func)(t, &y)
213    };
214
215    let t_event = illinois_bracket(t_prev, t_curr, g_prev, g_curr, eval);
216    let y_event = interp(t_event);
217
218    Some(EventResult {
219        t_event,
220        y_event,
221        event_idx,
222    })
223}
224
225/// Locate a zero crossing using a callable ODE solution interpolant instead
226/// of linear interpolation between steps.
227///
228/// `interp(t)` must return the (approximate) ODE state at any time in
229/// `[t_prev, t_curr]`.  This is typically the dense-output polynomial from
230/// the underlying solver step.
231///
232/// Returns `Some(EventResult)` if a crossing is found, `None` otherwise.
233pub fn find_event_root_dense<I>(
234    g_prev: f64,
235    g_curr: f64,
236    t_prev: f64,
237    t_curr: f64,
238    interp: I,
239    event_idx: usize,
240    event: &EventSpec,
241) -> Option<EventResult>
242where
243    I: Fn(f64) -> Vec<f64>,
244{
245    if !direction_matches(g_prev, g_curr, event.direction) {
246        return None;
247    }
248
249    let eval = |t: f64| -> f64 {
250        let y = interp(t);
251        (event.func)(t, &y)
252    };
253
254    let t_event = illinois_bracket(t_prev, t_curr, g_prev, g_curr, eval);
255    let y_event = interp(t_event);
256
257    Some(EventResult {
258        t_event,
259        y_event,
260        event_idx,
261    })
262}
263
264// ─── Complete result type ────────────────────────────────────────────────────
265
266/// Combined result from ODE integration with event detection.
267#[derive(Debug)]
268pub struct OdeEventResult {
269    /// The standard ODE trajectory.
270    pub ode: OdeResult,
271    /// All detected events, in chronological order.
272    pub events: Vec<EventResult>,
273    /// Whether integration terminated due to a terminal event.
274    pub terminated: bool,
275}
276
277// ─── High-level solver with events ──────────────────────────────────────────
278
279/// Solve an ODE with DOPRI5 while monitoring a set of events.
280///
281/// Integration proceeds step by step.  After each accepted step the event
282/// functions are evaluated and any zero crossings located with
283/// [`find_event_root`].  If a terminal event fires integration stops at the
284/// event time; otherwise it continues to `t_end`.
285///
286/// # Arguments
287///
288/// * `f`       – Right-hand side `dy/dt = f(t, y)`.
289/// * `t0`      – Initial time.
290/// * `y0`      – Initial state vector.
291/// * `t_end`   – Final time (may not be reached if a terminal event fires).
292/// * `rtol`    – Relative tolerance for DOPRI5.
293/// * `atol`    – Absolute tolerance for DOPRI5.
294/// * `events`  – The set of events to monitor.
295///
296/// # Errors
297///
298/// Propagates any errors from the underlying DOPRI5 integrator.
299pub fn dopri5_with_events<F>(
300    f: F,
301    t0: f64,
302    y0: &[f64],
303    t_end: f64,
304    rtol: f64,
305    atol: f64,
306    events: EventSet,
307) -> IntegrateResult<OdeEventResult>
308where
309    F: Fn(f64, &[f64]) -> Vec<f64> + Clone,
310{
311    if y0.is_empty() {
312        return Err(IntegrateError::ValueError(
313            "y0 must be non-empty".to_string(),
314        ));
315    }
316    if t_end <= t0 {
317        return Err(IntegrateError::ValueError("t_end must be > t0".to_string()));
318    }
319
320    let mut all_t: Vec<f64> = vec![t0];
321    let mut all_y: Vec<Vec<f64>> = vec![y0.to_vec()];
322    let mut all_events: Vec<EventResult> = Vec::new();
323    let mut n_steps_total: usize = 0;
324    let mut n_rejected_total: usize = 0;
325    let mut n_evals_total: usize = 0;
326    let mut terminated = false;
327
328    // Evaluate all event functions at t0
329    let mut g_prev: Vec<f64> = events.specs.iter().map(|s| (s.func)(t0, y0)).collect();
330
331    // Step through using DOPRI5 in segments.  We run one "short" integration
332    // at a time to keep the segment granularity coarse; then we scan for
333    // events within each returned step.
334    //
335    // For simplicity we drive DOPRI5 with a per-segment call and inspect the
336    // resulting trajectory pairwise.
337    let n_seg_max = 10_000_usize;
338    let seg_hint = ((t_end - t0) / 0.1).ceil() as usize; // ~100 points per segment
339    let n_seg = seg_hint.min(n_seg_max).max(1);
340
341    let dt_seg = (t_end - t0) / n_seg as f64;
342    let mut t_start = t0;
343    let mut y_start = y0.to_vec();
344
345    for _seg in 0..n_seg {
346        if terminated || t_start >= t_end - 1e-14 * (t_end - t0) {
347            break;
348        }
349
350        let t_seg_end = (t_start + dt_seg).min(t_end);
351
352        let seg_result = dopri5(f.clone(), t_start, &y_start, t_seg_end, rtol, atol)?;
353
354        n_steps_total += seg_result.n_steps;
355        n_rejected_total += seg_result.n_rejected;
356        n_evals_total += seg_result.n_evals;
357
358        // Scan each consecutive pair in the segment for events
359        let seg_len = seg_result.t.len();
360        let mut early_stop_idx: Option<usize> = None;
361
362        'step_scan: for step_i in 1..seg_len {
363            let t_p = seg_result.t[step_i - 1];
364            let t_c = seg_result.t[step_i];
365            let y_p = &seg_result.y[step_i - 1];
366            let y_c = &seg_result.y[step_i];
367
368            for (ev_idx, spec) in events.specs.iter().enumerate() {
369                let g_c = (spec.func)(t_c, y_c);
370                let g_p = g_prev[ev_idx];
371
372                if direction_matches(g_p, g_c, spec.direction) {
373                    if let Some(ev) = find_event_root(g_p, g_c, t_p, t_c, y_p, y_c, ev_idx, spec) {
374                        all_events.push(ev);
375                        if spec.terminal {
376                            early_stop_idx = Some(step_i);
377                            terminated = true;
378                            break 'step_scan;
379                        }
380                    }
381                }
382
383                g_prev[ev_idx] = g_c;
384            }
385        }
386
387        // Append trajectory points
388        let append_up_to = early_stop_idx.unwrap_or(seg_len);
389        for step_i in 1..append_up_to {
390            all_t.push(seg_result.t[step_i]);
391            all_y.push(seg_result.y[step_i].clone());
392        }
393
394        // If a terminal event fired add the event location as the final point
395        if terminated {
396            if let Some(last_ev) = all_events.last() {
397                all_t.push(last_ev.t_event);
398                all_y.push(last_ev.y_event.clone());
399            }
400            break;
401        }
402
403        // Advance to next segment
404        if let (Some(t_last), Some(y_last)) = (seg_result.t.last(), seg_result.y.last()) {
405            t_start = *t_last;
406            y_start = y_last.clone();
407        } else {
408            break;
409        }
410    }
411
412    let n_out = all_t.len();
413    Ok(OdeEventResult {
414        ode: OdeResult {
415            t: all_t,
416            y: all_y,
417            n_steps: n_steps_total,
418            n_rejected: n_rejected_total,
419            n_evals: n_evals_total + n_out, // approximate
420        },
421        events: all_events,
422        terminated,
423    })
424}
425
426// ─── Tests ───────────────────────────────────────────────────────────────────
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    // ── Illinois root finder ─────────────────────────────────────────────────
433
434    #[test]
435    fn illinois_finds_exact_midpoint() {
436        // g(t) = t - 0.5, crosses zero at t = 0.5
437        let spec = EventSpec {
438            func: Box::new(|t: f64, _y: &[f64]| t - 0.5),
439            direction: EventDirection::Rising,
440            terminal: false,
441        };
442        let y_prev = vec![1.0_f64];
443        let y_curr = vec![1.0_f64];
444        let result = find_event_root(-0.5, 0.5, 0.0, 1.0, &y_prev, &y_curr, 0, &spec)
445            .expect("should detect rising crossing");
446        assert!(
447            (result.t_event - 0.5).abs() < 1e-10,
448            "t_event={} expected 0.5",
449            result.t_event
450        );
451        assert_eq!(result.event_idx, 0);
452    }
453
454    #[test]
455    fn illinois_direction_filter_falling() {
456        // g goes from +1 to -1 → falling crossing
457        let spec_rising = EventSpec {
458            func: Box::new(|t: f64, _y: &[f64]| 1.0 - 2.0 * t), // crosses 0 at 0.5
459            direction: EventDirection::Rising,                  // should NOT match
460            terminal: false,
461        };
462        let y = vec![0.0_f64];
463        let res = find_event_root(1.0, -1.0, 0.0, 1.0, &y, &y, 0, &spec_rising);
464        assert!(
465            res.is_none(),
466            "Rising filter should reject falling crossing"
467        );
468
469        let spec_falling = EventSpec {
470            func: Box::new(|t: f64, _y: &[f64]| 1.0 - 2.0 * t),
471            direction: EventDirection::Falling,
472            terminal: false,
473        };
474        let res2 = find_event_root(1.0, -1.0, 0.0, 1.0, &y, &y, 0, &spec_falling)
475            .expect("Falling filter should accept falling crossing");
476        assert!((res2.t_event - 0.5).abs() < 1e-8);
477    }
478
479    #[test]
480    fn illinois_both_directions() {
481        let spec = EventSpec {
482            func: Box::new(|t: f64, _y: &[f64]| (t - 0.3).sin()),
483            direction: EventDirection::Both,
484            terminal: false,
485        };
486        let y = vec![0.0_f64];
487        // Any sign change should be caught
488        let res = find_event_root(-0.5, 0.5, 0.0, 0.6, &y, &y, 2, &spec);
489        let ev = res.expect("should find crossing");
490        assert_eq!(ev.event_idx, 2);
491    }
492
493    // ── dopri5_with_events ───────────────────────────────────────────────────
494
495    #[test]
496    fn events_detect_zero_crossing_sin() {
497        // dy/dt = cos(t), y(0) = 0 → y(t) = sin(t)
498        // Event: y crosses zero again at t = π
499        let f = |t: f64, _y: &[f64]| vec![t.cos()];
500        let event_spec = EventSpec {
501            func: Box::new(|_t: f64, y: &[f64]| y[0]),
502            direction: EventDirection::Falling, // sin goes positive → negative at π
503            terminal: false,
504        };
505        let events = EventSet::new(vec![event_spec]);
506        let result = dopri5_with_events(f, 0.0, &[0.0], 4.0, 1e-8, 1e-10, events)
507            .expect("integration failed");
508
509        // Should detect a crossing near t = π ≈ 3.14159
510        let pi = std::f64::consts::PI;
511        let found = result.events.iter().any(|e| (e.t_event - pi).abs() < 0.05);
512        assert!(
513            found,
514            "Expected crossing near t=π, got events: {:?}",
515            result.events.iter().map(|e| e.t_event).collect::<Vec<_>>()
516        );
517        assert!(!result.terminated);
518    }
519
520    #[test]
521    fn events_terminal_stops_integration() {
522        // dy/dt = -y, y(0) = 1  →  y(t) = exp(-t)
523        // Terminal event: y < 0.5 (triggers when exp(-t) = 0.5, i.e. t = ln 2 ≈ 0.693)
524        let f = |_t: f64, y: &[f64]| vec![-y[0]];
525        let threshold = EventSpec {
526            func: Box::new(|_t: f64, y: &[f64]| y[0] - 0.5), // crosses 0 from above
527            direction: EventDirection::Falling,
528            terminal: true,
529        };
530        let events = EventSet::new(vec![threshold]);
531        let result = dopri5_with_events(f, 0.0, &[1.0], 5.0, 1e-8, 1e-10, events)
532            .expect("integration failed");
533
534        assert!(result.terminated, "Expected terminal stop");
535        // Integration should stop well before t = 5
536        let t_final = result.ode.t.last().copied().unwrap_or(0.0);
537        let ln2 = 2.0_f64.ln();
538        assert!(
539            (t_final - ln2).abs() < 0.1,
540            "Expected termination near t=ln2≈{ln2:.4}, got t={t_final:.4}"
541        );
542        assert!(!result.events.is_empty());
543    }
544
545    #[test]
546    fn events_multiple_crossings() {
547        // dy/dt = 1, y(0) = 0  →  y(t) = t
548        // Detect crossings of thresholds at t = 1, 2, 3
549        let f = |_t: f64, _y: &[f64]| vec![1.0];
550        let mut specs = Vec::new();
551        for thresh in [1.0_f64, 2.0, 3.0] {
552            specs.push(EventSpec {
553                func: Box::new(move |_t: f64, y: &[f64]| y[0] - thresh),
554                direction: EventDirection::Rising,
555                terminal: false,
556            });
557        }
558        let events = EventSet::new(specs);
559        let result = dopri5_with_events(f, 0.0, &[0.0], 4.0, 1e-8, 1e-10, events)
560            .expect("integration failed");
561
562        // Should detect 3 crossings
563        assert!(
564            result.events.len() >= 3,
565            "expected ≥3 events, got {}",
566            result.events.len()
567        );
568    }
569
570    #[test]
571    fn events_validates_empty_y0() {
572        let f = |_t: f64, _y: &[f64]| vec![];
573        let events = EventSet::new(vec![]);
574        assert!(dopri5_with_events(f, 0.0, &[], 1.0, 1e-6, 1e-8, events).is_err());
575    }
576
577    #[test]
578    fn events_validates_t_end_leq_t0() {
579        let f = |_t: f64, y: &[f64]| vec![-y[0]];
580        let events = EventSet::new(vec![]);
581        assert!(dopri5_with_events(f, 1.0, &[1.0], 0.5, 1e-6, 1e-8, events).is_err());
582    }
583}