use ndarray::{Array3, ArrayView1, ArrayView2, ArrayViewMut2, Axis};
use rayon::prelude::*;
use thiserror::Error;
use super::kernel::catmull_rom_sample;
#[derive(Debug, Error, PartialEq)]
pub enum RenderError {
#[error("epsf must be square; got ({rows}, {cols})")]
EpsfNotSquare { rows: usize, cols: usize },
#[error("oversample must be odd; got {oversample}")]
OversampleNotOdd { oversample: usize },
#[error("epsf side ({epsf_side}) must be an integer multiple of oversample ({oversample})")]
EpsfSizeNotMultiple { epsf_side: usize, oversample: usize },
#[error("derived stamp_size (epsf_side / oversample) must be odd; got {stamp_size}")]
DerivedStampSizeEven { stamp_size: usize },
#[error(
"batch dimensions disagree: delta shape {delta:?} must be (N, 2) with N == flux len ({flux}) == background len ({background})"
)]
BatchLengthMismatch {
delta: (usize, usize),
flux: usize,
background: usize,
},
}
pub fn render(
epsf: ArrayView2<f64>,
oversample: usize,
delta: ArrayView2<f64>,
flux: ArrayView1<f64>,
background: ArrayView1<f64>,
) -> Result<Array3<f64>, RenderError> {
let epsf_rows = epsf.shape()[0];
let epsf_cols = epsf.shape()[1];
if epsf_rows != epsf_cols {
return Err(RenderError::EpsfNotSquare {
rows: epsf_rows,
cols: epsf_cols,
});
}
if oversample.is_multiple_of(2) {
return Err(RenderError::OversampleNotOdd { oversample });
}
let epsf_side = epsf_rows;
if !epsf_side.is_multiple_of(oversample) {
return Err(RenderError::EpsfSizeNotMultiple {
epsf_side,
oversample,
});
}
let stamp_size = epsf_side / oversample;
if stamp_size.is_multiple_of(2) {
return Err(RenderError::DerivedStampSizeEven { stamp_size });
}
let batch_size = flux.len();
if delta.shape()[1] != 2 || delta.shape()[0] != batch_size || background.len() != batch_size {
return Err(RenderError::BatchLengthMismatch {
delta: (delta.shape()[0], delta.shape()[1]),
flux: batch_size,
background: background.len(),
});
}
let psf_center = (epsf_side as f64 - 1.0) / 2.0;
let detector_center = (stamp_size as f64 - 1.0) / 2.0;
let oversample_f = oversample as f64;
let mut output = Array3::<f64>::zeros((batch_size, stamp_size, stamp_size));
output
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(stamp_index, stamp_out)| {
render_stamp(
stamp_out,
epsf,
delta[(stamp_index, 0)],
delta[(stamp_index, 1)],
flux[stamp_index],
background[stamp_index],
oversample_f,
stamp_size,
psf_center,
detector_center,
);
});
Ok(output)
}
#[allow(clippy::too_many_arguments)]
fn render_stamp(
mut stamp_out: ArrayViewMut2<f64>,
epsf: ArrayView2<f64>,
delta_row: f64,
delta_column: f64,
flux: f64,
background: f64,
oversample: f64,
stamp_size: usize,
psf_center: f64,
detector_center: f64,
) {
for i in 0..stamp_size {
let k_u = psf_center + oversample * ((i as f64 - detector_center) - delta_row);
for j in 0..stamp_size {
let k_v = psf_center + oversample * ((j as f64 - detector_center) - delta_column);
stamp_out[(i, j)] = flux * catmull_rom_sample(&epsf, k_u, k_v) + background;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::image::psf::kernel::{catmull_rom_sample, catmull_rom_weights};
use ndarray::{Array1, Array2, arr2};
const TOL: f64 = 1e-9;
fn gaussian_epsf(epsf_side: usize, sigma_model: f64, amplitude: f64) -> Array2<f64> {
let center = (epsf_side as f64 - 1.0) / 2.0;
let mut model = Array2::<f64>::zeros((epsf_side, epsf_side));
for p in 0..epsf_side {
for q in 0..epsf_side {
let dp = p as f64 - center;
let dq = q as f64 - center;
model[(p, q)] =
amplitude * (-0.5 * (dp * dp + dq * dq) / (sigma_model * sigma_model)).exp();
}
}
model
}
fn ramp_epsf(epsf_side: usize) -> Array2<f64> {
let mut model = Array2::<f64>::zeros((epsf_side, epsf_side));
for p in 0..epsf_side {
for q in 0..epsf_side {
model[(p, q)] = 1.0 + p as f64 + 3.0 * q as f64 + 0.5 * (p * q) as f64;
}
}
model
}
#[test]
fn catmull_rom_weights_partition_of_unity_and_integer_pick() {
assert_eq!(catmull_rom_weights(0.0), [0.0, 1.0, 0.0, 0.0]);
for &frac in &[0.1, 0.25, 0.5, 0.7, 0.9999] {
let weights = catmull_rom_weights(frac);
let sum: f64 = weights.iter().sum();
assert!((sum - 1.0).abs() < TOL, "sum = {sum} at frac = {frac}");
}
}
#[test]
fn delta_zero_is_exact_grid_extraction() {
let oversample = 5;
let stamp_size = 7;
let epsf_side = oversample * stamp_size;
let epsf = ramp_epsf(epsf_side);
let delta = arr2(&[[0.0, 0.0]]);
let flux = Array1::from_elem(1, 1.0);
let background = Array1::from_elem(1, 0.0);
let out = render(
epsf.view(),
oversample,
delta.view(),
flux.view(),
background.view(),
)
.unwrap();
assert_eq!(out.shape(), &[1, stamp_size, stamp_size]);
let psf_center = (epsf_side - 1) / 2; let detector_center = (stamp_size - 1) / 2; for i in 0..stamp_size {
for j in 0..stamp_size {
let p = psf_center + oversample * i - oversample * detector_center;
let q = psf_center + oversample * j - oversample * detector_center;
assert!(
(out[(0, i, j)] - epsf[(p, q)]).abs() < 1e-12,
"({i},{j}) rendered {} != model {}",
out[(0, i, j)],
epsf[(p, q)]
);
}
}
}
#[test]
fn subpixel_shift_matches_analytic_gaussian() {
let oversample = 5;
let stamp_size = 9;
let epsf_side = oversample * stamp_size;
let sigma_det = 1.5;
let amplitude = 50.0;
let epsf = gaussian_epsf(epsf_side, oversample as f64 * sigma_det, amplitude);
let delta_row = 0.3;
let delta_col = -0.2;
let flux = 3.0;
let back = 2.0;
let delta = arr2(&[[delta_row, delta_col]]);
let out = render(
epsf.view(),
oversample,
delta.view(),
Array1::from_elem(1, flux).view(),
Array1::from_elem(1, back).view(),
)
.unwrap();
let detector_center = (stamp_size as f64 - 1.0) / 2.0;
let mut max_abs_error = 0.0_f64;
for i in 0..stamp_size {
for j in 0..stamp_size {
let du = (i as f64 - detector_center) - delta_row;
let dv = (j as f64 - detector_center) - delta_col;
let expected =
flux * amplitude * (-0.5 * (du * du + dv * dv) / (sigma_det * sigma_det)).exp()
+ back;
max_abs_error = max_abs_error.max((out[(0, i, j)] - expected).abs());
}
}
assert!(
max_abs_error < 5e-3 * flux * amplitude,
"max abs error = {max_abs_error}"
);
}
#[test]
fn flux_and_background_are_linear() {
let oversample = 3;
let stamp_size = 5;
let epsf_side = oversample * stamp_size;
let epsf = gaussian_epsf(epsf_side, oversample as f64 * 1.2, 10.0);
let delta = arr2(&[[0.27, -0.13]]);
let unit = render(
epsf.view(),
oversample,
delta.view(),
Array1::from_elem(1, 1.0).view(),
Array1::from_elem(1, 0.0).view(),
)
.unwrap();
let flux = 7.5;
let back = -3.25;
let scaled = render(
epsf.view(),
oversample,
delta.view(),
Array1::from_elem(1, flux).view(),
Array1::from_elem(1, back).view(),
)
.unwrap();
for i in 0..stamp_size {
for j in 0..stamp_size {
let expected = flux * unit[(0, i, j)] + back;
assert!(
(scaled[(0, i, j)] - expected).abs() < 1e-12,
"({i},{j}) {} != {}",
scaled[(0, i, j)],
expected
);
}
}
}
#[test]
fn source_pushed_off_grid_renders_flat_background() {
let oversample = 5;
let stamp_size = 7;
let epsf_side = oversample * stamp_size;
let epsf = gaussian_epsf(epsf_side, oversample as f64 * 1.5, 100.0);
let delta = arr2(&[[1000.0, -1000.0]]);
let back = 4.0;
let out = render(
epsf.view(),
oversample,
delta.view(),
Array1::from_elem(1, 9.0).view(),
Array1::from_elem(1, back).view(),
)
.unwrap();
for value in out.iter() {
assert!((value - back).abs() < TOL, "value = {value}");
}
}
#[test]
fn uniform_model_interior_is_partition_of_unity() {
let oversample = 5;
let stamp_size = 9;
let epsf_side = oversample * stamp_size;
let epsf = Array2::<f64>::from_elem((epsf_side, epsf_side), 1.0);
let delta = arr2(&[[0.3, 0.3]]);
let out = render(
epsf.view(),
oversample,
delta.view(),
Array1::from_elem(1, 1.0).view(),
Array1::from_elem(1, 0.0).view(),
)
.unwrap();
for value in out.iter() {
assert!(value.is_finite(), "non-finite render value {value}");
}
for i in 1..stamp_size - 1 {
for j in 1..stamp_size - 1 {
assert!(
(out[(0, i, j)] - 1.0).abs() < 1e-12,
"interior ({i},{j}) = {}",
out[(0, i, j)]
);
}
}
}
#[test]
fn batch_stamps_are_independent() {
let oversample = 5;
let stamp_size = 7;
let epsf_side = oversample * stamp_size;
let epsf = gaussian_epsf(epsf_side, oversample as f64 * 1.4, 80.0);
let delta = arr2(&[[0.10, -0.20], [-0.35, 0.05], [0.40, 0.40]]);
let flux = Array1::from(vec![1.0, 2.5, 0.7]);
let background = Array1::from(vec![0.0, -1.0, 3.0]);
let batch = render(
epsf.view(),
oversample,
delta.view(),
flux.view(),
background.view(),
)
.unwrap();
assert_eq!(batch.shape(), &[3, stamp_size, stamp_size]);
for n in 0..3 {
let single = render(
epsf.view(),
oversample,
delta.slice(ndarray::s![n..n + 1, ..]),
flux.slice(ndarray::s![n..n + 1]),
background.slice(ndarray::s![n..n + 1]),
)
.unwrap();
for i in 0..stamp_size {
for j in 0..stamp_size {
assert!(
(batch[(n, i, j)] - single[(0, i, j)]).abs() < 1e-12,
"stamp {n} ({i},{j}) differs from single render"
);
}
}
}
}
#[test]
fn error_epsf_not_square() {
let epsf = Array2::<f64>::zeros((10, 15));
let err = render(
epsf.view(),
5,
arr2(&[[0.0, 0.0]]).view(),
Array1::from_elem(1, 1.0).view(),
Array1::from_elem(1, 0.0).view(),
)
.unwrap_err();
assert_eq!(err, RenderError::EpsfNotSquare { rows: 10, cols: 15 });
}
#[test]
fn error_oversample_not_odd() {
let epsf = Array2::<f64>::zeros((20, 20));
let err = render(
epsf.view(),
4,
arr2(&[[0.0, 0.0]]).view(),
Array1::from_elem(1, 1.0).view(),
Array1::from_elem(1, 0.0).view(),
)
.unwrap_err();
assert_eq!(err, RenderError::OversampleNotOdd { oversample: 4 });
}
#[test]
fn error_epsf_size_not_multiple() {
let epsf = Array2::<f64>::zeros((34, 34));
let err = render(
epsf.view(),
5,
arr2(&[[0.0, 0.0]]).view(),
Array1::from_elem(1, 1.0).view(),
Array1::from_elem(1, 0.0).view(),
)
.unwrap_err();
assert_eq!(
err,
RenderError::EpsfSizeNotMultiple {
epsf_side: 34,
oversample: 5,
}
);
}
#[test]
fn error_derived_stamp_size_even() {
let epsf = Array2::<f64>::zeros((18, 18));
let err = render(
epsf.view(),
3,
arr2(&[[0.0, 0.0]]).view(),
Array1::from_elem(1, 1.0).view(),
Array1::from_elem(1, 0.0).view(),
)
.unwrap_err();
assert_eq!(err, RenderError::DerivedStampSizeEven { stamp_size: 6 });
}
#[test]
fn error_batch_length_mismatch_row_count() {
let epsf = Array2::<f64>::zeros((35, 35));
let err = render(
epsf.view(),
5,
arr2(&[[0.0, 0.0], [0.1, 0.1]]).view(),
Array1::from_elem(1, 1.0).view(),
Array1::from_elem(1, 0.0).view(),
)
.unwrap_err();
assert_eq!(
err,
RenderError::BatchLengthMismatch {
delta: (2, 2),
flux: 1,
background: 1,
}
);
}
#[test]
fn error_batch_length_mismatch_delta_not_two_columns() {
let epsf = Array2::<f64>::zeros((35, 35));
let err = render(
epsf.view(),
5,
arr2(&[[0.0, 0.0, 0.0]]).view(),
Array1::from_elem(1, 1.0).view(),
Array1::from_elem(1, 0.0).view(),
)
.unwrap_err();
assert_eq!(
err,
RenderError::BatchLengthMismatch {
delta: (1, 3),
flux: 1,
background: 1,
}
);
}
#[test]
fn empty_batch_yields_empty_output() {
let epsf = Array2::<f64>::zeros((35, 35));
let out = render(
epsf.view(),
5,
Array2::<f64>::zeros((0, 2)).view(),
Array1::<f64>::zeros(0).view(),
Array1::<f64>::zeros(0).view(),
)
.unwrap();
assert_eq!(out.shape(), &[0, 7, 7]);
}
#[test]
fn build_stamp_delta_round_trip() {
use crate::image::stamp::{
DEFAULT_CENTROID_MAX_ITER, DEFAULT_CENTROID_TOL, DEFAULT_WEIGHT_FWHM, build_stamp,
};
let true_row = 7.35_f64;
let true_col = 6.80_f64;
let amplitude = 100.0;
let sigma_det = 1.6;
let back = 5.0;
let cutout_size = 15;
let mut cutout = Array2::<f64>::from_elem((cutout_size, cutout_size), back);
for r in 0..cutout_size {
for c in 0..cutout_size {
let dr = r as f64 - true_row;
let dc = c as f64 - true_col;
cutout[(r, c)] +=
amplitude * (-0.5 * (dr * dr + dc * dc) / (sigma_det * sigma_det)).exp();
}
}
let stamp_size = 7;
let stamp_result = build_stamp(
cutout.view(),
stamp_size,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp for an isolated source");
let detector_center = (stamp_size as i64 - 1) / 2;
let recon_row = (stamp_result.origin.0 + detector_center) as f64 + stamp_result.delta.0;
let recon_col = (stamp_result.origin.1 + detector_center) as f64 + stamp_result.delta.1;
assert!(
(recon_row - true_row).abs() < 0.05 && (recon_col - true_col).abs() < 0.05,
"reconstructed centroid ({recon_row}, {recon_col}) far from ({true_row}, {true_col})"
);
let oversample = 5;
let epsf_side = oversample * stamp_size;
let epsf = gaussian_epsf(epsf_side, oversample as f64 * sigma_det, amplitude);
let delta = arr2(&[[stamp_result.delta.0, stamp_result.delta.1]]);
let rendered = render(
epsf.view(),
oversample,
delta.view(),
Array1::from_elem(1, 1.0).view(),
Array1::from_elem(1, back).view(),
)
.unwrap();
let mut max_abs_error = 0.0_f64;
for i in 0..stamp_size {
for j in 0..stamp_size {
max_abs_error =
max_abs_error.max((rendered[(0, i, j)] - stamp_result.stamp[(i, j)]).abs());
}
}
assert!(
max_abs_error < 0.02 * amplitude,
"round-trip max abs error = {max_abs_error}"
);
}
#[test]
fn rejects_transposed_indexing_with_asymmetric_delta() {
let oversample = 5;
let stamp_size = 5;
let epsf_side = oversample * stamp_size;
let epsf = ramp_epsf(epsf_side);
let delta = arr2(&[[0.4, -0.1]]);
let out = render(
epsf.view(),
oversample,
delta.view(),
Array1::from_elem(1, 1.0).view(),
Array1::from_elem(1, 0.0).view(),
)
.unwrap();
let psf_center = (epsf_side as f64 - 1.0) / 2.0;
let detector_center = (stamp_size as f64 - 1.0) / 2.0;
let k_u = psf_center + oversample as f64 * ((0.0 - detector_center) - 0.4);
let k_v = psf_center + oversample as f64 * ((0.0 - detector_center) - (-0.1));
let expected = catmull_rom_sample(&epsf.view(), k_u, k_v);
assert!(
(out[(0, 0, 0)] - expected).abs() < 1e-12,
"{} != {}",
out[(0, 0, 0)],
expected
);
assert!((epsf[(0, 1)] - epsf[(1, 0)]).abs() > 1.0);
}
}