use log::info;
use snafu::ResultExt;
pub use super::CostFunction;
use super::{MultipleShootingError, TargetingSnafu};
use crate::linalg::{DMatrix, DVector, SVector};
use crate::md::opti::solution::TargeterSolution;
use crate::md::targeter::Targeter;
use crate::md::{TargetingError, prelude::*};
use crate::pseudo_inverse;
use crate::{Orbit, Spacecraft};
use std::fmt;
pub trait MultishootNode<const O: usize>: Copy + Into<[Objective; O]> {
fn epoch(&self) -> Epoch;
fn update_component(&mut self, component: usize, add_val: f64);
}
pub struct MultipleShooting<'a, T: MultishootNode<OT>, const VT: usize, const OT: usize> {
pub prop: &'a Propagator<SpacecraftDynamics>,
pub targets: Vec<T>,
pub x0: Spacecraft,
pub xf: Orbit,
pub current_iteration: usize,
pub max_iterations: usize,
pub improvement_threshold: f64,
pub variables: [Variable; VT],
pub all_dvs: Vec<SVector<f64, VT>>,
}
impl<T: MultishootNode<OT>, const VT: usize, const OT: usize> MultipleShooting<'_, T, VT, OT> {
pub fn solve(
&mut self,
cost: CostFunction,
almanac: Arc<Almanac>,
) -> Result<MultipleShootingSolution<T, OT>, MultipleShootingError> {
let mut prev_cost = 1e12; for it in 0..self.max_iterations {
let mut initial_states = Vec::with_capacity(self.targets.len());
initial_states.push(self.x0);
let mut outer_jacobian =
DMatrix::from_element(3 * self.targets.len(), OT * (self.targets.len() - 1), 0.0);
let mut cost_vec = DVector::from_element(3 * self.targets.len(), 0.0);
self.all_dvs = Vec::with_capacity(self.all_dvs.len());
for i in 0..self.targets.len() {
let tgt = Targeter {
prop: self.prop,
objectives: self.targets[i].into(),
variables: self.variables,
iterations: 100,
objective_frame: None,
correction_frame: None,
};
let sol = tgt
.try_achieve_dual(
initial_states[i],
initial_states[i].epoch(),
self.targets[i].epoch(),
almanac.clone(),
)
.context(TargetingSnafu { segment: i })?;
let nominal_delta_v = sol.correction;
self.all_dvs.push(nominal_delta_v);
initial_states.push(sol.achieved_state);
}
for i in 0..(self.targets.len() - 1) {
for axis in 0..OT {
let mut next_node = self.targets[i].into();
next_node[axis].desired_value += next_node[axis].tolerance;
let inner_tgt_a = Targeter::delta_v(self.prop, next_node);
let inner_sol_a = inner_tgt_a
.try_achieve_dual(
initial_states[i],
initial_states[i].epoch(),
self.targets[i].epoch(),
almanac.clone(),
)
.context(TargetingSnafu { segment: i })?;
outer_jacobian[(3 * i, OT * i + axis)] = (inner_sol_a.correction[0]
- self.all_dvs[i][0])
/ next_node[axis].tolerance;
outer_jacobian[(3 * i + 1, OT * i + axis)] = (inner_sol_a.correction[1]
- self.all_dvs[i][1])
/ next_node[axis].tolerance;
outer_jacobian[(3 * i + 2, OT * i + axis)] = (inner_sol_a.correction[2]
- self.all_dvs[i][2])
/ next_node[axis].tolerance;
let inner_tgt_b = Targeter::delta_v(self.prop, self.targets[i + 1].into());
let inner_sol_b = inner_tgt_b
.try_achieve_dual(
inner_sol_a.achieved_state,
inner_sol_a.achieved_state.epoch(),
self.targets[i + 1].epoch(),
almanac.clone(),
)
.context(TargetingSnafu { segment: i })?;
outer_jacobian[(3 * (i + 1), OT * i + axis)] = (inner_sol_b.correction[0]
- self.all_dvs[i + 1][0])
/ next_node[axis].tolerance;
outer_jacobian[(3 * (i + 1) + 1, OT * i + axis)] = (inner_sol_b.correction[1]
- self.all_dvs[i + 1][1])
/ next_node[axis].tolerance;
outer_jacobian[(3 * (i + 1) + 2, OT * i + axis)] = (inner_sol_b.correction[2]
- self.all_dvs[i + 1][2])
/ next_node[axis].tolerance;
if i < self.targets.len() - 3 {
let dv_ip1 = inner_sol_b.achieved_state.orbit.velocity_km_s
- initial_states[i + 2].orbit.velocity_km_s;
outer_jacobian[(3 * (i + 2), OT * i + axis)] =
dv_ip1[0] / next_node[axis].tolerance;
outer_jacobian[(3 * (i + 2) + 1, OT * i + axis)] =
dv_ip1[1] / next_node[axis].tolerance;
outer_jacobian[(3 * (i + 2) + 2, OT * i + axis)] =
dv_ip1[2] / next_node[axis].tolerance;
}
}
}
for i in 0..self.targets.len() {
for j in 0..3 {
cost_vec[3 * i + j] = self.all_dvs[i][j];
}
}
let new_cost = match cost {
CostFunction::MinimumEnergy => cost_vec.dot(&cost_vec),
CostFunction::MinimumFuel => cost_vec.dot(&cost_vec).sqrt(),
};
let cost_improvmt = (prev_cost - new_cost) / new_cost.abs();
match cost {
CostFunction::MinimumEnergy => info!(
"Multiple shooting iteration #{}\t\tCost = {:.3} km^2/s^2\timprovement = {:.2}%",
it,
new_cost,
100.0 * cost_improvmt
),
CostFunction::MinimumFuel => info!(
"Multiple shooting iteration #{}\t\tCost = {:.3} km/s\timprovement = {:.2}%",
it,
new_cost,
100.0 * cost_improvmt
),
};
if cost_improvmt.abs() < self.improvement_threshold {
info!("Improvement below desired threshold. Running targeter on computed nodes.");
let mut ms_sol = MultipleShootingSolution {
x0: self.x0,
xf: self.xf,
nodes: self.targets.clone(),
solutions: Vec::with_capacity(self.targets.len()),
};
let mut initial_states = Vec::with_capacity(self.targets.len());
initial_states.push(self.x0);
for (i, node) in self.targets.iter().enumerate() {
let tgt = Targeter::delta_v(self.prop, (*node).into());
let sol = tgt
.try_achieve_dual(
initial_states[i],
initial_states[i].epoch(),
node.epoch(),
almanac.clone(),
)
.context(TargetingSnafu { segment: i })?;
initial_states.push(sol.achieved_state);
ms_sol.solutions.push(sol);
}
return Ok(ms_sol);
}
prev_cost = new_cost;
let inv_jac =
pseudo_inverse!(&outer_jacobian).context(TargetingSnafu { segment: 0_usize })?;
let delta_r = inv_jac * cost_vec;
let node_vector = -delta_r;
for (i, val) in node_vector.iter().enumerate() {
let node_no = i / 3;
let component_no = i % OT;
self.targets[node_no].update_component(component_no, *val);
}
self.current_iteration += 1;
}
Err(MultipleShootingError::TargetingError {
segment: 0_usize,
source: Box::new(TargetingError::TooManyIterations),
})
}
}
impl<T: MultishootNode<OT>, const VT: usize, const OT: usize> fmt::Display
for MultipleShooting<'_, T, VT, OT>
{
#[allow(clippy::or_fun_call, clippy::clone_on_copy)]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut nodemsg = String::from("");
nodemsg.push_str(&format!(
"[{:.3}, {:.3}, {:.3}, {}, {}, {}, {}, {}, {}],\n",
self.x0.orbit.radius_km.x,
self.x0.orbit.radius_km.y,
self.x0.orbit.radius_km.z,
self.current_iteration,
0.0,
0.0,
0.0,
0.0,
0
));
for (i, node) in self.targets.iter().enumerate() {
let objectives: [Objective; OT] = (*node).into();
let mut this_nodemsg = String::from("");
for obj in &objectives {
this_nodemsg.push_str(&format!("{:.3}, ", obj.desired_value));
}
let mut this_costmsg = String::from("");
let dv = match self.all_dvs.get(i) {
Some(dv) => dv.clone(),
None => SVector::<f64, VT>::zeros(),
};
for val in &dv {
this_costmsg.push_str(&format!("{val}, "));
}
if VT == 3 {
this_costmsg.push_str(&format!("{}, ", dv.norm()));
}
nodemsg.push_str(&format!(
"[{}{}, {}{}],\n",
this_nodemsg,
self.current_iteration,
this_nodemsg,
i + 1
));
}
write!(f, "{nodemsg}")
}
}
#[derive(Clone, Debug)]
pub struct MultipleShootingSolution<T: MultishootNode<O>, const O: usize> {
pub x0: Spacecraft,
pub xf: Orbit,
pub nodes: Vec<T>,
pub solutions: Vec<TargeterSolution<3, O>>,
}
impl<T: MultishootNode<O>, const O: usize> fmt::Display for MultipleShootingSolution<T, O> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for sol in &self.solutions {
write!(f, "{sol}")?;
}
Ok(())
}
}
impl<T: MultishootNode<O>, const O: usize> MultipleShootingSolution<T, O> {
pub fn build_trajectories(
&self,
prop: &Propagator<SpacecraftDynamics>,
almanac: Arc<Almanac>,
) -> Result<Vec<Trajectory>, MultipleShootingError> {
let mut trajz = Vec::with_capacity(self.nodes.len());
for (i, node) in self.nodes.iter().copied().enumerate() {
let (_, traj) = Targeter::delta_v(prop, node.into())
.apply_with_traj(&self.solutions[i], almanac.clone())
.context(TargetingSnafu { segment: i })?;
trajz.push(traj);
}
Ok(trajz)
}
}