extern crate levenberg_marquardt;
use super::optimizer::Optimizer;
use crate::dynamics::guidance::Mnvr;
use crate::errors::TargetingError;
use crate::linalg::{storage::Owned, Const, SMatrix, SVector, Vector6};
use crate::linalg::{DimMax, DimMin, ToTypenum};
use crate::md::rayon::prelude::*;
use crate::md::ui::*;
use crate::md::StateParameter;
pub use crate::md::{Variable, Vary};
use crate::polyfit::CommonPolynomial;
use crate::propagators::error_ctrl::ErrorCtrl;
use hifitime::TimeUnits;
use levenberg_marquardt::{LeastSquaresProblem, LevenbergMarquardt};
pub struct OptimizerInstance<'a, E: ErrorCtrl, const N: usize, const M: usize>
where
Const<N>: ToTypenum,
Const<M>: ToTypenum,
Const<M>: DimMin<Const<N>, Output = Const<N>> + DimMax<Const<N>, Output = Const<N>>,
{
pub prop: &'a Propagator<'a, SpacecraftDynamics<'a>, E>,
pub objectives: [Objective; M],
pub objective_frame: Option<(Frame, Arc<Cosm>)>,
pub variables: [Variable; N],
pub correction_frame: Option<Frame>,
pub spacecraft: Spacecraft,
pub achievement_epoch: Epoch,
pub correction_epoch: Epoch,
pub control: SVector<f64, N>,
pub residuals: SVector<f64, M>,
pub jacobian: SMatrix<f64, M, N>,
}
impl<'a, E: ErrorCtrl, const V: usize, const O: usize> LeastSquaresProblem<f64, Const<O>, Const<V>>
for OptimizerInstance<'a, E, V, O>
where
Const<V>: ToTypenum,
Const<O>: ToTypenum,
Const<O>: DimMin<Const<V>, Output = Const<V>> + DimMax<Const<V>, Output = Const<V>>,
{
type ResidualStorage = Owned<f64, Const<O>>;
type ParameterStorage = Owned<f64, Const<V>>;
type JacobianStorage = Owned<f64, Const<O>, Const<V>>;
fn set_params(&mut self, attempted_control: &SVector<f64, V>) {
println!("Ctrl: {}", attempted_control);
self.control = *attempted_control;
let mut is_bplane_tgt = false;
for obj in &self.objectives {
if obj.parameter.is_b_plane() {
is_bplane_tgt = true;
break;
}
}
let xi_start = self
.prop
.with(self.spacecraft)
.until_epoch(self.correction_epoch)
.unwrap();
debug!("xi_start = {}", xi_start);
let mut xi = xi_start;
let mut state_correction = Vector6::<f64>::zeros();
let mut total_correction = SVector::<f64, V>::zeros();
let mut mnvr = Mnvr {
start: self.correction_epoch,
end: self.achievement_epoch,
thrust_lvl: 1.0,
alpha_inplane_radians: CommonPolynomial::Quadratic(0.0, 0.0, 0.0),
delta_outofplane_radians: CommonPolynomial::Quadratic(0.0, 0.0, 0.0),
frame: Frame::RCN,
};
let mut finite_burn_target = false;
for (i, var) in self.variables.iter().enumerate() {
var.valid().unwrap();
if self.correction_frame.is_some() && var.component.vec_index() < 3 {
let msg = format!(
"Variable is in frame {} but that frame cannot be used for a {:?} correction",
self.correction_frame.unwrap(),
var.component
);
error!("{}", msg);
panic!();
}
if var.component.is_finite_burn() {
if xi_start.thruster.is_none() {
panic!();
}
finite_burn_target = true;
match var.component {
Vary::Duration => mnvr.end = mnvr.start + attempted_control[i].seconds(),
Vary::EndEpoch => mnvr.end += attempted_control[i].seconds(),
Vary::StartEpoch => mnvr.start += attempted_control[i].seconds(),
Vary::MnvrAlpha | Vary::MnvrAlphaDot | Vary::MnvrAlphaDDot => {
mnvr.alpha_inplane_radians = mnvr
.alpha_inplane_radians
.add_val_in_order(attempted_control[i], var.component.vec_index())
.unwrap();
}
Vary::MnvrDelta | Vary::MnvrDeltaDot | Vary::MnvrDeltaDDot => {
mnvr.delta_outofplane_radians = mnvr
.delta_outofplane_radians
.add_val_in_order(attempted_control[i], var.component.vec_index())
.unwrap();
}
Vary::ThrustX | Vary::ThrustY | Vary::ThrustZ => {
let mut vector = mnvr.vector(mnvr.start);
vector[var.component.vec_index()] = attempted_control[i];
mnvr.set_direction(vector).unwrap();
}
Vary::ThrustLevel => {
mnvr.thrust_lvl += attempted_control[i];
}
_ => unreachable!(),
}
info!("Initial maneuver guess: {}", mnvr);
} else {
state_correction[var.component.vec_index()] -= attempted_control[i];
if let Some(frame) = self.correction_frame {
let dcm_vnc2inertial = xi.orbit.dcm_from_traj_frame(frame).unwrap();
let velocity_correction =
dcm_vnc2inertial * state_correction.fixed_rows::<3>(3);
xi.orbit.apply_dv(velocity_correction);
} else {
xi.orbit.x += state_correction[0];
xi.orbit.y += state_correction[1];
xi.orbit.z += state_correction[2];
xi.orbit.vx += state_correction[3];
xi.orbit.vy += state_correction[4];
xi.orbit.vz += state_correction[5];
}
}
total_correction[i] += attempted_control[i];
}
let max_obj_val = self
.objectives
.iter()
.map(|obj| {
(obj.desired_value.abs().ceil() as i32
* 10_i32.pow(obj.tolerance.abs().log10().ceil() as u32)) as i32
})
.max()
.unwrap();
let max_obj_tol = self
.objectives
.iter()
.map(|obj| obj.tolerance.log10().abs().ceil() as usize)
.max()
.unwrap();
let width = f64::from(max_obj_val).log10() as usize + 2 + max_obj_tol;
let cur_xi = xi;
let xf = if finite_burn_target {
info!("{}", mnvr);
let mut prop = self.prop.clone();
let prop_opts = prop.opts;
let pre_mnvr = prop.with(cur_xi).until_epoch(mnvr.start).unwrap();
prop.dynamics = prop.dynamics.with_guidance_law_no_decr(Arc::new(mnvr));
prop.set_max_step(mnvr.end - mnvr.start);
let post_mnvr = prop
.with(pre_mnvr.with_guidance_mode(GuidanceMode::Thrust))
.until_epoch(mnvr.end)
.unwrap();
prop.opts = prop_opts;
prop.with(post_mnvr)
.until_epoch(self.achievement_epoch)
.unwrap()
.orbit
} else {
self.prop
.with(cur_xi)
.until_epoch(self.achievement_epoch)
.unwrap()
.orbit
};
let xf_dual_obj_frame = match &self.objective_frame {
Some((frame, cosm)) => {
let orbit_obj_frame = cosm.frame_chg(&xf, *frame);
OrbitDual::from(orbit_obj_frame)
}
None => OrbitDual::from(xf),
};
let b_plane = if is_bplane_tgt {
Some(BPlane::from_dual(xf_dual_obj_frame).unwrap())
} else {
None
};
let mut objmsg = Vec::with_capacity(self.objectives.len());
for (i, obj) in self.objectives.iter().enumerate() {
let partial = if obj.parameter.is_b_plane() {
match obj.parameter {
StateParameter::BdotR => b_plane.unwrap().b_r,
StateParameter::BdotT => b_plane.unwrap().b_t,
StateParameter::BLTOF => b_plane.unwrap().ltof_s,
_ => unreachable!(),
}
} else {
xf_dual_obj_frame.partial_for(&obj.parameter).unwrap()
};
let achieved = partial.real();
self.residuals[i] = obj.assess_raw(achieved).1;
objmsg.push(format!(
"\t{:?}: achieved = {:>width$.prec$}\t desired = {:>width$.prec$}\t scaled error = {:>width$.prec$}",
obj.parameter,
achieved,
obj.desired_value,
self.residuals[i], width=width, prec=max_obj_tol
));
let mut pert_calc: Vec<_> = self
.variables
.iter()
.enumerate()
.map(|(j, var)| (j, var, 0.0_f64))
.collect();
pert_calc.par_iter_mut().for_each(|(_, var, jac_val)| {
let mut this_xi = xi;
let mut this_prop = self.prop.clone();
let mut this_mnvr = mnvr;
if var.component.is_finite_burn() {
let pert = var.perturbation;
match var.component {
Vary::Duration => {
if pert.abs() > 1e-3 {
this_mnvr.end = mnvr.start + pert.seconds()
}
}
Vary::EndEpoch => {
if pert.abs() > 1e-3 {
this_mnvr.end = mnvr.end + pert.seconds()
}
}
Vary::StartEpoch => {
if pert.abs() > 1e-3 {
this_mnvr.start = mnvr.start + pert.seconds()
}
}
Vary::MnvrAlpha | Vary::MnvrAlphaDot | Vary::MnvrAlphaDDot => {
this_mnvr.alpha_inplane_radians = mnvr
.alpha_inplane_radians
.add_val_in_order(pert, var.component.vec_index())
.unwrap();
}
Vary::MnvrDelta | Vary::MnvrDeltaDot | Vary::MnvrDeltaDDot => {
this_mnvr.delta_outofplane_radians = mnvr
.delta_outofplane_radians
.add_val_in_order(pert, var.component.vec_index())
.unwrap();
}
Vary::ThrustX | Vary::ThrustY | Vary::ThrustZ => {
let mut vector = this_mnvr.vector(self.correction_epoch);
vector[var.component.vec_index()] += pert;
this_mnvr.set_direction(vector).unwrap();
}
Vary::ThrustLevel => {
this_mnvr.thrust_lvl += pert;
}
_ => unreachable!(),
}
} else {
let mut state_correction = Vector6::<f64>::zeros();
state_correction[var.component.vec_index()] += var.perturbation;
if let Some(frame) = self.correction_frame {
let dcm_vnc2inertial = this_xi.orbit.dcm_from_traj_frame(frame).unwrap();
let velocity_correction =
dcm_vnc2inertial * state_correction.fixed_rows::<3>(3);
this_xi.orbit.apply_dv(velocity_correction);
} else {
this_xi = xi + state_correction;
}
}
let this_xf = if finite_burn_target {
let pre_mnvr = this_prop.with(cur_xi).until_epoch(this_mnvr.start).unwrap();
let prop_opts = this_prop.opts;
this_prop.set_max_step(this_mnvr.duration());
this_prop.dynamics = this_prop.dynamics.with_guidance_law(Arc::new(this_mnvr));
let post_mnvr = this_prop
.with(pre_mnvr.with_guidance_mode(GuidanceMode::Thrust))
.until_epoch(this_mnvr.end)
.unwrap();
this_prop.opts = prop_opts;
this_prop
.with(post_mnvr)
.until_epoch(self.achievement_epoch)
.unwrap()
.orbit
} else {
this_prop
.with(this_xi)
.until_epoch(self.achievement_epoch)
.unwrap()
.orbit
};
let xf_dual_obj_frame = match &self.objective_frame {
Some((frame, cosm)) => {
let orbit_obj_frame = cosm.frame_chg(&this_xf, *frame);
OrbitDual::from(orbit_obj_frame)
}
None => OrbitDual::from(this_xf),
};
let b_plane = if is_bplane_tgt {
Some(BPlane::from_dual(xf_dual_obj_frame).unwrap())
} else {
None
};
let partial = if obj.parameter.is_b_plane() {
match obj.parameter {
StateParameter::BdotR => b_plane.unwrap().b_r,
StateParameter::BdotT => b_plane.unwrap().b_t,
StateParameter::BLTOF => b_plane.unwrap().ltof_s,
_ => unreachable!(),
}
} else {
xf_dual_obj_frame.partial_for(&obj.parameter).unwrap()
};
let this_achieved = partial.real();
*jac_val = (this_achieved - achieved) / var.perturbation;
});
for (j, var, jac_val) in &pert_calc {
println!(
"jac[({}, {})] = {} for {:?} and {:?}",
i, *j, jac_val, var, obj
);
self.jacobian[(i, *j)] = *jac_val;
}
}
println!("resid: {}", self.residuals);
}
fn params(&self) -> SVector<f64, V> {
self.control
}
fn residuals(&self) -> Option<SVector<f64, O>> {
Some(self.residuals)
}
fn jacobian(&self) -> Option<SMatrix<f64, O, V>> {
Some(self.jacobian)
}
}
impl<'a, E: ErrorCtrl, const V: usize, const O: usize> Optimizer<'a, E, V, O>
where
Const<V>: ToTypenum,
Const<O>: ToTypenum,
Const<O>: DimMin<Const<V>, Output = Const<V>> + DimMax<Const<V>, Output = Const<V>>,
{
#[allow(clippy::comparison_chain)]
pub fn minimize(
&self,
initial_state: Spacecraft,
correction_epoch: Epoch,
achievement_epoch: Epoch,
) -> Result<(), NyxError> {
let mut initial_control = SVector::<f64, V>::zeros();
for (i, var) in self.variables.iter().enumerate() {
var.valid()?;
if self.correction_frame.is_some() && var.component.vec_index() < 3 {
let msg = format!(
"Variable is in frame {} but that frame cannot be used for a {:?} correction",
self.correction_frame.unwrap(),
var.component
);
error!("{}", msg);
return Err(NyxError::Targeter(TargetingError::FrameError(msg)));
}
initial_control[i] = var.init_guess;
}
let mut instance = OptimizerInstance {
prop: &self.prop.clone(),
objectives: self.objectives,
objective_frame: self.objective_frame.clone(),
variables: self.variables,
correction_frame: self.correction_frame,
spacecraft: initial_state,
achievement_epoch,
correction_epoch,
control: initial_control,
residuals: SVector::zeros(),
jacobian: SMatrix::zeros(),
};
instance.set_params(&initial_control);
println!("Init resid: {}", instance.residuals);
let (result, report) = LevenbergMarquardt::new()
.with_patience(10)
.minimize(instance);
println!("{:?}", report);
println!(
"Result correction: {}\t\t{} km/s",
result.control,
result.control.norm()
);
Ok(())
}
}