#![allow(clippy::type_complexity)] #![allow(unused_imports)] use crate::linalg::allocator::Allocator;
use crate::linalg::{Const, DefaultAllocator, DimName, OMatrix, OVector, U1}; use crate::md::trajectory::{Interpolatable, Traj}; pub use crate::od::estimate::*;
pub use crate::od::ground_station::*;
pub use crate::od::snc::*; pub use crate::od::*;
use crate::propagators::Propagator;
pub use crate::time::{Duration, Epoch, Unit};
use anise::prelude::Almanac;
use indexmap::IndexSet;
use log::error;
use log::{debug, info, trace, warn};
use msr::sensitivity::TrackerSensitivity; use nalgebra::{Cholesky, Dyn, Matrix, VecStorage};
use snafu::prelude::*;
use solution::msr::MeasurementType;
use std::collections::BTreeMap;
use std::marker::PhantomData;
use std::ops::Add;
use std::sync::Arc;
use typed_builder::TypedBuilder;
mod solution;
pub use solution::BLSSolution;
use self::msr::TrackingDataArc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BLSSolver {
NormalEquations,
LevenbergMarquardt,
}
#[derive(Clone, TypedBuilder)]
#[builder(doc)]
pub struct BatchLeastSquares<
D: Dynamics,
Trk: TrackerSensitivity<D::StateType, D::StateType>, > where
D::StateType:
Interpolatable + Add<OVector<f64, <D::StateType as State>::Size>, Output = D::StateType>,
<D::StateType as State>::Size: DimName, <DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
<DefaultAllocator as Allocator<<D::StateType as State>::Size>>::Buffer<f64>: Copy,
<DefaultAllocator as Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>>::Buffer<f64>: Copy,
DefaultAllocator: Allocator<<D::StateType as State>::Size>
+ Allocator<<D::StateType as State>::VecLength>
+ Allocator<U1> + Allocator<U1, <D::StateType as State>::Size>
+ Allocator<<D::StateType as State>::Size, U1>
+ Allocator<U1, U1>
+ Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>,
{
pub prop: Propagator<D>,
pub devices: BTreeMap<String, Trk>,
#[builder(default = BLSSolver::NormalEquations)]
pub solver: BLSSolver,
#[builder(default = 1e-4)]
pub tolerance_pos_km: f64,
#[builder(default = 10)]
pub max_iterations: usize,
#[builder(default_code = "30 * Unit::Second")]
pub max_step: Duration,
#[builder(default_code = "1 * Unit::Microsecond")]
pub epoch_precision: Duration,
#[builder(default = 10.0)]
pub lm_lambda_init: f64,
#[builder(default = 10.0)] pub lm_lambda_decrease: f64,
#[builder(default = 10.0)] pub lm_lambda_increase: f64,
#[builder(default = 1e-12)]
pub lm_lambda_min: f64,
#[builder(default = 1e12)]
pub lm_lambda_max: f64,
#[builder(default = true)]
pub lm_use_diag_scaling: bool,
pub almanac: Arc<Almanac>,
}
#[allow(type_alias_bounds)]
type StateMatrix<D: Dynamics> =
OMatrix<f64, <D::StateType as State>::Size, <D::StateType as State>::Size>;
impl<D, Trk> BatchLeastSquares<D, Trk>
where
D: Dynamics,
Trk: TrackerSensitivity<D::StateType, D::StateType> + Clone, D::StateType: Interpolatable
+ Add<OVector<f64, <D::StateType as State>::Size>, Output = D::StateType>
+ std::fmt::Debug, <D::StateType as State>::Size: DimName,
<DefaultAllocator as Allocator<<D::StateType as State>::VecLength>>::Buffer<f64>: Send,
<DefaultAllocator as Allocator<<D::StateType as State>::Size>>::Buffer<f64>: Copy,
<DefaultAllocator as Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>>::Buffer<f64>: Copy,
DefaultAllocator: Allocator<<D::StateType as State>::Size>
+ Allocator<<D::StateType as State>::VecLength>
+ Allocator<U1>
+ Allocator<U1, <D::StateType as State>::Size>
+ Allocator<<D::StateType as State>::Size, U1>
+ Allocator<U1, U1>
+ Allocator<<D::StateType as State>::Size, <D::StateType as State>::Size>,
{
pub fn estimate(
&self,
initial_guess: D::StateType,
arc: &TrackingDataArc,
) -> Result<BLSSolution<D::StateType>, ODError> {
let measurements = &arc.measurements;
let num_measurements = measurements.iter().filter(|m| !m.rejected).count();
let mut devices = self.devices.clone();
ensure!(
num_measurements >= 2,
TooFewMeasurementsSnafu {
need: 2_usize,
action: "BLSE"
}
);
info!(
"Using {:?} in the Batch Least Squares estimation with {num_measurements} measurements",
self.solver
);
info!("Initial guess: {}", initial_guess.orbit());
let mut current_estimate = initial_guess;
let mut current_covariance = StateMatrix::<D>::zeros();
let mut converged = false;
let mut corr_pos_km = f64::MAX;
let mut lambda = self.lm_lambda_init;
let mut current_rms = f64::MAX;
let mut iter: usize = 0;
let mut unknown_trackers = IndexSet::new();
while iter < self.max_iterations {
iter += 1;
info!("[{iter}/{}] Current estimate: {}", self.max_iterations, current_estimate.orbit());
let mut info_matrix = StateMatrix::<D>::identity();
let mut normal_matrix = OVector::<f64, <D::StateType as State>::Size>::zeros();
let mut sum_sq_weighted_residuals = 0.0;
let mut prop_inst = self.prop.with(current_estimate.with_stm(), (self.almanac).clone()).quiet();
let mut epoch = current_estimate.epoch();
let mut stm = StateMatrix::<D>::identity();
for msr in measurements {
if msr.rejected {
continue;
}
let msr_epoch = msr.epoch;
loop {
let delta_t = msr_epoch - epoch;
if delta_t <= Duration::ZERO {
break;
}
let next_step = delta_t.min(prop_inst.step_size).min(self.max_step);
let this_state = prop_inst.for_duration(next_step).context(ODPropSnafu)?;
epoch = this_state.epoch();
let step_stm = this_state.stm().expect("STM unavailable");
stm = step_stm * stm;
if (epoch - msr_epoch).abs() < self.epoch_precision {
let device = match devices.get_mut(&msr.tracker) {
Some(d) => d,
None => {
if !unknown_trackers.contains(&msr.tracker) {
error!(
"Tracker {} is not in the list of configured devices",
msr.tracker
);
}
unknown_trackers.insert(msr.tracker.clone());
continue;
}
};
for msr_type in msr.data.keys().copied() {
let mut msr_types = IndexSet::new();
msr_types.insert(msr_type);
let h_tilde = device
.h_tilde::<U1>(msr, &msr_types, &this_state, &self.almanac)?;
let computed_meas_opt = device
.measure_instantaneous(this_state, None, &self.almanac)?;
let computed_meas = match computed_meas_opt {
Some(cm) => cm,
None => {
debug!("Device {} does not expect measurement at epoch {msr_epoch}, skipping", msr.tracker);
continue;
}
};
let computed_obs = computed_meas.observation::<U1>(&msr_types)[0];
let real_obs = msr.observation::<U1>(&msr_types)[0];
ensure!(
real_obs.is_finite(),
InvalidMeasurementSnafu {
epoch: msr_epoch,
val: real_obs
}
);
let residual = real_obs - computed_obs;
let r_matrix = device
.measurement_covar_matrix::<U1>(&msr_types, msr_epoch)?;
let r_variance = r_matrix[(0, 0)];
ensure!(r_variance > 0.0, SingularNoiseRkSnafu);
let weight = 1.0 / r_variance;
let h_matrix = h_tilde * stm;
info_matrix += h_matrix.transpose() * &h_matrix * weight;
normal_matrix += h_matrix.transpose() * residual * weight;
sum_sq_weighted_residuals += weight * residual * residual;
}
}
}
}
let state_correction: OVector<f64, <D::StateType as State>::Size>;
let iteration_cost_decreased;
let current_iter_rms = (sum_sq_weighted_residuals / num_measurements as f64).sqrt();
match self.solver {
BLSSolver::NormalEquations => {
let info_matrix_chol = match info_matrix.cholesky() {
Some(chol) => chol,
None => return Err(ODError::SingularInformationMatrix)
};
state_correction = info_matrix_chol.solve(&normal_matrix);
iteration_cost_decreased = true;
current_rms = current_iter_rms;
}
BLSSolver::LevenbergMarquardt => {
let mut d_sq = StateMatrix::<D>::identity();
if self.lm_use_diag_scaling {
for i in 0..6 {
d_sq[(i, i)] = info_matrix.diagonal()[i];
}
for i in 0..6 {
if d_sq[(i, i)] <= 0.0 {
d_sq[(i, i)] = 1e-6; warn!("LM Scaling: Found non-positive diagonal element {} in H^TWH, using floor.", info_matrix[(i, i)]);
}
}
}
let augmented_matrix = info_matrix + d_sq * lambda;
if let Some(aug_chol) = augmented_matrix.cholesky() {
state_correction = aug_chol.solve(&normal_matrix);
if current_iter_rms < current_rms || iter == 0 {
iteration_cost_decreased = true;
lambda /= self.lm_lambda_decrease;
lambda = lambda.max(self.lm_lambda_min);
debug!("LM: Cost decreased (RMS {current_rms} -> {current_iter_rms}). Decreasing lambda to {lambda}");
current_rms = current_iter_rms;
} else {
iteration_cost_decreased = false;
lambda *= self.lm_lambda_increase;
lambda = lambda.min(self.lm_lambda_max);
debug!("LM: Cost increased/stalled (RMS {current_rms} -> {current_iter_rms}). Increasing lambda to {lambda}");
}
} else {
warn!("LM: Augmented matrix (H^TWH + lambda*D^2) singular with lambda={lambda}. Increasing lambda.");
lambda *= self.lm_lambda_increase * 10.0; lambda = lambda.min(self.lm_lambda_max);
continue;
}
}
}
if iteration_cost_decreased {
current_estimate = current_estimate + state_correction;
corr_pos_km = state_correction.fixed_rows::<3>(0).norm();
let corr_vel_km_s = state_correction.fixed_rows::<3>(3).norm();
info!(
"[{iter}/{}] RMS: {current_iter_rms:.3}; corrections: {:.3} m\t{:.3} m/s",
self.max_iterations,
corr_pos_km * 1e3,
corr_vel_km_s * 1e3
);
current_covariance = match info_matrix.udu() {
Some(info_udu) => {
match info_udu.u.try_inverse() {
None =>{
warn!("Information matrix H^TWH is singular.");
StateMatrix::<D>::identity()
},
Some(u_inv) => {
let d_inv_v = OVector::<f64,<D::StateType as State>::Size>::from_iterator(info_udu.d.iter().map(|d_ii| 1.0 / d_ii));
let d_inv = OMatrix::from_diagonal(&d_inv_v);
let y = d_inv * u_inv;
u_inv.transpose() * y
}
}
}
None => {
warn!("Information matrix H^TWH is singular.");
StateMatrix::<D>::identity()
}
};
if corr_pos_km < self.tolerance_pos_km {
info!("Converged in {iter} iterations.");
converged = true;
break;
}
} else if self.solver == BLSSolver::LevenbergMarquardt {
info!("[{iter}/{}] LM: Step rejected, increasing lambda.", self.max_iterations);
corr_pos_km = f64::MAX;
}
}
if !converged {
warn!("Not converged after {} iterations.", self.max_iterations);
}
info!("Batch Least Squares estimation completed.");
Ok(BLSSolution {
estimated_state: current_estimate,
covariance: current_covariance,
num_iterations: iter,
final_rms: current_rms,
final_corr_pos_km: corr_pos_km,
converged,
})
}
pub fn evaluate(
&self,
state: D::StateType,
arc: &TrackingDataArc,
) -> Result<f64, ODError> {
let measurements = &arc.measurements;
let num_measurements = measurements.iter().filter(|m| !m.rejected).count();
let mut devices = self.devices.clone();
ensure!(
num_measurements >= 1,
TooFewMeasurementsSnafu {
need: 1_usize,
action: "BLSE Evaluate"
}
);
let mut sum_sq_weighted_residuals = 0.0;
let mut unknown_trackers = IndexSet::new();
let mut prop_inst = self.prop.with(state.with_stm(), self.almanac.clone()).quiet();
let mut epoch = state.epoch();
for msr in measurements {
if msr.rejected {
continue;
}
let msr_epoch = msr.epoch;
loop {
let delta_t = msr_epoch - epoch;
if delta_t <= Duration::ZERO {
break;
}
let next_step = delta_t.min(prop_inst.step_size).min(self.max_step);
let this_state = prop_inst.for_duration(next_step).context(ODPropSnafu)?;
epoch = this_state.epoch();
if (epoch - msr_epoch).abs() < self.epoch_precision {
let device = match devices.get_mut(&msr.tracker) {
Some(d) => d,
None => {
if !unknown_trackers.contains(&msr.tracker) {
error!(
"Tracker {} is not in the list of configured devices",
msr.tracker
);
}
unknown_trackers.insert(msr.tracker.clone());
continue;
}
};
for msr_type in msr.data.keys().copied() {
let mut msr_types = IndexSet::new();
msr_types.insert(msr_type);
let computed_meas_opt = device
.measure_instantaneous(this_state, None, &self.almanac)?;
let computed_meas = match computed_meas_opt {
Some(cm) => cm,
None => continue,
};
let computed_obs = computed_meas.observation::<U1>(&msr_types)[0];
let real_obs = msr.observation::<U1>(&msr_types)[0];
ensure!(
real_obs.is_finite(),
InvalidMeasurementSnafu {
epoch: msr_epoch,
val: real_obs
}
);
let residual = real_obs - computed_obs;
let r_matrix = device.measurement_covar_matrix::<U1>(&msr_types, msr_epoch)?;
let r_variance = r_matrix[(0, 0)];
ensure!(r_variance > 0.0, SingularNoiseRkSnafu);
let weight = 1.0 / r_variance;
sum_sq_weighted_residuals += weight * residual * residual;
}
}
}
}
Ok((sum_sq_weighted_residuals / num_measurements as f64).sqrt())
}
}