1use ndarray::Array2;
7use realfft::RealFftPlanner;
8use rustfft::{FftPlanner, num_complex::Complex};
9use thiserror::Error;
10
11use crate::beam::{Beam, gauss_factor};
12
13#[derive(Debug, Error)]
14pub enum ConvolveError {
15 #[error("image is entirely NaN")]
16 AllNaN,
17 #[error("beam larger than cutoff — image blanked")]
18 AboveCutoff,
19}
20
21pub struct ConvolutionResult {
22 pub image: Array2<f32>,
24 pub scaling_factor: f64,
26}
27
28pub fn convolve_uv(
54 image: &Array2<f32>,
55 old_beam: &Beam,
56 new_beam: &Beam,
57 dx_deg: f64,
58 dy_deg: f64,
59 cutoff_arcsec: Option<f64>,
60) -> Result<ConvolutionResult, ConvolveError> {
61 if let Some(cutoff) = cutoff_arcsec
63 && old_beam.major_arcsec() > cutoff
64 {
65 return Err(ConvolveError::AboveCutoff);
66 }
67
68 if old_beam.approx_eq(new_beam) {
70 return Ok(ConvolutionResult {
71 image: image.clone(),
72 scaling_factor: 1.0,
73 });
74 }
75
76 let conv_beam = new_beam.deconvolve_or_zero(old_beam);
78 let (fac, ..) = gauss_factor(
79 &conv_beam,
80 old_beam,
81 dx_deg.abs() * 3600.0,
82 dy_deg.abs() * 3600.0,
83 );
84
85 if image.iter().all(|x| x.is_nan()) {
87 return Ok(ConvolutionResult {
88 image: image.clone(),
89 scaling_factor: fac,
90 });
91 }
92
93 let (nrows, ncols) = image.dim();
94
95 let has_nan = image.iter().any(|x| x.is_nan());
97 let (clean_image, nan_mask): (Vec<f64>, Option<Vec<f64>>) = if has_nan {
98 let vals: Vec<f64> = image
99 .iter()
100 .map(|&x| if x.is_nan() { 0.0 } else { x as f64 })
101 .collect();
102 let mask: Vec<f64> = image
103 .iter()
104 .map(|&x| if x.is_nan() { 1.0 } else { 0.0 })
105 .collect();
106 (vals, Some(mask))
107 } else {
108 let vals: Vec<f64> = image.iter().map(|&x| x as f64).collect();
109 (vals, None)
110 };
111
112 let nhalf = ncols / 2 + 1;
119 let dx_rad = dx_deg.to_radians();
120 let dy_rad = dy_deg.to_radians();
121 let u_freqs = fftfreq(nrows, dx_rad); let v_freqs_full = fftfreq(ncols, dy_rad);
123 let v_freqs = &v_freqs_full[..nhalf]; let (g_final, g_ratio) = gaussft(old_beam, new_beam, &u_freqs, v_freqs);
127
128 let mut im_f = rfft2(&clean_image, nrows, ncols);
130 for (s, &g) in im_f.iter_mut().zip(g_final.iter()) {
131 *s *= g;
132 }
133 let im_conv_flat = irfft2(im_f, nrows, ncols);
134
135 let out_flat: Vec<f32> = if let Some(mask) = nan_mask {
137 let mut mask_f = rfft2(&mask, nrows, ncols);
138 for (s, &g) in mask_f.iter_mut().zip(g_final.iter()) {
139 *s *= g;
140 }
141 let mask_conv = irfft2(mask_f, nrows, ncols);
142 im_conv_flat
143 .iter()
144 .zip(mask_conv.iter())
145 .map(|(&v, &m)| if m >= 1.0 { f32::NAN } else { v as f32 })
146 .collect()
147 } else {
148 im_conv_flat.iter().map(|&v| v as f32).collect()
149 };
150
151 let out = Array2::from_shape_vec((nrows, ncols), out_flat)
152 .expect("shape mismatch in convolve_uv output");
153
154 Ok(ConvolutionResult {
155 image: out,
156 scaling_factor: g_ratio,
157 })
158}
159
160pub fn gaussft(
169 old_beam: &Beam,
170 new_beam: &Beam,
171 u_freqs: &[f64],
172 v_freqs: &[f64],
173) -> (Vec<f64>, f64) {
174 let deg2rad = std::f64::consts::PI / 180.0;
175 let two_ln2 = 2.0 * 2_f64.ln();
176 let fwhm_to_sigma = 2.0 * two_ln2.sqrt(); let bmaj_rad = new_beam.major_deg * deg2rad;
180 let bmin_rad = new_beam.minor_deg * deg2rad;
181 let bpa_rad = new_beam.pa_deg * deg2rad;
182 let sx = bmaj_rad / fwhm_to_sigma;
183 let sy = bmin_rad / fwhm_to_sigma;
184
185 let bmaj_in_rad = old_beam.major_deg * deg2rad;
187 let bmin_in_rad = old_beam.minor_deg * deg2rad;
188 let bpa_in_rad = old_beam.pa_deg * deg2rad;
189 let sx_in = bmaj_in_rad / fwhm_to_sigma;
190 let sy_in = bmin_in_rad / fwhm_to_sigma;
191
192 let g_amp = (2.0 * std::f64::consts::PI * sx * sy).sqrt();
194 let dg_amp = (2.0 * std::f64::consts::PI * sx_in * sy_in).sqrt();
195 let g_ratio = g_amp / dg_amp;
196
197 let pi2 = std::f64::consts::PI * std::f64::consts::PI;
198 let nrows = u_freqs.len();
199 let ncols = v_freqs.len();
200 let mut g_final = vec![0.0_f64; nrows * ncols];
201
202 let u_cos = u_freqs
204 .iter()
205 .map(|&u| u * bpa_rad.cos())
206 .collect::<Vec<_>>();
207 let u_sin = u_freqs
208 .iter()
209 .map(|&u| u * bpa_rad.sin())
210 .collect::<Vec<_>>();
211 let v_cos = v_freqs
212 .iter()
213 .map(|&v| v * bpa_rad.cos())
214 .collect::<Vec<_>>();
215 let v_sin = v_freqs
216 .iter()
217 .map(|&v| v * bpa_rad.sin())
218 .collect::<Vec<_>>();
219
220 let u_cos_in = u_freqs
222 .iter()
223 .map(|&u| u * bpa_in_rad.cos())
224 .collect::<Vec<_>>();
225 let u_sin_in = u_freqs
226 .iter()
227 .map(|&u| u * bpa_in_rad.sin())
228 .collect::<Vec<_>>();
229 let v_cos_in = v_freqs
230 .iter()
231 .map(|&v| v * bpa_in_rad.cos())
232 .collect::<Vec<_>>();
233 let v_sin_in = v_freqs
234 .iter()
235 .map(|&v| v * bpa_in_rad.sin())
236 .collect::<Vec<_>>();
237
238 for i in 0..nrows {
239 for j in 0..ncols {
240 let ur = u_cos[i] - v_sin[j];
242 let vr = u_sin[i] + v_cos[j];
243 let ur_in = u_cos_in[i] - v_sin_in[j];
245 let vr_in = u_sin_in[i] + v_cos_in[j];
246
247 let g_arg = -2.0 * pi2 * ((sx * ur).powi(2) + (sy * vr).powi(2));
248 let dg_arg = -2.0 * pi2 * ((sx_in * ur_in).powi(2) + (sy_in * vr_in).powi(2));
249
250 g_final[i * ncols + j] = g_ratio * (g_arg - dg_arg).exp();
251 }
252 }
253
254 (g_final, g_ratio)
255}
256
257pub fn fftfreq(n: usize, d: f64) -> Vec<f64> {
272 let val = 1.0 / (n as f64 * d);
273 let m = n.div_ceil(2); let mut freqs = vec![0.0_f64; n];
275 for (i, freq) in freqs.iter_mut().enumerate().take(m) {
276 *freq = i as f64 * val;
277 }
278 for (i, freq) in freqs.iter_mut().enumerate().take(n).skip(m) {
279 *freq = (i as f64 - n as f64) * val;
280 }
281 freqs
282}
283
284fn rfft2(data: &[f64], nrows: usize, ncols: usize) -> Vec<Complex<f64>> {
292 let nhalf = ncols / 2 + 1;
293
294 let mut rplanner = RealFftPlanner::<f64>::new();
296 let r2c = rplanner.plan_fft_forward(ncols);
297 let mut scratch = r2c.make_scratch_vec();
298 let mut inrow = r2c.make_input_vec();
299 let mut spectrum = vec![Complex::new(0.0, 0.0); nrows * nhalf];
300 for (i, chunk) in data.chunks(ncols).enumerate() {
301 inrow.copy_from_slice(chunk);
302 r2c.process_with_scratch(
303 &mut inrow,
304 &mut spectrum[i * nhalf..(i + 1) * nhalf],
305 &mut scratch,
306 )
307 .expect("r2c FFT");
308 }
309
310 let col_fft = FftPlanner::new().plan_fft_forward(nrows);
312 let mut col_buf = vec![Complex::new(0.0, 0.0); nrows];
313 for j in 0..nhalf {
314 for i in 0..nrows {
315 col_buf[i] = spectrum[i * nhalf + j];
316 }
317 col_fft.process(&mut col_buf);
318 for i in 0..nrows {
319 spectrum[i * nhalf + j] = col_buf[i];
320 }
321 }
322
323 spectrum
324}
325
326fn irfft2(mut spectrum: Vec<Complex<f64>>, nrows: usize, ncols: usize) -> Vec<f64> {
329 let nhalf = ncols / 2 + 1;
330
331 let col_ifft = FftPlanner::new().plan_fft_inverse(nrows);
333 let mut col_buf = vec![Complex::new(0.0, 0.0); nrows];
334 for j in 0..nhalf {
335 for i in 0..nrows {
336 col_buf[i] = spectrum[i * nhalf + j];
337 }
338 col_ifft.process(&mut col_buf);
339 for i in 0..nrows {
340 spectrum[i * nhalf + j] = col_buf[i];
341 }
342 }
343
344 let mut rplanner = RealFftPlanner::<f64>::new();
346 let c2r = rplanner.plan_fft_inverse(ncols);
347 let mut scratch = c2r.make_scratch_vec();
348 let mut inrow = c2r.make_input_vec();
349 let mut out = vec![0.0_f64; nrows * ncols];
350 let even = ncols.is_multiple_of(2);
351 for i in 0..nrows {
352 inrow.copy_from_slice(&spectrum[i * nhalf..(i + 1) * nhalf]);
353 inrow[0].im = 0.0;
356 if even {
357 inrow[nhalf - 1].im = 0.0;
358 }
359 c2r.process_with_scratch(
360 &mut inrow,
361 &mut out[i * ncols..(i + 1) * ncols],
362 &mut scratch,
363 )
364 .expect("c2r FFT");
365 }
366
367 let norm = (nrows * ncols) as f64;
368 for v in out.iter_mut() {
369 *v /= norm;
370 }
371 out
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use ndarray::Array2;
378
379 #[test]
380 fn test_fftfreq() {
381 let f = fftfreq(4, 1.0);
383 let expected = [0.0, 0.25, -0.5, -0.25];
384 for (a, b) in f.iter().zip(expected.iter()) {
385 assert!((a - b).abs() < 1e-12, "got {a}, want {b}");
386 }
387 }
388
389 #[test]
390 fn test_rfft2_irfft2_roundtrip() {
391 let data = vec![
393 1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
394 16.0,
395 ];
396 let (nrows, ncols) = (4, 4);
397 let spectrum = rfft2(&data, nrows, ncols);
398 let recovered = irfft2(spectrum, nrows, ncols);
399 for (a, b) in data.iter().zip(recovered.iter()) {
400 assert!((a - b).abs() < 1e-10, "roundtrip failed: {a} vs {b}");
401 }
402 }
403
404 #[test]
405 fn test_convolve_uv_no_change_when_beams_equal() {
406 let beam = Beam::new(10.0 / 3600.0, 10.0 / 3600.0, 0.0).unwrap();
407 let img = Array2::from_elem((16, 16), 1.0_f32);
408 let result = convolve_uv(&img, &beam, &beam, 2.5 / 3600.0, 2.5 / 3600.0, None).unwrap();
409 assert!((result.scaling_factor - 1.0).abs() < 1e-10);
410 }
411}