use crate::error::{IntegrateError, IntegrateResult};
use super::embedded_rk::{dopri5, OdeResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EventDirection {
Rising,
Falling,
Both,
}
pub struct EventSpec {
pub func: Box<dyn Fn(f64, &[f64]) -> f64 + Send + Sync>,
pub direction: EventDirection,
pub terminal: bool,
}
impl std::fmt::Debug for EventSpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EventSpec")
.field("direction", &self.direction)
.field("terminal", &self.terminal)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct EventResult {
pub t_event: f64,
pub y_event: Vec<f64>,
pub event_idx: usize,
}
pub struct EventSet {
pub specs: Vec<EventSpec>,
}
impl EventSet {
pub fn new(specs: Vec<EventSpec>) -> Self {
Self { specs }
}
}
const MAX_ILLINOIS: usize = 50;
const ILLINOIS_TOL: f64 = 1e-12;
fn direction_matches(g_prev: f64, g_curr: f64, direction: EventDirection) -> bool {
match direction {
EventDirection::Both => g_prev * g_curr < 0.0,
EventDirection::Rising => g_prev < 0.0 && g_curr > 0.0,
EventDirection::Falling => g_prev > 0.0 && g_curr < 0.0,
}
}
fn illinois_bracket<E>(
mut ta: f64,
mut tb: f64,
mut ga: f64,
mut gb: f64,
eval: E,
) -> f64
where
E: Fn(f64) -> f64,
{
let mut side = 0i32;
for _ in 0..MAX_ILLINOIS {
let dg = gb - ga;
let t_new = if dg.abs() < 1e-300 {
(ta + tb) / 2.0
} else {
ta - ga * (tb - ta) / dg
};
let t_new = t_new.clamp(ta.min(tb), ta.max(tb));
if (tb - ta).abs() < ILLINOIS_TOL {
return t_new;
}
let g_new = eval(t_new);
if g_new.abs() < ILLINOIS_TOL {
return t_new;
}
if ga * g_new < 0.0 {
if side == 1 {
ga /= 2.0;
}
tb = t_new;
gb = g_new;
side = 1; } else {
if side == -1 {
gb /= 2.0;
}
ta = t_new;
ga = g_new;
side = -1; }
}
(ta + tb) / 2.0
}
pub fn find_event_root(
g_prev: f64,
g_curr: f64,
t_prev: f64,
t_curr: f64,
y_prev: &[f64],
y_curr: &[f64],
event_idx: usize,
event: &EventSpec,
) -> Option<EventResult> {
if !direction_matches(g_prev, g_curr, event.direction) {
return None;
}
let n = y_prev.len();
let dt = t_curr - t_prev;
let interp = |t: f64| -> Vec<f64> {
let alpha = if dt.abs() < 1e-300 {
0.5
} else {
(t - t_prev) / dt
};
(0..n)
.map(|i| y_prev[i] + alpha * (y_curr[i] - y_prev[i]))
.collect()
};
let eval = |t: f64| -> f64 {
let y = interp(t);
(event.func)(t, &y)
};
let t_event = illinois_bracket(t_prev, t_curr, g_prev, g_curr, eval);
let y_event = interp(t_event);
Some(EventResult {
t_event,
y_event,
event_idx,
})
}
pub fn find_event_root_dense<I>(
g_prev: f64,
g_curr: f64,
t_prev: f64,
t_curr: f64,
interp: I,
event_idx: usize,
event: &EventSpec,
) -> Option<EventResult>
where
I: Fn(f64) -> Vec<f64>,
{
if !direction_matches(g_prev, g_curr, event.direction) {
return None;
}
let eval = |t: f64| -> f64 {
let y = interp(t);
(event.func)(t, &y)
};
let t_event = illinois_bracket(t_prev, t_curr, g_prev, g_curr, eval);
let y_event = interp(t_event);
Some(EventResult {
t_event,
y_event,
event_idx,
})
}
#[derive(Debug)]
pub struct OdeEventResult {
pub ode: OdeResult,
pub events: Vec<EventResult>,
pub terminated: bool,
}
pub fn dopri5_with_events<F>(
f: F,
t0: f64,
y0: &[f64],
t_end: f64,
rtol: f64,
atol: f64,
events: EventSet,
) -> IntegrateResult<OdeEventResult>
where
F: Fn(f64, &[f64]) -> Vec<f64> + Clone,
{
if y0.is_empty() {
return Err(IntegrateError::ValueError(
"y0 must be non-empty".to_string(),
));
}
if t_end <= t0 {
return Err(IntegrateError::ValueError(
"t_end must be > t0".to_string(),
));
}
let mut all_t: Vec<f64> = vec![t0];
let mut all_y: Vec<Vec<f64>> = vec![y0.to_vec()];
let mut all_events: Vec<EventResult> = Vec::new();
let mut n_steps_total: usize = 0;
let mut n_rejected_total: usize = 0;
let mut n_evals_total: usize = 0;
let mut terminated = false;
let mut g_prev: Vec<f64> = events
.specs
.iter()
.map(|s| (s.func)(t0, y0))
.collect();
let n_seg_max = 10_000_usize;
let seg_hint = ((t_end - t0) / 0.1).ceil() as usize; let n_seg = seg_hint.min(n_seg_max).max(1);
let dt_seg = (t_end - t0) / n_seg as f64;
let mut t_start = t0;
let mut y_start = y0.to_vec();
for _seg in 0..n_seg {
if terminated || t_start >= t_end - 1e-14 * (t_end - t0) {
break;
}
let t_seg_end = (t_start + dt_seg).min(t_end);
let seg_result = dopri5(f.clone(), t_start, &y_start, t_seg_end, rtol, atol)?;
n_steps_total += seg_result.n_steps;
n_rejected_total += seg_result.n_rejected;
n_evals_total += seg_result.n_evals;
let seg_len = seg_result.t.len();
let mut early_stop_idx: Option<usize> = None;
'step_scan: for step_i in 1..seg_len {
let t_p = seg_result.t[step_i - 1];
let t_c = seg_result.t[step_i];
let y_p = &seg_result.y[step_i - 1];
let y_c = &seg_result.y[step_i];
for (ev_idx, spec) in events.specs.iter().enumerate() {
let g_c = (spec.func)(t_c, y_c);
let g_p = g_prev[ev_idx];
if direction_matches(g_p, g_c, spec.direction) {
if let Some(ev) =
find_event_root(g_p, g_c, t_p, t_c, y_p, y_c, ev_idx, spec)
{
all_events.push(ev);
if spec.terminal {
early_stop_idx = Some(step_i);
terminated = true;
break 'step_scan;
}
}
}
g_prev[ev_idx] = g_c;
}
}
let append_up_to = early_stop_idx.unwrap_or(seg_len);
for step_i in 1..append_up_to {
all_t.push(seg_result.t[step_i]);
all_y.push(seg_result.y[step_i].clone());
}
if terminated {
if let Some(last_ev) = all_events.last() {
all_t.push(last_ev.t_event);
all_y.push(last_ev.y_event.clone());
}
break;
}
if let (Some(t_last), Some(y_last)) =
(seg_result.t.last(), seg_result.y.last())
{
t_start = *t_last;
y_start = y_last.clone();
} else {
break;
}
}
let n_out = all_t.len();
Ok(OdeEventResult {
ode: OdeResult {
t: all_t,
y: all_y,
n_steps: n_steps_total,
n_rejected: n_rejected_total,
n_evals: n_evals_total + n_out, },
events: all_events,
terminated,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn illinois_finds_exact_midpoint() {
let spec = EventSpec {
func: Box::new(|t: f64, _y: &[f64]| t - 0.5),
direction: EventDirection::Rising,
terminal: false,
};
let y_prev = vec![1.0_f64];
let y_curr = vec![1.0_f64];
let result =
find_event_root(-0.5, 0.5, 0.0, 1.0, &y_prev, &y_curr, 0, &spec)
.expect("should detect rising crossing");
assert!(
(result.t_event - 0.5).abs() < 1e-10,
"t_event={} expected 0.5",
result.t_event
);
assert_eq!(result.event_idx, 0);
}
#[test]
fn illinois_direction_filter_falling() {
let spec_rising = EventSpec {
func: Box::new(|t: f64, _y: &[f64]| 1.0 - 2.0 * t), direction: EventDirection::Rising, terminal: false,
};
let y = vec![0.0_f64];
let res = find_event_root(1.0, -1.0, 0.0, 1.0, &y, &y, 0, &spec_rising);
assert!(res.is_none(), "Rising filter should reject falling crossing");
let spec_falling = EventSpec {
func: Box::new(|t: f64, _y: &[f64]| 1.0 - 2.0 * t),
direction: EventDirection::Falling,
terminal: false,
};
let res2 = find_event_root(1.0, -1.0, 0.0, 1.0, &y, &y, 0, &spec_falling)
.expect("Falling filter should accept falling crossing");
assert!((res2.t_event - 0.5).abs() < 1e-8);
}
#[test]
fn illinois_both_directions() {
let spec = EventSpec {
func: Box::new(|t: f64, _y: &[f64]| (t - 0.3).sin()),
direction: EventDirection::Both,
terminal: false,
};
let y = vec![0.0_f64];
let res = find_event_root(-0.5, 0.5, 0.0, 0.6, &y, &y, 2, &spec);
let ev = res.expect("should find crossing");
assert_eq!(ev.event_idx, 2);
}
#[test]
fn events_detect_zero_crossing_sin() {
let f = |t: f64, _y: &[f64]| vec![t.cos()];
let event_spec = EventSpec {
func: Box::new(|_t: f64, y: &[f64]| y[0]),
direction: EventDirection::Falling, terminal: false,
};
let events = EventSet::new(vec![event_spec]);
let result =
dopri5_with_events(f, 0.0, &[0.0], 4.0, 1e-8, 1e-10, events)
.expect("integration failed");
let pi = std::f64::consts::PI;
let found = result
.events
.iter()
.any(|e| (e.t_event - pi).abs() < 0.05);
assert!(
found,
"Expected crossing near t=π, got events: {:?}",
result.events.iter().map(|e| e.t_event).collect::<Vec<_>>()
);
assert!(!result.terminated);
}
#[test]
fn events_terminal_stops_integration() {
let f = |_t: f64, y: &[f64]| vec![-y[0]];
let threshold = EventSpec {
func: Box::new(|_t: f64, y: &[f64]| y[0] - 0.5), direction: EventDirection::Falling,
terminal: true,
};
let events = EventSet::new(vec![threshold]);
let result = dopri5_with_events(f, 0.0, &[1.0], 5.0, 1e-8, 1e-10, events)
.expect("integration failed");
assert!(result.terminated, "Expected terminal stop");
let t_final = result.ode.t.last().copied().unwrap_or(0.0);
let ln2 = 2.0_f64.ln();
assert!(
(t_final - ln2).abs() < 0.1,
"Expected termination near t=ln2≈{ln2:.4}, got t={t_final:.4}"
);
assert!(!result.events.is_empty());
}
#[test]
fn events_multiple_crossings() {
let f = |_t: f64, _y: &[f64]| vec![1.0];
let mut specs = Vec::new();
for thresh in [1.0_f64, 2.0, 3.0] {
specs.push(EventSpec {
func: Box::new(move |_t: f64, y: &[f64]| y[0] - thresh),
direction: EventDirection::Rising,
terminal: false,
});
}
let events = EventSet::new(specs);
let result = dopri5_with_events(f, 0.0, &[0.0], 4.0, 1e-8, 1e-10, events)
.expect("integration failed");
assert!(
result.events.len() >= 3,
"expected ≥3 events, got {}",
result.events.len()
);
}
#[test]
fn events_validates_empty_y0() {
let f = |_t: f64, _y: &[f64]| vec![];
let events = EventSet::new(vec![]);
assert!(dopri5_with_events(f, 0.0, &[], 1.0, 1e-6, 1e-8, events).is_err());
}
#[test]
fn events_validates_t_end_leq_t0() {
let f = |_t: f64, y: &[f64]| vec![-y[0]];
let events = EventSet::new(vec![]);
assert!(dopri5_with_events(f, 1.0, &[1.0], 0.5, 1e-6, 1e-8, events).is_err());
}
}