Skip to main content

axonml_nn/layers/
fft.rs

1//! FFT Layers - Fast Fourier Transform for Neural Networks
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/fft.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21use rustfft::{FftPlanner, num_complex::Complex};
22
23use crate::module::Module;
24use crate::parameter::Parameter;
25
26// =============================================================================
27// FFT1d
28// =============================================================================
29
30/// 1D Fast Fourier Transform layer.
31///
32/// Computes the magnitude spectrum of the input signal using an efficient
33/// FFT algorithm. Returns the first `n_fft/2 + 1` frequency bins (real FFT).
34///
35/// # Input Shape
36/// `(batch, channels, time)` or `(batch, time)`
37///
38/// # Output Shape
39/// `(batch, channels, n_fft/2+1)` or `(batch, n_fft/2+1)`
40///
41/// # Example
42/// ```ignore
43/// use axonml_nn::layers::FFT1d;
44///
45/// let fft = FFT1d::new(256);
46/// let signal = Variable::new(Tensor::randn(&[2, 1, 256]), true);
47/// let spectrum = fft.forward(&signal);
48/// // spectrum shape: (2, 1, 129)
49/// ```
50pub struct FFT1d {
51    n_fft: usize,
52    normalized: bool,
53}
54
55impl FFT1d {
56    /// Creates a new FFT1d layer.
57    ///
58    /// # Arguments
59    /// * `n_fft` - FFT size. Input will be zero-padded or truncated to this length.
60    pub fn new(n_fft: usize) -> Self {
61        Self {
62            n_fft,
63            normalized: false,
64        }
65    }
66
67    /// Creates an FFT1d layer with optional normalization.
68    ///
69    /// When normalized, the output is divided by `sqrt(n_fft)`.
70    pub fn with_normalization(n_fft: usize, normalized: bool) -> Self {
71        Self { n_fft, normalized }
72    }
73
74    /// Returns the number of output frequency bins.
75    pub fn output_bins(&self) -> usize {
76        self.n_fft / 2 + 1
77    }
78
79    /// Compute FFT magnitude for a single 1D signal.
80    fn fft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
81        let n = self.n_fft;
82        let n_out = n / 2 + 1;
83
84        // Prepare complex input (zero-pad or truncate)
85        let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
86        let copy_len = signal.len().min(n);
87        for i in 0..copy_len {
88            buffer[i] = Complex::new(signal[i], 0.0);
89        }
90
91        // Compute FFT
92        let mut planner = FftPlanner::new();
93        let fft = planner.plan_fft_forward(n);
94        fft.process(&mut buffer);
95
96        // Extract magnitude for positive frequencies
97        let norm_factor = if self.normalized {
98            1.0 / (n as f32).sqrt()
99        } else {
100            1.0
101        };
102
103        let mut magnitude = Vec::with_capacity(n_out);
104        for i in 0..n_out {
105            let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
106            magnitude.push(mag * norm_factor);
107        }
108
109        magnitude
110    }
111}
112
113impl Module for FFT1d {
114    fn forward(&self, input: &Variable) -> Variable {
115        let shape = input.shape();
116        let data = input.data().to_vec();
117        let n_out = self.output_bins();
118
119        match shape.len() {
120            2 => {
121                // (batch, time) → (batch, n_fft/2+1)
122                let batch = shape[0];
123                let time = shape[1];
124                let mut output = Vec::with_capacity(batch * n_out);
125
126                for b in 0..batch {
127                    let start = b * time;
128                    let end = start + time;
129                    let signal = &data[start..end];
130                    output.extend_from_slice(&self.fft_magnitude(signal));
131                }
132
133                Variable::new(
134                    Tensor::from_vec(output, &[batch, n_out]).unwrap(),
135                    input.requires_grad(),
136                )
137            }
138            3 => {
139                // (batch, channels, time) → (batch, channels, n_fft/2+1)
140                let batch = shape[0];
141                let channels = shape[1];
142                let time = shape[2];
143                let mut output = Vec::with_capacity(batch * channels * n_out);
144
145                for b in 0..batch {
146                    for c in 0..channels {
147                        let start = (b * channels + c) * time;
148                        let end = start + time;
149                        let signal = &data[start..end];
150                        output.extend_from_slice(&self.fft_magnitude(signal));
151                    }
152                }
153
154                Variable::new(
155                    Tensor::from_vec(output, &[batch, channels, n_out]).unwrap(),
156                    input.requires_grad(),
157                )
158            }
159            _ => panic!(
160                "FFT1d expects input of shape (batch, time) or (batch, channels, time), got {:?}",
161                shape
162            ),
163        }
164    }
165
166    fn parameters(&self) -> Vec<Parameter> {
167        Vec::new() // FFT has no learnable parameters
168    }
169
170    fn named_parameters(&self) -> HashMap<String, Parameter> {
171        HashMap::new()
172    }
173
174    fn name(&self) -> &'static str {
175        "FFT1d"
176    }
177}
178
179// =============================================================================
180// STFT
181// =============================================================================
182
183/// Short-Time Fourier Transform layer.
184///
185/// Applies a sliding window FFT to compute time-frequency representations.
186/// Uses a Hann window by default.
187///
188/// # Input Shape
189/// `(batch, channels, time)` or `(batch, time)`
190///
191/// # Output Shape
192/// `(batch, channels, n_frames, n_fft/2+1)` or `(batch, n_frames, n_fft/2+1)`
193///
194/// # Example
195/// ```ignore
196/// use axonml_nn::layers::STFT;
197///
198/// let stft = STFT::new(256, 128); // n_fft=256, hop=128
199/// let signal = Variable::new(Tensor::randn(&[2, 1, 1024]), false);
200/// let spec = stft.forward(&signal);
201/// // spec shape: (2, 1, 7, 129)
202/// ```
203pub struct STFT {
204    n_fft: usize,
205    hop_length: usize,
206    window: Vec<f32>,
207    normalized: bool,
208}
209
210impl STFT {
211    /// Creates a new STFT layer with a Hann window.
212    ///
213    /// # Arguments
214    /// * `n_fft` - FFT window size
215    /// * `hop_length` - Number of samples between successive frames
216    pub fn new(n_fft: usize, hop_length: usize) -> Self {
217        let window = hann_window(n_fft);
218        Self {
219            n_fft,
220            hop_length,
221            window,
222            normalized: false,
223        }
224    }
225
226    /// Creates an STFT layer with normalization.
227    pub fn with_normalization(n_fft: usize, hop_length: usize, normalized: bool) -> Self {
228        let window = hann_window(n_fft);
229        Self {
230            n_fft,
231            hop_length,
232            window,
233            normalized,
234        }
235    }
236
237    /// Returns the number of output frequency bins.
238    pub fn output_bins(&self) -> usize {
239        self.n_fft / 2 + 1
240    }
241
242    /// Computes the number of frames for a given signal length.
243    pub fn n_frames(&self, signal_length: usize) -> usize {
244        if signal_length < self.n_fft {
245            1
246        } else {
247            (signal_length - self.n_fft) / self.hop_length + 1
248        }
249    }
250
251    /// Compute STFT magnitude for a single 1D signal.
252    fn stft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
253        let n = self.n_fft;
254        let n_out = n / 2 + 1;
255        let n_frames = self.n_frames(signal.len());
256
257        let norm_factor = if self.normalized {
258            1.0 / (n as f32).sqrt()
259        } else {
260            1.0
261        };
262
263        let mut planner = FftPlanner::new();
264        let fft = planner.plan_fft_forward(n);
265
266        let mut output = Vec::with_capacity(n_frames * n_out);
267
268        for frame in 0..n_frames {
269            let start = frame * self.hop_length;
270
271            // Apply window and create complex buffer
272            let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
273            for i in 0..n {
274                let idx = start + i;
275                let sample = if idx < signal.len() { signal[idx] } else { 0.0 };
276                buffer[i] = Complex::new(sample * self.window[i], 0.0);
277            }
278
279            fft.process(&mut buffer);
280
281            for i in 0..n_out {
282                let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
283                output.push(mag * norm_factor);
284            }
285        }
286
287        output
288    }
289}
290
291impl Module for STFT {
292    fn forward(&self, input: &Variable) -> Variable {
293        let shape = input.shape();
294        let data = input.data().to_vec();
295        let n_out = self.output_bins();
296
297        match shape.len() {
298            2 => {
299                // (batch, time) → (batch, n_frames, n_fft/2+1)
300                let batch = shape[0];
301                let time = shape[1];
302                let n_frames = self.n_frames(time);
303                let mut output = Vec::with_capacity(batch * n_frames * n_out);
304
305                for b in 0..batch {
306                    let start = b * time;
307                    let end = start + time;
308                    let signal = &data[start..end];
309                    output.extend_from_slice(&self.stft_magnitude(signal));
310                }
311
312                Variable::new(
313                    Tensor::from_vec(output, &[batch, n_frames, n_out]).unwrap(),
314                    input.requires_grad(),
315                )
316            }
317            3 => {
318                // (batch, channels, time) → (batch, channels, n_frames, n_fft/2+1)
319                let batch = shape[0];
320                let channels = shape[1];
321                let time = shape[2];
322                let n_frames = self.n_frames(time);
323                let mut output = Vec::with_capacity(batch * channels * n_frames * n_out);
324
325                for b in 0..batch {
326                    for c in 0..channels {
327                        let start = (b * channels + c) * time;
328                        let end = start + time;
329                        let signal = &data[start..end];
330                        output.extend_from_slice(&self.stft_magnitude(signal));
331                    }
332                }
333
334                Variable::new(
335                    Tensor::from_vec(output, &[batch, channels, n_frames, n_out]).unwrap(),
336                    input.requires_grad(),
337                )
338            }
339            _ => panic!(
340                "STFT expects input of shape (batch, time) or (batch, channels, time), got {:?}",
341                shape
342            ),
343        }
344    }
345
346    fn parameters(&self) -> Vec<Parameter> {
347        Vec::new()
348    }
349
350    fn named_parameters(&self) -> HashMap<String, Parameter> {
351        HashMap::new()
352    }
353
354    fn name(&self) -> &'static str {
355        "STFT"
356    }
357}
358
359// =============================================================================
360// Utility Functions
361// =============================================================================
362
363/// Generates a Hann window of the given size.
364fn hann_window(size: usize) -> Vec<f32> {
365    (0..size)
366        .map(|i| {
367            let phase = 2.0 * std::f32::consts::PI * i as f32 / (size - 1) as f32;
368            0.5 * (1.0 - phase.cos())
369        })
370        .collect()
371}
372
373// =============================================================================
374// Tests
375// =============================================================================
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_fft1d_shape_2d() {
383        let fft = FFT1d::new(64);
384        let input = Variable::new(Tensor::from_vec(vec![0.0; 128], &[2, 64]).unwrap(), false);
385        let output = fft.forward(&input);
386        assert_eq!(output.shape(), vec![2, 33]); // n_fft/2+1 = 33
387    }
388
389    #[test]
390    fn test_fft1d_shape_3d() {
391        let fft = FFT1d::new(128);
392        let input = Variable::new(
393            Tensor::from_vec(vec![0.0; 2 * 3 * 128], &[2, 3, 128]).unwrap(),
394            false,
395        );
396        let output = fft.forward(&input);
397        assert_eq!(output.shape(), vec![2, 3, 65]); // n_fft/2+1 = 65
398    }
399
400    #[test]
401    fn test_fft1d_known_sinusoid() {
402        // Create a pure 10 Hz sinusoid sampled at 64 Hz
403        let n = 64;
404        let freq = 10.0;
405        let sample_rate = 64.0;
406        let signal: Vec<f32> = (0..n)
407            .map(|i| {
408                let t = i as f32 / sample_rate;
409                (2.0 * std::f32::consts::PI * freq * t).sin()
410            })
411            .collect();
412
413        let fft = FFT1d::new(n);
414        let input = Variable::new(Tensor::from_vec(signal, &[1, n]).unwrap(), false);
415        let output = fft.forward(&input);
416        let spectrum = output.data().to_vec();
417
418        // The peak should be at bin 10 (freq * n / sample_rate = 10)
419        let peak_bin = spectrum
420            .iter()
421            .enumerate()
422            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
423            .unwrap()
424            .0;
425        assert_eq!(peak_bin, 10);
426    }
427
428    #[test]
429    fn test_fft1d_zero_padding() {
430        // Input shorter than n_fft gets zero-padded
431        let fft = FFT1d::new(128);
432        let input = Variable::new(Tensor::from_vec(vec![1.0; 32], &[1, 32]).unwrap(), false);
433        let output = fft.forward(&input);
434        assert_eq!(output.shape(), vec![1, 65]);
435    }
436
437    #[test]
438    fn test_fft1d_normalized() {
439        let fft_norm = FFT1d::with_normalization(64, true);
440        let fft_raw = FFT1d::new(64);
441
442        let signal = vec![1.0; 64];
443        let input = Variable::new(Tensor::from_vec(signal, &[1, 64]).unwrap(), false);
444
445        let out_norm = fft_norm.forward(&input).data().to_vec();
446        let out_raw = fft_raw.forward(&input).data().to_vec();
447
448        // Normalized should be raw / sqrt(64) = raw / 8
449        let ratio = out_raw[0] / out_norm[0];
450        assert!((ratio - 8.0).abs() < 0.01);
451    }
452
453    #[test]
454    fn test_stft_shape() {
455        let stft = STFT::new(256, 128);
456        let input = Variable::new(
457            Tensor::from_vec(vec![0.0; 2 * 1024], &[2, 1024]).unwrap(),
458            false,
459        );
460        let output = stft.forward(&input);
461
462        let n_frames = stft.n_frames(1024); // (1024 - 256) / 128 + 1 = 7
463        assert_eq!(output.shape(), vec![2, n_frames, 129]);
464        assert_eq!(n_frames, 7);
465    }
466
467    #[test]
468    fn test_stft_shape_3d() {
469        let stft = STFT::new(64, 32);
470        let input = Variable::new(
471            Tensor::from_vec(vec![0.0; 2 * 3 * 256], &[2, 3, 256]).unwrap(),
472            false,
473        );
474        let output = stft.forward(&input);
475
476        let n_frames = stft.n_frames(256); // (256 - 64) / 32 + 1 = 7
477        assert_eq!(output.shape(), vec![2, 3, n_frames, 33]);
478    }
479
480    #[test]
481    fn test_stft_no_parameters() {
482        let stft = STFT::new(256, 128);
483        assert_eq!(stft.parameters().len(), 0);
484    }
485
486    #[test]
487    fn test_fft1d_output_bins() {
488        assert_eq!(FFT1d::new(64).output_bins(), 33);
489        assert_eq!(FFT1d::new(256).output_bins(), 129);
490        assert_eq!(FFT1d::new(512).output_bins(), 257);
491    }
492
493    #[test]
494    fn test_hann_window() {
495        let w = hann_window(4);
496        // Hann window for size 4: [0, 0.75, 0.75, 0]
497        assert!((w[0]).abs() < 1e-6);
498        assert!((w[1] - 0.75).abs() < 0.01);
499        assert!((w[2] - 0.75).abs() < 0.01);
500        assert!((w[3]).abs() < 1e-6);
501    }
502}