1use super::*;
4use crate::traits::DefaultState;
5
6pub struct EventConfig {
7 pub direction: CrossingDirection,
9 pub terminate: Option<u32>,
11}
12
13impl Default for EventConfig {
14 fn default() -> Self {
15 Self {
16 direction: CrossingDirection::Both,
17 terminate: None,
18 }
19 }
20}
21
22impl EventConfig {
23 pub fn new(direction: impl Into<CrossingDirection>, terminate: Option<u32>) -> Self {
25 Self {
26 direction: direction.into(),
27 terminate,
28 }
29 }
30
31 pub fn direction(mut self, direction: impl Into<CrossingDirection>) -> Self {
32 self.direction = direction.into();
33 self
34 }
35
36 pub fn terminate_after(mut self, count: u32) -> Self {
38 self.terminate = Some(count);
39 self
40 }
41
42 pub fn terminal(mut self) -> Self {
44 self.terminate = Some(1);
45 self
46 }
47}
48
49pub trait Event<T: Real = f64, Y: State<T> = DefaultState<T>> {
50 fn config(&self) -> EventConfig {
52 EventConfig::default()
53 }
54
55 fn event(&self, t: T, y: &Y) -> T;
57}
58
59pub struct EventSolout<'a, T: Real, Y: State<T>, E: Event<T, Y> + ?Sized> {
68 event: &'a E,
70 config: EventConfig,
72 last_g: Option<T>,
74 event_count: u32,
76 direction: T,
78 rel_tol: T,
80 abs_tol: T,
82 _marker: std::marker::PhantomData<Y>,
84}
85
86impl<'a, T: Real, Y: State<T>, E: Event<T, Y> + ?Sized> EventSolout<'a, T, Y, E> {
87 pub fn new(event: &'a E, t0: T, tf: T) -> Self {
88 let direction = (tf - t0).signum();
89 let config = event.config();
90 EventSolout {
91 event,
92 config,
93 last_g: None,
94 event_count: 0,
95 direction,
96 rel_tol: T::from_f64(1e-12).unwrap_or(T::default_epsilon()),
97 abs_tol: T::from_f64(1e-14).unwrap_or(T::default_epsilon()),
98 _marker: std::marker::PhantomData,
99 }
100 }
101
102 fn brent_dekker<I>(
105 &mut self,
106 mut a: T,
107 mut b: T,
108 mut fa: T,
109 mut fb: T,
110 interpolator: &mut I,
111 ) -> Option<T>
112 where
113 I: Interpolation<T, Y> + ?Sized,
114 {
115 if fa.abs() < fb.abs() {
117 std::mem::swap(&mut a, &mut b);
118 std::mem::swap(&mut fa, &mut fb);
119 }
120
121 let mut c = a;
122 let mut fc = fa;
123 let mut d = b - a;
124 let mut e = d;
125
126 let one = T::one();
127 let two = T::from_f64(2.0).unwrap();
128 let half = one / two;
129 let three = T::from_f64(3.0).unwrap();
130
131 let max_iter = 50u32;
132 for _ in 0..max_iter {
133 if fb == T::zero() {
134 return Some(b);
135 }
136 if fa.signum() == fb.signum() {
137 a = c;
139 fa = fc;
140 c = b;
141 fc = fb;
142 d = b - a;
143 e = d;
144 }
145 if fa.abs() < fb.abs() {
146 c = b;
147 b = a;
148 a = c;
149 fc = fb;
150 fb = fa;
151 fa = fc;
152 }
153
154 let tol = self.abs_tol.max(self.rel_tol * b.abs());
156 let m = half * (a - b);
157 if m.abs() <= tol || fb == T::zero() {
158 return Some(b);
159 }
160
161 let mut use_bisection = true;
163 if e.abs() > tol && fa.abs() > fb.abs() {
164 let s = fb / fa;
166 let p;
167 let q;
168 if a == c {
169 p = two * m * s;
171 q = one - s;
172 } else {
173 let q1 = fa / fc;
175 let r = fb / fc;
176 p = s * (two * m * q1 * (q1 - r) - (b - a) * (r - one));
177 q = (q1 - one) * (r - one) * (s - one);
178 }
179 let mut q_mod = q;
180 let mut p_mod = p;
181 if q_mod > T::zero() {
182 p_mod = -p_mod;
183 } else {
184 q_mod = -q_mod;
185 }
186 if (two * p_mod).abs() < (three * m * q_mod - (tol * q_mod).abs())
188 && p_mod < (e * half * q_mod).abs()
189 {
190 e = d;
191 d = p_mod / q_mod;
192 use_bisection = false;
193 }
194 }
195 if use_bisection {
196 d = m;
197 e = m;
198 }
199 a = b;
201 fa = fb;
202 if d.abs() > tol {
203 b += d;
204 } else {
205 b += if m > T::zero() { tol } else { -tol };
206 }
207 let yb = interpolator.interpolate(b).ok()?;
209 fb = self.event.event(b, &yb);
210 c = a;
211 fc = fa;
212 }
213 None
214 }
215}
216
217impl<'a, T, Y, E> Solout<T, Y> for EventSolout<'a, T, Y, E>
218where
219 T: Real,
220 Y: State<T>,
221 E: Event<T, Y> + ?Sized,
222{
223 fn solout<I>(
224 &mut self,
225 t_curr: T,
226 t_prev: T,
227 y_curr: &Y,
228 y_prev: &Y,
229 interpolator: &mut I,
230 solution: &mut Solution<T, Y>,
231 ) -> ControlFlag<T, Y>
232 where
233 I: Interpolation<T, Y> + ?Sized,
234 {
235 let g_curr = self.event.event(t_curr, y_curr);
237
238 let g_prev = match self.last_g {
240 Some(g) => g,
241 None => {
242 let g0 = self.event.event(t_prev, y_prev);
243 self.last_g = Some(g0);
244 self.last_g = Some(g_curr);
246 return ControlFlag::Continue;
247 }
248 };
249
250 let zero = T::zero();
252 let sign_change = g_prev.signum() != g_curr.signum();
253
254 let direction_ok = match self.config.direction {
255 CrossingDirection::Both => sign_change,
256 CrossingDirection::Positive => sign_change && g_prev < zero && g_curr >= zero,
257 CrossingDirection::Negative => sign_change && g_prev > zero && g_curr <= zero,
258 };
259
260 if direction_ok {
261 let (mut a, mut b, mut fa, mut fb) = (t_prev, t_curr, g_prev, g_curr);
263 if (self.direction > zero && a > b) || (self.direction < zero && a < b) {
265 std::mem::swap(&mut a, &mut b);
266 std::mem::swap(&mut fa, &mut fb);
267 }
268
269 if fa * fb <= zero
271 && let Some(t_event) = self.brent_dekker(a, b, fa, fb, interpolator)
272 {
273 let y_event = interpolator.interpolate(t_event).unwrap();
274 let push_point = match solution.t.last() {
276 Some(&last_t) => (t_event - last_t).abs() > self.abs_tol,
277 None => true,
278 };
279 if push_point {
280 solution.push(t_event, y_event);
281 }
282 self.event_count += 1;
283
284 if let Some(limit) = self.config.terminate
285 && self.event_count >= limit
286 {
287 self.last_g = Some(g_curr);
288 return ControlFlag::Terminate;
289 }
290 }
291 }
292
293 self.last_g = Some(g_curr);
294 ControlFlag::Continue
295 }
296}
297
298pub struct EventWrappedSolout<'a, T: Real, Y: State<T>, O, E>
300where
301 O: Solout<T, Y>,
302 E: Event<T, Y> + ?Sized,
303{
304 base: O,
305 event: &'a E,
306 config: EventConfig,
307 last_g: Option<T>,
308 event_count: u32,
309 direction: T,
310 rel_tol: T,
311 abs_tol: T,
312 _marker: std::marker::PhantomData<Y>,
313}
314
315impl<'a, T: Real, Y: State<T>, O, E> EventWrappedSolout<'a, T, Y, O, E>
316where
317 O: Solout<T, Y>,
318 E: Event<T, Y> + ?Sized,
319{
320 pub fn new(base: O, event: &'a E, t0: T, tf: T) -> Self {
321 let config = event.config();
322 EventWrappedSolout {
323 base,
324 event,
325 config,
326 last_g: None,
327 event_count: 0,
328 direction: (tf - t0).signum(),
329 rel_tol: T::from_f64(1e-12).unwrap_or(T::default_epsilon()),
330 abs_tol: T::from_f64(1e-14).unwrap_or(T::default_epsilon()),
331 _marker: std::marker::PhantomData,
332 }
333 }
334
335 fn detect_event<I>(
336 &mut self,
337 t_curr: T,
338 t_prev: T,
339 y_curr: &Y,
340 y_prev: &Y,
341 interpolator: &mut I,
342 solution: &mut Solution<T, Y>,
343 ) -> ControlFlag<T, Y>
344 where
345 I: Interpolation<T, Y> + ?Sized,
346 {
347 let g_curr = self.event.event(t_curr, y_curr);
348 let g_prev = match self.last_g {
349 Some(g) => g,
350 None => {
351 let g0 = self.event.event(t_prev, y_prev);
352 self.last_g = Some(g0);
353 self.last_g = Some(g_curr);
354 return ControlFlag::Continue;
355 }
356 };
357
358 let zero = T::zero();
359 let sign_change = g_prev.signum() != g_curr.signum();
360 let direction_ok = match self.config.direction {
361 CrossingDirection::Both => sign_change,
362 CrossingDirection::Positive => sign_change && g_prev < zero && g_curr >= zero,
363 CrossingDirection::Negative => sign_change && g_prev > zero && g_curr <= zero,
364 };
365 if direction_ok {
366 let (mut a, mut b, mut fa, mut fb) = (t_prev, t_curr, g_prev, g_curr);
367 if (self.direction > zero && a > b) || (self.direction < zero && a < b) {
368 std::mem::swap(&mut a, &mut b);
369 std::mem::swap(&mut fa, &mut fb);
370 }
371 if fa * fb <= zero
372 && let Some(t_event) = self.brent_dekker(a, b, fa, fb, interpolator)
373 {
374 let y_event = interpolator.interpolate(t_event).unwrap();
375 let push_point = match solution.t.last() {
376 Some(&last_t) => (t_event - last_t).abs() > self.abs_tol,
377 None => true,
378 };
379 if push_point {
380 solution.push(t_event, y_event);
381 }
382 self.event_count += 1;
383 if let Some(limit) = self.config.terminate
384 && self.event_count >= limit
385 {
386 self.last_g = Some(g_curr);
387 return ControlFlag::Terminate;
388 }
389 }
390 }
391 self.last_g = Some(g_curr);
392 ControlFlag::Continue
393 }
394
395 fn brent_dekker<I>(
396 &mut self,
397 mut a: T,
398 mut b: T,
399 mut fa: T,
400 mut fb: T,
401 interpolator: &mut I,
402 ) -> Option<T>
403 where
404 I: Interpolation<T, Y> + ?Sized,
405 {
406 if fa.abs() < fb.abs() {
407 std::mem::swap(&mut a, &mut b);
408 std::mem::swap(&mut fa, &mut fb);
409 }
410 let mut c = a;
411 let mut fc = fa;
412 let mut d = b - a;
413 let mut e = d;
414 let one = T::one();
415 let two = T::from_f64(2.0).unwrap();
416 let half = one / two;
417 let three = T::from_f64(3.0).unwrap();
418 for _ in 0..50u32 {
419 if fb == T::zero() {
420 return Some(b);
421 }
422 if fa.signum() == fb.signum() {
423 a = c;
424 fa = fc;
425 c = b;
426 fc = fb;
427 d = b - a;
428 e = d;
429 }
430 if fa.abs() < fb.abs() {
431 c = b;
432 b = a;
433 a = c;
434 fc = fb;
435 fb = fa;
436 fa = fc;
437 }
438 let tol = self.abs_tol.max(self.rel_tol * b.abs());
439 let m = half * (a - b);
440 if m.abs() <= tol || fb == T::zero() {
441 return Some(b);
442 }
443 let mut use_bis = true;
444 if e.abs() > tol && fa.abs() > fb.abs() {
445 let s = fb / fa;
446 let p;
447 let q;
448 if a == c {
449 p = two * m * s;
450 q = one - s;
451 } else {
452 let q1 = fa / fc;
453 let r = fb / fc;
454 p = s * (two * m * q1 * (q1 - r) - (b - a) * (r - one));
455 q = (q1 - one) * (r - one) * (s - one);
456 }
457 let mut q_mod = q;
458 let mut p_mod = p;
459 if q_mod > T::zero() {
460 p_mod = -p_mod;
461 } else {
462 q_mod = -q_mod;
463 }
464 if (two * p_mod).abs() < (three * m * q_mod - (tol * q_mod).abs())
465 && p_mod < (e * half * q_mod).abs()
466 {
467 e = d;
468 d = p_mod / q_mod;
469 use_bis = false;
470 }
471 }
472 if use_bis {
473 d = m;
474 e = m;
475 }
476 a = b;
477 fa = fb;
478 b = if d.abs() > tol {
479 b + d
480 } else {
481 b + if m > T::zero() { tol } else { -tol }
482 };
483 let yb = interpolator.interpolate(b).ok()?;
484 fb = self.event.event(b, &yb);
485 c = a;
486 fc = fa;
487 }
488 None
489 }
490}
491
492impl<'a, T, Y, O, E> Solout<T, Y> for EventWrappedSolout<'a, T, Y, O, E>
493where
494 T: Real,
495 Y: State<T>,
496 O: Solout<T, Y>,
497 E: Event<T, Y> + ?Sized,
498{
499 fn solout<I>(
500 &mut self,
501 t_curr: T,
502 t_prev: T,
503 y_curr: &Y,
504 y_prev: &Y,
505 interpolator: &mut I,
506 solution: &mut Solution<T, Y>,
507 ) -> ControlFlag<T, Y>
508 where
509 I: Interpolation<T, Y> + ?Sized,
510 {
511 let flag = self
512 .base
513 .solout(t_curr, t_prev, y_curr, y_prev, interpolator, solution);
514 if let ControlFlag::Terminate = flag {
515 return flag;
516 }
517 let evt_flag = self.detect_event(t_curr, t_prev, y_curr, y_prev, interpolator, solution);
518 match evt_flag {
519 ControlFlag::Terminate => ControlFlag::Terminate,
520 _ => flag,
521 }
522 }
523}