use crate::astro::dynamics::StateCovariance;
use crate::astro::dynamics::{OrbitState, Velocity};
use crate::coordinates::frames::GCRS;
use affn::Displacement;
use faer::Mat;
use qtty::dynamics::KmPerSeconds;
use qtty::length::Kilometers;
use qtty::unit::Kilometer;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum EkfError {
#[error("singular innovation covariance: {0}")]
Singular(f64),
#[error("invalid measurement update input: {0}")]
InvalidInput(String),
}
#[derive(Debug, Clone, Copy)]
pub struct InnovationRecord {
pub innovation: f64,
pub variance: f64,
pub nis: f64,
}
#[derive(Debug, Clone)]
pub struct OrbitEkf {
state: OrbitState,
cov: StateCovariance<GCRS>,
}
impl OrbitEkf {
pub fn new(state: OrbitState, cov: StateCovariance<GCRS>) -> Self {
Self { state, cov }
}
pub fn from_stddevs(state: OrbitState, sigma_pos: [f64; 3], sigma_vel: [f64; 3]) -> Self {
let sigma_pos_typed = [
Kilometers::new(sigma_pos[0]),
Kilometers::new(sigma_pos[1]),
Kilometers::new(sigma_pos[2]),
];
let sigma_vel_typed = [
KmPerSeconds::new(sigma_vel[0]),
KmPerSeconds::new(sigma_vel[1]),
KmPerSeconds::new(sigma_vel[2]),
];
Self {
state,
cov: StateCovariance::<GCRS>::diagonal_from_sigmas(sigma_pos_typed, sigma_vel_typed),
}
}
pub fn state(&self) -> &OrbitState {
&self.state
}
pub fn covariance(&self) -> &StateCovariance<GCRS> {
&self.cov
}
pub fn predict(
&mut self,
state_pred: OrbitState,
phi: [[f64; 6]; 6],
q: Option<StateCovariance<GCRS>>,
) {
let phi_m: Mat<f64> = Mat::from_fn(6, 6, |i, j| phi[i][j]);
let p = self.cov.to_row_major();
let p_m: Mat<f64> = Mat::from_fn(6, 6, |i, j| p[i][j]);
let phi_p = &phi_m * &p_m;
let mut p_next = &phi_p * phi_m.transpose();
if let Some(q_cov) = q {
let q = q_cov.to_row_major();
let q_m: Mat<f64> = Mat::from_fn(6, 6, |i, j| q[i][j]);
p_next += q_m;
}
self.state = state_pred;
self.cov = StateCovariance::from_row_major(std::array::from_fn(|i| {
std::array::from_fn(|j| p_next[(i, j)])
}));
}
pub fn update_scalar(
&mut self,
h: [f64; 6],
innovation: f64,
r: f64,
) -> Result<InnovationRecord, EkfError> {
if !r.is_finite() || r < 0.0 {
return Err(EkfError::InvalidInput(format!(
"measurement noise variance r must be finite and ≥ 0 (got {r})"
)));
}
if !innovation.is_finite() {
return Err(EkfError::InvalidInput(format!(
"innovation must be finite (got {innovation})"
)));
}
for (i, &hi) in h.iter().enumerate() {
if !hi.is_finite() {
return Err(EkfError::InvalidInput(format!(
"observation matrix element h[{i}] must be finite (got {hi})"
)));
}
}
let p = self.cov.to_row_major();
let p_m: Mat<f64> = Mat::from_fn(6, 6, |i, j| p[i][j]);
let h_col: Mat<f64> = Mat::from_fn(6, 1, |i, _| h[i]);
let ph = &p_m * &h_col;
let mut s = r;
for i in 0..6 {
s += h[i] * ph[(i, 0)];
}
if s.partial_cmp(&0.0) != Some(std::cmp::Ordering::Greater) {
return Err(EkfError::Singular(s));
}
let k: Mat<f64> = Mat::from_fn(6, 1, |i, _| ph[(i, 0)] / s);
let pos_delta = Displacement::<GCRS, Kilometer>::new(
k[(0, 0)] * innovation,
k[(1, 0)] * innovation,
k[(2, 0)] * innovation,
);
let vel_delta = Velocity::<GCRS>::new(
k[(3, 0)] * innovation,
k[(4, 0)] * innovation,
k[(5, 0)] * innovation,
);
let new_pos = self.state.position + pos_delta;
let new_vel = self.state.velocity + vel_delta;
self.state = OrbitState::new(self.state.epoch, new_pos, new_vel);
use faer::Scale;
let h_row: Mat<f64> = Mat::from_fn(1, 6, |_, j| h[j]);
let mut i_kh: Mat<f64> = Mat::identity(6, 6);
i_kh -= &k * &h_row;
let mut p_next = &i_kh * &p_m * i_kh.transpose() + Scale(r) * (&k * k.transpose());
for i in 0..6 {
for j in (i + 1)..6 {
let v = 0.5 * (p_next[(i, j)] + p_next[(j, i)]);
p_next[(i, j)] = v;
p_next[(j, i)] = v;
}
}
self.cov = StateCovariance::from_row_major(std::array::from_fn(|i| {
std::array::from_fn(|j| p_next[(i, j)])
}));
let nis = innovation * innovation / s;
Ok(InnovationRecord {
innovation,
variance: s,
nis,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::astro::dynamics::Position;
use crate::time::JulianDate;
fn make_orbit_state() -> OrbitState {
let epoch = JulianDate::new(2_451_545.0);
let pos = Position::<GCRS>::new(7000.0, 0.0, 0.0);
let vel = Velocity::<GCRS>::new(0.0, 7.5, 0.0);
OrbitState::new(epoch.to_j2000s(), pos, vel)
}
#[test]
fn orbit_ekf_update_reduces_variance() {
let s0 = make_orbit_state();
let mut f = OrbitEkf::from_stddevs(s0, [1.0, 1.0, 1.0], [1e-3, 1e-3, 1e-3]);
let p_before = f.covariance().to_row_major();
let rec = f
.update_scalar([1.0, 0.0, 0.0, 0.0, 0.0, 0.0], 0.5, 0.01)
.unwrap();
let p_after = f.covariance().to_row_major();
assert!(p_after[0][0] < p_before[0][0], "pos variance must shrink");
assert!(rec.nis > 0.0);
assert!(f.state().position.x().value() > 7000.0);
}
#[test]
fn orbit_ekf_predict_propagates_covariance() {
let s0 = make_orbit_state();
let mut f = OrbitEkf::from_stddevs(s0, [0.1, 0.1, 0.1], [1e-4, 1e-4, 1e-4]);
let p_before = f.covariance().to_row_major();
let mut phi = [[0.0f64; 6]; 6];
for i in 0..6 {
phi[i][i] = 1.0;
}
f.predict(s0, phi, None);
let p_after_no_q = f.covariance().to_row_major();
for i in 0..6 {
assert!(
(p_after_no_q[i][i] - p_before[i][i]).abs() < 1e-12,
"identity phi should leave diagonal unchanged"
);
}
let q_pos = [
Kilometers::new(0.01),
Kilometers::new(0.01),
Kilometers::new(0.01),
];
let q_vel = [
KmPerSeconds::new(1e-5),
KmPerSeconds::new(1e-5),
KmPerSeconds::new(1e-5),
];
let q = StateCovariance::<GCRS>::diagonal_from_sigmas(q_pos, q_vel);
f.predict(s0, phi, Some(q));
let p_inflated = f.covariance().to_row_major();
assert!(
p_inflated[0][0] > p_after_no_q[0][0],
"Q must inflate pos variance"
);
assert!(
p_inflated[3][3] > p_after_no_q[3][3],
"Q must inflate vel variance"
);
}
}