Skip to main content

axonml_nn/layers/
fft.rs

1//! FFT Layers - Fast Fourier Transform for Neural Networks
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/fft.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use std::collections::HashMap;
19
20use axonml_autograd::Variable;
21use axonml_tensor::Tensor;
22use rustfft::{FftPlanner, num_complex::Complex};
23
24use crate::module::Module;
25use crate::parameter::Parameter;
26
27// =============================================================================
28// FFT1d
29// =============================================================================
30
31/// 1D Fast Fourier Transform layer.
32///
33/// Computes the magnitude spectrum of the input signal using an efficient
34/// FFT algorithm. Returns the first `n_fft/2 + 1` frequency bins (real FFT).
35///
36/// # Input Shape
37/// `(batch, channels, time)` or `(batch, time)` — panics on other ranks
38///
39/// # Output Shape
40/// `(batch, channels, n_fft/2+1)` or `(batch, n_fft/2+1)`
41///
42/// # Example
43/// ```ignore
44/// use axonml_nn::layers::FFT1d;
45///
46/// let fft = FFT1d::new(256);
47/// let signal = Variable::new(Tensor::randn(&[2, 1, 256]), true);
48/// let spectrum = fft.forward(&signal);
49/// // spectrum shape: (2, 1, 129)
50/// ```
51pub struct FFT1d {
52    n_fft: usize,
53    normalized: bool,
54}
55
56impl FFT1d {
57    /// Creates a new FFT1d layer.
58    ///
59    /// # Arguments
60    /// * `n_fft` - FFT size. Input will be zero-padded or truncated to this length.
61    pub fn new(n_fft: usize) -> Self {
62        Self {
63            n_fft,
64            normalized: false,
65        }
66    }
67
68    /// Creates an FFT1d layer with optional normalization.
69    ///
70    /// When normalized, the output is divided by `sqrt(n_fft)`.
71    pub fn with_normalization(n_fft: usize, normalized: bool) -> Self {
72        Self { n_fft, normalized }
73    }
74
75    /// Returns the number of output frequency bins.
76    pub fn output_bins(&self) -> usize {
77        self.n_fft / 2 + 1
78    }
79
80    /// Compute FFT magnitude for a single 1D signal.
81    fn fft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
82        let n = self.n_fft;
83        let n_out = n / 2 + 1;
84
85        // Prepare complex input (zero-pad or truncate)
86        let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
87        let copy_len = signal.len().min(n);
88        for i in 0..copy_len {
89            buffer[i] = Complex::new(signal[i], 0.0);
90        }
91
92        // Compute FFT
93        let mut planner = FftPlanner::new();
94        let fft = planner.plan_fft_forward(n);
95        fft.process(&mut buffer);
96
97        // Extract magnitude for positive frequencies
98        let norm_factor = if self.normalized {
99            1.0 / (n as f32).sqrt()
100        } else {
101            1.0
102        };
103
104        let mut magnitude = Vec::with_capacity(n_out);
105        for i in 0..n_out {
106            let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
107            magnitude.push(mag * norm_factor);
108        }
109
110        magnitude
111    }
112}
113
114impl Module for FFT1d {
115    fn forward(&self, input: &Variable) -> Variable {
116        let shape = input.shape();
117        let data = input.data().to_vec();
118        let n_out = self.output_bins();
119
120        match shape.len() {
121            2 => {
122                // (batch, time) → (batch, n_fft/2+1)
123                let batch = shape[0];
124                let time = shape[1];
125                let mut output = Vec::with_capacity(batch * n_out);
126
127                for b in 0..batch {
128                    let start = b * time;
129                    let end = start + time;
130                    let signal = &data[start..end];
131                    output.extend_from_slice(&self.fft_magnitude(signal));
132                }
133
134                Variable::new(
135                    Tensor::from_vec(output, &[batch, n_out]).expect("tensor creation failed"),
136                    input.requires_grad(),
137                )
138            }
139            3 => {
140                // (batch, channels, time) → (batch, channels, n_fft/2+1)
141                let batch = shape[0];
142                let channels = shape[1];
143                let time = shape[2];
144                let mut output = Vec::with_capacity(batch * channels * n_out);
145
146                for b in 0..batch {
147                    for c in 0..channels {
148                        let start = (b * channels + c) * time;
149                        let end = start + time;
150                        let signal = &data[start..end];
151                        output.extend_from_slice(&self.fft_magnitude(signal));
152                    }
153                }
154
155                Variable::new(
156                    Tensor::from_vec(output, &[batch, channels, n_out])
157                        .expect("tensor creation failed"),
158                    input.requires_grad(),
159                )
160            }
161            _ => panic!(
162                "FFT1d expects input of shape (batch, time) or (batch, channels, time), got {:?}",
163                shape
164            ),
165        }
166    }
167
168    fn parameters(&self) -> Vec<Parameter> {
169        Vec::new() // FFT has no learnable parameters
170    }
171
172    fn named_parameters(&self) -> HashMap<String, Parameter> {
173        HashMap::new()
174    }
175
176    fn name(&self) -> &'static str {
177        "FFT1d"
178    }
179}
180
181// =============================================================================
182// STFT
183// =============================================================================
184
185/// Short-Time Fourier Transform layer.
186///
187/// Applies a sliding window FFT to compute time-frequency representations.
188/// Uses a Hann window by default.
189///
190/// # Input Shape
191/// `(batch, channels, time)` or `(batch, time)` — panics on other ranks
192///
193/// # Output Shape
194/// `(batch, channels, n_frames, n_fft/2+1)` or `(batch, n_frames, n_fft/2+1)`
195///
196/// # Example
197/// ```ignore
198/// use axonml_nn::layers::STFT;
199///
200/// let stft = STFT::new(256, 128); // n_fft=256, hop=128
201/// let signal = Variable::new(Tensor::randn(&[2, 1, 1024]), false);
202/// let spec = stft.forward(&signal);
203/// // spec shape: (2, 1, 7, 129)
204/// ```
205pub struct STFT {
206    n_fft: usize,
207    hop_length: usize,
208    window: Vec<f32>,
209    normalized: bool,
210}
211
212impl STFT {
213    /// Creates a new STFT layer with a Hann window.
214    ///
215    /// # Arguments
216    /// * `n_fft` - FFT window size
217    /// * `hop_length` - Number of samples between successive frames
218    pub fn new(n_fft: usize, hop_length: usize) -> Self {
219        let window = hann_window(n_fft);
220        Self {
221            n_fft,
222            hop_length,
223            window,
224            normalized: false,
225        }
226    }
227
228    /// Creates an STFT layer with normalization.
229    pub fn with_normalization(n_fft: usize, hop_length: usize, normalized: bool) -> Self {
230        let window = hann_window(n_fft);
231        Self {
232            n_fft,
233            hop_length,
234            window,
235            normalized,
236        }
237    }
238
239    /// Returns the number of output frequency bins.
240    pub fn output_bins(&self) -> usize {
241        self.n_fft / 2 + 1
242    }
243
244    /// Computes the number of frames for a given signal length.
245    pub fn n_frames(&self, signal_length: usize) -> usize {
246        if signal_length < self.n_fft {
247            1
248        } else {
249            (signal_length - self.n_fft) / self.hop_length + 1
250        }
251    }
252
253    /// Compute STFT magnitude for a single 1D signal.
254    fn stft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
255        let n = self.n_fft;
256        let n_out = n / 2 + 1;
257        let n_frames = self.n_frames(signal.len());
258
259        let norm_factor = if self.normalized {
260            1.0 / (n as f32).sqrt()
261        } else {
262            1.0
263        };
264
265        let mut planner = FftPlanner::new();
266        let fft = planner.plan_fft_forward(n);
267
268        let mut output = Vec::with_capacity(n_frames * n_out);
269
270        for frame in 0..n_frames {
271            let start = frame * self.hop_length;
272
273            // Apply window and create complex buffer
274            let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
275            for i in 0..n {
276                let idx = start + i;
277                let sample = if idx < signal.len() { signal[idx] } else { 0.0 };
278                buffer[i] = Complex::new(sample * self.window[i], 0.0);
279            }
280
281            fft.process(&mut buffer);
282
283            for i in 0..n_out {
284                let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
285                output.push(mag * norm_factor);
286            }
287        }
288
289        output
290    }
291}
292
293impl Module for STFT {
294    fn forward(&self, input: &Variable) -> Variable {
295        let shape = input.shape();
296        let data = input.data().to_vec();
297        let n_out = self.output_bins();
298
299        match shape.len() {
300            2 => {
301                // (batch, time) → (batch, n_frames, n_fft/2+1)
302                let batch = shape[0];
303                let time = shape[1];
304                let n_frames = self.n_frames(time);
305                let mut output = Vec::with_capacity(batch * n_frames * n_out);
306
307                for b in 0..batch {
308                    let start = b * time;
309                    let end = start + time;
310                    let signal = &data[start..end];
311                    output.extend_from_slice(&self.stft_magnitude(signal));
312                }
313
314                Variable::new(
315                    Tensor::from_vec(output, &[batch, n_frames, n_out])
316                        .expect("tensor creation failed"),
317                    input.requires_grad(),
318                )
319            }
320            3 => {
321                // (batch, channels, time) → (batch, channels, n_frames, n_fft/2+1)
322                let batch = shape[0];
323                let channels = shape[1];
324                let time = shape[2];
325                let n_frames = self.n_frames(time);
326                let mut output = Vec::with_capacity(batch * channels * n_frames * n_out);
327
328                for b in 0..batch {
329                    for c in 0..channels {
330                        let start = (b * channels + c) * time;
331                        let end = start + time;
332                        let signal = &data[start..end];
333                        output.extend_from_slice(&self.stft_magnitude(signal));
334                    }
335                }
336
337                Variable::new(
338                    Tensor::from_vec(output, &[batch, channels, n_frames, n_out])
339                        .expect("tensor creation failed"),
340                    input.requires_grad(),
341                )
342            }
343            _ => panic!(
344                "STFT expects input of shape (batch, time) or (batch, channels, time), got {:?}",
345                shape
346            ),
347        }
348    }
349
350    fn parameters(&self) -> Vec<Parameter> {
351        Vec::new()
352    }
353
354    fn named_parameters(&self) -> HashMap<String, Parameter> {
355        HashMap::new()
356    }
357
358    fn name(&self) -> &'static str {
359        "STFT"
360    }
361}
362
363// =============================================================================
364// Utility Functions
365// =============================================================================
366
367/// Generates a Hann window of the given size.
368fn hann_window(size: usize) -> Vec<f32> {
369    (0..size)
370        .map(|i| {
371            let phase = 2.0 * std::f32::consts::PI * i as f32 / (size - 1) as f32;
372            0.5 * (1.0 - phase.cos())
373        })
374        .collect()
375}
376
377// =============================================================================
378// Tests
379// =============================================================================
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn test_fft1d_shape_2d() {
387        let fft = FFT1d::new(64);
388        let input = Variable::new(
389            Tensor::from_vec(vec![0.0; 128], &[2, 64]).expect("tensor creation failed"),
390            false,
391        );
392        let output = fft.forward(&input);
393        assert_eq!(output.shape(), vec![2, 33]); // n_fft/2+1 = 33
394    }
395
396    #[test]
397    fn test_fft1d_shape_3d() {
398        let fft = FFT1d::new(128);
399        let input = Variable::new(
400            Tensor::from_vec(vec![0.0; 2 * 3 * 128], &[2, 3, 128]).expect("tensor creation failed"),
401            false,
402        );
403        let output = fft.forward(&input);
404        assert_eq!(output.shape(), vec![2, 3, 65]); // n_fft/2+1 = 65
405    }
406
407    #[test]
408    fn test_fft1d_known_sinusoid() {
409        // Create a pure 10 Hz sinusoid sampled at 64 Hz
410        let n = 64;
411        let freq = 10.0;
412        let sample_rate = 64.0;
413        let signal: Vec<f32> = (0..n)
414            .map(|i| {
415                let t = i as f32 / sample_rate;
416                (2.0 * std::f32::consts::PI * freq * t).sin()
417            })
418            .collect();
419
420        let fft = FFT1d::new(n);
421        let input = Variable::new(
422            Tensor::from_vec(signal, &[1, n]).expect("tensor creation failed"),
423            false,
424        );
425        let output = fft.forward(&input);
426        let spectrum = output.data().to_vec();
427
428        // The peak should be at bin 10 (freq * n / sample_rate = 10)
429        let peak_bin = spectrum
430            .iter()
431            .enumerate()
432            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
433            .unwrap()
434            .0;
435        assert_eq!(peak_bin, 10);
436    }
437
438    #[test]
439    fn test_fft1d_zero_padding() {
440        // Input shorter than n_fft gets zero-padded
441        let fft = FFT1d::new(128);
442        let input = Variable::new(
443            Tensor::from_vec(vec![1.0; 32], &[1, 32]).expect("tensor creation failed"),
444            false,
445        );
446        let output = fft.forward(&input);
447        assert_eq!(output.shape(), vec![1, 65]);
448    }
449
450    #[test]
451    fn test_fft1d_normalized() {
452        let fft_norm = FFT1d::with_normalization(64, true);
453        let fft_raw = FFT1d::new(64);
454
455        let signal = vec![1.0; 64];
456        let input = Variable::new(
457            Tensor::from_vec(signal, &[1, 64]).expect("tensor creation failed"),
458            false,
459        );
460
461        let out_norm = fft_norm.forward(&input).data().to_vec();
462        let out_raw = fft_raw.forward(&input).data().to_vec();
463
464        // Normalized should be raw / sqrt(64) = raw / 8
465        let ratio = out_raw[0] / out_norm[0];
466        assert!((ratio - 8.0).abs() < 0.01);
467    }
468
469    #[test]
470    fn test_stft_shape() {
471        let stft = STFT::new(256, 128);
472        let input = Variable::new(
473            Tensor::from_vec(vec![0.0; 2 * 1024], &[2, 1024]).expect("tensor creation failed"),
474            false,
475        );
476        let output = stft.forward(&input);
477
478        let n_frames = stft.n_frames(1024); // (1024 - 256) / 128 + 1 = 7
479        assert_eq!(output.shape(), vec![2, n_frames, 129]);
480        assert_eq!(n_frames, 7);
481    }
482
483    #[test]
484    fn test_stft_shape_3d() {
485        let stft = STFT::new(64, 32);
486        let input = Variable::new(
487            Tensor::from_vec(vec![0.0; 2 * 3 * 256], &[2, 3, 256]).expect("tensor creation failed"),
488            false,
489        );
490        let output = stft.forward(&input);
491
492        let n_frames = stft.n_frames(256); // (256 - 64) / 32 + 1 = 7
493        assert_eq!(output.shape(), vec![2, 3, n_frames, 33]);
494    }
495
496    #[test]
497    fn test_stft_no_parameters() {
498        let stft = STFT::new(256, 128);
499        assert_eq!(stft.parameters().len(), 0);
500    }
501
502    #[test]
503    fn test_fft1d_output_bins() {
504        assert_eq!(FFT1d::new(64).output_bins(), 33);
505        assert_eq!(FFT1d::new(256).output_bins(), 129);
506        assert_eq!(FFT1d::new(512).output_bins(), 257);
507    }
508
509    #[test]
510    fn test_hann_window() {
511        let w = hann_window(4);
512        // Hann window for size 4: [0, 0.75, 0.75, 0]
513        assert!((w[0]).abs() < 1e-6);
514        assert!((w[1] - 0.75).abs() < 0.01);
515        assert!((w[2] - 0.75).abs() < 0.01);
516        assert!((w[3]).abs() < 1e-6);
517    }
518}