1use std::sync::Arc;
7
8use num_complex::Complex as NumComplex;
9use rustfft::{Fft, FftPlanner};
10
11use crate::error::{AudioFftError, Result};
12use crate::messages::Complex;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum WindowFunction {
17 Rectangular,
19 Hann,
21 Hamming,
23 Blackman,
25 BlackmanHarris,
27 Kaiser(u8), }
30
31impl WindowFunction {
32 pub fn generate(&self, size: usize) -> Vec<f32> {
34 let n = size as f32;
35 (0..size)
36 .map(|i| {
37 let x = i as f32;
38 match self {
39 Self::Rectangular => 1.0,
40 Self::Hann => 0.5 * (1.0 - (2.0 * std::f32::consts::PI * x / n).cos()),
41 Self::Hamming => 0.54 - 0.46 * (2.0 * std::f32::consts::PI * x / n).cos(),
42 Self::Blackman => {
43 let a0 = 0.42;
44 let a1 = 0.5;
45 let a2 = 0.08;
46 a0 - a1 * (2.0 * std::f32::consts::PI * x / n).cos()
47 + a2 * (4.0 * std::f32::consts::PI * x / n).cos()
48 }
49 Self::BlackmanHarris => {
50 let a0 = 0.35875;
51 let a1 = 0.48829;
52 let a2 = 0.14128;
53 let a3 = 0.01168;
54 a0 - a1 * (2.0 * std::f32::consts::PI * x / n).cos()
55 + a2 * (4.0 * std::f32::consts::PI * x / n).cos()
56 - a3 * (6.0 * std::f32::consts::PI * x / n).cos()
57 }
58 Self::Kaiser(beta_10) => {
59 let beta = *beta_10 as f32 / 10.0;
60 let alpha = (n - 1.0) / 2.0;
61 let r = (x - alpha) / alpha;
62 bessel_i0(beta * (1.0 - r * r).sqrt()) / bessel_i0(beta)
63 }
64 }
65 })
66 .collect()
67 }
68
69 pub fn coherent_gain(&self) -> f32 {
71 match self {
72 Self::Rectangular => 1.0,
73 Self::Hann => 0.5,
74 Self::Hamming => 0.54,
75 Self::Blackman => 0.42,
76 Self::BlackmanHarris => 0.35875,
77 Self::Kaiser(_) => 0.5, }
79 }
80}
81
82fn bessel_i0(x: f32) -> f32 {
84 let mut sum = 1.0f32;
85 let mut term = 1.0f32;
86 let x2 = x * x / 4.0;
87
88 for k in 1..20 {
89 term *= x2 / (k * k) as f32;
90 sum += term;
91 if term < 1e-10 {
92 break;
93 }
94 }
95 sum
96}
97
98pub struct FftProcessor {
100 fft_size: usize,
102 hop_size: usize,
104 sample_rate: u32,
106 #[allow(dead_code)]
108 window: WindowFunction,
109 window_coeffs: Vec<f32>,
111 fft: Arc<dyn Fft<f32>>,
113 scratch: Vec<NumComplex<f32>>,
115 input_buffer: Vec<f32>,
117}
118
119impl FftProcessor {
120 pub fn new(fft_size: usize, hop_size: usize, sample_rate: u32) -> Result<Self> {
122 Self::with_window(fft_size, hop_size, sample_rate, WindowFunction::Hann)
123 }
124
125 pub fn with_window(
127 fft_size: usize,
128 hop_size: usize,
129 sample_rate: u32,
130 window: WindowFunction,
131 ) -> Result<Self> {
132 if !fft_size.is_power_of_two() {
133 return Err(AudioFftError::config(format!(
134 "FFT size must be power of 2, got {}",
135 fft_size
136 )));
137 }
138
139 if hop_size > fft_size {
140 return Err(AudioFftError::config(format!(
141 "Hop size {} cannot exceed FFT size {}",
142 hop_size, fft_size
143 )));
144 }
145
146 let mut planner = FftPlanner::new();
147 let fft = planner.plan_fft_forward(fft_size);
148 let scratch_len = fft.get_inplace_scratch_len();
149
150 Ok(Self {
151 fft_size,
152 hop_size,
153 sample_rate,
154 window,
155 window_coeffs: window.generate(fft_size),
156 fft,
157 scratch: vec![NumComplex::default(); scratch_len],
158 input_buffer: Vec::with_capacity(fft_size * 2),
159 })
160 }
161
162 pub fn fft_size(&self) -> usize {
164 self.fft_size
165 }
166
167 pub fn hop_size(&self) -> usize {
169 self.hop_size
170 }
171
172 pub fn num_bins(&self) -> usize {
174 self.fft_size / 2 + 1
175 }
176
177 pub fn bin_to_frequency(&self, bin: usize) -> f32 {
179 bin as f32 * self.sample_rate as f32 / self.fft_size as f32
180 }
181
182 pub fn frequency_to_bin(&self, freq: f32) -> usize {
184 (freq * self.fft_size as f32 / self.sample_rate as f32).round() as usize
185 }
186
187 pub fn process_frame(&mut self, samples: &[f32]) -> Vec<Complex> {
189 self.input_buffer.extend_from_slice(samples);
191
192 if self.input_buffer.len() < self.fft_size {
194 return Vec::new();
195 }
196
197 let mut buffer: Vec<NumComplex<f32>> = self.input_buffer[..self.fft_size]
199 .iter()
200 .enumerate()
201 .map(|(i, &s)| NumComplex::new(s * self.window_coeffs[i], 0.0))
202 .collect();
203
204 self.fft
206 .process_with_scratch(&mut buffer, &mut self.scratch);
207
208 self.input_buffer.drain(..self.hop_size);
210
211 buffer[..self.num_bins()]
213 .iter()
214 .map(|c| Complex::new(c.re, c.im))
215 .collect()
216 }
217
218 pub fn process_all(&mut self, samples: &[f32]) -> Vec<Vec<Complex>> {
220 self.input_buffer.extend_from_slice(samples);
221
222 let mut frames = Vec::new();
223
224 while self.input_buffer.len() >= self.fft_size {
225 let mut buffer: Vec<NumComplex<f32>> = self.input_buffer[..self.fft_size]
227 .iter()
228 .enumerate()
229 .map(|(i, &s)| NumComplex::new(s * self.window_coeffs[i], 0.0))
230 .collect();
231
232 self.fft
234 .process_with_scratch(&mut buffer, &mut self.scratch);
235
236 self.input_buffer.drain(..self.hop_size);
238
239 frames.push(
241 buffer[..self.num_bins()]
242 .iter()
243 .map(|c| Complex::new(c.re, c.im))
244 .collect(),
245 );
246 }
247
248 frames
249 }
250
251 pub fn flush(&mut self) -> Option<Vec<Complex>> {
253 if self.input_buffer.is_empty() {
254 return None;
255 }
256
257 self.input_buffer.resize(self.fft_size, 0.0);
259
260 let mut buffer: Vec<NumComplex<f32>> = self
261 .input_buffer
262 .iter()
263 .enumerate()
264 .map(|(i, &s)| NumComplex::new(s * self.window_coeffs[i], 0.0))
265 .collect();
266
267 self.fft
268 .process_with_scratch(&mut buffer, &mut self.scratch);
269 self.input_buffer.clear();
270
271 Some(
272 buffer[..self.num_bins()]
273 .iter()
274 .map(|c| Complex::new(c.re, c.im))
275 .collect(),
276 )
277 }
278
279 pub fn reset(&mut self) {
281 self.input_buffer.clear();
282 }
283}
284
285pub struct IfftProcessor {
287 fft_size: usize,
289 hop_size: usize,
291 ifft: Arc<dyn Fft<f32>>,
293 scratch: Vec<NumComplex<f32>>,
295 synthesis_window: Vec<f32>,
297 output_buffer: Vec<f32>,
299 norm_factor: f32,
301}
302
303impl IfftProcessor {
304 pub fn new(fft_size: usize, hop_size: usize) -> Result<Self> {
306 Self::with_window(fft_size, hop_size, WindowFunction::Hann)
307 }
308
309 pub fn with_window(fft_size: usize, hop_size: usize, window: WindowFunction) -> Result<Self> {
311 if !fft_size.is_power_of_two() {
312 return Err(AudioFftError::config(format!(
313 "FFT size must be power of 2, got {}",
314 fft_size
315 )));
316 }
317
318 let mut planner = FftPlanner::new();
319 let ifft = planner.plan_fft_inverse(fft_size);
320 let scratch_len = ifft.get_inplace_scratch_len();
321
322 let window_coeffs = window.generate(fft_size);
325 let overlap_factor = fft_size / hop_size;
326
327 let mut cola_sum = vec![0.0f32; hop_size];
329 for offset in 0..overlap_factor {
330 for (i, sum) in cola_sum.iter_mut().enumerate() {
331 let window_idx = offset * hop_size + i;
332 if window_idx < fft_size {
333 *sum += window_coeffs[window_idx] * window_coeffs[window_idx];
334 }
335 }
336 }
337 let avg_cola = cola_sum.iter().sum::<f32>() / hop_size as f32;
338
339 Ok(Self {
340 fft_size,
341 hop_size,
342 ifft,
343 scratch: vec![NumComplex::default(); scratch_len],
344 synthesis_window: window_coeffs,
345 output_buffer: vec![0.0; fft_size * 2],
346 norm_factor: 1.0 / (fft_size as f32 * avg_cola.sqrt()),
347 })
348 }
349
350 pub fn process_frame(&mut self, bins: &[Complex]) -> Vec<f32> {
352 let mut buffer: Vec<NumComplex<f32>> = Vec::with_capacity(self.fft_size);
354
355 for bin in bins.iter().take(self.fft_size / 2 + 1) {
357 buffer.push(NumComplex::new(bin.re, bin.im));
358 }
359
360 for i in 1..self.fft_size / 2 {
362 let idx = self.fft_size / 2 - i;
363 if idx < bins.len() {
364 buffer.push(NumComplex::new(bins[idx].re, -bins[idx].im));
365 } else {
366 buffer.push(NumComplex::default());
367 }
368 }
369
370 while buffer.len() < self.fft_size {
372 buffer.push(NumComplex::default());
373 }
374
375 self.ifft
377 .process_with_scratch(&mut buffer, &mut self.scratch);
378
379 for (i, c) in buffer.iter().enumerate() {
381 self.output_buffer[i] += c.re * self.synthesis_window[i] * self.norm_factor;
382 }
383
384 let output: Vec<f32> = self.output_buffer[..self.hop_size].to_vec();
386
387 self.output_buffer.copy_within(self.hop_size.., 0);
389 for i in (self.output_buffer.len() - self.hop_size)..self.output_buffer.len() {
390 self.output_buffer[i] = 0.0;
391 }
392
393 output
394 }
395
396 pub fn flush(&mut self) -> Vec<f32> {
398 let mut output = Vec::new();
399
400 while self.output_buffer.iter().any(|&x| x.abs() > 1e-10) {
402 output.extend_from_slice(&self.output_buffer[..self.hop_size]);
403 self.output_buffer.copy_within(self.hop_size.., 0);
404 for i in (self.output_buffer.len() - self.hop_size)..self.output_buffer.len() {
405 self.output_buffer[i] = 0.0;
406 }
407 }
408
409 output
410 }
411
412 pub fn reset(&mut self) {
414 self.output_buffer.fill(0.0);
415 }
416}
417
418pub struct StftProcessor {
420 pub fft: FftProcessor,
422 pub ifft: IfftProcessor,
424}
425
426impl StftProcessor {
427 pub fn new(fft_size: usize, hop_size: usize, sample_rate: u32) -> Result<Self> {
429 Self::with_window(fft_size, hop_size, sample_rate, WindowFunction::Hann)
430 }
431
432 pub fn with_window(
434 fft_size: usize,
435 hop_size: usize,
436 sample_rate: u32,
437 window: WindowFunction,
438 ) -> Result<Self> {
439 Ok(Self {
440 fft: FftProcessor::with_window(fft_size, hop_size, sample_rate, window)?,
441 ifft: IfftProcessor::with_window(fft_size, hop_size, window)?,
442 })
443 }
444
445 pub fn process<F>(&mut self, samples: &[f32], mut processor: F) -> Vec<f32>
447 where
448 F: FnMut(&mut [Complex]),
449 {
450 let mut output = Vec::new();
451
452 for mut frame in self.fft.process_all(samples) {
453 processor(&mut frame);
454 output.extend(self.ifft.process_frame(&frame));
455 }
456
457 output
458 }
459
460 pub fn flush<F>(&mut self, mut processor: F) -> Vec<f32>
462 where
463 F: FnMut(&mut [Complex]),
464 {
465 let mut output = Vec::new();
466
467 if let Some(mut frame) = self.fft.flush() {
468 processor(&mut frame);
469 output.extend(self.ifft.process_frame(&frame));
470 }
471
472 output.extend(self.ifft.flush());
473 output
474 }
475
476 pub fn reset(&mut self) {
478 self.fft.reset();
479 self.ifft.reset();
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_window_functions() {
489 let size = 1024;
490
491 let hann = WindowFunction::Hann.generate(size);
492 assert!((hann[0] - 0.0).abs() < 1e-6);
493 assert!((hann[size / 2] - 1.0).abs() < 1e-6);
494
495 let hamming = WindowFunction::Hamming.generate(size);
496 assert!((hamming[0] - 0.08).abs() < 0.01);
497 }
498
499 #[test]
500 fn test_fft_roundtrip() {
501 let fft_size = 1024;
502 let hop_size = 256;
503 let sample_rate = 44100;
504
505 let mut stft = StftProcessor::new(fft_size, hop_size, sample_rate).unwrap();
506
507 let duration = 0.1;
509 let samples: Vec<f32> = (0..(sample_rate as f32 * duration) as usize)
510 .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sample_rate as f32).sin())
511 .collect();
512
513 let output = stft.process(&samples, |_bins| {
515 });
517
518 assert!(!output.is_empty());
520
521 }
524
525 #[test]
526 fn test_bin_frequency_conversion() {
527 let fft = FftProcessor::new(2048, 512, 44100).unwrap();
528
529 assert!((fft.bin_to_frequency(0) - 0.0).abs() < 1e-6);
531
532 let nyquist = fft.bin_to_frequency(1024);
534 assert!((nyquist - 22050.0).abs() < 1.0);
535
536 let freq = 1000.0;
538 let bin = fft.frequency_to_bin(freq);
539 let recovered = fft.bin_to_frequency(bin);
540 assert!((recovered - freq).abs() < 50.0); }
542}