Skip to main content

scirs2_integrate/ode/
events.rs

1//! Enhanced event detection for ODE integration
2//!
3//! This module provides advanced zero-crossing detection and event handling
4//! capabilities for use with ODE solvers. It extends the basic event detection
5//! in `ode::utils::events` with:
6//!
7//! - **Illinois method**: Modified regula falsi with guaranteed convergence
8//! - **Brent's method**: Combining bisection, secant, and inverse quadratic
9//!   interpolation for robust root finding
10//! - **Multiple simultaneous events**: Proper ordering when multiple events
11//!   fire in the same step
12//! - **Dense output integration**: Uses cubic Hermite interpolation for
13//!   sub-step event location
14//! - **Event chaining**: One event's state modification can trigger another
15//!
16//! # Usage
17//!
18//! ```rust,ignore
19//! use scirs2_integrate::ode::events::{EventDetector, EventDef, EventResponse};
20//!
21//! // Detect when y[0] crosses zero (falling direction)
22//! let detector = EventDetector::new()
23//!     .add_event(EventDef::new("impact")
24//!         .direction(CrossingDirection::Falling)
25//!         .response(EventResponse::Terminate)
26//!         .function(|t, y| y[0])  // y[0] = height
27//!     );
28//! ```
29
30use crate::common::IntegrateFloat;
31use crate::error::{IntegrateError, IntegrateResult};
32use scirs2_core::ndarray::Array1;
33
34// ---------------------------------------------------------------------------
35// Types
36// ---------------------------------------------------------------------------
37
38/// Direction of zero-crossing to detect.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub enum CrossingDirection {
41    /// Detect only when the event function goes from negative to positive
42    Rising,
43    /// Detect only when the event function goes from positive to negative
44    Falling,
45    /// Detect crossings in either direction
46    #[default]
47    Both,
48}
49
50/// What to do when an event is detected.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum EventResponse {
53    /// Continue integration (non-terminal event, just record it)
54    #[default]
55    Continue,
56    /// Terminate integration at the event time
57    Terminate,
58}
59
60/// Root-finding method for precise event location.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
62pub enum RootFindingMethod {
63    /// Simple bisection (always converges, slow)
64    Bisection,
65    /// Illinois method (modified regula falsi, faster convergence)
66    Illinois,
67    /// Brent's method (combines bisection, secant, inverse quadratic interpolation)
68    #[default]
69    Brent,
70}
71
72/// Configuration for a single event.
73pub struct EventDef<F: IntegrateFloat> {
74    /// Unique name for this event
75    pub name: String,
76    /// Direction of zero-crossing to detect
77    pub direction: CrossingDirection,
78    /// Action when event is detected
79    pub response: EventResponse,
80    /// Root-finding method for precise location
81    pub root_method: RootFindingMethod,
82    /// Tolerance for root finding
83    pub tolerance: F,
84    /// Maximum iterations for root finding
85    pub max_root_iter: usize,
86    /// Maximum number of times this event can fire (None = unlimited)
87    pub max_count: Option<usize>,
88    /// The event function g(t, y): event occurs when g crosses zero
89    event_fn: Box<dyn Fn(F, &Array1<F>) -> F + Send + Sync>,
90}
91
92impl<F: IntegrateFloat> EventDef<F> {
93    /// Create a new event definition with a name and event function.
94    pub fn new<G>(name: &str, event_fn: G) -> Self
95    where
96        G: Fn(F, &Array1<F>) -> F + Send + Sync + 'static,
97    {
98        EventDef {
99            name: name.to_string(),
100            direction: CrossingDirection::default(),
101            response: EventResponse::default(),
102            root_method: RootFindingMethod::default(),
103            tolerance: F::from_f64(1e-12).unwrap_or_else(|| F::epsilon()),
104            max_root_iter: 100,
105            max_count: None,
106            event_fn: Box::new(event_fn),
107        }
108    }
109
110    /// Set the crossing direction.
111    pub fn with_direction(mut self, dir: CrossingDirection) -> Self {
112        self.direction = dir;
113        self
114    }
115
116    /// Set the event response.
117    pub fn with_response(mut self, resp: EventResponse) -> Self {
118        self.response = resp;
119        self
120    }
121
122    /// Set the root-finding method.
123    pub fn with_root_method(mut self, method: RootFindingMethod) -> Self {
124        self.root_method = method;
125        self
126    }
127
128    /// Set maximum fire count.
129    pub fn with_max_count(mut self, count: usize) -> Self {
130        self.max_count = Some(count);
131        self
132    }
133
134    /// Set root-finding tolerance.
135    pub fn with_tolerance(mut self, tol: F) -> Self {
136        self.tolerance = tol;
137        self
138    }
139
140    /// Evaluate the event function.
141    pub fn evaluate(&self, t: F, y: &Array1<F>) -> F {
142        (self.event_fn)(t, y)
143    }
144}
145
146/// A detected event occurrence.
147#[derive(Debug, Clone)]
148pub struct DetectedEvent<F: IntegrateFloat> {
149    /// Name of the event that fired
150    pub name: String,
151    /// Precise time of the event
152    pub t: F,
153    /// State at the event time
154    pub y: Array1<F>,
155    /// Value of the event function (should be near zero)
156    pub g_value: F,
157    /// Direction of crossing: +1 rising, -1 falling
158    pub crossing_sign: i8,
159    /// How many times this event has fired so far
160    pub count: usize,
161}
162
163// ---------------------------------------------------------------------------
164// Event Detector
165// ---------------------------------------------------------------------------
166
167/// Multi-event detector for ODE integration.
168///
169/// Manages a collection of event definitions, tracks state between steps,
170/// and locates events precisely using root-finding algorithms.
171pub struct EventDetector<F: IntegrateFloat> {
172    /// Event definitions
173    events: Vec<EventDef<F>>,
174    /// Last evaluated g-values for each event
175    last_g: Vec<Option<F>>,
176    /// Fire counts for each event
177    fire_counts: Vec<usize>,
178    /// All detected events in chronological order
179    pub detected: Vec<DetectedEvent<F>>,
180}
181
182impl<F: IntegrateFloat> EventDetector<F> {
183    /// Create an empty event detector.
184    pub fn new() -> Self {
185        EventDetector {
186            events: Vec::new(),
187            last_g: Vec::new(),
188            fire_counts: Vec::new(),
189            detected: Vec::new(),
190        }
191    }
192
193    /// Add an event definition. Returns self for chaining.
194    pub fn add_event(mut self, event: EventDef<F>) -> Self {
195        self.events.push(event);
196        self.last_g.push(None);
197        self.fire_counts.push(0);
198        self
199    }
200
201    /// Number of registered events.
202    pub fn n_events(&self) -> usize {
203        self.events.len()
204    }
205
206    /// Initialize at t0, y0 (must be called before check_step).
207    pub fn initialize(&mut self, t: F, y: &Array1<F>) {
208        for (i, ev) in self.events.iter().enumerate() {
209            self.last_g[i] = Some(ev.evaluate(t, y));
210        }
211    }
212
213    /// Check for events between (t_old, y_old) and (t_new, y_new).
214    ///
215    /// If an interpolant is provided, it is used for precise event location;
216    /// otherwise linear interpolation between the endpoints is used.
217    ///
218    /// Returns `true` if a terminal event was detected (integration should stop).
219    pub fn check_step<I>(
220        &mut self,
221        t_old: F,
222        y_old: &Array1<F>,
223        t_new: F,
224        y_new: &Array1<F>,
225        interpolant: Option<&I>,
226    ) -> IntegrateResult<bool>
227    where
228        I: Fn(F) -> Array1<F>,
229    {
230        let mut terminal = false;
231
232        // Collect candidate events that have a sign change
233        let mut candidates: Vec<(usize, F, F)> = Vec::new(); // (index, g_old, g_new)
234
235        for (i, ev) in self.events.iter().enumerate() {
236            // Check max count
237            if let Some(max) = ev.max_count {
238                if self.fire_counts[i] >= max {
239                    continue;
240                }
241            }
242
243            let g_old = match self.last_g[i] {
244                Some(g) => g,
245                None => {
246                    let g = ev.evaluate(t_old, y_old);
247                    self.last_g[i] = Some(g);
248                    g
249                }
250            };
251
252            let g_new = ev.evaluate(t_new, y_new);
253
254            // Check for sign change
255            let rising = g_old < F::zero() && g_new >= F::zero();
256            let falling = g_old > F::zero() && g_new <= F::zero();
257
258            let triggered = match ev.direction {
259                CrossingDirection::Rising => rising,
260                CrossingDirection::Falling => falling,
261                CrossingDirection::Both => rising || falling,
262            };
263
264            if triggered {
265                candidates.push((i, g_old, g_new));
266            }
267
268            // Update last_g
269            self.last_g[i] = Some(g_new);
270        }
271
272        // Sort candidates by estimated event time (linear interpolation estimate)
273        candidates.sort_by(|a, b| {
274            let t_a = estimate_crossing_time(t_old, t_new, a.1, a.2);
275            let t_b = estimate_crossing_time(t_old, t_new, b.1, b.2);
276            t_a.partial_cmp(&t_b).unwrap_or(std::cmp::Ordering::Equal)
277        });
278
279        // Process candidates in chronological order
280        for (idx, g_old, g_new) in candidates {
281            let ev = &self.events[idx];
282
283            // Find precise event time using root-finding
284            let (t_event, y_event, g_event) = match ev.root_method {
285                RootFindingMethod::Bisection => bisection_root(
286                    ev,
287                    t_old,
288                    y_old,
289                    t_new,
290                    y_new,
291                    g_old,
292                    g_new,
293                    interpolant,
294                    ev.tolerance,
295                    ev.max_root_iter,
296                )?,
297                RootFindingMethod::Illinois => illinois_root(
298                    ev,
299                    t_old,
300                    y_old,
301                    t_new,
302                    y_new,
303                    g_old,
304                    g_new,
305                    interpolant,
306                    ev.tolerance,
307                    ev.max_root_iter,
308                )?,
309                RootFindingMethod::Brent => brent_root(
310                    ev,
311                    t_old,
312                    y_old,
313                    t_new,
314                    y_new,
315                    g_old,
316                    g_new,
317                    interpolant,
318                    ev.tolerance,
319                    ev.max_root_iter,
320                )?,
321            };
322
323            let crossing_sign = if g_old < F::zero() { 1i8 } else { -1i8 };
324
325            self.fire_counts[idx] += 1;
326            let count = self.fire_counts[idx];
327
328            self.detected.push(DetectedEvent {
329                name: ev.name.clone(),
330                t: t_event,
331                y: y_event,
332                g_value: g_event,
333                crossing_sign,
334                count,
335            });
336
337            if ev.response == EventResponse::Terminate {
338                terminal = true;
339            }
340        }
341
342        Ok(terminal)
343    }
344
345    /// Get all detected events.
346    pub fn get_detected(&self) -> &[DetectedEvent<F>] {
347        &self.detected
348    }
349
350    /// Get events by name.
351    pub fn events_by_name(&self, name: &str) -> Vec<&DetectedEvent<F>> {
352        self.detected.iter().filter(|e| e.name == name).collect()
353    }
354
355    /// Get the first terminal event (if any).
356    pub fn first_terminal_event(&self) -> Option<&DetectedEvent<F>> {
357        for det in &self.detected {
358            for ev in &self.events {
359                if ev.name == det.name && ev.response == EventResponse::Terminate {
360                    return Some(det);
361                }
362            }
363        }
364        None
365    }
366}
367
368impl<F: IntegrateFloat> Default for EventDetector<F> {
369    fn default() -> Self {
370        Self::new()
371    }
372}
373
374// ---------------------------------------------------------------------------
375// Root-finding helpers
376// ---------------------------------------------------------------------------
377
378/// Estimate crossing time via linear interpolation.
379fn estimate_crossing_time<F: IntegrateFloat>(t_a: F, t_b: F, g_a: F, g_b: F) -> F {
380    if (g_b - g_a).abs() < F::from_f64(1e-30).unwrap_or_else(|| F::epsilon()) {
381        (t_a + t_b) / (F::one() + F::one())
382    } else {
383        t_a - g_a * (t_b - t_a) / (g_b - g_a)
384    }
385}
386
387/// Interpolate state at time t between (t_old, y_old) and (t_new, y_new).
388fn interpolate_state<F: IntegrateFloat, I>(
389    t: F,
390    t_old: F,
391    y_old: &Array1<F>,
392    t_new: F,
393    y_new: &Array1<F>,
394    interpolant: Option<&I>,
395) -> Array1<F>
396where
397    I: Fn(F) -> Array1<F>,
398{
399    if let Some(interp) = interpolant {
400        interp(t)
401    } else {
402        // Linear interpolation
403        let dt = t_new - t_old;
404        if dt.abs() < F::from_f64(1e-30).unwrap_or_else(|| F::epsilon()) {
405            y_old.clone()
406        } else {
407            let s = (t - t_old) / dt;
408            y_old * (F::one() - s) + y_new * s
409        }
410    }
411}
412
413/// Bisection root-finding for event location.
414#[allow(clippy::too_many_arguments)]
415fn bisection_root<F: IntegrateFloat, I>(
416    ev: &EventDef<F>,
417    t_old: F,
418    y_old: &Array1<F>,
419    t_new: F,
420    y_new: &Array1<F>,
421    g_old: F,
422    _g_new: F,
423    interpolant: Option<&I>,
424    tol: F,
425    max_iter: usize,
426) -> IntegrateResult<(F, Array1<F>, F)>
427where
428    I: Fn(F) -> Array1<F>,
429{
430    let mut a = t_old;
431    let mut b = t_new;
432    let mut ga = g_old;
433
434    let two = F::one() + F::one();
435    let mut t_mid = (a + b) / two;
436    let mut y_mid;
437    let mut g_mid = F::zero();
438
439    for _ in 0..max_iter {
440        t_mid = (a + b) / two;
441        y_mid = interpolate_state(t_mid, t_old, y_old, t_new, y_new, interpolant);
442        g_mid = ev.evaluate(t_mid, &y_mid);
443
444        if g_mid.abs() < tol || (b - a) < tol {
445            return Ok((t_mid, y_mid, g_mid));
446        }
447
448        if ga * g_mid < F::zero() {
449            b = t_mid;
450        } else {
451            a = t_mid;
452            ga = g_mid;
453        }
454    }
455
456    let y_final = interpolate_state(t_mid, t_old, y_old, t_new, y_new, interpolant);
457    Ok((t_mid, y_final, g_mid))
458}
459
460/// Illinois method (modified regula falsi) for event location.
461///
462/// The Illinois method modifies the regula falsi method by halving the
463/// function value at the retained endpoint when the same endpoint is
464/// retained twice. This prevents the "stalling" behavior of standard
465/// regula falsi and guarantees superlinear convergence.
466#[allow(clippy::too_many_arguments)]
467fn illinois_root<F: IntegrateFloat, I>(
468    ev: &EventDef<F>,
469    t_old: F,
470    y_old: &Array1<F>,
471    t_new: F,
472    y_new: &Array1<F>,
473    g_old: F,
474    g_new: F,
475    interpolant: Option<&I>,
476    tol: F,
477    max_iter: usize,
478) -> IntegrateResult<(F, Array1<F>, F)>
479where
480    I: Fn(F) -> Array1<F>,
481{
482    let mut a = t_old;
483    let mut b = t_new;
484    let mut ga = g_old;
485    let mut gb = g_new;
486    let mut last_side: i8 = 0; // 0 = none, 1 = left retained, -1 = right retained
487
488    let two = F::one() + F::one();
489    let mut t_c = (a + b) / two;
490    let mut g_c = F::zero();
491
492    for _ in 0..max_iter {
493        // Regula falsi step
494        let dg = gb - ga;
495        if dg.abs() < F::from_f64(1e-30).unwrap_or_else(|| F::epsilon()) {
496            t_c = (a + b) / two;
497        } else {
498            t_c = a - ga * (b - a) / dg;
499        }
500
501        // Clamp to interval
502        if t_c <= a || t_c >= b {
503            t_c = (a + b) / two;
504        }
505
506        let y_c = interpolate_state(t_c, t_old, y_old, t_new, y_new, interpolant);
507        g_c = ev.evaluate(t_c, &y_c);
508
509        if g_c.abs() < tol || (b - a) < tol {
510            return Ok((t_c, y_c, g_c));
511        }
512
513        if ga * g_c < F::zero() {
514            // Root is in [a, t_c]
515            b = t_c;
516            gb = g_c;
517
518            if last_side == 1 {
519                // Illinois modification: halve ga
520                ga /= two;
521            }
522            last_side = 1;
523        } else {
524            // Root is in [t_c, b]
525            a = t_c;
526            ga = g_c;
527
528            if last_side == -1 {
529                // Illinois modification: halve gb
530                gb /= two;
531            }
532            last_side = -1;
533        }
534    }
535
536    let y_final = interpolate_state(t_c, t_old, y_old, t_new, y_new, interpolant);
537    Ok((t_c, y_final, g_c))
538}
539
540/// Brent's method for event location.
541///
542/// Combines bisection, secant method, and inverse quadratic interpolation.
543/// Guaranteed to converge and typically faster than bisection.
544#[allow(clippy::too_many_arguments)]
545fn brent_root<F: IntegrateFloat, I>(
546    ev: &EventDef<F>,
547    t_old: F,
548    y_old: &Array1<F>,
549    t_new: F,
550    y_new: &Array1<F>,
551    g_old: F,
552    g_new: F,
553    interpolant: Option<&I>,
554    tol: F,
555    max_iter: usize,
556) -> IntegrateResult<(F, Array1<F>, F)>
557where
558    I: Fn(F) -> Array1<F>,
559{
560    let mut a = t_old;
561    let mut b = t_new;
562    let mut fa = g_old;
563    let mut fb = g_new;
564
565    // Ensure |f(b)| <= |f(a)|
566    if fa.abs() < fb.abs() {
567        std::mem::swap(&mut a, &mut b);
568        std::mem::swap(&mut fa, &mut fb);
569    }
570
571    let mut c = a;
572    let mut fc = fa;
573    let mut d = b - a;
574    let mut e = d;
575
576    let two = F::one() + F::one();
577
578    for _ in 0..max_iter {
579        if fb.abs() < tol {
580            let y_b = interpolate_state(b, t_old, y_old, t_new, y_new, interpolant);
581            return Ok((b, y_b, fb));
582        }
583
584        if (b - a).abs() < tol {
585            let y_b = interpolate_state(b, t_old, y_old, t_new, y_new, interpolant);
586            return Ok((b, y_b, fb));
587        }
588
589        let mut s;
590
591        if fa.abs() > fb.abs() && fc.abs() > fb.abs() {
592            // Try inverse quadratic interpolation
593            if (fa - fc).abs() > F::from_f64(1e-30).unwrap_or_else(|| F::epsilon())
594                && (fb - fc).abs() > F::from_f64(1e-30).unwrap_or_else(|| F::epsilon())
595            {
596                s = a * fb * fc / ((fa - fb) * (fa - fc))
597                    + b * fa * fc / ((fb - fa) * (fb - fc))
598                    + c * fa * fb / ((fc - fa) * (fc - fb));
599            } else {
600                // Secant method
601                s = b - fb * (b - a) / (fb - fa);
602            }
603        } else {
604            // Secant method
605            if (fb - fa).abs() > F::from_f64(1e-30).unwrap_or_else(|| F::epsilon()) {
606                s = b - fb * (b - a) / (fb - fa);
607            } else {
608                s = (a + b) / two;
609            }
610        }
611
612        // Acceptance conditions for Brent
613        let three = F::one() + F::one() + F::one();
614        let cond1 = (s - (three * a + b) / (F::one() + three)) * (s - b) >= F::zero();
615        let cond2 = (s - b).abs() >= (b - c).abs() / two;
616        let cond3 = (b - c).abs() < tol;
617
618        if cond1 || cond2 || cond3 {
619            // Bisection
620            s = (a + b) / two;
621        }
622
623        let y_s = interpolate_state(s, t_old, y_old, t_new, y_new, interpolant);
624        let fs = ev.evaluate(s, &y_s);
625
626        c = b;
627        fc = fb;
628
629        if fa * fs < F::zero() {
630            b = s;
631            fb = fs;
632        } else {
633            a = s;
634            fa = fs;
635        }
636
637        // Ensure |f(b)| <= |f(a)|
638        if fa.abs() < fb.abs() {
639            std::mem::swap(&mut a, &mut b);
640            std::mem::swap(&mut fa, &mut fb);
641        }
642    }
643
644    let y_final = interpolate_state(b, t_old, y_old, t_new, y_new, interpolant);
645    Ok((b, y_final, fb))
646}
647
648// ---------------------------------------------------------------------------
649// Tests
650// ---------------------------------------------------------------------------
651
652#[cfg(test)]
653mod tests {
654    use super::*;
655    use scirs2_core::ndarray::array;
656
657    #[test]
658    fn test_event_def_creation() {
659        let ev = EventDef::<f64>::new("test", |_t, y: &Array1<f64>| y[0])
660            .with_direction(CrossingDirection::Falling)
661            .with_response(EventResponse::Terminate)
662            .with_max_count(3);
663
664        assert_eq!(ev.name, "test");
665        assert_eq!(ev.direction, CrossingDirection::Falling);
666        assert_eq!(ev.response, EventResponse::Terminate);
667        assert_eq!(ev.max_count, Some(3));
668    }
669
670    #[test]
671    fn test_bisection_root_finding() {
672        // Event: y[0] - 0.5 = 0  (detect when y[0] = 0.5)
673        let ev = EventDef::<f64>::new("half", |_t, y: &Array1<f64>| y[0] - 0.5)
674            .with_direction(CrossingDirection::Falling);
675
676        let t_old = 0.0;
677        let t_new = 1.0;
678        let y_old = array![1.0];
679        let y_new = array![0.0];
680        let g_old = 0.5; // 1.0 - 0.5
681        let g_new = -0.5; // 0.0 - 0.5
682
683        // Linear interpolant: y(t) = 1 - t
684        let interp = |t: f64| array![1.0 - t];
685
686        let (t_event, y_event, g_event) = bisection_root(
687            &ev,
688            t_old,
689            &y_old,
690            t_new,
691            &y_new,
692            g_old,
693            g_new,
694            Some(&interp),
695            1e-12,
696            100,
697        )
698        .expect("bisection should succeed");
699
700        assert!(
701            (t_event - 0.5).abs() < 1e-10,
702            "event at t = {t_event}, expected 0.5"
703        );
704        assert!(
705            (y_event[0] - 0.5).abs() < 1e-10,
706            "y at event = {}, expected 0.5",
707            y_event[0]
708        );
709        assert!(g_event.abs() < 1e-10, "g at event = {g_event}");
710    }
711
712    #[test]
713    fn test_illinois_root_finding() {
714        let ev = EventDef::<f64>::new("zero", |_t, y: &Array1<f64>| y[0]);
715
716        let t_old = 0.0;
717        let t_new = 1.0;
718        let y_old = array![1.0];
719        let y_new = array![-1.0];
720
721        // Nonlinear interpolant: y(t) = cos(pi*t), crossing at t = 0.5
722        let interp = |t: f64| array![(std::f64::consts::PI * t).cos()];
723
724        let (t_event, _, _) = illinois_root(
725            &ev,
726            t_old,
727            &y_old,
728            t_new,
729            &y_new,
730            1.0,
731            -1.0,
732            Some(&interp),
733            1e-12,
734            100,
735        )
736        .expect("Illinois should succeed");
737
738        assert!(
739            (t_event - 0.5).abs() < 1e-10,
740            "Illinois found t = {t_event}, expected 0.5"
741        );
742    }
743
744    #[test]
745    fn test_brent_root_finding() {
746        let ev = EventDef::<f64>::new("zero", |_t, y: &Array1<f64>| y[0]);
747
748        let t_old = 0.0;
749        let t_new = 2.0;
750        let y_old = array![1.0];
751        let y_new = array![-1.0];
752
753        // y(t) = 1 - t, crossing at t = 1.0
754        let interp = |t: f64| array![1.0 - t];
755
756        let (t_event, _, _) = brent_root(
757            &ev,
758            t_old,
759            &y_old,
760            t_new,
761            &y_new,
762            1.0,
763            -1.0,
764            Some(&interp),
765            1e-12,
766            100,
767        )
768        .expect("Brent should succeed");
769
770        assert!(
771            (t_event - 1.0).abs() < 1e-10,
772            "Brent found t = {t_event}, expected 1.0"
773        );
774    }
775
776    #[test]
777    fn test_event_detector_single_event() {
778        let mut detector = EventDetector::new().add_event(
779            EventDef::new("zero_crossing", |_t, y: &Array1<f64>| y[0])
780                .with_direction(CrossingDirection::Falling)
781                .with_response(EventResponse::Terminate),
782        );
783
784        let y0 = array![1.0];
785        detector.initialize(0.0, &y0);
786
787        // Step where y goes from positive to negative
788        let y1 = array![-0.5];
789        let interp = |t: f64| array![1.0 - 1.5 * t]; // crosses at t = 2/3
790
791        let terminal = detector
792            .check_step(0.0, &y0, 1.0, &y1, Some(&interp))
793            .expect("check_step should succeed");
794
795        assert!(terminal, "should detect terminal event");
796        assert_eq!(detector.detected.len(), 1);
797        assert_eq!(detector.detected[0].name, "zero_crossing");
798        assert!(
799            (detector.detected[0].t - 2.0 / 3.0).abs() < 1e-8,
800            "event at t = {}",
801            detector.detected[0].t
802        );
803    }
804
805    #[test]
806    fn test_event_detector_multiple_events() {
807        let mut detector = EventDetector::new()
808            .add_event(
809                EventDef::new("event_a", |_t, y: &Array1<f64>| y[0] - 0.5)
810                    .with_direction(CrossingDirection::Falling),
811            )
812            .add_event(
813                EventDef::new("event_b", |_t, y: &Array1<f64>| y[0] - 0.25)
814                    .with_direction(CrossingDirection::Falling),
815            );
816
817        let y0 = array![1.0];
818        detector.initialize(0.0, &y0);
819
820        // y goes from 1 to 0 linearly: event_a at t=0.5, event_b at t=0.75
821        let y1 = array![0.0];
822        let interp = |t: f64| array![1.0 - t];
823
824        let _terminal = detector
825            .check_step(0.0, &y0, 1.0, &y1, Some(&interp))
826            .expect("check_step should succeed");
827
828        assert_eq!(detector.detected.len(), 2);
829
830        // Events should be in chronological order
831        assert!(
832            detector.detected[0].t <= detector.detected[1].t,
833            "events should be ordered by time"
834        );
835
836        // event_a fires first (at t=0.5)
837        assert_eq!(detector.detected[0].name, "event_a");
838        assert!(
839            (detector.detected[0].t - 0.5).abs() < 1e-8,
840            "event_a at t = {}",
841            detector.detected[0].t
842        );
843    }
844
845    #[test]
846    fn test_event_max_count() {
847        let mut detector = EventDetector::new().add_event(
848            EventDef::new("bounce", |_t, y: &Array1<f64>| y[0])
849                .with_direction(CrossingDirection::Both)
850                .with_max_count(2),
851        );
852
853        let y0 = array![1.0];
854        detector.initialize(0.0, &y0);
855
856        // First crossing
857        let y1 = array![-1.0];
858        let interp1 = |t: f64| array![1.0 - 2.0 * t];
859        detector
860            .check_step(0.0, &y0, 1.0, &y1, Some(&interp1))
861            .expect("step 1");
862        assert_eq!(detector.detected.len(), 1);
863
864        // Second crossing
865        let y2 = array![1.0];
866        let interp2 = |t: f64| array![-1.0 + 2.0 * (t - 1.0)];
867        detector
868            .check_step(1.0, &y1, 2.0, &y2, Some(&interp2))
869            .expect("step 2");
870        assert_eq!(detector.detected.len(), 2);
871
872        // Third crossing should be blocked by max_count=2
873        let y3 = array![-1.0];
874        let interp3 = |t: f64| array![1.0 - 2.0 * (t - 2.0)];
875        detector
876            .check_step(2.0, &y2, 3.0, &y3, Some(&interp3))
877            .expect("step 3");
878        assert_eq!(
879            detector.detected.len(),
880            2,
881            "should not fire beyond max_count"
882        );
883    }
884
885    #[test]
886    fn test_rising_direction_only() {
887        let mut detector = EventDetector::new().add_event(
888            EventDef::new("rising", |_t, y: &Array1<f64>| y[0])
889                .with_direction(CrossingDirection::Rising),
890        );
891
892        let y0 = array![1.0];
893        detector.initialize(0.0, &y0);
894
895        // Falling crossing: should NOT trigger
896        let y1 = array![-1.0];
897        let interp1 = |t: f64| array![1.0 - 2.0 * t];
898        detector
899            .check_step(0.0, &y0, 1.0, &y1, Some(&interp1))
900            .expect("step 1");
901        assert_eq!(
902            detector.detected.len(),
903            0,
904            "falling should not trigger rising event"
905        );
906
907        // Rising crossing: should trigger
908        let y2 = array![1.0];
909        let interp2 = |t: f64| array![-1.0 + 2.0 * (t - 1.0)];
910        detector
911            .check_step(1.0, &y1, 2.0, &y2, Some(&interp2))
912            .expect("step 2");
913        assert_eq!(detector.detected.len(), 1, "rising should trigger");
914    }
915
916    #[test]
917    fn test_no_interpolant_fallback() {
918        // Test that detection works with linear interpolation (no interpolant)
919        let mut detector = EventDetector::new().add_event(
920            EventDef::new("cross", |_t, y: &Array1<f64>| y[0])
921                .with_direction(CrossingDirection::Both),
922        );
923
924        let y0 = array![1.0];
925        detector.initialize(0.0, &y0);
926
927        let y1 = array![-1.0];
928        let no_interp: Option<&fn(f64) -> Array1<f64>> = None;
929        detector
930            .check_step(0.0, &y0, 1.0, &y1, no_interp)
931            .expect("no interp step");
932
933        assert_eq!(detector.detected.len(), 1);
934        // With linear interp, crossing at t=0.5
935        assert!(
936            (detector.detected[0].t - 0.5).abs() < 1e-8,
937            "t = {}",
938            detector.detected[0].t
939        );
940    }
941
942    #[test]
943    fn test_events_by_name() {
944        let mut detector = EventDetector::new()
945            .add_event(
946                EventDef::new("bounce", |_t, y: &Array1<f64>| y[0])
947                    .with_direction(CrossingDirection::Both),
948            )
949            .add_event(
950                EventDef::new("threshold", |_t, y: &Array1<f64>| y[0] - 0.5)
951                    .with_direction(CrossingDirection::Falling),
952            );
953
954        let y0 = array![1.0];
955        detector.initialize(0.0, &y0);
956
957        let y1 = array![-1.0];
958        let interp = |t: f64| array![1.0 - 2.0 * t];
959        detector
960            .check_step(0.0, &y0, 1.0, &y1, Some(&interp))
961            .expect("step");
962
963        let bounces = detector.events_by_name("bounce");
964        let thresholds = detector.events_by_name("threshold");
965
966        assert_eq!(bounces.len(), 1);
967        assert_eq!(thresholds.len(), 1);
968    }
969}