use ndarray::{Array1, Array2, Array3, ArrayView2, ArrayView3, Axis};
use thiserror::Error;
use crate::float::Float;
use crate::image::stats::{mad_sigma, median_in_place};
use super::accumulate::accumulate;
use super::nuisance::{refine_nuisance, solve_flux_background};
use super::render::render;
use super::robust::{CombineMethod, robust_combine};
pub const DEFAULT_MAX_ITER: usize = 50;
pub const DEFAULT_TOL: f64 = 1e-4;
pub const DEFAULT_STEP: f64 = 1.0;
pub const DEFAULT_NUISANCE_MAX_ITER: usize = 3;
pub const DEFAULT_NUISANCE_TOL: f64 = 1e-4;
const POWER_ITERATIONS: usize = 12;
const SEED_CLIP_KAPPA: f64 = 3.0;
const SEED_CLIP_MAX_ITER: usize = 5;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ResidualReweight {
None,
Huber { c: f64 },
Tukey { c: f64 },
}
#[derive(Debug, Clone, PartialEq)]
pub struct BuildEpsfParams {
pub max_iter: usize,
pub tol: f64,
pub step: f64,
pub residual_reweight: ResidualReweight,
pub nuisance_max_iter: usize,
pub nuisance_tol: f64,
}
impl Default for BuildEpsfParams {
fn default() -> Self {
Self {
max_iter: DEFAULT_MAX_ITER,
tol: DEFAULT_TOL,
step: DEFAULT_STEP,
residual_reweight: ResidualReweight::None,
nuisance_max_iter: DEFAULT_NUISANCE_MAX_ITER,
nuisance_tol: DEFAULT_NUISANCE_TOL,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BuildEpsf {
pub epsf: Array2<f64>,
pub flux: Array1<f64>,
pub background: Array1<f64>,
pub delta: Array2<f64>,
pub ok: Array1<bool>,
pub iterations: usize,
pub converged: bool,
}
#[derive(Debug, Error, PartialEq)]
pub enum BuildEpsfError {
#[error("oversample must be odd; got {oversample}")]
OversampleNotOdd { oversample: usize },
#[error("recovered stamp_size (data.shape()[1]) must be odd; got {stamp_size}")]
StampSizeEven { stamp_size: usize },
#[error(
"batch dimensions disagree: data shape {data:?} must be (N, s, s) with s odd and delta_init shape {delta_init:?} must be (N, 2)"
)]
BatchLengthMismatch {
data: (usize, usize, usize),
delta_init: (usize, usize),
},
#[error("weight shape {weight:?} must equal data shape {data:?}")]
WeightShapeMismatch {
weight: (usize, usize, usize),
data: (usize, usize, usize),
},
#[error("psi_init shape {psi_init:?} must equal the model grid {expected:?}")]
PsiInitShapeMismatch {
psi_init: (usize, usize),
expected: (usize, usize),
},
#[error(
"build_epsf requires max_iter > 0, tol/step/nuisance_tol finite and > 0, nuisance_max_iter > 0, and any ResidualReweight c finite and > 0; got max_iter = {max_iter}, tol = {tol}, step = {step}"
)]
ParamsInvalid {
max_iter: usize,
tol: f64,
step: f64,
},
}
pub fn build_epsf<T: Float>(
data: ArrayView3<T>,
weight: Option<ArrayView3<T>>,
delta_init: ArrayView2<f64>,
oversample: usize,
psi_init: Option<ArrayView2<f64>>,
params: BuildEpsfParams,
) -> Result<BuildEpsf, BuildEpsfError> {
if oversample.is_multiple_of(2) {
return Err(BuildEpsfError::OversampleNotOdd { oversample });
}
let batch_size = data.shape()[0];
let stamp_size = data.shape()[1];
if stamp_size.is_multiple_of(2) {
return Err(BuildEpsfError::StampSizeEven { stamp_size });
}
if data.shape()[2] != stamp_size
|| delta_init.shape()[0] != batch_size
|| delta_init.shape()[1] != 2
{
return Err(BuildEpsfError::BatchLengthMismatch {
data: (batch_size, data.shape()[1], data.shape()[2]),
delta_init: (delta_init.shape()[0], delta_init.shape()[1]),
});
}
if let Some(weight_view) = weight.as_ref() {
let weight_shape = weight_view.shape();
if weight_shape[0] != batch_size
|| weight_shape[1] != stamp_size
|| weight_shape[2] != stamp_size
{
return Err(BuildEpsfError::WeightShapeMismatch {
weight: (weight_shape[0], weight_shape[1], weight_shape[2]),
data: (batch_size, stamp_size, stamp_size),
});
}
}
let side = oversample * stamp_size;
if let Some(psi_view) = psi_init.as_ref() {
let psi_shape = psi_view.shape();
if psi_shape[0] != side || psi_shape[1] != side {
return Err(BuildEpsfError::PsiInitShapeMismatch {
psi_init: (psi_shape[0], psi_shape[1]),
expected: (side, side),
});
}
}
let reweight_c_invalid = match params.residual_reweight {
ResidualReweight::None => false,
ResidualReweight::Huber { c } | ResidualReweight::Tukey { c } => {
!(c.is_finite() && c > 0.0)
}
};
let params_invalid = params.max_iter == 0
|| !(params.tol.is_finite() && params.tol > 0.0)
|| !(params.step.is_finite() && params.step > 0.0)
|| params.nuisance_max_iter == 0
|| !(params.nuisance_tol.is_finite() && params.nuisance_tol > 0.0)
|| reweight_c_invalid;
if params_invalid {
return Err(BuildEpsfError::ParamsInvalid {
max_iter: params.max_iter,
tol: params.tol,
step: params.step,
});
}
let data_f: Array3<f64> = data.mapv(|value| value.to_f64().unwrap_or(f64::NAN));
let weight_base: Option<Array3<f64>> = weight
.as_ref()
.map(|w| w.mapv(|value| value.to_f64().unwrap_or(f64::NAN)));
let mut psi: Array2<f64> = match psi_init.as_ref() {
Some(view) => view.to_owned(),
None => seed_psi(&data_f, oversample, stamp_size, side),
};
if batch_size == 0 {
gauge_unit_volume(&mut psi, &mut Array1::<f64>::zeros(0), oversample);
return Ok(BuildEpsf {
epsf: psi,
flux: Array1::<f64>::zeros(0),
background: Array1::<f64>::zeros(0),
delta: Array2::<f64>::zeros((0, 2)),
ok: Array1::from_vec(Vec::<bool>::new()),
iterations: 0,
converged: true,
});
}
let mut delta: Array2<f64> = delta_init.to_owned();
let zeros_background = Array1::<f64>::zeros(batch_size);
gauge_unit_volume(&mut psi, &mut Array1::<f64>::zeros(0), oversample);
let initial = solve_flux_background::<f64>(
psi.view(),
oversample,
data_f.view(),
weight_base.as_ref().map(|w| w.view()),
delta.view(),
)
.expect("solve_flux_background preconditions are guaranteed by build_epsf validation");
let mut flux = initial.flux;
let mut background = initial.background;
let mut ok = initial.ok;
let mut iterations = 0usize;
let mut converged = false;
for iter in 0..params.max_iter {
let flux_safe = sanitize_1d(&flux);
let background_safe = sanitize_1d(&background);
let delta_safe = sanitize_2d(&delta);
let model = render(
psi.view(),
oversample,
delta_safe.view(),
flux_safe.view(),
background_safe.view(),
)
.expect("render preconditions are guaranteed by build_epsf validation");
let mut residual = Array3::<f64>::zeros((batch_size, stamp_size, stamp_size));
for n in 0..batch_size {
for i in 0..stamp_size {
for j in 0..stamp_size {
residual[(n, i, j)] = data_f[(n, i, j)] - model[(n, i, j)];
}
}
}
let factor = reweight_factor(
&residual,
weight_base.as_ref(),
&ok,
params.residual_reweight,
);
let weight_refine: Option<Array3<f64>> =
effective_weight(weight_base.as_ref(), &factor, params.residual_reweight);
let mut weight_acc = Array3::<f64>::zeros((batch_size, stamp_size, stamp_size));
let mut residual_acc = Array3::<f64>::zeros((batch_size, stamp_size, stamp_size));
for n in 0..batch_size {
let star_ok = ok[n];
for i in 0..stamp_size {
for j in 0..stamp_size {
let r = residual[(n, i, j)];
let w = match &weight_refine {
Some(w_eff) => w_eff[(n, i, j)],
None => 1.0,
};
if star_ok && r.is_finite() && w.is_finite() && w > 0.0 {
weight_acc[(n, i, j)] = w;
residual_acc[(n, i, j)] = r;
}
}
}
}
let operator_norm = power_iteration_norm(
oversample,
stamp_size,
&delta_safe,
&flux_safe,
&zeros_background,
&weight_acc,
side,
);
let update = accumulate(
residual_acc.view(),
Some(weight_acc.view()),
oversample,
stamp_size,
delta_safe.view(),
flux_safe.view(),
)
.expect("accumulate preconditions are guaranteed by build_epsf validation");
let lambda = if operator_norm.is_finite() && operator_norm > 0.0 {
params.step / operator_norm
} else {
0.0
};
let mut psi_next = Array2::<f64>::zeros((side, side));
for p in 0..side {
for q in 0..side {
let value = psi[(p, q)] + lambda * update[(p, q)];
psi_next[(p, q)] = if value > 0.0 { value } else { 0.0 };
}
}
gauge_unit_volume(&mut psi_next, &mut flux, oversample);
let relative_change = relative_l2_change(&psi, &psi_next);
psi = psi_next;
iterations = iter + 1;
run_refine(
&psi,
oversample,
&data_f,
weight_refine.as_ref(),
&delta.clone(),
params.nuisance_max_iter,
params.nuisance_tol,
&mut flux,
&mut background,
&mut delta,
&mut ok,
);
if relative_change < params.tol {
converged = iter + 1 < params.max_iter;
break;
}
}
Ok(BuildEpsf {
epsf: psi,
flux,
background,
delta,
ok,
iterations,
converged,
})
}
fn sanitize_1d(values: &Array1<f64>) -> Array1<f64> {
values.mapv(|x| if x.is_finite() { x } else { 0.0 })
}
fn sanitize_2d(values: &Array2<f64>) -> Array2<f64> {
values.mapv(|x| if x.is_finite() { x } else { 0.0 })
}
fn relative_l2_change(psi: &Array2<f64>, psi_next: &Array2<f64>) -> f64 {
let mut diff_sq = 0.0;
let mut base_sq = 0.0;
for (a, b) in psi.iter().zip(psi_next.iter()) {
let d = b - a;
diff_sq += d * d;
base_sq += a * a;
}
let base = base_sq.sqrt();
if base > 1e-300 {
diff_sq.sqrt() / base
} else {
diff_sq.sqrt()
}
}
fn gauge_unit_volume(psi: &mut Array2<f64>, flux: &mut Array1<f64>, oversample: usize) {
let volume = psi.sum() / (oversample * oversample) as f64;
if volume.is_finite() && volume > 0.0 {
psi.mapv_inplace(|x| x / volume);
flux.mapv_inplace(|f| f * volume);
}
}
fn reweight_factor(
residual: &Array3<f64>,
weight_base: Option<&Array3<f64>>,
ok: &Array1<bool>,
method: ResidualReweight,
) -> Array3<f64> {
let shape = residual.dim();
if let ResidualReweight::None = method {
return Array3::<f64>::ones(shape);
}
let global_sigma = if weight_base.is_none() {
global_mad_sigma(residual, ok)
} else {
f64::NAN
};
let mut factor = Array3::<f64>::ones(shape);
let (batch, height, width) = shape;
for n in 0..batch {
if !ok[n] {
continue;
}
for i in 0..height {
for j in 0..width {
let r = residual[(n, i, j)];
if !r.is_finite() {
continue;
}
let sigma = match weight_base {
Some(w) => {
let wv = w[(n, i, j)];
if wv.is_finite() && wv > 0.0 {
(1.0 / wv).sqrt()
} else {
continue;
}
}
None => global_sigma,
};
if !(sigma.is_finite() && sigma > 0.0) {
continue;
}
let z = (r / sigma).abs();
factor[(n, i, j)] = match method {
ResidualReweight::None => 1.0,
ResidualReweight::Huber { c } => {
if z <= c {
1.0
} else {
c / z
}
}
ResidualReweight::Tukey { c } => {
if z <= c {
let t = 1.0 - (z / c) * (z / c);
t * t
} else {
0.0
}
}
};
}
}
}
factor
}
fn global_mad_sigma(residual: &Array3<f64>, ok: &Array1<bool>) -> f64 {
let (batch, height, width) = residual.dim();
let mut samples: Vec<f64> = Vec::new();
for n in 0..batch {
if !ok[n] {
continue;
}
for i in 0..height {
for j in 0..width {
let r = residual[(n, i, j)];
if r.is_finite() {
samples.push(r);
}
}
}
}
mad_sigma(&samples)
}
fn effective_weight(
weight_base: Option<&Array3<f64>>,
factor: &Array3<f64>,
method: ResidualReweight,
) -> Option<Array3<f64>> {
match (weight_base, method) {
(None, ResidualReweight::None) => None,
(None, _) => Some(factor.clone()),
(Some(base), ResidualReweight::None) => Some(base.clone()),
(Some(base), _) => {
let mut w = base.clone();
for (we, fa) in w.iter_mut().zip(factor.iter()) {
*we *= *fa;
}
Some(w)
}
}
}
#[allow(clippy::too_many_arguments)]
fn power_iteration_norm(
oversample: usize,
stamp_size: usize,
delta_safe: &Array2<f64>,
flux_safe: &Array1<f64>,
zeros_background: &Array1<f64>,
weight_acc: &Array3<f64>,
side: usize,
) -> f64 {
let mut probe = Array2::<f64>::ones((side, side));
let mut estimate = 0.0;
for _ in 0..POWER_ITERATIONS {
let probe_norm = l2_norm(probe.iter());
if !(probe_norm.is_finite() && probe_norm > 0.0) {
return 0.0;
}
let a_probe = render(
probe.view(),
oversample,
delta_safe.view(),
flux_safe.view(),
zeros_background.view(),
)
.expect("render preconditions are guaranteed by build_epsf validation");
let at_w_a_probe = accumulate(
a_probe.view(),
Some(weight_acc.view()),
oversample,
stamp_size,
delta_safe.view(),
flux_safe.view(),
)
.expect("accumulate preconditions are guaranteed by build_epsf validation");
let next_norm = l2_norm(at_w_a_probe.iter());
if !next_norm.is_finite() {
return 0.0;
}
estimate = next_norm / probe_norm;
if next_norm == 0.0 {
return 0.0;
}
probe = at_w_a_probe.mapv(|x| x / next_norm);
}
estimate
}
fn l2_norm<'a, I: Iterator<Item = &'a f64>>(iter: I) -> f64 {
iter.map(|&x| x * x).sum::<f64>().sqrt()
}
#[allow(clippy::too_many_arguments)]
fn run_refine(
psi: &Array2<f64>,
oversample: usize,
data_f: &Array3<f64>,
weight: Option<&Array3<f64>>,
delta_init: &Array2<f64>,
nuisance_max_iter: usize,
nuisance_tol: f64,
flux: &mut Array1<f64>,
background: &mut Array1<f64>,
delta: &mut Array2<f64>,
ok: &mut Array1<bool>,
) {
let refined = refine_nuisance::<f64>(
psi.view(),
oversample,
data_f.view(),
weight.map(|w| w.view()),
delta_init.view(),
nuisance_max_iter,
nuisance_tol,
)
.expect("refine_nuisance preconditions are guaranteed by build_epsf validation");
flux.assign(&refined.flux);
background.assign(&refined.background);
delta.assign(&refined.delta);
ok.assign(&refined.ok);
}
fn seed_psi(
data_f: &Array3<f64>,
oversample: usize,
stamp_size: usize,
side: usize,
) -> Array2<f64> {
let batch_size = data_f.shape()[0];
if batch_size == 0 {
return Array2::<f64>::from_elem((side, side), 1.0);
}
let mut normalized = Array3::<f64>::from_elem((batch_size, stamp_size, stamp_size), f64::NAN);
for n in 0..batch_size {
let stamp = data_f.index_axis(Axis(0), n);
let background = border_ring_median(&stamp, stamp_size);
let mut flux_estimate = 0.0;
for i in 0..stamp_size {
for j in 0..stamp_size {
let v = stamp[(i, j)];
if v.is_finite() {
flux_estimate += v - background;
}
}
}
if !(flux_estimate.is_finite() && flux_estimate > 0.0) {
continue; }
for i in 0..stamp_size {
for j in 0..stamp_size {
let v = stamp[(i, j)];
if v.is_finite() {
normalized[(n, i, j)] = (v - background) / flux_estimate;
}
}
}
}
let coadd = robust_combine::<f64>(
normalized.view(),
None,
CombineMethod::ClippedMean {
kappa: SEED_CLIP_KAPPA,
max_iter: SEED_CLIP_MAX_ITER,
},
)
.expect("robust_combine preconditions hold for the seed coadd")
.combined;
let mut psi = Array2::<f64>::zeros((side, side));
let mut any_positive = false;
for p in 0..side {
let native_row = p / oversample;
for q in 0..side {
let native_col = q / oversample;
let value = coadd[(native_row, native_col)];
if value.is_finite() && value > 0.0 {
psi[(p, q)] = value;
any_positive = true;
}
}
}
if !any_positive {
return Array2::<f64>::from_elem((side, side), 1.0);
}
psi
}
fn border_ring_median(stamp: &ArrayView2<f64>, stamp_size: usize) -> f64 {
if stamp_size == 0 {
return 0.0;
}
let mut ring: Vec<f64> = Vec::new();
for j in 0..stamp_size {
let top = stamp[(0, j)];
if top.is_finite() {
ring.push(top);
}
let bottom = stamp[(stamp_size - 1, j)];
if bottom.is_finite() {
ring.push(bottom);
}
}
for i in 1..stamp_size.saturating_sub(1) {
let left = stamp[(i, 0)];
if left.is_finite() {
ring.push(left);
}
let right = stamp[(i, stamp_size - 1)];
if right.is_finite() {
ring.push(right);
}
}
median_in_place(&mut ring).unwrap_or(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::image::psf::render::render;
use ndarray::{Array1, Array2, Array3};
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn unit(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
fn range(&mut self, lo: f64, hi: f64) -> f64 {
lo + (hi - lo) * self.unit()
}
}
fn gaussian_epsf(side: usize, sigma: f64, amplitude: f64) -> Array2<f64> {
let center = (side as f64 - 1.0) / 2.0;
Array2::from_shape_fn((side, side), |(p, q)| {
let dp = p as f64 - center;
let dq = q as f64 - center;
amplitude * (-0.5 * (dp * dp + dq * dq) / (sigma * sigma)).exp()
})
}
fn asymmetric_epsf(side: usize, sigma: f64) -> Array2<f64> {
let center = (side as f64 - 1.0) / 2.0;
Array2::from_shape_fn((side, side), |(p, q)| {
let dp = p as f64 - center;
let dq = q as f64 - center;
let g = (-0.5 * (dp * dp + dq * dq) / (sigma * sigma)).exp();
g * (1.0 + 0.20 * (p as f64 / side as f64)) * (1.0 + 0.12 * (q as f64 / side as f64))
})
}
fn unit_volume(psi: &Array2<f64>, oversample: usize) -> Array2<f64> {
let volume = psi.sum() / (oversample * oversample) as f64;
psi.mapv(|x| x / volume)
}
fn synth(
epsf: &Array2<f64>,
oversample: usize,
delta: &Array2<f64>,
flux: &Array1<f64>,
background: &Array1<f64>,
) -> Array3<f64> {
render(
epsf.view(),
oversample,
delta.view(),
flux.view(),
background.view(),
)
.unwrap()
}
fn recon_rel_error(result: &BuildEpsf, oversample: usize, data: &Array3<f64>) -> f64 {
let model = render(
result.epsf.view(),
oversample,
result.delta.view(),
result.flux.view(),
result.background.view(),
)
.unwrap();
let mut diff = 0.0;
let mut base = 0.0;
for (m, d) in model.iter().zip(data.iter()) {
diff += (m - d) * (m - d);
base += d * d;
}
(diff / base).sqrt()
}
fn rel_l2(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let mut diff = 0.0;
let mut base = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
diff += (x - y) * (x - y);
base += y * y;
}
(diff / base).sqrt()
}
fn scattered_deltas(rng: &mut SplitMix64, n: usize) -> Array2<f64> {
Array2::from_shape_fn((n, 2), |_| rng.range(-0.45, 0.45))
}
#[test]
fn warm_start_recovers_known_psi_and_star_table() {
let oversample = 5;
let stamp_size = 7;
let side = oversample * stamp_size;
let psi_true = asymmetric_epsf(side, oversample as f64 * 1.4);
let mut rng = SplitMix64::new(0x5EED_1234_ABCD_0001);
let n = 8;
let delta = scattered_deltas(&mut rng, n);
let flux = Array1::from_shape_fn(n, |_| rng.range(50.0, 150.0));
let background = Array1::from_shape_fn(n, |_| rng.range(2.0, 9.0));
let data = synth(&psi_true, oversample, &delta, &flux, &background);
let params = BuildEpsfParams {
max_iter: 30,
tol: 1e-9,
step: 1.0,
residual_reweight: ResidualReweight::None,
nuisance_max_iter: 10,
nuisance_tol: 1e-12,
};
let result = build_epsf::<f64>(
data.view(),
None,
delta.view(),
oversample,
Some(psi_true.view()),
params,
)
.unwrap();
assert_eq!(result.epsf.dim(), (side, side));
assert!((result.epsf.sum() / (oversample * oversample) as f64 - 1.0).abs() < 1e-9);
assert!(result.epsf.iter().all(|&x| x >= 0.0 && x.is_finite()));
assert!(result.ok.iter().all(|&o| o));
assert!(result.iterations <= 3, "iterations = {}", result.iterations);
assert!(result.converged);
assert!(rel_l2(&result.epsf, &unit_volume(&psi_true, oversample)) < 1e-4);
assert!(recon_rel_error(&result, oversample, &data) < 1e-6);
for n in 0..n {
assert!((result.delta[(n, 0)] - delta[(n, 0)]).abs() < 1e-4);
assert!((result.delta[(n, 1)] - delta[(n, 1)]).abs() < 1e-4);
}
}
#[test]
fn seed_path_runs_and_converges() {
let oversample = 3;
let stamp_size = 7;
let side = oversample * stamp_size;
let psi_true = gaussian_epsf(side, oversample as f64 * 1.5, 100.0);
let mut rng = SplitMix64::new(0x5EED_1234_ABCD_0002);
let n = 14;
let delta = scattered_deltas(&mut rng, n);
let flux = Array1::from_shape_fn(n, |_| rng.range(60.0, 140.0));
let background = Array1::from_shape_fn(n, |_| rng.range(1.0, 5.0));
let data = synth(&psi_true, oversample, &delta, &flux, &background);
let params = BuildEpsfParams {
max_iter: 250,
tol: 1e-7,
step: 1.0,
residual_reweight: ResidualReweight::None,
nuisance_max_iter: 3,
nuisance_tol: 1e-8,
};
let result =
build_epsf::<f64>(data.view(), None, delta.view(), oversample, None, params).unwrap();
assert_eq!(result.epsf.dim(), (side, side));
assert!(result.epsf.iter().all(|&x| x >= 0.0 && x.is_finite()));
assert!(result.ok.iter().all(|&o| o));
assert!((result.epsf.sum() / (oversample * oversample) as f64 - 1.0).abs() < 1e-6);
let err = recon_rel_error(&result, oversample, &data);
assert!(err < 0.01, "seed-path reconstruction error = {err}");
let shape = rel_l2(&result.epsf, &unit_volume(&psi_true, oversample));
assert!(shape < 0.25, "seed-path psi shape error = {shape}");
assert!(result.iterations <= 250);
}
#[test]
fn max_iter_is_respected_and_converged_semantics() {
let oversample = 5;
let stamp_size = 7;
let side = oversample * stamp_size;
let psi_true = gaussian_epsf(side, oversample as f64 * 1.4, 100.0);
let mut rng = SplitMix64::new(0x5EED_1234_ABCD_0003);
let n = 6;
let delta = scattered_deltas(&mut rng, n);
let flux = Array1::from_shape_fn(n, |_| rng.range(50.0, 120.0));
let background = Array1::from_elem(n, 3.0);
let data = synth(&psi_true, oversample, &delta, &flux, &background);
let capped = build_epsf::<f64>(
data.view(),
None,
delta.view(),
oversample,
Some(psi_true.view()),
BuildEpsfParams {
max_iter: 1,
..Default::default()
},
)
.unwrap();
assert_eq!(capped.iterations, 1);
assert!(!capped.converged);
let free = build_epsf::<f64>(
data.view(),
None,
delta.view(),
oversample,
Some(psi_true.view()),
BuildEpsfParams {
max_iter: 50,
tol: 1e-8,
nuisance_max_iter: 8,
nuisance_tol: 1e-12,
..Default::default()
},
)
.unwrap();
assert!(free.converged);
assert!(free.iterations < 50);
}
#[test]
fn output_nonnegative_with_negative_data_pixels() {
let oversample = 5;
let stamp_size = 7;
let side = oversample * stamp_size;
let psi_true = gaussian_epsf(side, oversample as f64 * 1.3, 80.0);
let mut rng = SplitMix64::new(0x5EED_1234_ABCD_0004);
let n = 6;
let delta = scattered_deltas(&mut rng, n);
let flux = Array1::from_shape_fn(n, |_| rng.range(40.0, 90.0));
let background = Array1::from_elem(n, -7.0);
let data = synth(&psi_true, oversample, &delta, &flux, &background);
assert!(data.iter().any(|&v| v < 0.0));
let result = build_epsf::<f64>(
data.view(),
None,
delta.view(),
oversample,
Some(psi_true.view()),
BuildEpsfParams {
max_iter: 30,
tol: 1e-8,
nuisance_max_iter: 8,
nuisance_tol: 1e-12,
..Default::default()
},
)
.unwrap();
assert!(result.epsf.iter().all(|&x| x >= 0.0 && x.is_finite()));
assert!(result.ok.iter().all(|&o| o));
assert!(recon_rel_error(&result, oversample, &data) < 1e-5);
}
#[test]
fn gauge_unit_volume_invariance() {
let oversample = 5;
let stamp_size = 7;
let side = oversample * stamp_size;
let mut psi = gaussian_epsf(side, oversample as f64 * 1.5, 42.0);
let mut flux = Array1::from_vec(vec![3.0, 7.5]);
let delta = Array2::from_shape_vec((2, 2), vec![0.1, -0.2, -0.3, 0.05]).unwrap();
let background = Array1::from_vec(vec![1.0, 2.0]);
let before = render(
psi.view(),
oversample,
delta.view(),
flux.view(),
background.view(),
)
.unwrap();
gauge_unit_volume(&mut psi, &mut flux, oversample);
assert!((psi.sum() / (oversample * oversample) as f64 - 1.0).abs() < 1e-12);
let after = render(
psi.view(),
oversample,
delta.view(),
flux.view(),
background.view(),
)
.unwrap();
for (a, b) in before.iter().zip(after.iter()) {
assert!((a - b).abs() < 1e-9 * a.abs().max(1.0));
}
}
fn build_with_outlier(sign: f64, method: ResidualReweight) -> (BuildEpsf, Array2<f64>, usize) {
let oversample = 5;
let stamp_size = 7;
let side = oversample * stamp_size;
let psi_true = gaussian_epsf(side, oversample as f64 * 1.4, 100.0);
let mut rng = SplitMix64::new(0x5EED_1234_ABCD_0005);
let n = 8;
let delta = scattered_deltas(&mut rng, n);
let flux = Array1::from_shape_fn(n, |_| rng.range(60.0, 130.0));
let background = Array1::from_elem(n, 4.0);
let mut data = synth(&psi_true, oversample, &delta, &flux, &background);
for (i, j) in [(2, 2), (2, 4), (4, 3)] {
data[(0, i, j)] += sign * 5000.0;
}
let result = build_epsf::<f64>(
data.view(),
None,
delta.view(),
oversample,
Some(psi_true.view()),
BuildEpsfParams {
max_iter: 25,
tol: 1e-9,
step: 1.0,
residual_reweight: method,
nuisance_max_iter: 6,
nuisance_tol: 1e-10,
},
)
.unwrap();
(result, unit_volume(&psi_true, oversample), oversample)
}
#[test]
fn residual_reweight_reduces_outlier_bias() {
let (none_res, truth, os) = build_with_outlier(1.0, ResidualReweight::None);
let (huber_res, _, _) = build_with_outlier(1.0, ResidualReweight::Huber { c: 3.0 });
let (tukey_res, _, _) = build_with_outlier(1.0, ResidualReweight::Tukey { c: 4.0 });
let err_none = rel_l2(&none_res.epsf, &truth);
let err_huber = rel_l2(&huber_res.epsf, &truth);
let err_tukey = rel_l2(&tukey_res.epsf, &truth);
assert!(err_huber < err_none, "huber {err_huber} vs none {err_none}");
assert!(err_tukey < err_none, "tukey {err_tukey} vs none {err_none}");
let _ = os;
}
#[test]
fn residual_reweight_is_sign_agnostic() {
let (plus, truth, _) = build_with_outlier(1.0, ResidualReweight::Huber { c: 3.0 });
let (minus, _, _) = build_with_outlier(-1.0, ResidualReweight::Huber { c: 3.0 });
let err_plus = rel_l2(&plus.epsf, &truth);
let err_minus = rel_l2(&minus.epsf, &truth);
assert!(
(err_plus - err_minus).abs() < 1e-6 * err_plus.max(1e-6),
"asymmetric: +{err_plus} vs -{err_minus}"
);
}
#[test]
fn per_star_not_ok_is_excluded_from_psi() {
let oversample = 5;
let stamp_size = 7;
let side = oversample * stamp_size;
let psi_true = gaussian_epsf(side, oversample as f64 * 1.4, 100.0);
let mut rng = SplitMix64::new(0x5EED_1234_ABCD_0006);
let n = 7;
let delta = scattered_deltas(&mut rng, n);
let flux = Array1::from_shape_fn(n, |_| rng.range(50.0, 120.0));
let background = Array1::from_elem(n, 3.0);
let mut data = synth(&psi_true, oversample, &delta, &flux, &background);
for i in 0..stamp_size {
for j in 0..stamp_size {
data[(3, i, j)] = f64::NAN;
}
}
let result = build_epsf::<f64>(
data.view(),
None,
delta.view(),
oversample,
Some(psi_true.view()),
BuildEpsfParams {
max_iter: 30,
tol: 1e-9,
nuisance_max_iter: 8,
nuisance_tol: 1e-12,
..Default::default()
},
)
.unwrap();
assert!(!result.ok[3]);
assert!(result.flux[3].is_nan());
for n in 0..n {
if n != 3 {
assert!(result.ok[n], "star {n} should be ok");
}
}
assert!(result.epsf.iter().all(|&x| x >= 0.0 && x.is_finite()));
assert!(rel_l2(&result.epsf, &unit_volume(&psi_true, oversample)) < 1e-3);
}
#[test]
fn single_star_recovers() {
let oversample = 5;
let stamp_size = 7;
let side = oversample * stamp_size;
let psi_true = gaussian_epsf(side, oversample as f64 * 1.4, 100.0);
let delta = Array2::from_shape_vec((1, 2), vec![0.12, -0.18]).unwrap();
let flux = Array1::from_vec(vec![90.0]);
let background = Array1::from_vec(vec![4.0]);
let data = synth(&psi_true, oversample, &delta, &flux, &background);
let result = build_epsf::<f64>(
data.view(),
None,
delta.view(),
oversample,
Some(psi_true.view()),
BuildEpsfParams {
max_iter: 30,
tol: 1e-9,
nuisance_max_iter: 8,
nuisance_tol: 1e-12,
..Default::default()
},
)
.unwrap();
assert!(result.ok[0]);
assert!(recon_rel_error(&result, oversample, &data) < 1e-5);
}
#[test]
fn empty_batch_is_legal() {
let oversample = 5;
let stamp_size = 7;
let side = oversample * stamp_size;
let data = Array3::<f64>::zeros((0, stamp_size, stamp_size));
let delta = Array2::<f64>::zeros((0, 2));
let result = build_epsf::<f64>(
data.view(),
None,
delta.view(),
oversample,
None,
BuildEpsfParams::default(),
)
.unwrap();
assert_eq!(result.epsf.dim(), (side, side));
assert_eq!(result.flux.len(), 0);
assert_eq!(result.background.len(), 0);
assert_eq!(result.delta.dim(), (0, 2));
assert_eq!(result.ok.len(), 0);
assert_eq!(result.iterations, 0);
assert!(result.converged);
}
#[test]
fn f32_and_f64_paths_agree() {
let oversample = 5;
let stamp_size = 7;
let side = oversample * stamp_size;
let psi_true = gaussian_epsf(side, oversample as f64 * 1.4, 100.0);
let mut rng = SplitMix64::new(0x5EED_1234_ABCD_0007);
let n = 6;
let delta = scattered_deltas(&mut rng, n);
let flux = Array1::from_shape_fn(n, |_| rng.range(50.0, 110.0));
let background = Array1::from_elem(n, 3.0);
let data64 = synth(&psi_true, oversample, &delta, &flux, &background);
let data32 = data64.mapv(|x| x as f32);
let params = BuildEpsfParams {
max_iter: 25,
tol: 1e-7,
nuisance_max_iter: 6,
nuisance_tol: 1e-9,
..Default::default()
};
let r64 = build_epsf::<f64>(
data64.view(),
None,
delta.view(),
oversample,
Some(psi_true.view()),
params.clone(),
)
.unwrap();
let r32 = build_epsf::<f32>(
data32.view(),
None,
delta.view(),
oversample,
Some(psi_true.view()),
params,
)
.unwrap();
assert!(r64.ok.iter().all(|&o| o));
assert!(r32.ok.iter().all(|&o| o));
assert!(rel_l2(&r64.epsf, &r32.epsf) < 1e-3);
assert!(rel_l2(&r64.epsf, &unit_volume(&psi_true, oversample)) < 1e-3);
}
#[test]
fn default_params_are_sane() {
let p = BuildEpsfParams::default();
assert_eq!(p.max_iter, DEFAULT_MAX_ITER);
assert_eq!(p.tol, DEFAULT_TOL);
assert_eq!(p.step, DEFAULT_STEP);
assert_eq!(p.residual_reweight, ResidualReweight::None);
assert_eq!(p.nuisance_max_iter, DEFAULT_NUISANCE_MAX_ITER);
assert_eq!(p.nuisance_tol, DEFAULT_NUISANCE_TOL);
}
fn ok_params() -> BuildEpsfParams {
BuildEpsfParams::default()
}
#[test]
fn err_oversample_not_odd() {
let data = Array3::<f64>::zeros((2, 7, 7));
let delta = Array2::<f64>::zeros((2, 2));
let err =
build_epsf::<f64>(data.view(), None, delta.view(), 4, None, ok_params()).unwrap_err();
assert_eq!(err, BuildEpsfError::OversampleNotOdd { oversample: 4 });
let err0 =
build_epsf::<f64>(data.view(), None, delta.view(), 0, None, ok_params()).unwrap_err();
assert_eq!(err0, BuildEpsfError::OversampleNotOdd { oversample: 0 });
}
#[test]
fn err_stamp_size_even() {
let data = Array3::<f64>::zeros((2, 6, 6));
let delta = Array2::<f64>::zeros((2, 2));
let err =
build_epsf::<f64>(data.view(), None, delta.view(), 5, None, ok_params()).unwrap_err();
assert_eq!(err, BuildEpsfError::StampSizeEven { stamp_size: 6 });
}
#[test]
fn err_batch_length_mismatch() {
let data = Array3::<f64>::zeros((2, 7, 5));
let delta = Array2::<f64>::zeros((2, 2));
let err =
build_epsf::<f64>(data.view(), None, delta.view(), 5, None, ok_params()).unwrap_err();
assert_eq!(
err,
BuildEpsfError::BatchLengthMismatch {
data: (2, 7, 5),
delta_init: (2, 2),
}
);
let data2 = Array3::<f64>::zeros((2, 7, 7));
let delta_bad = Array2::<f64>::zeros((3, 2));
let err2 = build_epsf::<f64>(data2.view(), None, delta_bad.view(), 5, None, ok_params())
.unwrap_err();
assert_eq!(
err2,
BuildEpsfError::BatchLengthMismatch {
data: (2, 7, 7),
delta_init: (3, 2),
}
);
}
#[test]
fn err_weight_shape_mismatch() {
let data = Array3::<f64>::zeros((2, 7, 7));
let delta = Array2::<f64>::zeros((2, 2));
let weight = Array3::<f64>::ones((2, 7, 5));
let err = build_epsf::<f64>(
data.view(),
Some(weight.view()),
delta.view(),
5,
None,
ok_params(),
)
.unwrap_err();
assert_eq!(
err,
BuildEpsfError::WeightShapeMismatch {
weight: (2, 7, 5),
data: (2, 7, 7),
}
);
}
#[test]
fn err_psi_init_shape_mismatch() {
let data = Array3::<f64>::zeros((2, 7, 7));
let delta = Array2::<f64>::zeros((2, 2));
let psi_bad = Array2::<f64>::zeros((34, 35));
let err = build_epsf::<f64>(
data.view(),
None,
delta.view(),
5,
Some(psi_bad.view()),
ok_params(),
)
.unwrap_err();
assert_eq!(
err,
BuildEpsfError::PsiInitShapeMismatch {
psi_init: (34, 35),
expected: (35, 35),
}
);
}
#[test]
fn err_params_invalid() {
let data = Array3::<f64>::zeros((2, 7, 7));
let delta = Array2::<f64>::zeros((2, 2));
let cases = [
BuildEpsfParams {
max_iter: 0,
..Default::default()
},
BuildEpsfParams {
tol: 0.0,
..Default::default()
},
BuildEpsfParams {
tol: f64::NAN,
..Default::default()
},
BuildEpsfParams {
step: -1.0,
..Default::default()
},
BuildEpsfParams {
step: f64::INFINITY,
..Default::default()
},
BuildEpsfParams {
nuisance_max_iter: 0,
..Default::default()
},
BuildEpsfParams {
nuisance_tol: -3.0,
..Default::default()
},
BuildEpsfParams {
residual_reweight: ResidualReweight::Huber { c: 0.0 },
..Default::default()
},
BuildEpsfParams {
residual_reweight: ResidualReweight::Tukey { c: f64::NAN },
..Default::default()
},
];
for params in cases {
let err =
build_epsf::<f64>(data.view(), None, delta.view(), 5, None, params).unwrap_err();
assert!(matches!(err, BuildEpsfError::ParamsInvalid { .. }));
}
let okp = BuildEpsfParams {
residual_reweight: ResidualReweight::Huber { c: 3.0 },
..Default::default()
};
assert!(build_epsf::<f64>(data.view(), None, delta.view(), 5, None, okp).is_ok());
}
#[test]
fn shape_preconditions_precede_params_invalid() {
let data = Array3::<f64>::zeros((2, 7, 7));
let delta = Array2::<f64>::zeros((2, 2));
let err = build_epsf::<f64>(
data.view(),
None,
delta.view(),
4,
None,
BuildEpsfParams {
max_iter: 0,
..Default::default()
},
)
.unwrap_err();
assert_eq!(err, BuildEpsfError::OversampleNotOdd { oversample: 4 });
}
}