use super::solution::TargeterSolution;
use crate::cosmic::{AstroAlmanacSnafu, AstroPhysicsSnafu};
use crate::errors::TargetingError;
use crate::linalg::{DMatrix, SVector};
use crate::md::{AstroSnafu, StateParameter};
use crate::md::{PropSnafu, UnderdeterminedProblemSnafu, prelude::*};
pub use crate::md::{Variable, Vary};
use crate::pseudo_inverse;
use crate::utils::are_eigenvalues_stable;
use anise::astro::orbit_gradient::OrbitGrad;
use log::{debug, info, warn};
use snafu::{ResultExt, ensure};
#[cfg(not(target_arch = "wasm32"))]
use std::time::Instant;
impl<const V: usize, const O: usize> Targeter<'_, V, O> {
#[allow(clippy::comparison_chain)]
pub fn try_achieve_dual(
&self,
initial_state: Spacecraft,
correction_epoch: Epoch,
achievement_epoch: Epoch,
almanac: Arc<Almanac>,
) -> Result<TargeterSolution<V, O>, TargetingError> {
ensure!(!self.objectives.is_empty(), UnderdeterminedProblemSnafu);
let mut is_bplane_tgt = false;
for obj in &self.objectives {
if obj.parameter.is_b_plane() {
is_bplane_tgt = true;
break;
}
}
let xi_start = self
.prop
.with(initial_state, almanac.clone())
.until_epoch(correction_epoch)
.context(PropSnafu)?;
debug!("initial_state = {initial_state:?}");
debug!("xi_start = {xi_start:?}");
let mut xi = xi_start;
let mut total_correction = SVector::<f64, V>::zeros();
for (i, var) in self.variables.iter().enumerate() {
match var.component {
Vary::PositionX => {
xi.orbit.radius_km.x += var.init_guess;
}
Vary::PositionY => {
xi.orbit.radius_km.y += var.init_guess;
}
Vary::PositionZ => {
xi.orbit.radius_km.z += var.init_guess;
}
Vary::VelocityX => {
xi.orbit.velocity_km_s.x += var.init_guess;
}
Vary::VelocityY => {
xi.orbit.velocity_km_s.y += var.init_guess;
}
Vary::VelocityZ => {
xi.orbit.velocity_km_s.z += var.init_guess;
}
_ => {
return Err(TargetingError::UnsupportedVariable {
var: var.to_string(),
});
}
}
total_correction[i] += var.init_guess;
}
let mut prev_err_norm = f64::INFINITY;
let max_obj_val = self
.objectives
.iter()
.map(|obj| {
obj.desired_value.abs().ceil() as i32
* 10_i32.pow(obj.tolerance.abs().log10().ceil() as u32)
})
.max()
.unwrap();
let max_obj_tol = self
.objectives
.iter()
.map(|obj| obj.tolerance.log10().abs().ceil() as usize)
.max()
.unwrap();
let width = f64::from(max_obj_val).log10() as usize + 2 + max_obj_tol;
#[cfg(not(target_arch = "wasm32"))]
let start_instant = Instant::now();
for it in 0..=self.iterations {
xi.enable_stm();
let xf = self
.prop
.with(xi, almanac.clone())
.until_epoch(achievement_epoch)
.context(PropSnafu)?;
if !are_eigenvalues_stable(&xf.stm().unwrap().complex_eigenvalues()) {
warn!(
"STM linearization assumption is wrong for a time step of {}",
achievement_epoch - correction_epoch
);
}
let xf_dual_obj_frame = match &self.objective_frame {
Some(frame) => {
let orbit_obj_frame = almanac
.transform_to(xf.orbit, *frame, None)
.context(AstroAlmanacSnafu)
.context(AstroSnafu)?;
OrbitGrad::from(orbit_obj_frame)
}
None => OrbitGrad::from(xf.orbit),
};
let mut err_vector = SVector::<f64, O>::zeros();
let mut converged = true;
let b_plane = if is_bplane_tgt {
Some(BPlane::from_dual(xf_dual_obj_frame).context(AstroSnafu)?)
} else {
None
};
let mut objmsg = Vec::new();
let mut jac = DMatrix::from_element(self.objectives.len(), self.variables.len(), 0.0);
for (i, obj) in self.objectives.iter().enumerate() {
let xf_partial = if obj.parameter.is_b_plane() {
match obj.parameter {
StateParameter::BdotR() => b_plane.unwrap().b_r_km,
StateParameter::BdotT() => b_plane.unwrap().b_t_km,
StateParameter::BLTOF() => b_plane.unwrap().ltof_s,
_ => unreachable!(),
}
} else if let StateParameter::Element(oe) = obj.parameter {
xf_dual_obj_frame
.partial_for(oe)
.context(AstroPhysicsSnafu)
.context(AstroSnafu)?
} else {
unreachable!()
};
let achieved = xf_partial.real();
let (ok, param_err) = obj.assess_value(achieved);
if !ok {
converged = false;
}
err_vector[i] = param_err;
objmsg.push(format!(
"\t{:?}: achieved = {:>width$.prec$}\t desired = {:>width$.prec$}\t scaled error = {:>width$.prec$}",
obj.parameter,
achieved,
obj.desired_value,
param_err, width=width, prec=max_obj_tol
));
let mut partial_vec = DMatrix::from_element(1, 6, 0.0);
for (i, val) in [
xf_partial.wrt_x(),
xf_partial.wrt_y(),
xf_partial.wrt_z(),
xf_partial.wrt_vx(),
xf_partial.wrt_vy(),
xf_partial.wrt_vz(),
]
.iter()
.enumerate()
{
partial_vec[(0, i)] = *val;
}
for (j, var) in self.variables.iter().enumerate() {
let sc_stm = xf.stm().unwrap();
let stm = sc_stm.fixed_view::<6, 6>(0, 0);
let idx = var.component.vec_index();
let rslt = &partial_vec * stm.fixed_columns::<1>(idx);
jac[(i, j)] = rslt[(0, 0)];
}
}
if converged {
#[cfg(not(target_arch = "wasm32"))]
let conv_dur = Instant::now() - start_instant;
#[cfg(target_arch = "wasm32")]
let conv_dur = Duration::ZERO.into();
let mut state = xi_start;
for (i, var) in self.variables.iter().enumerate() {
match var.component {
Vary::PositionX => state.orbit.radius_km.x += total_correction[i],
Vary::PositionY => state.orbit.radius_km.y += total_correction[i],
Vary::PositionZ => state.orbit.radius_km.z += total_correction[i],
Vary::VelocityX => state.orbit.velocity_km_s.x += total_correction[i],
Vary::VelocityY => state.orbit.velocity_km_s.y += total_correction[i],
Vary::VelocityZ => state.orbit.velocity_km_s.z += total_correction[i],
_ => {
return Err(TargetingError::UnsupportedVariable {
var: var.to_string(),
});
}
}
}
let sol = TargeterSolution {
corrected_state: state,
achieved_state: xf,
correction: total_correction,
computation_dur: conv_dur,
variables: self.variables,
achieved_errors: err_vector,
achieved_objectives: self.objectives,
iterations: it,
};
info!("Targeter -- CONVERGED in {it} iterations");
for obj in &objmsg {
info!("{obj}");
}
return Ok(sol);
}
if (err_vector.norm() - prev_err_norm).abs() < 1e-10 {
return Err(TargetingError::CorrectionIneffective {
cur_val: err_vector.norm(),
prev_val: prev_err_norm,
action: "No change in objective errors",
});
}
prev_err_norm = err_vector.norm();
debug!("Jacobian {jac}");
let jac_inv = pseudo_inverse!(&jac)?;
debug!("Inverse Jacobian {jac_inv}");
let mut delta = jac_inv * err_vector;
debug!("Error vector: {err_vector}\nRaw correction: {delta}");
for (i, var) in self.variables.iter().enumerate() {
if delta[i].abs() > var.max_step {
delta[i] = var.max_step * delta[i].signum();
} else if delta[i] > var.max_value {
delta[i] = var.max_value;
} else if delta[i] < var.min_value {
delta[i] = var.min_value;
}
info!(
"Correction {:?} (element {}): {}",
var.component, i, delta[i]
);
match var.component {
Vary::PositionX => {
xi.orbit.radius_km.x += delta[i];
}
Vary::PositionY => {
xi.orbit.radius_km.y += delta[i];
}
Vary::PositionZ => {
xi.orbit.radius_km.z += delta[i];
}
Vary::VelocityX => {
xi.orbit.velocity_km_s.x += delta[i];
}
Vary::VelocityY => {
xi.orbit.velocity_km_s.y += delta[i];
}
Vary::VelocityZ => {
xi.orbit.velocity_km_s.z += delta[i];
}
_ => {
return Err(TargetingError::UnsupportedVariable {
var: var.to_string(),
});
}
}
}
total_correction += delta;
debug!("Total correction: {total_correction:e}");
info!("Targeter -- Iteration #{it} -- {achievement_epoch}");
for obj in &objmsg {
info!("{obj}");
}
}
Err(TargetingError::TooManyIterations)
}
}