1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21use rustfft::{FftPlanner, num_complex::Complex};
22
23use crate::module::Module;
24use crate::parameter::Parameter;
25
26pub struct FFT1d {
51 n_fft: usize,
52 normalized: bool,
53}
54
55impl FFT1d {
56 pub fn new(n_fft: usize) -> Self {
61 Self {
62 n_fft,
63 normalized: false,
64 }
65 }
66
67 pub fn with_normalization(n_fft: usize, normalized: bool) -> Self {
71 Self { n_fft, normalized }
72 }
73
74 pub fn output_bins(&self) -> usize {
76 self.n_fft / 2 + 1
77 }
78
79 fn fft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
81 let n = self.n_fft;
82 let n_out = n / 2 + 1;
83
84 let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
86 let copy_len = signal.len().min(n);
87 for i in 0..copy_len {
88 buffer[i] = Complex::new(signal[i], 0.0);
89 }
90
91 let mut planner = FftPlanner::new();
93 let fft = planner.plan_fft_forward(n);
94 fft.process(&mut buffer);
95
96 let norm_factor = if self.normalized {
98 1.0 / (n as f32).sqrt()
99 } else {
100 1.0
101 };
102
103 let mut magnitude = Vec::with_capacity(n_out);
104 for i in 0..n_out {
105 let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
106 magnitude.push(mag * norm_factor);
107 }
108
109 magnitude
110 }
111}
112
113impl Module for FFT1d {
114 fn forward(&self, input: &Variable) -> Variable {
115 let shape = input.shape();
116 let data = input.data().to_vec();
117 let n_out = self.output_bins();
118
119 match shape.len() {
120 2 => {
121 let batch = shape[0];
123 let time = shape[1];
124 let mut output = Vec::with_capacity(batch * n_out);
125
126 for b in 0..batch {
127 let start = b * time;
128 let end = start + time;
129 let signal = &data[start..end];
130 output.extend_from_slice(&self.fft_magnitude(signal));
131 }
132
133 Variable::new(
134 Tensor::from_vec(output, &[batch, n_out]).expect("tensor creation failed"),
135 input.requires_grad(),
136 )
137 }
138 3 => {
139 let batch = shape[0];
141 let channels = shape[1];
142 let time = shape[2];
143 let mut output = Vec::with_capacity(batch * channels * n_out);
144
145 for b in 0..batch {
146 for c in 0..channels {
147 let start = (b * channels + c) * time;
148 let end = start + time;
149 let signal = &data[start..end];
150 output.extend_from_slice(&self.fft_magnitude(signal));
151 }
152 }
153
154 Variable::new(
155 Tensor::from_vec(output, &[batch, channels, n_out])
156 .expect("tensor creation failed"),
157 input.requires_grad(),
158 )
159 }
160 _ => panic!(
161 "FFT1d expects input of shape (batch, time) or (batch, channels, time), got {:?}",
162 shape
163 ),
164 }
165 }
166
167 fn parameters(&self) -> Vec<Parameter> {
168 Vec::new() }
170
171 fn named_parameters(&self) -> HashMap<String, Parameter> {
172 HashMap::new()
173 }
174
175 fn name(&self) -> &'static str {
176 "FFT1d"
177 }
178}
179
180pub struct STFT {
205 n_fft: usize,
206 hop_length: usize,
207 window: Vec<f32>,
208 normalized: bool,
209}
210
211impl STFT {
212 pub fn new(n_fft: usize, hop_length: usize) -> Self {
218 let window = hann_window(n_fft);
219 Self {
220 n_fft,
221 hop_length,
222 window,
223 normalized: false,
224 }
225 }
226
227 pub fn with_normalization(n_fft: usize, hop_length: usize, normalized: bool) -> Self {
229 let window = hann_window(n_fft);
230 Self {
231 n_fft,
232 hop_length,
233 window,
234 normalized,
235 }
236 }
237
238 pub fn output_bins(&self) -> usize {
240 self.n_fft / 2 + 1
241 }
242
243 pub fn n_frames(&self, signal_length: usize) -> usize {
245 if signal_length < self.n_fft {
246 1
247 } else {
248 (signal_length - self.n_fft) / self.hop_length + 1
249 }
250 }
251
252 fn stft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
254 let n = self.n_fft;
255 let n_out = n / 2 + 1;
256 let n_frames = self.n_frames(signal.len());
257
258 let norm_factor = if self.normalized {
259 1.0 / (n as f32).sqrt()
260 } else {
261 1.0
262 };
263
264 let mut planner = FftPlanner::new();
265 let fft = planner.plan_fft_forward(n);
266
267 let mut output = Vec::with_capacity(n_frames * n_out);
268
269 for frame in 0..n_frames {
270 let start = frame * self.hop_length;
271
272 let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
274 for i in 0..n {
275 let idx = start + i;
276 let sample = if idx < signal.len() { signal[idx] } else { 0.0 };
277 buffer[i] = Complex::new(sample * self.window[i], 0.0);
278 }
279
280 fft.process(&mut buffer);
281
282 for i in 0..n_out {
283 let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
284 output.push(mag * norm_factor);
285 }
286 }
287
288 output
289 }
290}
291
292impl Module for STFT {
293 fn forward(&self, input: &Variable) -> Variable {
294 let shape = input.shape();
295 let data = input.data().to_vec();
296 let n_out = self.output_bins();
297
298 match shape.len() {
299 2 => {
300 let batch = shape[0];
302 let time = shape[1];
303 let n_frames = self.n_frames(time);
304 let mut output = Vec::with_capacity(batch * n_frames * n_out);
305
306 for b in 0..batch {
307 let start = b * time;
308 let end = start + time;
309 let signal = &data[start..end];
310 output.extend_from_slice(&self.stft_magnitude(signal));
311 }
312
313 Variable::new(
314 Tensor::from_vec(output, &[batch, n_frames, n_out])
315 .expect("tensor creation failed"),
316 input.requires_grad(),
317 )
318 }
319 3 => {
320 let batch = shape[0];
322 let channels = shape[1];
323 let time = shape[2];
324 let n_frames = self.n_frames(time);
325 let mut output = Vec::with_capacity(batch * channels * n_frames * n_out);
326
327 for b in 0..batch {
328 for c in 0..channels {
329 let start = (b * channels + c) * time;
330 let end = start + time;
331 let signal = &data[start..end];
332 output.extend_from_slice(&self.stft_magnitude(signal));
333 }
334 }
335
336 Variable::new(
337 Tensor::from_vec(output, &[batch, channels, n_frames, n_out])
338 .expect("tensor creation failed"),
339 input.requires_grad(),
340 )
341 }
342 _ => panic!(
343 "STFT expects input of shape (batch, time) or (batch, channels, time), got {:?}",
344 shape
345 ),
346 }
347 }
348
349 fn parameters(&self) -> Vec<Parameter> {
350 Vec::new()
351 }
352
353 fn named_parameters(&self) -> HashMap<String, Parameter> {
354 HashMap::new()
355 }
356
357 fn name(&self) -> &'static str {
358 "STFT"
359 }
360}
361
362fn hann_window(size: usize) -> Vec<f32> {
368 (0..size)
369 .map(|i| {
370 let phase = 2.0 * std::f32::consts::PI * i as f32 / (size - 1) as f32;
371 0.5 * (1.0 - phase.cos())
372 })
373 .collect()
374}
375
376#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_fft1d_shape_2d() {
386 let fft = FFT1d::new(64);
387 let input = Variable::new(
388 Tensor::from_vec(vec![0.0; 128], &[2, 64]).expect("tensor creation failed"),
389 false,
390 );
391 let output = fft.forward(&input);
392 assert_eq!(output.shape(), vec![2, 33]); }
394
395 #[test]
396 fn test_fft1d_shape_3d() {
397 let fft = FFT1d::new(128);
398 let input = Variable::new(
399 Tensor::from_vec(vec![0.0; 2 * 3 * 128], &[2, 3, 128]).expect("tensor creation failed"),
400 false,
401 );
402 let output = fft.forward(&input);
403 assert_eq!(output.shape(), vec![2, 3, 65]); }
405
406 #[test]
407 fn test_fft1d_known_sinusoid() {
408 let n = 64;
410 let freq = 10.0;
411 let sample_rate = 64.0;
412 let signal: Vec<f32> = (0..n)
413 .map(|i| {
414 let t = i as f32 / sample_rate;
415 (2.0 * std::f32::consts::PI * freq * t).sin()
416 })
417 .collect();
418
419 let fft = FFT1d::new(n);
420 let input = Variable::new(
421 Tensor::from_vec(signal, &[1, n]).expect("tensor creation failed"),
422 false,
423 );
424 let output = fft.forward(&input);
425 let spectrum = output.data().to_vec();
426
427 let peak_bin = spectrum
429 .iter()
430 .enumerate()
431 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
432 .unwrap()
433 .0;
434 assert_eq!(peak_bin, 10);
435 }
436
437 #[test]
438 fn test_fft1d_zero_padding() {
439 let fft = FFT1d::new(128);
441 let input = Variable::new(
442 Tensor::from_vec(vec![1.0; 32], &[1, 32]).expect("tensor creation failed"),
443 false,
444 );
445 let output = fft.forward(&input);
446 assert_eq!(output.shape(), vec![1, 65]);
447 }
448
449 #[test]
450 fn test_fft1d_normalized() {
451 let fft_norm = FFT1d::with_normalization(64, true);
452 let fft_raw = FFT1d::new(64);
453
454 let signal = vec![1.0; 64];
455 let input = Variable::new(
456 Tensor::from_vec(signal, &[1, 64]).expect("tensor creation failed"),
457 false,
458 );
459
460 let out_norm = fft_norm.forward(&input).data().to_vec();
461 let out_raw = fft_raw.forward(&input).data().to_vec();
462
463 let ratio = out_raw[0] / out_norm[0];
465 assert!((ratio - 8.0).abs() < 0.01);
466 }
467
468 #[test]
469 fn test_stft_shape() {
470 let stft = STFT::new(256, 128);
471 let input = Variable::new(
472 Tensor::from_vec(vec![0.0; 2 * 1024], &[2, 1024]).expect("tensor creation failed"),
473 false,
474 );
475 let output = stft.forward(&input);
476
477 let n_frames = stft.n_frames(1024); assert_eq!(output.shape(), vec![2, n_frames, 129]);
479 assert_eq!(n_frames, 7);
480 }
481
482 #[test]
483 fn test_stft_shape_3d() {
484 let stft = STFT::new(64, 32);
485 let input = Variable::new(
486 Tensor::from_vec(vec![0.0; 2 * 3 * 256], &[2, 3, 256]).expect("tensor creation failed"),
487 false,
488 );
489 let output = stft.forward(&input);
490
491 let n_frames = stft.n_frames(256); assert_eq!(output.shape(), vec![2, 3, n_frames, 33]);
493 }
494
495 #[test]
496 fn test_stft_no_parameters() {
497 let stft = STFT::new(256, 128);
498 assert_eq!(stft.parameters().len(), 0);
499 }
500
501 #[test]
502 fn test_fft1d_output_bins() {
503 assert_eq!(FFT1d::new(64).output_bins(), 33);
504 assert_eq!(FFT1d::new(256).output_bins(), 129);
505 assert_eq!(FFT1d::new(512).output_bins(), 257);
506 }
507
508 #[test]
509 fn test_hann_window() {
510 let w = hann_window(4);
511 assert!((w[0]).abs() < 1e-6);
513 assert!((w[1] - 0.75).abs() < 0.01);
514 assert!((w[2] - 0.75).abs() < 0.01);
515 assert!((w[3]).abs() < 1e-6);
516 }
517}