use ndarray::Array2;
use realfft::RealFftPlanner;
use rustfft::{FftPlanner, num_complex::Complex};
use thiserror::Error;
use crate::beam::{Beam, gauss_factor};
#[derive(Debug, Error)]
pub enum ConvolveError {
#[error("image is entirely NaN")]
AllNaN,
#[error("beam larger than cutoff — image blanked")]
AboveCutoff,
}
pub struct ConvolutionResult {
pub image: Array2<f32>,
pub scaling_factor: f64,
}
pub fn convolve_uv(
image: &Array2<f32>,
old_beam: &Beam,
new_beam: &Beam,
dx_deg: f64,
dy_deg: f64,
cutoff_arcsec: Option<f64>,
) -> Result<ConvolutionResult, ConvolveError> {
if let Some(cutoff) = cutoff_arcsec
&& old_beam.major_arcsec() > cutoff
{
return Err(ConvolveError::AboveCutoff);
}
if old_beam.approx_eq(new_beam) {
return Ok(ConvolutionResult {
image: image.clone(),
scaling_factor: 1.0,
});
}
let conv_beam = new_beam.deconvolve_or_zero(old_beam);
let (fac, ..) = gauss_factor(
&conv_beam,
old_beam,
dx_deg.abs() * 3600.0,
dy_deg.abs() * 3600.0,
);
if image.iter().all(|x| x.is_nan()) {
return Ok(ConvolutionResult {
image: image.clone(),
scaling_factor: fac,
});
}
let (nrows, ncols) = image.dim();
let has_nan = image.iter().any(|x| x.is_nan());
let (clean_image, nan_mask): (Vec<f64>, Option<Vec<f64>>) = if has_nan {
let vals: Vec<f64> = image
.iter()
.map(|&x| if x.is_nan() { 0.0 } else { x as f64 })
.collect();
let mask: Vec<f64> = image
.iter()
.map(|&x| if x.is_nan() { 1.0 } else { 0.0 })
.collect();
(vals, Some(mask))
} else {
let vals: Vec<f64> = image.iter().map(|&x| x as f64).collect();
(vals, None)
};
let nhalf = ncols / 2 + 1;
let dx_rad = dx_deg.to_radians();
let dy_rad = dy_deg.to_radians();
let u_freqs = fftfreq(nrows, dx_rad); let v_freqs_full = fftfreq(ncols, dy_rad);
let v_freqs = &v_freqs_full[..nhalf];
let (g_final, g_ratio) = gaussft(old_beam, new_beam, &u_freqs, v_freqs);
let mut im_f = rfft2(&clean_image, nrows, ncols);
for (s, &g) in im_f.iter_mut().zip(g_final.iter()) {
*s *= g;
}
let im_conv_flat = irfft2(im_f, nrows, ncols);
let out_flat: Vec<f32> = if let Some(mask) = nan_mask {
let mut mask_f = rfft2(&mask, nrows, ncols);
for (s, &g) in mask_f.iter_mut().zip(g_final.iter()) {
*s *= g;
}
let mask_conv = irfft2(mask_f, nrows, ncols);
im_conv_flat
.iter()
.zip(mask_conv.iter())
.map(|(&v, &m)| if m >= 1.0 { f32::NAN } else { v as f32 })
.collect()
} else {
im_conv_flat.iter().map(|&v| v as f32).collect()
};
let out = Array2::from_shape_vec((nrows, ncols), out_flat)
.expect("shape mismatch in convolve_uv output");
Ok(ConvolutionResult {
image: out,
scaling_factor: g_ratio,
})
}
pub fn gaussft(
old_beam: &Beam,
new_beam: &Beam,
u_freqs: &[f64],
v_freqs: &[f64],
) -> (Vec<f64>, f64) {
let deg2rad = std::f64::consts::PI / 180.0;
let two_ln2 = 2.0 * 2_f64.ln();
let fwhm_to_sigma = 2.0 * two_ln2.sqrt();
let bmaj_rad = new_beam.major_deg * deg2rad;
let bmin_rad = new_beam.minor_deg * deg2rad;
let bpa_rad = new_beam.pa_deg * deg2rad;
let sx = bmaj_rad / fwhm_to_sigma;
let sy = bmin_rad / fwhm_to_sigma;
let bmaj_in_rad = old_beam.major_deg * deg2rad;
let bmin_in_rad = old_beam.minor_deg * deg2rad;
let bpa_in_rad = old_beam.pa_deg * deg2rad;
let sx_in = bmaj_in_rad / fwhm_to_sigma;
let sy_in = bmin_in_rad / fwhm_to_sigma;
let g_amp = (2.0 * std::f64::consts::PI * sx * sy).sqrt();
let dg_amp = (2.0 * std::f64::consts::PI * sx_in * sy_in).sqrt();
let g_ratio = g_amp / dg_amp;
let pi2 = std::f64::consts::PI * std::f64::consts::PI;
let nrows = u_freqs.len();
let ncols = v_freqs.len();
let mut g_final = vec![0.0_f64; nrows * ncols];
let u_cos = u_freqs
.iter()
.map(|&u| u * bpa_rad.cos())
.collect::<Vec<_>>();
let u_sin = u_freqs
.iter()
.map(|&u| u * bpa_rad.sin())
.collect::<Vec<_>>();
let v_cos = v_freqs
.iter()
.map(|&v| v * bpa_rad.cos())
.collect::<Vec<_>>();
let v_sin = v_freqs
.iter()
.map(|&v| v * bpa_rad.sin())
.collect::<Vec<_>>();
let u_cos_in = u_freqs
.iter()
.map(|&u| u * bpa_in_rad.cos())
.collect::<Vec<_>>();
let u_sin_in = u_freqs
.iter()
.map(|&u| u * bpa_in_rad.sin())
.collect::<Vec<_>>();
let v_cos_in = v_freqs
.iter()
.map(|&v| v * bpa_in_rad.cos())
.collect::<Vec<_>>();
let v_sin_in = v_freqs
.iter()
.map(|&v| v * bpa_in_rad.sin())
.collect::<Vec<_>>();
for i in 0..nrows {
for j in 0..ncols {
let ur = u_cos[i] - v_sin[j];
let vr = u_sin[i] + v_cos[j];
let ur_in = u_cos_in[i] - v_sin_in[j];
let vr_in = u_sin_in[i] + v_cos_in[j];
let g_arg = -2.0 * pi2 * ((sx * ur).powi(2) + (sy * vr).powi(2));
let dg_arg = -2.0 * pi2 * ((sx_in * ur_in).powi(2) + (sy_in * vr_in).powi(2));
g_final[i * ncols + j] = g_ratio * (g_arg - dg_arg).exp();
}
}
(g_final, g_ratio)
}
pub fn fftfreq(n: usize, d: f64) -> Vec<f64> {
let val = 1.0 / (n as f64 * d);
let m = n.div_ceil(2); let mut freqs = vec![0.0_f64; n];
for (i, freq) in freqs.iter_mut().enumerate().take(m) {
*freq = i as f64 * val;
}
for (i, freq) in freqs.iter_mut().enumerate().take(n).skip(m) {
*freq = (i as f64 - n as f64) * val;
}
freqs
}
fn rfft2(data: &[f64], nrows: usize, ncols: usize) -> Vec<Complex<f64>> {
let nhalf = ncols / 2 + 1;
let mut rplanner = RealFftPlanner::<f64>::new();
let r2c = rplanner.plan_fft_forward(ncols);
let mut scratch = r2c.make_scratch_vec();
let mut inrow = r2c.make_input_vec();
let mut spectrum = vec![Complex::new(0.0, 0.0); nrows * nhalf];
for (i, chunk) in data.chunks(ncols).enumerate() {
inrow.copy_from_slice(chunk);
r2c.process_with_scratch(
&mut inrow,
&mut spectrum[i * nhalf..(i + 1) * nhalf],
&mut scratch,
)
.expect("r2c FFT");
}
let col_fft = FftPlanner::new().plan_fft_forward(nrows);
let mut col_buf = vec![Complex::new(0.0, 0.0); nrows];
for j in 0..nhalf {
for i in 0..nrows {
col_buf[i] = spectrum[i * nhalf + j];
}
col_fft.process(&mut col_buf);
for i in 0..nrows {
spectrum[i * nhalf + j] = col_buf[i];
}
}
spectrum
}
fn irfft2(mut spectrum: Vec<Complex<f64>>, nrows: usize, ncols: usize) -> Vec<f64> {
let nhalf = ncols / 2 + 1;
let col_ifft = FftPlanner::new().plan_fft_inverse(nrows);
let mut col_buf = vec![Complex::new(0.0, 0.0); nrows];
for j in 0..nhalf {
for i in 0..nrows {
col_buf[i] = spectrum[i * nhalf + j];
}
col_ifft.process(&mut col_buf);
for i in 0..nrows {
spectrum[i * nhalf + j] = col_buf[i];
}
}
let mut rplanner = RealFftPlanner::<f64>::new();
let c2r = rplanner.plan_fft_inverse(ncols);
let mut scratch = c2r.make_scratch_vec();
let mut inrow = c2r.make_input_vec();
let mut out = vec![0.0_f64; nrows * ncols];
let even = ncols.is_multiple_of(2);
for i in 0..nrows {
inrow.copy_from_slice(&spectrum[i * nhalf..(i + 1) * nhalf]);
inrow[0].im = 0.0;
if even {
inrow[nhalf - 1].im = 0.0;
}
c2r.process_with_scratch(
&mut inrow,
&mut out[i * ncols..(i + 1) * ncols],
&mut scratch,
)
.expect("c2r FFT");
}
let norm = (nrows * ncols) as f64;
for v in out.iter_mut() {
*v /= norm;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn test_fftfreq() {
let f = fftfreq(4, 1.0);
let expected = [0.0, 0.25, -0.5, -0.25];
for (a, b) in f.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-12, "got {a}, want {b}");
}
}
#[test]
fn test_rfft2_irfft2_roundtrip() {
let data = vec![
1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0,
];
let (nrows, ncols) = (4, 4);
let spectrum = rfft2(&data, nrows, ncols);
let recovered = irfft2(spectrum, nrows, ncols);
for (a, b) in data.iter().zip(recovered.iter()) {
assert!((a - b).abs() < 1e-10, "roundtrip failed: {a} vs {b}");
}
}
#[test]
fn test_convolve_uv_no_change_when_beams_equal() {
let beam = Beam::new(10.0 / 3600.0, 10.0 / 3600.0, 0.0).unwrap();
let img = Array2::from_elem((16, 16), 1.0_f32);
let result = convolve_uv(&img, &beam, &beam, 2.5 / 3600.0, 2.5 / 3600.0, None).unwrap();
assert!((result.scaling_factor - 1.0).abs() < 1e-10);
}
}