1use crate::complex::Complex;
8use crate::fft_core::{fft, ifft};
9use numra_core::Scalar;
10
11pub fn fftconvolve<S: Scalar>(a: &[S], b: &[S]) -> Vec<S> {
16 if a.is_empty() || b.is_empty() {
17 return vec![];
18 }
19
20 let n = a.len() + b.len() - 1;
21 let fft_len = n.next_power_of_two();
23
24 let mut ca = vec![Complex::zero(); fft_len];
25 let mut cb = vec![Complex::zero(); fft_len];
26
27 for (i, &v) in a.iter().enumerate() {
28 ca[i] = Complex::new(v, S::ZERO);
29 }
30 for (i, &v) in b.iter().enumerate() {
31 cb[i] = Complex::new(v, S::ZERO);
32 }
33
34 let fa = fft(&ca);
35 let fb = fft(&cb);
36
37 let fc: Vec<Complex<S>> = fa.iter().zip(fb.iter()).map(|(&a, &b)| a * b).collect();
39
40 let result = ifft(&fc);
41 result[..n].iter().map(|c| c.re).collect()
42}
43
44pub fn fftcorrelate<S: Scalar>(a: &[S], b: &[S]) -> Vec<S> {
51 if a.is_empty() || b.is_empty() {
52 return vec![];
53 }
54
55 let b_rev: Vec<S> = b.iter().rev().copied().collect();
57 fftconvolve(a, &b_rev)
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[test]
65 fn test_fftconvolve_impulse() {
66 let a = vec![1.0, 2.0, 3.0, 4.0];
68 let b = vec![1.0];
69 let result = fftconvolve(&a, &b);
70 assert_eq!(result.len(), 4);
71 for (i, &v) in result.iter().enumerate() {
72 assert!((v - a[i]).abs() < 1e-12);
73 }
74 }
75
76 #[test]
77 fn test_fftconvolve_known() {
78 let a = vec![1.0, 1.0];
80 let b = vec![1.0, 1.0, 1.0];
81 let result = fftconvolve(&a, &b);
82 assert_eq!(result.len(), 4);
83 let expected = [1.0, 2.0, 2.0, 1.0];
84 for (r, e) in result.iter().zip(expected.iter()) {
85 assert!((r - e).abs() < 1e-12, "{} vs {}", r, e);
86 }
87 }
88
89 #[test]
90 fn test_fftconvolve_commutative() {
91 let a = vec![1.0, 2.0, 3.0];
92 let b = vec![4.0, 5.0];
93 let ab = fftconvolve(&a, &b);
94 let ba = fftconvolve(&b, &a);
95 assert_eq!(ab.len(), ba.len());
96 for (x, y) in ab.iter().zip(ba.iter()) {
97 assert!((x - y).abs() < 1e-12);
98 }
99 }
100
101 #[test]
102 fn test_fftcorrelate_autocorrelation() {
103 let a = vec![1.0, 2.0, 3.0, 2.0, 1.0];
105 let result = fftcorrelate(&a, &a);
106 assert_eq!(result.len(), 9); let peak = result.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
109 assert!((peak - 19.0).abs() < 1e-10);
110 }
111
112 #[test]
113 fn test_fftcorrelate_vs_direct() {
114 let a = vec![1.0, 2.0, 3.0];
116 let b = vec![4.0, 5.0];
117 let result = fftcorrelate(&a, &b);
118 assert_eq!(result.len(), 4); let b_rev = vec![5.0, 4.0];
122 let conv = fftconvolve(&a, &b_rev);
123 assert_eq!(result.len(), conv.len());
124 for (r, c) in result.iter().zip(conv.iter()) {
125 assert!((r - c).abs() < 1e-12, "{} vs {}", r, c);
126 }
127 }
128
129 #[test]
130 fn test_fftconvolve_empty() {
131 assert!(fftconvolve::<f64>(&[], &[1.0]).is_empty());
132 assert!(fftcorrelate::<f64>(&[], &[1.0]).is_empty());
133 }
134}