differential_equations/solout/
event.rs

1//! Solout implementation which takes a Event object and uses it to detect events
2
3use super::*;
4
5pub struct EventConfig {
6    /// Direction of zero crossing to detect
7    pub direction: CrossingDirection,
8    /// Number of events before termination
9    pub terminate: Option<u32>,
10}
11
12impl Default for EventConfig {
13    fn default() -> Self {
14        Self {
15            direction: CrossingDirection::Both,
16            terminate: None,
17        }
18    }
19}
20
21impl EventConfig {
22    /// Create a new EventConfig with specified direction and termination count
23    pub fn new(direction: impl Into<CrossingDirection>, terminate: Option<u32>) -> Self {
24        Self {
25            direction: direction.into(),
26            terminate,
27        }
28    }
29
30    pub fn direction(mut self, direction: impl Into<CrossingDirection>) -> Self {
31        self.direction = direction.into();
32        self
33    }
34
35    /// Set the number of events before termination
36    pub fn terminate_after(mut self, count: u32) -> Self {
37        self.terminate = Some(count);
38        self
39    }
40
41    /// Set to terminate after the first event
42    pub fn terminal(mut self) -> Self {
43        self.terminate = Some(1);
44        self
45    }
46}
47
48pub trait Event<T: Real = f64, Y: State<T> = f64> {
49    /// Configure the event detection parameters (called once at initialization).
50    fn config(&self) -> EventConfig {
51        EventConfig::default()
52    }
53
54    /// Event function g(t,y) whose zero crossings are detected.
55    fn event(&self, t: T, y: &Y) -> T;
56}
57
58/// Solout implementation that evaluates user-provided Event objects similar to SciPy events.
59///
60/// The `EventSolout` monitors the sign of `event.event(t, y)` across solver steps. When a sign
61/// change consistent with the configured `CrossingDirection` is detected, a Brent-Dekker root
62/// finding procedure is applied (using the solver's interpolation) to locate the event time with
63/// high accuracy. The event point `(t_event, y_event)` is then appended to the solution. Depending
64/// on the `EventConfig::terminate` setting the integration may terminate after collecting the
65/// specified number of events.
66pub struct EventSolout<'a, T: Real, Y: State<T>, E: Event<T, Y>> {
67    /// User provided event object implementing `Event`
68    event: &'a E,
69    /// Configuration (direction filtering and termination count)
70    config: EventConfig,
71    /// Last event function value g(t_prev, y_prev)
72    last_g: Option<T>,
73    /// Number of events detected so far
74    event_count: u32,
75    /// Integration direction cached (+1 or -1)
76    direction: T,
77    /// Tolerance factor for root finding termination
78    rel_tol: T,
79    /// Absolute tolerance floor for root search
80    abs_tol: T,
81    /// State type marker
82    _marker: std::marker::PhantomData<Y>,
83}
84
85impl<'a, T: Real, Y: State<T>, E: Event<T, Y>> EventSolout<'a, T, Y, E> {
86    pub fn new(event: &'a E, t0: T, tf: T) -> Self {
87        let direction = (tf - t0).signum();
88        let config = event.config();
89        EventSolout {
90            event,
91            config,
92            last_g: None,
93            event_count: 0,
94            direction,
95            rel_tol: T::from_f64(1e-12).unwrap_or(T::default_epsilon()),
96            abs_tol: T::from_f64(1e-14).unwrap_or(T::default_epsilon()),
97            _marker: std::marker::PhantomData,
98        }
99    }
100
101    /// Brent-Dekker root finding for locating g(t)=0 within [a,b] where g(a)*g(b) <= 0.
102    /// Uses interpolator to obtain y(t) for evaluating g(t) = event.event(t, y(t)).
103    fn brent_dekker<I>(
104        &mut self,
105        mut a: T,
106        mut b: T,
107        mut fa: T,
108        mut fb: T,
109        interpolator: &mut I,
110    ) -> Option<T>
111    where
112        I: Interpolation<T, Y>,
113    {
114        // Ensure that |f(a)| < |f(b)| swapping if necessary
115        if fa.abs() < fb.abs() {
116            std::mem::swap(&mut a, &mut b);
117            std::mem::swap(&mut fa, &mut fb);
118        }
119
120        let mut c = a;
121        let mut fc = fa;
122        let mut d = b - a;
123        let mut e = d;
124
125        let one = T::one();
126        let two = T::from_f64(2.0).unwrap();
127        let half = one / two;
128        let three = T::from_f64(3.0).unwrap();
129
130        let max_iter = 50u32;
131        for _ in 0..max_iter {
132            if fb == T::zero() {
133                return Some(b);
134            }
135            if fa.signum() == fb.signum() {
136                // Rename a -> c
137                a = c;
138                fa = fc;
139                c = b;
140                fc = fb;
141                d = b - a;
142                e = d;
143            }
144            if fa.abs() < fb.abs() {
145                c = b;
146                b = a;
147                a = c;
148                fc = fb;
149                fb = fa;
150                fa = fc;
151            }
152
153            // Convergence check
154            let tol = self.abs_tol.max(self.rel_tol * b.abs());
155            let m = half * (a - b);
156            if m.abs() <= tol || fb == T::zero() {
157                return Some(b);
158            }
159
160            // Attempt inverse quadratic interpolation or secant
161            let mut use_bisection = true;
162            if e.abs() > tol && fa.abs() > fb.abs() {
163                // Inverse quadratic interpolation
164                let s = fb / fa;
165                let p;
166                let q;
167                if a == c {
168                    // Secant method
169                    p = two * m * s;
170                    q = one - s;
171                } else {
172                    // Inverse quadratic interpolation
173                    let q1 = fa / fc;
174                    let r = fb / fc;
175                    p = s * (two * m * q1 * (q1 - r) - (b - a) * (r - one));
176                    q = (q1 - one) * (r - one) * (s - one);
177                }
178                let mut q_mod = q;
179                let mut p_mod = p;
180                if q_mod > T::zero() {
181                    p_mod = -p_mod;
182                } else {
183                    q_mod = -q_mod;
184                }
185                // Accept interpolation only if conditions satisfied
186                if (two * p_mod).abs() < (three * m * q_mod - (tol * q_mod).abs())
187                    && p_mod < (e * half * q_mod).abs()
188                {
189                    e = d;
190                    d = p_mod / q_mod;
191                    use_bisection = false;
192                }
193            }
194            if use_bisection {
195                d = m;
196                e = m;
197            }
198            // Move last best b
199            a = b;
200            fa = fb;
201            if d.abs() > tol {
202                b = b + d;
203            } else {
204                b = b + if m > T::zero() { tol } else { -tol };
205            }
206            // Evaluate at new b via interpolation
207            let yb = interpolator.interpolate(b).ok()?;
208            fb = self.event.event(b, &yb);
209            c = a;
210            fc = fa;
211        }
212        None
213    }
214}
215
216impl<'a, T, Y, E> Solout<T, Y> for EventSolout<'a, T, Y, E>
217where
218    T: Real,
219    Y: State<T>,
220    E: Event<T, Y>,
221{
222    fn solout<I>(
223        &mut self,
224        t_curr: T,
225        t_prev: T,
226        y_curr: &Y,
227        y_prev: &Y,
228        interpolator: &mut I,
229        solution: &mut Solution<T, Y>,
230    ) -> ControlFlag<T, Y>
231    where
232        I: Interpolation<T, Y>,
233    {
234        // Evaluate event function at current endpoint
235        let g_curr = self.event.event(t_curr, y_curr);
236
237        // Initialize previous value if first call
238        let g_prev = match self.last_g {
239            Some(g) => g,
240            None => {
241                let g0 = self.event.event(t_prev, y_prev);
242                self.last_g = Some(g0);
243                // We don't attempt detection on first initialization
244                self.last_g = Some(g_curr);
245                return ControlFlag::Continue;
246            }
247        };
248
249        // Detect sign change according to direction config
250        let zero = T::zero();
251        let sign_change = g_prev.signum() != g_curr.signum();
252
253        let direction_ok = match self.config.direction {
254            CrossingDirection::Both => sign_change,
255            CrossingDirection::Positive => sign_change && g_prev < zero && g_curr >= zero,
256            CrossingDirection::Negative => sign_change && g_prev > zero && g_curr <= zero,
257        };
258
259        if direction_ok {
260            // Root find for precise event time
261            let (mut a, mut b, mut fa, mut fb) = (t_prev, t_curr, g_prev, g_curr);
262            // Ensure bracket ordering consistent with integration direction
263            if (self.direction > zero && a > b) || (self.direction < zero && a < b) {
264                std::mem::swap(&mut a, &mut b);
265                std::mem::swap(&mut fa, &mut fb);
266            }
267
268            // Only proceed if fa*fb <= 0
269            if fa * fb <= zero {
270                if let Some(t_event) = self.brent_dekker(a, b, fa, fb, interpolator) {
271                    let y_event = interpolator.interpolate(t_event).unwrap();
272                    // Avoid duplicate near-equal times
273                    let push_point = match solution.t.last() {
274                        Some(&last_t) => (t_event - last_t).abs() > self.abs_tol,
275                        None => true,
276                    };
277                    if push_point {
278                        solution.push(t_event, y_event);
279                    }
280                    self.event_count += 1;
281
282                    if let Some(limit) = self.config.terminate {
283                        if self.event_count >= limit {
284                            self.last_g = Some(g_curr);
285                            return ControlFlag::Terminate;
286                        }
287                    }
288                }
289            }
290        }
291
292        self.last_g = Some(g_curr);
293        ControlFlag::Continue
294    }
295}
296
297/// Wrapper solout that decorates an existing solout with event detection.
298pub struct EventWrappedSolout<'a, T: Real, Y: State<T>, O, E>
299where
300    O: Solout<T, Y>,
301    E: Event<T, Y>,
302{
303    base: O,
304    event: &'a E,
305    config: EventConfig,
306    last_g: Option<T>,
307    event_count: u32,
308    direction: T,
309    rel_tol: T,
310    abs_tol: T,
311    _marker: std::marker::PhantomData<Y>,
312}
313
314impl<'a, T: Real, Y: State<T>, O, E> EventWrappedSolout<'a, T, Y, O, E>
315where
316    O: Solout<T, Y>,
317    E: Event<T, Y>,
318{
319    pub fn new(base: O, event: &'a E, t0: T, tf: T) -> Self {
320        let config = event.config();
321        EventWrappedSolout {
322            base,
323            event,
324            config,
325            last_g: None,
326            event_count: 0,
327            direction: (tf - t0).signum(),
328            rel_tol: T::from_f64(1e-12).unwrap_or(T::default_epsilon()),
329            abs_tol: T::from_f64(1e-14).unwrap_or(T::default_epsilon()),
330            _marker: std::marker::PhantomData,
331        }
332    }
333
334    fn detect_event<I>(
335        &mut self,
336        t_curr: T,
337        t_prev: T,
338        y_curr: &Y,
339        y_prev: &Y,
340        interpolator: &mut I,
341        solution: &mut Solution<T, Y>,
342    ) -> ControlFlag<T, Y>
343    where
344        I: Interpolation<T, Y>,
345    {
346        let g_curr = self.event.event(t_curr, y_curr);
347        let g_prev = match self.last_g {
348            Some(g) => g,
349            None => {
350                let g0 = self.event.event(t_prev, y_prev);
351                self.last_g = Some(g0);
352                self.last_g = Some(g_curr);
353                return ControlFlag::Continue;
354            }
355        };
356
357        let zero = T::zero();
358        let sign_change = g_prev.signum() != g_curr.signum();
359        let direction_ok = match self.config.direction {
360            CrossingDirection::Both => sign_change,
361            CrossingDirection::Positive => sign_change && g_prev < zero && g_curr >= zero,
362            CrossingDirection::Negative => sign_change && g_prev > zero && g_curr <= zero,
363        };
364        if direction_ok {
365            let (mut a, mut b, mut fa, mut fb) = (t_prev, t_curr, g_prev, g_curr);
366            if (self.direction > zero && a > b) || (self.direction < zero && a < b) {
367                std::mem::swap(&mut a, &mut b);
368                std::mem::swap(&mut fa, &mut fb);
369            }
370            if fa * fb <= zero {
371                if let Some(t_event) = self.brent_dekker(a, b, fa, fb, interpolator) {
372                    let y_event = interpolator.interpolate(t_event).unwrap();
373                    let push_point = match solution.t.last() {
374                        Some(&last_t) => (t_event - last_t).abs() > self.abs_tol,
375                        None => true,
376                    };
377                    if push_point {
378                        solution.push(t_event, y_event);
379                    }
380                    self.event_count += 1;
381                    if let Some(limit) = self.config.terminate {
382                        if self.event_count >= limit {
383                            self.last_g = Some(g_curr);
384                            return ControlFlag::Terminate;
385                        }
386                    }
387                }
388            }
389        }
390        self.last_g = Some(g_curr);
391        ControlFlag::Continue
392    }
393
394    fn brent_dekker<I>(
395        &mut self,
396        mut a: T,
397        mut b: T,
398        mut fa: T,
399        mut fb: T,
400        interpolator: &mut I,
401    ) -> Option<T>
402    where
403        I: Interpolation<T, Y>,
404    {
405        if fa.abs() < fb.abs() {
406            std::mem::swap(&mut a, &mut b);
407            std::mem::swap(&mut fa, &mut fb);
408        }
409        let mut c = a;
410        let mut fc = fa;
411        let mut d = b - a;
412        let mut e = d;
413        let one = T::one();
414        let two = T::from_f64(2.0).unwrap();
415        let half = one / two;
416        let three = T::from_f64(3.0).unwrap();
417        for _ in 0..50u32 {
418            if fb == T::zero() {
419                return Some(b);
420            }
421            if fa.signum() == fb.signum() {
422                a = c;
423                fa = fc;
424                c = b;
425                fc = fb;
426                d = b - a;
427                e = d;
428            }
429            if fa.abs() < fb.abs() {
430                c = b;
431                b = a;
432                a = c;
433                fc = fb;
434                fb = fa;
435                fa = fc;
436            }
437            let tol = self.abs_tol.max(self.rel_tol * b.abs());
438            let m = half * (a - b);
439            if m.abs() <= tol || fb == T::zero() {
440                return Some(b);
441            }
442            let mut use_bis = true;
443            if e.abs() > tol && fa.abs() > fb.abs() {
444                let s = fb / fa;
445                let p;
446                let q;
447                if a == c {
448                    p = two * m * s;
449                    q = one - s;
450                } else {
451                    let q1 = fa / fc;
452                    let r = fb / fc;
453                    p = s * (two * m * q1 * (q1 - r) - (b - a) * (r - one));
454                    q = (q1 - one) * (r - one) * (s - one);
455                }
456                let mut q_mod = q;
457                let mut p_mod = p;
458                if q_mod > T::zero() {
459                    p_mod = -p_mod;
460                } else {
461                    q_mod = -q_mod;
462                }
463                if (two * p_mod).abs() < (three * m * q_mod - (tol * q_mod).abs())
464                    && p_mod < (e * half * q_mod).abs()
465                {
466                    e = d;
467                    d = p_mod / q_mod;
468                    use_bis = false;
469                }
470            }
471            if use_bis {
472                d = m;
473                e = m;
474            }
475            a = b;
476            fa = fb;
477            b = if d.abs() > tol {
478                b + d
479            } else {
480                b + if m > T::zero() { tol } else { -tol }
481            };
482            let yb = interpolator.interpolate(b).ok()?;
483            fb = self.event.event(b, &yb);
484            c = a;
485            fc = fa;
486        }
487        None
488    }
489}
490
491impl<'a, T, Y, O, E> Solout<T, Y> for EventWrappedSolout<'a, T, Y, O, E>
492where
493    T: Real,
494    Y: State<T>,
495    O: Solout<T, Y>,
496    E: Event<T, Y>,
497{
498    fn solout<I>(
499        &mut self,
500        t_curr: T,
501        t_prev: T,
502        y_curr: &Y,
503        y_prev: &Y,
504        interpolator: &mut I,
505        solution: &mut Solution<T, Y>,
506    ) -> ControlFlag<T, Y>
507    where
508        I: Interpolation<T, Y>,
509    {
510        let flag = self
511            .base
512            .solout(t_curr, t_prev, y_curr, y_prev, interpolator, solution);
513        if let ControlFlag::Terminate = flag {
514            return flag;
515        }
516        let evt_flag = self.detect_event(t_curr, t_prev, y_curr, y_prev, interpolator, solution);
517        match evt_flag {
518            ControlFlag::Terminate => ControlFlag::Terminate,
519            _ => flag,
520        }
521    }
522}