use crate::float_trait::Bm3dFloat;
use crate::transforms;
use crate::utils::{compute_1d_median_filter, estimate_robust_sigma};
use ndarray::{Array1, Array2, ArrayView2};
use rustfft::num_complex::Complex;
use rustfft::FftPlanner;
fn power_iteration_k1<F: Bm3dFloat>(
matrix: ArrayView2<F>,
max_iter: usize,
_tol: F,
) -> (Array1<F>, F, Array1<F>) {
let (rows, cols) = matrix.dim();
let init_val = F::one() / F::from_f64_c((cols as f64).sqrt());
let mut v = Array1::from_elem(cols, init_val);
let mut u = Array1::zeros(rows);
let mut s = F::zero();
let epsilon = F::from_f64_c(1e-10);
for _ in 0..max_iter {
u = matrix.dot(&v);
let u_norm = u.dot(&u).sqrt();
if u_norm < epsilon {
break;
}
u.mapv_inplace(|x| x / u_norm);
v = matrix.t().dot(&u);
let v_norm = v.dot(&v).sqrt();
s = v_norm;
if v_norm < epsilon {
break;
}
v.mapv_inplace(|x| x / v_norm);
}
(u, s, v)
}
fn compute_vertical_energy_profile<F: Bm3dFloat>(
sinogram: ArrayView2<F>,
notch_width: F,
) -> Array1<F> {
let (rows, cols) = sinogram.dim();
let mut planner = FftPlanner::<F>::new();
let fft_row = planner.plan_fft_forward(cols);
let fft_col = planner.plan_fft_forward(rows);
let ifft_row = planner.plan_fft_inverse(cols);
let ifft_col = planner.plan_fft_inverse(rows);
let freq_domain = transforms::fft2d(sinogram, &fft_row, &fft_col);
let mut filtered_freq = freq_domain;
let neg_half = F::from_f64_c(-0.5);
let sigma_sq = notch_width * notch_width;
let rows_f = F::usize_as(rows);
let rows_half = rows_f / F::from_f64_c(2.0);
let cols_f = F::usize_as(cols);
let cols_half = cols_f / F::from_f64_c(2.0);
let mut x_weights = Vec::with_capacity(cols);
for c in 0..cols {
let c_f = F::usize_as(c);
let dist = if c_f <= cols_half { c_f } else { cols_f - c_f };
let dist_sq = dist * dist;
let low_pass = (neg_half * dist_sq / sigma_sq).exp();
x_weights.push(F::one() - low_pass);
}
for r in 0..rows {
let r_f = F::usize_as(r);
let dist = if r_f <= rows_half { r_f } else { rows_f - r_f };
let dist_sq = dist * dist;
let y_weight = (neg_half * dist_sq / sigma_sq).exp();
for c in 0..cols {
let w_val = y_weight * x_weights[c];
let w_complex = Complex::new(w_val, F::zero());
filtered_freq[[r, c]] *= w_complex;
}
}
let spatial_filtered = transforms::ifft2d(&filtered_freq, &ifft_row, &ifft_col);
let mut energy_profile = Array1::<F>::zeros(cols);
let rows_f_inv = F::one() / rows_f;
for c in 0..cols {
let mut sum_abs = F::zero();
for r in 0..rows {
sum_abs += spatial_filtered[[r, c]].abs();
}
energy_profile[c] = sum_abs * rows_f_inv;
}
let mut energy_vec: Vec<F> = energy_profile.to_vec();
energy_vec.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let len = energy_vec.len();
let median = if len > 0 {
if len % 2 == 1 {
energy_vec[len / 2]
} else {
(energy_vec[len / 2 - 1] + energy_vec[len / 2]) * F::from_f64_c(0.5)
}
} else {
F::one()
};
if median > F::from_f64_c(1e-10) {
let inv_med = F::one() / median;
energy_profile.mapv_inplace(|x| x * inv_med);
}
energy_profile
}
pub fn fourier_svd_removal<F: Bm3dFloat>(
sinogram: ArrayView2<F>,
fft_alpha: F,
notch_width: F,
) -> Array2<F> {
let (rows, cols) = sinogram.dim();
let (u, s, v) = power_iteration_k1(sinogram, 20, F::from_f64_c(1e-6));
let v_slice = v.as_slice().unwrap(); let v_smooth_vec = compute_1d_median_filter(v_slice, 51);
let v_smooth = Array1::from(v_smooth_vec);
let v_detail = &v - &v_smooth;
let sigma = estimate_robust_sigma(v_detail.view());
let base_thresh = F::from_f64_c(sigma * 3.0);
let modulator = if fft_alpha > F::from_f64_c(1e-6) {
let energy = compute_vertical_energy_profile(sinogram, notch_width);
Some(energy.mapv(|e| F::one() + fft_alpha * e))
} else {
None
};
let exponent = 6;
let mut v_streak = Array1::<F>::zeros(cols);
for c in 0..cols {
let x = v_detail[c];
let thresh = if let Some(ref m) = modulator {
base_thresh * m[c]
} else {
base_thresh
};
let mask = if thresh > F::from_f64_c(1e-10) {
let ratio = x.abs() / thresh;
F::one() / (F::one() + ratio.powi(exponent))
} else {
F::zero()
};
v_streak[c] = x * mask;
}
let scaled_u = u.mapv(|x| x * s);
let mut corrected = sinogram.to_owned();
for r in 0..rows {
let u_val = scaled_u[r];
for c in 0..cols {
let streak_val = u_val * v_streak[c];
corrected[[r, c]] -= streak_val;
}
}
corrected
}