use crate::astro::dynamics::StateCovariance;
use crate::coordinates::frames::GCRS;
use affn::matrix3::{FrameMatrix3, SymmetricFrameMatrix3};
use faer::linalg::solvers::Solve;
use faer::{Mat, Side};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum WlsSolverError {
#[error("normal equations not positive definite: {0}")]
NotPositiveDefinite(String),
#[error("invalid parameter index {index} (n_params = {n_params})")]
InvalidIndex {
index: usize,
n_params: usize,
},
#[error("solver row assembly failed: {0}")]
Other(String),
}
impl WlsSolverError {
pub fn other<S: Into<String>>(msg: S) -> Self {
Self::Other(msg.into())
}
}
#[derive(Debug, Clone)]
pub struct NormalEquations {
n: usize,
pub n_matrix: Mat<f64>,
pub b: Mat<f64>,
pub chi2: f64,
pub n_obs: usize,
}
impl NormalEquations {
pub fn new(n_params: usize) -> Self {
Self {
n: n_params,
n_matrix: Mat::zeros(n_params, n_params),
b: Mat::zeros(n_params, 1),
chi2: 0.0,
n_obs: 0,
}
}
pub fn n_params(&self) -> usize {
self.n
}
pub fn add_row(
&mut self,
partials: &[(usize, f64)],
residual: f64,
sigma: f64,
) -> Result<(), WlsSolverError> {
if !sigma.is_finite() || sigma <= 0.0 {
return Err(WlsSolverError::NotPositiveDefinite(format!(
"sigma must be finite and > 0 (got {sigma})"
)));
}
if !residual.is_finite() {
return Err(WlsSolverError::Other(format!(
"residual must be finite (got {residual})"
)));
}
let w = 1.0 / (sigma * sigma);
for &(i, hi) in partials {
if i >= self.n {
return Err(WlsSolverError::InvalidIndex {
index: i,
n_params: self.n,
});
}
if !hi.is_finite() {
return Err(WlsSolverError::Other(format!(
"partial at index {i} must be finite (got {hi})"
)));
}
self.b[(i, 0)] += w * hi * residual;
for &(j, hj) in partials {
if j < i {
continue;
}
self.n_matrix[(i, j)] += w * hi * hj;
}
}
self.chi2 += w * residual * residual;
self.n_obs += 1;
Ok(())
}
pub fn solve(self) -> Result<WlsResult, WlsSolverError> {
let mut a = self.n_matrix.clone();
for i in 0..self.n {
for j in i + 1..self.n {
let v = a[(i, j)];
a[(j, i)] = v;
}
}
let llt = a
.as_ref()
.llt(Side::Lower)
.map_err(|e| WlsSolverError::NotPositiveDefinite(format!("{e:?}")))?;
let dp = llt.solve(&self.b);
let mut id = Mat::<f64>::zeros(self.n, self.n);
for i in 0..self.n {
id[(i, i)] = 1.0;
}
let cov = llt.solve(&id);
let mut update = vec![0.0; self.n];
for i in 0..self.n {
update[i] = dp[(i, 0)];
}
let mut cov_arr = vec![vec![0.0f64; self.n]; self.n];
for i in 0..self.n {
for j in 0..self.n {
cov_arr[i][j] = cov[(i, j)];
}
}
Ok(WlsResult {
update,
covariance: cov_arr,
chi2: self.chi2,
n_obs: self.n_obs,
n_params: self.n,
})
}
}
#[derive(Debug, Clone)]
pub struct WlsResult {
pub update: Vec<f64>,
pub covariance: Vec<Vec<f64>>,
pub chi2: f64,
pub n_obs: usize,
pub n_params: usize,
}
impl WlsResult {
pub fn reduced_chi2(&self) -> f64 {
let dof = self.n_obs.saturating_sub(self.n_params).max(1) as f64;
self.chi2 / dof
}
pub fn to_state_covariance(&self) -> Option<StateCovariance<GCRS>> {
if self.n_params != 6 {
return None;
}
let c = &self.covariance;
let rr = SymmetricFrameMatrix3::<GCRS>::from_upper([
[c[0][0], c[0][1], c[0][2]],
[c[1][0], c[1][1], c[1][2]],
[c[2][0], c[2][1], c[2][2]],
]);
let rv = FrameMatrix3::<GCRS>::from_array([
[c[0][3], c[0][4], c[0][5]],
[c[1][3], c[1][4], c[1][5]],
[c[2][3], c[2][4], c[2][5]],
]);
let vv = SymmetricFrameMatrix3::<GCRS>::from_upper([
[c[3][3], c[3][4], c[3][5]],
[c[4][3], c[4][4], c[4][5]],
[c[5][3], c[5][4], c[5][5]],
]);
Some(StateCovariance::<GCRS>::from_blocks(rr, rv, vv))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn solves_simple_2d_system() {
let mut ne = NormalEquations::new(2);
for (x, y) in [(0.0, 1.0), (1.0, 2.0), (2.0, 3.0)] {
ne.add_row(&[(0, 1.0), (1, x)], y, 1.0).unwrap();
}
let r = ne.solve().unwrap();
assert!((r.update[0] - 1.0).abs() < 1e-12);
assert!((r.update[1] - 1.0).abs() < 1e-12);
}
#[test]
fn flags_singular_system() {
let mut ne = NormalEquations::new(2);
ne.add_row(&[(0, 1.0)], 1.0, 1.0).unwrap();
ne.add_row(&[(0, 1.0)], 1.0, 1.0).unwrap();
assert!(matches!(
ne.solve(),
Err(WlsSolverError::NotPositiveDefinite(_))
));
}
#[test]
fn rejects_nan_sigma() {
let mut ne = NormalEquations::new(1);
assert!(matches!(
ne.add_row(&[(0, 1.0)], 1.0, f64::NAN),
Err(WlsSolverError::NotPositiveDefinite(_))
));
}
#[test]
fn rejects_inf_sigma() {
let mut ne = NormalEquations::new(1);
assert!(matches!(
ne.add_row(&[(0, 1.0)], 1.0, f64::INFINITY),
Err(WlsSolverError::NotPositiveDefinite(_))
));
}
#[test]
fn rejects_zero_sigma() {
let mut ne = NormalEquations::new(1);
assert!(matches!(
ne.add_row(&[(0, 1.0)], 1.0, 0.0),
Err(WlsSolverError::NotPositiveDefinite(_))
));
}
#[test]
fn rejects_nan_residual() {
let mut ne = NormalEquations::new(1);
assert!(matches!(
ne.add_row(&[(0, 1.0)], f64::NAN, 1.0),
Err(WlsSolverError::Other(_))
));
}
#[test]
fn rejects_nan_partial() {
let mut ne = NormalEquations::new(1);
assert!(matches!(
ne.add_row(&[(0, f64::NAN)], 1.0, 1.0),
Err(WlsSolverError::Other(_))
));
}
}