1use crate::error::{CirculantError, Result};
11use crate::fft::backend::FftBackend;
12use crate::traits::Scalar;
13use num_complex::Complex;
14use rustfft::{Fft, FftPlanner};
15use std::sync::Arc;
16
17pub struct RustFftBackend<T: Scalar + rustfft::FftNum> {
21 forward: Arc<dyn Fft<T>>,
22 inverse: Arc<dyn Fft<T>>,
23 size: usize,
24 scratch_len: usize,
25}
26
27impl<T: Scalar + rustfft::FftNum> RustFftBackend<T> {
28 pub fn new(size: usize) -> Result<Self> {
34 if size == 0 {
35 return Err(CirculantError::InvalidFftSize(0));
36 }
37
38 let mut planner = FftPlanner::new();
39 let forward = planner.plan_fft_forward(size);
40 let inverse = planner.plan_fft_inverse(size);
41
42 let scratch_len = forward
43 .get_inplace_scratch_len()
44 .max(inverse.get_inplace_scratch_len());
45
46 Ok(Self {
47 forward,
48 inverse,
49 size,
50 scratch_len,
51 })
52 }
53}
54
55impl<T: Scalar + rustfft::FftNum> FftBackend<T> for RustFftBackend<T> {
56 fn fft_forward(&self, buffer: &mut [Complex<T>]) {
57 let mut scratch = vec![Complex::new(T::zero(), T::zero()); self.scratch_len];
58 self.forward.process_with_scratch(buffer, &mut scratch);
59 }
60
61 fn fft_inverse(&self, buffer: &mut [Complex<T>]) {
62 let mut scratch = vec![Complex::new(T::zero(), T::zero()); self.scratch_len];
63 self.inverse.process_with_scratch(buffer, &mut scratch);
64
65 let scale = T::one() / T::from(self.size).unwrap_or_else(|| T::one());
67 for val in buffer.iter_mut() {
68 *val = Complex::new(val.re * scale, val.im * scale);
69 }
70 }
71
72 fn size(&self) -> usize {
73 self.size
74 }
75
76 fn make_scratch(&self) -> Vec<Complex<T>> {
77 vec![Complex::new(T::zero(), T::zero()); self.scratch_len]
78 }
79}
80
81unsafe impl<T: Scalar + rustfft::FftNum> Send for RustFftBackend<T> {}
83unsafe impl<T: Scalar + rustfft::FftNum> Sync for RustFftBackend<T> {}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use approx::assert_relative_eq;
89 use std::f64::consts::PI;
90
91 #[test]
92 fn test_fft_forward_inverse_roundtrip() {
93 let backend = RustFftBackend::<f64>::new(8).unwrap();
94
95 let original: Vec<Complex<f64>> = (1..=8).map(|x| Complex::new(x as f64, 0.0)).collect();
97
98 let mut buffer = original.clone();
99
100 backend.fft_forward(&mut buffer);
102
103 backend.fft_inverse(&mut buffer);
105
106 for (orig, result) in original.iter().zip(buffer.iter()) {
108 assert_relative_eq!(orig.re, result.re, epsilon = 1e-10);
109 assert_relative_eq!(orig.im, result.im, epsilon = 1e-10);
110 }
111 }
112
113 #[test]
114 fn test_fft_known_values() {
115 let backend = RustFftBackend::<f64>::new(4).unwrap();
117
118 let mut buffer = vec![
119 Complex::new(1.0, 0.0),
120 Complex::new(0.0, 0.0),
121 Complex::new(0.0, 0.0),
122 Complex::new(0.0, 0.0),
123 ];
124
125 backend.fft_forward(&mut buffer);
126
127 for val in &buffer {
128 assert_relative_eq!(val.re, 1.0, epsilon = 1e-10);
129 assert_relative_eq!(val.im, 0.0, epsilon = 1e-10);
130 }
131 }
132
133 #[test]
134 fn test_fft_known_values_sinusoid() {
135 let n = 8;
137 let backend = RustFftBackend::<f64>::new(n).unwrap();
138
139 let mut buffer: Vec<Complex<f64>> = (0..n)
141 .map(|k| {
142 let theta = 2.0 * PI * (k as f64) / (n as f64);
143 Complex::new(theta.cos(), 0.0)
144 })
145 .collect();
146
147 backend.fft_forward(&mut buffer);
148
149 assert_relative_eq!(buffer[0].re, 0.0, epsilon = 1e-10);
153 assert_relative_eq!(buffer[1].re, 4.0, epsilon = 1e-10);
154 assert_relative_eq!(buffer[7].re, 4.0, epsilon = 1e-10);
155
156 for i in [2, 3, 4, 5, 6] {
158 assert_relative_eq!(buffer[i].re.abs(), 0.0, epsilon = 1e-10);
159 assert_relative_eq!(buffer[i].im.abs(), 0.0, epsilon = 1e-10);
160 }
161 }
162
163 #[test]
164 fn test_fft_linearity() {
165 let backend = RustFftBackend::<f64>::new(4).unwrap();
166
167 let x: Vec<Complex<f64>> = vec![
168 Complex::new(1.0, 0.0),
169 Complex::new(2.0, 0.0),
170 Complex::new(3.0, 0.0),
171 Complex::new(4.0, 0.0),
172 ];
173 let y: Vec<Complex<f64>> = vec![
174 Complex::new(5.0, 0.0),
175 Complex::new(6.0, 0.0),
176 Complex::new(7.0, 0.0),
177 Complex::new(8.0, 0.0),
178 ];
179
180 let mut fft_x = x.clone();
182 let mut fft_y = y.clone();
183 backend.fft_forward(&mut fft_x);
184 backend.fft_forward(&mut fft_y);
185
186 let mut sum: Vec<Complex<f64>> = x.iter().zip(y.iter()).map(|(a, b)| a + b).collect();
188 backend.fft_forward(&mut sum);
189
190 for i in 0..4 {
192 let expected = fft_x[i] + fft_y[i];
193 assert_relative_eq!(sum[i].re, expected.re, epsilon = 1e-10);
194 assert_relative_eq!(sum[i].im, expected.im, epsilon = 1e-10);
195 }
196
197 let alpha = Complex::new(2.5, 0.0);
199 let mut scaled_x: Vec<Complex<f64>> = x.iter().map(|v| alpha * v).collect();
200 backend.fft_forward(&mut scaled_x);
201
202 let mut original_fft = x.clone();
203 backend.fft_forward(&mut original_fft);
204
205 for i in 0..4 {
206 let expected = alpha * original_fft[i];
207 assert_relative_eq!(scaled_x[i].re, expected.re, epsilon = 1e-10);
208 assert_relative_eq!(scaled_x[i].im, expected.im, epsilon = 1e-10);
209 }
210 }
211
212 #[test]
213 fn test_fft_size() {
214 let backend = RustFftBackend::<f64>::new(16).unwrap();
215 assert_eq!(backend.size(), 16);
216 }
217
218 #[test]
219 fn test_fft_zero_size_returns_error() {
220 let result = RustFftBackend::<f64>::new(0);
221 assert!(matches!(
222 result,
223 Err(crate::error::CirculantError::InvalidFftSize(0))
224 ));
225 }
226
227 #[test]
228 fn test_fft_f32() {
229 let backend = RustFftBackend::<f32>::new(4).unwrap();
230
231 let original: Vec<Complex<f32>> = vec![
232 Complex::new(1.0, 0.0),
233 Complex::new(2.0, 0.0),
234 Complex::new(3.0, 0.0),
235 Complex::new(4.0, 0.0),
236 ];
237
238 let mut buffer = original.clone();
239 backend.fft_forward(&mut buffer);
240 backend.fft_inverse(&mut buffer);
241
242 for (orig, result) in original.iter().zip(buffer.iter()) {
243 assert_relative_eq!(orig.re, result.re, epsilon = 1e-5);
244 assert_relative_eq!(orig.im, result.im, epsilon = 1e-5);
245 }
246 }
247}