use ndarray::{Array2, ArrayView3};
use rayon::prelude::*;
use thiserror::Error;
use crate::float::Float;
use crate::image::stats::median_in_place;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CombineMethod {
ClippedMean { kappa: f64, max_iter: usize },
Median,
}
#[derive(Debug, Clone, PartialEq)]
pub struct RobustCombined {
pub combined: Array2<f64>,
pub weight: Array2<f64>,
pub count: Array2<u32>,
}
#[derive(Debug, Error, PartialEq)]
pub enum RobustError {
#[error("weight shape {weight:?} must equal stack shape {stack:?}")]
WeightShapeMismatch {
weight: (usize, usize, usize),
stack: (usize, usize, usize),
},
#[error(
"ClippedMean requires kappa > 0 and max_iter > 0; got kappa = {kappa}, max_iter = {max_iter}"
)]
ClippedMeanInvalidParams { kappa: f64, max_iter: usize },
}
pub fn robust_combine<T: Float>(
stack: ArrayView3<T>,
weight: Option<ArrayView3<T>>,
method: CombineMethod,
) -> Result<RobustCombined, RobustError> {
let sample_count = stack.shape()[0];
let height = stack.shape()[1];
let width = stack.shape()[2];
if let Some(weight_view) = weight {
let weight_shape = weight_view.shape();
if weight_shape[0] != sample_count || weight_shape[1] != height || weight_shape[2] != width
{
return Err(RobustError::WeightShapeMismatch {
weight: (weight_shape[0], weight_shape[1], weight_shape[2]),
stack: (sample_count, height, width),
});
}
}
if let CombineMethod::ClippedMean { kappa, max_iter } = method {
let params_invalid = kappa <= 0.0 || max_iter == 0;
if params_invalid {
return Err(RobustError::ClippedMeanInvalidParams { kappa, max_iter });
}
}
let pixel_count = height * width;
let ((combined_values, weight_values), count_values): (
(Vec<f64>, Vec<f64>),
Vec<u32>,
) = (0..pixel_count)
.into_par_iter()
.map(|flat_index| {
let row = flat_index / width;
let column = flat_index % width;
let (combined, combined_weight, survivors) =
combine_one_pixel(&stack, weight.as_ref(), row, column, sample_count, method);
((combined, combined_weight), survivors)
})
.unzip();
Ok(RobustCombined {
combined: Array2::from_shape_vec((height, width), combined_values)
.expect("combined length == height * width"),
weight: Array2::from_shape_vec((height, width), weight_values)
.expect("weight length == height * width"),
count: Array2::from_shape_vec((height, width), count_values)
.expect("count length == height * width"),
})
}
fn combine_one_pixel<T: Float>(
stack: &ArrayView3<T>,
weight: Option<&ArrayView3<T>>,
row: usize,
column: usize,
sample_count: usize,
method: CombineMethod,
) -> (f64, f64, u32) {
let mut sample_values: Vec<f64> = Vec::with_capacity(sample_count);
let mut sample_weights: Vec<f64> = Vec::with_capacity(sample_count);
for n in 0..sample_count {
let value = stack[(n, row, column)].to_f64().unwrap_or(f64::NAN);
if !value.is_finite() {
continue;
}
let pixel_weight = match weight {
Some(weight_view) => {
let w = weight_view[(n, row, column)].to_f64().unwrap_or(f64::NAN);
if !(w.is_finite() && w > 0.0) {
continue;
}
w
}
None => 1.0,
};
sample_values.push(value);
sample_weights.push(pixel_weight);
}
if sample_values.is_empty() {
return (f64::NAN, 0.0, 0);
}
match method {
CombineMethod::Median => combine_median(&sample_values, &sample_weights),
CombineMethod::ClippedMean { kappa, max_iter } => {
combine_clipped_mean(&sample_values, &sample_weights, kappa, max_iter)
}
}
}
fn combine_median(values: &[f64], weights: &[f64]) -> (f64, f64, u32) {
let mut sorted = values.to_vec();
let combined = median_in_place(&mut sorted).unwrap_or(f64::NAN);
let weight_sum: f64 = weights.iter().sum();
(combined, weight_sum, values.len() as u32)
}
fn combine_clipped_mean(
values: &[f64],
weights: &[f64],
kappa: f64,
max_iter: usize,
) -> (f64, f64, u32) {
let mut alive = vec![true; values.len()];
for _iteration in 0..max_iter {
let survivors: Vec<f64> = (0..values.len())
.filter(|&index| alive[index])
.map(|index| values[index])
.collect();
if survivors.is_empty() {
break;
}
let mut sorted = survivors.clone();
let center = median_in_place(&mut sorted).unwrap_or(f64::NAN);
let mean_square: f64 = survivors
.iter()
.map(|&x| (x - center) * (x - center))
.sum::<f64>()
/ survivors.len() as f64;
let scale = mean_square.sqrt();
let threshold = kappa * scale;
let mut removed_this_pass = false;
for index in 0..values.len() {
if alive[index] && (values[index] - center).abs() > threshold {
alive[index] = false;
removed_this_pass = true;
}
}
if !removed_this_pass {
break;
}
}
let mut weighted_value_sum = 0.0;
let mut weight_sum = 0.0;
let mut survivor_count: usize = 0;
for index in 0..values.len() {
if alive[index] {
weighted_value_sum += weights[index] * values[index];
weight_sum += weights[index];
survivor_count += 1;
}
}
if survivor_count == 0 {
return (f64::NAN, 0.0, 0);
}
(
weighted_value_sum / weight_sum,
weight_sum,
survivor_count as u32,
)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array2, Array3};
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn unit(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
fn range(&mut self, lo: f64, hi: f64) -> f64 {
lo + (hi - lo) * self.unit()
}
}
fn naive_gather(
stack: &Array3<f64>,
weight: Option<&Array3<f64>>,
row: usize,
column: usize,
) -> (Vec<f64>, Vec<f64>) {
let mut values = Vec::new();
let mut weights = Vec::new();
for n in 0..stack.shape()[0] {
let value = stack[(n, row, column)];
if !value.is_finite() {
continue;
}
let w = match weight {
Some(weight_array) => {
let w = weight_array[(n, row, column)];
if !(w.is_finite() && w > 0.0) {
continue;
}
w
}
None => 1.0,
};
values.push(value);
weights.push(w);
}
(values, weights)
}
fn naive_median(values: &mut [f64]) -> f64 {
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let length = values.len();
if length % 2 == 1 {
values[length / 2]
} else {
0.5 * (values[length / 2 - 1] + values[length / 2])
}
}
fn naive_median_combine(
stack: &Array3<f64>,
weight: Option<&Array3<f64>>,
) -> (Array2<f64>, Array2<f64>, Array2<u32>) {
let (_, height, width) = (stack.shape()[0], stack.shape()[1], stack.shape()[2]);
let mut combined = Array2::<f64>::zeros((height, width));
let mut weight_out = Array2::<f64>::zeros((height, width));
let mut count = Array2::<u32>::zeros((height, width));
for row in 0..height {
for column in 0..width {
let (mut values, weights) = naive_gather(stack, weight, row, column);
if values.is_empty() {
combined[(row, column)] = f64::NAN;
weight_out[(row, column)] = 0.0;
count[(row, column)] = 0;
continue;
}
count[(row, column)] = values.len() as u32;
weight_out[(row, column)] = weights.iter().sum();
combined[(row, column)] = naive_median(&mut values);
}
}
(combined, weight_out, count)
}
fn naive_clipped_mean_combine(
stack: &Array3<f64>,
weight: Option<&Array3<f64>>,
kappa: f64,
max_iter: usize,
) -> (Array2<f64>, Array2<f64>, Array2<u32>) {
let (_, height, width) = (stack.shape()[0], stack.shape()[1], stack.shape()[2]);
let mut combined = Array2::<f64>::zeros((height, width));
let mut weight_out = Array2::<f64>::zeros((height, width));
let mut count = Array2::<u32>::zeros((height, width));
for row in 0..height {
for column in 0..width {
let (values, weights) = naive_gather(stack, weight, row, column);
if values.is_empty() {
combined[(row, column)] = f64::NAN;
continue;
}
let mut alive = vec![true; values.len()];
for _ in 0..max_iter {
let survivors: Vec<f64> = (0..values.len())
.filter(|&i| alive[i])
.map(|i| values[i])
.collect();
if survivors.is_empty() {
break;
}
let mut sorted = survivors.clone();
let center = naive_median(&mut sorted);
let mean_square: f64 = survivors
.iter()
.map(|&x| (x - center) * (x - center))
.sum::<f64>()
/ survivors.len() as f64;
let threshold = kappa * mean_square.sqrt();
let mut removed = false;
for i in 0..values.len() {
if alive[i] && (values[i] - center).abs() > threshold {
alive[i] = false;
removed = true;
}
}
if !removed {
break;
}
}
let mut wv = 0.0;
let mut ws = 0.0;
let mut c: u32 = 0;
for i in 0..values.len() {
if alive[i] {
wv += weights[i] * values[i];
ws += weights[i];
c += 1;
}
}
if c == 0 {
combined[(row, column)] = f64::NAN;
weight_out[(row, column)] = 0.0;
count[(row, column)] = 0;
} else {
combined[(row, column)] = wv / ws;
weight_out[(row, column)] = ws;
count[(row, column)] = c;
}
}
}
(combined, weight_out, count)
}
fn assert_planes_close(
got: &RobustCombined,
want: &(Array2<f64>, Array2<f64>, Array2<u32>),
tol: f64,
) {
let (want_combined, want_weight, want_count) = want;
for ((g, w), _) in got.combined.iter().zip(want_combined.iter()).zip(0..) {
if w.is_nan() {
assert!(g.is_nan(), "expected NaN sentinel, got {g}");
} else {
assert!(
(g - w).abs() <= tol * w.abs().max(1.0),
"combined {g} != {w}"
);
}
}
for (g, w) in got.weight.iter().zip(want_weight.iter()) {
assert!((g - w).abs() <= tol * w.abs().max(1.0), "weight {g} != {w}");
}
for (g, w) in got.count.iter().zip(want_count.iter()) {
assert_eq!(g, w, "count {g} != {w}");
}
}
fn random_stack(rng: &mut SplitMix64, n: usize, h: usize, w: usize) -> Array3<f64> {
Array3::from_shape_fn((n, h, w), |_| rng.range(-5.0, 5.0))
}
#[test]
fn median_matches_naive_reference_equal_weight() {
let mut rng = SplitMix64::new(0xC0FF_EE00_1234_5678);
let stack = random_stack(&mut rng, 9, 6, 7);
let got = robust_combine(stack.view(), None, CombineMethod::Median).unwrap();
let want = naive_median_combine(&stack, None);
assert_planes_close(&got, &want, 1e-12);
}
#[test]
fn median_matches_naive_reference_weighted_gate() {
let mut rng = SplitMix64::new(0x1234_5678_9ABC_DEF0);
let stack = random_stack(&mut rng, 11, 5, 4);
let weight = Array3::from_shape_fn((11, 5, 4), |_| rng.range(-0.5, 2.0));
let got = robust_combine(stack.view(), Some(weight.view()), CombineMethod::Median).unwrap();
let want = naive_median_combine(&stack, Some(&weight));
assert_planes_close(&got, &want, 1e-12);
}
#[test]
fn clipped_mean_matches_naive_reference_equal_weight() {
let mut rng = SplitMix64::new(0xDEAD_BEEF_F00D_BABE);
let stack = random_stack(&mut rng, 15, 5, 6);
let method = CombineMethod::ClippedMean {
kappa: 2.5,
max_iter: 5,
};
let got = robust_combine(stack.view(), None, method).unwrap();
let want = naive_clipped_mean_combine(&stack, None, 2.5, 5);
assert_planes_close(&got, &want, 1e-9);
}
#[test]
fn clipped_mean_matches_naive_reference_weighted() {
let mut rng = SplitMix64::new(0x0BAD_C0DE_1234_5678);
let stack = random_stack(&mut rng, 13, 4, 5);
let weight = Array3::from_shape_fn((13, 4, 5), |_| rng.range(0.1, 3.0));
let method = CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 8,
};
let got = robust_combine(stack.view(), Some(weight.view()), method).unwrap();
let want = naive_clipped_mean_combine(&stack, Some(&weight), 3.0, 8);
assert_planes_close(&got, &want, 1e-9);
}
#[test]
fn weight_zero_sample_excluded_by_inclusion_gate() {
let stack = Array3::from_shape_vec((3, 1, 1), vec![1.0, 999.0, 3.0]).unwrap();
let weight = Array3::from_shape_vec((3, 1, 1), vec![2.0, 0.0, 4.0]).unwrap();
let got = robust_combine(stack.view(), Some(weight.view()), CombineMethod::Median).unwrap();
assert!((got.combined[(0, 0)] - 2.0).abs() < 1e-12);
assert_eq!(got.count[(0, 0)], 2);
assert!((got.weight[(0, 0)] - 6.0).abs() < 1e-12);
}
#[test]
fn all_invalid_pixel_is_nan_weight0_count0() {
let stack = Array3::from_shape_vec((2, 1, 2), vec![f64::NAN, 5.0, f64::NAN, 7.0]).unwrap();
let weight = Array3::from_shape_vec((2, 1, 2), vec![1.0, 0.0, 1.0, 0.0]).unwrap();
for method in [
CombineMethod::Median,
CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 5,
},
] {
let got = robust_combine(stack.view(), Some(weight.view()), method).unwrap();
for column in 0..2 {
assert!(
got.combined[(0, column)].is_nan(),
"expected NaN sentinel at column {column}"
);
assert_eq!(got.weight[(0, column)], 0.0);
assert_eq!(got.count[(0, column)], 0);
}
}
}
#[test]
fn clipped_mean_rejects_outlier_sign_agnostically() {
let inlier_pattern = [-2.0, -1.0, 0.0, 1.0, 2.0];
let n = inlier_pattern.len() * 4 + 1; let mut data = Vec::with_capacity(n * 2);
for sample_index in 0..n {
let (v0, v1) = if sample_index < n - 1 {
let v = inlier_pattern[sample_index % inlier_pattern.len()];
(v, v)
} else {
(50.0, -50.0) };
data.push(v0);
data.push(v1);
}
let stack = Array3::from_shape_vec((n, 1, 2), data).unwrap();
let method = CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 5,
};
let got = robust_combine(stack.view(), None, method).unwrap();
assert_eq!(got.count[(0, 0)], 20);
assert_eq!(got.count[(0, 1)], 20);
assert!(
(got.combined[(0, 0)] - got.combined[(0, 1)]).abs() < 1e-12,
"sign-agnostic clip must give identical results: {} vs {}",
got.combined[(0, 0)],
got.combined[(0, 1)]
);
assert!(got.combined[(0, 0)].abs() < 1e-12);
let want = naive_clipped_mean_combine(&stack, None, 3.0, 5);
assert_planes_close(&got, &want, 1e-9);
}
#[test]
fn clipped_mean_huge_kappa_is_weighted_mean() {
let mut rng = SplitMix64::new(0xABCD_1234_5678_9F00);
let stack = random_stack(&mut rng, 7, 3, 3);
let weight = Array3::from_shape_fn((7, 3, 3), |_| rng.range(0.2, 2.0));
let got = robust_combine(
stack.view(),
Some(weight.view()),
CombineMethod::ClippedMean {
kappa: 1e30,
max_iter: 10,
},
)
.unwrap();
for row in 0..3 {
for column in 0..3 {
let mut wv = 0.0;
let mut ws = 0.0;
for n in 0..7 {
wv += weight[(n, row, column)] * stack[(n, row, column)];
ws += weight[(n, row, column)];
}
assert!(
(got.combined[(row, column)] - wv / ws).abs() < 1e-9 * (wv / ws).abs().max(1.0),
"expected plain weighted mean"
);
assert_eq!(got.count[(row, column)], 7);
assert!((got.weight[(row, column)] - ws).abs() < 1e-9 * ws.max(1.0));
}
}
}
#[test]
fn single_stamp_passes_through() {
let stack = Array3::from_shape_vec((1, 2, 2), vec![3.0, -7.0, 0.5, 11.0]).unwrap();
for method in [
CombineMethod::Median,
CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 4,
},
] {
let got = robust_combine(stack.view(), None, method).unwrap();
for (g, s) in got.combined.iter().zip(stack.iter()) {
assert!((g - s).abs() < 1e-12, "N=1 passthrough: {g} != {s}");
}
assert!(got.count.iter().all(|&c| c == 1));
assert!(got.weight.iter().all(|&w| (w - 1.0).abs() < 1e-12));
}
}
#[test]
fn empty_n_yields_all_sentinel_plane() {
let stack = Array3::<f64>::zeros((0, 3, 4));
for method in [
CombineMethod::Median,
CombineMethod::ClippedMean {
kappa: 2.0,
max_iter: 3,
},
] {
let got = robust_combine(stack.view(), None, method).unwrap();
assert_eq!(got.combined.shape(), &[3, 4]);
assert!(got.combined.iter().all(|v| v.is_nan()));
assert!(got.weight.iter().all(|&w| w == 0.0));
assert!(got.count.iter().all(|&c| c == 0));
}
}
#[test]
fn empty_spatial_axis_yields_empty_output() {
let stack = Array3::<f64>::zeros((4, 0, 5));
let got = robust_combine(stack.view(), None, CombineMethod::Median).unwrap();
assert_eq!(got.combined.shape(), &[0, 5]);
let stack = Array3::<f64>::zeros((4, 5, 0));
let got = robust_combine(
stack.view(),
None,
CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 2,
},
)
.unwrap();
assert_eq!(got.combined.shape(), &[5, 0]);
}
#[test]
fn count_and_weight_semantics() {
let stack = Array3::from_shape_vec((4, 1, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let equal = robust_combine(stack.view(), None, CombineMethod::Median).unwrap();
assert_eq!(equal.count[(0, 0)], 4);
assert!((equal.weight[(0, 0)] - 4.0).abs() < 1e-12);
let weight = Array3::from_shape_vec((4, 1, 1), vec![0.5, 1.5, 2.0, 1.0]).unwrap();
let weighted =
robust_combine(stack.view(), Some(weight.view()), CombineMethod::Median).unwrap();
assert_eq!(weighted.count[(0, 0)], 4);
assert!((weighted.weight[(0, 0)] - 5.0).abs() < 1e-12);
}
#[test]
fn f32_and_f64_dual_path_agree() {
let mut rng = SplitMix64::new(0xFEED_FACE_CAFE_0001);
let stack_f64 = random_stack(&mut rng, 9, 4, 4);
let stack_f32: Array3<f32> = stack_f64.mapv(|v| v as f32);
let method = CombineMethod::ClippedMean {
kappa: 2.5,
max_iter: 5,
};
let from_f64 = robust_combine(stack_f64.view(), None, method).unwrap();
let from_f32 = robust_combine(stack_f32.view(), None, method).unwrap();
for (a, b) in from_f64.combined.iter().zip(from_f32.combined.iter()) {
if a.is_nan() {
assert!(b.is_nan());
} else {
assert!((a - b).abs() < 1e-4 * a.abs().max(1.0), "{a} vs {b}");
}
}
assert_eq!(from_f64.count, from_f32.count);
let mf64 = robust_combine(stack_f64.view(), None, CombineMethod::Median).unwrap();
let mf32 = robust_combine(stack_f32.view(), None, CombineMethod::Median).unwrap();
for (a, b) in mf64.combined.iter().zip(mf32.combined.iter()) {
assert!((a - b).abs() < 1e-4 * a.abs().max(1.0), "{a} vs {b}");
}
}
#[test]
fn error_weight_shape_mismatch() {
let stack = Array3::<f64>::zeros((3, 4, 5));
let weight = Array3::<f64>::zeros((3, 4, 6));
let err =
robust_combine(stack.view(), Some(weight.view()), CombineMethod::Median).unwrap_err();
assert_eq!(
err,
RobustError::WeightShapeMismatch {
weight: (3, 4, 6),
stack: (3, 4, 5),
}
);
}
#[test]
fn error_clipped_mean_invalid_kappa() {
let stack = Array3::<f64>::zeros((3, 2, 2));
for bad_kappa in [0.0, -1.0] {
let err = robust_combine(
stack.view(),
None,
CombineMethod::ClippedMean {
kappa: bad_kappa,
max_iter: 5,
},
)
.unwrap_err();
assert_eq!(
err,
RobustError::ClippedMeanInvalidParams {
kappa: bad_kappa,
max_iter: 5,
}
);
}
}
#[test]
fn error_clipped_mean_zero_max_iter() {
let stack = Array3::<f64>::zeros((3, 2, 2));
let err = robust_combine(
stack.view(),
None,
CombineMethod::ClippedMean {
kappa: 3.0,
max_iter: 0,
},
)
.unwrap_err();
assert_eq!(
err,
RobustError::ClippedMeanInvalidParams {
kappa: 3.0,
max_iter: 0,
}
);
}
#[test]
fn error_weight_shape_checked_before_clipped_mean_params() {
let stack = Array3::<f64>::zeros((2, 2, 2));
let weight = Array3::<f64>::zeros((2, 2, 3));
let err = robust_combine(
stack.view(),
Some(weight.view()),
CombineMethod::ClippedMean {
kappa: -1.0,
max_iter: 0,
},
)
.unwrap_err();
assert_eq!(
err,
RobustError::WeightShapeMismatch {
weight: (2, 2, 3),
stack: (2, 2, 2),
}
);
}
#[test]
fn median_no_extra_precondition() {
let stack = Array3::<f64>::zeros((1, 1, 1));
let got = robust_combine(stack.view(), None, CombineMethod::Median).unwrap();
assert_eq!(got.combined[(0, 0)], 0.0);
assert_eq!(got.count[(0, 0)], 1);
}
}