1use super::embedded_rk::{dopri5, OdeResult};
26use crate::error::{IntegrateError, IntegrateResult};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum EventDirection {
33 Rising,
35 Falling,
37 Both,
39}
40
41pub struct EventSpec {
46 pub func: Box<dyn Fn(f64, &[f64]) -> f64 + Send + Sync>,
48 pub direction: EventDirection,
50 pub terminal: bool,
52}
53
54impl std::fmt::Debug for EventSpec {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("EventSpec")
57 .field("direction", &self.direction)
58 .field("terminal", &self.terminal)
59 .finish()
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct EventResult {
66 pub t_event: f64,
68 pub y_event: Vec<f64>,
70 pub event_idx: usize,
72}
73
74pub struct EventSet {
76 pub specs: Vec<EventSpec>,
78}
79
80impl EventSet {
81 pub fn new(specs: Vec<EventSpec>) -> Self {
83 Self { specs }
84 }
85}
86
87const MAX_ILLINOIS: usize = 50;
91const ILLINOIS_TOL: f64 = 1e-12;
93
94fn direction_matches(g_prev: f64, g_curr: f64, direction: EventDirection) -> bool {
97 match direction {
98 EventDirection::Both => g_prev * g_curr < 0.0,
99 EventDirection::Rising => g_prev < 0.0 && g_curr > 0.0,
100 EventDirection::Falling => g_prev > 0.0 && g_curr < 0.0,
101 }
102}
103
104fn illinois_bracket<E>(mut ta: f64, mut tb: f64, mut ga: f64, mut gb: f64, eval: E) -> f64
110where
111 E: Fn(f64) -> f64,
112{
113 let mut side = 0i32; for _ in 0..MAX_ILLINOIS {
118 let dg = gb - ga;
120 let t_new = if dg.abs() < 1e-300 {
121 (ta + tb) / 2.0
122 } else {
123 ta - ga * (tb - ta) / dg
124 };
125 let t_new = t_new.clamp(ta.min(tb), ta.max(tb));
126
127 if (tb - ta).abs() < ILLINOIS_TOL {
128 return t_new;
129 }
130
131 let g_new = eval(t_new);
132
133 if g_new.abs() < ILLINOIS_TOL {
134 return t_new;
135 }
136
137 if ga * g_new < 0.0 {
138 if side == 1 {
140 ga /= 2.0;
142 }
143 tb = t_new;
144 gb = g_new;
145 side = 1; } else {
147 if side == -1 {
149 gb /= 2.0;
151 }
152 ta = t_new;
153 ga = g_new;
154 side = -1; }
156 }
157
158 (ta + tb) / 2.0
159}
160
161pub fn find_event_root(
182 g_prev: f64,
183 g_curr: f64,
184 t_prev: f64,
185 t_curr: f64,
186 y_prev: &[f64],
187 y_curr: &[f64],
188 event_idx: usize,
189 event: &EventSpec,
190) -> Option<EventResult> {
191 if !direction_matches(g_prev, g_curr, event.direction) {
192 return None;
193 }
194
195 let n = y_prev.len();
196 let dt = t_curr - t_prev;
197
198 let interp = |t: f64| -> Vec<f64> {
200 let alpha = if dt.abs() < 1e-300 {
201 0.5
202 } else {
203 (t - t_prev) / dt
204 };
205 (0..n)
206 .map(|i| y_prev[i] + alpha * (y_curr[i] - y_prev[i]))
207 .collect()
208 };
209
210 let eval = |t: f64| -> f64 {
211 let y = interp(t);
212 (event.func)(t, &y)
213 };
214
215 let t_event = illinois_bracket(t_prev, t_curr, g_prev, g_curr, eval);
216 let y_event = interp(t_event);
217
218 Some(EventResult {
219 t_event,
220 y_event,
221 event_idx,
222 })
223}
224
225pub fn find_event_root_dense<I>(
234 g_prev: f64,
235 g_curr: f64,
236 t_prev: f64,
237 t_curr: f64,
238 interp: I,
239 event_idx: usize,
240 event: &EventSpec,
241) -> Option<EventResult>
242where
243 I: Fn(f64) -> Vec<f64>,
244{
245 if !direction_matches(g_prev, g_curr, event.direction) {
246 return None;
247 }
248
249 let eval = |t: f64| -> f64 {
250 let y = interp(t);
251 (event.func)(t, &y)
252 };
253
254 let t_event = illinois_bracket(t_prev, t_curr, g_prev, g_curr, eval);
255 let y_event = interp(t_event);
256
257 Some(EventResult {
258 t_event,
259 y_event,
260 event_idx,
261 })
262}
263
264#[derive(Debug)]
268pub struct OdeEventResult {
269 pub ode: OdeResult,
271 pub events: Vec<EventResult>,
273 pub terminated: bool,
275}
276
277pub fn dopri5_with_events<F>(
300 f: F,
301 t0: f64,
302 y0: &[f64],
303 t_end: f64,
304 rtol: f64,
305 atol: f64,
306 events: EventSet,
307) -> IntegrateResult<OdeEventResult>
308where
309 F: Fn(f64, &[f64]) -> Vec<f64> + Clone,
310{
311 if y0.is_empty() {
312 return Err(IntegrateError::ValueError(
313 "y0 must be non-empty".to_string(),
314 ));
315 }
316 if t_end <= t0 {
317 return Err(IntegrateError::ValueError("t_end must be > t0".to_string()));
318 }
319
320 let mut all_t: Vec<f64> = vec![t0];
321 let mut all_y: Vec<Vec<f64>> = vec![y0.to_vec()];
322 let mut all_events: Vec<EventResult> = Vec::new();
323 let mut n_steps_total: usize = 0;
324 let mut n_rejected_total: usize = 0;
325 let mut n_evals_total: usize = 0;
326 let mut terminated = false;
327
328 let mut g_prev: Vec<f64> = events.specs.iter().map(|s| (s.func)(t0, y0)).collect();
330
331 let n_seg_max = 10_000_usize;
338 let seg_hint = ((t_end - t0) / 0.1).ceil() as usize; let n_seg = seg_hint.min(n_seg_max).max(1);
340
341 let dt_seg = (t_end - t0) / n_seg as f64;
342 let mut t_start = t0;
343 let mut y_start = y0.to_vec();
344
345 for _seg in 0..n_seg {
346 if terminated || t_start >= t_end - 1e-14 * (t_end - t0) {
347 break;
348 }
349
350 let t_seg_end = (t_start + dt_seg).min(t_end);
351
352 let seg_result = dopri5(f.clone(), t_start, &y_start, t_seg_end, rtol, atol)?;
353
354 n_steps_total += seg_result.n_steps;
355 n_rejected_total += seg_result.n_rejected;
356 n_evals_total += seg_result.n_evals;
357
358 let seg_len = seg_result.t.len();
360 let mut early_stop_idx: Option<usize> = None;
361
362 'step_scan: for step_i in 1..seg_len {
363 let t_p = seg_result.t[step_i - 1];
364 let t_c = seg_result.t[step_i];
365 let y_p = &seg_result.y[step_i - 1];
366 let y_c = &seg_result.y[step_i];
367
368 for (ev_idx, spec) in events.specs.iter().enumerate() {
369 let g_c = (spec.func)(t_c, y_c);
370 let g_p = g_prev[ev_idx];
371
372 if direction_matches(g_p, g_c, spec.direction) {
373 if let Some(ev) = find_event_root(g_p, g_c, t_p, t_c, y_p, y_c, ev_idx, spec) {
374 all_events.push(ev);
375 if spec.terminal {
376 early_stop_idx = Some(step_i);
377 terminated = true;
378 break 'step_scan;
379 }
380 }
381 }
382
383 g_prev[ev_idx] = g_c;
384 }
385 }
386
387 let append_up_to = early_stop_idx.unwrap_or(seg_len);
389 for step_i in 1..append_up_to {
390 all_t.push(seg_result.t[step_i]);
391 all_y.push(seg_result.y[step_i].clone());
392 }
393
394 if terminated {
396 if let Some(last_ev) = all_events.last() {
397 all_t.push(last_ev.t_event);
398 all_y.push(last_ev.y_event.clone());
399 }
400 break;
401 }
402
403 if let (Some(t_last), Some(y_last)) = (seg_result.t.last(), seg_result.y.last()) {
405 t_start = *t_last;
406 y_start = y_last.clone();
407 } else {
408 break;
409 }
410 }
411
412 let n_out = all_t.len();
413 Ok(OdeEventResult {
414 ode: OdeResult {
415 t: all_t,
416 y: all_y,
417 n_steps: n_steps_total,
418 n_rejected: n_rejected_total,
419 n_evals: n_evals_total + n_out, },
421 events: all_events,
422 terminated,
423 })
424}
425
426#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
435 fn illinois_finds_exact_midpoint() {
436 let spec = EventSpec {
438 func: Box::new(|t: f64, _y: &[f64]| t - 0.5),
439 direction: EventDirection::Rising,
440 terminal: false,
441 };
442 let y_prev = vec![1.0_f64];
443 let y_curr = vec![1.0_f64];
444 let result = find_event_root(-0.5, 0.5, 0.0, 1.0, &y_prev, &y_curr, 0, &spec)
445 .expect("should detect rising crossing");
446 assert!(
447 (result.t_event - 0.5).abs() < 1e-10,
448 "t_event={} expected 0.5",
449 result.t_event
450 );
451 assert_eq!(result.event_idx, 0);
452 }
453
454 #[test]
455 fn illinois_direction_filter_falling() {
456 let spec_rising = EventSpec {
458 func: Box::new(|t: f64, _y: &[f64]| 1.0 - 2.0 * t), direction: EventDirection::Rising, terminal: false,
461 };
462 let y = vec![0.0_f64];
463 let res = find_event_root(1.0, -1.0, 0.0, 1.0, &y, &y, 0, &spec_rising);
464 assert!(
465 res.is_none(),
466 "Rising filter should reject falling crossing"
467 );
468
469 let spec_falling = EventSpec {
470 func: Box::new(|t: f64, _y: &[f64]| 1.0 - 2.0 * t),
471 direction: EventDirection::Falling,
472 terminal: false,
473 };
474 let res2 = find_event_root(1.0, -1.0, 0.0, 1.0, &y, &y, 0, &spec_falling)
475 .expect("Falling filter should accept falling crossing");
476 assert!((res2.t_event - 0.5).abs() < 1e-8);
477 }
478
479 #[test]
480 fn illinois_both_directions() {
481 let spec = EventSpec {
482 func: Box::new(|t: f64, _y: &[f64]| (t - 0.3).sin()),
483 direction: EventDirection::Both,
484 terminal: false,
485 };
486 let y = vec![0.0_f64];
487 let res = find_event_root(-0.5, 0.5, 0.0, 0.6, &y, &y, 2, &spec);
489 let ev = res.expect("should find crossing");
490 assert_eq!(ev.event_idx, 2);
491 }
492
493 #[test]
496 fn events_detect_zero_crossing_sin() {
497 let f = |t: f64, _y: &[f64]| vec![t.cos()];
500 let event_spec = EventSpec {
501 func: Box::new(|_t: f64, y: &[f64]| y[0]),
502 direction: EventDirection::Falling, terminal: false,
504 };
505 let events = EventSet::new(vec![event_spec]);
506 let result = dopri5_with_events(f, 0.0, &[0.0], 4.0, 1e-8, 1e-10, events)
507 .expect("integration failed");
508
509 let pi = std::f64::consts::PI;
511 let found = result.events.iter().any(|e| (e.t_event - pi).abs() < 0.05);
512 assert!(
513 found,
514 "Expected crossing near t=π, got events: {:?}",
515 result.events.iter().map(|e| e.t_event).collect::<Vec<_>>()
516 );
517 assert!(!result.terminated);
518 }
519
520 #[test]
521 fn events_terminal_stops_integration() {
522 let f = |_t: f64, y: &[f64]| vec![-y[0]];
525 let threshold = EventSpec {
526 func: Box::new(|_t: f64, y: &[f64]| y[0] - 0.5), direction: EventDirection::Falling,
528 terminal: true,
529 };
530 let events = EventSet::new(vec![threshold]);
531 let result = dopri5_with_events(f, 0.0, &[1.0], 5.0, 1e-8, 1e-10, events)
532 .expect("integration failed");
533
534 assert!(result.terminated, "Expected terminal stop");
535 let t_final = result.ode.t.last().copied().unwrap_or(0.0);
537 let ln2 = 2.0_f64.ln();
538 assert!(
539 (t_final - ln2).abs() < 0.1,
540 "Expected termination near t=ln2≈{ln2:.4}, got t={t_final:.4}"
541 );
542 assert!(!result.events.is_empty());
543 }
544
545 #[test]
546 fn events_multiple_crossings() {
547 let f = |_t: f64, _y: &[f64]| vec![1.0];
550 let mut specs = Vec::new();
551 for thresh in [1.0_f64, 2.0, 3.0] {
552 specs.push(EventSpec {
553 func: Box::new(move |_t: f64, y: &[f64]| y[0] - thresh),
554 direction: EventDirection::Rising,
555 terminal: false,
556 });
557 }
558 let events = EventSet::new(specs);
559 let result = dopri5_with_events(f, 0.0, &[0.0], 4.0, 1e-8, 1e-10, events)
560 .expect("integration failed");
561
562 assert!(
564 result.events.len() >= 3,
565 "expected ≥3 events, got {}",
566 result.events.len()
567 );
568 }
569
570 #[test]
571 fn events_validates_empty_y0() {
572 let f = |_t: f64, _y: &[f64]| vec![];
573 let events = EventSet::new(vec![]);
574 assert!(dopri5_with_events(f, 0.0, &[], 1.0, 1e-6, 1e-8, events).is_err());
575 }
576
577 #[test]
578 fn events_validates_t_end_leq_t0() {
579 let f = |_t: f64, y: &[f64]| vec![-y[0]];
580 let events = EventSet::new(vec![]);
581 assert!(dopri5_with_events(f, 1.0, &[1.0], 0.5, 1e-6, 1e-8, events).is_err());
582 }
583}