Skip to main content

numra_fft/
convolution.rs

1//! FFT-based convolution and correlation.
2//!
3//! Author: Moussa Leblouba
4//! Date: 9 February 2026
5//! Modified: 2 May 2026
6
7use crate::complex::Complex;
8use crate::fft_core::{fft, ifft};
9use numra_core::Scalar;
10
11/// FFT-based linear convolution of two real signals.
12///
13/// Returns a vector of length `a.len() + b.len() - 1`.
14/// Equivalent to `numpy.convolve(a, b, mode='full')`.
15pub fn fftconvolve<S: Scalar>(a: &[S], b: &[S]) -> Vec<S> {
16    if a.is_empty() || b.is_empty() {
17        return vec![];
18    }
19
20    let n = a.len() + b.len() - 1;
21    // Pad to next power of 2 for FFT efficiency
22    let fft_len = n.next_power_of_two();
23
24    let mut ca = vec![Complex::zero(); fft_len];
25    let mut cb = vec![Complex::zero(); fft_len];
26
27    for (i, &v) in a.iter().enumerate() {
28        ca[i] = Complex::new(v, S::ZERO);
29    }
30    for (i, &v) in b.iter().enumerate() {
31        cb[i] = Complex::new(v, S::ZERO);
32    }
33
34    let fa = fft(&ca);
35    let fb = fft(&cb);
36
37    // Pointwise multiply in frequency domain
38    let fc: Vec<Complex<S>> = fa.iter().zip(fb.iter()).map(|(&a, &b)| a * b).collect();
39
40    let result = ifft(&fc);
41    result[..n].iter().map(|c| c.re).collect()
42}
43
44/// FFT-based cross-correlation of two real signals.
45///
46/// Returns a vector of length `a.len() + b.len() - 1`.
47/// Equivalent to `numpy.correlate(a, b, mode='full')`.
48///
49/// Implemented as `convolve(a, reverse(b))`.
50pub fn fftcorrelate<S: Scalar>(a: &[S], b: &[S]) -> Vec<S> {
51    if a.is_empty() || b.is_empty() {
52        return vec![];
53    }
54
55    // Correlation = convolution with time-reversed second signal
56    let b_rev: Vec<S> = b.iter().rev().copied().collect();
57    fftconvolve(a, &b_rev)
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn test_fftconvolve_impulse() {
66        // Convolution with impulse [1] is identity
67        let a = vec![1.0, 2.0, 3.0, 4.0];
68        let b = vec![1.0];
69        let result = fftconvolve(&a, &b);
70        assert_eq!(result.len(), 4);
71        for (i, &v) in result.iter().enumerate() {
72            assert!((v - a[i]).abs() < 1e-12);
73        }
74    }
75
76    #[test]
77    fn test_fftconvolve_known() {
78        // [1, 1] * [1, 1, 1] = [1, 2, 2, 1]
79        let a = vec![1.0, 1.0];
80        let b = vec![1.0, 1.0, 1.0];
81        let result = fftconvolve(&a, &b);
82        assert_eq!(result.len(), 4);
83        let expected = [1.0, 2.0, 2.0, 1.0];
84        for (r, e) in result.iter().zip(expected.iter()) {
85            assert!((r - e).abs() < 1e-12, "{} vs {}", r, e);
86        }
87    }
88
89    #[test]
90    fn test_fftconvolve_commutative() {
91        let a = vec![1.0, 2.0, 3.0];
92        let b = vec![4.0, 5.0];
93        let ab = fftconvolve(&a, &b);
94        let ba = fftconvolve(&b, &a);
95        assert_eq!(ab.len(), ba.len());
96        for (x, y) in ab.iter().zip(ba.iter()) {
97            assert!((x - y).abs() < 1e-12);
98        }
99    }
100
101    #[test]
102    fn test_fftcorrelate_autocorrelation() {
103        // Autocorrelation: peak should be the maximum value (at zero lag)
104        let a = vec![1.0, 2.0, 3.0, 2.0, 1.0];
105        let result = fftcorrelate(&a, &a);
106        assert_eq!(result.len(), 9); // 2*5-1
107                                     // Peak value = sum(a[i]^2) = 1+4+9+4+1 = 19
108        let peak = result.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
109        assert!((peak - 19.0).abs() < 1e-10);
110    }
111
112    #[test]
113    fn test_fftcorrelate_vs_direct() {
114        // Verify FFT correlation matches direct computation
115        let a = vec![1.0, 2.0, 3.0];
116        let b = vec![4.0, 5.0];
117        let result = fftcorrelate(&a, &b);
118        assert_eq!(result.len(), 4); // 3+2-1 = 4
119                                     // Direct: corr[k] = sum_i a[i]*b[i-k+(M-1)] for valid overlaps
120                                     // Since correlate(a,b) = convolve(a, rev(b)):
121        let b_rev = vec![5.0, 4.0];
122        let conv = fftconvolve(&a, &b_rev);
123        assert_eq!(result.len(), conv.len());
124        for (r, c) in result.iter().zip(conv.iter()) {
125            assert!((r - c).abs() < 1e-12, "{} vs {}", r, c);
126        }
127    }
128
129    #[test]
130    fn test_fftconvolve_empty() {
131        assert!(fftconvolve::<f64>(&[], &[1.0]).is_empty());
132        assert!(fftcorrelate::<f64>(&[], &[1.0]).is_empty());
133    }
134}