use ndarray::{Array2, ArrayView2};
use thiserror::Error;
use crate::float::Float;
use crate::image::stats::median_in_place;
pub const DEFAULT_WEIGHT_FWHM: f64 = 3.0;
pub const DEFAULT_CENTROID_MAX_ITER: usize = 10;
pub const DEFAULT_CENTROID_TOL: f64 = 1e-3;
#[derive(Debug, Clone, PartialEq)]
pub struct StampResult {
pub stamp: Array2<f64>,
pub error: Option<Array2<f64>>,
pub valid: Array2<bool>,
pub delta: (f64, f64),
pub origin: (i64, i64),
}
#[derive(Debug, Error, PartialEq)]
pub enum StampError {
#[error("stamp_size must be odd; got {stamp_size}")]
StampSizeEven { stamp_size: usize },
#[error(
"stamp_size ({stamp_size}) must not exceed cutout dimensions; cutout is ({rows}, {cols})"
)]
StampSizeTooLarge {
stamp_size: usize,
rows: usize,
cols: usize,
},
#[error("error shape {error_shape:?} must equal cutout shape {cutout_shape:?}")]
ErrorShapeMismatch {
error_shape: (usize, usize),
cutout_shape: (usize, usize),
},
#[error("mask shape {mask_shape:?} must equal cutout shape {cutout_shape:?}")]
MaskShapeMismatch {
mask_shape: (usize, usize),
cutout_shape: (usize, usize),
},
}
pub fn build_stamp<T: Float>(
cutout: ArrayView2<T>,
stamp_size: usize,
error: Option<ArrayView2<T>>,
mask: Option<ArrayView2<bool>>,
weight_fwhm: f64,
max_iter: usize,
tol: f64,
) -> Result<Option<StampResult>, StampError> {
let cutout_rows = cutout.shape()[0];
let cutout_cols = cutout.shape()[1];
if stamp_size.is_multiple_of(2) {
return Err(StampError::StampSizeEven { stamp_size });
}
if stamp_size > cutout_rows || stamp_size > cutout_cols {
return Err(StampError::StampSizeTooLarge {
stamp_size,
rows: cutout_rows,
cols: cutout_cols,
});
}
if let Some(error_view) = error.as_ref() {
let error_shape = (error_view.shape()[0], error_view.shape()[1]);
if error_shape != (cutout_rows, cutout_cols) {
return Err(StampError::ErrorShapeMismatch {
error_shape,
cutout_shape: (cutout_rows, cutout_cols),
});
}
}
if let Some(mask_view) = mask.as_ref() {
let mask_shape = (mask_view.shape()[0], mask_view.shape()[1]);
if mask_shape != (cutout_rows, cutout_cols) {
return Err(StampError::MaskShapeMismatch {
mask_shape,
cutout_shape: (cutout_rows, cutout_cols),
});
}
}
let cutout_f64: Array2<f64> = cutout.mapv(|value| value.to_f64().unwrap_or(f64::NAN));
let is_masked = |row: usize, column: usize| -> bool {
mask.as_ref()
.map(|mask_view| mask_view[(row, column)])
.unwrap_or(false)
};
let participates = |row: usize, column: usize| -> bool {
cutout_f64[(row, column)].is_finite() && !is_masked(row, column)
};
let mut border_values: Vec<f64> = Vec::new();
for column in 0..cutout_cols {
for &row in &[0usize, cutout_rows - 1] {
if participates(row, column) {
border_values.push(cutout_f64[(row, column)]);
}
}
}
if cutout_rows > 2 {
for row in 1..(cutout_rows - 1) {
for &column in &[0usize, cutout_cols - 1] {
if participates(row, column) {
border_values.push(cutout_f64[(row, column)]);
}
}
}
}
let background = median_in_place(&mut border_values).unwrap_or(0.0);
let sigma = weight_fwhm / (2.0 * (2.0 * 2_f64.ln()).sqrt());
let mut centroid_row = (cutout_rows as f64 - 1.0) / 2.0;
let mut centroid_column = (cutout_cols as f64 - 1.0) / 2.0;
let mut converged_weight_sum = 0.0_f64;
let mut participating_count: usize = 0;
for _iteration in 0..max_iter.max(1) {
let mut weight_sum = 0.0_f64;
let mut weighted_row_sum = 0.0_f64;
let mut weighted_column_sum = 0.0_f64;
let mut count_this_pass: usize = 0;
for row in 0..cutout_rows {
for column in 0..cutout_cols {
if !participates(row, column) {
continue;
}
count_this_pass += 1;
let background_subtracted = cutout_f64[(row, column)] - background;
let positive_part = background_subtracted.max(0.0);
let row_distance = row as f64 - centroid_row;
let column_distance = column as f64 - centroid_column;
let squared_distance =
row_distance * row_distance + column_distance * column_distance;
let gaussian = (-0.5 * squared_distance / (sigma * sigma)).exp();
let pixel_weight = positive_part * gaussian;
weight_sum += pixel_weight;
weighted_row_sum += pixel_weight * row as f64;
weighted_column_sum += pixel_weight * column as f64;
}
}
participating_count = count_this_pass;
converged_weight_sum = weight_sum;
if weight_sum <= 0.0 {
break;
}
let new_centroid_row = weighted_row_sum / weight_sum;
let new_centroid_column = weighted_column_sum / weight_sum;
let shift = ((new_centroid_row - centroid_row).powi(2)
+ (new_centroid_column - centroid_column).powi(2))
.sqrt();
centroid_row = new_centroid_row;
centroid_column = new_centroid_column;
if shift < tol {
break;
}
}
if participating_count < 4 || converged_weight_sum <= 0.0 {
return Ok(None);
}
let center_row = centroid_row.round();
let center_column = centroid_column.round();
let half = (stamp_size / 2) as i64;
let origin_row = center_row as i64 - half;
let origin_column = center_column as i64 - half;
if origin_row < 0
|| origin_column < 0
|| origin_row + stamp_size as i64 > cutout_rows as i64
|| origin_column + stamp_size as i64 > cutout_cols as i64
{
return Ok(None);
}
let delta_row = centroid_row - center_row;
let delta_column = centroid_column - center_column;
let mut stamp = Array2::<f64>::zeros((stamp_size, stamp_size));
let mut valid = Array2::<bool>::from_elem((stamp_size, stamp_size), false);
let mut windowed_error: Option<Array2<f64>> = error
.as_ref()
.map(|_| Array2::<f64>::zeros((stamp_size, stamp_size)));
for a in 0..stamp_size {
for b in 0..stamp_size {
let source_row = (origin_row + a as i64) as usize;
let source_column = (origin_column + b as i64) as usize;
let value = cutout_f64[(source_row, source_column)];
stamp[(a, b)] = value;
let error_ok = match error.as_ref() {
Some(error_view) => {
let error_value = error_view[(source_row, source_column)]
.to_f64()
.unwrap_or(f64::NAN);
windowed_error.as_mut().unwrap()[(a, b)] = error_value;
error_value.is_finite() && error_value > 0.0
}
None => true,
};
valid[(a, b)] = value.is_finite() && !is_masked(source_row, source_column) && error_ok;
}
}
Ok(Some(StampResult {
stamp,
error: windowed_error,
valid,
delta: (delta_row, delta_column),
origin: (origin_row, origin_column),
}))
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
const TOL: f64 = 1e-9;
fn gaussian_cutout(
rows: usize,
cols: usize,
source_row: f64,
source_column: f64,
amplitude: f64,
sigma: f64,
background: f64,
) -> Array2<f64> {
let mut image = Array2::<f64>::from_elem((rows, cols), background);
for row in 0..rows {
for column in 0..cols {
let row_distance = row as f64 - source_row;
let column_distance = column as f64 - source_column;
let squared_distance =
row_distance * row_distance + column_distance * column_distance;
image[(row, column)] +=
amplitude * (-0.5 * squared_distance / (sigma * sigma)).exp();
}
}
image
}
#[test]
fn centered_point_source_zero_delta_f64() {
let cutout = gaussian_cutout(11, 11, 5.0, 5.0, 100.0, 1.5, 1.0);
let result = build_stamp(
cutout.view(),
5,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert!(
result.delta.0.abs() < 1e-6,
"delta_row = {}",
result.delta.0
);
assert!(
result.delta.1.abs() < 1e-6,
"delta_col = {}",
result.delta.1
);
assert_eq!(result.origin, (3, 3));
assert!((result.stamp[(2, 2)] - 101.0).abs() < TOL);
assert!((result.stamp[(0, 0)] - cutout[(3, 3)]).abs() < TOL);
assert!(result.error.is_none());
assert!(result.valid.iter().all(|&v| v));
}
#[test]
fn centered_point_source_zero_delta_f32() {
let cutout_f64 = gaussian_cutout(11, 11, 5.0, 5.0, 100.0, 1.5, 1.0);
let cutout_f32: Array2<f32> = cutout_f64.mapv(|v| v as f32);
let result = build_stamp(
cutout_f32.view(),
5,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert!(
result.delta.0.abs() < 1e-4,
"delta_row = {}",
result.delta.0
);
assert!(
result.delta.1.abs() < 1e-4,
"delta_col = {}",
result.delta.1
);
assert_eq!(result.origin, (3, 3));
assert!((result.stamp[(2, 2)] - 101.0).abs() < 1e-3);
assert!(result.error.is_none());
assert!(result.valid.iter().all(|&v| v));
}
#[test]
fn off_center_source_integer_window_and_delta_range_f64() {
let cutout = gaussian_cutout(13, 13, 6.3, 4.7, 80.0, 1.5, 2.0);
let result = build_stamp(
cutout.view(),
5,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert!(
result.delta.0 >= -0.5 && result.delta.0 < 0.5,
"delta_row = {}",
result.delta.0
);
assert!(
result.delta.1 >= -0.5 && result.delta.1 < 0.5,
"delta_col = {}",
result.delta.1
);
let center_row = result.origin.0 + 2;
let center_column = result.origin.1 + 2;
let centroid_row = center_row as f64 + result.delta.0;
let centroid_column = center_column as f64 + result.delta.1;
assert!(
(centroid_row - 6.3).abs() < 0.3,
"centroid_row = {centroid_row}"
);
assert!(
(centroid_column - 4.7).abs() < 0.3,
"centroid_col = {centroid_column}"
);
assert!((result.delta.0 - (centroid_row - centroid_row.round())).abs() < TOL);
assert!((result.delta.1 - (centroid_column - centroid_column.round())).abs() < TOL);
}
#[test]
fn off_center_source_integer_window_and_delta_range_f32() {
let cutout_f64 = gaussian_cutout(13, 13, 6.3, 4.7, 80.0, 1.5, 2.0);
let cutout_f32: Array2<f32> = cutout_f64.mapv(|v| v as f32);
let result = build_stamp(
cutout_f32.view(),
5,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert!(result.delta.0 >= -0.5 && result.delta.0 < 0.5);
assert!(result.delta.1 >= -0.5 && result.delta.1 < 0.5);
let center_row = result.origin.0 + 2;
let center_column = result.origin.1 + 2;
let centroid_row = center_row as f64 + result.delta.0;
let centroid_column = center_column as f64 + result.delta.1;
assert!((centroid_row - 6.3).abs() < 0.3);
assert!((centroid_column - 4.7).abs() < 0.3);
}
#[test]
fn window_cannot_fit_returns_none_f64() {
let cutout = gaussian_cutout(9, 9, 0.0, 0.0, 100.0, 1.0, 0.0);
let result = build_stamp(
cutout.view(),
7,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap();
assert!(result.is_none());
}
#[test]
fn window_cannot_fit_returns_none_f32() {
let cutout_f64 = gaussian_cutout(9, 9, 0.0, 0.0, 100.0, 1.0, 0.0);
let cutout_f32: Array2<f32> = cutout_f64.mapv(|v| v as f32);
let result = build_stamp(
cutout_f32.view(),
7,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap();
assert!(result.is_none());
}
#[test]
fn too_few_valid_pixels_returns_none_f64() {
let cutout = Array2::<f64>::from_elem((5, 5), 10.0);
let mut mask = Array2::<bool>::from_elem((5, 5), true);
mask[(2, 2)] = false;
mask[(2, 3)] = false;
mask[(3, 2)] = false;
let result = build_stamp(
cutout.view(),
3,
None,
Some(mask.view()),
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap();
assert!(result.is_none());
}
#[test]
fn too_few_valid_pixels_returns_none_f32() {
let cutout = Array2::<f32>::from_elem((5, 5), 10.0);
let mut mask = Array2::<bool>::from_elem((5, 5), true);
mask[(2, 2)] = false;
mask[(2, 3)] = false;
mask[(3, 2)] = false;
let result = build_stamp(
cutout.view(),
3,
None,
Some(mask.view()),
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap();
assert!(result.is_none());
}
#[test]
fn mask_excludes_pixels_from_centroid_and_marks_invalid_f64() {
let mut cutout = gaussian_cutout(11, 11, 5.0, 5.0, 50.0, 1.5, 1.0);
cutout[(1, 1)] = 10_000.0;
let mut mask = Array2::<bool>::from_elem((11, 11), false);
mask[(1, 1)] = true;
let result = build_stamp(
cutout.view(),
5,
None,
Some(mask.view()),
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert_eq!(result.origin, (3, 3));
assert!(result.delta.0.abs() < 0.5 && result.delta.1.abs() < 0.5);
let mut mask2 = Array2::<bool>::from_elem((11, 11), false);
mask2[(4, 4)] = true; let result2 = build_stamp(
cutout.view(),
5,
None,
Some(mask2.view()),
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert!(!result2.valid[(1, 1)], "masked pixel must be valid=false");
assert!(result2.valid[(0, 0)]);
}
#[test]
fn mask_excludes_pixels_from_centroid_and_marks_invalid_f32() {
let mut cutout_f64 = gaussian_cutout(11, 11, 5.0, 5.0, 50.0, 1.5, 1.0);
cutout_f64[(1, 1)] = 10_000.0;
let cutout_f32: Array2<f32> = cutout_f64.mapv(|v| v as f32);
let mut mask = Array2::<bool>::from_elem((11, 11), false);
mask[(1, 1)] = true;
let result = build_stamp(
cutout_f32.view(),
5,
None,
Some(mask.view()),
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert_eq!(result.origin, (3, 3));
}
#[test]
fn error_none_means_result_error_none_and_valid_ignores_error_f64() {
let cutout = gaussian_cutout(11, 11, 5.0, 5.0, 100.0, 1.5, 1.0);
let result = build_stamp(
cutout.view(),
5,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert!(result.error.is_none());
assert!(result.valid.iter().all(|&v| v));
}
#[test]
fn error_none_means_result_error_none_and_valid_ignores_error_f32() {
let cutout_f64 = gaussian_cutout(11, 11, 5.0, 5.0, 100.0, 1.5, 1.0);
let cutout_f32: Array2<f32> = cutout_f64.mapv(|v| v as f32);
let result = build_stamp(
cutout_f32.view(),
5,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert!(result.error.is_none());
assert!(result.valid.iter().all(|&v| v));
}
#[test]
fn error_some_returns_windowed_error_and_valid_requires_positive_f64() {
let cutout = gaussian_cutout(11, 11, 5.0, 5.0, 100.0, 1.5, 1.0);
let mut error = Array2::<f64>::from_elem((11, 11), 2.0);
error[(5, 5)] = 0.0;
error[(4, 4)] = f64::NAN;
let result = build_stamp(
cutout.view(),
5,
Some(error.view()),
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
let windowed_error = result.error.as_ref().expect("error must be Some");
assert!((windowed_error[(0, 0)] - 2.0).abs() < TOL);
assert!((windowed_error[(2, 2)] - 0.0).abs() < TOL);
assert!(!result.valid[(2, 2)]);
assert!(windowed_error[(1, 1)].is_nan());
assert!(!result.valid[(1, 1)]);
assert!(result.valid[(0, 0)]);
}
#[test]
fn error_some_returns_windowed_error_and_valid_requires_positive_f32() {
let cutout_f64 = gaussian_cutout(11, 11, 5.0, 5.0, 100.0, 1.5, 1.0);
let cutout_f32: Array2<f32> = cutout_f64.mapv(|v| v as f32);
let mut error = Array2::<f32>::from_elem((11, 11), 2.0);
error[(5, 5)] = 0.0;
error[(4, 4)] = f32::NAN;
let result = build_stamp(
cutout_f32.view(),
5,
Some(error.view()),
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
let windowed_error = result.error.as_ref().expect("error must be Some");
assert!((windowed_error[(0, 0)] - 2.0).abs() < 1e-5);
assert!(!result.valid[(2, 2)]);
assert!(!result.valid[(1, 1)]);
assert!(result.valid[(0, 0)]);
}
#[test]
fn nan_in_cutout_excluded_from_centroid_and_invalid_f64() {
let mut cutout = gaussian_cutout(11, 11, 5.0, 5.0, 100.0, 1.5, 1.0);
cutout[(5, 6)] = f64::NAN;
let result = build_stamp(
cutout.view(),
5,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert!(result.stamp[(2, 3)].is_nan());
assert!(!result.valid[(2, 3)], "NaN pixel must be valid=false");
assert!(result.valid[(0, 0)]);
}
#[test]
fn nan_in_cutout_excluded_from_centroid_and_invalid_f32() {
let mut cutout_f64 = gaussian_cutout(11, 11, 5.0, 5.0, 100.0, 1.5, 1.0);
cutout_f64[(5, 6)] = f64::NAN;
let cutout_f32: Array2<f32> = cutout_f64.mapv(|v| v as f32);
let result = build_stamp(
cutout_f32.view(),
5,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap()
.expect("expected a stamp");
assert!(result.stamp[(2, 3)].is_nan());
assert!(!result.valid[(2, 3)]);
assert!(result.valid[(0, 0)]);
}
#[test]
fn error_precondition_stamp_size_even() {
let cutout = Array2::<f64>::zeros((9, 9));
let err = build_stamp(
cutout.view(),
4,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap_err();
assert_eq!(err, StampError::StampSizeEven { stamp_size: 4 });
}
#[test]
fn error_precondition_stamp_size_too_large() {
let cutout = Array2::<f64>::zeros((5, 9));
let err = build_stamp(
cutout.view(),
7,
None,
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap_err();
assert_eq!(
err,
StampError::StampSizeTooLarge {
stamp_size: 7,
rows: 5,
cols: 9,
}
);
}
#[test]
fn error_precondition_error_shape_mismatch() {
let cutout = Array2::<f64>::zeros((9, 9));
let error = Array2::<f64>::zeros((9, 8));
let err = build_stamp(
cutout.view(),
5,
Some(error.view()),
None,
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap_err();
assert_eq!(
err,
StampError::ErrorShapeMismatch {
error_shape: (9, 8),
cutout_shape: (9, 9),
}
);
}
#[test]
fn error_precondition_mask_shape_mismatch() {
let cutout = Array2::<f64>::zeros((9, 9));
let mask = Array2::<bool>::from_elem((8, 9), false);
let err = build_stamp(
cutout.view(),
5,
None,
Some(mask.view()),
DEFAULT_WEIGHT_FWHM,
DEFAULT_CENTROID_MAX_ITER,
DEFAULT_CENTROID_TOL,
)
.unwrap_err();
assert_eq!(
err,
StampError::MaskShapeMismatch {
mask_shape: (8, 9),
cutout_shape: (9, 9),
}
);
}
}