use ndarray::{Array1, Array2, Array3, ArrayView2, ArrayView3, Axis};
use thiserror::Error;
use crate::float::Float;
use crate::image::stats::median_in_place;
use super::kernel::catmull_rom_sample;
use super::nuisance::solve_flux_background;
use super::robust::{CombineMethod, robust_combine};
const NUM_AZIMUTH: usize = 720;
pub const DEFAULT_MATCH_RADIUS: f64 = 8.0;
pub const DEFAULT_FEATHER_WIDTH: f64 = 4.0;
pub const DEFAULT_EE_APERTURE_RADIUS: f64 = 15.0;
pub const DEFAULT_SCALE_APERTURE_RADIUS: f64 = 6.0;
pub const DEFAULT_SCALE_BACKGROUND_ANNULUS: (f64, f64) = (18.0, 24.0);
#[derive(Debug, Clone, PartialEq)]
pub struct StitchParams {
pub match_radius: f64,
pub feather_width: f64,
pub ee_aperture_radius: f64,
}
impl Default for StitchParams {
fn default() -> Self {
Self {
match_radius: DEFAULT_MATCH_RADIUS,
feather_width: DEFAULT_FEATHER_WIDTH,
ee_aperture_radius: DEFAULT_EE_APERTURE_RADIUS,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ExtendedPsfParams {
pub stitch: StitchParams,
pub combine: CombineMethod,
pub scale_aperture_radius: f64,
pub scale_background_annulus: (f64, f64),
}
impl Default for ExtendedPsfParams {
fn default() -> Self {
Self {
stitch: StitchParams::default(),
combine: CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 5,
},
scale_aperture_radius: DEFAULT_SCALE_APERTURE_RADIUS,
scale_background_annulus: DEFAULT_SCALE_BACKGROUND_ANNULUS,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ExtendedPsf {
pub core: Array2<f64>,
pub oversample: usize,
pub wing: Array2<f64>,
pub match_radius: f64,
pub feather_width: f64,
pub ee_aperture_radius: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ExtendedPsfBuilt {
pub extended: ExtendedPsf,
pub star_flux: Array1<f64>,
pub star_background: Array1<f64>,
pub star_ok: Array1<bool>,
pub star_scale_from_core: Array1<bool>,
}
#[derive(Debug, Error, PartialEq)]
pub enum StitchError {
#[error("oversample must be odd; got {oversample}")]
OversampleNotOdd { oversample: usize },
#[error("core must be square; got ({rows}, {cols})")]
CoreNotSquare { rows: usize, cols: usize },
#[error("core side ({core_side}) must be an integer multiple of oversample ({oversample})")]
CoreSizeNotMultiple { core_side: usize, oversample: usize },
#[error("derived stamp_size (core_side / oversample) must be odd; got {stamp_size}")]
DerivedStampSizeEven { stamp_size: usize },
#[error("wing must have odd dimensions to have a defined center; got ({rows}, {cols})")]
WingNotOdd { rows: usize, cols: usize },
#[error("wing must be square; got ({rows}, {cols})")]
WingNotSquare { rows: usize, cols: usize },
#[error("wing_confidence shape {confidence:?} must equal wing shape {wing:?}")]
WingConfidenceShapeMismatch {
confidence: (usize, usize),
wing: (usize, usize),
},
#[error(
"stitch params must be finite and > 0 with 0 <= match_radius - feather_width/2 and match_radius + feather_width/2 within both the core and wing native extent and ee_aperture_radius within the wing native extent; got match_radius = {match_radius}, feather_width = {feather_width}, ee_aperture_radius = {ee_aperture_radius}"
)]
StitchParamsInvalid {
match_radius: f64,
feather_width: f64,
ee_aperture_radius: f64,
},
}
#[derive(Debug, Error, PartialEq)]
pub enum ExtendedPsfError {
#[error("oversample must be odd; got {oversample}")]
OversampleNotOdd { oversample: usize },
#[error("core must be square; got ({rows}, {cols})")]
CoreNotSquare { rows: usize, cols: usize },
#[error("core side ({core_side}) must be an integer multiple of oversample ({oversample})")]
CoreSizeNotMultiple { core_side: usize, oversample: usize },
#[error("derived stamp_size (core_side / oversample) must be odd; got {stamp_size}")]
DerivedStampSizeEven { stamp_size: usize },
#[error(
"wing_data must have odd spatial dimensions to have a defined center; got ({rows}, {cols})"
)]
WingNotOdd { rows: usize, cols: usize },
#[error("wing_data spatial dimensions must be square; got ({rows}, {cols})")]
WingNotSquare { rows: usize, cols: usize },
#[error(
"batch dimensions disagree: wing_data shape {wing_data:?} must be (M, H, W) and wing_delta shape {wing_delta:?} must be (M, 2)"
)]
BatchLengthMismatch {
wing_data: (usize, usize, usize),
wing_delta: (usize, usize),
},
#[error("wing_weight shape {weight:?} must equal wing_data shape {wing_data:?}")]
WeightShapeMismatch {
weight: (usize, usize, usize),
wing_data: (usize, usize, usize),
},
#[error(
"params invalid: stitch params (see StitchParamsInvalid), the combine method (ClippedMean kappa > 0 and max_iter > 0), scale_aperture_radius finite and > 0 and within the wing extent, and scale_background_annulus (r_in, r_out) finite with 0 <= r_in < r_out within the wing extent; got match_radius = {match_radius}, feather_width = {feather_width}, ee_aperture_radius = {ee_aperture_radius}, scale_aperture_radius = {scale_aperture_radius}"
)]
ParamsInvalid {
match_radius: f64,
feather_width: f64,
ee_aperture_radius: f64,
scale_aperture_radius: f64,
},
}
struct CoreGeometry {
stamp_size: usize,
core_center: f64,
oversample_f: f64,
core_native_half: f64,
}
fn validate_core(
core: &ArrayView2<f64>,
oversample: usize,
) -> Result<CoreGeometry, CoreValidationError> {
let rows = core.shape()[0];
let cols = core.shape()[1];
if rows != cols {
return Err(CoreValidationError::NotSquare { rows, cols });
}
if oversample.is_multiple_of(2) {
return Err(CoreValidationError::OversampleNotOdd { oversample });
}
let core_side = rows;
if !core_side.is_multiple_of(oversample) {
return Err(CoreValidationError::SizeNotMultiple {
core_side,
oversample,
});
}
let stamp_size = core_side / oversample;
if stamp_size.is_multiple_of(2) {
return Err(CoreValidationError::StampSizeEven { stamp_size });
}
Ok(CoreGeometry {
stamp_size,
core_center: (core_side as f64 - 1.0) / 2.0,
oversample_f: oversample as f64,
core_native_half: (stamp_size as f64 - 1.0) / 2.0,
})
}
#[derive(Debug)]
enum CoreValidationError {
NotSquare { rows: usize, cols: usize },
OversampleNotOdd { oversample: usize },
SizeNotMultiple { core_side: usize, oversample: usize },
StampSizeEven { stamp_size: usize },
}
impl From<CoreValidationError> for StitchError {
fn from(error: CoreValidationError) -> Self {
match error {
CoreValidationError::NotSquare { rows, cols } => {
StitchError::CoreNotSquare { rows, cols }
}
CoreValidationError::OversampleNotOdd { oversample } => {
StitchError::OversampleNotOdd { oversample }
}
CoreValidationError::SizeNotMultiple {
core_side,
oversample,
} => StitchError::CoreSizeNotMultiple {
core_side,
oversample,
},
CoreValidationError::StampSizeEven { stamp_size } => {
StitchError::DerivedStampSizeEven { stamp_size }
}
}
}
}
impl From<CoreValidationError> for ExtendedPsfError {
fn from(error: CoreValidationError) -> Self {
match error {
CoreValidationError::NotSquare { rows, cols } => {
ExtendedPsfError::CoreNotSquare { rows, cols }
}
CoreValidationError::OversampleNotOdd { oversample } => {
ExtendedPsfError::OversampleNotOdd { oversample }
}
CoreValidationError::SizeNotMultiple {
core_side,
oversample,
} => ExtendedPsfError::CoreSizeNotMultiple {
core_side,
oversample,
},
CoreValidationError::StampSizeEven { stamp_size } => {
ExtendedPsfError::DerivedStampSizeEven { stamp_size }
}
}
}
}
fn stitch_params_infeasible(params: &StitchParams, core_native_half: f64, wing_half: f64) -> bool {
let positive_finite = |x: f64| x.is_finite() && x > 0.0;
if !positive_finite(params.match_radius)
|| !positive_finite(params.feather_width)
|| !positive_finite(params.ee_aperture_radius)
{
return true;
}
let inner = params.match_radius - 0.5 * params.feather_width;
let outer = params.match_radius + 0.5 * params.feather_width;
inner < 0.0
|| outer > core_native_half
|| outer > wing_half
|| params.ee_aperture_radius > wing_half
}
fn combine_method_invalid(method: CombineMethod) -> bool {
match method {
CombineMethod::ClippedMean { kappa, max_iter } => kappa <= 0.0 || max_iter == 0,
CombineMethod::Median => false,
}
}
fn feather_wing_weight(r: f64, match_radius: f64, feather_width: f64) -> f64 {
let inner = match_radius - 0.5 * feather_width;
let outer = match_radius + 0.5 * feather_width;
if r <= inner {
0.0
} else if r >= outer {
1.0
} else {
0.5 * (1.0 - (std::f64::consts::PI * (r - inner) / feather_width).cos())
}
}
fn azimuthal_average(
image: &ArrayView2<f64>,
center_row: f64,
center_col: f64,
radius: f64,
sample_scale: f64,
confidence: Option<&ArrayView2<f64>>,
) -> f64 {
let mut weighted_sum = 0.0;
let mut weight_sum = 0.0;
for angle_index in 0..NUM_AZIMUTH {
let angle = 2.0 * std::f64::consts::PI * (angle_index as f64) / (NUM_AZIMUTH as f64);
let offset_row = radius * angle.sin();
let offset_col = radius * angle.cos();
let row_coord = center_row + sample_scale * offset_row;
let col_coord = center_col + sample_scale * offset_col;
let value = catmull_rom_sample(image, row_coord, col_coord);
if !value.is_finite() {
continue;
}
let sample_weight = match confidence {
Some(confidence_view) => {
let c = catmull_rom_sample(confidence_view, row_coord, col_coord);
if !(c.is_finite() && c > 0.0) {
continue;
}
c
}
None => 1.0,
};
weighted_sum += sample_weight * value;
weight_sum += sample_weight;
}
if weight_sum > 0.0 {
weighted_sum / weight_sum
} else {
f64::NAN
}
}
fn annulus_background(
stamp: &ArrayView2<f64>,
weight: Option<&ArrayView2<f64>>,
center_row: f64,
center_col: f64,
r_in: f64,
r_out: f64,
) -> f64 {
let (rows, cols) = (stamp.shape()[0], stamp.shape()[1]);
let mut ring: Vec<f64> = Vec::new();
for row in 0..rows {
for col in 0..cols {
let dr = row as f64 - center_row;
let dc = col as f64 - center_col;
let radius = (dr * dr + dc * dc).sqrt();
if radius < r_in || radius > r_out {
continue;
}
let value = stamp[(row, col)];
if !value.is_finite() {
continue;
}
if let Some(weight_view) = weight {
let w = weight_view[(row, col)];
if !(w.is_finite() && w > 0.0) {
continue;
}
}
ring.push(value);
}
}
median_in_place(&mut ring).unwrap_or(0.0)
}
fn aperture_photometry_scale(
stamp: &ArrayView2<f64>,
weight: Option<&ArrayView2<f64>>,
center_row: f64,
center_col: f64,
aperture_radius: f64,
background: f64,
) -> f64 {
let (rows, cols) = (stamp.shape()[0], stamp.shape()[1]);
let mut sum = 0.0;
for row in 0..rows {
for col in 0..cols {
let dr = row as f64 - center_row;
let dc = col as f64 - center_col;
if (dr * dr + dc * dc).sqrt() > aperture_radius {
continue;
}
let value = stamp[(row, col)];
if !value.is_finite() {
continue;
}
if let Some(weight_view) = weight {
let w = weight_view[(row, col)];
if !(w.is_finite() && w > 0.0) {
continue;
}
}
sum += value - background;
}
}
if sum.is_finite() && sum > 0.0 {
sum
} else {
f64::NAN
}
}
pub fn stitch_psf(
core: ArrayView2<f64>,
oversample: usize,
wing: ArrayView2<f64>,
wing_confidence: Option<ArrayView2<f64>>,
params: StitchParams,
) -> Result<ExtendedPsf, StitchError> {
let geometry = validate_core(&core, oversample)?;
let wing_rows = wing.shape()[0];
let wing_cols = wing.shape()[1];
if wing_rows.is_multiple_of(2) || wing_cols.is_multiple_of(2) {
return Err(StitchError::WingNotOdd {
rows: wing_rows,
cols: wing_cols,
});
}
if wing_rows != wing_cols {
return Err(StitchError::WingNotSquare {
rows: wing_rows,
cols: wing_cols,
});
}
if let Some(confidence_view) = wing_confidence.as_ref() {
let confidence_shape = confidence_view.shape();
if confidence_shape[0] != wing_rows || confidence_shape[1] != wing_cols {
return Err(StitchError::WingConfidenceShapeMismatch {
confidence: (confidence_shape[0], confidence_shape[1]),
wing: (wing_rows, wing_cols),
});
}
}
let wing_half = ((wing_rows as f64 - 1.0) / 2.0).min((wing_cols as f64 - 1.0) / 2.0);
if stitch_params_infeasible(¶ms, geometry.core_native_half, wing_half) {
return Err(StitchError::StitchParamsInvalid {
match_radius: params.match_radius,
feather_width: params.feather_width,
ee_aperture_radius: params.ee_aperture_radius,
});
}
Ok(stitch_core_and_wing(
&core,
&geometry,
&wing,
wing_confidence.as_ref(),
¶ms,
))
}
fn stitch_core_and_wing(
core: &ArrayView2<f64>,
geometry: &CoreGeometry,
wing: &ArrayView2<f64>,
wing_confidence: Option<&ArrayView2<f64>>,
params: &StitchParams,
) -> ExtendedPsf {
let wing_rows = wing.shape()[0];
let wing_cols = wing.shape()[1];
let wing_center_row = (wing_rows as f64 - 1.0) / 2.0;
let wing_center_col = (wing_cols as f64 - 1.0) / 2.0;
let core_at_match = azimuthal_average(
core,
geometry.core_center,
geometry.core_center,
params.match_radius,
geometry.oversample_f,
None,
);
let wing_at_match = azimuthal_average(
wing,
wing_center_row,
wing_center_col,
params.match_radius,
1.0,
wing_confidence,
);
let scale = {
let candidate = core_at_match / wing_at_match;
if candidate.is_finite() {
candidate
} else {
0.0
}
};
let mut wing_plane = Array2::<f64>::zeros((wing_rows, wing_cols));
for row in 0..wing_rows {
for col in 0..wing_cols {
let dr = row as f64 - wing_center_row;
let dc = col as f64 - wing_center_col;
let radius = (dr * dr + dc * dc).sqrt();
let f_wing = feather_wing_weight(radius, params.match_radius, params.feather_width);
let raw = wing[(row, col)];
if f_wing > 0.0 && scale != 0.0 && raw.is_finite() {
wing_plane[(row, col)] = f_wing * scale * raw;
}
}
}
let mut ee_raw = 0.0;
for row in 0..wing_rows {
for col in 0..wing_cols {
let dr = row as f64 - wing_center_row;
let dc = col as f64 - wing_center_col;
let radius = (dr * dr + dc * dc).sqrt();
if radius > params.ee_aperture_radius {
continue;
}
let f_wing = feather_wing_weight(radius, params.match_radius, params.feather_width);
let core_native = catmull_rom_sample(
core,
geometry.core_center + geometry.oversample_f * dr,
geometry.core_center + geometry.oversample_f * dc,
);
let core_term = if core_native.is_finite() {
(1.0 - f_wing) * core_native
} else {
0.0
};
ee_raw += core_term + wing_plane[(row, col)];
}
}
let global_factor = if ee_raw.is_finite() && ee_raw > 0.0 {
1.0 / ee_raw
} else {
1.0
};
ExtendedPsf {
core: core.mapv(|v| v * global_factor),
oversample: geometry.oversample_f as usize,
wing: wing_plane.mapv(|v| v * global_factor),
match_radius: params.match_radius,
feather_width: params.feather_width,
ee_aperture_radius: params.ee_aperture_radius,
}
}
pub fn build_extended_psf<T: Float>(
wing_data: ArrayView3<T>,
wing_weight: Option<ArrayView3<T>>,
wing_delta: ArrayView2<f64>,
core: ArrayView2<f64>,
oversample: usize,
params: ExtendedPsfParams,
) -> Result<ExtendedPsfBuilt, ExtendedPsfError> {
let geometry = validate_core(&core, oversample)?;
let star_count = wing_data.shape()[0];
let wing_rows = wing_data.shape()[1];
let wing_cols = wing_data.shape()[2];
if wing_rows.is_multiple_of(2) || wing_cols.is_multiple_of(2) {
return Err(ExtendedPsfError::WingNotOdd {
rows: wing_rows,
cols: wing_cols,
});
}
if wing_rows != wing_cols {
return Err(ExtendedPsfError::WingNotSquare {
rows: wing_rows,
cols: wing_cols,
});
}
if wing_delta.shape()[0] != star_count || wing_delta.shape()[1] != 2 {
return Err(ExtendedPsfError::BatchLengthMismatch {
wing_data: (star_count, wing_rows, wing_cols),
wing_delta: (wing_delta.shape()[0], wing_delta.shape()[1]),
});
}
if let Some(weight_view) = wing_weight.as_ref() {
let weight_shape = weight_view.shape();
if weight_shape[0] != star_count
|| weight_shape[1] != wing_rows
|| weight_shape[2] != wing_cols
{
return Err(ExtendedPsfError::WeightShapeMismatch {
weight: (weight_shape[0], weight_shape[1], weight_shape[2]),
wing_data: (star_count, wing_rows, wing_cols),
});
}
}
let wing_half = ((wing_rows as f64 - 1.0) / 2.0).min((wing_cols as f64 - 1.0) / 2.0);
let (annulus_in, annulus_out) = params.scale_background_annulus;
let annulus_invalid = !annulus_in.is_finite()
|| !annulus_out.is_finite()
|| annulus_in < 0.0
|| annulus_out <= annulus_in
|| annulus_out > wing_half;
let scale_aperture_invalid = !(params.scale_aperture_radius.is_finite()
&& params.scale_aperture_radius > 0.0)
|| params.scale_aperture_radius > wing_half;
let params_invalid =
stitch_params_infeasible(¶ms.stitch, geometry.core_native_half, wing_half)
|| combine_method_invalid(params.combine)
|| scale_aperture_invalid
|| annulus_invalid;
if params_invalid {
return Err(ExtendedPsfError::ParamsInvalid {
match_radius: params.stitch.match_radius,
feather_width: params.stitch.feather_width,
ee_aperture_radius: params.stitch.ee_aperture_radius,
scale_aperture_radius: params.scale_aperture_radius,
});
}
let wing_data_f: Array3<f64> = wing_data.mapv(|value| value.to_f64().unwrap_or(f64::NAN));
let wing_weight_f: Option<Array3<f64>> = wing_weight
.as_ref()
.map(|w| w.mapv(|value| value.to_f64().unwrap_or(f64::NAN)));
let wing_center_row = (wing_rows as f64 - 1.0) / 2.0;
let wing_center_col = (wing_cols as f64 - 1.0) / 2.0;
let stamp_size = geometry.stamp_size;
let core_path_available = star_count > 0 && wing_rows >= stamp_size && wing_cols >= stamp_size;
let core_solution = if core_path_available {
let row_offset = (wing_rows - stamp_size) / 2;
let col_offset = (wing_cols - stamp_size) / 2;
let mut central = Array3::<f64>::zeros((star_count, stamp_size, stamp_size));
let mut central_weight = wing_weight_f
.as_ref()
.map(|_| Array3::<f64>::zeros((star_count, stamp_size, stamp_size)));
for star in 0..star_count {
for i in 0..stamp_size {
for j in 0..stamp_size {
central[(star, i, j)] = wing_data_f[(star, row_offset + i, col_offset + j)];
if let (Some(dst), Some(src)) =
(central_weight.as_mut(), wing_weight_f.as_ref())
{
dst[(star, i, j)] = src[(star, row_offset + i, col_offset + j)];
}
}
}
}
Some(
solve_flux_background::<f64>(
core,
oversample,
central.view(),
central_weight.as_ref().map(|w| w.view()),
wing_delta,
)
.expect(
"solve_flux_background preconditions are guaranteed by build_extended_psf validation",
),
)
} else {
None
};
let mut normalized = Array3::<f64>::from_elem((star_count, wing_rows, wing_cols), f64::NAN);
let mut normalized_weight = wing_weight_f
.as_ref()
.map(|_| Array3::<f64>::zeros((star_count, wing_rows, wing_cols)));
let mut star_flux = vec![f64::NAN; star_count];
let mut star_background = vec![f64::NAN; star_count];
let mut star_ok = vec![false; star_count];
let mut star_scale_from_core = vec![false; star_count];
for star in 0..star_count {
let stamp = wing_data_f.index_axis(Axis(0), star);
let weight_stamp = wing_weight_f.as_ref().map(|w| w.index_axis(Axis(0), star));
let background = annulus_background(
&stamp,
weight_stamp.as_ref(),
wing_center_row,
wing_center_col,
annulus_in,
annulus_out,
);
let core_flux = core_solution.as_ref().and_then(|solution| {
if solution.ok[star] && solution.flux[star].is_finite() && solution.flux[star] > 0.0 {
Some(solution.flux[star])
} else {
None
}
});
let (scale, from_core) = match core_flux {
Some(flux) => (flux, true),
None => {
let aperture = aperture_photometry_scale(
&stamp,
weight_stamp.as_ref(),
wing_center_row,
wing_center_col,
params.scale_aperture_radius,
background,
);
(aperture, false)
}
};
if !(scale.is_finite() && scale > 0.0) {
continue;
}
let shift_row = wing_delta[(star, 0)].round() as i64;
let shift_col = wing_delta[(star, 1)].round() as i64;
for row in 0..wing_rows {
for col in 0..wing_cols {
let source_row = row as i64 + shift_row;
let source_col = col as i64 + shift_col;
if source_row < 0
|| source_row >= wing_rows as i64
|| source_col < 0
|| source_col >= wing_cols as i64
{
continue; }
let value = wing_data_f[(star, source_row as usize, source_col as usize)];
if value.is_finite() {
normalized[(star, row, col)] = (value - background) / scale;
}
if let (Some(dst), Some(src)) = (normalized_weight.as_mut(), wing_weight_f.as_ref())
{
dst[(star, row, col)] = src[(star, source_row as usize, source_col as usize)];
}
}
}
star_flux[star] = scale;
star_background[star] = background;
star_ok[star] = true;
star_scale_from_core[star] = from_core;
}
let combined = robust_combine::<f64>(
normalized.view(),
normalized_weight.as_ref().map(|w| w.view()),
params.combine,
)
.expect("robust_combine preconditions are guaranteed by build_extended_psf validation");
let extended = stitch_core_and_wing(
&core,
&geometry,
&combined.combined.view(),
Some(&combined.weight.view()),
¶ms.stitch,
);
Ok(ExtendedPsfBuilt {
extended,
star_flux: Array1::from(star_flux),
star_background: Array1::from(star_background),
star_ok: Array1::from(star_ok),
star_scale_from_core: Array1::from(star_scale_from_core),
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array2, Array3};
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn unit(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
fn range(&mut self, lo: f64, hi: f64) -> f64 {
lo + (hi - lo) * self.unit()
}
}
fn gaussian_core(oversample: usize, stamp_size: usize, sigma: f64) -> Array2<f64> {
let side = oversample * stamp_size;
let center = (side as f64 - 1.0) / 2.0;
let sigma_os = oversample as f64 * sigma;
let mut core = Array2::<f64>::zeros((side, side));
for r in 0..side {
for c in 0..side {
let dr = r as f64 - center;
let dc = c as f64 - center;
core[(r, c)] = (-(dr * dr + dc * dc) / (2.0 * sigma_os * sigma_os)).exp();
}
}
let volume: f64 = core.iter().sum::<f64>() / (oversample * oversample) as f64;
core.mapv(|v| v / volume)
}
fn core_native_value(core: &Array2<f64>, oversample: usize, dr: f64, dc: f64) -> f64 {
let side = core.shape()[0];
let center = (side as f64 - 1.0) / 2.0;
catmull_rom_sample(
&core.view(),
center + oversample as f64 * dr,
center + oversample as f64 * dc,
)
}
fn core_peak(core: &Array2<f64>) -> f64 {
let mid = (core.shape()[0] - 1) / 2;
core[(mid, mid)]
}
fn truth_native(peak: f64, sigma: f64, dr: f64, dc: f64) -> f64 {
peak * (-(dr * dr + dc * dc) / (2.0 * sigma * sigma)).exp()
}
#[test]
fn stitch_seam_is_c0_c1_continuous_and_ee_normalized() {
let oversample = 5;
let stamp_size = 21;
let core = gaussian_core(oversample, stamp_size, 1.6);
let wing_side = 81usize;
let wing_center = (wing_side as f64 - 1.0) / 2.0;
let wing = Array2::<f64>::from_shape_fn((wing_side, wing_side), |(r, c)| {
let dr = r as f64 - wing_center;
let dc = c as f64 - wing_center;
17.0 * (-(dr * dr + dc * dc) / (2.0 * 3.4 * 3.4)).exp()
});
let params = StitchParams {
match_radius: 5.0,
feather_width: 3.0,
ee_aperture_radius: 18.0,
};
let extended =
stitch_psf(core.view(), oversample, wing.view(), None, params.clone()).unwrap();
assert_eq!(
extended.core.shape(),
&[oversample * stamp_size, oversample * stamp_size]
);
assert_eq!(extended.wing.shape(), &[wing_side, wing_side]);
assert_eq!(extended.oversample, oversample);
assert_eq!(extended.match_radius, params.match_radius);
assert_eq!(extended.feather_width, params.feather_width);
assert_eq!(extended.ee_aperture_radius, params.ee_aperture_radius);
let ec = (extended.core.shape()[0] as f64 - 1.0) / 2.0;
let recon = |radius: f64| -> f64 {
let f_wing = feather_wing_weight(radius, params.match_radius, params.feather_width);
let core_az = azimuthal_average(
&extended.core.view(),
ec,
ec,
radius,
oversample as f64,
None,
);
let wing_az = azimuthal_average(
&extended.wing.view(),
wing_center,
wing_center,
radius,
1.0,
None,
);
(1.0 - f_wing) * core_az + wing_az
};
let step = 0.02;
let mut radius = 1.0;
while radius < 12.0 {
let here = recon(radius);
let ahead = recon(radius + step);
assert!(
(here - ahead).abs() < 5e-2 * here.abs().max(1e-3),
"C0 break at r = {radius}: {here} vs {ahead}"
);
radius += step;
}
for &edge in &[
params.match_radius - 0.5 * params.feather_width,
params.match_radius + 0.5 * params.feather_width,
] {
let h = 0.05;
let slope_in = (recon(edge) - recon(edge - h)) / h;
let slope_out = (recon(edge + h) - recon(edge)) / h;
assert!(
(slope_in - slope_out).abs() < 5e-2 * slope_in.abs().max(1e-2),
"C1 break at edge {edge}: {slope_in} vs {slope_out}"
);
}
let mut ee = 0.0;
for r in 0..wing_side {
for c in 0..wing_side {
let dr = r as f64 - wing_center;
let dc = c as f64 - wing_center;
let radius = (dr * dr + dc * dc).sqrt();
if radius > params.ee_aperture_radius {
continue;
}
let f_wing = feather_wing_weight(radius, params.match_radius, params.feather_width);
let core_native = core_native_value(&extended.core, oversample, dr, dc);
ee += (1.0 - f_wing) * core_native + extended.wing[(r, c)];
}
}
assert!((ee - 1.0).abs() < 1e-9, "EE = {ee}, expected 1");
}
#[test]
fn stitch_feather_partitions_and_wing_is_zero_in_core_region() {
let match_radius = 6.0;
let feather_width = 4.0;
let inner = match_radius - 0.5 * feather_width;
let outer = match_radius + 0.5 * feather_width;
let mut previous = -1.0;
let mut radius = 0.0;
while radius < 12.0 {
let f_wing = feather_wing_weight(radius, match_radius, feather_width);
assert!((0.0..=1.0).contains(&f_wing));
assert!(((1.0 - f_wing) + f_wing - 1.0).abs() < 1e-15);
if radius <= inner {
assert_eq!(f_wing, 0.0);
}
if radius >= outer {
assert_eq!(f_wing, 1.0);
}
assert!(f_wing + 1e-12 >= previous, "f_wing must be monotone");
previous = f_wing;
radius += 0.05;
}
let oversample = 3;
let stamp_size = 21;
let core = gaussian_core(oversample, stamp_size, 1.3);
let wing_side = 61usize;
let wing_center = (wing_side as f64 - 1.0) / 2.0;
let wing = Array2::<f64>::from_shape_fn((wing_side, wing_side), |(r, c)| {
let dr = r as f64 - wing_center;
let dc = c as f64 - wing_center;
5.0 * (-(dr * dr + dc * dc) / (2.0 * 3.0 * 3.0)).exp()
});
let extended = stitch_psf(
core.view(),
oversample,
wing.view(),
None,
StitchParams {
match_radius,
feather_width,
ee_aperture_radius: 16.0,
},
)
.unwrap();
for r in 0..wing_side {
for c in 0..wing_side {
let dr = r as f64 - wing_center;
let dc = c as f64 - wing_center;
let radius = (dr * dr + dc * dc).sqrt();
if radius <= inner {
assert_eq!(
extended.wing[(r, c)],
0.0,
"wing must be 0 in the core region at r = {radius}"
);
}
}
}
}
#[test]
fn stitch_all_sentinel_wing_yields_pure_core() {
let oversample = 5;
let stamp_size = 17;
let core = gaussian_core(oversample, stamp_size, 1.5);
let wing_side = 71usize;
let wing = Array2::<f64>::from_elem((wing_side, wing_side), f64::NAN);
let params = StitchParams {
match_radius: 6.0,
feather_width: 3.0,
ee_aperture_radius: 14.0,
};
let extended =
stitch_psf(core.view(), oversample, wing.view(), None, params.clone()).unwrap();
assert!(extended.wing.iter().all(|&v| v == 0.0));
assert!(extended.core.iter().all(|v| v.is_finite()));
let wing_center = (wing_side as f64 - 1.0) / 2.0;
let mut ee = 0.0;
for r in 0..wing_side {
for c in 0..wing_side {
let dr = r as f64 - wing_center;
let dc = c as f64 - wing_center;
let radius = (dr * dr + dc * dc).sqrt();
if radius > params.ee_aperture_radius {
continue;
}
let f_wing = feather_wing_weight(radius, params.match_radius, params.feather_width);
ee += (1.0 - f_wing) * core_native_value(&extended.core, oversample, dr, dc);
}
}
assert!((ee - 1.0).abs() < 1e-9, "pure-core EE = {ee}");
}
#[test]
fn stitch_params_default_fields() {
let p = StitchParams::default();
assert_eq!(p.match_radius, DEFAULT_MATCH_RADIUS);
assert_eq!(p.feather_width, DEFAULT_FEATHER_WIDTH);
assert_eq!(p.ee_aperture_radius, DEFAULT_EE_APERTURE_RADIUS);
let ep = ExtendedPsfParams::default();
assert_eq!(ep.stitch, StitchParams::default());
assert_eq!(
ep.combine,
CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 5
}
);
assert_eq!(ep.scale_aperture_radius, DEFAULT_SCALE_APERTURE_RADIUS);
assert_eq!(
ep.scale_background_annulus,
DEFAULT_SCALE_BACKGROUND_ANNULUS
);
}
#[test]
fn stitch_preconditions() {
let oversample = 5;
let stamp_size = 15;
let good_core = gaussian_core(oversample, stamp_size, 1.4);
let good_wing = Array2::<f64>::from_elem((61, 61), 0.01);
let good = StitchParams {
match_radius: 6.0,
feather_width: 3.0,
ee_aperture_radius: 14.0,
};
assert_eq!(
stitch_psf(good_core.view(), 4, good_wing.view(), None, good.clone()).unwrap_err(),
StitchError::OversampleNotOdd { oversample: 4 }
);
assert_eq!(
stitch_psf(good_core.view(), 0, good_wing.view(), None, good.clone()).unwrap_err(),
StitchError::OversampleNotOdd { oversample: 0 }
);
let rect = Array2::<f64>::zeros((10, 12));
assert_eq!(
stitch_psf(rect.view(), 5, good_wing.view(), None, good.clone()).unwrap_err(),
StitchError::CoreNotSquare { rows: 10, cols: 12 }
);
let not_mult = Array2::<f64>::zeros((76, 76));
assert_eq!(
stitch_psf(not_mult.view(), 5, good_wing.view(), None, good.clone()).unwrap_err(),
StitchError::CoreSizeNotMultiple {
core_side: 76,
oversample: 5
}
);
let even_s = Array2::<f64>::zeros((20, 20));
assert_eq!(
stitch_psf(even_s.view(), 5, good_wing.view(), None, good.clone()).unwrap_err(),
StitchError::DerivedStampSizeEven { stamp_size: 4 }
);
let even_wing = Array2::<f64>::zeros((60, 61));
assert_eq!(
stitch_psf(good_core.view(), 5, even_wing.view(), None, good.clone()).unwrap_err(),
StitchError::WingNotOdd { rows: 60, cols: 61 }
);
let rectangular_wing = Array2::<f64>::zeros((59, 61));
assert_eq!(
stitch_psf(
good_core.view(),
5,
rectangular_wing.view(),
None,
good.clone()
)
.unwrap_err(),
StitchError::WingNotSquare { rows: 59, cols: 61 }
);
let bad_conf = Array2::<f64>::zeros((59, 61));
assert_eq!(
stitch_psf(
good_core.view(),
5,
good_wing.view(),
Some(bad_conf.view()),
good.clone()
)
.unwrap_err(),
StitchError::WingConfidenceShapeMismatch {
confidence: (59, 61),
wing: (61, 61)
}
);
for bad in [
StitchParams {
match_radius: 0.0,
feather_width: 3.0,
ee_aperture_radius: 14.0,
},
StitchParams {
match_radius: 6.0,
feather_width: f64::NAN,
ee_aperture_radius: 14.0,
},
StitchParams {
match_radius: 1.0,
feather_width: 4.0,
ee_aperture_radius: 14.0,
}, StitchParams {
match_radius: 6.0,
feather_width: 3.0,
ee_aperture_radius: 100.0,
}, ] {
let err =
stitch_psf(good_core.view(), 5, good_wing.view(), None, bad.clone()).unwrap_err();
assert!(
matches!(err, StitchError::StitchParamsInvalid { .. }),
"expected StitchParamsInvalid for {bad:?}, got {err:?}"
);
}
assert_eq!(
stitch_psf(
even_s.view(),
5,
good_wing.view(),
None,
StitchParams {
match_radius: 0.0,
feather_width: 0.0,
ee_aperture_radius: 0.0,
}
)
.unwrap_err(),
StitchError::DerivedStampSizeEven { stamp_size: 4 }
);
}
fn synth_bright_stars(
core: &Array2<f64>,
sigma: f64,
wing_side: usize,
fluxes: &[f64],
backgrounds: &[f64],
) -> Array3<f64> {
let m = fluxes.len();
let peak = core_peak(core);
let wing_center = (wing_side as f64 - 1.0) / 2.0;
Array3::<f64>::from_shape_fn((m, wing_side, wing_side), |(star, r, c)| {
let dr = r as f64 - wing_center;
let dc = c as f64 - wing_center;
fluxes[star] * truth_native(peak, sigma, dr, dc) + backgrounds[star]
})
}
#[test]
fn build_recovers_wing_vs_robust_combine_reference() {
let oversample = 5;
let stamp_size = 15;
let core = gaussian_core(oversample, stamp_size, 1.1);
let wing_side = 61usize;
let fluxes = [1000.0, 2500.0, 700.0, 1800.0, 3300.0];
let backgrounds = [10.0, -4.0, 25.0, 0.0, 7.0];
let sigma = 1.1;
let data = synth_bright_stars(&core, sigma, wing_side, &fluxes, &backgrounds);
let wing_delta = Array2::<f64>::zeros((fluxes.len(), 2));
let params = ExtendedPsfParams {
stitch: StitchParams {
match_radius: 5.0,
feather_width: 3.0,
ee_aperture_radius: 18.0,
},
combine: CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 5,
},
scale_aperture_radius: 5.0,
scale_background_annulus: (22.0, 28.0),
};
let built = build_extended_psf(
data.view(),
None,
wing_delta.view(),
core.view(),
oversample,
params.clone(),
)
.unwrap();
for star in 0..fluxes.len() {
assert!(built.star_ok[star]);
assert!(built.star_scale_from_core[star]);
assert!(
(built.star_flux[star] - fluxes[star]).abs() < 1e-5 * fluxes[star],
"star {star} flux {} vs {}",
built.star_flux[star],
fluxes[star]
);
assert!(
(built.star_background[star] - backgrounds[star]).abs() < 1e-3,
"star {star} bg {} vs {}",
built.star_background[star],
backgrounds[star]
);
}
let wing_center = (wing_side as f64 - 1.0) / 2.0;
let peak = core_peak(&core);
let reference_norm =
Array3::<f64>::from_shape_fn((fluxes.len(), wing_side, wing_side), |(star, r, c)| {
let dr = r as f64 - wing_center;
let dc = c as f64 - wing_center;
truth_native(peak, sigma, dr, dc)
+ (backgrounds[star] - built.star_background[star]) / fluxes[star]
});
let reference = robust_combine::<f64>(reference_norm.view(), None, params.combine)
.unwrap()
.combined;
let geometry = validate_core(&core.view(), oversample).unwrap();
let reference_extended = stitch_core_and_wing(
&core.view(),
&geometry,
&reference.view(),
None,
¶ms.stitch,
);
for (a, b) in built
.extended
.wing
.iter()
.zip(reference_extended.wing.iter())
{
assert!(
(a - b).abs() < 1e-6 * a.abs().max(1.0),
"wing mismatch vs reference: {a} vs {b}"
);
}
let mut ee = 0.0;
for r in 0..wing_side {
for c in 0..wing_side {
let dr = r as f64 - wing_center;
let dc = c as f64 - wing_center;
let radius = (dr * dr + dc * dc).sqrt();
if radius > params.stitch.ee_aperture_radius {
continue;
}
let f_wing = feather_wing_weight(
radius,
params.stitch.match_radius,
params.stitch.feather_width,
);
ee += (1.0 - f_wing) * core_native_value(&built.extended.core, oversample, dr, dc)
+ built.extended.wing[(r, c)];
}
}
assert!((ee - 1.0).abs() < 1e-9, "recovered EE = {ee}");
}
#[test]
fn build_saturated_core_uses_aperture_fallback_still_in_stack() {
let oversample = 5;
let stamp_size = 15;
let core = gaussian_core(oversample, stamp_size, 1.1);
let wing_side = 61usize;
let fluxes = [1500.0, 1500.0, 1500.0, 1500.0];
let backgrounds = [5.0, 5.0, 5.0, 5.0];
let data = synth_bright_stars(&core, 1.1, wing_side, &fluxes, &backgrounds);
let row_off = (wing_side - stamp_size) / 2;
let mut weight = Array3::<f64>::from_elem((4, wing_side, wing_side), 1.0);
for i in 0..stamp_size {
for j in 0..stamp_size {
weight[(1, row_off + i, row_off + j)] = 0.0;
}
}
let params = ExtendedPsfParams {
stitch: StitchParams {
match_radius: 5.0,
feather_width: 3.0,
ee_aperture_radius: 18.0,
},
combine: CombineMethod::Median,
scale_aperture_radius: 9.0,
scale_background_annulus: (22.0, 28.0),
};
let built = build_extended_psf(
data.view(),
Some(weight.view()),
Array2::<f64>::zeros((4, 2)).view(),
core.view(),
oversample,
params,
)
.unwrap();
assert!(built.star_ok.iter().all(|&ok| ok));
assert!(built.star_scale_from_core[0]);
assert!(!built.star_scale_from_core[1], "saturated star -> fallback");
assert!(built.star_scale_from_core[2]);
assert!(built.star_scale_from_core[3]);
assert!(built.star_flux[1].is_finite() && built.star_flux[1] > 0.0);
let wc = (wing_side - 1) / 2;
assert!(built.extended.wing[(wc, wc + 8)].is_finite());
}
#[test]
fn build_uncalibratable_star_excluded_wing_from_rest() {
let oversample = 5;
let stamp_size = 15;
let core = gaussian_core(oversample, stamp_size, 1.1);
let wing_side = 61usize;
let fluxes = [1200.0, 1200.0, 1200.0];
let backgrounds = [3.0, 3.0, 3.0];
let sigma = 1.1;
let data = synth_bright_stars(&core, sigma, wing_side, &fluxes, &backgrounds);
let mut weight = Array3::<f64>::from_elem((3, wing_side, wing_side), 1.0);
for r in 0..wing_side {
for c in 0..wing_side {
weight[(2, r, c)] = 0.0; }
}
let params = ExtendedPsfParams {
stitch: StitchParams {
match_radius: 5.0,
feather_width: 3.0,
ee_aperture_radius: 18.0,
},
combine: CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 5,
},
scale_aperture_radius: 6.0,
scale_background_annulus: (22.0, 28.0),
};
let built = build_extended_psf(
data.view(),
Some(weight.view()),
Array2::<f64>::zeros((3, 2)).view(),
core.view(),
oversample,
params.clone(),
)
.unwrap();
assert!(built.star_ok[0]);
assert!(built.star_ok[1]);
assert!(!built.star_ok[2], "all-zero-weight star must be excluded");
assert!(built.star_flux[2].is_nan());
assert!(built.star_background[2].is_nan());
let wing_center = (wing_side as f64 - 1.0) / 2.0;
let peak = core_peak(&core);
let reference_norm =
Array3::<f64>::from_shape_fn((2, wing_side, wing_side), |(star, r, c)| {
let dr = r as f64 - wing_center;
let dc = c as f64 - wing_center;
let s = if star == 0 { 0 } else { 1 };
truth_native(peak, sigma, dr, dc)
+ (backgrounds[s] - built.star_background[s]) / fluxes[s]
});
let reference = robust_combine::<f64>(reference_norm.view(), None, params.combine)
.unwrap()
.combined;
let geometry = validate_core(&core.view(), oversample).unwrap();
let reference_extended = stitch_core_and_wing(
&core.view(),
&geometry,
&reference.view(),
None,
¶ms.stitch,
);
for (a, b) in built
.extended
.wing
.iter()
.zip(reference_extended.wing.iter())
{
assert!(
(a - b).abs() < 1e-6 * a.abs().max(1.0),
"wing must be recovered from the surviving stars"
);
}
}
#[test]
fn build_integer_recenter_drops_subpixel_delta() {
let oversample = 5;
let stamp_size = 15;
let core = gaussian_core(oversample, stamp_size, 1.1);
let wing_side = 61usize;
let fluxes = [900.0, 1400.0, 2000.0];
let backgrounds = [2.0, 6.0, -3.0];
let data = synth_bright_stars(&core, 1.1, wing_side, &fluxes, &backgrounds);
let row_off = (wing_side - stamp_size) / 2;
let mut weight = Array3::<f64>::from_elem((3, wing_side, wing_side), 1.0);
for star in 0..3 {
for i in 0..stamp_size {
for j in 0..stamp_size {
weight[(star, row_off + i, row_off + j)] = 0.0;
}
}
}
let params = ExtendedPsfParams {
stitch: StitchParams {
match_radius: 5.0,
feather_width: 3.0,
ee_aperture_radius: 18.0,
},
combine: CombineMethod::Median,
scale_aperture_radius: 9.0,
scale_background_annulus: (22.0, 28.0),
};
let zero_delta = Array2::<f64>::zeros((3, 2));
let subpixel_delta =
Array2::<f64>::from_shape_vec((3, 2), vec![0.3, -0.4, 0.49, 0.1, -0.49, 0.2]).unwrap();
let from_zero = build_extended_psf(
data.view(),
Some(weight.view()),
zero_delta.view(),
core.view(),
oversample,
params.clone(),
)
.unwrap();
let from_subpixel = build_extended_psf(
data.view(),
Some(weight.view()),
subpixel_delta.view(),
core.view(),
oversample,
params,
)
.unwrap();
for (a, b) in from_zero
.extended
.wing
.iter()
.zip(from_subpixel.extended.wing.iter())
{
assert_eq!(
a, b,
"sub-pixel delta must not move the wing (integer recenter only)"
);
}
}
#[test]
fn build_m1_and_m0() {
let oversample = 5;
let stamp_size = 15;
let core = gaussian_core(oversample, stamp_size, 1.2);
let wing_side = 61usize;
let params = ExtendedPsfParams {
stitch: StitchParams {
match_radius: 5.0,
feather_width: 3.0,
ee_aperture_radius: 18.0,
},
combine: CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 5,
},
scale_aperture_radius: 6.0,
scale_background_annulus: (22.0, 28.0),
};
let data1 = synth_bright_stars(&core, 1.2, wing_side, &[1700.0], &[4.0]);
let built1 = build_extended_psf(
data1.view(),
None,
Array2::<f64>::zeros((1, 2)).view(),
core.view(),
oversample,
params.clone(),
)
.unwrap();
assert_eq!(built1.star_ok.len(), 1);
assert!(built1.star_ok[0]);
assert!((built1.star_flux[0] - 1700.0).abs() < 1e-2);
let data0 = Array3::<f64>::zeros((0, wing_side, wing_side));
let built0 = build_extended_psf(
data0.view(),
None,
Array2::<f64>::zeros((0, 2)).view(),
core.view(),
oversample,
params.clone(),
)
.unwrap();
assert_eq!(built0.star_ok.len(), 0);
assert!(built0.extended.wing.iter().all(|&v| v == 0.0));
let wing_center = (wing_side as f64 - 1.0) / 2.0;
let mut ee = 0.0;
for r in 0..wing_side {
for c in 0..wing_side {
let dr = r as f64 - wing_center;
let dc = c as f64 - wing_center;
let radius = (dr * dr + dc * dc).sqrt();
if radius > params.stitch.ee_aperture_radius {
continue;
}
let f_wing = feather_wing_weight(
radius,
params.stitch.match_radius,
params.stitch.feather_width,
);
ee += (1.0 - f_wing) * core_native_value(&built0.extended.core, oversample, dr, dc);
}
}
assert!((ee - 1.0).abs() < 1e-9, "M=0 pure-core EE = {ee}");
}
#[test]
fn build_f32_and_f64_dual_path_agree() {
let oversample = 5;
let stamp_size = 15;
let core = gaussian_core(oversample, stamp_size, 1.1);
let wing_side = 61usize;
let fluxes = [1000.0, 2000.0, 1500.0];
let backgrounds = [5.0, -2.0, 8.0];
let data64 = synth_bright_stars(&core, 1.1, wing_side, &fluxes, &backgrounds);
let data32: Array3<f32> = data64.mapv(|v| v as f32);
let wing_delta = Array2::<f64>::zeros((3, 2));
let params = ExtendedPsfParams::default();
let params = ExtendedPsfParams {
stitch: StitchParams {
match_radius: 5.0,
feather_width: 3.0,
ee_aperture_radius: 18.0,
},
scale_aperture_radius: 6.0,
scale_background_annulus: (22.0, 28.0),
..params
};
let from64 = build_extended_psf(
data64.view(),
None,
wing_delta.view(),
core.view(),
oversample,
params.clone(),
)
.unwrap();
let from32 = build_extended_psf(
data32.view(),
None,
wing_delta.view(),
core.view(),
oversample,
params,
)
.unwrap();
for (a, b) in from64.extended.wing.iter().zip(from32.extended.wing.iter()) {
assert!(
(a - b).abs() < 1e-3 * a.abs().max(1.0),
"f32/f64 wing mismatch: {a} vs {b}"
);
}
for star in 0..3 {
assert!(
(from64.star_flux[star] - from32.star_flux[star]).abs()
< 1e-2 * from64.star_flux[star].abs().max(1.0)
);
}
}
#[test]
fn build_weight_none_vs_some_consistent() {
let oversample = 5;
let stamp_size = 15;
let core = gaussian_core(oversample, stamp_size, 1.2);
let wing_side = 61usize;
let fluxes = [1100.0, 1700.0, 2300.0, 800.0];
let backgrounds = [4.0, -1.0, 9.0, 2.0];
let data = synth_bright_stars(&core, 1.2, wing_side, &fluxes, &backgrounds);
let wing_delta = Array2::<f64>::zeros((4, 2));
let params = ExtendedPsfParams {
stitch: StitchParams {
match_radius: 5.0,
feather_width: 3.0,
ee_aperture_radius: 18.0,
},
combine: CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 5,
},
scale_aperture_radius: 6.0,
scale_background_annulus: (22.0, 28.0),
};
let none = build_extended_psf(
data.view(),
None,
wing_delta.view(),
core.view(),
oversample,
params.clone(),
)
.unwrap();
let ones = Array3::<f64>::from_elem((4, wing_side, wing_side), 1.0);
let some = build_extended_psf(
data.view(),
Some(ones.view()),
wing_delta.view(),
core.view(),
oversample,
params,
)
.unwrap();
for (a, b) in none.extended.wing.iter().zip(some.extended.wing.iter()) {
assert!(
(a - b).abs() < 1e-9 * a.abs().max(1.0),
"None vs all-ones weight must agree: {a} vs {b}"
);
}
}
#[test]
fn build_robust_recovery_noise_and_symmetric_outlier() {
let oversample = 5;
let stamp_size = 15;
let core = gaussian_core(oversample, stamp_size, 1.1);
let wing_side = 61usize;
let sigma = 1.1;
let fluxes = [
1000.0, 1400.0, 900.0, 2100.0, 1700.0, 1200.0, 2600.0, 800.0, 1500.0,
];
let backgrounds = [5.0, -3.0, 9.0, 1.0, 7.0, 0.0, 12.0, -2.0, 4.0];
let clean = synth_bright_stars(&core, sigma, wing_side, &fluxes, &backgrounds);
let mut rng = SplitMix64::new(0x7151_3771_2DEA_D17F);
let mut noisy = clean.clone();
for value in noisy.iter_mut() {
*value += rng.range(-0.4, 0.4);
}
let wc = (wing_side - 1) / 2;
let outlier_pixel = (0usize, wc, wc + 12); let params = ExtendedPsfParams {
stitch: StitchParams {
match_radius: 5.0,
feather_width: 3.0,
ee_aperture_radius: 18.0,
},
combine: CombineMethod::ClippedMean {
kappa: 2.5,
max_iter: 5,
},
scale_aperture_radius: 6.0,
scale_background_annulus: (22.0, 28.0),
};
let wing_delta = Array2::<f64>::zeros((fluxes.len(), 2));
let mut plus = noisy.clone();
plus[outlier_pixel] += 5.0e4;
let mut minus = noisy.clone();
minus[outlier_pixel] -= 5.0e4;
let built_plus = build_extended_psf(
plus.view(),
None,
wing_delta.view(),
core.view(),
oversample,
params.clone(),
)
.unwrap();
let built_minus = build_extended_psf(
minus.view(),
None,
wing_delta.view(),
core.view(),
oversample,
params.clone(),
)
.unwrap();
for (a, b) in built_plus
.extended
.wing
.iter()
.zip(built_minus.extended.wing.iter())
{
assert!(
(a - b).abs() < 1e-9 * a.abs().max(1.0),
"sign-agnostic outlier rejection must propagate: {a} vs {b}"
);
}
let geometry = validate_core(&core.view(), oversample).unwrap();
let reference_norm =
Array3::<f64>::from_shape_fn((fluxes.len(), wing_side, wing_side), |(star, r, c)| {
(plus[(star, r, c)] - built_plus.star_background[star]) / built_plus.star_flux[star]
});
let reference = robust_combine::<f64>(reference_norm.view(), None, params.combine).unwrap();
let reference_extended = stitch_core_and_wing(
&core.view(),
&geometry,
&reference.combined.view(),
Some(&reference.weight.view()),
¶ms.stitch,
);
for (a, b) in built_plus
.extended
.wing
.iter()
.zip(reference_extended.wing.iter())
{
assert!(
(a - b).abs() < 1e-6 * a.abs().max(1.0),
"orchestrator wing must equal the hand-built reference: {a} vs {b}"
);
}
assert!(built_plus.star_ok.iter().all(|&ok| ok));
assert!(built_plus.star_scale_from_core.iter().all(|&v| v));
assert!(
built_plus
.star_flux
.iter()
.all(|f| f.is_finite() && *f > 0.0)
);
}
#[test]
fn build_preconditions() {
let oversample = 5;
let stamp_size = 15;
let core = gaussian_core(oversample, stamp_size, 1.2);
let wing_side = 61usize;
let data = Array3::<f64>::zeros((3, wing_side, wing_side));
let delta = Array2::<f64>::zeros((3, 2));
let good = ExtendedPsfParams {
stitch: StitchParams {
match_radius: 5.0,
feather_width: 3.0,
ee_aperture_radius: 18.0,
},
combine: CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 5,
},
scale_aperture_radius: 6.0,
scale_background_annulus: (22.0, 28.0),
};
assert_eq!(
build_extended_psf(
data.view(),
None,
delta.view(),
core.view(),
4,
good.clone()
)
.unwrap_err(),
ExtendedPsfError::OversampleNotOdd { oversample: 4 }
);
let rect = Array2::<f64>::zeros((10, 12));
assert_eq!(
build_extended_psf(
data.view(),
None,
delta.view(),
rect.view(),
5,
good.clone()
)
.unwrap_err(),
ExtendedPsfError::CoreNotSquare { rows: 10, cols: 12 }
);
let nm = Array2::<f64>::zeros((76, 76));
assert_eq!(
build_extended_psf(data.view(), None, delta.view(), nm.view(), 5, good.clone())
.unwrap_err(),
ExtendedPsfError::CoreSizeNotMultiple {
core_side: 76,
oversample: 5
}
);
let es = Array2::<f64>::zeros((20, 20));
assert_eq!(
build_extended_psf(data.view(), None, delta.view(), es.view(), 5, good.clone())
.unwrap_err(),
ExtendedPsfError::DerivedStampSizeEven { stamp_size: 4 }
);
let even_wing = Array3::<f64>::zeros((3, 60, 61));
assert_eq!(
build_extended_psf(
even_wing.view(),
None,
delta.view(),
core.view(),
5,
good.clone()
)
.unwrap_err(),
ExtendedPsfError::WingNotOdd { rows: 60, cols: 61 }
);
let rectangular_wing = Array3::<f64>::zeros((3, 59, 61));
assert_eq!(
build_extended_psf(
rectangular_wing.view(),
None,
delta.view(),
core.view(),
5,
good.clone()
)
.unwrap_err(),
ExtendedPsfError::WingNotSquare { rows: 59, cols: 61 }
);
let bad_delta = Array2::<f64>::zeros((2, 2));
assert_eq!(
build_extended_psf(
data.view(),
None,
bad_delta.view(),
core.view(),
5,
good.clone()
)
.unwrap_err(),
ExtendedPsfError::BatchLengthMismatch {
wing_data: (3, wing_side, wing_side),
wing_delta: (2, 2)
}
);
let bad_weight = Array3::<f64>::zeros((3, wing_side, wing_side - 2));
assert_eq!(
build_extended_psf(
data.view(),
Some(bad_weight.view()),
delta.view(),
core.view(),
5,
good.clone()
)
.unwrap_err(),
ExtendedPsfError::WeightShapeMismatch {
weight: (3, wing_side, wing_side - 2),
wing_data: (3, wing_side, wing_side)
}
);
for bad in [
ExtendedPsfParams {
combine: CombineMethod::ClippedMean {
kappa: 0.0,
max_iter: 5,
},
..good.clone()
},
ExtendedPsfParams {
scale_aperture_radius: -1.0,
..good.clone()
},
ExtendedPsfParams {
scale_background_annulus: (10.0, 5.0),
..good.clone()
},
ExtendedPsfParams {
stitch: StitchParams {
match_radius: f64::NAN,
feather_width: 3.0,
ee_aperture_radius: 18.0,
},
..good.clone()
},
] {
let err =
build_extended_psf(data.view(), None, delta.view(), core.view(), 5, bad.clone())
.unwrap_err();
assert!(
matches!(err, ExtendedPsfError::ParamsInvalid { .. }),
"expected ParamsInvalid for {bad:?}, got {err:?}"
);
}
assert_eq!(
build_extended_psf(
data.view(),
None,
delta.view(),
es.view(),
5,
ExtendedPsfParams {
scale_aperture_radius: -1.0,
..good.clone()
}
)
.unwrap_err(),
ExtendedPsfError::DerivedStampSizeEven { stamp_size: 4 }
);
}
}