use super::error_ctrl::{ErrorCtrl, RSSCartesianStep};
use super::rayon::iter::ParallelBridge;
use super::rayon::prelude::ParallelIterator;
use super::{Dormand78, IntegrationDetails, RK, RK89};
use crate::dynamics::Dynamics;
use crate::errors::NyxError;
use crate::linalg::allocator::Allocator;
use crate::linalg::{DefaultAllocator, OVector};
use crate::md::trajectory::spline::INTERPOLATION_SAMPLES;
use crate::md::trajectory::{interpolate, InterpState, Traj, TrajError};
use crate::md::EventEvaluator;
use crate::time::{Duration, Epoch, Unit};
use crate::State;
use std::collections::BTreeMap;
use std::f64;
use std::sync::mpsc::{channel, Sender};
#[derive(Clone, Debug)]
pub struct Propagator<'a, D: Dynamics, E: ErrorCtrl>
where
DefaultAllocator: Allocator<f64, <D::StateType as State>::Size>
+ Allocator<f64, <D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<usize, <D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<f64, <D::StateType as State>::VecLength>,
{
pub dynamics: D, pub opts: PropOpts<E>, order: u8, stages: usize, a_coeffs: &'a [f64],
b_coeffs: &'a [f64],
}
impl<'a, D: Dynamics, E: ErrorCtrl> Propagator<'a, D, E>
where
DefaultAllocator: Allocator<f64, <D::StateType as State>::Size>
+ Allocator<f64, <D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<usize, <D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<f64, <D::StateType as State>::VecLength>,
{
pub fn new<T: RK>(dynamics: D, opts: PropOpts<E>) -> Self {
Self {
dynamics,
opts,
stages: T::STAGES,
order: T::ORDER,
a_coeffs: T::A_COEFFS,
b_coeffs: T::B_COEFFS,
}
}
pub fn set_tolerance(&mut self, tol: f64) {
self.opts.tolerance = tol;
}
pub fn set_max_step(&mut self, step: Duration) {
self.opts.set_max_step(step);
}
pub fn set_min_step(&mut self, step: Duration) {
self.opts.set_min_step(step);
}
pub fn rk89(dynamics: D, opts: PropOpts<E>) -> Self {
Self::new::<RK89>(dynamics, opts)
}
pub fn dp78(dynamics: D, opts: PropOpts<E>) -> Self {
Self::new::<Dormand78>(dynamics, opts)
}
pub fn with(&'a self, state: D::StateType) -> PropInstance<'a, D, E> {
let mut k = Vec::with_capacity(self.stages + 1);
for _ in 0..self.stages {
k.push(OVector::<f64, <D::StateType as State>::VecLength>::zeros());
}
PropInstance {
state,
prop: self,
details: IntegrationDetails {
step: self.opts.init_step,
error: 0.0,
attempts: 1,
},
step_size: self.opts.init_step,
fixed_step: self.opts.fixed_step,
k,
}
}
}
impl<'a, D: Dynamics> Propagator<'a, D, RSSCartesianStep>
where
DefaultAllocator: Allocator<f64, <D::StateType as State>::Size>
+ Allocator<f64, <D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<usize, <D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<f64, <D::StateType as State>::VecLength>,
{
pub fn default(dynamics: D) -> Self {
Self::new::<RK89>(dynamics, PropOpts::default())
}
pub fn default_dp78(dynamics: D) -> Self {
Self::new::<Dormand78>(dynamics, PropOpts::default())
}
}
#[derive(Debug)]
pub struct PropInstance<'a, D: Dynamics, E: ErrorCtrl>
where
DefaultAllocator: Allocator<f64, <D::StateType as State>::Size>
+ Allocator<f64, <D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<usize, <D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<f64, <D::StateType as State>::VecLength>,
{
pub state: D::StateType,
pub prop: &'a Propagator<'a, D, E>,
pub details: IntegrationDetails,
step_size: Duration, fixed_step: bool,
k: Vec<OVector<f64, <D::StateType as State>::VecLength>>,
}
impl<'a, D: Dynamics, E: ErrorCtrl> PropInstance<'a, D, E>
where
DefaultAllocator: Allocator<f64, <D::StateType as State>::Size>
+ Allocator<f64, <D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<usize, <D::StateType as State>::Size, <D::StateType as State>::Size>
+ Allocator<f64, <D::StateType as State>::VecLength>,
{
pub fn set_step(&mut self, step_size: Duration, fixed: bool) {
self.step_size = step_size;
self.fixed_step = fixed;
}
#[allow(clippy::erasing_op)]
fn for_duration_channel_option(
&mut self,
duration: Duration,
maybe_tx_chan: Option<Sender<D::StateType>>,
) -> Result<D::StateType, NyxError> {
if duration == 0 * Unit::Second {
return Ok(self.state);
}
let stop_time = self.state.epoch() + duration;
if duration > 2 * Unit::Minute || duration < -2 * Unit::Minute {
info!("Propagating for {} until {}", duration, stop_time);
}
self.state = self.prop.dynamics.finally(self.state)?;
let backprop = duration < Unit::Nanosecond;
if backprop {
self.step_size = -self.step_size; }
loop {
let dt = self.state.epoch();
if (!backprop && dt + self.step_size > stop_time)
|| (backprop && dt + self.step_size <= stop_time)
{
if stop_time == dt {
return Ok(self.state);
}
let prev_step_size = self.step_size;
let prev_step_kind = self.fixed_step;
self.set_step(stop_time - dt, true);
self.single_step()?;
if let Some(ref chan) = maybe_tx_chan {
if let Err(e) = chan.send(self.state) {
warn!("could not publish to channel: {}", e)
}
}
self.set_step(prev_step_size, prev_step_kind);
if backprop {
self.step_size = -self.step_size; }
return Ok(self.state);
} else {
self.single_step()?;
if let Some(ref chan) = maybe_tx_chan {
if let Err(e) = chan.send(self.state) {
warn!("could not publish to channel: {}", e)
}
}
}
}
}
pub fn for_duration(&mut self, duration: Duration) -> Result<D::StateType, NyxError> {
self.for_duration_channel_option(duration, None)
}
pub fn for_duration_with_channel(
&mut self,
duration: Duration,
tx_chan: Sender<D::StateType>,
) -> Result<D::StateType, NyxError> {
self.for_duration_channel_option(duration, Some(tx_chan))
}
pub fn until_epoch(&mut self, end_time: Epoch) -> Result<D::StateType, NyxError> {
let duration: Duration = end_time - self.state.epoch();
self.for_duration(duration)
}
pub fn until_epoch_with_channel(
&mut self,
end_time: Epoch,
tx_chan: Sender<D::StateType>,
) -> Result<D::StateType, NyxError> {
let duration: Duration = end_time - self.state.epoch();
self.for_duration_with_channel(duration, tx_chan)
}
#[allow(clippy::map_clone)]
pub fn for_duration_with_traj(
&mut self,
duration: Duration,
) -> Result<(D::StateType, Traj<D::StateType>), NyxError>
where
<DefaultAllocator as Allocator<f64, <D::StateType as State>::VecLength>>::Buffer: Send,
D::StateType: InterpState,
{
let start_state = self.state;
let end_state;
let rx = {
let (tx_bucket, rx_bucket) = channel();
let rx = {
let (tx, rx) = channel();
end_state = self.for_duration_with_channel(duration, tx)?;
rx
};
let items_per_segments = INTERPOLATION_SAMPLES;
let mut window_states = Vec::with_capacity(2 * items_per_segments);
window_states.push(start_state);
while let Ok(state) = rx.recv() {
window_states.push(state);
if window_states.len() == 2 * items_per_segments {
let this_wdn = window_states[..items_per_segments]
.iter()
.map(|&x| x)
.collect::<Vec<D::StateType>>();
tx_bucket.send(this_wdn).map_err(|_| {
NyxError::from(TrajError::CreationError(
"could not send onto channel".to_string(),
))
})?;
for _ in 0..items_per_segments - 1 {
window_states.remove(0);
}
}
}
if window_states.len() < items_per_segments {
let step_size =
(end_state.epoch() - start_state.epoch()) / ((items_per_segments - 1) as f64);
self.state = start_state;
window_states.clear();
self.set_step(step_size, true);
let rx = {
let (tx, rx) = channel();
self.for_duration_with_channel(duration, tx)?;
rx
};
window_states.push(start_state);
while let Ok(state) = rx.recv() {
window_states.push(state);
}
}
let mut start_idx = 0;
loop {
tx_bucket
.send(
window_states
[start_idx..(start_idx + items_per_segments).min(window_states.len())]
.iter()
.map(|&x| x)
.collect::<Vec<D::StateType>>(),
)
.map_err(|_| {
NyxError::from(TrajError::CreationError(
"could not send onto channel".to_string(),
))
})?;
if start_idx > 0 || window_states.len() < items_per_segments {
break;
}
start_idx = window_states.len() - items_per_segments;
if start_idx == 0 {
break;
}
}
rx_bucket
};
let splines: Vec<_> = rx.into_iter().par_bridge().map(interpolate).collect();
let mut traj = Traj {
segments: BTreeMap::new(),
start_state,
backward: false,
};
for maybe_spline in splines {
let spline = maybe_spline?;
traj.append_spline(spline)?;
}
Ok((end_state, traj))
}
pub fn until_epoch_with_traj(
&mut self,
end_time: Epoch,
) -> Result<(D::StateType, Traj<D::StateType>), NyxError>
where
<DefaultAllocator as Allocator<f64, <D::StateType as State>::VecLength>>::Buffer: Send,
D::StateType: InterpState,
{
let duration: Duration = end_time - self.state.epoch();
self.for_duration_with_traj(duration)
}
pub fn until_event<F: EventEvaluator<D::StateType>>(
&mut self,
max_duration: Duration,
event: &F,
) -> Result<(D::StateType, Traj<D::StateType>), NyxError>
where
<DefaultAllocator as Allocator<f64, <D::StateType as State>::VecLength>>::Buffer: Send,
D::StateType: InterpState,
{
self.until_nth_event(max_duration, event, 0)
}
pub fn until_nth_event<F: EventEvaluator<D::StateType>>(
&mut self,
max_duration: Duration,
event: &F,
trigger: usize,
) -> Result<(D::StateType, Traj<D::StateType>), NyxError>
where
<DefaultAllocator as Allocator<f64, <D::StateType as State>::VecLength>>::Buffer: Send,
D::StateType: InterpState,
{
info!("Searching for {}", event);
let (_, traj) = self.for_duration_with_traj(max_duration)?;
let events = traj.find_all(event)?;
match events.get(trigger) {
Some(event_state) => Ok((*event_state, traj)),
None => Err(NyxError::UnsufficientTriggers(trigger, events.len())),
}
}
pub fn single_step(&mut self) -> Result<(), NyxError> {
let (t, state_vec) = self.derive()?;
self.state.set(self.state.epoch() + t, &state_vec)?;
self.state = self.prop.dynamics.finally(self.state)?;
Ok(())
}
fn derive(
&mut self,
) -> Result<(Duration, OVector<f64, <D::StateType as State>::VecLength>), NyxError> {
let state = &self.state.as_vector()?;
let ctx = &self.state;
self.details.attempts = 1;
let mut step_size = self.step_size.in_seconds();
loop {
let ki = self.prop.dynamics.eom(0.0, state, ctx)?;
self.k[0] = ki;
let mut a_idx: usize = 0;
for i in 0..(self.prop.stages - 1) {
let mut ci: f64 = 0.0;
let mut wi = OVector::<f64, <D::StateType as State>::VecLength>::from_element(0.0);
for kj in &self.k[0..i + 1] {
let a_ij = self.prop.a_coeffs[a_idx];
ci += a_ij;
wi += a_ij * kj;
a_idx += 1;
}
let ki = self
.prop
.dynamics
.eom(ci * step_size, &(state + step_size * wi), ctx)?;
self.k[i + 1] = ki;
}
let mut next_state = state.clone();
let mut error_est =
OVector::<f64, <D::StateType as State>::VecLength>::from_element(0.0);
for (i, ki) in self.k.iter().enumerate() {
let b_i = self.prop.b_coeffs[i];
if !self.fixed_step {
let b_i_star = self.prop.b_coeffs[i + self.prop.stages];
error_est += step_size * (b_i - b_i_star) * ki;
}
next_state += step_size * b_i * ki;
}
if self.fixed_step {
self.details.step = self.step_size;
return Ok(((self.details.step), next_state));
} else {
self.details.error = E::estimate(&error_est, &next_state, state);
if self.details.error <= self.prop.opts.tolerance
|| step_size <= self.prop.opts.min_step.in_seconds()
|| self.details.attempts >= self.prop.opts.attempts
{
if self.details.attempts >= self.prop.opts.attempts {
warn!(
"Could not further decrease step size: maximum number of attempts reached ({})",
self.details.attempts
);
}
self.details.step = step_size * Unit::Second;
if self.details.error < self.prop.opts.tolerance {
let proposed_step = 0.9
* step_size
* (self.prop.opts.tolerance / self.details.error)
.powf(1.0 / f64::from(self.prop.order));
step_size = if proposed_step > self.prop.opts.max_step.in_seconds() {
self.prop.opts.max_step.in_seconds()
} else {
proposed_step
};
}
self.step_size = step_size * Unit::Second;
return Ok((self.details.step, next_state));
} else {
self.details.attempts += 1;
let proposed_step = 0.9
* step_size
* (self.prop.opts.tolerance / self.details.error)
.powf(1.0 / f64::from(self.prop.order - 1));
step_size = if proposed_step < self.prop.opts.min_step.in_seconds() {
self.prop.opts.min_step.in_seconds()
} else {
proposed_step
};
}
}
}
}
pub fn latest_details(&self) -> &IntegrationDetails {
&self.details
}
}
#[derive(Clone, Copy, Debug)]
pub struct PropOpts<E: ErrorCtrl> {
pub init_step: Duration,
pub min_step: Duration,
pub max_step: Duration,
pub tolerance: f64,
pub attempts: u8,
pub fixed_step: bool,
pub _errctrl: E,
}
impl<E: ErrorCtrl> PropOpts<E> {
pub fn with_adaptive_step(
min_step: Duration,
max_step: Duration,
tolerance: f64,
errctrl: E,
) -> Self {
PropOpts {
init_step: max_step,
min_step,
max_step,
tolerance,
attempts: 50,
fixed_step: false,
_errctrl: errctrl,
}
}
pub fn with_adaptive_step_s(min_step: f64, max_step: f64, tolerance: f64, errctrl: E) -> Self {
Self::with_adaptive_step(
min_step * Unit::Second,
max_step * Unit::Second,
tolerance,
errctrl,
)
}
pub fn info(&self) -> String {
format!(
"[min_step: {:.e}, max_step: {:.e}, tol: {:.e}, attempts: {}]",
self.min_step, self.max_step, self.tolerance, self.attempts,
)
}
pub fn set_max_step(&mut self, max_step: Duration) {
if self.init_step > max_step {
self.init_step = max_step;
}
self.max_step = max_step;
}
pub fn set_min_step(&mut self, min_step: Duration) {
if self.init_step < min_step {
self.init_step = min_step;
}
self.min_step = min_step;
}
}
impl PropOpts<RSSCartesianStep> {
pub fn with_fixed_step(step: Duration) -> Self {
PropOpts {
init_step: step,
min_step: step,
max_step: step,
tolerance: 0.0,
fixed_step: true,
attempts: 0,
_errctrl: RSSCartesianStep {},
}
}
pub fn with_fixed_step_s(step: f64) -> Self {
Self::with_fixed_step(step * Unit::Second)
}
#[allow(clippy::field_reassign_with_default)]
pub fn with_tolerance(tolerance: f64) -> Self {
let mut opts = Self::default();
opts.tolerance = tolerance;
opts
}
#[allow(clippy::field_reassign_with_default)]
pub fn with_max_step(max_step: Duration) -> Self {
let mut opts = Self::default();
opts.set_max_step(max_step);
opts
}
}
impl Default for PropOpts<RSSCartesianStep> {
fn default() -> PropOpts<RSSCartesianStep> {
PropOpts {
init_step: 60.0 * Unit::Second,
min_step: 0.001 * Unit::Second,
max_step: 2700.0 * Unit::Second,
tolerance: 1e-12,
attempts: 50,
fixed_step: false,
_errctrl: RSSCartesianStep {},
}
}
}
#[test]
fn test_options() {
use super::error_ctrl::RSSStep;
let opts = PropOpts::with_fixed_step_s(1e-1);
assert_eq!(opts.min_step, 1e-1 * Unit::Second);
assert_eq!(opts.max_step, 1e-1 * Unit::Second);
assert!(opts.tolerance.abs() < f64::EPSILON);
assert!(opts.fixed_step);
let opts = PropOpts::with_adaptive_step_s(1e-2, 10.0, 1e-12, RSSStep {});
assert_eq!(opts.min_step, 1e-2 * Unit::Second);
assert_eq!(opts.max_step, 10.0 * Unit::Second);
assert!((opts.tolerance - 1e-12).abs() < f64::EPSILON);
assert!(!opts.fixed_step);
let opts: PropOpts<RSSCartesianStep> = Default::default();
assert_eq!(opts.init_step, 60.0 * Unit::Second);
assert_eq!(opts.min_step, 0.001 * Unit::Second);
assert_eq!(opts.max_step, 2700.0 * Unit::Second);
assert!((opts.tolerance - 1e-12).abs() < f64::EPSILON);
assert_eq!(opts.attempts, 50);
assert!(!opts.fixed_step);
let opts = PropOpts::with_max_step(1.0 * Unit::Second);
assert_eq!(opts.init_step, 1.0 * Unit::Second);
assert_eq!(opts.min_step, 0.001 * Unit::Second);
assert_eq!(opts.max_step, 1.0 * Unit::Second);
assert!((opts.tolerance - 1e-12).abs() < f64::EPSILON);
assert_eq!(opts.attempts, 50);
assert!(!opts.fixed_step);
}