use super::*;
pub struct CrossingSolout<T: Real> {
component_idx: usize,
threshold: T,
last_offset_value: Option<T>,
direction: CrossingDirection,
}
impl<T: Real> CrossingSolout<T> {
pub fn new(component_idx: usize, threshold: T) -> Self {
CrossingSolout {
component_idx,
threshold,
last_offset_value: None,
direction: CrossingDirection::Both,
}
}
pub fn with_direction(mut self, direction: CrossingDirection) -> Self {
self.direction = direction;
self
}
pub fn positive_only(mut self) -> Self {
self.direction = CrossingDirection::Positive;
self
}
pub fn negative_only(mut self) -> Self {
self.direction = CrossingDirection::Negative;
self
}
}
impl<T, Y> Solout<T, Y> for CrossingSolout<T>
where
T: Real,
Y: State<T>,
{
fn solout<I>(
&mut self,
t_curr: T,
t_prev: T,
y_curr: &Y,
_y_prev: &Y,
interpolator: &mut I,
solution: &mut Solution<T, Y>,
) -> ControlFlag<T, Y>
where
I: Interpolation<T, Y>,
{
let current_value = y_curr.get(self.component_idx);
let offset_value = current_value - self.threshold;
if let Some(last_offset) = self.last_offset_value {
let zero = T::zero();
let is_crossing = last_offset.signum() != offset_value.signum();
if is_crossing {
let record_crossing = match self.direction {
CrossingDirection::Positive => last_offset < zero && offset_value >= zero,
CrossingDirection::Negative => last_offset > zero && offset_value <= zero,
CrossingDirection::Both => true, };
if record_crossing {
if let Some(t_cross) = self.find_crossing_newton(
interpolator,
t_prev,
t_curr,
last_offset,
offset_value,
) {
let y_cross = interpolator.interpolate(t_cross).unwrap();
solution.push(t_cross, y_cross);
} else {
let frac = -last_offset / (offset_value - last_offset);
let t_cross = t_prev + frac * (t_curr - t_prev);
let y_cross = interpolator.interpolate(t_cross).unwrap();
solution.push(t_cross, y_cross);
}
}
}
}
self.last_offset_value = Some(offset_value);
ControlFlag::Continue
}
}
impl<T: Real> CrossingSolout<T> {
fn find_crossing_newton<I, Y>(
&self,
interpolator: &mut I,
t_lower: T,
t_upper: T,
offset_lower: T,
offset_upper: T,
) -> Option<T>
where
I: Interpolation<T, Y>,
Y: State<T>,
{
let mut t = t_lower - offset_lower * (t_upper - t_lower) / (offset_upper - offset_lower);
let max_iterations = 10;
let tolerance = T::default_epsilon() * T::from_f64(100.0).unwrap(); let mut offset;
for _ in 0..max_iterations {
let y_t = interpolator.interpolate(t).unwrap();
offset = y_t.get(self.component_idx) - self.threshold;
if offset.abs() < tolerance {
return Some(t);
}
let delta_t = (t_upper - t_lower) * T::from_f64(1e-6).unwrap();
let t_plus = t + delta_t;
let y_plus = interpolator.interpolate(t_plus).unwrap();
let offset_plus = y_plus.get(self.component_idx) - self.threshold;
let derivative = (offset_plus - offset) / delta_t;
if derivative.abs() < T::default_epsilon() * T::from_f64(10.0).unwrap() {
break;
}
let t_new = t - offset / derivative;
if t_new < t_lower || t_new > t_upper {
t = (t_lower + t_upper) / T::from_f64(2.0).unwrap();
} else {
let change = (t_new - t).abs();
if change < tolerance * T::from_f64(0.1).unwrap() {
t = t_new;
break;
}
t = t_new;
}
}
let y_t = interpolator.interpolate(t).unwrap();
offset = y_t.get(self.component_idx) - self.threshold;
if offset.abs() < tolerance * T::from_f64(10.0).unwrap() {
Some(t)
} else {
None }
}
}