Skip to main content

axonml_nn/layers/
fft.rs

1//! FFT Layers - Fast Fourier Transform for Neural Networks
2//!
3//! Provides FFT1d and STFT as neural network layers for spectral
4//! feature extraction in signal processing and time series models.
5//!
6//! @version 0.1.0
7//! @author AutomataNexus Development Team
8
9use std::collections::HashMap;
10
11use axonml_autograd::Variable;
12use axonml_tensor::Tensor;
13use rustfft::{FftPlanner, num_complex::Complex};
14
15use crate::module::Module;
16use crate::parameter::Parameter;
17
18// =============================================================================
19// FFT1d
20// =============================================================================
21
22/// 1D Fast Fourier Transform layer.
23///
24/// Computes the magnitude spectrum of the input signal using an efficient
25/// FFT algorithm. Returns the first `n_fft/2 + 1` frequency bins (real FFT).
26///
27/// # Input Shape
28/// `(batch, channels, time)` or `(batch, time)`
29///
30/// # Output Shape
31/// `(batch, channels, n_fft/2+1)` or `(batch, n_fft/2+1)`
32///
33/// # Example
34/// ```ignore
35/// use axonml_nn::layers::FFT1d;
36///
37/// let fft = FFT1d::new(256);
38/// let signal = Variable::new(Tensor::randn(&[2, 1, 256]), true);
39/// let spectrum = fft.forward(&signal);
40/// // spectrum shape: (2, 1, 129)
41/// ```
42pub struct FFT1d {
43    n_fft: usize,
44    normalized: bool,
45}
46
47impl FFT1d {
48    /// Creates a new FFT1d layer.
49    ///
50    /// # Arguments
51    /// * `n_fft` - FFT size. Input will be zero-padded or truncated to this length.
52    pub fn new(n_fft: usize) -> Self {
53        Self {
54            n_fft,
55            normalized: false,
56        }
57    }
58
59    /// Creates an FFT1d layer with optional normalization.
60    ///
61    /// When normalized, the output is divided by `sqrt(n_fft)`.
62    pub fn with_normalization(n_fft: usize, normalized: bool) -> Self {
63        Self { n_fft, normalized }
64    }
65
66    /// Returns the number of output frequency bins.
67    pub fn output_bins(&self) -> usize {
68        self.n_fft / 2 + 1
69    }
70
71    /// Compute FFT magnitude for a single 1D signal.
72    fn fft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
73        let n = self.n_fft;
74        let n_out = n / 2 + 1;
75
76        // Prepare complex input (zero-pad or truncate)
77        let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
78        let copy_len = signal.len().min(n);
79        for i in 0..copy_len {
80            buffer[i] = Complex::new(signal[i], 0.0);
81        }
82
83        // Compute FFT
84        let mut planner = FftPlanner::new();
85        let fft = planner.plan_fft_forward(n);
86        fft.process(&mut buffer);
87
88        // Extract magnitude for positive frequencies
89        let norm_factor = if self.normalized {
90            1.0 / (n as f32).sqrt()
91        } else {
92            1.0
93        };
94
95        let mut magnitude = Vec::with_capacity(n_out);
96        for i in 0..n_out {
97            let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
98            magnitude.push(mag * norm_factor);
99        }
100
101        magnitude
102    }
103}
104
105impl Module for FFT1d {
106    fn forward(&self, input: &Variable) -> Variable {
107        let shape = input.shape();
108        let data = input.data().to_vec();
109        let n_out = self.output_bins();
110
111        match shape.len() {
112            2 => {
113                // (batch, time) → (batch, n_fft/2+1)
114                let batch = shape[0];
115                let time = shape[1];
116                let mut output = Vec::with_capacity(batch * n_out);
117
118                for b in 0..batch {
119                    let start = b * time;
120                    let end = start + time;
121                    let signal = &data[start..end];
122                    output.extend_from_slice(&self.fft_magnitude(signal));
123                }
124
125                Variable::new(
126                    Tensor::from_vec(output, &[batch, n_out]).unwrap(),
127                    input.requires_grad(),
128                )
129            }
130            3 => {
131                // (batch, channels, time) → (batch, channels, n_fft/2+1)
132                let batch = shape[0];
133                let channels = shape[1];
134                let time = shape[2];
135                let mut output = Vec::with_capacity(batch * channels * n_out);
136
137                for b in 0..batch {
138                    for c in 0..channels {
139                        let start = (b * channels + c) * time;
140                        let end = start + time;
141                        let signal = &data[start..end];
142                        output.extend_from_slice(&self.fft_magnitude(signal));
143                    }
144                }
145
146                Variable::new(
147                    Tensor::from_vec(output, &[batch, channels, n_out]).unwrap(),
148                    input.requires_grad(),
149                )
150            }
151            _ => panic!(
152                "FFT1d expects input of shape (batch, time) or (batch, channels, time), got {:?}",
153                shape
154            ),
155        }
156    }
157
158    fn parameters(&self) -> Vec<Parameter> {
159        Vec::new() // FFT has no learnable parameters
160    }
161
162    fn named_parameters(&self) -> HashMap<String, Parameter> {
163        HashMap::new()
164    }
165
166    fn name(&self) -> &'static str {
167        "FFT1d"
168    }
169}
170
171// =============================================================================
172// STFT
173// =============================================================================
174
175/// Short-Time Fourier Transform layer.
176///
177/// Applies a sliding window FFT to compute time-frequency representations.
178/// Uses a Hann window by default.
179///
180/// # Input Shape
181/// `(batch, channels, time)` or `(batch, time)`
182///
183/// # Output Shape
184/// `(batch, channels, n_frames, n_fft/2+1)` or `(batch, n_frames, n_fft/2+1)`
185///
186/// # Example
187/// ```ignore
188/// use axonml_nn::layers::STFT;
189///
190/// let stft = STFT::new(256, 128); // n_fft=256, hop=128
191/// let signal = Variable::new(Tensor::randn(&[2, 1, 1024]), false);
192/// let spec = stft.forward(&signal);
193/// // spec shape: (2, 1, 7, 129)
194/// ```
195pub struct STFT {
196    n_fft: usize,
197    hop_length: usize,
198    window: Vec<f32>,
199    normalized: bool,
200}
201
202impl STFT {
203    /// Creates a new STFT layer with a Hann window.
204    ///
205    /// # Arguments
206    /// * `n_fft` - FFT window size
207    /// * `hop_length` - Number of samples between successive frames
208    pub fn new(n_fft: usize, hop_length: usize) -> Self {
209        let window = hann_window(n_fft);
210        Self {
211            n_fft,
212            hop_length,
213            window,
214            normalized: false,
215        }
216    }
217
218    /// Creates an STFT layer with normalization.
219    pub fn with_normalization(n_fft: usize, hop_length: usize, normalized: bool) -> Self {
220        let window = hann_window(n_fft);
221        Self {
222            n_fft,
223            hop_length,
224            window,
225            normalized,
226        }
227    }
228
229    /// Returns the number of output frequency bins.
230    pub fn output_bins(&self) -> usize {
231        self.n_fft / 2 + 1
232    }
233
234    /// Computes the number of frames for a given signal length.
235    pub fn n_frames(&self, signal_length: usize) -> usize {
236        if signal_length < self.n_fft {
237            1
238        } else {
239            (signal_length - self.n_fft) / self.hop_length + 1
240        }
241    }
242
243    /// Compute STFT magnitude for a single 1D signal.
244    fn stft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
245        let n = self.n_fft;
246        let n_out = n / 2 + 1;
247        let n_frames = self.n_frames(signal.len());
248
249        let norm_factor = if self.normalized {
250            1.0 / (n as f32).sqrt()
251        } else {
252            1.0
253        };
254
255        let mut planner = FftPlanner::new();
256        let fft = planner.plan_fft_forward(n);
257
258        let mut output = Vec::with_capacity(n_frames * n_out);
259
260        for frame in 0..n_frames {
261            let start = frame * self.hop_length;
262
263            // Apply window and create complex buffer
264            let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
265            for i in 0..n {
266                let idx = start + i;
267                let sample = if idx < signal.len() { signal[idx] } else { 0.0 };
268                buffer[i] = Complex::new(sample * self.window[i], 0.0);
269            }
270
271            fft.process(&mut buffer);
272
273            for i in 0..n_out {
274                let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
275                output.push(mag * norm_factor);
276            }
277        }
278
279        output
280    }
281}
282
283impl Module for STFT {
284    fn forward(&self, input: &Variable) -> Variable {
285        let shape = input.shape();
286        let data = input.data().to_vec();
287        let n_out = self.output_bins();
288
289        match shape.len() {
290            2 => {
291                // (batch, time) → (batch, n_frames, n_fft/2+1)
292                let batch = shape[0];
293                let time = shape[1];
294                let n_frames = self.n_frames(time);
295                let mut output = Vec::with_capacity(batch * n_frames * n_out);
296
297                for b in 0..batch {
298                    let start = b * time;
299                    let end = start + time;
300                    let signal = &data[start..end];
301                    output.extend_from_slice(&self.stft_magnitude(signal));
302                }
303
304                Variable::new(
305                    Tensor::from_vec(output, &[batch, n_frames, n_out]).unwrap(),
306                    input.requires_grad(),
307                )
308            }
309            3 => {
310                // (batch, channels, time) → (batch, channels, n_frames, n_fft/2+1)
311                let batch = shape[0];
312                let channels = shape[1];
313                let time = shape[2];
314                let n_frames = self.n_frames(time);
315                let mut output = Vec::with_capacity(batch * channels * n_frames * n_out);
316
317                for b in 0..batch {
318                    for c in 0..channels {
319                        let start = (b * channels + c) * time;
320                        let end = start + time;
321                        let signal = &data[start..end];
322                        output.extend_from_slice(&self.stft_magnitude(signal));
323                    }
324                }
325
326                Variable::new(
327                    Tensor::from_vec(output, &[batch, channels, n_frames, n_out]).unwrap(),
328                    input.requires_grad(),
329                )
330            }
331            _ => panic!(
332                "STFT expects input of shape (batch, time) or (batch, channels, time), got {:?}",
333                shape
334            ),
335        }
336    }
337
338    fn parameters(&self) -> Vec<Parameter> {
339        Vec::new()
340    }
341
342    fn named_parameters(&self) -> HashMap<String, Parameter> {
343        HashMap::new()
344    }
345
346    fn name(&self) -> &'static str {
347        "STFT"
348    }
349}
350
351// =============================================================================
352// Utility Functions
353// =============================================================================
354
355/// Generates a Hann window of the given size.
356fn hann_window(size: usize) -> Vec<f32> {
357    (0..size)
358        .map(|i| {
359            let phase = 2.0 * std::f32::consts::PI * i as f32 / (size - 1) as f32;
360            0.5 * (1.0 - phase.cos())
361        })
362        .collect()
363}
364
365// =============================================================================
366// Tests
367// =============================================================================
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_fft1d_shape_2d() {
375        let fft = FFT1d::new(64);
376        let input = Variable::new(
377            Tensor::from_vec(vec![0.0; 128], &[2, 64]).unwrap(),
378            false,
379        );
380        let output = fft.forward(&input);
381        assert_eq!(output.shape(), vec![2, 33]); // n_fft/2+1 = 33
382    }
383
384    #[test]
385    fn test_fft1d_shape_3d() {
386        let fft = FFT1d::new(128);
387        let input = Variable::new(
388            Tensor::from_vec(vec![0.0; 2 * 3 * 128], &[2, 3, 128]).unwrap(),
389            false,
390        );
391        let output = fft.forward(&input);
392        assert_eq!(output.shape(), vec![2, 3, 65]); // n_fft/2+1 = 65
393    }
394
395    #[test]
396    fn test_fft1d_known_sinusoid() {
397        // Create a pure 10 Hz sinusoid sampled at 64 Hz
398        let n = 64;
399        let freq = 10.0;
400        let sample_rate = 64.0;
401        let signal: Vec<f32> = (0..n)
402            .map(|i| {
403                let t = i as f32 / sample_rate;
404                (2.0 * std::f32::consts::PI * freq * t).sin()
405            })
406            .collect();
407
408        let fft = FFT1d::new(n);
409        let input = Variable::new(
410            Tensor::from_vec(signal, &[1, n]).unwrap(),
411            false,
412        );
413        let output = fft.forward(&input);
414        let spectrum = output.data().to_vec();
415
416        // The peak should be at bin 10 (freq * n / sample_rate = 10)
417        let peak_bin = spectrum
418            .iter()
419            .enumerate()
420            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
421            .unwrap()
422            .0;
423        assert_eq!(peak_bin, 10);
424    }
425
426    #[test]
427    fn test_fft1d_zero_padding() {
428        // Input shorter than n_fft gets zero-padded
429        let fft = FFT1d::new(128);
430        let input = Variable::new(
431            Tensor::from_vec(vec![1.0; 32], &[1, 32]).unwrap(),
432            false,
433        );
434        let output = fft.forward(&input);
435        assert_eq!(output.shape(), vec![1, 65]);
436    }
437
438    #[test]
439    fn test_fft1d_normalized() {
440        let fft_norm = FFT1d::with_normalization(64, true);
441        let fft_raw = FFT1d::new(64);
442
443        let signal = vec![1.0; 64];
444        let input = Variable::new(
445            Tensor::from_vec(signal, &[1, 64]).unwrap(),
446            false,
447        );
448
449        let out_norm = fft_norm.forward(&input).data().to_vec();
450        let out_raw = fft_raw.forward(&input).data().to_vec();
451
452        // Normalized should be raw / sqrt(64) = raw / 8
453        let ratio = out_raw[0] / out_norm[0];
454        assert!((ratio - 8.0).abs() < 0.01);
455    }
456
457    #[test]
458    fn test_stft_shape() {
459        let stft = STFT::new(256, 128);
460        let input = Variable::new(
461            Tensor::from_vec(vec![0.0; 2 * 1024], &[2, 1024]).unwrap(),
462            false,
463        );
464        let output = stft.forward(&input);
465
466        let n_frames = stft.n_frames(1024); // (1024 - 256) / 128 + 1 = 7
467        assert_eq!(output.shape(), vec![2, n_frames, 129]);
468        assert_eq!(n_frames, 7);
469    }
470
471    #[test]
472    fn test_stft_shape_3d() {
473        let stft = STFT::new(64, 32);
474        let input = Variable::new(
475            Tensor::from_vec(vec![0.0; 2 * 3 * 256], &[2, 3, 256]).unwrap(),
476            false,
477        );
478        let output = stft.forward(&input);
479
480        let n_frames = stft.n_frames(256); // (256 - 64) / 32 + 1 = 7
481        assert_eq!(output.shape(), vec![2, 3, n_frames, 33]);
482    }
483
484    #[test]
485    fn test_stft_no_parameters() {
486        let stft = STFT::new(256, 128);
487        assert_eq!(stft.parameters().len(), 0);
488    }
489
490    #[test]
491    fn test_fft1d_output_bins() {
492        assert_eq!(FFT1d::new(64).output_bins(), 33);
493        assert_eq!(FFT1d::new(256).output_bins(), 129);
494        assert_eq!(FFT1d::new(512).output_bins(), 257);
495    }
496
497    #[test]
498    fn test_hann_window() {
499        let w = hann_window(4);
500        // Hann window for size 4: [0, 0.75, 0.75, 0]
501        assert!((w[0]).abs() < 1e-6);
502        assert!((w[1] - 0.75).abs() < 0.01);
503        assert!((w[2] - 0.75).abs() < 0.01);
504        assert!((w[3]).abs() < 1e-6);
505    }
506}