use crate::events::event::Direction;
use crate::events::event::Event;
use crate::integrators::integrator_trait::Integrator;
use crate::ode_state::ode_state_trait::OdeState;
use rootfinder::{Interval, root_bisection};
pub enum EventDetectionMethod {
Exact,
LinearInterpolation,
LeftInterpolation,
RightInterpolation,
}
fn event_detection_helper<T: OdeState>(
event: &Event<T>,
t_prev: f64,
y_prev: &T,
y_curr: &T,
h: f64,
) -> Option<(f64, f64)> {
let g_prev = (event.g)(t_prev, y_prev);
let g_curr = (event.g)(t_prev + h, y_curr);
if !matches!(event.direction, Direction::Either) {
if (g_curr == g_prev)
|| ((g_curr > g_prev) && matches!(event.direction, Direction::Decreasing))
|| ((g_curr < g_prev) && matches!(event.direction, Direction::Increasing))
{
return None;
}
}
Some((g_prev, g_curr))
}
pub(crate) fn exact_event_detection<T: OdeState, I: Integrator<T>>(
f: &impl Fn(f64, &T) -> T,
event: &Event<T>,
t_prev: f64,
y_prev: &T,
y_curr: &T,
h: f64,
) -> Option<f64> {
if event_detection_helper(event, t_prev, y_prev, y_curr, h).is_some() {
let gh = |h: f64| {
let mut y_copy = y_prev.clone();
I::propagate(&f, t_prev, h, &mut y_copy);
(event.g)(t_prev + h, &y_copy)
};
root_bisection(&gh, Interval::new(0.0, h), None, None).ok()
} else {
None
}
}
pub(crate) fn linear_event_detection<T: OdeState>(
event: &Event<T>,
t_prev: f64,
y_prev: &T,
y_curr: &T,
h: f64,
) -> Option<f64> {
if let Some((g_prev, g_curr)) = event_detection_helper(event, t_prev, y_prev, y_curr, h) {
Some(-h * g_prev / (g_curr - g_prev))
} else {
None
}
}
pub(crate) fn left_event_detection<T: OdeState>(
event: &Event<T>,
t_prev: f64,
y_prev: &T,
y_curr: &T,
h: f64,
) -> Option<f64> {
if let Some((_, g_curr)) = event_detection_helper(event, t_prev, y_prev, y_curr, h) {
if g_curr == 0.0 { Some(h) } else { Some(0.0) }
} else {
None
}
}
pub(crate) fn right_event_detection<T: OdeState>(
event: &Event<T>,
t_prev: f64,
y_prev: &T,
y_curr: &T,
h: f64,
) -> Option<f64> {
if let Some((g_prev, _)) = event_detection_helper(event, t_prev, y_prev, y_curr, h) {
if g_prev == 0.0 { Some(0.0) } else { Some(h) }
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Euler, RK4};
use numtest::*;
fn check_event_function_value<I: Integrator<f64>>(
f: &impl Fn(f64, &f64) -> f64,
event: &Event<f64>,
t_prev: f64,
y_prev: &f64,
h_event: f64,
) {
let mut y_event = *y_prev;
I::propagate(f, t_prev, h_event, &mut y_event);
let t_event = t_prev + h_event;
let g_event = (event.g)(t_event, &y_event);
assert_equal_to_atol!(g_event, 0.0, 2.0 * f64::EPSILON);
}
#[test]
fn test_event_detection_on_time_basic() {
let event = Event::new(|t: f64, _y: &f64| t.sqrt() - 0.5);
let f = |_t: f64, y: &f64| *y;
let t_prev = 0.0;
let y_prev = 1.0;
let h = 1.0;
let mut y_curr = y_prev;
Euler::propagate(&f, t_prev, h, &mut y_curr);
let h_event_exact =
exact_event_detection::<f64, Euler>(&f, &event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_equal_to_decimal!(h_event_exact, 0.25, 15);
check_event_function_value::<Euler>(&f, &event, t_prev, &y_prev, h_event_exact);
let h_event_linear =
linear_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_linear, 0.5);
let h_event_left =
left_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_left, 0.0);
let h_event_right =
right_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_right, h);
}
#[test]
fn test_event_detection_helper_1() {
let event = Event::new(|t: f64, y: &f64| y - t - 0.5);
let t_prev = 0.0;
let y_prev = 1.0; let y_curr = 1.5; let h = 1.0;
let result = event_detection_helper(&event, t_prev, &y_prev, &y_curr, h);
assert!(result.is_some());
let (g_prev, g_curr) = result.unwrap();
assert_equal_to_decimal!(g_prev, 0.5, 15);
assert_equal_to_decimal!(g_curr, 0.0, 15);
}
#[test]
fn test_event_detection_helper_2() {
let event = Event::new(|t: f64, y: &f64| y - t - 0.5).direction(Direction::Decreasing);
let t_prev = 0.0;
let y_prev = 1.0; let y_curr = 1.5; let h = 1.0;
let result = event_detection_helper(&event, t_prev, &y_prev, &y_curr, h);
assert!(result.is_some());
let (g_prev, g_curr) = result.unwrap();
assert_equal_to_decimal!(g_prev, 0.5, 15);
assert_equal_to_decimal!(g_curr, 0.0, 15);
}
#[test]
fn test_event_detection_helper_3() {
let event = Event::new(|t: f64, y: &f64| y - t - 0.5).direction(Direction::Decreasing);
let t_prev = 0.0;
let y_prev = 0.5; let y_curr = 1.5; let h = 1.0;
let result = event_detection_helper(&event, t_prev, &y_prev, &y_curr, h);
assert!(result.is_none());
}
#[test]
fn test_event_detection_helper_4() {
let event = Event::new(|t: f64, y: &f64| y - t - 0.5).direction(Direction::Increasing);
let t_prev = 0.0;
let y_prev = 0.0; let y_curr = 1.5; let h = 1.0;
let result = event_detection_helper(&event, t_prev, &y_prev, &y_curr, h);
assert!(result.is_some());
let (g_prev, g_curr) = result.unwrap();
assert_equal_to_decimal!(g_prev, -0.5, 15);
assert_equal_to_decimal!(g_curr, 0.0, 15);
}
#[test]
fn test_event_detection_helper_5() {
let event = Event::new(|t: f64, y: &f64| y - t - 0.5).direction(Direction::Increasing);
let t_prev = 0.0;
let y_prev = 1.0; let y_curr = 1.5; let h = 1.0;
let result = event_detection_helper(&event, t_prev, &y_prev, &y_curr, h);
assert!(result.is_none());
}
#[test]
fn test_event_detection_on_time_lower_bound() {
let event = Event::new(|t: f64, _y: &f64| t);
let f = |_t: f64, y: &f64| *y;
let t_prev = 0.0;
let y_prev = 1.0;
let h = 1.0;
let mut y_curr = y_prev;
Euler::propagate(&f, t_prev, h, &mut y_curr);
let h_event_exact =
exact_event_detection::<f64, Euler>(&f, &event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_exact, 0.0);
check_event_function_value::<Euler>(&f, &event, t_prev, &y_prev, h_event_exact);
let h_event_linear =
linear_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_linear, 0.0);
let h_event_left =
left_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_left, 0.0);
let h_event_right =
right_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_right, 0.0);
}
#[test]
fn test_event_detection_on_time_upper_bound() {
let event = Event::new(|t: f64, _y: &f64| t - 1.0);
let f = |_t: f64, y: &f64| *y;
let t_prev = 0.0;
let y_prev = 1.0;
let h = 1.0;
let mut y_curr = y_prev;
Euler::propagate(&f, t_prev, h, &mut y_curr);
let h_event_exact =
exact_event_detection::<f64, Euler>(&f, &event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_exact, 1.0);
check_event_function_value::<Euler>(&f, &event, t_prev, &y_prev, h_event_exact);
let h_event_linear =
linear_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_linear, 1.0);
let h_event_left =
left_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_left, h);
let h_event_right =
right_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_right, h);
}
#[test]
fn test_exact_event_detection_on_state_basic() {
let event = Event::new(|_t: f64, y: &f64| *y - 1.5);
let f = |_t: f64, y: &f64| *y;
let t_prev = 0.0;
let y_prev = 1.0;
let h = 1.0;
let mut y_curr = y_prev;
Euler::propagate(&f, t_prev, h, &mut y_curr);
let h_event =
exact_event_detection::<f64, Euler>(&f, &event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_equal_to_decimal!(h_event, 0.5, 15);
check_event_function_value::<Euler>(&f, &event, t_prev, &y_prev, h_event);
let h_event_linear =
linear_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_linear, 0.5);
let h_event_left =
left_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_left, 0.0);
let h_event_right =
right_event_detection::<f64>(&event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event_right, h);
}
#[test]
fn test_exact_event_detection_on_state_different_integrators_different_directions() {
let event = Event::new(|_t: f64, y: &f64| *y - 1.5);
let f = |_t: f64, y: &f64| *y;
let t_prev = 0.0;
let y_prev = 1.0;
let y_curr = 1.0_f64.exp();
let h = 1.0;
assert!(matches!(event.direction, Direction::Either));
let h_event =
exact_event_detection::<f64, Euler>(&f, &event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_equal_to_decimal!(h_event, 0.5, 15);
check_event_function_value::<Euler>(&f, &event, t_prev, &y_prev, h_event);
let h_event =
exact_event_detection::<f64, RK4>(&f, &event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event, 0.40553040739646273);
check_event_function_value::<RK4>(&f, &event, t_prev, &y_prev, h_event);
let event = event.direction(Direction::Decreasing);
assert!(
exact_event_detection::<f64, Euler>(&f, &event, t_prev, &y_prev, &y_curr, h).is_none()
);
let event = event.direction(Direction::Increasing);
let h_event =
exact_event_detection::<f64, RK4>(&f, &event, t_prev, &y_prev, &y_curr, h).unwrap();
assert_eq!(h_event, 0.40553040739646273);
}
}