use super::PropagationError;
use crate::State;
use crate::dynamics::Dynamics;
use crate::linalg::DefaultAllocator;
use crate::linalg::allocator::Allocator;
use crate::md::trajectory::{Interpolatable, Traj};
use crate::propagators::{PropAlmanacSnafu, PropAnalysisSnafu, TrajectoryEventSnafu};
use crate::time::{Duration, Epoch};
use anise::analysis::event::Event;
use anise::analysis::{AnalysisError, brent_solver};
use anise::frames::Frame;
use log::info;
use rayon::iter::ParallelBridge;
use rayon::prelude::ParallelIterator;
use snafu::ResultExt;
use std::f64;
use std::sync::mpsc::channel;
#[cfg(not(target_arch = "wasm32"))]
use std::time::Instant;
use super::PropInstance;
impl<D: Dynamics> PropInstance<'_, D>
where
DefaultAllocator: Allocator<<D::StateType as State>::Size>
+ Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<<D::StateType as State>::VecLength>,
{
pub fn until_event(
&mut self,
max_duration: Duration,
event: &Event,
event_frame: Option<Frame>,
) -> Result<(D::StateType, Traj<D::StateType>), PropagationError>
where
<DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
D::StateType: Interpolatable,
{
self.until_nth_event(max_duration, event, event_frame, 1)
}
pub fn until_nth_event(
&mut self,
max_duration: Duration,
event: &Event,
event_frame: Option<Frame>,
trigger: usize,
) -> Result<(D::StateType, Traj<D::StateType>), PropagationError>
where
<DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
D::StateType: Interpolatable,
{
info!("Propagating until {event} or {max_duration}");
let mut crossing_counts = 0;
let closure_almanac = self.almanac.clone();
let orbit = if let Some(observer_frame) = event_frame {
self.almanac
.transform_to(self.state.orbit(), observer_frame, None)
.context(PropAlmanacSnafu)?
} else {
self.state.orbit()
};
let mut y_prev = event
.eval(orbit, &self.almanac)
.context(PropAnalysisSnafu)?;
let enough_crossings = |next_state: D::StateType| -> Result<bool, PropagationError> {
let orbit = if let Some(observer_frame) = event_frame {
closure_almanac
.transform_to(next_state.orbit(), observer_frame, None)
.context(PropAlmanacSnafu)?
} else {
next_state.orbit()
};
let y_next = event
.eval(orbit, &closure_almanac)
.context(PropAnalysisSnafu)?;
let delta = (y_next - y_prev).abs();
if event.scalar.is_angle() {
if y_prev.signum() != y_next.signum() && delta < 180.0 {
crossing_counts += 1;
}
} else {
if y_prev * y_next < 0.0 {
crossing_counts += 1;
}
}
y_prev = y_next;
Ok(crossing_counts >= trigger)
};
let end_state;
let mut traj = Traj::new();
let start_state = self.state;
#[cfg(not(target_arch = "wasm32"))]
let tick = Instant::now();
let rx = {
let (tx, rx) = channel();
end_state = self.propagate(max_duration, Some(tx), Some(enough_crossings))?;
rx
};
traj.states = rx.into_iter().par_bridge().collect();
traj.states.push(start_state);
traj.finalize();
let last_traj_state = traj.states.last().cloned().unwrap();
if end_state == last_traj_state {
return Err(PropagationError::NthEventError {
nth: trigger,
found: crossing_counts,
});
}
traj.states.push(end_state);
let traj_at = |epoch: Epoch| -> Result<f64, AnalysisError> {
let state = traj
.at(epoch)
.map_err(|e| AnalysisError::GenericAnalysisError {
err: format!("{e}"),
})?;
event.eval(state.orbit(), &self.almanac)
};
let event_epoch = brent_solver(traj_at, event, last_traj_state.epoch(), end_state.epoch())
.context(PropAnalysisSnafu)?;
#[cfg(not(target_arch = "wasm32"))]
{
if self.log_progress {
let tock: Duration = tick.elapsed().into();
if trigger > 1 {
info!("\t... event #{trigger} found in {tock}");
} else {
info!("\t... event found in {tock}");
}
}
}
Ok((traj.at(event_epoch).context(TrajectoryEventSnafu)?, traj))
}
}