Skip to main content

axonml_nn/layers/
fft.rs

1//! FFT Layers - Fast Fourier Transform for Neural Networks
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/fft.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21use rustfft::{FftPlanner, num_complex::Complex};
22
23use crate::module::Module;
24use crate::parameter::Parameter;
25
26// =============================================================================
27// FFT1d
28// =============================================================================
29
30/// 1D Fast Fourier Transform layer.
31///
32/// Computes the magnitude spectrum of the input signal using an efficient
33/// FFT algorithm. Returns the first `n_fft/2 + 1` frequency bins (real FFT).
34///
35/// # Input Shape
36/// `(batch, channels, time)` or `(batch, time)` — panics on other ranks
37///
38/// # Output Shape
39/// `(batch, channels, n_fft/2+1)` or `(batch, n_fft/2+1)`
40///
41/// # Example
42/// ```ignore
43/// use axonml_nn::layers::FFT1d;
44///
45/// let fft = FFT1d::new(256);
46/// let signal = Variable::new(Tensor::randn(&[2, 1, 256]), true);
47/// let spectrum = fft.forward(&signal);
48/// // spectrum shape: (2, 1, 129)
49/// ```
50pub struct FFT1d {
51    n_fft: usize,
52    normalized: bool,
53}
54
55impl FFT1d {
56    /// Creates a new FFT1d layer.
57    ///
58    /// # Arguments
59    /// * `n_fft` - FFT size. Input will be zero-padded or truncated to this length.
60    pub fn new(n_fft: usize) -> Self {
61        Self {
62            n_fft,
63            normalized: false,
64        }
65    }
66
67    /// Creates an FFT1d layer with optional normalization.
68    ///
69    /// When normalized, the output is divided by `sqrt(n_fft)`.
70    pub fn with_normalization(n_fft: usize, normalized: bool) -> Self {
71        Self { n_fft, normalized }
72    }
73
74    /// Returns the number of output frequency bins.
75    pub fn output_bins(&self) -> usize {
76        self.n_fft / 2 + 1
77    }
78
79    /// Compute FFT magnitude for a single 1D signal.
80    fn fft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
81        let n = self.n_fft;
82        let n_out = n / 2 + 1;
83
84        // Prepare complex input (zero-pad or truncate)
85        let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
86        let copy_len = signal.len().min(n);
87        for i in 0..copy_len {
88            buffer[i] = Complex::new(signal[i], 0.0);
89        }
90
91        // Compute FFT
92        let mut planner = FftPlanner::new();
93        let fft = planner.plan_fft_forward(n);
94        fft.process(&mut buffer);
95
96        // Extract magnitude for positive frequencies
97        let norm_factor = if self.normalized {
98            1.0 / (n as f32).sqrt()
99        } else {
100            1.0
101        };
102
103        let mut magnitude = Vec::with_capacity(n_out);
104        for i in 0..n_out {
105            let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
106            magnitude.push(mag * norm_factor);
107        }
108
109        magnitude
110    }
111}
112
113impl Module for FFT1d {
114    fn forward(&self, input: &Variable) -> Variable {
115        let shape = input.shape();
116        let data = input.data().to_vec();
117        let n_out = self.output_bins();
118
119        match shape.len() {
120            2 => {
121                // (batch, time) → (batch, n_fft/2+1)
122                let batch = shape[0];
123                let time = shape[1];
124                let mut output = Vec::with_capacity(batch * n_out);
125
126                for b in 0..batch {
127                    let start = b * time;
128                    let end = start + time;
129                    let signal = &data[start..end];
130                    output.extend_from_slice(&self.fft_magnitude(signal));
131                }
132
133                Variable::new(
134                    Tensor::from_vec(output, &[batch, n_out]).expect("tensor creation failed"),
135                    input.requires_grad(),
136                )
137            }
138            3 => {
139                // (batch, channels, time) → (batch, channels, n_fft/2+1)
140                let batch = shape[0];
141                let channels = shape[1];
142                let time = shape[2];
143                let mut output = Vec::with_capacity(batch * channels * n_out);
144
145                for b in 0..batch {
146                    for c in 0..channels {
147                        let start = (b * channels + c) * time;
148                        let end = start + time;
149                        let signal = &data[start..end];
150                        output.extend_from_slice(&self.fft_magnitude(signal));
151                    }
152                }
153
154                Variable::new(
155                    Tensor::from_vec(output, &[batch, channels, n_out])
156                        .expect("tensor creation failed"),
157                    input.requires_grad(),
158                )
159            }
160            _ => panic!(
161                "FFT1d expects input of shape (batch, time) or (batch, channels, time), got {:?}",
162                shape
163            ),
164        }
165    }
166
167    fn parameters(&self) -> Vec<Parameter> {
168        Vec::new() // FFT has no learnable parameters
169    }
170
171    fn named_parameters(&self) -> HashMap<String, Parameter> {
172        HashMap::new()
173    }
174
175    fn name(&self) -> &'static str {
176        "FFT1d"
177    }
178}
179
180// =============================================================================
181// STFT
182// =============================================================================
183
184/// Short-Time Fourier Transform layer.
185///
186/// Applies a sliding window FFT to compute time-frequency representations.
187/// Uses a Hann window by default.
188///
189/// # Input Shape
190/// `(batch, channels, time)` or `(batch, time)` — panics on other ranks
191///
192/// # Output Shape
193/// `(batch, channels, n_frames, n_fft/2+1)` or `(batch, n_frames, n_fft/2+1)`
194///
195/// # Example
196/// ```ignore
197/// use axonml_nn::layers::STFT;
198///
199/// let stft = STFT::new(256, 128); // n_fft=256, hop=128
200/// let signal = Variable::new(Tensor::randn(&[2, 1, 1024]), false);
201/// let spec = stft.forward(&signal);
202/// // spec shape: (2, 1, 7, 129)
203/// ```
204pub struct STFT {
205    n_fft: usize,
206    hop_length: usize,
207    window: Vec<f32>,
208    normalized: bool,
209}
210
211impl STFT {
212    /// Creates a new STFT layer with a Hann window.
213    ///
214    /// # Arguments
215    /// * `n_fft` - FFT window size
216    /// * `hop_length` - Number of samples between successive frames
217    pub fn new(n_fft: usize, hop_length: usize) -> Self {
218        let window = hann_window(n_fft);
219        Self {
220            n_fft,
221            hop_length,
222            window,
223            normalized: false,
224        }
225    }
226
227    /// Creates an STFT layer with normalization.
228    pub fn with_normalization(n_fft: usize, hop_length: usize, normalized: bool) -> Self {
229        let window = hann_window(n_fft);
230        Self {
231            n_fft,
232            hop_length,
233            window,
234            normalized,
235        }
236    }
237
238    /// Returns the number of output frequency bins.
239    pub fn output_bins(&self) -> usize {
240        self.n_fft / 2 + 1
241    }
242
243    /// Computes the number of frames for a given signal length.
244    pub fn n_frames(&self, signal_length: usize) -> usize {
245        if signal_length < self.n_fft {
246            1
247        } else {
248            (signal_length - self.n_fft) / self.hop_length + 1
249        }
250    }
251
252    /// Compute STFT magnitude for a single 1D signal.
253    fn stft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
254        let n = self.n_fft;
255        let n_out = n / 2 + 1;
256        let n_frames = self.n_frames(signal.len());
257
258        let norm_factor = if self.normalized {
259            1.0 / (n as f32).sqrt()
260        } else {
261            1.0
262        };
263
264        let mut planner = FftPlanner::new();
265        let fft = planner.plan_fft_forward(n);
266
267        let mut output = Vec::with_capacity(n_frames * n_out);
268
269        for frame in 0..n_frames {
270            let start = frame * self.hop_length;
271
272            // Apply window and create complex buffer
273            let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
274            for i in 0..n {
275                let idx = start + i;
276                let sample = if idx < signal.len() { signal[idx] } else { 0.0 };
277                buffer[i] = Complex::new(sample * self.window[i], 0.0);
278            }
279
280            fft.process(&mut buffer);
281
282            for i in 0..n_out {
283                let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
284                output.push(mag * norm_factor);
285            }
286        }
287
288        output
289    }
290}
291
292impl Module for STFT {
293    fn forward(&self, input: &Variable) -> Variable {
294        let shape = input.shape();
295        let data = input.data().to_vec();
296        let n_out = self.output_bins();
297
298        match shape.len() {
299            2 => {
300                // (batch, time) → (batch, n_frames, n_fft/2+1)
301                let batch = shape[0];
302                let time = shape[1];
303                let n_frames = self.n_frames(time);
304                let mut output = Vec::with_capacity(batch * n_frames * n_out);
305
306                for b in 0..batch {
307                    let start = b * time;
308                    let end = start + time;
309                    let signal = &data[start..end];
310                    output.extend_from_slice(&self.stft_magnitude(signal));
311                }
312
313                Variable::new(
314                    Tensor::from_vec(output, &[batch, n_frames, n_out])
315                        .expect("tensor creation failed"),
316                    input.requires_grad(),
317                )
318            }
319            3 => {
320                // (batch, channels, time) → (batch, channels, n_frames, n_fft/2+1)
321                let batch = shape[0];
322                let channels = shape[1];
323                let time = shape[2];
324                let n_frames = self.n_frames(time);
325                let mut output = Vec::with_capacity(batch * channels * n_frames * n_out);
326
327                for b in 0..batch {
328                    for c in 0..channels {
329                        let start = (b * channels + c) * time;
330                        let end = start + time;
331                        let signal = &data[start..end];
332                        output.extend_from_slice(&self.stft_magnitude(signal));
333                    }
334                }
335
336                Variable::new(
337                    Tensor::from_vec(output, &[batch, channels, n_frames, n_out])
338                        .expect("tensor creation failed"),
339                    input.requires_grad(),
340                )
341            }
342            _ => panic!(
343                "STFT expects input of shape (batch, time) or (batch, channels, time), got {:?}",
344                shape
345            ),
346        }
347    }
348
349    fn parameters(&self) -> Vec<Parameter> {
350        Vec::new()
351    }
352
353    fn named_parameters(&self) -> HashMap<String, Parameter> {
354        HashMap::new()
355    }
356
357    fn name(&self) -> &'static str {
358        "STFT"
359    }
360}
361
362// =============================================================================
363// Utility Functions
364// =============================================================================
365
366/// Generates a Hann window of the given size.
367fn hann_window(size: usize) -> Vec<f32> {
368    (0..size)
369        .map(|i| {
370            let phase = 2.0 * std::f32::consts::PI * i as f32 / (size - 1) as f32;
371            0.5 * (1.0 - phase.cos())
372        })
373        .collect()
374}
375
376// =============================================================================
377// Tests
378// =============================================================================
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_fft1d_shape_2d() {
386        let fft = FFT1d::new(64);
387        let input = Variable::new(
388            Tensor::from_vec(vec![0.0; 128], &[2, 64]).expect("tensor creation failed"),
389            false,
390        );
391        let output = fft.forward(&input);
392        assert_eq!(output.shape(), vec![2, 33]); // n_fft/2+1 = 33
393    }
394
395    #[test]
396    fn test_fft1d_shape_3d() {
397        let fft = FFT1d::new(128);
398        let input = Variable::new(
399            Tensor::from_vec(vec![0.0; 2 * 3 * 128], &[2, 3, 128]).expect("tensor creation failed"),
400            false,
401        );
402        let output = fft.forward(&input);
403        assert_eq!(output.shape(), vec![2, 3, 65]); // n_fft/2+1 = 65
404    }
405
406    #[test]
407    fn test_fft1d_known_sinusoid() {
408        // Create a pure 10 Hz sinusoid sampled at 64 Hz
409        let n = 64;
410        let freq = 10.0;
411        let sample_rate = 64.0;
412        let signal: Vec<f32> = (0..n)
413            .map(|i| {
414                let t = i as f32 / sample_rate;
415                (2.0 * std::f32::consts::PI * freq * t).sin()
416            })
417            .collect();
418
419        let fft = FFT1d::new(n);
420        let input = Variable::new(
421            Tensor::from_vec(signal, &[1, n]).expect("tensor creation failed"),
422            false,
423        );
424        let output = fft.forward(&input);
425        let spectrum = output.data().to_vec();
426
427        // The peak should be at bin 10 (freq * n / sample_rate = 10)
428        let peak_bin = spectrum
429            .iter()
430            .enumerate()
431            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
432            .unwrap()
433            .0;
434        assert_eq!(peak_bin, 10);
435    }
436
437    #[test]
438    fn test_fft1d_zero_padding() {
439        // Input shorter than n_fft gets zero-padded
440        let fft = FFT1d::new(128);
441        let input = Variable::new(
442            Tensor::from_vec(vec![1.0; 32], &[1, 32]).expect("tensor creation failed"),
443            false,
444        );
445        let output = fft.forward(&input);
446        assert_eq!(output.shape(), vec![1, 65]);
447    }
448
449    #[test]
450    fn test_fft1d_normalized() {
451        let fft_norm = FFT1d::with_normalization(64, true);
452        let fft_raw = FFT1d::new(64);
453
454        let signal = vec![1.0; 64];
455        let input = Variable::new(
456            Tensor::from_vec(signal, &[1, 64]).expect("tensor creation failed"),
457            false,
458        );
459
460        let out_norm = fft_norm.forward(&input).data().to_vec();
461        let out_raw = fft_raw.forward(&input).data().to_vec();
462
463        // Normalized should be raw / sqrt(64) = raw / 8
464        let ratio = out_raw[0] / out_norm[0];
465        assert!((ratio - 8.0).abs() < 0.01);
466    }
467
468    #[test]
469    fn test_stft_shape() {
470        let stft = STFT::new(256, 128);
471        let input = Variable::new(
472            Tensor::from_vec(vec![0.0; 2 * 1024], &[2, 1024]).expect("tensor creation failed"),
473            false,
474        );
475        let output = stft.forward(&input);
476
477        let n_frames = stft.n_frames(1024); // (1024 - 256) / 128 + 1 = 7
478        assert_eq!(output.shape(), vec![2, n_frames, 129]);
479        assert_eq!(n_frames, 7);
480    }
481
482    #[test]
483    fn test_stft_shape_3d() {
484        let stft = STFT::new(64, 32);
485        let input = Variable::new(
486            Tensor::from_vec(vec![0.0; 2 * 3 * 256], &[2, 3, 256]).expect("tensor creation failed"),
487            false,
488        );
489        let output = stft.forward(&input);
490
491        let n_frames = stft.n_frames(256); // (256 - 64) / 32 + 1 = 7
492        assert_eq!(output.shape(), vec![2, 3, n_frames, 33]);
493    }
494
495    #[test]
496    fn test_stft_no_parameters() {
497        let stft = STFT::new(256, 128);
498        assert_eq!(stft.parameters().len(), 0);
499    }
500
501    #[test]
502    fn test_fft1d_output_bins() {
503        assert_eq!(FFT1d::new(64).output_bins(), 33);
504        assert_eq!(FFT1d::new(256).output_bins(), 129);
505        assert_eq!(FFT1d::new(512).output_bins(), 257);
506    }
507
508    #[test]
509    fn test_hann_window() {
510        let w = hann_window(4);
511        // Hann window for size 4: [0, 0.75, 0.75, 0]
512        assert!((w[0]).abs() < 1e-6);
513        assert!((w[1] - 0.75).abs() < 0.01);
514        assert!((w[2] - 0.75).abs() < 0.01);
515        assert!((w[3]).abs() < 1e-6);
516    }
517}