Skip to main content

ferray_ufunc/ops/
convolution.rs

1// ferray-ufunc: Convolution
2//
3// convolve with modes: Full, Same, Valid
4
5use ferray_core::Array;
6use ferray_core::dimension::Ix1;
7use ferray_core::dtype::Element;
8use ferray_core::error::{FerrayError, FerrayResult};
9
10/// Convolution mode.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ConvolveMode {
13    /// Full convolution output (length = N + M - 1).
14    Full,
15    /// Output has length max(N, M).
16    Same,
17    /// Output only where signals fully overlap (length = max(N, M) - min(N, M) + 1).
18    Valid,
19}
20
21/// Discrete, linear convolution of two 1-D arrays.
22///
23/// Computes `convolve(a, v, mode)` following `NumPy` semantics — direct
24/// O(n·m) algorithm, matching `numpy.convolve`. For large floating-point
25/// inputs prefer [`fftconvolve`] (behind the `fft-convolve` feature),
26/// which is O((n+m) · log(n+m)) via FFT.
27///
28/// The inner loop is restructured to iterate over output positions
29/// rather than input pairs (#89). The previous form
30/// `full[i+j] += a[i] * v[j]` had a strided write pattern that
31/// confused auto-vectorisation; iterating over `k = i+j` and
32/// accumulating a dot product `Σ a[i] * v[k-i]` gives the auto-
33/// vectoriser a clean inner loop with linear writes.
34pub fn convolve<T>(
35    a: &Array<T, Ix1>,
36    v: &Array<T, Ix1>,
37    mode: ConvolveMode,
38) -> FerrayResult<Array<T, Ix1>>
39where
40    T: Element + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Copy,
41{
42    let a_data: Vec<T> = a.iter().copied().collect();
43    let v_data: Vec<T> = v.iter().copied().collect();
44    let n = a_data.len();
45    let m = v_data.len();
46
47    if n == 0 || m == 0 {
48        return Err(FerrayError::invalid_value(
49            "convolve: input arrays must be non-empty",
50        ));
51    }
52
53    // Full convolution — output position k = i + j, inner loop is a
54    // dot product with linear indexing into both inputs.
55    let full_len = n + m - 1;
56    let mut full = vec![<T as Element>::zero(); full_len];
57
58    for k in 0..full_len {
59        // Range of valid i: max(0, k+1-m) <= i < min(n, k+1).
60        let i_lo = (k + 1).saturating_sub(m);
61        let i_hi = (k + 1).min(n);
62        let mut acc = <T as Element>::zero();
63        for i in i_lo..i_hi {
64            // SAFETY-style bounds rationale: the i_lo/i_hi math
65            // guarantees i < n and (k - i) < m. The compiler usually
66            // proves this and eliminates bounds checks; if it can't,
67            // the panics still produce correct (slightly slower) code.
68            acc = acc + a_data[i] * v_data[k - i];
69        }
70        full[k] = acc;
71    }
72
73    match mode {
74        ConvolveMode::Full => Array::from_vec(Ix1::new([full_len]), full),
75        ConvolveMode::Same => {
76            let out_len = n.max(m);
77            let start = (full_len - out_len) / 2;
78            let result = full[start..start + out_len].to_vec();
79            Array::from_vec(Ix1::new([out_len]), result)
80        }
81        ConvolveMode::Valid => {
82            let out_len = if n >= m { n - m + 1 } else { m - n + 1 };
83            let start = m.min(n) - 1;
84            let result = full[start..start + out_len].to_vec();
85            Array::from_vec(Ix1::new([out_len]), result)
86        }
87    }
88}
89
90/// FFT-based convolution of two 1-D f64 arrays — O((n+m) · log(n+m))
91/// via real-to-complex FFT.
92///
93/// Matches `scipy.signal.fftconvolve(a, v, mode)` for the supported
94/// modes. Numerically equivalent to [`convolve`] up to float roundoff
95/// (typically a few ULPs at the magnitude of the result; for inputs
96/// that span many orders of magnitude the FFT path may diverge more
97/// because FFT accumulates O(log N) ULPs of error per butterfly).
98///
99/// Available only with the `fft-convolve` feature enabled (pulls in
100/// `ferray-fft` / `realfft`). For inputs where n·m < ~50_000 the
101/// direct path in [`convolve`] is faster; the cross-over depends on
102/// hardware but is typically around (n+m) > 1024.
103#[cfg(feature = "fft-convolve")]
104pub fn fftconvolve(
105    a: &Array<f64, Ix1>,
106    v: &Array<f64, Ix1>,
107    mode: ConvolveMode,
108) -> FerrayResult<Array<f64, Ix1>> {
109    use ferray_fft::{FftNorm, irfft, rfft};
110
111    let n = a.size();
112    let m = v.size();
113    if n == 0 || m == 0 {
114        return Err(FerrayError::invalid_value(
115            "fftconvolve: input arrays must be non-empty",
116        ));
117    }
118
119    // Zero-pad both inputs to the full convolution length.
120    let full_len = n + m - 1;
121    let mut a_pad = vec![0.0f64; full_len];
122    let mut v_pad = vec![0.0f64; full_len];
123    for (dst, &src) in a_pad.iter_mut().zip(a.iter()) {
124        *dst = src;
125    }
126    for (dst, &src) in v_pad.iter_mut().zip(v.iter()) {
127        *dst = src;
128    }
129    let a_padded = Array::<f64, Ix1>::from_vec(Ix1::new([full_len]), a_pad)?;
130    let v_padded = Array::<f64, Ix1>::from_vec(Ix1::new([full_len]), v_pad)?;
131
132    // Forward real FFT on both, multiply elementwise, inverse real FFT.
133    let a_fft = rfft(&a_padded, None, None, FftNorm::Backward)?;
134    let v_fft = rfft(&v_padded, None, None, FftNorm::Backward)?;
135
136    let a_spec: Vec<num_complex::Complex<f64>> = a_fft.iter().copied().collect();
137    let v_spec: Vec<num_complex::Complex<f64>> = v_fft.iter().copied().collect();
138    let prod: Vec<num_complex::Complex<f64>> = a_spec
139        .iter()
140        .zip(v_spec.iter())
141        .map(|(a, b)| a * b)
142        .collect();
143    let prod_arr = Array::<num_complex::Complex<f64>, Ix1>::from_vec(Ix1::new([prod.len()]), prod)?;
144
145    let inv = irfft(&prod_arr, Some(full_len), None, FftNorm::Backward)?;
146    let inv_data: Vec<f64> = inv.iter().copied().collect();
147
148    match mode {
149        ConvolveMode::Full => Array::from_vec(Ix1::new([full_len]), inv_data),
150        ConvolveMode::Same => {
151            let out_len = n.max(m);
152            let start = (full_len - out_len) / 2;
153            let slice = inv_data[start..start + out_len].to_vec();
154            Array::from_vec(Ix1::new([out_len]), slice)
155        }
156        ConvolveMode::Valid => {
157            let out_len = if n >= m { n - m + 1 } else { m - n + 1 };
158            let start = m.min(n) - 1;
159            let slice = inv_data[start..start + out_len].to_vec();
160            Array::from_vec(Ix1::new([out_len]), slice)
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    use crate::test_util::arr1;
170
171    #[test]
172    fn test_convolve_full() {
173        let a = arr1(vec![1.0, 2.0, 3.0]);
174        let v = arr1(vec![0.0, 1.0, 0.5]);
175        let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
176        let s = r.as_slice().unwrap();
177        // [0*1, 1*1+0*2, 0.5*1+1*2+0*3, 0.5*2+1*3, 0.5*3]
178        // = [0, 1, 2.5, 4, 1.5]
179        assert_eq!(s.len(), 5);
180        assert!((s[0] - 0.0).abs() < 1e-12);
181        assert!((s[1] - 1.0).abs() < 1e-12);
182        assert!((s[2] - 2.5).abs() < 1e-12);
183        assert!((s[3] - 4.0).abs() < 1e-12);
184        assert!((s[4] - 1.5).abs() < 1e-12);
185    }
186
187    #[test]
188    fn test_convolve_same() {
189        let a = arr1(vec![1.0, 2.0, 3.0]);
190        let v = arr1(vec![0.0, 1.0, 0.5]);
191        let r = convolve(&a, &v, ConvolveMode::Same).unwrap();
192        assert_eq!(r.size(), 3);
193        let s = r.as_slice().unwrap();
194        // Full = [0, 1, 2.5, 4, 1.5], same takes middle 3 = [1, 2.5, 4]
195        assert!((s[0] - 1.0).abs() < 1e-12);
196        assert!((s[1] - 2.5).abs() < 1e-12);
197        assert!((s[2] - 4.0).abs() < 1e-12);
198    }
199
200    #[test]
201    fn test_convolve_valid() {
202        let a = arr1(vec![1.0, 2.0, 3.0]);
203        let v = arr1(vec![0.0, 1.0, 0.5]);
204        let r = convolve(&a, &v, ConvolveMode::Valid).unwrap();
205        assert_eq!(r.size(), 1);
206        let s = r.as_slice().unwrap();
207        assert!((s[0] - 2.5).abs() < 1e-12);
208    }
209
210    #[test]
211    fn test_convolve_simple() {
212        let a = arr1(vec![1.0, 1.0, 1.0]);
213        let v = arr1(vec![1.0, 1.0, 1.0]);
214        let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
215        let s = r.as_slice().unwrap();
216        assert_eq!(s.len(), 5);
217        assert!((s[0] - 1.0).abs() < 1e-12);
218        assert!((s[1] - 2.0).abs() < 1e-12);
219        assert!((s[2] - 3.0).abs() < 1e-12);
220        assert!((s[3] - 2.0).abs() < 1e-12);
221        assert!((s[4] - 1.0).abs() < 1e-12);
222    }
223
224    #[test]
225    fn test_convolve_i32() {
226        let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
227        let v = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 1]).unwrap();
228        let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
229        assert_eq!(r.as_slice().unwrap(), &[1, 3, 5, 3]);
230    }
231}