use serde::{Deserialize, Serialize};
use serde_dhall::StaticType;
use crate::linalg::allocator::Allocator;
use crate::linalg::{DefaultAllocator, DimName, OVector, U3};
const REL_ERR_THRESH: f64 = 0.1;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, StaticType)]
pub enum ErrorControl {
RSSCartesianState,
#[default]
RSSCartesianStep,
RSSState,
RSSStep,
LargestError,
LargestState,
LargestStep,
}
impl ErrorControl {
pub fn estimate<N: DimName>(
self,
error_est: &OVector<f64, N>,
candidate: &OVector<f64, N>,
cur_state: &OVector<f64, N>,
) -> f64
where
DefaultAllocator: Allocator<N>,
{
match self {
ErrorControl::RSSCartesianState => {
if N::dim() >= 6 {
let err_radius = RSSState::estimate::<U3>(
&error_est.fixed_rows::<3>(0).into_owned(),
&candidate.fixed_rows::<3>(0).into_owned(),
&cur_state.fixed_rows::<3>(0).into_owned(),
);
let err_velocity = RSSState::estimate::<U3>(
&error_est.fixed_rows::<3>(3).into_owned(),
&candidate.fixed_rows::<3>(3).into_owned(),
&cur_state.fixed_rows::<3>(3).into_owned(),
);
err_radius.max(err_velocity)
} else {
RSSStep::estimate(error_est, candidate, cur_state)
}
}
ErrorControl::RSSCartesianStep => {
if N::dim() >= 6 {
let err_radius = RSSStep::estimate::<U3>(
&error_est.fixed_rows::<3>(0).into_owned(),
&candidate.fixed_rows::<3>(0).into_owned(),
&cur_state.fixed_rows::<3>(0).into_owned(),
);
let err_velocity = RSSStep::estimate::<U3>(
&error_est.fixed_rows::<3>(3).into_owned(),
&candidate.fixed_rows::<3>(3).into_owned(),
&cur_state.fixed_rows::<3>(3).into_owned(),
);
err_radius.max(err_velocity)
} else {
RSSStep::estimate(error_est, candidate, cur_state)
}
}
ErrorControl::RSSState => {
let mag = 0.5 * (candidate + cur_state).norm();
let err = error_est.norm();
if mag > REL_ERR_THRESH { err / mag } else { err }
}
ErrorControl::RSSStep => {
let mag = (candidate - cur_state).norm();
let err = error_est.norm();
if mag > REL_ERR_THRESH.sqrt() {
err / mag
} else {
err
}
}
ErrorControl::LargestError => {
let state_delta = candidate - cur_state;
let mut max_err = 0.0;
for (i, prop_err_i) in error_est.iter().enumerate() {
let err = if state_delta[i] > REL_ERR_THRESH {
(prop_err_i / state_delta[i]).abs()
} else {
prop_err_i.abs()
};
if err > max_err {
max_err = err;
}
}
max_err
}
ErrorControl::LargestState => {
let sum_state = candidate + cur_state;
let mut mag = 0.0f64;
let mut err = 0.0f64;
for i in 0..N::dim() {
mag += 0.5 * sum_state[i].abs();
err += error_est[i].abs();
}
if mag > REL_ERR_THRESH { err / mag } else { err }
}
ErrorControl::LargestStep => {
let state_delta = candidate - cur_state;
let mut mag = 0.0f64;
let mut err = 0.0f64;
for i in 0..N::dim() {
mag += state_delta[i].abs();
err += error_est[i].abs();
}
if mag > REL_ERR_THRESH { err / mag } else { err }
}
}
}
}
#[derive(Clone, Copy)]
#[allow(clippy::upper_case_acronyms)]
struct RSSStep;
impl RSSStep {
fn estimate<N: DimName>(
error_est: &OVector<f64, N>,
candidate: &OVector<f64, N>,
cur_state: &OVector<f64, N>,
) -> f64
where
DefaultAllocator: Allocator<N>,
{
let mag = (candidate - cur_state).norm();
let err = error_est.norm();
if mag > REL_ERR_THRESH.sqrt() {
err / mag
} else {
err
}
}
}
#[derive(Clone, Copy)]
#[allow(clippy::upper_case_acronyms)]
struct RSSState;
impl RSSState {
fn estimate<N: DimName>(
error_est: &OVector<f64, N>,
candidate: &OVector<f64, N>,
cur_state: &OVector<f64, N>,
) -> f64
where
DefaultAllocator: Allocator<N>,
{
let mag = 0.5 * (candidate + cur_state).norm();
let err = error_est.norm();
if mag > REL_ERR_THRESH { err / mag } else { err }
}
}