use numra_core::Scalar;
const EVENT_ZERO_TOL: f64 = 1e-12;
const EVENT_BISECTION_TOL: f64 = 1e-13;
const EVENT_MAX_BISECTION: usize = 50;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum EventDirection {
Rising,
Falling,
Both,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum EventAction {
Stop,
Continue,
}
pub trait EventFunction<S: Scalar>: Send + Sync {
fn evaluate(&self, t: S, y: &[S]) -> S;
fn direction(&self) -> EventDirection {
EventDirection::Both
}
fn action(&self) -> EventAction {
EventAction::Continue
}
}
#[derive(Debug, Clone)]
pub struct Event<S: Scalar> {
pub t: S,
pub y: Vec<S>,
pub event_index: usize,
}
pub fn find_event_time<S: Scalar>(
event: &dyn EventFunction<S>,
t_lo: S,
y_lo: &[S],
t_hi: S,
y_hi: &[S],
interpolate: &dyn Fn(S) -> Vec<S>,
) -> Option<(S, Vec<S>)> {
let g_lo = event.evaluate(t_lo, y_lo);
let g_hi = event.evaluate(t_hi, y_hi);
if g_lo * g_hi > S::ZERO {
return None;
}
match event.direction() {
EventDirection::Rising if g_hi <= g_lo => return None,
EventDirection::Falling if g_hi >= g_lo => return None,
_ => {}
}
let mut t_a = t_lo;
let mut t_b = t_hi;
let mut y_a = y_lo.to_vec();
let zero_tol = S::from_f64(EVENT_ZERO_TOL);
let bisect_tol = S::from_f64(EVENT_BISECTION_TOL);
for _ in 0..EVENT_MAX_BISECTION {
let t_mid = (t_a + t_b) / S::TWO;
let y_mid = interpolate(t_mid);
let g_mid = event.evaluate(t_mid, &y_mid);
if g_mid.abs() < zero_tol {
return Some((t_mid, y_mid));
}
let g_a = event.evaluate(t_a, &y_a);
if g_mid * g_a < S::ZERO {
t_b = t_mid;
} else {
t_a = t_mid;
y_a = y_mid;
}
if (t_b - t_a).abs() < bisect_tol {
break;
}
}
let t_root = (t_a + t_b) / S::TWO;
let y_root = interpolate(t_root);
Some((t_root, y_root))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_event_direction_default() {
struct MyEvent;
impl EventFunction<f64> for MyEvent {
fn evaluate(&self, _t: f64, _y: &[f64]) -> f64 {
0.0
}
}
let e = MyEvent;
assert_eq!(e.direction(), EventDirection::Both);
assert_eq!(e.action(), EventAction::Continue);
}
#[test]
fn test_find_event_time_simple() {
struct ZeroCross;
impl EventFunction<f64> for ZeroCross {
fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
y[0]
}
}
let event = ZeroCross;
let t_lo = 0.0;
let y_lo = [1.0];
let t_hi = 2.0;
let y_hi = [-1.0];
let interpolate = |t: f64| -> Vec<f64> {
vec![1.0 - t] };
let result = find_event_time(&event, t_lo, &y_lo, t_hi, &y_hi, &interpolate);
assert!(result.is_some());
let (t_event, y_event) = result.unwrap();
assert!(
(t_event - 1.0).abs() < 1e-10,
"Expected t=1, got t={}",
t_event
);
assert!(y_event[0].abs() < 1e-10);
}
#[test]
fn test_find_event_time_direction_filter() {
struct RisingOnly;
impl EventFunction<f64> for RisingOnly {
fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
y[0]
}
fn direction(&self) -> EventDirection {
EventDirection::Rising
}
}
let event = RisingOnly;
let interpolate = |t: f64| -> Vec<f64> { vec![1.0 - t] };
let result = find_event_time(&event, 0.0, &[1.0], 2.0, &[-1.0], &interpolate);
assert!(
result.is_none(),
"Rising event should not trigger on falling crossing"
);
}
#[test]
fn test_no_sign_change() {
struct MyEvent;
impl EventFunction<f64> for MyEvent {
fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
y[0]
}
}
let event = MyEvent;
let interpolate = |t: f64| -> Vec<f64> { vec![1.0 + t] };
let result = find_event_time(&event, 0.0, &[1.0], 2.0, &[3.0], &interpolate);
assert!(result.is_none(), "No sign change should return None");
}
}