1use std::collections::HashMap;
10
11use axonml_autograd::Variable;
12use axonml_tensor::Tensor;
13use rustfft::{FftPlanner, num_complex::Complex};
14
15use crate::module::Module;
16use crate::parameter::Parameter;
17
18pub struct FFT1d {
43 n_fft: usize,
44 normalized: bool,
45}
46
47impl FFT1d {
48 pub fn new(n_fft: usize) -> Self {
53 Self {
54 n_fft,
55 normalized: false,
56 }
57 }
58
59 pub fn with_normalization(n_fft: usize, normalized: bool) -> Self {
63 Self { n_fft, normalized }
64 }
65
66 pub fn output_bins(&self) -> usize {
68 self.n_fft / 2 + 1
69 }
70
71 fn fft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
73 let n = self.n_fft;
74 let n_out = n / 2 + 1;
75
76 let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
78 let copy_len = signal.len().min(n);
79 for i in 0..copy_len {
80 buffer[i] = Complex::new(signal[i], 0.0);
81 }
82
83 let mut planner = FftPlanner::new();
85 let fft = planner.plan_fft_forward(n);
86 fft.process(&mut buffer);
87
88 let norm_factor = if self.normalized {
90 1.0 / (n as f32).sqrt()
91 } else {
92 1.0
93 };
94
95 let mut magnitude = Vec::with_capacity(n_out);
96 for i in 0..n_out {
97 let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
98 magnitude.push(mag * norm_factor);
99 }
100
101 magnitude
102 }
103}
104
105impl Module for FFT1d {
106 fn forward(&self, input: &Variable) -> Variable {
107 let shape = input.shape();
108 let data = input.data().to_vec();
109 let n_out = self.output_bins();
110
111 match shape.len() {
112 2 => {
113 let batch = shape[0];
115 let time = shape[1];
116 let mut output = Vec::with_capacity(batch * n_out);
117
118 for b in 0..batch {
119 let start = b * time;
120 let end = start + time;
121 let signal = &data[start..end];
122 output.extend_from_slice(&self.fft_magnitude(signal));
123 }
124
125 Variable::new(
126 Tensor::from_vec(output, &[batch, n_out]).unwrap(),
127 input.requires_grad(),
128 )
129 }
130 3 => {
131 let batch = shape[0];
133 let channels = shape[1];
134 let time = shape[2];
135 let mut output = Vec::with_capacity(batch * channels * n_out);
136
137 for b in 0..batch {
138 for c in 0..channels {
139 let start = (b * channels + c) * time;
140 let end = start + time;
141 let signal = &data[start..end];
142 output.extend_from_slice(&self.fft_magnitude(signal));
143 }
144 }
145
146 Variable::new(
147 Tensor::from_vec(output, &[batch, channels, n_out]).unwrap(),
148 input.requires_grad(),
149 )
150 }
151 _ => panic!(
152 "FFT1d expects input of shape (batch, time) or (batch, channels, time), got {:?}",
153 shape
154 ),
155 }
156 }
157
158 fn parameters(&self) -> Vec<Parameter> {
159 Vec::new() }
161
162 fn named_parameters(&self) -> HashMap<String, Parameter> {
163 HashMap::new()
164 }
165
166 fn name(&self) -> &'static str {
167 "FFT1d"
168 }
169}
170
171pub struct STFT {
196 n_fft: usize,
197 hop_length: usize,
198 window: Vec<f32>,
199 normalized: bool,
200}
201
202impl STFT {
203 pub fn new(n_fft: usize, hop_length: usize) -> Self {
209 let window = hann_window(n_fft);
210 Self {
211 n_fft,
212 hop_length,
213 window,
214 normalized: false,
215 }
216 }
217
218 pub fn with_normalization(n_fft: usize, hop_length: usize, normalized: bool) -> Self {
220 let window = hann_window(n_fft);
221 Self {
222 n_fft,
223 hop_length,
224 window,
225 normalized,
226 }
227 }
228
229 pub fn output_bins(&self) -> usize {
231 self.n_fft / 2 + 1
232 }
233
234 pub fn n_frames(&self, signal_length: usize) -> usize {
236 if signal_length < self.n_fft {
237 1
238 } else {
239 (signal_length - self.n_fft) / self.hop_length + 1
240 }
241 }
242
243 fn stft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
245 let n = self.n_fft;
246 let n_out = n / 2 + 1;
247 let n_frames = self.n_frames(signal.len());
248
249 let norm_factor = if self.normalized {
250 1.0 / (n as f32).sqrt()
251 } else {
252 1.0
253 };
254
255 let mut planner = FftPlanner::new();
256 let fft = planner.plan_fft_forward(n);
257
258 let mut output = Vec::with_capacity(n_frames * n_out);
259
260 for frame in 0..n_frames {
261 let start = frame * self.hop_length;
262
263 let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
265 for i in 0..n {
266 let idx = start + i;
267 let sample = if idx < signal.len() { signal[idx] } else { 0.0 };
268 buffer[i] = Complex::new(sample * self.window[i], 0.0);
269 }
270
271 fft.process(&mut buffer);
272
273 for i in 0..n_out {
274 let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
275 output.push(mag * norm_factor);
276 }
277 }
278
279 output
280 }
281}
282
283impl Module for STFT {
284 fn forward(&self, input: &Variable) -> Variable {
285 let shape = input.shape();
286 let data = input.data().to_vec();
287 let n_out = self.output_bins();
288
289 match shape.len() {
290 2 => {
291 let batch = shape[0];
293 let time = shape[1];
294 let n_frames = self.n_frames(time);
295 let mut output = Vec::with_capacity(batch * n_frames * n_out);
296
297 for b in 0..batch {
298 let start = b * time;
299 let end = start + time;
300 let signal = &data[start..end];
301 output.extend_from_slice(&self.stft_magnitude(signal));
302 }
303
304 Variable::new(
305 Tensor::from_vec(output, &[batch, n_frames, n_out]).unwrap(),
306 input.requires_grad(),
307 )
308 }
309 3 => {
310 let batch = shape[0];
312 let channels = shape[1];
313 let time = shape[2];
314 let n_frames = self.n_frames(time);
315 let mut output = Vec::with_capacity(batch * channels * n_frames * n_out);
316
317 for b in 0..batch {
318 for c in 0..channels {
319 let start = (b * channels + c) * time;
320 let end = start + time;
321 let signal = &data[start..end];
322 output.extend_from_slice(&self.stft_magnitude(signal));
323 }
324 }
325
326 Variable::new(
327 Tensor::from_vec(output, &[batch, channels, n_frames, n_out]).unwrap(),
328 input.requires_grad(),
329 )
330 }
331 _ => panic!(
332 "STFT expects input of shape (batch, time) or (batch, channels, time), got {:?}",
333 shape
334 ),
335 }
336 }
337
338 fn parameters(&self) -> Vec<Parameter> {
339 Vec::new()
340 }
341
342 fn named_parameters(&self) -> HashMap<String, Parameter> {
343 HashMap::new()
344 }
345
346 fn name(&self) -> &'static str {
347 "STFT"
348 }
349}
350
351fn hann_window(size: usize) -> Vec<f32> {
357 (0..size)
358 .map(|i| {
359 let phase = 2.0 * std::f32::consts::PI * i as f32 / (size - 1) as f32;
360 0.5 * (1.0 - phase.cos())
361 })
362 .collect()
363}
364
365#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_fft1d_shape_2d() {
375 let fft = FFT1d::new(64);
376 let input = Variable::new(
377 Tensor::from_vec(vec![0.0; 128], &[2, 64]).unwrap(),
378 false,
379 );
380 let output = fft.forward(&input);
381 assert_eq!(output.shape(), vec![2, 33]); }
383
384 #[test]
385 fn test_fft1d_shape_3d() {
386 let fft = FFT1d::new(128);
387 let input = Variable::new(
388 Tensor::from_vec(vec![0.0; 2 * 3 * 128], &[2, 3, 128]).unwrap(),
389 false,
390 );
391 let output = fft.forward(&input);
392 assert_eq!(output.shape(), vec![2, 3, 65]); }
394
395 #[test]
396 fn test_fft1d_known_sinusoid() {
397 let n = 64;
399 let freq = 10.0;
400 let sample_rate = 64.0;
401 let signal: Vec<f32> = (0..n)
402 .map(|i| {
403 let t = i as f32 / sample_rate;
404 (2.0 * std::f32::consts::PI * freq * t).sin()
405 })
406 .collect();
407
408 let fft = FFT1d::new(n);
409 let input = Variable::new(
410 Tensor::from_vec(signal, &[1, n]).unwrap(),
411 false,
412 );
413 let output = fft.forward(&input);
414 let spectrum = output.data().to_vec();
415
416 let peak_bin = spectrum
418 .iter()
419 .enumerate()
420 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
421 .unwrap()
422 .0;
423 assert_eq!(peak_bin, 10);
424 }
425
426 #[test]
427 fn test_fft1d_zero_padding() {
428 let fft = FFT1d::new(128);
430 let input = Variable::new(
431 Tensor::from_vec(vec![1.0; 32], &[1, 32]).unwrap(),
432 false,
433 );
434 let output = fft.forward(&input);
435 assert_eq!(output.shape(), vec![1, 65]);
436 }
437
438 #[test]
439 fn test_fft1d_normalized() {
440 let fft_norm = FFT1d::with_normalization(64, true);
441 let fft_raw = FFT1d::new(64);
442
443 let signal = vec![1.0; 64];
444 let input = Variable::new(
445 Tensor::from_vec(signal, &[1, 64]).unwrap(),
446 false,
447 );
448
449 let out_norm = fft_norm.forward(&input).data().to_vec();
450 let out_raw = fft_raw.forward(&input).data().to_vec();
451
452 let ratio = out_raw[0] / out_norm[0];
454 assert!((ratio - 8.0).abs() < 0.01);
455 }
456
457 #[test]
458 fn test_stft_shape() {
459 let stft = STFT::new(256, 128);
460 let input = Variable::new(
461 Tensor::from_vec(vec![0.0; 2 * 1024], &[2, 1024]).unwrap(),
462 false,
463 );
464 let output = stft.forward(&input);
465
466 let n_frames = stft.n_frames(1024); assert_eq!(output.shape(), vec![2, n_frames, 129]);
468 assert_eq!(n_frames, 7);
469 }
470
471 #[test]
472 fn test_stft_shape_3d() {
473 let stft = STFT::new(64, 32);
474 let input = Variable::new(
475 Tensor::from_vec(vec![0.0; 2 * 3 * 256], &[2, 3, 256]).unwrap(),
476 false,
477 );
478 let output = stft.forward(&input);
479
480 let n_frames = stft.n_frames(256); assert_eq!(output.shape(), vec![2, 3, n_frames, 33]);
482 }
483
484 #[test]
485 fn test_stft_no_parameters() {
486 let stft = STFT::new(256, 128);
487 assert_eq!(stft.parameters().len(), 0);
488 }
489
490 #[test]
491 fn test_fft1d_output_bins() {
492 assert_eq!(FFT1d::new(64).output_bins(), 33);
493 assert_eq!(FFT1d::new(256).output_bins(), 129);
494 assert_eq!(FFT1d::new(512).output_bins(), 257);
495 }
496
497 #[test]
498 fn test_hann_window() {
499 let w = hann_window(4);
500 assert!((w[0]).abs() < 1e-6);
502 assert!((w[1] - 0.75).abs() < 0.01);
503 assert!((w[2] - 0.75).abs() < 0.01);
504 assert!((w[3]).abs() < 1e-6);
505 }
506}