use super::*;
use crate::linalg::dot;
pub type ExtractorFn<Y, P> = fn(&Y) -> P;
pub struct HyperplaneCrossingSolout<T, Y1, Y2>
where
T: Real,
Y1: State<T>,
Y2: State<T>,
{
point: Y1,
normal: Y1,
extractor: ExtractorFn<Y2, Y1>,
last_distance: Option<T>,
direction: CrossingDirection,
_phantom: std::marker::PhantomData<Y2>,
}
impl<T, Y1, Y2> HyperplaneCrossingSolout<T, Y1, Y2>
where
T: Real,
Y1: State<T>,
Y2: State<T>,
{
pub fn new(point: Y1, mut normal: Y1, extractor: ExtractorFn<Y2, Y1>) -> Self {
let norm = |y: Y1| {
let mut norm = T::zero();
for i in 0..y.len() {
norm += y.get(i).powi(2);
}
norm.sqrt()
};
let norm = norm(normal);
if norm > T::default_epsilon() {
normal = normal * T::one() / norm;
}
HyperplaneCrossingSolout {
point,
normal,
extractor,
last_distance: None,
direction: CrossingDirection::Both,
_phantom: std::marker::PhantomData,
}
}
pub fn with_direction(mut self, direction: CrossingDirection) -> Self {
self.direction = direction;
self
}
pub fn positive_only(mut self) -> Self {
self.direction = CrossingDirection::Positive;
self
}
pub fn negative_only(mut self) -> Self {
self.direction = CrossingDirection::Negative;
self
}
fn signed_distance(&self, pos: &Y1) -> T {
let displacement = *pos - self.point;
dot(&displacement, &self.normal)
}
}
impl<T, Y1, Y2> Solout<T, Y2> for HyperplaneCrossingSolout<T, Y1, Y2>
where
T: Real,
Y1: State<T>,
Y2: State<T>,
{
fn solout<I>(
&mut self,
t_curr: T,
t_prev: T,
y_curr: &Y2,
_y_prev: &Y2,
interpolator: &mut I,
solution: &mut Solution<T, Y2>,
) -> ControlFlag<T, Y2>
where
I: Interpolation<T, Y2>,
{
let pos_curr = (self.extractor)(y_curr);
let distance = self.signed_distance(&pos_curr);
if let Some(last_distance) = self.last_distance {
let zero = T::zero();
let is_crossing = last_distance.signum() != distance.signum()
|| (last_distance == zero && distance != zero)
|| (last_distance != zero && distance == zero);
if is_crossing {
let record_crossing = match self.direction {
CrossingDirection::Positive => last_distance < zero && distance >= zero,
CrossingDirection::Negative => last_distance > zero && distance <= zero,
CrossingDirection::Both => true, };
if record_crossing {
if let Some(t_cross) = self.find_crossing_newton(
interpolator,
t_prev,
t_curr,
last_distance,
distance,
) {
let y_cross = interpolator.interpolate(t_cross).unwrap();
solution.push(t_cross, y_cross);
} else {
let frac = -last_distance / (distance - last_distance);
let t_cross = t_prev + frac * (t_curr - t_prev);
let y_cross = interpolator.interpolate(t_cross).unwrap();
solution.push(t_cross, y_cross);
}
}
}
}
self.last_distance = Some(distance);
ControlFlag::Continue
}
}
impl<T, Y1, Y2> HyperplaneCrossingSolout<T, Y1, Y2>
where
T: Real,
Y1: State<T>,
Y2: State<T>,
{
fn find_crossing_newton<I>(
&self,
interpolator: &mut I,
t_lower: T,
t_upper: T,
dist_lower: T,
dist_upper: T,
) -> Option<T>
where
I: Interpolation<T, Y2>,
{
let mut t = t_lower - dist_lower * (t_upper - t_lower) / (dist_upper - dist_lower);
let max_iterations = 10;
let tolerance = T::default_epsilon() * T::from_f64(100.0).unwrap(); let mut dist;
for _ in 0..max_iterations {
let y_t = interpolator.interpolate(t).unwrap();
let pos_t = (self.extractor)(&y_t);
dist = self.signed_distance(&pos_t);
if dist.abs() < tolerance {
return Some(t);
}
let delta_t = (t_upper - t_lower) * T::from_f64(1e-6).unwrap();
let t_plus = t + delta_t;
let y_plus = interpolator.interpolate(t_plus).unwrap();
let pos_plus = (self.extractor)(&y_plus);
let dist_plus = self.signed_distance(&pos_plus);
let derivative = (dist_plus - dist) / delta_t;
if derivative.abs() < T::default_epsilon() {
break;
}
let t_new = t - dist / derivative;
if t_new < t_lower || t_new > t_upper {
t = (t_lower + t_upper) / T::from_f64(2.0).unwrap();
} else {
t = t_new;
}
}
let y_t = interpolator.interpolate(t).unwrap();
let pos_t = (self.extractor)(&y_t);
dist = self.signed_distance(&pos_t);
if dist.abs() < tolerance * T::from_f64(10.0).unwrap() {
Some(t)
} else {
None }
}
}