use crate::events::event_manager::EventManager;
use crate::integrators::integrator_trait::Integrator;
use crate::ode_state::ode_state_trait::OdeState;
use crate::solution::Solution;
pub fn solve_ivp<T: OdeState + 'static, I: Integrator<T>>(
f: &impl Fn(f64, &T) -> T,
t0: f64,
y0: &T,
tf: f64,
mut h: f64,
mut event_manager: Option<&mut EventManager<T>>,
) -> Solution<T> {
let mut sol = Solution::new_for_ivp(y0, t0, tf, h);
let mut t;
let mut y = y0.clone();
for i in 1..sol.t.capacity() {
t = t0 + (i as f64) * h;
sol.t.push(t);
if i == sol.t.capacity() - 1 {
h = tf - sol.t[i - 1];
sol.t[i] = tf;
}
I::propagate(f, sol.t[i - 1], h, &mut y);
sol.y.push(y.clone());
if let Some(event_manager) = event_manager.as_deref_mut() {
let (idx_event, h_event) =
event_manager.detect_events::<I>(f, sol.t[i - 1], &sol.y[i - 1], &y, h);
if let (Some(idx_event), Some(h_event)) = (idx_event, h_event) {
let t_event;
let mut y_event;
if h_event == 0.0 {
t_event = sol.t[i - 1];
y_event = sol.y[i - 1].clone();
} else if h_event == h {
t_event = sol.t[i];
y_event = sol.y[i].clone();
} else {
t_event = sol.t[i - 1] + h_event;
y_event = sol.y[i - 1].clone();
I::propagate(f, sol.t[i - 1], h_event, &mut y_event);
}
sol.t[i] = t_event;
sol.y[i] = y_event.clone();
event_manager.store(t_event, &y_event, idx_event);
if event_manager.num_detections[idx_event]
== event_manager[idx_event].termination.num_detections
{
break;
}
if let Some(s) = &event_manager[idx_event].s {
sol.y[i] = s(t_event, &y);
}
}
}
}
sol.shrink_to_fit();
sol
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Euler;
use crate::StateIndex;
use crate::events::event::{Event, Termination};
use crate::events::event_manager::EventManager;
use numtest::*;
#[cfg(feature = "nalgebra")]
use nalgebra::{DVector, SMatrix, dvector};
#[test]
fn test_solve_ivp_event_at_current_time() {
let f = |_t: f64, y: &f64| *y;
let y0 = 1.0;
let t0 = 0.0;
let tf = 3.0;
let h = 1.0;
let event = Event::new(|t: f64, _y: &f64| t - 1.0);
let mut event_manager = EventManager::new(vec![&event]);
let sol = solve_ivp::<f64, Euler>(&f, t0, &y0, tf, h, Some(&mut event_manager));
assert!(sol.t.contains(&1.0));
assert!(sol.y.len() >= 2);
}
#[test]
fn test_solve_ivp_event_at_previous_time() {
let f = |_t: f64, y: &f64| *y;
let y0 = 1.0;
let t0 = 0.0;
let tf = 2.0;
let h = 1.0;
let event = Event::new(|t: f64, _y: &f64| t - 0.0);
let mut event_manager = EventManager::new(vec![&event]);
let sol = solve_ivp::<f64, Euler>(&f, t0, &y0, tf, h, Some(&mut event_manager));
assert_eq!(sol.t[0], 0.0);
}
#[test]
fn test_solve_ivp_event_between_time_steps() {
let f = |_t: f64, y: &f64| *y;
let y0 = 1.0;
let t0 = 0.0;
let tf = 3.0;
let h = 1.0;
let event = Event::new(|_t: f64, y: &f64| y - 1.5);
let mut event_manager = EventManager::new(vec![&event]);
let sol = solve_ivp::<f64, Euler>(&f, t0, &y0, tf, h, Some(&mut event_manager));
assert!(sol.t.len() >= 2);
let event_time = sol.t.last().unwrap();
assert!(*event_time > 0.0 && *event_time < 1.0);
let event_state = sol.y.last().unwrap();
assert_equal_to_decimal!(*event_state, 1.5, 10);
}
#[test]
fn test_solve_ivp_scalar() {
let f = |_t: f64, y: &f64| *y;
let y0 = 1.0;
let t0 = 0.0;
let tf = 3.0;
let h = 1.0;
let sol = solve_ivp::<f64, Euler>(&f, t0, &y0, tf, h, None);
assert_eq!(sol.t, [0.0, 1.0, 2.0, 3.0]);
assert_eq!(sol.y, [1.0, 2.0, 4.0, 8.0]);
}
#[test]
fn test_solve_ivp_event_detection_on_state() {
let f = |_t: f64, y: &f64| *y;
let y0 = 1.0;
let t0 = 0.0;
let tf = 3.0;
let h = 1.0;
let event = Event::new(|_t: f64, y: &f64| y - 3.5);
let mut event_manager = EventManager::new(vec![&event]);
let sol = solve_ivp::<f64, Euler>(&f, t0, &y0, tf, h, Some(&mut event_manager));
assert_eq!(sol.t, [0.0, 1.0, 1.7499999999999998]);
assert_eq!(sol.y, [1.0, 2.0, 3.4999999999999996]);
}
#[test]
fn test_solve_ivp_state_reset() {
let f = |_t: f64, y: &f64| -y;
let y0 = 10.0;
let t0 = 0.0;
let tf = 3.0;
let h = 0.5;
let event = Event::new(|_t: f64, y: &f64| y - 5.0)
.s(|_t: f64, _y: &f64| 10.0)
.termination(Termination::new(0));
let mut event_manager = EventManager::new(vec![&event]);
let sol = solve_ivp::<f64, Euler>(&f, t0, &y0, tf, h, Some(&mut event_manager));
let final_y = sol.y.last().unwrap();
assert!(*final_y > 3.0);
assert!(event_manager.num_detections[0] > 0);
let values_near_10 = sol.y.iter().filter(|&&y| (y - 10.0).abs() < 1.0).count();
assert!(values_near_10 >= 2);
}
#[test]
#[cfg(feature = "nalgebra")]
fn test_solve_ivp_vector() {
let b = 5.0; let k = 1.0; let m = 2.0;
let x0 = 1.0; let xdot0 = 0.0;
let f = |t: f64, y: &DVector<f64>| {
DVector::<f64>::from_row_slice(&[
y[1],
-(b / m) * y[1] - (k / m) * y[0] + (1.0 / m) * t.sin(),
])
};
let y0 = dvector![x0, xdot0];
let t0 = 0.0;
let tf = 1.0;
let h = 0.1;
let sol = solve_ivp::<DVector<f64>, Euler>(&f, t0, &y0, tf, h, None);
assert_arrays_equal_to_decimal!(
sol.t,
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
15
);
assert_arrays_equal!(
sol.get_state_variable::<DVector<f64>>(&StateIndex::Vector(0)),
[
1.0,
1.0,
0.995,
0.9867491670832341,
0.976579389049635,
0.9654959107223262,
0.9542474967431397,
0.9433808343981592,
0.9332828125226833,
0.9242134803802741,
0.9163318476653514
]
);
assert_arrays_equal!(
sol.get_state_variable::<DVector<f64>>(&StateIndex::Vector(1)),
[
0.0,
-0.05,
-0.0825083291676586,
-0.10169778033599089,
-0.11083478327308789,
-0.11248413979186514,
-0.10866662344980502,
-0.10098021875475897,
-0.09069332142409263,
-0.0788163271492275,
-0.06615657389956016
]
);
}
#[test]
#[cfg(feature = "nalgebra")]
fn test_solve_ivp_matrix() {
let f = |t: f64, y: &SMatrix<f64, 2, 2>| {
SMatrix::<f64, 2, 2>::from_row_slice(&[
y[(0, 1)],
-2.5 * y[(0, 1)] - 0.5 * y[(0, 0)] + 0.5 * t.sin(),
y[(1, 0)],
0.5 * y[(1, 1)],
])
};
let y0 = SMatrix::<f64, 2, 2>::from_row_slice(&[1.0, 0.0, 1.0, 1.0]);
let t0 = 0.0;
let tf = 1.0;
let h = 0.1;
let sol = solve_ivp::<SMatrix<f64, 2, 2>, Euler>(&f, t0, &y0, tf, h, None);
assert_arrays_equal_to_decimal!(
sol.t,
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
15
);
assert_eq!(
sol.get_state_variable::<Vec<f64>>(&StateIndex::Matrix(0, 0)),
[
1.0,
1.0,
0.995,
0.9867491670832341,
0.976579389049635,
0.9654959107223262,
0.9542474967431397,
0.9433808343981592,
0.9332828125226833,
0.9242134803802741,
0.9163318476653514
]
);
assert_eq!(
sol.get_state_variable::<Vec<f64>>(&StateIndex::Matrix(0, 1)),
[
0.0,
-0.05,
-0.0825083291676586,
-0.10169778033599089,
-0.11083478327308789,
-0.11248413979186514,
-0.10866662344980502,
-0.10098021875475897,
-0.09069332142409263,
-0.0788163271492275,
-0.06615657389956016
]
);
assert_eq!(
sol.get_state_variable::<Vec<f64>>(&StateIndex::Matrix(1, 0)),
[
1.0,
1.1,
1.2100000000000002,
1.3310000000000002,
1.4641000000000002,
1.61051,
1.7715610000000002,
1.9487171,
2.1435888100000002,
2.357947691,
2.5937424601
]
);
assert_eq!(
sol.get_state_variable::<Vec<f64>>(&StateIndex::Matrix(1, 1)),
[
1.0,
1.05,
1.1025,
1.1576250000000001,
1.2155062500000002,
1.2762815625000004,
1.3400956406250004,
1.4071004226562505,
1.477455443789063,
1.5513282159785162,
1.628894626777442
]
);
}
#[test]
fn test_solve_ivp_stress_time_termination() {
let f = |_t: f64, _y: &f64| 1.0;
let y0 = 0.0;
let h = 1.0;
let t0 = 0.0;
let tf = 4.5;
let sol = solve_ivp::<f64, Euler>(&f, t0, &y0, tf, h, None);
assert_eq!(sol.t, [0.0, 1.0, 2.0, 3.0, 4.0, 4.5]);
assert_arrays_equal!(sol.y, [0.0, 1.0, 2.0, 3.0, 4.0, 4.5]);
}
}