1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21use rustfft::{FftPlanner, num_complex::Complex};
22
23use crate::module::Module;
24use crate::parameter::Parameter;
25
26pub struct FFT1d {
51 n_fft: usize,
52 normalized: bool,
53}
54
55impl FFT1d {
56 pub fn new(n_fft: usize) -> Self {
61 Self {
62 n_fft,
63 normalized: false,
64 }
65 }
66
67 pub fn with_normalization(n_fft: usize, normalized: bool) -> Self {
71 Self { n_fft, normalized }
72 }
73
74 pub fn output_bins(&self) -> usize {
76 self.n_fft / 2 + 1
77 }
78
79 fn fft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
81 let n = self.n_fft;
82 let n_out = n / 2 + 1;
83
84 let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
86 let copy_len = signal.len().min(n);
87 for i in 0..copy_len {
88 buffer[i] = Complex::new(signal[i], 0.0);
89 }
90
91 let mut planner = FftPlanner::new();
93 let fft = planner.plan_fft_forward(n);
94 fft.process(&mut buffer);
95
96 let norm_factor = if self.normalized {
98 1.0 / (n as f32).sqrt()
99 } else {
100 1.0
101 };
102
103 let mut magnitude = Vec::with_capacity(n_out);
104 for i in 0..n_out {
105 let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
106 magnitude.push(mag * norm_factor);
107 }
108
109 magnitude
110 }
111}
112
113impl Module for FFT1d {
114 fn forward(&self, input: &Variable) -> Variable {
115 let shape = input.shape();
116 let data = input.data().to_vec();
117 let n_out = self.output_bins();
118
119 match shape.len() {
120 2 => {
121 let batch = shape[0];
123 let time = shape[1];
124 let mut output = Vec::with_capacity(batch * n_out);
125
126 for b in 0..batch {
127 let start = b * time;
128 let end = start + time;
129 let signal = &data[start..end];
130 output.extend_from_slice(&self.fft_magnitude(signal));
131 }
132
133 Variable::new(
134 Tensor::from_vec(output, &[batch, n_out]).unwrap(),
135 input.requires_grad(),
136 )
137 }
138 3 => {
139 let batch = shape[0];
141 let channels = shape[1];
142 let time = shape[2];
143 let mut output = Vec::with_capacity(batch * channels * n_out);
144
145 for b in 0..batch {
146 for c in 0..channels {
147 let start = (b * channels + c) * time;
148 let end = start + time;
149 let signal = &data[start..end];
150 output.extend_from_slice(&self.fft_magnitude(signal));
151 }
152 }
153
154 Variable::new(
155 Tensor::from_vec(output, &[batch, channels, n_out]).unwrap(),
156 input.requires_grad(),
157 )
158 }
159 _ => panic!(
160 "FFT1d expects input of shape (batch, time) or (batch, channels, time), got {:?}",
161 shape
162 ),
163 }
164 }
165
166 fn parameters(&self) -> Vec<Parameter> {
167 Vec::new() }
169
170 fn named_parameters(&self) -> HashMap<String, Parameter> {
171 HashMap::new()
172 }
173
174 fn name(&self) -> &'static str {
175 "FFT1d"
176 }
177}
178
179pub struct STFT {
204 n_fft: usize,
205 hop_length: usize,
206 window: Vec<f32>,
207 normalized: bool,
208}
209
210impl STFT {
211 pub fn new(n_fft: usize, hop_length: usize) -> Self {
217 let window = hann_window(n_fft);
218 Self {
219 n_fft,
220 hop_length,
221 window,
222 normalized: false,
223 }
224 }
225
226 pub fn with_normalization(n_fft: usize, hop_length: usize, normalized: bool) -> Self {
228 let window = hann_window(n_fft);
229 Self {
230 n_fft,
231 hop_length,
232 window,
233 normalized,
234 }
235 }
236
237 pub fn output_bins(&self) -> usize {
239 self.n_fft / 2 + 1
240 }
241
242 pub fn n_frames(&self, signal_length: usize) -> usize {
244 if signal_length < self.n_fft {
245 1
246 } else {
247 (signal_length - self.n_fft) / self.hop_length + 1
248 }
249 }
250
251 fn stft_magnitude(&self, signal: &[f32]) -> Vec<f32> {
253 let n = self.n_fft;
254 let n_out = n / 2 + 1;
255 let n_frames = self.n_frames(signal.len());
256
257 let norm_factor = if self.normalized {
258 1.0 / (n as f32).sqrt()
259 } else {
260 1.0
261 };
262
263 let mut planner = FftPlanner::new();
264 let fft = planner.plan_fft_forward(n);
265
266 let mut output = Vec::with_capacity(n_frames * n_out);
267
268 for frame in 0..n_frames {
269 let start = frame * self.hop_length;
270
271 let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n];
273 for i in 0..n {
274 let idx = start + i;
275 let sample = if idx < signal.len() { signal[idx] } else { 0.0 };
276 buffer[i] = Complex::new(sample * self.window[i], 0.0);
277 }
278
279 fft.process(&mut buffer);
280
281 for i in 0..n_out {
282 let mag = (buffer[i].re * buffer[i].re + buffer[i].im * buffer[i].im).sqrt();
283 output.push(mag * norm_factor);
284 }
285 }
286
287 output
288 }
289}
290
291impl Module for STFT {
292 fn forward(&self, input: &Variable) -> Variable {
293 let shape = input.shape();
294 let data = input.data().to_vec();
295 let n_out = self.output_bins();
296
297 match shape.len() {
298 2 => {
299 let batch = shape[0];
301 let time = shape[1];
302 let n_frames = self.n_frames(time);
303 let mut output = Vec::with_capacity(batch * n_frames * n_out);
304
305 for b in 0..batch {
306 let start = b * time;
307 let end = start + time;
308 let signal = &data[start..end];
309 output.extend_from_slice(&self.stft_magnitude(signal));
310 }
311
312 Variable::new(
313 Tensor::from_vec(output, &[batch, n_frames, n_out]).unwrap(),
314 input.requires_grad(),
315 )
316 }
317 3 => {
318 let batch = shape[0];
320 let channels = shape[1];
321 let time = shape[2];
322 let n_frames = self.n_frames(time);
323 let mut output = Vec::with_capacity(batch * channels * n_frames * n_out);
324
325 for b in 0..batch {
326 for c in 0..channels {
327 let start = (b * channels + c) * time;
328 let end = start + time;
329 let signal = &data[start..end];
330 output.extend_from_slice(&self.stft_magnitude(signal));
331 }
332 }
333
334 Variable::new(
335 Tensor::from_vec(output, &[batch, channels, n_frames, n_out]).unwrap(),
336 input.requires_grad(),
337 )
338 }
339 _ => panic!(
340 "STFT expects input of shape (batch, time) or (batch, channels, time), got {:?}",
341 shape
342 ),
343 }
344 }
345
346 fn parameters(&self) -> Vec<Parameter> {
347 Vec::new()
348 }
349
350 fn named_parameters(&self) -> HashMap<String, Parameter> {
351 HashMap::new()
352 }
353
354 fn name(&self) -> &'static str {
355 "STFT"
356 }
357}
358
359fn hann_window(size: usize) -> Vec<f32> {
365 (0..size)
366 .map(|i| {
367 let phase = 2.0 * std::f32::consts::PI * i as f32 / (size - 1) as f32;
368 0.5 * (1.0 - phase.cos())
369 })
370 .collect()
371}
372
373#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_fft1d_shape_2d() {
383 let fft = FFT1d::new(64);
384 let input = Variable::new(Tensor::from_vec(vec![0.0; 128], &[2, 64]).unwrap(), false);
385 let output = fft.forward(&input);
386 assert_eq!(output.shape(), vec![2, 33]); }
388
389 #[test]
390 fn test_fft1d_shape_3d() {
391 let fft = FFT1d::new(128);
392 let input = Variable::new(
393 Tensor::from_vec(vec![0.0; 2 * 3 * 128], &[2, 3, 128]).unwrap(),
394 false,
395 );
396 let output = fft.forward(&input);
397 assert_eq!(output.shape(), vec![2, 3, 65]); }
399
400 #[test]
401 fn test_fft1d_known_sinusoid() {
402 let n = 64;
404 let freq = 10.0;
405 let sample_rate = 64.0;
406 let signal: Vec<f32> = (0..n)
407 .map(|i| {
408 let t = i as f32 / sample_rate;
409 (2.0 * std::f32::consts::PI * freq * t).sin()
410 })
411 .collect();
412
413 let fft = FFT1d::new(n);
414 let input = Variable::new(Tensor::from_vec(signal, &[1, n]).unwrap(), false);
415 let output = fft.forward(&input);
416 let spectrum = output.data().to_vec();
417
418 let peak_bin = spectrum
420 .iter()
421 .enumerate()
422 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
423 .unwrap()
424 .0;
425 assert_eq!(peak_bin, 10);
426 }
427
428 #[test]
429 fn test_fft1d_zero_padding() {
430 let fft = FFT1d::new(128);
432 let input = Variable::new(Tensor::from_vec(vec![1.0; 32], &[1, 32]).unwrap(), false);
433 let output = fft.forward(&input);
434 assert_eq!(output.shape(), vec![1, 65]);
435 }
436
437 #[test]
438 fn test_fft1d_normalized() {
439 let fft_norm = FFT1d::with_normalization(64, true);
440 let fft_raw = FFT1d::new(64);
441
442 let signal = vec![1.0; 64];
443 let input = Variable::new(Tensor::from_vec(signal, &[1, 64]).unwrap(), false);
444
445 let out_norm = fft_norm.forward(&input).data().to_vec();
446 let out_raw = fft_raw.forward(&input).data().to_vec();
447
448 let ratio = out_raw[0] / out_norm[0];
450 assert!((ratio - 8.0).abs() < 0.01);
451 }
452
453 #[test]
454 fn test_stft_shape() {
455 let stft = STFT::new(256, 128);
456 let input = Variable::new(
457 Tensor::from_vec(vec![0.0; 2 * 1024], &[2, 1024]).unwrap(),
458 false,
459 );
460 let output = stft.forward(&input);
461
462 let n_frames = stft.n_frames(1024); assert_eq!(output.shape(), vec![2, n_frames, 129]);
464 assert_eq!(n_frames, 7);
465 }
466
467 #[test]
468 fn test_stft_shape_3d() {
469 let stft = STFT::new(64, 32);
470 let input = Variable::new(
471 Tensor::from_vec(vec![0.0; 2 * 3 * 256], &[2, 3, 256]).unwrap(),
472 false,
473 );
474 let output = stft.forward(&input);
475
476 let n_frames = stft.n_frames(256); assert_eq!(output.shape(), vec![2, 3, n_frames, 33]);
478 }
479
480 #[test]
481 fn test_stft_no_parameters() {
482 let stft = STFT::new(256, 128);
483 assert_eq!(stft.parameters().len(), 0);
484 }
485
486 #[test]
487 fn test_fft1d_output_bins() {
488 assert_eq!(FFT1d::new(64).output_bins(), 33);
489 assert_eq!(FFT1d::new(256).output_bins(), 129);
490 assert_eq!(FFT1d::new(512).output_bins(), 257);
491 }
492
493 #[test]
494 fn test_hann_window() {
495 let w = hann_window(4);
496 assert!((w[0]).abs() < 1e-6);
498 assert!((w[1] - 0.75).abs() < 0.01);
499 assert!((w[2] - 0.75).abs() < 0.01);
500 assert!((w[3]).abs() < 1e-6);
501 }
502}