use crate::common::IntegrateFloat;
use crate::error::{IntegrateError, IntegrateResult};
use crate::ode::utils::dense_output::DenseSolution;
use scirs2_core::ndarray::{Array1, ArrayView1};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EventDirection {
Rising,
Falling,
#[default]
Both,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EventAction {
#[default]
Continue,
Stop,
}
#[derive(Debug, Clone)]
pub struct EventSpec<F: IntegrateFloat> {
pub id: String,
pub direction: EventDirection,
pub action: EventAction,
pub threshold: F,
pub max_count: Option<usize>,
pub precise_time: bool,
}
impl<F: IntegrateFloat> EventSpec<F> {
pub fn max_count_reached(&self, currentcount: Option<usize>) -> bool {
if let Some(max) = self.max_count {
if let Some(_count) = currentcount {
return _count >= max;
}
}
false
}
}
impl<F: IntegrateFloat> Default for EventSpec<F> {
fn default() -> Self {
EventSpec {
id: "default".to_string(),
direction: EventDirection::default(),
action: EventAction::default(),
threshold: F::from_f64(1e-6).expect("Operation failed"),
max_count: None,
precise_time: true,
}
}
}
#[derive(Debug, Clone)]
pub struct Event<F: IntegrateFloat> {
pub id: String,
pub time: F,
pub state: Array1<F>,
pub value: F,
pub direction: i8,
}
#[derive(Debug, Clone)]
pub struct EventRecord<F: IntegrateFloat> {
pub events: Vec<Event<F>>,
pub counts: std::collections::HashMap<String, usize>,
}
impl<F: IntegrateFloat> Default for EventRecord<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: IntegrateFloat> EventRecord<F> {
pub fn new() -> Self {
EventRecord {
events: Vec::new(),
counts: std::collections::HashMap::new(),
}
}
pub fn add_event(&mut self, event: Event<F>) {
*self.counts.entry(event.id.clone()).or_insert(0) += 1;
self.events.push(event);
}
pub fn get_count(&self, id: &str) -> usize {
*self.counts.get(id).unwrap_or(&0)
}
pub fn get_events(&self, id: &str) -> Vec<&Event<F>> {
self.events.iter().filter(|e| e.id == id).collect()
}
pub fn max_count_reached(&self, _id: &str, maxcount: Option<usize>) -> bool {
if let Some(max) = maxcount {
self.get_count(_id) >= max
} else {
false
}
}
}
#[derive(Debug)]
pub struct EventHandler<F: IntegrateFloat> {
pub specs: Vec<EventSpec<F>>,
pub record: EventRecord<F>,
last_values: Vec<Option<F>>,
last_state: Option<(F, Array1<F>)>,
}
impl<F: IntegrateFloat> EventHandler<F> {
pub fn new(specs: Vec<EventSpec<F>>) -> Self {
let last_values = vec![None; specs.len()];
EventHandler {
specs,
record: EventRecord::new(),
last_values,
last_state: None,
}
}
pub fn initialize<Func>(
&mut self,
t: F,
y: &Array1<F>,
event_funcs: &[Func],
) -> IntegrateResult<()>
where
Func: Fn(F, ArrayView1<F>) -> F,
{
self.last_state = Some((t, y.clone()));
for (i, func) in event_funcs.iter().enumerate() {
let value = func(t, y.view());
self.last_values[i] = Some(value);
}
Ok(())
}
pub fn check_events<Func>(
&mut self,
t: F,
y: &Array1<F>,
dense_output: Option<&DenseSolution<F>>,
event_funcs: &[Func],
) -> IntegrateResult<EventAction>
where
Func: Fn(F, ArrayView1<F>) -> F,
{
if event_funcs.len() != self.specs.len() {
return Err(IntegrateError::ValueError(
"Number of event functions does not match number of event specifications"
.to_string(),
));
}
if self.last_state.is_none() {
self.initialize(t, y, event_funcs)?;
return Ok(EventAction::Continue);
}
let (t_prev, y_prev) = self.last_state.as_ref().expect("Operation failed");
let mut action = EventAction::Continue;
for (i, (func, spec)) in event_funcs.iter().zip(self.specs.iter()).enumerate() {
if spec.max_count_reached(self.record.counts.get(&spec.id).cloned()) {
continue;
}
let value = func(t, y.view());
if let Some(prev_value) = self.last_values[i] {
let rising = prev_value < F::zero() && value >= F::zero();
let falling = prev_value > F::zero() && value <= F::zero();
let triggered = match spec.direction {
EventDirection::Rising => rising,
EventDirection::Falling => falling,
EventDirection::Both => rising || falling,
};
if triggered {
let (event_t, event_y, event_val, dir) =
if let (true, Some(dense)) = (spec.precise_time, dense_output) {
self.refine_event_time(
*t_prev, y_prev, t, y, prev_value, value, func, dense,
)?
} else {
let dir = if rising { 1 } else { -1 };
(t, y.clone(), value, dir)
};
let event = Event {
id: spec.id.clone(),
time: event_t,
state: event_y,
value: event_val,
direction: dir,
};
self.record.add_event(event);
if spec.action == EventAction::Stop {
action = EventAction::Stop;
}
}
}
self.last_values[i] = Some(value);
}
self.last_state = Some((t, y.clone()));
Ok(action)
}
#[allow(clippy::too_many_arguments)]
fn refine_event_time<Func>(
&self,
t_prev: F,
y_prev: &Array1<F>,
t_curr: F,
y_curr: &Array1<F>,
value_prev: F,
value_curr: F,
event_func: &Func,
dense_output: &DenseSolution<F>,
) -> IntegrateResult<(F, Array1<F>, F, i8)>
where
Func: Fn(F, ArrayView1<F>) -> F,
{
let direction: i8 = if value_prev < F::zero() && value_curr >= F::zero() {
1 } else {
-1 };
let tol = F::from_f64(1e-10).expect("Operation failed");
let max_iter = 50;
let mut t_left = t_prev;
let mut t_right = t_curr;
let mut f_left = value_prev;
let f_right = value_curr;
if f_left.abs() < tol {
return Ok((t_left, y_prev.clone(), f_left, direction));
}
if f_right.abs() < tol {
return Ok((t_right, y_curr.clone(), f_right, direction));
}
let mut t_mid = F::zero();
let mut y_mid = Array1::<F>::zeros(y_prev.len());
let mut f_mid = F::zero();
for _ in 0..max_iter {
t_mid = (t_left + t_right) / F::from_f64(2.0).expect("Operation failed");
y_mid = dense_output.evaluate(t_mid)?;
f_mid = event_func(t_mid, y_mid.view());
if f_mid.abs() < tol || (t_right - t_left).abs() < tol {
break;
}
if f_left * f_mid < F::zero() {
t_right = t_mid;
let _f_right = f_mid;
} else {
t_left = t_mid;
f_left = f_mid;
}
}
Ok((t_mid, y_mid, f_mid, direction))
}
pub fn get_record(&self) -> &EventRecord<F> {
&self.record
}
pub fn should_stop(&self) -> bool {
self.record.events.iter().any(|e| {
let spec = self
.specs
.iter()
.find(|s| s.id == e.id)
.expect("Operation failed");
spec.action == EventAction::Stop
})
}
}
#[allow(dead_code)]
pub fn terminal_event<F: IntegrateFloat>(id: &str, direction: EventDirection) -> EventSpec<F> {
EventSpec {
id: id.to_string(),
direction,
action: EventAction::Stop,
threshold: F::from_f64(1e-6).expect("Operation failed"),
max_count: Some(1),
precise_time: true,
}
}
#[derive(Debug, Clone)]
pub struct ODEOptionsWithEvents<F: IntegrateFloat> {
pub base_options: super::super::types::ODEOptions<F>,
pub event_specs: Vec<EventSpec<F>>,
}
impl<F: IntegrateFloat> ODEOptionsWithEvents<F> {
pub fn new(
base_options: super::super::types::ODEOptions<F>,
event_specs: Vec<EventSpec<F>>,
) -> Self {
ODEOptionsWithEvents {
base_options,
event_specs,
}
}
}
#[derive(Debug)]
pub struct ODEResultWithEvents<F: IntegrateFloat> {
pub base_result: super::super::types::ODEResult<F>,
pub events: EventRecord<F>,
pub dense_output: Option<DenseSolution<F>>,
pub event_termination: bool,
}
impl<F: IntegrateFloat> ODEResultWithEvents<F> {
pub fn new(
base_result: super::super::types::ODEResult<F>,
events: EventRecord<F>,
dense_output: Option<DenseSolution<F>>,
event_termination: bool,
) -> Self {
ODEResultWithEvents {
base_result,
events,
dense_output,
event_termination,
}
}
pub fn at_time(&self, t: F) -> IntegrateResult<Option<Array1<F>>> {
if let Some(ref dense) = self.dense_output {
Ok(Some(dense.evaluate(t)?))
} else {
for (i, &ti) in self.base_result.t.iter().enumerate() {
if (ti - t).abs() < F::from_f64(1e-10).expect("Operation failed") {
return Ok(Some(self.base_result.y[i].clone()));
}
}
Ok(None)
}
}
pub fn get_events(&self, id: &str) -> Vec<&Event<F>> {
self.events.get_events(id)
}
pub fn first_event(&self, id: &str) -> Option<&Event<F>> {
self.events.get_events(id).first().copied()
}
}