use crate::utils::compute_1d_median_filter;
use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
use crate::float_trait::Bm3dFloat;
use crate::orchestration::{
bm3d_ring_artifact_removal, bm3d_ring_artifact_removal_with_plans, Bm3dConfig, RingRemovalMode,
};
const MIN_SCALE_WIDTH: usize = 40;
const BINNING_FACTOR: usize = 2;
const DEFAULT_DEBIN_ITERATIONS: usize = 10;
const DEFAULT_MULTISCALE_FILTER_STRENGTH: f64 = 1.0;
const DEFAULT_MULTISCALE_THRESHOLD: f64 = 3.5;
#[derive(Debug, Clone)]
pub struct MultiscaleConfig<F: Bm3dFloat> {
pub num_scales: Option<usize>,
pub filter_strength: F,
pub threshold: F,
pub debin_iterations: usize,
pub bm3d_config: Bm3dConfig<F>,
}
impl<F: Bm3dFloat> Default for MultiscaleConfig<F> {
fn default() -> Self {
let bm3d_config = Bm3dConfig {
threshold: F::from_f64_c(DEFAULT_MULTISCALE_THRESHOLD),
filter_strength: F::from_f64_c(DEFAULT_MULTISCALE_FILTER_STRENGTH),
..Default::default()
};
Self {
num_scales: None,
filter_strength: F::from_f64_c(DEFAULT_MULTISCALE_FILTER_STRENGTH),
threshold: F::from_f64_c(DEFAULT_MULTISCALE_THRESHOLD),
debin_iterations: DEFAULT_DEBIN_ITERATIONS,
bm3d_config,
}
}
}
impl<F: Bm3dFloat> MultiscaleConfig<F> {
pub fn new() -> Self {
Self::default()
}
pub fn validate(&self) -> Result<(), String> {
if self.filter_strength <= F::zero() {
return Err("filter_strength must be > 0".to_string());
}
if self.threshold < F::zero() {
return Err("threshold must be >= 0".to_string());
}
if self.debin_iterations == 0 {
return Err("debin_iterations must be > 0".to_string());
}
self.bm3d_config.validate()
}
}
pub fn compute_num_scales(width: usize) -> usize {
if width <= MIN_SCALE_WIDTH {
return 0;
}
let ratio = width as f64 / MIN_SCALE_WIDTH as f64;
let k = ratio.log2().floor().max(0.0) as usize;
k.min(3)
}
fn symmetric_pad_right<F: Bm3dFloat>(arr: ArrayView2<F>, pad_amount: usize) -> Array2<F> {
if pad_amount == 0 {
return arr.to_owned();
}
let (rows, cols) = arr.dim();
let new_cols = cols + pad_amount;
let mut result = Array2::zeros((rows, new_cols));
result.slice_mut(s![.., ..cols]).assign(&arr);
for p in 0..pad_amount {
let src_col = cols - 1 - (p % cols);
for r in 0..rows {
result[[r, cols + p]] = arr[[r, src_col]];
}
}
result
}
pub fn bin_horizontal<F: Bm3dFloat>(arr: ArrayView2<F>, factor: usize) -> Array2<F> {
if factor <= 1 {
return arr.to_owned();
}
let (rows, cols) = arr.dim();
let mod_pad = factor - ((cols - 1) % factor) - 1;
let padded = symmetric_pad_right(arr, mod_pad);
let padded_cols = padded.ncols();
let mut convolved = Array2::zeros((rows, padded_cols));
for r in 0..rows {
for c in 0..padded_cols {
let mut sum = F::zero();
for k in 0..factor {
let src_c = c as isize + k as isize - (factor as isize / 2);
if src_c >= 0 && (src_c as usize) < padded_cols {
sum += padded[[r, src_c as usize]];
}
}
convolved[[r, c]] = sum;
}
}
let h_half = factor / 2;
let start = h_half + if factor % 2 == 1 { 1 } else { 0 };
let end = padded_cols - h_half + 1;
let num_bins = (end - start).div_ceil(factor);
let mut result = Array2::zeros((rows, num_bins));
for r in 0..rows {
for (i, c) in (start..end).step_by(factor).enumerate() {
if i < num_bins && c < padded_cols {
result[[r, i]] = convolved[[r, c]];
}
}
}
result
}
fn compute_bin_weights<F: Bm3dFloat>(target_width: usize, factor: usize) -> Array1<F> {
if factor <= 1 {
return Array1::ones(target_width);
}
let ones = Array2::from_elem((1, target_width), F::one());
let mod_pad = factor - ((target_width - 1) % factor) - 1;
let padded = symmetric_pad_right(ones.view(), mod_pad);
let padded_cols = padded.ncols();
let mut convolved = Array1::zeros(padded_cols);
for c in 0..padded_cols {
let mut sum = F::zero();
for k in 0..factor {
let src_c = c as isize + k as isize - (factor as isize / 2);
if src_c >= 0 && (src_c as usize) < padded_cols {
sum += padded[[0, src_c as usize]];
}
}
convolved[c] = sum;
}
convolved.slice(s![..target_width]).to_owned()
}
fn cubic_spline_interpolate_row<F: Bm3dFloat>(
x_coords: &[f64],
y_values: ArrayView1<F>,
target_x: &[f64],
output: &mut [F],
) {
let len = x_coords.len();
if len < 2 {
output.fill(F::zero());
return;
}
let catmull_rom = |p0: f64, p1: f64, p2: f64, p3: f64, t: f64| -> f64 {
let t2 = t * t;
let t3 = t2 * t;
0.5 * ((2.0 * p1)
+ (-p0 + p2) * t
+ (2.0 * p0 - 5.0 * p1 + 4.0 * p2 - p3) * t2
+ (-p0 + 3.0 * p1 - 3.0 * p2 + p3) * t3)
};
let mut last_idx = 0;
for (out_val, &tx) in output.iter_mut().zip(target_x.iter()) {
let tx = tx.clamp(x_coords[0], x_coords[len - 1]);
while last_idx < len - 2 && x_coords[last_idx + 1] <= tx {
last_idx += 1;
}
let i = last_idx;
let p1_idx = i;
let p2_idx = (i + 1).min(len - 1);
let p0_idx = i.saturating_sub(1);
let p3_idx = (i + 2).min(len - 1);
let x1 = x_coords[p1_idx];
let x2 = x_coords[p2_idx];
let dist = x2 - x1;
let val = if dist.abs() < 1e-6 {
y_values[p1_idx].to_f64().unwrap_or(0.0)
} else {
let t = (tx - x1) / dist;
let y0 = y_values[p0_idx].to_f64().unwrap_or(0.0);
let y1 = y_values[p1_idx].to_f64().unwrap_or(0.0);
let y2 = y_values[p2_idx].to_f64().unwrap_or(0.0);
let y3 = y_values[p3_idx].to_f64().unwrap_or(0.0);
catmull_rom(y0, y1, y2, y3, t)
};
*out_val = F::from_f64_c(val);
}
}
pub fn debin_horizontal<F: Bm3dFloat>(
arr: ArrayView2<F>,
target_width: usize,
factor: usize,
iterations: usize,
) -> Array2<F> {
if factor <= 1 {
return arr.to_owned();
}
let (rows, binned_cols) = arr.dim();
let h_half = factor / 2;
let n_counter = compute_bin_weights::<F>(target_width, factor);
let x1c: Vec<usize> = (0..binned_cols)
.map(|i| h_half + (if factor % 2 == 1 { 1 } else { 0 }) + i * factor)
.filter(|&c| c < target_width)
.collect();
let x1: Vec<f64> = (0..binned_cols)
.map(|i| {
(h_half as f64) + 1.0 - (if factor.is_multiple_of(2) { 0.5 } else { 0.0 })
+ (i * factor) as f64
})
.collect();
let ix1: Vec<f64> = (1..=target_width).map(|i| i as f64).collect();
let mut y_j = Array2::zeros((rows, target_width));
let mut row_buffer = vec![F::zero(); target_width];
for iter in 0..iterations.max(1) {
let r_j = if iter > 0 {
let binned_y_j = bin_horizontal(y_j.view(), factor);
let min_cols = arr.ncols().min(binned_y_j.ncols());
let mut residual = Array2::zeros((rows, min_cols));
for r in 0..rows {
for c in 0..min_cols {
residual[[r, c]] = arr[[r, c]] - binned_y_j[[r, c]];
}
}
residual
} else {
arr.to_owned()
};
let mut normalized = Array2::zeros((rows, r_j.ncols()));
for r in 0..rows {
for (c, &idx) in x1c.iter().enumerate() {
if c < r_j.ncols() && idx < target_width {
let weight = n_counter[idx];
if weight > F::zero() {
normalized[[r, c]] = r_j[[r, c]] / weight;
}
}
}
}
let x1_trimmed: Vec<f64> = x1.iter().take(normalized.ncols()).copied().collect();
if x1_trimmed.len() >= 2 {
for r in 0..rows {
let row_values = normalized.row(r);
let row_slice = row_values.slice(s![..x1_trimmed.len()]);
cubic_spline_interpolate_row(&x1_trimmed, row_slice, &ix1, &mut row_buffer);
for (c, &val) in row_buffer.iter().enumerate() {
if c < target_width {
y_j[[r, c]] += val;
}
}
}
}
}
y_j
}
pub fn generate_psd_shapes<F: Bm3dFloat>(denoise_sizes: &[usize]) -> Vec<Array1<F>> {
let mut psd_shapes = Vec::with_capacity(denoise_sizes.len());
let kernel_half_vec = compute_residual_kernel::<f64>();
for i in 0..denoise_sizes.len() {
let sz = denoise_sizes[i];
if i < denoise_sizes.len() - 1 {
let kernel_half_len = kernel_half_vec.len();
let full_kernel_len = kernel_half_len * 2 - 1;
let trim_amount = if full_kernel_len > sz {
(full_kernel_len - sz) / 2
} else {
0
};
let trimmed_start = trim_amount.min(kernel_half_len - 1);
let trimmed_half: Vec<f64> = kernel_half_vec[trimmed_start..].to_vec();
let mut residual_kernel: Vec<f64> = trimmed_half.clone();
for j in (0..trimmed_half.len() - 1).rev() {
residual_kernel.push(trimmed_half[j]);
}
let psd = compute_fft_psd::<F>(&residual_kernel, sz);
psd_shapes.push(psd);
} else {
let scale = F::from_f64_c(1.0 / (sz as f64).sqrt());
psd_shapes.push(Array1::from_elem(sz, scale));
}
}
psd_shapes
}
fn compute_fft_psd<F: Bm3dFloat>(kernel: &[f64], target_size: usize) -> Array1<F> {
use rustfft::{num_complex::Complex, FftPlanner};
let mut padded = vec![Complex::new(0.0, 0.0); target_size];
for (i, &val) in kernel.iter().enumerate() {
if i < target_size {
padded[i] = Complex::new(val, 0.0);
}
}
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(target_size);
fft.process(&mut padded);
Array1::from_iter(padded.iter().map(|c| F::from_f64_c(c.norm_sqr())))
}
fn estimate_noise_sigma_robust<F: Bm3dFloat>(sinogram: ArrayView2<F>) -> F {
let (rows, cols) = sinogram.dim();
if rows < 16 || cols < 16 {
return F::zero();
}
let patch_size = 64;
let mut patch_sigmas = Vec::new();
for r in (0..rows.saturating_sub(patch_size)).step_by(patch_size / 2) {
for c in (0..cols.saturating_sub(patch_size)).step_by(patch_size / 2) {
let patch = sinogram.slice(s![r..r + patch_size, c..c + patch_size]);
let sigma = estimate_patch_sigma_internal(patch);
if sigma > F::zero() {
patch_sigmas.push(sigma);
}
}
}
if patch_sigmas.is_empty() {
return estimate_patch_sigma_internal(sinogram); }
patch_sigmas.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
patch_sigmas[0]
}
fn estimate_patch_sigma_internal<F: Bm3dFloat>(patch: ArrayView2<F>) -> F {
let (rows, cols) = patch.dim();
if rows < 5 || cols < 5 {
return F::zero();
}
let sigma_v = 2.0;
let radius = (4.0f64 * sigma_v).ceil() as usize;
let width = 2 * radius + 1;
let mut v_kernel = Vec::with_capacity(width);
let mut v_sum = 0.0;
for i in 0..width {
let x = i as f64 - radius as f64;
let val = (-0.5 * (x / sigma_v).powi(2)).exp();
v_kernel.push(F::from_f64_c(val));
v_sum += val;
}
let v_norm = F::from_f64_c(v_sum);
for k in &mut v_kernel {
*k /= v_norm;
}
let db3_coeffs: [f64; 6] = [
0.03522629,
-0.08544127,
0.13501102,
0.45987750,
-0.80689151,
0.33267055,
];
let h_kernel: Vec<F> = db3_coeffs.iter().map(|&x| F::from_f64_c(x)).collect();
let smoothed = gaussian_filter_1d_vertical_internal(patch, &v_kernel);
let filtered = convolve_1d_horizontal_internal(smoothed.view(), &h_kernel);
let mad = compute_mad_internal(filtered.view());
mad * F::from_f64_c(3.96)
}
fn gaussian_filter_1d_vertical_internal<F: Bm3dFloat>(
data: ArrayView2<F>,
kernel: &[F],
) -> Array2<F> {
let (rows, cols) = data.dim();
let k_len = kernel.len();
let radius = k_len / 2;
let mut output = Array2::zeros((rows, cols));
for c in 0..cols {
for r in 0..rows {
let mut sum = F::zero();
for (k, &k_val) in kernel.iter().enumerate() {
let k_idx = k as isize - radius as isize;
let src_r = (r as isize + k_idx).clamp(0, (rows - 1) as isize);
sum += data[[src_r as usize, c]] * k_val;
}
output[[r, c]] = sum;
}
}
output
}
fn convolve_1d_horizontal_internal<F: Bm3dFloat>(data: ArrayView2<F>, kernel: &[F]) -> Array2<F> {
let (rows, cols) = data.dim();
let k_len = kernel.len();
let radius = k_len / 2;
let mut output = Array2::zeros((rows, cols));
for r in 0..rows {
for c in 0..cols {
let mut sum = F::zero();
for (k, &k_val) in kernel.iter().enumerate() {
let k_idx = k as isize - radius as isize;
let src_c = (c as isize + k_idx).clamp(0, (cols - 1) as isize);
sum += data[[r, src_c as usize]] * k_val;
}
output[[r, c]] = sum;
}
}
output
}
fn compute_mad_internal<F: Bm3dFloat>(data: ArrayView2<F>) -> F {
let mut flat_data: Vec<F> = data.iter().cloned().collect();
if flat_data.is_empty() {
return F::zero();
}
let len = flat_data.len();
let mid = len / 2;
let (_, &mut median, _) =
flat_data.select_nth_unstable_by(mid, |a, b| a.partial_cmp(b).unwrap());
let median_val = if len % 2 == 1 {
median
} else {
let left = &flat_data[..mid];
let prev = left
.iter()
.fold(F::neg_infinity(), |a, &b| if b > a { b } else { a });
(prev + median) / F::from_f64_c(2.0)
};
let mut deviations: Vec<F> = flat_data.iter().map(|&x| (x - median_val).abs()).collect();
let (_, &mut dev_median, _) =
deviations.select_nth_unstable_by(mid, |a, b| a.partial_cmp(b).unwrap());
if len % 2 == 1 {
dev_median
} else {
let left = &deviations[..mid];
let prev = left
.iter()
.fold(F::neg_infinity(), |a, &b| if b > a { b } else { a });
(prev + dev_median) / F::from_f64_c(2.0)
}
}
pub fn multiscale_bm3d_streak_removal<F: Bm3dFloat>(
sinogram: ArrayView2<F>,
config: &MultiscaleConfig<F>,
) -> Result<Array2<F>, String> {
let plans = crate::pipeline::Bm3dPlans::new(
config.bm3d_config.patch_size,
config.bm3d_config.max_matches,
);
multiscale_bm3d_streak_removal_with_plans(sinogram, config, &plans)
}
pub fn multiscale_bm3d_streak_removal_with_plans<F: Bm3dFloat>(
sinogram: ArrayView2<F>,
config: &MultiscaleConfig<F>,
plans: &crate::pipeline::Bm3dPlans<F>,
) -> Result<Array2<F>, String> {
config.validate()?;
let (rows, cols) = sinogram.dim();
let estimated_global = estimate_noise_sigma_robust(sinogram);
let sigma_global = if config.bm3d_config.sigma_random > F::from_f64_c(1e-6) {
config.bm3d_config.sigma_random
} else {
estimated_global
};
let noise_var_threshold = (sigma_global * sigma_global) * F::from_f64_c(2.0);
let mut fine_col_vars = Vec::with_capacity(cols);
let n_rows_f = F::from_f64_c(rows as f64);
for c in 0..cols {
let col = sinogram.column(c);
let mut col_data: Vec<F> = col.to_vec();
let mid = col_data.len() / 2;
col_data.select_nth_unstable_by(mid, |a, b| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
});
let med = col_data[mid];
let mut sum_sq = F::zero();
for &val in col {
let diff = val - med;
sum_sq += diff * diff;
}
let var = sum_sq / n_rows_f;
fine_col_vars.push(var);
}
let num_scales = config
.num_scales
.unwrap_or_else(|| compute_num_scales(cols));
if num_scales == 0 {
return bm3d_ring_artifact_removal(sinogram, RingRemovalMode::Streak, &config.bm3d_config);
}
let mut pyramid_orig: Vec<Array2<F>> = Vec::with_capacity(num_scales + 1);
pyramid_orig.push(sinogram.to_owned());
for _ in 0..num_scales {
let last = pyramid_orig.last().unwrap();
let binned = bin_horizontal(last.view(), BINNING_FACTOR);
pyramid_orig.push(binned);
}
let mut pyramid_work: Vec<Array2<F>> = vec![Array2::zeros((0, 0)); num_scales + 1];
pyramid_work[num_scales] = pyramid_orig[num_scales].clone();
let denoise_sizes: Vec<usize> = pyramid_orig.iter().map(|arr| arr.ncols()).collect();
let _psd_shapes = generate_psd_shapes::<F>(&denoise_sizes);
let mut denoised: Option<Array2<F>> = None;
for scale in (0..=num_scales).rev() {
let img = pyramid_work[scale].clone();
let norm_factor = F::from_f64_c(2.0f64.powi(scale as i32));
if img.nrows() < config.bm3d_config.patch_size
|| img.ncols() < config.bm3d_config.patch_size
{
if scale > 0 {
let residual = &img - &pyramid_orig[scale];
let scale_correction = F::from_f64_c(BINNING_FACTOR as f64);
let residual_corrected = residual.mapv(|x| x / scale_correction);
let target_width = pyramid_orig[scale - 1].ncols();
let debinned_residual = debin_horizontal(
residual_corrected.view(),
target_width,
BINNING_FACTOR,
config.debin_iterations,
);
let target_rows = pyramid_orig[scale - 1].nrows();
let mut adjusted_residual = Array2::zeros((target_rows, target_width));
let copy_rows = debinned_residual.nrows().min(target_rows);
let copy_cols = debinned_residual.ncols().min(target_width);
adjusted_residual
.slice_mut(s![..copy_rows, ..copy_cols])
.assign(&debinned_residual.slice(s![..copy_rows, ..copy_cols]));
pyramid_work[scale - 1] = &pyramid_orig[scale - 1] + &adjusted_residual;
}
denoised = Some(img);
continue;
}
let mut scale_config = config.bm3d_config.clone();
scale_config.threshold = config.threshold * config.filter_strength;
let scaled_window = (config.bm3d_config.search_window >> scale).max(8);
scale_config.search_window = scaled_window;
let img_normalized = img.mapv(|x| x / norm_factor);
let estimated_sigma = estimate_noise_sigma_robust(img_normalized.view());
let min_sigma = F::from_f64_c(0.001);
if estimated_sigma > min_sigma {
scale_config.sigma_random = estimated_sigma;
}
let den_normalized = bm3d_ring_artifact_removal_with_plans(
img_normalized.view(),
RingRemovalMode::Streak,
&scale_config,
plans,
)?;
let den = den_normalized.mapv(|x| x * norm_factor);
if scale > 0 {
let residual = &den - &pyramid_orig[scale];
let (res_rows, res_cols) = residual.dim();
let n_rows_f = F::from_f64_c(res_rows as f64);
let mut v_profile = Vec::with_capacity(res_cols);
let bin_size = 1 << scale;
for c in 0..res_cols {
let col = residual.column(c);
let mut col_data: Vec<F> = col.to_vec();
let mid = col_data.len() / 2;
col_data.select_nth_unstable_by(mid, |a, b| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
});
let med = col_data[mid];
let med_sq = med * med;
let mut sum_ac_sq = F::zero();
for &val in col {
let ac = val - med;
sum_ac_sq += ac * ac;
}
let var_ac = sum_ac_sq / n_rows_f;
let mut score = if med_sq + var_ac > F::zero() {
med_sq / (med_sq + var_ac)
} else {
F::zero()
};
let start_idx = c * bin_size;
let end_idx = (start_idx + bin_size).min(cols);
let mut max_fine_var = F::zero();
for &val in &fine_col_vars[start_idx..end_idx] {
if val > max_fine_var {
max_fine_var = val;
}
}
if max_fine_var > noise_var_threshold {
score = F::zero();
}
let s2 = score * score;
let s4 = s2 * s2;
let gate = s4 * s4;
v_profile.push(med * gate);
}
let smooth_hump = compute_1d_median_filter(&v_profile, 201);
let mut spiky_profile = Array2::<F>::zeros((res_rows, res_cols));
for c in 0..res_cols {
let spike = v_profile[c] - smooth_hump[c];
for r in 0..res_rows {
spiky_profile[[r, c]] = spike;
}
}
let residual_filtered = spiky_profile;
let scale_correction = F::from_f64_c(BINNING_FACTOR as f64);
let residual_corrected = residual_filtered.mapv(|x| x / scale_correction);
let target_width = pyramid_orig[scale - 1].ncols();
let debinned_residual = debin_horizontal(
residual_corrected.view(),
target_width,
BINNING_FACTOR,
config.debin_iterations,
);
let target_rows = pyramid_orig[scale - 1].nrows();
let mut adjusted_residual = Array2::zeros((target_rows, target_width));
let copy_rows = debinned_residual.nrows().min(target_rows);
let copy_cols = debinned_residual.ncols().min(target_width);
adjusted_residual
.slice_mut(s![..copy_rows, ..copy_cols])
.assign(&debinned_residual.slice(s![..copy_rows, ..copy_cols]));
pyramid_work[scale - 1] = &pyramid_orig[scale - 1] + &adjusted_residual;
}
denoised = Some(den);
}
let final_denoised = denoised.ok_or("No scales processed")?;
let total_diff = &final_denoised - &sinogram;
let mut total_profile = Vec::with_capacity(cols);
for c in 0..cols {
let mut col_data: Vec<F> = total_diff.column(c).to_vec();
let mid = col_data.len() / 2;
col_data.select_nth_unstable_by(mid, |a, b| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
});
total_profile.push(col_data[mid]);
}
let final_hump = compute_1d_median_filter(&total_profile, 201);
let mag_threshold_wall = sigma_global * F::from_f64_c(8.0);
let silence_threshold = sigma_global * F::from_f64_c(0.5);
let mut diff_profile = Vec::with_capacity(cols);
for c in 0..cols {
diff_profile.push(total_profile[c] - final_hump[c]);
}
let mut c = 0;
while c < cols {
if diff_profile[c].abs() > silence_threshold {
let start = c;
let mut end = c;
let mut peak_mag = F::zero();
while end < cols && diff_profile[end].abs() > silence_threshold {
let mag = diff_profile[end].abs();
if mag > peak_mag {
peak_mag = mag;
}
end += 1;
}
let width = end - start;
if width > 20 && peak_mag > mag_threshold_wall {
diff_profile[start..end].fill(F::zero());
}
c = end;
} else {
c += 1;
}
}
let mut final_output = sinogram.to_owned();
for c in 0..cols {
let streak_val = diff_profile[c];
for r in 0..rows {
final_output[[r, c]] += streak_val;
}
}
Ok(final_output)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
struct SimpleLcg {
state: u64,
}
impl SimpleLcg {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
self.state
}
fn next_f32(&mut self) -> f32 {
let u = self.next_u64();
(u >> 40) as f32 / (1u64 << 24) as f32
}
}
fn random_matrix(rows: usize, cols: usize, seed: u64) -> Array2<f32> {
let mut rng = SimpleLcg::new(seed);
Array2::from_shape_fn((rows, cols), |_| rng.next_f32())
}
fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() < eps
}
#[test]
fn test_compute_num_scales_small() {
assert_eq!(compute_num_scales(40), 0);
assert_eq!(compute_num_scales(30), 0);
assert_eq!(compute_num_scales(1), 0);
}
#[test]
fn test_compute_num_scales_medium() {
assert_eq!(compute_num_scales(80), 1);
assert_eq!(compute_num_scales(160), 2);
}
#[test]
fn test_compute_num_scales_large() {
assert_eq!(compute_num_scales(640), 3);
assert_eq!(compute_num_scales(1280), 3);
}
#[test]
fn test_binning_reduces_width() {
let arr = random_matrix(32, 100, 12345);
let binned = bin_horizontal(arr.view(), 2);
assert!(binned.ncols() <= 51); assert!(binned.ncols() >= 49);
assert_eq!(binned.nrows(), 32); }
#[test]
fn test_binning_factor_1_identity() {
let arr = random_matrix(32, 64, 11111);
let binned = bin_horizontal(arr.view(), 1);
assert_eq!(binned.dim(), arr.dim());
for (a, b) in arr.iter().zip(binned.iter()) {
assert!(approx_eq(*a, *b, 1e-6));
}
}
#[test]
fn test_binning_preserves_sum_approximately() {
let arr = random_matrix(16, 64, 22222);
let sum_orig: f32 = arr.iter().sum();
let binned = bin_horizontal(arr.view(), 2);
let sum_binned: f32 = binned.iter().sum();
assert!(
(sum_binned - sum_orig).abs() < sum_orig * 0.5,
"Sum changed too much: {} vs {}",
sum_binned,
sum_orig
);
}
#[test]
fn test_debin_restores_width() {
let original = random_matrix(32, 100, 33333);
let binned = bin_horizontal(original.view(), 2);
let debinned = debin_horizontal(binned.view(), 100, 2, 30);
assert_eq!(debinned.dim(), (32, 100));
}
#[test]
fn test_bin_debin_approximate_recovery() {
let original = random_matrix(16, 64, 44444);
let binned = bin_horizontal(original.view(), 2);
let debinned = debin_horizontal(binned.view(), 64, 2, 30);
let mae: f32 = original
.iter()
.zip(debinned.iter())
.map(|(a, b)| (a - b).abs())
.sum::<f32>()
/ (original.len() as f32);
let mean_val: f32 = original.iter().sum::<f32>() / (original.len() as f32);
assert!(
mae < mean_val * 0.5,
"MAE {} too large relative to mean {}",
mae,
mean_val
);
}
#[test]
fn test_debin_factor_1_identity() {
let arr = random_matrix(16, 32, 55555);
let debinned = debin_horizontal(arr.view(), 32, 1, 30);
assert_eq!(debinned.dim(), arr.dim());
}
#[test]
fn test_psd_shapes_count() {
let sizes = vec![256, 128, 64, 32];
let shapes = generate_psd_shapes::<f32>(&sizes);
assert_eq!(shapes.len(), 4);
}
#[test]
fn test_psd_shapes_sizes() {
let sizes = vec![256, 128, 64, 32];
let shapes = generate_psd_shapes::<f32>(&sizes);
for (i, shape) in shapes.iter().enumerate() {
assert_eq!(shape.len(), sizes[i], "Shape {} has wrong size", i);
}
}
#[test]
fn test_psd_coarsest_flat() {
let sizes = vec![128, 64, 32];
let shapes = generate_psd_shapes::<f32>(&sizes);
let coarsest = &shapes[shapes.len() - 1];
let first_val = coarsest[0];
for &val in coarsest.iter() {
assert!(
approx_eq(val, first_val, 1e-6),
"Coarsest PSD should be flat"
);
}
}
#[test]
fn test_psd_all_non_negative() {
let sizes = vec![256, 128, 64];
let shapes = generate_psd_shapes::<f32>(&sizes);
for (i, shape) in shapes.iter().enumerate() {
for &val in shape.iter() {
assert!(val >= 0.0, "PSD shape {} has negative value", i);
}
}
}
#[test]
fn test_default_config() {
let config: MultiscaleConfig<f32> = MultiscaleConfig::default();
assert!(config.num_scales.is_none());
assert!(approx_eq(config.filter_strength, 1.0, 1e-6));
assert!(approx_eq(config.threshold, 3.5, 1e-6));
assert_eq!(config.debin_iterations, 10);
}
#[test]
fn test_config_validation_valid() {
let config: MultiscaleConfig<f32> = MultiscaleConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_config_validation_invalid_filter_strength() {
let config: MultiscaleConfig<f32> = MultiscaleConfig {
filter_strength: 0.0,
..MultiscaleConfig::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_config_validation_invalid_debin_iterations() {
let config: MultiscaleConfig<f32> = MultiscaleConfig {
debin_iterations: 0,
..MultiscaleConfig::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_multiscale_smoke_small() {
let image = random_matrix(32, 32, 66666);
let config = MultiscaleConfig::default();
let result = multiscale_bm3d_streak_removal(image.view(), &config);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.dim(), image.dim());
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_multiscale_smoke_medium() {
let image = random_matrix(64, 128, 77777);
let config = MultiscaleConfig::default();
let result = multiscale_bm3d_streak_removal(image.view(), &config);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.dim(), image.dim());
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_k0_equals_single_scale() {
let image = random_matrix(32, 32, 88888);
let config = MultiscaleConfig::default();
let multiscale_result = multiscale_bm3d_streak_removal(image.view(), &config).unwrap();
let single_result =
bm3d_ring_artifact_removal(image.view(), RingRemovalMode::Streak, &config.bm3d_config)
.unwrap();
for (a, b) in multiscale_result.iter().zip(single_result.iter()) {
assert!(
approx_eq(*a, *b, 1e-5),
"K=0 results differ: {} vs {}",
a,
b
);
}
}
#[test]
fn test_multiscale_differs_from_single() {
let image = random_matrix(64, 256, 99999);
let single_config = Bm3dConfig {
threshold: 3.5, ..Bm3dConfig::default()
};
let single_result =
bm3d_ring_artifact_removal(image.view(), RingRemovalMode::Streak, &single_config)
.unwrap();
let multi_config = MultiscaleConfig::default();
let multi_result = multiscale_bm3d_streak_removal(image.view(), &multi_config).unwrap();
let diff: f32 = single_result
.iter()
.zip(multi_result.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff > 0.01,
"Multiscale should produce different results than single-scale"
);
}
#[test]
fn test_multiscale_handles_wide_streak() {
let mut image = Array2::from_elem((64, 256), 0.5f32);
for r in 0..64 {
for c in 60..64 {
image[[r, c]] = 0.9;
}
}
let config = MultiscaleConfig {
bm3d_config: Bm3dConfig {
sigma_random: 0.5,
..Bm3dConfig::default()
},
..MultiscaleConfig::default()
};
let result = multiscale_bm3d_streak_removal(image.view(), &config);
assert!(result.is_ok());
let output = result.unwrap();
let col_means: Vec<f32> = (0..256)
.map(|c| {
let sum: f32 = (0..64).map(|r| output[[r, c]]).sum();
sum / 64.0
})
.collect();
let overall_mean: f32 = col_means.iter().sum::<f32>() / 256.0;
let col_variance: f32 = col_means
.iter()
.map(|m| (m - overall_mean).powi(2))
.sum::<f32>()
/ 256.0;
let orig_col_means: Vec<f32> = (0..256)
.map(|c| {
let sum: f32 = (0..64).map(|r| image[[r, c]]).sum();
sum / 64.0
})
.collect();
let orig_overall_mean: f32 = orig_col_means.iter().sum::<f32>() / 256.0;
let orig_col_variance: f32 = orig_col_means
.iter()
.map(|m| (m - orig_overall_mean).powi(2))
.sum::<f32>()
/ 256.0;
assert!(
col_variance < orig_col_variance,
"Multiscale should reduce wide streak variance: {} >= {}",
col_variance,
orig_col_variance
);
}
#[test]
fn test_multiscale_f64() {
let image = Array2::from_shape_fn((32, 128), |(r, c)| (r * 128 + c) as f64 / 4096.0);
let config: MultiscaleConfig<f64> = MultiscaleConfig::default();
let result = multiscale_bm3d_streak_removal(image.view(), &config);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.dim(), image.dim());
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_override_num_scales() {
let image = random_matrix(64, 256, 10101);
let config = MultiscaleConfig {
num_scales: Some(1),
..MultiscaleConfig::default()
};
let result = multiscale_bm3d_streak_removal(image.view(), &config);
assert!(result.is_ok());
}
#[test]
fn test_force_k0() {
let image = random_matrix(64, 256, 20202);
let config = MultiscaleConfig {
num_scales: Some(0),
..MultiscaleConfig::default()
};
let multi_result = multiscale_bm3d_streak_removal(image.view(), &config).unwrap();
let single_result =
bm3d_ring_artifact_removal(image.view(), RingRemovalMode::Streak, &config.bm3d_config)
.unwrap();
for (a, b) in multi_result.iter().zip(single_result.iter()) {
assert!(approx_eq(*a, *b, 1e-5), "Forced K=0 differs from single");
}
}
}
pub fn compute_residual_kernel<F: Bm3dFloat>() -> Vec<f64> {
let kernel0 = compute_residual_kernel_shift::<F>(0);
let kernel1 = compute_residual_kernel_shift::<F>(1);
kernel0
.iter()
.zip(kernel1.iter())
.map(|(a, b)| (a + b) / 2.0)
.collect()
}
fn compute_residual_kernel_shift<F: Bm3dFloat>(shift: usize) -> Vec<f64> {
let factor = BINNING_FACTOR; let iterations = DEFAULT_DEBIN_ITERATIONS;
let width = 512;
let center = width / 2 + shift;
let mut impulse = Array2::<F>::zeros((1, width));
impulse[[0, center]] = F::one();
let binned = bin_horizontal(impulse.view(), factor);
let scale_correction = F::from_f64_c(factor as f64);
let binned_normalized = binned.mapv(|x| x / scale_correction);
let rebuilt = debin_horizontal(binned_normalized.view(), width, factor, iterations);
let mut residual_kernel = Vec::new();
let kernel_radius = 18;
for i in (center - kernel_radius)..=(center) {
let val = impulse[[0, i]] - rebuilt[[0, i]];
residual_kernel.push(val.to_f64().unwrap());
}
residual_kernel
}
#[cfg(test)]
mod kernel_tests {
use super::*;
#[test]
fn test_verify_residual_kernel_properties() {
let computed = compute_residual_kernel::<f64>();
let center_idx = 18;
let center_val = computed[center_idx];
assert!(
center_val > 0.7,
"Center value should be significant (residual of peak)"
);
let mut full_kernel = computed.clone();
for i in (0..center_idx).rev() {
full_kernel.push(computed[i]);
}
let sum: f64 = full_kernel.iter().sum();
assert!(
(sum - 0.5).abs() < 0.1,
"Residual kernel sum reflects pipeline DC damping (0.5)"
);
}
}