use crate::astro::math::linear::{
invert_flat_first_tie_into, normal_equations_weighted, solve_flat_normal_first_tie_into,
solve_flat_normal_square_root_into, solve_linear_last_tie, FlatCholeskySolveScratch,
FlatLinearScratch, FlatNormalSolveScratch,
};
use crate::estimation::recipe::NormalRecipe;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum CovarianceBlock {
SharedReferenceDoubleDifference,
}
#[derive(Debug, Default)]
pub(crate) struct BlockFoldScratch {
r_inv: Vec<f64>,
cov: Vec<f64>,
r_inv_y: Vec<f64>,
invert: FlatLinearScratch,
}
pub(crate) trait CorrelatedBlock {
fn len(&self) -> usize;
fn sd_variance(&self, k: usize) -> f64;
fn ref_variance(&self, k: usize) -> f64;
fn design(&self, k: usize) -> &[f64];
fn value(&self, k: usize) -> f64;
}
impl CovarianceBlock {
#[allow(clippy::needless_range_loop)]
pub(crate) fn inverse_into(
&self,
m: usize,
sd_variance: impl Fn(usize) -> f64,
ref_variance: impl Fn(usize) -> f64,
r_inv: &mut Vec<f64>,
cov: &mut Vec<f64>,
invert: &mut FlatLinearScratch,
) -> Option<()> {
if m == 0 {
r_inv.clear();
return Some(());
}
let first = sd_variance(0);
let constant = (0..m).all(|k| sd_variance(k) == first && ref_variance(k) == first);
if constant {
let mf = m as f64;
let diagonal_scale = 1.0 / first * (1.0 - 1.0 / (mf + 1.0));
let off_diagonal = -1.0 / (first * (mf + 1.0));
r_inv.resize(m * m, 0.0);
for i in 0..m {
for j in 0..m {
r_inv[i * m + j] = if i == j { diagonal_scale } else { off_diagonal };
}
}
Some(())
} else {
let ref_v = ref_variance(0);
cov.resize(m * m, 0.0);
for i in 0..m {
for j in 0..m {
cov[i * m + j] = if i == j {
sd_variance(i) + ref_v
} else {
ref_v
};
}
}
invert_flat_first_tie_into(cov, m, r_inv, invert)
}
}
#[allow(clippy::needless_range_loop)]
pub(crate) fn fold_block_into(
&self,
block: &impl CorrelatedBlock,
lambda: &mut [f64],
eta: &mut [f64],
scratch: &mut BlockFoldScratch,
) -> Option<()> {
let m = block.len();
if m == 0 {
return Some(());
}
let n = eta.len();
self.inverse_into(
m,
|k| block.sd_variance(k),
|k| block.ref_variance(k),
&mut scratch.r_inv,
&mut scratch.cov,
&mut scratch.invert,
)?;
let r_inv = &scratch.r_inv;
scratch.r_inv_y.resize(m, 0.0);
for (a, rinvy_a) in scratch.r_inv_y.iter_mut().enumerate() {
let mut s = 0.0;
for b in 0..m {
s += r_inv[a * m + b] * block.value(b);
}
*rinvy_a = s;
}
for i in 0..n {
let row = i * n;
for j in 0..n {
let mut acc = 0.0;
for a in 0..m {
let hi = block.design(a)[i];
let mut row_sum = 0.0;
for b in 0..m {
row_sum += r_inv[a * m + b] * block.design(b)[j];
}
acc += hi * row_sum;
}
lambda[row + j] += acc;
}
}
for (i, e) in eta.iter_mut().enumerate() {
let mut acc = 0.0;
for a in 0..m {
acc += block.design(a)[i] * scratch.r_inv_y[a];
}
*e += acc;
}
Some(())
}
}
#[derive(Clone, Copy, Debug)]
pub(crate) struct NormalAssembler {
recipe: NormalRecipe,
}
impl NormalAssembler {
pub(crate) const fn new(recipe: NormalRecipe) -> Self {
Self { recipe }
}
pub(crate) fn solve_dense_last_tie<'a, I>(&self, rows: I, n: usize) -> Option<Vec<f64>>
where
I: IntoIterator<Item = (&'a [f64], f64, f64)>,
{
debug_assert_eq!(self.recipe, NormalRecipe::PppDenseLastTie);
let (ata, aty) = normal_equations_weighted(rows, n)?;
solve_linear_last_tie(ata, aty)
}
pub(crate) fn assemble_dense<'a, I>(
&self,
rows: I,
n: usize,
) -> Option<(Vec<Vec<f64>>, Vec<f64>)>
where
I: IntoIterator<Item = (&'a [f64], f64, f64)>,
{
debug_assert_eq!(self.recipe, NormalRecipe::PppDenseLastTie);
normal_equations_weighted(rows, n)
}
pub(crate) fn solve_dense_square_root<'a, I>(&self, rows: I, n: usize) -> Option<Vec<f64>>
where
I: IntoIterator<Item = (&'a [f64], f64, f64)>,
{
debug_assert_eq!(self.recipe, NormalRecipe::CanonicalSquareRoot);
let (ata, aty) = normal_equations_weighted(rows, n)?;
let mut lambda = Vec::with_capacity(n * n);
for row in &ata {
lambda.extend_from_slice(row);
}
let mut scratch = FlatCholeskySolveScratch::default();
solve_flat_normal_square_root_into(&lambda, &aty, &mut scratch).map(<[f64]>::to_vec)
}
pub(crate) fn solve_flat_first_tie<'s>(
&self,
lambda: &[f64],
eta: &[f64],
scratch: &'s mut FlatNormalSolveScratch,
) -> Option<&'s [f64]> {
debug_assert_eq!(self.recipe, NormalRecipe::RtkDoubleDifferenceBlockFirstTie);
solve_flat_normal_first_tie_into(lambda, eta, scratch)
}
pub(crate) fn solve_square_root<'s>(
&self,
lambda: &[f64],
eta: &[f64],
scratch: &'s mut FlatCholeskySolveScratch,
) -> Option<&'s [f64]> {
debug_assert_eq!(self.recipe, NormalRecipe::CanonicalSquareRoot);
solve_flat_normal_square_root_into(lambda, eta, scratch)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shared_reference_dd_inverse_matches_dense_inverse() {
let sd = [1.0_f64, 2.0];
let mut r_inv = Vec::new();
let mut cov = Vec::new();
let mut invert = FlatLinearScratch::default();
CovarianceBlock::SharedReferenceDoubleDifference
.inverse_into(2, |k| sd[k], |_| 0.5, &mut r_inv, &mut cov, &mut invert)
.unwrap();
let r = [[1.5_f64, 0.5], [0.5, 2.5]];
for (i, r_row) in r.iter().enumerate() {
for j in 0..2 {
let prod = r_row[0] * r_inv[j] + r_row[1] * r_inv[2 + j];
let expect = if i == j { 1.0 } else { 0.0 };
assert!((prod - expect).abs() < 1.0e-12);
}
}
}
#[test]
fn shared_reference_dd_inverse_uses_equal_variance_closed_form() {
let mut r_inv = Vec::new();
let mut cov = Vec::new();
let mut invert = FlatLinearScratch::default();
CovarianceBlock::SharedReferenceDoubleDifference
.inverse_into(1, |_| 4.0, |_| 4.0, &mut r_inv, &mut cov, &mut invert)
.unwrap();
assert_eq!(r_inv, vec![0.125]);
}
#[test]
fn empty_block_clears_inverse() {
let mut r_inv = vec![9.0, 9.0];
let mut cov = Vec::new();
let mut invert = FlatLinearScratch::default();
CovarianceBlock::SharedReferenceDoubleDifference
.inverse_into(0, |_| 1.0, |_| 1.0, &mut r_inv, &mut cov, &mut invert)
.unwrap();
assert!(r_inv.is_empty());
}
#[test]
fn dense_last_tie_solves_diagonal_system() {
let r0: Vec<f64> = vec![1.0, 0.0];
let r1: Vec<f64> = vec![0.0, 1.0];
let rows = [(r0.as_slice(), 1.0, 1.0), (r1.as_slice(), 2.0, 1.0)];
let x = NormalAssembler::new(NormalRecipe::PppDenseLastTie)
.solve_dense_last_tie(rows.iter().copied(), 2)
.unwrap();
assert_eq!(x, vec![1.0, 2.0]);
}
}