1use super::*;
4
5pub struct EventConfig {
6 pub direction: CrossingDirection,
8 pub terminate: Option<u32>,
10}
11
12impl Default for EventConfig {
13 fn default() -> Self {
14 Self {
15 direction: CrossingDirection::Both,
16 terminate: None,
17 }
18 }
19}
20
21impl EventConfig {
22 pub fn new(direction: impl Into<CrossingDirection>, terminate: Option<u32>) -> Self {
24 Self {
25 direction: direction.into(),
26 terminate,
27 }
28 }
29
30 pub fn direction(mut self, direction: impl Into<CrossingDirection>) -> Self {
31 self.direction = direction.into();
32 self
33 }
34
35 pub fn terminate_after(mut self, count: u32) -> Self {
37 self.terminate = Some(count);
38 self
39 }
40
41 pub fn terminal(mut self) -> Self {
43 self.terminate = Some(1);
44 self
45 }
46}
47
48pub trait Event<T: Real = f64, Y: State<T> = f64> {
49 fn config(&self) -> EventConfig {
51 EventConfig::default()
52 }
53
54 fn event(&self, t: T, y: &Y) -> T;
56}
57
58pub struct EventSolout<'a, T: Real, Y: State<T>, E: Event<T, Y>> {
67 event: &'a E,
69 config: EventConfig,
71 last_g: Option<T>,
73 event_count: u32,
75 direction: T,
77 rel_tol: T,
79 abs_tol: T,
81 _marker: std::marker::PhantomData<Y>,
83}
84
85impl<'a, T: Real, Y: State<T>, E: Event<T, Y>> EventSolout<'a, T, Y, E> {
86 pub fn new(event: &'a E, t0: T, tf: T) -> Self {
87 let direction = (tf - t0).signum();
88 let config = event.config();
89 EventSolout {
90 event,
91 config,
92 last_g: None,
93 event_count: 0,
94 direction,
95 rel_tol: T::from_f64(1e-12).unwrap_or(T::default_epsilon()),
96 abs_tol: T::from_f64(1e-14).unwrap_or(T::default_epsilon()),
97 _marker: std::marker::PhantomData,
98 }
99 }
100
101 fn brent_dekker<I>(
104 &mut self,
105 mut a: T,
106 mut b: T,
107 mut fa: T,
108 mut fb: T,
109 interpolator: &mut I,
110 ) -> Option<T>
111 where
112 I: Interpolation<T, Y>,
113 {
114 if fa.abs() < fb.abs() {
116 std::mem::swap(&mut a, &mut b);
117 std::mem::swap(&mut fa, &mut fb);
118 }
119
120 let mut c = a;
121 let mut fc = fa;
122 let mut d = b - a;
123 let mut e = d;
124
125 let one = T::one();
126 let two = T::from_f64(2.0).unwrap();
127 let half = one / two;
128 let three = T::from_f64(3.0).unwrap();
129
130 let max_iter = 50u32;
131 for _ in 0..max_iter {
132 if fb == T::zero() {
133 return Some(b);
134 }
135 if fa.signum() == fb.signum() {
136 a = c;
138 fa = fc;
139 c = b;
140 fc = fb;
141 d = b - a;
142 e = d;
143 }
144 if fa.abs() < fb.abs() {
145 c = b;
146 b = a;
147 a = c;
148 fc = fb;
149 fb = fa;
150 fa = fc;
151 }
152
153 let tol = self.abs_tol.max(self.rel_tol * b.abs());
155 let m = half * (a - b);
156 if m.abs() <= tol || fb == T::zero() {
157 return Some(b);
158 }
159
160 let mut use_bisection = true;
162 if e.abs() > tol && fa.abs() > fb.abs() {
163 let s = fb / fa;
165 let p;
166 let q;
167 if a == c {
168 p = two * m * s;
170 q = one - s;
171 } else {
172 let q1 = fa / fc;
174 let r = fb / fc;
175 p = s * (two * m * q1 * (q1 - r) - (b - a) * (r - one));
176 q = (q1 - one) * (r - one) * (s - one);
177 }
178 let mut q_mod = q;
179 let mut p_mod = p;
180 if q_mod > T::zero() {
181 p_mod = -p_mod;
182 } else {
183 q_mod = -q_mod;
184 }
185 if (two * p_mod).abs() < (three * m * q_mod - (tol * q_mod).abs())
187 && p_mod < (e * half * q_mod).abs()
188 {
189 e = d;
190 d = p_mod / q_mod;
191 use_bisection = false;
192 }
193 }
194 if use_bisection {
195 d = m;
196 e = m;
197 }
198 a = b;
200 fa = fb;
201 if d.abs() > tol {
202 b = b + d;
203 } else {
204 b = b + if m > T::zero() { tol } else { -tol };
205 }
206 let yb = interpolator.interpolate(b).ok()?;
208 fb = self.event.event(b, &yb);
209 c = a;
210 fc = fa;
211 }
212 None
213 }
214}
215
216impl<'a, T, Y, E> Solout<T, Y> for EventSolout<'a, T, Y, E>
217where
218 T: Real,
219 Y: State<T>,
220 E: Event<T, Y>,
221{
222 fn solout<I>(
223 &mut self,
224 t_curr: T,
225 t_prev: T,
226 y_curr: &Y,
227 y_prev: &Y,
228 interpolator: &mut I,
229 solution: &mut Solution<T, Y>,
230 ) -> ControlFlag<T, Y>
231 where
232 I: Interpolation<T, Y>,
233 {
234 let g_curr = self.event.event(t_curr, y_curr);
236
237 let g_prev = match self.last_g {
239 Some(g) => g,
240 None => {
241 let g0 = self.event.event(t_prev, y_prev);
242 self.last_g = Some(g0);
243 self.last_g = Some(g_curr);
245 return ControlFlag::Continue;
246 }
247 };
248
249 let zero = T::zero();
251 let sign_change = g_prev.signum() != g_curr.signum();
252
253 let direction_ok = match self.config.direction {
254 CrossingDirection::Both => sign_change,
255 CrossingDirection::Positive => sign_change && g_prev < zero && g_curr >= zero,
256 CrossingDirection::Negative => sign_change && g_prev > zero && g_curr <= zero,
257 };
258
259 if direction_ok {
260 let (mut a, mut b, mut fa, mut fb) = (t_prev, t_curr, g_prev, g_curr);
262 if (self.direction > zero && a > b) || (self.direction < zero && a < b) {
264 std::mem::swap(&mut a, &mut b);
265 std::mem::swap(&mut fa, &mut fb);
266 }
267
268 if fa * fb <= zero {
270 if let Some(t_event) = self.brent_dekker(a, b, fa, fb, interpolator) {
271 let y_event = interpolator.interpolate(t_event).unwrap();
272 let push_point = match solution.t.last() {
274 Some(&last_t) => (t_event - last_t).abs() > self.abs_tol,
275 None => true,
276 };
277 if push_point {
278 solution.push(t_event, y_event);
279 }
280 self.event_count += 1;
281
282 if let Some(limit) = self.config.terminate {
283 if self.event_count >= limit {
284 self.last_g = Some(g_curr);
285 return ControlFlag::Terminate;
286 }
287 }
288 }
289 }
290 }
291
292 self.last_g = Some(g_curr);
293 ControlFlag::Continue
294 }
295}
296
297pub struct EventWrappedSolout<'a, T: Real, Y: State<T>, O, E>
299where
300 O: Solout<T, Y>,
301 E: Event<T, Y>,
302{
303 base: O,
304 event: &'a E,
305 config: EventConfig,
306 last_g: Option<T>,
307 event_count: u32,
308 direction: T,
309 rel_tol: T,
310 abs_tol: T,
311 _marker: std::marker::PhantomData<Y>,
312}
313
314impl<'a, T: Real, Y: State<T>, O, E> EventWrappedSolout<'a, T, Y, O, E>
315where
316 O: Solout<T, Y>,
317 E: Event<T, Y>,
318{
319 pub fn new(base: O, event: &'a E, t0: T, tf: T) -> Self {
320 let config = event.config();
321 EventWrappedSolout {
322 base,
323 event,
324 config,
325 last_g: None,
326 event_count: 0,
327 direction: (tf - t0).signum(),
328 rel_tol: T::from_f64(1e-12).unwrap_or(T::default_epsilon()),
329 abs_tol: T::from_f64(1e-14).unwrap_or(T::default_epsilon()),
330 _marker: std::marker::PhantomData,
331 }
332 }
333
334 fn detect_event<I>(
335 &mut self,
336 t_curr: T,
337 t_prev: T,
338 y_curr: &Y,
339 y_prev: &Y,
340 interpolator: &mut I,
341 solution: &mut Solution<T, Y>,
342 ) -> ControlFlag<T, Y>
343 where
344 I: Interpolation<T, Y>,
345 {
346 let g_curr = self.event.event(t_curr, y_curr);
347 let g_prev = match self.last_g {
348 Some(g) => g,
349 None => {
350 let g0 = self.event.event(t_prev, y_prev);
351 self.last_g = Some(g0);
352 self.last_g = Some(g_curr);
353 return ControlFlag::Continue;
354 }
355 };
356
357 let zero = T::zero();
358 let sign_change = g_prev.signum() != g_curr.signum();
359 let direction_ok = match self.config.direction {
360 CrossingDirection::Both => sign_change,
361 CrossingDirection::Positive => sign_change && g_prev < zero && g_curr >= zero,
362 CrossingDirection::Negative => sign_change && g_prev > zero && g_curr <= zero,
363 };
364 if direction_ok {
365 let (mut a, mut b, mut fa, mut fb) = (t_prev, t_curr, g_prev, g_curr);
366 if (self.direction > zero && a > b) || (self.direction < zero && a < b) {
367 std::mem::swap(&mut a, &mut b);
368 std::mem::swap(&mut fa, &mut fb);
369 }
370 if fa * fb <= zero {
371 if let Some(t_event) = self.brent_dekker(a, b, fa, fb, interpolator) {
372 let y_event = interpolator.interpolate(t_event).unwrap();
373 let push_point = match solution.t.last() {
374 Some(&last_t) => (t_event - last_t).abs() > self.abs_tol,
375 None => true,
376 };
377 if push_point {
378 solution.push(t_event, y_event);
379 }
380 self.event_count += 1;
381 if let Some(limit) = self.config.terminate {
382 if self.event_count >= limit {
383 self.last_g = Some(g_curr);
384 return ControlFlag::Terminate;
385 }
386 }
387 }
388 }
389 }
390 self.last_g = Some(g_curr);
391 ControlFlag::Continue
392 }
393
394 fn brent_dekker<I>(
395 &mut self,
396 mut a: T,
397 mut b: T,
398 mut fa: T,
399 mut fb: T,
400 interpolator: &mut I,
401 ) -> Option<T>
402 where
403 I: Interpolation<T, Y>,
404 {
405 if fa.abs() < fb.abs() {
406 std::mem::swap(&mut a, &mut b);
407 std::mem::swap(&mut fa, &mut fb);
408 }
409 let mut c = a;
410 let mut fc = fa;
411 let mut d = b - a;
412 let mut e = d;
413 let one = T::one();
414 let two = T::from_f64(2.0).unwrap();
415 let half = one / two;
416 let three = T::from_f64(3.0).unwrap();
417 for _ in 0..50u32 {
418 if fb == T::zero() {
419 return Some(b);
420 }
421 if fa.signum() == fb.signum() {
422 a = c;
423 fa = fc;
424 c = b;
425 fc = fb;
426 d = b - a;
427 e = d;
428 }
429 if fa.abs() < fb.abs() {
430 c = b;
431 b = a;
432 a = c;
433 fc = fb;
434 fb = fa;
435 fa = fc;
436 }
437 let tol = self.abs_tol.max(self.rel_tol * b.abs());
438 let m = half * (a - b);
439 if m.abs() <= tol || fb == T::zero() {
440 return Some(b);
441 }
442 let mut use_bis = true;
443 if e.abs() > tol && fa.abs() > fb.abs() {
444 let s = fb / fa;
445 let p;
446 let q;
447 if a == c {
448 p = two * m * s;
449 q = one - s;
450 } else {
451 let q1 = fa / fc;
452 let r = fb / fc;
453 p = s * (two * m * q1 * (q1 - r) - (b - a) * (r - one));
454 q = (q1 - one) * (r - one) * (s - one);
455 }
456 let mut q_mod = q;
457 let mut p_mod = p;
458 if q_mod > T::zero() {
459 p_mod = -p_mod;
460 } else {
461 q_mod = -q_mod;
462 }
463 if (two * p_mod).abs() < (three * m * q_mod - (tol * q_mod).abs())
464 && p_mod < (e * half * q_mod).abs()
465 {
466 e = d;
467 d = p_mod / q_mod;
468 use_bis = false;
469 }
470 }
471 if use_bis {
472 d = m;
473 e = m;
474 }
475 a = b;
476 fa = fb;
477 b = if d.abs() > tol {
478 b + d
479 } else {
480 b + if m > T::zero() { tol } else { -tol }
481 };
482 let yb = interpolator.interpolate(b).ok()?;
483 fb = self.event.event(b, &yb);
484 c = a;
485 fc = fa;
486 }
487 None
488 }
489}
490
491impl<'a, T, Y, O, E> Solout<T, Y> for EventWrappedSolout<'a, T, Y, O, E>
492where
493 T: Real,
494 Y: State<T>,
495 O: Solout<T, Y>,
496 E: Event<T, Y>,
497{
498 fn solout<I>(
499 &mut self,
500 t_curr: T,
501 t_prev: T,
502 y_curr: &Y,
503 y_prev: &Y,
504 interpolator: &mut I,
505 solution: &mut Solution<T, Y>,
506 ) -> ControlFlag<T, Y>
507 where
508 I: Interpolation<T, Y>,
509 {
510 let flag = self
511 .base
512 .solout(t_curr, t_prev, y_curr, y_prev, interpolator, solution);
513 if let ControlFlag::Terminate = flag {
514 return flag;
515 }
516 let evt_flag = self.detect_event(t_curr, t_prev, y_curr, y_prev, interpolator, solution);
517 match evt_flag {
518 ControlFlag::Terminate => ControlFlag::Terminate,
519 _ => flag,
520 }
521 }
522}