1#![allow(unused_imports)]
7
8#[cfg(feature = "std")]
9use std::{f32::consts::PI, vec::Vec, marker::PhantomData};
10
11#[cfg(not(feature = "std"))]
12use core::{f32::consts::PI, marker::PhantomData};
13
14#[cfg(all(not(feature = "std"), feature = "alloc"))]
15use alloc::vec::Vec;
16
17use num_complex::Complex;
18use num_traits::Float;
19use num_traits::FromPrimitive;
20use num_traits::NumCast;
21
22use crate::fft;
23use crate::windows;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum WindowShape {
28 Ignore,
30 ACG,
32 Kaiser,
34}
35
36pub struct STFT<T: Float> {
38 fft: fft::Pow2RealFFT<T>,
40
41 analysis_channels: usize,
43 synthesis_channels: usize,
44 block_samples: usize,
45 fft_samples: usize,
46 fft_bins: usize,
47 input_length_samples: usize,
48 default_interval: usize,
49
50 analysis_window: Vec<T>,
52 synthesis_window: Vec<T>,
53 analysis_offset: usize,
54 synthesis_offset: usize,
55
56 input_buffer: Vec<T>,
58 input_pos: usize,
59 output_buffer: Vec<T>,
60 output_pos: usize,
61 window_products: Vec<T>,
62 spectrum_buffer: Vec<Complex<T>>,
63 time_buffer: Vec<T>,
64
65 almost_zero: T,
67 modified: bool,
68}
69
70#[cfg(feature = "std")]
71use std::ops::AddAssign;
72
73#[cfg(not(feature = "std"))]
74use core::ops::AddAssign;
75
76impl<T: Float + FromPrimitive + NumCast + AddAssign> STFT<T> {
77 pub fn new(modified: bool) -> Self {
79 Self {
80 fft: fft::Pow2RealFFT::new(0),
81 analysis_channels: 0,
82 synthesis_channels: 0,
83 block_samples: 0,
84 fft_samples: 0,
85 fft_bins: 0,
86 input_length_samples: 0,
87 default_interval: 0,
88 analysis_window: Vec::new(),
89 synthesis_window: Vec::new(),
90 analysis_offset: 0,
91 synthesis_offset: 0,
92 input_buffer: Vec::new(),
93 input_pos: 0,
94 output_buffer: Vec::new(),
95 output_pos: 0,
96 window_products: Vec::new(),
97 spectrum_buffer: Vec::new(),
98 time_buffer: Vec::new(),
99 almost_zero: T::from_f32(1e-20).unwrap(),
100 modified,
101 }
102 }
103
104 pub fn configure(
106 &mut self,
107 in_channels: usize,
108 out_channels: usize,
109 block_samples: usize,
110 extra_input_history: usize,
111 interval_samples: usize,
112 ) {
113 self.analysis_channels = in_channels;
114 self.synthesis_channels = out_channels;
115 self.block_samples = block_samples;
116
117 let mut fft_samples = 1;
119 while fft_samples < block_samples {
120 fft_samples *= 2;
121 }
122 self.fft_samples = fft_samples;
123 self.fft.resize(fft_samples);
124 self.fft_bins = fft_samples / 2 + 1; self.input_length_samples = block_samples + extra_input_history;
127 self.input_buffer.resize(self.input_length_samples * in_channels, T::zero());
128
129 self.output_buffer.resize(block_samples * out_channels, T::zero());
130 self.window_products.resize(block_samples, T::zero());
131 self.spectrum_buffer.resize(self.fft_bins * in_channels.max(out_channels), Complex::new(T::zero(), T::zero()));
132 self.time_buffer.resize(fft_samples, T::zero());
133
134 self.analysis_window.resize(block_samples, T::zero());
135 self.synthesis_window.resize(block_samples, T::zero());
136
137 let interval = if interval_samples > 0 {
139 interval_samples
140 } else {
141 block_samples / 4
142 };
143 self.set_interval(interval, WindowShape::ACG);
144
145 self.reset_default();
146 }
147
148 pub fn block_samples(&self) -> usize {
150 self.block_samples
151 }
152
153 pub fn fft_samples(&self) -> usize {
155 self.fft_samples
156 }
157
158 pub fn default_interval(&self) -> usize {
160 self.default_interval
161 }
162
163 pub fn bands(&self) -> usize {
165 self.fft_bins
166 }
167
168 pub fn analysis_latency(&self) -> usize {
170 self.block_samples - self.analysis_offset
171 }
172
173 pub fn synthesis_latency(&self) -> usize {
175 self.synthesis_offset
176 }
177
178 pub fn latency(&self) -> usize {
180 self.synthesis_latency() + self.analysis_latency()
181 }
182
183 pub fn bin_to_freq(&self, bin: T) -> T {
185 if self.modified {
186 (bin + T::from_f32(0.5).unwrap()) / T::from_usize(self.fft_samples).unwrap()
187 } else {
188 bin / T::from_usize(self.fft_samples).unwrap()
189 }
190 }
191
192 pub fn freq_to_bin(&self, freq: T) -> T {
194 if self.modified {
195 freq * T::from_usize(self.fft_samples).unwrap() - T::from_f32(0.5).unwrap()
196 } else {
197 freq * T::from_usize(self.fft_samples).unwrap()
198 }
199 }
200
201 pub fn reset(&mut self, product_weight: T) {
203 self.input_pos = self.block_samples;
204 self.output_pos = 0;
205
206 for v in &mut self.input_buffer {
208 *v = T::zero();
209 }
210 for v in &mut self.output_buffer {
211 *v = T::zero();
212 }
213 for v in &mut self.spectrum_buffer {
214 *v = Complex::new(T::zero(), T::zero());
215 }
216 for v in &mut self.window_products {
217 *v = T::zero();
218 }
219
220 self.add_window_product();
222
223 for i in (0..self.block_samples - self.default_interval).rev() {
225 self.window_products[i] = self.window_products[i] + self.window_products[i + self.default_interval];
226 }
227
228 for v in &mut self.window_products {
230 *v = *v * product_weight + self.almost_zero;
231 }
232
233 self.move_output(self.default_interval);
235 }
236
237 pub fn reset_default(&mut self) {
239 self.reset(T::one());
240 }
241
242 pub fn write_input(&mut self, channel: usize, offset: usize, length: usize, input_array: &[T]) {
244 assert!(channel < self.analysis_channels, "Channel index out of bounds");
245 assert!(offset + length <= input_array.len(), "Input array too small");
246
247 let buffer_start = channel * self.input_length_samples;
248 let offset_pos = (self.input_pos + offset) % self.input_length_samples;
249
250 let input_wrap_index = self.input_length_samples - offset_pos;
252 let chunk1 = length.min(input_wrap_index);
253
254 for i in 0..chunk1 {
256 let buffer_index = buffer_start + offset_pos + i;
257 self.input_buffer[buffer_index] = input_array[i];
258 }
259
260 for i in chunk1..length {
262 let buffer_index = buffer_start + i + offset_pos - self.input_length_samples;
263 self.input_buffer[buffer_index] = input_array[i];
264 }
265 }
266
267 pub fn write_input_simple(&mut self, channel: usize, input_array: &[T]) {
269 self.write_input(channel, 0, input_array.len(), input_array);
270 }
271
272 pub fn read_output(&self, channel: usize, offset: usize, length: usize, output_array: &mut [T]) {
274 assert!(channel < self.synthesis_channels, "Channel index out of bounds");
275 assert!(offset + length <= output_array.len(), "Output array too small");
276
277 let buffer_start = channel * self.block_samples;
278 let offset_pos = (self.output_pos + offset) % self.block_samples;
279
280 let output_wrap_index = self.block_samples - offset_pos;
282 let chunk1 = length.min(output_wrap_index);
283
284 for i in 0..chunk1 {
286 let buffer_index = buffer_start + offset_pos + i;
287 output_array[i] = self.output_buffer[buffer_index];
288 }
289
290 for i in chunk1..length {
292 let buffer_index = buffer_start + i + offset_pos - self.block_samples;
293 output_array[i] = self.output_buffer[buffer_index];
294 }
295 }
296
297 pub fn read_output_simple(&self, channel: usize, output_array: &mut [T]) {
299 self.read_output(channel, 0, output_array.len(), output_array);
300 }
301
302 pub fn move_input(&mut self, samples: usize) {
304 self.input_pos = (self.input_pos + samples) % self.input_length_samples;
305 }
306
307 pub fn move_output(&mut self, samples: usize) {
309 self.output_pos = (self.output_pos + samples) % self.block_samples;
310 }
311
312 pub fn set_interval(&mut self, interval: usize, window_shape: WindowShape) {
314 self.default_interval = interval;
315
316 self.analysis_offset = self.block_samples / 2;
318 self.synthesis_offset = self.block_samples / 2;
319
320 match window_shape {
322 WindowShape::Ignore => {
323 for i in 0..self.block_samples {
325 self.analysis_window[i] = T::one();
326 self.synthesis_window[i] = T::one();
327 }
328 },
329 WindowShape::ACG => {
330 let acg = windows::ApproximateConfinedGaussian::with_bandwidth(T::from_f32(2.5).unwrap());
332 acg.fill(self.analysis_window.as_mut_slice());
333 acg.fill(self.synthesis_window.as_mut_slice());
334 },
335 WindowShape::Kaiser => {
336 let kaiser = windows::Kaiser::with_bandwidth(T::from_f32(2.5).unwrap(), true);
338 kaiser.fill(self.analysis_window.as_mut_slice());
339 kaiser.fill(self.synthesis_window.as_mut_slice());
340 },
341 }
342
343 windows::force_perfect_reconstruction(&mut self.synthesis_window, self.block_samples, interval);
345 }
346
347 fn add_window_product(&mut self) {
349 for i in 0..self.block_samples {
350 self.window_products[i] += self.analysis_window[i] * self.synthesis_window[i];
351 }
352 }
353
354 pub fn process_block_to_spectrum(&mut self, channel: usize) -> &[Complex<T>] {
356 assert!(channel < self.analysis_channels, "Channel index out of bounds");
357
358 let buffer_start = channel * self.input_length_samples;
360 for i in 0..self.block_samples {
361 let input_index = (self.input_pos + self.block_samples - self.analysis_offset + i) % self.input_length_samples;
362 self.time_buffer[i] = self.input_buffer[buffer_start + input_index] * self.analysis_window[i];
363 }
364
365 for i in self.block_samples..self.fft_samples {
367 self.time_buffer[i] = T::zero();
368 }
369
370 let spectrum_start = channel * self.fft_bins;
372 let spectrum_slice = &mut self.spectrum_buffer[spectrum_start..spectrum_start + self.fft_bins];
373 self.fft.fft(&self.time_buffer, spectrum_slice);
374
375 &self.spectrum_buffer[spectrum_start..spectrum_start + self.fft_bins]
377 }
378
379 pub fn process_spectrum_to_block(&mut self, channel: usize, spectrum: &[Complex<T>]) {
381 assert!(channel < self.synthesis_channels, "Channel index out of bounds");
382 assert!(spectrum.len() >= self.fft_bins, "Spectrum too small");
383
384 self.fft.ifft(spectrum, &mut self.time_buffer);
386
387 let buffer_start = channel * self.block_samples;
389 for i in 0..self.block_samples {
390 let output_index = (self.output_pos + self.synthesis_offset + i) % self.block_samples;
392 let window_product = self.window_products[i];
393 let value = self.time_buffer[i] * self.synthesis_window[i] / window_product;
394 self.output_buffer[buffer_start + output_index] += value;
395 }
396 }
397
398 pub fn process_block(&mut self, in_channel: usize, out_channel: usize) {
400 let spectrum = self.process_block_to_spectrum(in_channel);
402
403 let spectrum_copy = spectrum.to_vec();
405
406 self.process_spectrum_to_block(out_channel, &spectrum_copy);
408 }
409
410 pub fn process_channels(&mut self, in_channels: &[usize], out_channels: &[usize]) {
412 assert!(in_channels.len() == out_channels.len(), "Channel arrays must have the same length");
413
414 for (in_ch, out_ch) in in_channels.iter().zip(out_channels.iter()) {
415 self.process_block(*in_ch, *out_ch);
416 }
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_stft_configuration() {
426 let mut stft = STFT::<f32>::new(false);
427 stft.configure(2, 2, 1024, 0, 256);
428
429 assert_eq!(stft.block_samples(), 1024);
430 assert_eq!(stft.fft_samples(), 1024);
431 assert_eq!(stft.default_interval(), 256);
432 assert_eq!(stft.bands(), 513); }
434
435 #[test]
436 fn test_stft_io() {
437 let mut stft = STFT::<f32>::new(false);
438 stft.configure(1, 1, 16, 0, 4);
439
440 let mut input = vec![0.0; 32];
442 input[0] = 1.0;
443
444 stft.write_input_simple(0, &input[0..16]);
446 stft.process_block(0, 0);
447
448 stft.move_input(4);
450 stft.write_input(0, 0, 4, &vec![0.0; 4]);
451 stft.process_block(0, 0);
452
453 stft.move_input(4);
454 stft.write_input(0, 0, 4, &vec![0.0; 4]);
455 stft.process_block(0, 0);
456
457 stft.move_input(4);
458 stft.write_input(0, 0, 4, &vec![0.0; 4]);
459 stft.process_block(0, 0);
460
461 let mut output = vec![0.0; 16];
463 stft.read_output_simple(0, &mut output);
464
465 let max_index = output.iter()
467 .enumerate()
468 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
469 .map(|(index, _)| index)
470 .unwrap();
471
472 assert_eq!(max_index, 4);
475 }
476
477 #[test]
478 fn test_stft_frequency_conversion() {
479 let mut stft = STFT::<f32>::new(false);
480 stft.configure(1, 1, 1024, 0, 256);
481
482 let bin = 100.0;
484 let freq = stft.bin_to_freq(bin);
485 let bin2 = stft.freq_to_bin(freq);
486
487 assert!((bin - bin2).abs() < 1e-10);
488 }
489}