Skip to main content

differential_equations/solout/
event.rs

1//! Solout implementation which takes a Event object and uses it to detect events
2
3use super::*;
4use crate::traits::DefaultState;
5
6pub struct EventConfig {
7    /// Direction of zero crossing to detect
8    pub direction: CrossingDirection,
9    /// Number of events before termination
10    pub terminate: Option<u32>,
11}
12
13impl Default for EventConfig {
14    fn default() -> Self {
15        Self {
16            direction: CrossingDirection::Both,
17            terminate: None,
18        }
19    }
20}
21
22impl EventConfig {
23    /// Create a new EventConfig with specified direction and termination count
24    pub fn new(direction: impl Into<CrossingDirection>, terminate: Option<u32>) -> Self {
25        Self {
26            direction: direction.into(),
27            terminate,
28        }
29    }
30
31    pub fn direction(mut self, direction: impl Into<CrossingDirection>) -> Self {
32        self.direction = direction.into();
33        self
34    }
35
36    /// Set the number of events before termination
37    pub fn terminate_after(mut self, count: u32) -> Self {
38        self.terminate = Some(count);
39        self
40    }
41
42    /// Set to terminate after the first event
43    pub fn terminal(mut self) -> Self {
44        self.terminate = Some(1);
45        self
46    }
47}
48
49pub trait Event<T: Real = f64, Y: State<T> = DefaultState<T>> {
50    /// Configure the event detection parameters (called once at initialization).
51    fn config(&self) -> EventConfig {
52        EventConfig::default()
53    }
54
55    /// Event function g(t,y) whose zero crossings are detected.
56    fn event(&self, t: T, y: &Y) -> T;
57}
58
59/// Solout implementation that evaluates user-provided Event objects similar to SciPy events.
60///
61/// The `EventSolout` monitors the sign of `event.event(t, y)` across solver steps. When a sign
62/// change consistent with the configured `CrossingDirection` is detected, a Brent-Dekker root
63/// finding procedure is applied (using the solver's interpolation) to locate the event time with
64/// high accuracy. The event point `(t_event, y_event)` is then appended to the solution. Depending
65/// on the `EventConfig::terminate` setting the integration may terminate after collecting the
66/// specified number of events.
67pub struct EventSolout<'a, T: Real, Y: State<T>, E: Event<T, Y> + ?Sized> {
68    /// User provided event object implementing `Event`
69    event: &'a E,
70    /// Configuration (direction filtering and termination count)
71    config: EventConfig,
72    /// Last event function value g(t_prev, y_prev)
73    last_g: Option<T>,
74    /// Number of events detected so far
75    event_count: u32,
76    /// Integration direction cached (+1 or -1)
77    direction: T,
78    /// Tolerance factor for root finding termination
79    rel_tol: T,
80    /// Absolute tolerance floor for root search
81    abs_tol: T,
82    /// State type marker
83    _marker: std::marker::PhantomData<Y>,
84}
85
86impl<'a, T: Real, Y: State<T>, E: Event<T, Y> + ?Sized> EventSolout<'a, T, Y, E> {
87    pub fn new(event: &'a E, t0: T, tf: T) -> Self {
88        let direction = (tf - t0).signum();
89        let config = event.config();
90        EventSolout {
91            event,
92            config,
93            last_g: None,
94            event_count: 0,
95            direction,
96            rel_tol: T::from_f64(1e-12).unwrap_or(T::default_epsilon()),
97            abs_tol: T::from_f64(1e-14).unwrap_or(T::default_epsilon()),
98            _marker: std::marker::PhantomData,
99        }
100    }
101
102    /// Brent-Dekker root finding for locating g(t)=0 within [a,b] where g(a)*g(b) <= 0.
103    /// Uses interpolator to obtain y(t) for evaluating g(t) = event.event(t, y(t)).
104    fn brent_dekker<I>(
105        &mut self,
106        mut a: T,
107        mut b: T,
108        mut fa: T,
109        mut fb: T,
110        interpolator: &mut I,
111    ) -> Option<T>
112    where
113        I: Interpolation<T, Y> + ?Sized,
114    {
115        // Ensure that |f(a)| < |f(b)| swapping if necessary
116        if fa.abs() < fb.abs() {
117            std::mem::swap(&mut a, &mut b);
118            std::mem::swap(&mut fa, &mut fb);
119        }
120
121        let mut c = a;
122        let mut fc = fa;
123        let mut d = b - a;
124        let mut e = d;
125
126        let one = T::one();
127        let two = T::from_f64(2.0).unwrap();
128        let half = one / two;
129        let three = T::from_f64(3.0).unwrap();
130
131        let max_iter = 50u32;
132        for _ in 0..max_iter {
133            if fb == T::zero() {
134                return Some(b);
135            }
136            if fa.signum() == fb.signum() {
137                // Rename a -> c
138                a = c;
139                fa = fc;
140                c = b;
141                fc = fb;
142                d = b - a;
143                e = d;
144            }
145            if fa.abs() < fb.abs() {
146                c = b;
147                b = a;
148                a = c;
149                fc = fb;
150                fb = fa;
151                fa = fc;
152            }
153
154            // Convergence check
155            let tol = self.abs_tol.max(self.rel_tol * b.abs());
156            let m = half * (a - b);
157            if m.abs() <= tol || fb == T::zero() {
158                return Some(b);
159            }
160
161            // Attempt inverse quadratic interpolation or secant
162            let mut use_bisection = true;
163            if e.abs() > tol && fa.abs() > fb.abs() {
164                // Inverse quadratic interpolation
165                let s = fb / fa;
166                let p;
167                let q;
168                if a == c {
169                    // Secant method
170                    p = two * m * s;
171                    q = one - s;
172                } else {
173                    // Inverse quadratic interpolation
174                    let q1 = fa / fc;
175                    let r = fb / fc;
176                    p = s * (two * m * q1 * (q1 - r) - (b - a) * (r - one));
177                    q = (q1 - one) * (r - one) * (s - one);
178                }
179                let mut q_mod = q;
180                let mut p_mod = p;
181                if q_mod > T::zero() {
182                    p_mod = -p_mod;
183                } else {
184                    q_mod = -q_mod;
185                }
186                // Accept interpolation only if conditions satisfied
187                if (two * p_mod).abs() < (three * m * q_mod - (tol * q_mod).abs())
188                    && p_mod < (e * half * q_mod).abs()
189                {
190                    e = d;
191                    d = p_mod / q_mod;
192                    use_bisection = false;
193                }
194            }
195            if use_bisection {
196                d = m;
197                e = m;
198            }
199            // Move last best b
200            a = b;
201            fa = fb;
202            if d.abs() > tol {
203                b += d;
204            } else {
205                b += if m > T::zero() { tol } else { -tol };
206            }
207            // Evaluate at new b via interpolation
208            let yb = interpolator.interpolate(b).ok()?;
209            fb = self.event.event(b, &yb);
210            c = a;
211            fc = fa;
212        }
213        None
214    }
215}
216
217impl<'a, T, Y, E> Solout<T, Y> for EventSolout<'a, T, Y, E>
218where
219    T: Real,
220    Y: State<T>,
221    E: Event<T, Y> + ?Sized,
222{
223    fn solout<I>(
224        &mut self,
225        t_curr: T,
226        t_prev: T,
227        y_curr: &Y,
228        y_prev: &Y,
229        interpolator: &mut I,
230        solution: &mut Solution<T, Y>,
231    ) -> ControlFlag<T, Y>
232    where
233        I: Interpolation<T, Y> + ?Sized,
234    {
235        // Evaluate event function at current endpoint
236        let g_curr = self.event.event(t_curr, y_curr);
237
238        // Initialize previous value if first call
239        let g_prev = match self.last_g {
240            Some(g) => g,
241            None => {
242                let g0 = self.event.event(t_prev, y_prev);
243                self.last_g = Some(g0);
244                // We don't attempt detection on first initialization
245                self.last_g = Some(g_curr);
246                return ControlFlag::Continue;
247            }
248        };
249
250        // Detect sign change according to direction config
251        let zero = T::zero();
252        let sign_change = g_prev.signum() != g_curr.signum();
253
254        let direction_ok = match self.config.direction {
255            CrossingDirection::Both => sign_change,
256            CrossingDirection::Positive => sign_change && g_prev < zero && g_curr >= zero,
257            CrossingDirection::Negative => sign_change && g_prev > zero && g_curr <= zero,
258        };
259
260        if direction_ok {
261            // Root find for precise event time
262            let (mut a, mut b, mut fa, mut fb) = (t_prev, t_curr, g_prev, g_curr);
263            // Ensure bracket ordering consistent with integration direction
264            if (self.direction > zero && a > b) || (self.direction < zero && a < b) {
265                std::mem::swap(&mut a, &mut b);
266                std::mem::swap(&mut fa, &mut fb);
267            }
268
269            // Only proceed if fa*fb <= 0
270            if fa * fb <= zero
271                && let Some(t_event) = self.brent_dekker(a, b, fa, fb, interpolator)
272            {
273                let y_event = interpolator.interpolate(t_event).unwrap();
274                // Avoid duplicate near-equal times
275                let push_point = match solution.t.last() {
276                    Some(&last_t) => (t_event - last_t).abs() > self.abs_tol,
277                    None => true,
278                };
279                if push_point {
280                    solution.push(t_event, y_event);
281                }
282                self.event_count += 1;
283
284                if let Some(limit) = self.config.terminate
285                    && self.event_count >= limit
286                {
287                    self.last_g = Some(g_curr);
288                    return ControlFlag::Terminate;
289                }
290            }
291        }
292
293        self.last_g = Some(g_curr);
294        ControlFlag::Continue
295    }
296}
297
298/// Wrapper solout that decorates an existing solout with event detection.
299pub struct EventWrappedSolout<'a, T: Real, Y: State<T>, O, E>
300where
301    O: Solout<T, Y>,
302    E: Event<T, Y> + ?Sized,
303{
304    base: O,
305    event: &'a E,
306    config: EventConfig,
307    last_g: Option<T>,
308    event_count: u32,
309    direction: T,
310    rel_tol: T,
311    abs_tol: T,
312    _marker: std::marker::PhantomData<Y>,
313}
314
315impl<'a, T: Real, Y: State<T>, O, E> EventWrappedSolout<'a, T, Y, O, E>
316where
317    O: Solout<T, Y>,
318    E: Event<T, Y> + ?Sized,
319{
320    pub fn new(base: O, event: &'a E, t0: T, tf: T) -> Self {
321        let config = event.config();
322        EventWrappedSolout {
323            base,
324            event,
325            config,
326            last_g: None,
327            event_count: 0,
328            direction: (tf - t0).signum(),
329            rel_tol: T::from_f64(1e-12).unwrap_or(T::default_epsilon()),
330            abs_tol: T::from_f64(1e-14).unwrap_or(T::default_epsilon()),
331            _marker: std::marker::PhantomData,
332        }
333    }
334
335    fn detect_event<I>(
336        &mut self,
337        t_curr: T,
338        t_prev: T,
339        y_curr: &Y,
340        y_prev: &Y,
341        interpolator: &mut I,
342        solution: &mut Solution<T, Y>,
343    ) -> ControlFlag<T, Y>
344    where
345        I: Interpolation<T, Y> + ?Sized,
346    {
347        let g_curr = self.event.event(t_curr, y_curr);
348        let g_prev = match self.last_g {
349            Some(g) => g,
350            None => {
351                let g0 = self.event.event(t_prev, y_prev);
352                self.last_g = Some(g0);
353                self.last_g = Some(g_curr);
354                return ControlFlag::Continue;
355            }
356        };
357
358        let zero = T::zero();
359        let sign_change = g_prev.signum() != g_curr.signum();
360        let direction_ok = match self.config.direction {
361            CrossingDirection::Both => sign_change,
362            CrossingDirection::Positive => sign_change && g_prev < zero && g_curr >= zero,
363            CrossingDirection::Negative => sign_change && g_prev > zero && g_curr <= zero,
364        };
365        if direction_ok {
366            let (mut a, mut b, mut fa, mut fb) = (t_prev, t_curr, g_prev, g_curr);
367            if (self.direction > zero && a > b) || (self.direction < zero && a < b) {
368                std::mem::swap(&mut a, &mut b);
369                std::mem::swap(&mut fa, &mut fb);
370            }
371            if fa * fb <= zero
372                && let Some(t_event) = self.brent_dekker(a, b, fa, fb, interpolator)
373            {
374                let y_event = interpolator.interpolate(t_event).unwrap();
375                let push_point = match solution.t.last() {
376                    Some(&last_t) => (t_event - last_t).abs() > self.abs_tol,
377                    None => true,
378                };
379                if push_point {
380                    solution.push(t_event, y_event);
381                }
382                self.event_count += 1;
383                if let Some(limit) = self.config.terminate
384                    && self.event_count >= limit
385                {
386                    self.last_g = Some(g_curr);
387                    return ControlFlag::Terminate;
388                }
389            }
390        }
391        self.last_g = Some(g_curr);
392        ControlFlag::Continue
393    }
394
395    fn brent_dekker<I>(
396        &mut self,
397        mut a: T,
398        mut b: T,
399        mut fa: T,
400        mut fb: T,
401        interpolator: &mut I,
402    ) -> Option<T>
403    where
404        I: Interpolation<T, Y> + ?Sized,
405    {
406        if fa.abs() < fb.abs() {
407            std::mem::swap(&mut a, &mut b);
408            std::mem::swap(&mut fa, &mut fb);
409        }
410        let mut c = a;
411        let mut fc = fa;
412        let mut d = b - a;
413        let mut e = d;
414        let one = T::one();
415        let two = T::from_f64(2.0).unwrap();
416        let half = one / two;
417        let three = T::from_f64(3.0).unwrap();
418        for _ in 0..50u32 {
419            if fb == T::zero() {
420                return Some(b);
421            }
422            if fa.signum() == fb.signum() {
423                a = c;
424                fa = fc;
425                c = b;
426                fc = fb;
427                d = b - a;
428                e = d;
429            }
430            if fa.abs() < fb.abs() {
431                c = b;
432                b = a;
433                a = c;
434                fc = fb;
435                fb = fa;
436                fa = fc;
437            }
438            let tol = self.abs_tol.max(self.rel_tol * b.abs());
439            let m = half * (a - b);
440            if m.abs() <= tol || fb == T::zero() {
441                return Some(b);
442            }
443            let mut use_bis = true;
444            if e.abs() > tol && fa.abs() > fb.abs() {
445                let s = fb / fa;
446                let p;
447                let q;
448                if a == c {
449                    p = two * m * s;
450                    q = one - s;
451                } else {
452                    let q1 = fa / fc;
453                    let r = fb / fc;
454                    p = s * (two * m * q1 * (q1 - r) - (b - a) * (r - one));
455                    q = (q1 - one) * (r - one) * (s - one);
456                }
457                let mut q_mod = q;
458                let mut p_mod = p;
459                if q_mod > T::zero() {
460                    p_mod = -p_mod;
461                } else {
462                    q_mod = -q_mod;
463                }
464                if (two * p_mod).abs() < (three * m * q_mod - (tol * q_mod).abs())
465                    && p_mod < (e * half * q_mod).abs()
466                {
467                    e = d;
468                    d = p_mod / q_mod;
469                    use_bis = false;
470                }
471            }
472            if use_bis {
473                d = m;
474                e = m;
475            }
476            a = b;
477            fa = fb;
478            b = if d.abs() > tol {
479                b + d
480            } else {
481                b + if m > T::zero() { tol } else { -tol }
482            };
483            let yb = interpolator.interpolate(b).ok()?;
484            fb = self.event.event(b, &yb);
485            c = a;
486            fc = fa;
487        }
488        None
489    }
490}
491
492impl<'a, T, Y, O, E> Solout<T, Y> for EventWrappedSolout<'a, T, Y, O, E>
493where
494    T: Real,
495    Y: State<T>,
496    O: Solout<T, Y>,
497    E: Event<T, Y> + ?Sized,
498{
499    fn solout<I>(
500        &mut self,
501        t_curr: T,
502        t_prev: T,
503        y_curr: &Y,
504        y_prev: &Y,
505        interpolator: &mut I,
506        solution: &mut Solution<T, Y>,
507    ) -> ControlFlag<T, Y>
508    where
509        I: Interpolation<T, Y> + ?Sized,
510    {
511        let flag = self
512            .base
513            .solout(t_curr, t_prev, y_curr, y_prev, interpolator, solution);
514        if let ControlFlag::Terminate = flag {
515            return flag;
516        }
517        let evt_flag = self.detect_event(t_curr, t_prev, y_curr, y_prev, interpolator, solution);
518        match evt_flag {
519            ControlFlag::Terminate => ControlFlag::Terminate,
520            _ => flag,
521        }
522    }
523}