1use std::collections::HashMap;
19
20use axonml_autograd::Variable;
21use axonml_tensor::Tensor;
22use rustfft::{FftPlanner, num_complex::Complex};
23
24use crate::module::Module;
25use crate::parameter::Parameter;
26
27pub struct FFT1d {
52 n_fft: usize,
53 normalized: bool,
54}
55
56impl FFT1d {
57 pub fn new(n_fft: usize) -> Self {
62 Self {
63 n_fft,
64 normalized: false,
65 }
66 }
67
68 pub fn with_normalization(n_fft: usize, normalized: bool) -> Self {
72 Self { n_fft, normalized }
73 }
74
75 pub fn output_bins(&self) -> usize {
77 self.n_fft / 2 + 1
78 }
79
80 fn fft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
82 let n = self.n_fft;
83 let n_out = n / 2 + 1;
84
85 let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
87 let copy_len = signal.len().min(n);
88 for i in 0..copy_len {
89 buffer[i] = Complex::new(signal[i], 0.0);
90 }
91
92 let mut planner = FftPlanner::new();
94 let fft = planner.plan_fft_forward(n);
95 fft.process(&mut buffer);
96
97 let norm_factor = if self.normalized {
99 1.0 / (n as f32).sqrt()
100 } else {
101 1.0
102 };
103
104 let mut magnitude = Vec::with_capacity(n_out);
105 for i in 0..n_out {
106 let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
107 magnitude.push(mag * norm_factor);
108 }
109
110 magnitude
111 }
112}
113
114impl Module for FFT1d {
115 fn forward(&self, input: &Variable) -> Variable {
116 let shape = input.shape();
117 let data = input.data().to_vec();
118 let n_out = self.output_bins();
119
120 match shape.len() {
121 2 => {
122 let batch = shape[0];
124 let time = shape[1];
125 let mut output = Vec::with_capacity(batch * n_out);
126
127 for b in 0..batch {
128 let start = b * time;
129 let end = start + time;
130 let signal = &data[start..end];
131 output.extend_from_slice(&self.fft_magnitude(signal));
132 }
133
134 Variable::new(
135 Tensor::from_vec(output, &[batch, n_out]).expect("tensor creation failed"),
136 input.requires_grad(),
137 )
138 }
139 3 => {
140 let batch = shape[0];
142 let channels = shape[1];
143 let time = shape[2];
144 let mut output = Vec::with_capacity(batch * channels * n_out);
145
146 for b in 0..batch {
147 for c in 0..channels {
148 let start = (b * channels + c) * time;
149 let end = start + time;
150 let signal = &data[start..end];
151 output.extend_from_slice(&self.fft_magnitude(signal));
152 }
153 }
154
155 Variable::new(
156 Tensor::from_vec(output, &[batch, channels, n_out])
157 .expect("tensor creation failed"),
158 input.requires_grad(),
159 )
160 }
161 _ => panic!(
162 "FFT1d expects input of shape (batch, time) or (batch, channels, time), got {:?}",
163 shape
164 ),
165 }
166 }
167
168 fn parameters(&self) -> Vec<Parameter> {
169 Vec::new() }
171
172 fn named_parameters(&self) -> HashMap<String, Parameter> {
173 HashMap::new()
174 }
175
176 fn name(&self) -> &'static str {
177 "FFT1d"
178 }
179}
180
181pub struct STFT {
206 n_fft: usize,
207 hop_length: usize,
208 window: Vec<f32>,
209 normalized: bool,
210}
211
212impl STFT {
213 pub fn new(n_fft: usize, hop_length: usize) -> Self {
219 let window = hann_window(n_fft);
220 Self {
221 n_fft,
222 hop_length,
223 window,
224 normalized: false,
225 }
226 }
227
228 pub fn with_normalization(n_fft: usize, hop_length: usize, normalized: bool) -> Self {
230 let window = hann_window(n_fft);
231 Self {
232 n_fft,
233 hop_length,
234 window,
235 normalized,
236 }
237 }
238
239 pub fn output_bins(&self) -> usize {
241 self.n_fft / 2 + 1
242 }
243
244 pub fn n_frames(&self, signal_length: usize) -> usize {
246 if signal_length < self.n_fft {
247 1
248 } else {
249 (signal_length - self.n_fft) / self.hop_length + 1
250 }
251 }
252
253 fn stft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
255 let n = self.n_fft;
256 let n_out = n / 2 + 1;
257 let n_frames = self.n_frames(signal.len());
258
259 let norm_factor = if self.normalized {
260 1.0 / (n as f32).sqrt()
261 } else {
262 1.0
263 };
264
265 let mut planner = FftPlanner::new();
266 let fft = planner.plan_fft_forward(n);
267
268 let mut output = Vec::with_capacity(n_frames * n_out);
269
270 for frame in 0..n_frames {
271 let start = frame * self.hop_length;
272
273 let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
275 for i in 0..n {
276 let idx = start + i;
277 let sample = if idx < signal.len() { signal[idx] } else { 0.0 };
278 buffer[i] = Complex::new(sample * self.window[i], 0.0);
279 }
280
281 fft.process(&mut buffer);
282
283 for i in 0..n_out {
284 let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
285 output.push(mag * norm_factor);
286 }
287 }
288
289 output
290 }
291}
292
293impl Module for STFT {
294 fn forward(&self, input: &Variable) -> Variable {
295 let shape = input.shape();
296 let data = input.data().to_vec();
297 let n_out = self.output_bins();
298
299 match shape.len() {
300 2 => {
301 let batch = shape[0];
303 let time = shape[1];
304 let n_frames = self.n_frames(time);
305 let mut output = Vec::with_capacity(batch * n_frames * n_out);
306
307 for b in 0..batch {
308 let start = b * time;
309 let end = start + time;
310 let signal = &data[start..end];
311 output.extend_from_slice(&self.stft_magnitude(signal));
312 }
313
314 Variable::new(
315 Tensor::from_vec(output, &[batch, n_frames, n_out])
316 .expect("tensor creation failed"),
317 input.requires_grad(),
318 )
319 }
320 3 => {
321 let batch = shape[0];
323 let channels = shape[1];
324 let time = shape[2];
325 let n_frames = self.n_frames(time);
326 let mut output = Vec::with_capacity(batch * channels * n_frames * n_out);
327
328 for b in 0..batch {
329 for c in 0..channels {
330 let start = (b * channels + c) * time;
331 let end = start + time;
332 let signal = &data[start..end];
333 output.extend_from_slice(&self.stft_magnitude(signal));
334 }
335 }
336
337 Variable::new(
338 Tensor::from_vec(output, &[batch, channels, n_frames, n_out])
339 .expect("tensor creation failed"),
340 input.requires_grad(),
341 )
342 }
343 _ => panic!(
344 "STFT expects input of shape (batch, time) or (batch, channels, time), got {:?}",
345 shape
346 ),
347 }
348 }
349
350 fn parameters(&self) -> Vec<Parameter> {
351 Vec::new()
352 }
353
354 fn named_parameters(&self) -> HashMap<String, Parameter> {
355 HashMap::new()
356 }
357
358 fn name(&self) -> &'static str {
359 "STFT"
360 }
361}
362
363fn hann_window(size: usize) -> Vec<f32> {
369 (0..size)
370 .map(|i| {
371 let phase = 2.0 * std::f32::consts::PI * i as f32 / (size - 1) as f32;
372 0.5 * (1.0 - phase.cos())
373 })
374 .collect()
375}
376
377#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_fft1d_shape_2d() {
387 let fft = FFT1d::new(64);
388 let input = Variable::new(
389 Tensor::from_vec(vec![0.0; 128], &[2, 64]).expect("tensor creation failed"),
390 false,
391 );
392 let output = fft.forward(&input);
393 assert_eq!(output.shape(), vec![2, 33]); }
395
396 #[test]
397 fn test_fft1d_shape_3d() {
398 let fft = FFT1d::new(128);
399 let input = Variable::new(
400 Tensor::from_vec(vec![0.0; 2 * 3 * 128], &[2, 3, 128]).expect("tensor creation failed"),
401 false,
402 );
403 let output = fft.forward(&input);
404 assert_eq!(output.shape(), vec![2, 3, 65]); }
406
407 #[test]
408 fn test_fft1d_known_sinusoid() {
409 let n = 64;
411 let freq = 10.0;
412 let sample_rate = 64.0;
413 let signal: Vec<f32> = (0..n)
414 .map(|i| {
415 let t = i as f32 / sample_rate;
416 (2.0 * std::f32::consts::PI * freq * t).sin()
417 })
418 .collect();
419
420 let fft = FFT1d::new(n);
421 let input = Variable::new(
422 Tensor::from_vec(signal, &[1, n]).expect("tensor creation failed"),
423 false,
424 );
425 let output = fft.forward(&input);
426 let spectrum = output.data().to_vec();
427
428 let peak_bin = spectrum
430 .iter()
431 .enumerate()
432 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
433 .unwrap()
434 .0;
435 assert_eq!(peak_bin, 10);
436 }
437
438 #[test]
439 fn test_fft1d_zero_padding() {
440 let fft = FFT1d::new(128);
442 let input = Variable::new(
443 Tensor::from_vec(vec![1.0; 32], &[1, 32]).expect("tensor creation failed"),
444 false,
445 );
446 let output = fft.forward(&input);
447 assert_eq!(output.shape(), vec![1, 65]);
448 }
449
450 #[test]
451 fn test_fft1d_normalized() {
452 let fft_norm = FFT1d::with_normalization(64, true);
453 let fft_raw = FFT1d::new(64);
454
455 let signal = vec![1.0; 64];
456 let input = Variable::new(
457 Tensor::from_vec(signal, &[1, 64]).expect("tensor creation failed"),
458 false,
459 );
460
461 let out_norm = fft_norm.forward(&input).data().to_vec();
462 let out_raw = fft_raw.forward(&input).data().to_vec();
463
464 let ratio = out_raw[0] / out_norm[0];
466 assert!((ratio - 8.0).abs() < 0.01);
467 }
468
469 #[test]
470 fn test_stft_shape() {
471 let stft = STFT::new(256, 128);
472 let input = Variable::new(
473 Tensor::from_vec(vec![0.0; 2 * 1024], &[2, 1024]).expect("tensor creation failed"),
474 false,
475 );
476 let output = stft.forward(&input);
477
478 let n_frames = stft.n_frames(1024); assert_eq!(output.shape(), vec![2, n_frames, 129]);
480 assert_eq!(n_frames, 7);
481 }
482
483 #[test]
484 fn test_stft_shape_3d() {
485 let stft = STFT::new(64, 32);
486 let input = Variable::new(
487 Tensor::from_vec(vec![0.0; 2 * 3 * 256], &[2, 3, 256]).expect("tensor creation failed"),
488 false,
489 );
490 let output = stft.forward(&input);
491
492 let n_frames = stft.n_frames(256); assert_eq!(output.shape(), vec![2, 3, n_frames, 33]);
494 }
495
496 #[test]
497 fn test_stft_no_parameters() {
498 let stft = STFT::new(256, 128);
499 assert_eq!(stft.parameters().len(), 0);
500 }
501
502 #[test]
503 fn test_fft1d_output_bins() {
504 assert_eq!(FFT1d::new(64).output_bins(), 33);
505 assert_eq!(FFT1d::new(256).output_bins(), 129);
506 assert_eq!(FFT1d::new(512).output_bins(), 257);
507 }
508
509 #[test]
510 fn test_hann_window() {
511 let w = hann_window(4);
512 assert!((w[0]).abs() < 1e-6);
514 assert!((w[1] - 0.75).abs() < 0.01);
515 assert!((w[2] - 0.75).abs() < 0.01);
516 assert!((w[3]).abs() < 1e-6);
517 }
518}