Skip to main content

convolve_rs/
convolve_uv.rs

1//! FFT-based UV-plane beam convolution.
2//!
3//! This is a port of `racs_tools.convolve_uv.convolve` and `racs_tools.gaussft.gaussft`.
4//! The "robust" mode is implemented: the FT of the convolving Gaussian is computed
5//! analytically at each UV point (no kernel image needed), which handles NaNs gracefully.
6use ndarray::Array2;
7use realfft::RealFftPlanner;
8use rustfft::{FftPlanner, num_complex::Complex};
9use thiserror::Error;
10
11use crate::beam::{Beam, gauss_factor};
12
13#[derive(Debug, Error)]
14pub enum ConvolveError {
15    #[error("image is entirely NaN")]
16    AllNaN,
17    #[error("beam larger than cutoff — image blanked")]
18    AboveCutoff,
19}
20
21pub struct ConvolutionResult {
22    /// Convolved image (NaNs propagated from input).
23    pub image: Array2<f32>,
24    /// Flux scaling factor for Jy/beam.
25    pub scaling_factor: f64,
26}
27
28/// Convolve `image` from `old_beam` to `new_beam` in the UV plane.
29///
30/// `dx_deg` / `dy_deg` are the pixel sizes in degrees (FITS |CDELT1|, |CDELT2|).
31/// `cutoff_arcsec` blanks images whose current beam exceeds this size.
32///
33/// The returned [`ConvolutionResult::scaling_factor`] is `√(Ω_new/Ω_old)`; see
34/// [`crate::smooth::smooth`] for how this becomes the Jy/beam or Kelvin factor.
35///
36/// # Examples
37///
38/// ```
39/// use convolve_rs::{Beam, convolve_uv};
40/// use ndarray::Array2;
41///
42/// let old = Beam::from_arcsec(10.0, 10.0, 0.0)?;
43/// let new = Beam::from_arcsec(20.0, 20.0, 0.0)?;
44/// let image = Array2::<f32>::from_elem((64, 64), 1.0);
45/// let dx = 2.5 / 3600.0;
46///
47/// let result = convolve_uv(&image, &old, &new, dx, dx, None)?;
48/// // √(Ω_new/Ω_old) = √4 = 2 for a doubling of both axes.
49/// assert!((result.scaling_factor - 2.0).abs() < 1e-9);
50/// assert_eq!(result.image.dim(), (64, 64));
51/// # Ok::<(), Box<dyn std::error::Error>>(())
52/// ```
53pub fn convolve_uv(
54    image: &Array2<f32>,
55    old_beam: &Beam,
56    new_beam: &Beam,
57    dx_deg: f64,
58    dy_deg: f64,
59    cutoff_arcsec: Option<f64>,
60) -> Result<ConvolutionResult, ConvolveError> {
61    // Cutoff check.
62    if let Some(cutoff) = cutoff_arcsec
63        && old_beam.major_arcsec() > cutoff
64    {
65        return Err(ConvolveError::AboveCutoff);
66    }
67
68    // Beams identical → no-op with unit scaling.
69    if old_beam.approx_eq(new_beam) {
70        return Ok(ConvolutionResult {
71            image: image.clone(),
72            scaling_factor: 1.0,
73        });
74    }
75
76    // Compute the convolving beam (new² - old² in quadrature) and flux scaling.
77    let conv_beam = new_beam.deconvolve_or_zero(old_beam);
78    let (fac, ..) = gauss_factor(
79        &conv_beam,
80        old_beam,
81        dx_deg.abs() * 3600.0,
82        dy_deg.abs() * 3600.0,
83    );
84
85    // All-NaN fast path.
86    if image.iter().all(|x| x.is_nan()) {
87        return Ok(ConvolutionResult {
88            image: image.clone(),
89            scaling_factor: fac,
90        });
91    }
92
93    let (nrows, ncols) = image.dim();
94
95    // Handle NaNs: zero-fill and track a mask.
96    let has_nan = image.iter().any(|x| x.is_nan());
97    let (clean_image, nan_mask): (Vec<f64>, Option<Vec<f64>>) = if has_nan {
98        let vals: Vec<f64> = image
99            .iter()
100            .map(|&x| if x.is_nan() { 0.0 } else { x as f64 })
101            .collect();
102        let mask: Vec<f64> = image
103            .iter()
104            .map(|&x| if x.is_nan() { 1.0 } else { 0.0 })
105            .collect();
106        (vals, Some(mask))
107    } else {
108        let vals: Vec<f64> = image.iter().map(|&x| x as f64).collect();
109        (vals, None)
110    };
111
112    // UV coordinates: fftfreq(n, d_rad) where d_rad = pixel_size_in_radians.
113    // The data is real, so we use a real-input FFT: the column (ncols) axis only
114    // needs its non-negative half, `nhalf = ncols/2 + 1` bins. We slice the full
115    // `fftfreq` rather than using `rfftfreq` so the filter is evaluated at exactly
116    // the frequencies the equivalent full FFT assigns to bins 0..nhalf (incl. the
117    // signed Nyquist), keeping results bit-for-bit aligned with the full-FFT port.
118    let nhalf = ncols / 2 + 1;
119    let dx_rad = dx_deg.to_radians();
120    let dy_rad = dy_deg.to_radians();
121    let u_freqs = fftfreq(nrows, dx_rad); // shape (nrows,)
122    let v_freqs_full = fftfreq(ncols, dy_rad);
123    let v_freqs = &v_freqs_full[..nhalf]; // half spectrum
124
125    // UV-plane filter on the half spectrum (shape nrows × nhalf), real-valued.
126    let (g_final, g_ratio) = gaussft(old_beam, new_beam, &u_freqs, v_freqs);
127
128    // Forward real FFT, apply the filter in place, inverse real FFT.
129    let mut im_f = rfft2(&clean_image, nrows, ncols);
130    for (s, &g) in im_f.iter_mut().zip(g_final.iter()) {
131        *s *= g;
132    }
133    let im_conv_flat = irfft2(im_f, nrows, ncols);
134
135    // NaN propagation.
136    let out_flat: Vec<f32> = if let Some(mask) = nan_mask {
137        let mut mask_f = rfft2(&mask, nrows, ncols);
138        for (s, &g) in mask_f.iter_mut().zip(g_final.iter()) {
139            *s *= g;
140        }
141        let mask_conv = irfft2(mask_f, nrows, ncols);
142        im_conv_flat
143            .iter()
144            .zip(mask_conv.iter())
145            .map(|(&v, &m)| if m >= 1.0 { f32::NAN } else { v as f32 })
146            .collect()
147    } else {
148        im_conv_flat.iter().map(|&v| v as f32).collect()
149    };
150
151    let out = Array2::from_shape_vec((nrows, ncols), out_flat)
152        .expect("shape mismatch in convolve_uv output");
153
154    Ok(ConvolutionResult {
155        image: out,
156        scaling_factor: g_ratio,
157    })
158}
159
160// ── gaussft ───────────────────────────────────────────────────────────────────
161
162/// Compute the UV-plane filter that deconvolves `old_beam` and re-convolves with
163/// `new_beam`. Direct port of `racs_tools.gaussft.gaussft`.
164///
165/// `u_freqs` has length `nrows`, `v_freqs` has length `ncols` (or `nhalf` for a
166/// half-spectrum / real-FFT layout). The filter is real-valued, so it is returned
167/// as `Vec<f64>` of length `nrows * v_freqs.len()` in row-major order.
168pub fn gaussft(
169    old_beam: &Beam,
170    new_beam: &Beam,
171    u_freqs: &[f64],
172    v_freqs: &[f64],
173) -> (Vec<f64>, f64) {
174    let deg2rad = std::f64::consts::PI / 180.0;
175    let two_ln2 = 2.0 * 2_f64.ln();
176    let fwhm_to_sigma = 2.0 * two_ln2.sqrt(); // = 2*sqrt(2*ln2)
177
178    // New beam (target).
179    let bmaj_rad = new_beam.major_deg * deg2rad;
180    let bmin_rad = new_beam.minor_deg * deg2rad;
181    let bpa_rad = new_beam.pa_deg * deg2rad;
182    let sx = bmaj_rad / fwhm_to_sigma;
183    let sy = bmin_rad / fwhm_to_sigma;
184
185    // Old beam (input PSF).
186    let bmaj_in_rad = old_beam.major_deg * deg2rad;
187    let bmin_in_rad = old_beam.minor_deg * deg2rad;
188    let bpa_in_rad = old_beam.pa_deg * deg2rad;
189    let sx_in = bmaj_in_rad / fwhm_to_sigma;
190    let sy_in = bmin_in_rad / fwhm_to_sigma;
191
192    // Amplitude ratio (= flux scaling factor).
193    let g_amp = (2.0 * std::f64::consts::PI * sx * sy).sqrt();
194    let dg_amp = (2.0 * std::f64::consts::PI * sx_in * sy_in).sqrt();
195    let g_ratio = g_amp / dg_amp;
196
197    let pi2 = std::f64::consts::PI * std::f64::consts::PI;
198    let nrows = u_freqs.len();
199    let ncols = v_freqs.len();
200    let mut g_final = vec![0.0_f64; nrows * ncols];
201
202    // Pre-rotate u and v for new beam.
203    let u_cos = u_freqs
204        .iter()
205        .map(|&u| u * bpa_rad.cos())
206        .collect::<Vec<_>>();
207    let u_sin = u_freqs
208        .iter()
209        .map(|&u| u * bpa_rad.sin())
210        .collect::<Vec<_>>();
211    let v_cos = v_freqs
212        .iter()
213        .map(|&v| v * bpa_rad.cos())
214        .collect::<Vec<_>>();
215    let v_sin = v_freqs
216        .iter()
217        .map(|&v| v * bpa_rad.sin())
218        .collect::<Vec<_>>();
219
220    // Pre-rotate u and v for old beam.
221    let u_cos_in = u_freqs
222        .iter()
223        .map(|&u| u * bpa_in_rad.cos())
224        .collect::<Vec<_>>();
225    let u_sin_in = u_freqs
226        .iter()
227        .map(|&u| u * bpa_in_rad.sin())
228        .collect::<Vec<_>>();
229    let v_cos_in = v_freqs
230        .iter()
231        .map(|&v| v * bpa_in_rad.cos())
232        .collect::<Vec<_>>();
233    let v_sin_in = v_freqs
234        .iter()
235        .map(|&v| v * bpa_in_rad.sin())
236        .collect::<Vec<_>>();
237
238    for i in 0..nrows {
239        for j in 0..ncols {
240            // Rotated UV coordinates for new beam.
241            let ur = u_cos[i] - v_sin[j];
242            let vr = u_sin[i] + v_cos[j];
243            // Rotated UV coordinates for old beam.
244            let ur_in = u_cos_in[i] - v_sin_in[j];
245            let vr_in = u_sin_in[i] + v_cos_in[j];
246
247            let g_arg = -2.0 * pi2 * ((sx * ur).powi(2) + (sy * vr).powi(2));
248            let dg_arg = -2.0 * pi2 * ((sx_in * ur_in).powi(2) + (sy_in * vr_in).powi(2));
249
250            g_final[i * ncols + j] = g_ratio * (g_arg - dg_arg).exp();
251        }
252    }
253
254    (g_final, g_ratio)
255}
256
257// ── FFT helpers ───────────────────────────────────────────────────────────────
258
259/// numpy-compatible `fftfreq(n, d)`.
260///
261/// For even n the Nyquist bin (index n/2) is listed as negative, matching numpy.
262///
263/// # Examples
264///
265/// ```
266/// use convolve_rs::fftfreq;
267///
268/// assert_eq!(fftfreq(4, 1.0), vec![0.0, 0.25, -0.5, -0.25]);
269/// assert_eq!(fftfreq(5, 1.0), vec![0.0, 0.2, 0.4, -0.4, -0.2]);
270/// ```
271pub fn fftfreq(n: usize, d: f64) -> Vec<f64> {
272    let val = 1.0 / (n as f64 * d);
273    let m = n.div_ceil(2); // ceiling(n/2): positive-frequency count
274    let mut freqs = vec![0.0_f64; n];
275    for (i, freq) in freqs.iter_mut().enumerate().take(m) {
276        *freq = i as f64 * val;
277    }
278    for (i, freq) in freqs.iter_mut().enumerate().take(n).skip(m) {
279        *freq = (i as f64 - n as f64) * val;
280    }
281    freqs
282}
283
284/// 2D forward FFT of real-valued data stored row-major in `data` (shape nrows×ncols).
285///
286/// Uses a real-input FFT along the contiguous (ncols) axis, so only the
287/// non-negative half of that axis is kept: the returned spectrum is
288/// `nrows × nhalf` (`nhalf = ncols/2 + 1`) complex values, row-major. This roughly
289/// halves the spectrum memory versus a full complex FFT — the dominant cost at
290/// large image sizes.
291fn rfft2(data: &[f64], nrows: usize, ncols: usize) -> Vec<Complex<f64>> {
292    let nhalf = ncols / 2 + 1;
293
294    // Row-wise real→complex FFT.
295    let mut rplanner = RealFftPlanner::<f64>::new();
296    let r2c = rplanner.plan_fft_forward(ncols);
297    let mut scratch = r2c.make_scratch_vec();
298    let mut inrow = r2c.make_input_vec();
299    let mut spectrum = vec![Complex::new(0.0, 0.0); nrows * nhalf];
300    for (i, chunk) in data.chunks(ncols).enumerate() {
301        inrow.copy_from_slice(chunk);
302        r2c.process_with_scratch(
303            &mut inrow,
304            &mut spectrum[i * nhalf..(i + 1) * nhalf],
305            &mut scratch,
306        )
307        .expect("r2c FFT");
308    }
309
310    // Column-wise complex FFT over the `nhalf` columns (gather, process, scatter).
311    let col_fft = FftPlanner::new().plan_fft_forward(nrows);
312    let mut col_buf = vec![Complex::new(0.0, 0.0); nrows];
313    for j in 0..nhalf {
314        for i in 0..nrows {
315            col_buf[i] = spectrum[i * nhalf + j];
316        }
317        col_fft.process(&mut col_buf);
318        for i in 0..nrows {
319            spectrum[i * nhalf + j] = col_buf[i];
320        }
321    }
322
323    spectrum
324}
325
326/// 2D inverse of [`rfft2`] (un-normalised → divide by N = nrows*ncols).
327/// Consumes the half `nrows × nhalf` spectrum and returns the real nrows×ncols image.
328fn irfft2(mut spectrum: Vec<Complex<f64>>, nrows: usize, ncols: usize) -> Vec<f64> {
329    let nhalf = ncols / 2 + 1;
330
331    // Column-wise inverse complex FFT over the `nhalf` columns.
332    let col_ifft = FftPlanner::new().plan_fft_inverse(nrows);
333    let mut col_buf = vec![Complex::new(0.0, 0.0); nrows];
334    for j in 0..nhalf {
335        for i in 0..nrows {
336            col_buf[i] = spectrum[i * nhalf + j];
337        }
338        col_ifft.process(&mut col_buf);
339        for i in 0..nrows {
340            spectrum[i * nhalf + j] = col_buf[i];
341        }
342    }
343
344    // Row-wise complex→real FFT.
345    let mut rplanner = RealFftPlanner::<f64>::new();
346    let c2r = rplanner.plan_fft_inverse(ncols);
347    let mut scratch = c2r.make_scratch_vec();
348    let mut inrow = c2r.make_input_vec();
349    let mut out = vec![0.0_f64; nrows * ncols];
350    let even = ncols.is_multiple_of(2);
351    for i in 0..nrows {
352        inrow.copy_from_slice(&spectrum[i * nhalf..(i + 1) * nhalf]);
353        // c2r requires the DC (and, for even ncols, Nyquist) bins to be purely
354        // real; they are up to rounding, so zero the imaginary parts explicitly.
355        inrow[0].im = 0.0;
356        if even {
357            inrow[nhalf - 1].im = 0.0;
358        }
359        c2r.process_with_scratch(
360            &mut inrow,
361            &mut out[i * ncols..(i + 1) * ncols],
362            &mut scratch,
363        )
364        .expect("c2r FFT");
365    }
366
367    let norm = (nrows * ncols) as f64;
368    for v in out.iter_mut() {
369        *v /= norm;
370    }
371    out
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use ndarray::Array2;
378
379    #[test]
380    fn test_fftfreq() {
381        // Match numpy: fftfreq(4, 1) = [0, 0.25, -0.5, -0.25]
382        let f = fftfreq(4, 1.0);
383        let expected = [0.0, 0.25, -0.5, -0.25];
384        for (a, b) in f.iter().zip(expected.iter()) {
385            assert!((a - b).abs() < 1e-12, "got {a}, want {b}");
386        }
387    }
388
389    #[test]
390    fn test_rfft2_irfft2_roundtrip() {
391        // Use even dimensions to exercise the Nyquist handling in irfft2.
392        let data = vec![
393            1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
394            16.0,
395        ];
396        let (nrows, ncols) = (4, 4);
397        let spectrum = rfft2(&data, nrows, ncols);
398        let recovered = irfft2(spectrum, nrows, ncols);
399        for (a, b) in data.iter().zip(recovered.iter()) {
400            assert!((a - b).abs() < 1e-10, "roundtrip failed: {a} vs {b}");
401        }
402    }
403
404    #[test]
405    fn test_convolve_uv_no_change_when_beams_equal() {
406        let beam = Beam::new(10.0 / 3600.0, 10.0 / 3600.0, 0.0).unwrap();
407        let img = Array2::from_elem((16, 16), 1.0_f32);
408        let result = convolve_uv(&img, &beam, &beam, 2.5 / 3600.0, 2.5 / 3600.0, None).unwrap();
409        assert!((result.scaling_factor - 1.0).abs() < 1e-10);
410    }
411}