1use neco_complex::Complex;
2
3use crate::dsp_float::DspFloat;
4use crate::FftError;
5
6pub type SpectrumFrame<T = f64> = Vec<Complex<T>>;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum StftError {
10 InvalidFrameLen { expected: usize, got: usize },
11 Fft(FftError),
12}
13
14impl core::fmt::Display for StftError {
15 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16 match self {
17 Self::InvalidFrameLen { expected, got } => {
18 write!(
19 f,
20 "wrong spectrum frame length: expected {expected}, got {got}"
21 )
22 }
23 Self::Fft(err) => err.fmt(f),
24 }
25 }
26}
27
28impl std::error::Error for StftError {}
29
30impl From<FftError> for StftError {
31 fn from(value: FftError) -> Self {
32 Self::Fft(value)
33 }
34}
35
36pub struct StftProcessor<T: DspFloat = f64> {
37 fft_size: usize,
38 hop_size: usize,
39 analysis_window: Vec<T>,
40 synthesis_window: Vec<T>,
41 norm: Vec<T>,
42}
43
44impl<T: DspFloat> StftProcessor<T> {
45 pub fn new(fft_size: usize, hop_size: usize, window: Vec<T>) -> Self {
46 assert_eq!(window.len(), fft_size);
47 assert!(hop_size > 0 && hop_size <= fft_size);
48
49 let norm = compute_wola_norm(&window, &window, fft_size, hop_size);
50 Self {
51 fft_size,
52 hop_size,
53 analysis_window: window.clone(),
54 synthesis_window: window,
55 norm,
56 }
57 }
58
59 pub fn fft_size(&self) -> usize {
60 self.fft_size
61 }
62
63 pub fn hop_size(&self) -> usize {
64 self.hop_size
65 }
66
67 pub fn num_bins(&self) -> usize {
68 self.fft_size / 2 + 1
69 }
70
71 pub fn analyze(&self, input: &[T]) -> Result<Vec<SpectrumFrame<T>>, StftError> {
72 T::with_fft_planner(|planner| {
73 let fft = planner.plan_fft_forward(self.fft_size);
74 let mut frames = Vec::new();
75 let mut pos = 0usize;
76 let mut windowed = vec![T::zero(); self.fft_size];
77 let mut spectrum = fft.make_output_vec();
78
79 while pos + self.fft_size <= input.len() {
80 for i in 0..self.fft_size {
81 windowed[i] = input[pos + i] * self.analysis_window[i];
82 }
83 fft.process(&mut windowed, &mut spectrum)?;
84 frames.push(spectrum.clone());
85 pos += self.hop_size;
86 }
87
88 Ok(frames)
89 })
90 }
91
92 pub fn synthesize(
93 &self,
94 frames: &[SpectrumFrame<T>],
95 output_len: usize,
96 ) -> Result<Vec<T>, StftError> {
97 T::with_fft_planner(|planner| {
98 let ifft = planner.plan_fft_inverse(self.fft_size);
99 let mut output = vec![T::zero(); output_len];
100 let mut pos = 0usize;
101 let scale = T::one() / T::from_usize(self.fft_size);
102 let mut spectrum = ifft.make_input_vec();
103 let mut time_buf = ifft.make_output_vec();
104
105 for frame in frames {
106 if pos + self.fft_size > output_len {
107 break;
108 }
109
110 if frame.len() != spectrum.len() {
111 return Err(StftError::InvalidFrameLen {
112 expected: spectrum.len(),
113 got: frame.len(),
114 });
115 }
116 spectrum.copy_from_slice(frame);
117 ifft.process(&mut spectrum, &mut time_buf)?;
118
119 for i in 0..self.fft_size {
120 output[pos + i] += time_buf[i] * scale * self.synthesis_window[i];
121 }
122 pos += self.hop_size;
123 }
124
125 self.apply_norm(&mut output);
126 Ok(output)
127 })
128 }
129
130 fn apply_norm(&self, output: &mut [T]) {
131 let epsilon = T::from_f64(1e-10);
132 let norm_len = self.norm.len();
133 for (index, sample) in output.iter_mut().enumerate() {
134 let norm_value = self.norm[index % norm_len];
135 if norm_value > epsilon {
136 *sample /= norm_value;
137 }
138 }
139 }
140}
141
142fn compute_wola_norm<T: DspFloat>(
143 analysis_window: &[T],
144 synthesis_window: &[T],
145 fft_size: usize,
146 hop_size: usize,
147) -> Vec<T> {
148 let num_frames = fft_size.div_ceil(hop_size) + 1;
149 let total_len = num_frames * hop_size + fft_size;
150 let mut norm = vec![T::zero(); total_len];
151 let mut pos = 0usize;
152 for _ in 0..num_frames + 1 {
153 if pos + fft_size <= total_len {
154 for i in 0..fft_size {
155 norm[pos + i] += analysis_window[i] * synthesis_window[i];
156 }
157 }
158 pos += hop_size;
159 }
160 let steady_start = fft_size;
161 norm[steady_start..steady_start + hop_size].to_vec()
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::{cast_vec, hann};
168
169 #[test]
170 fn roundtrip_identity() {
171 let fft_size = 1024;
172 let hop_size = 256;
173 let processor = StftProcessor::new(fft_size, hop_size, hann(fft_size));
174
175 let n = 16384;
176 let input: Vec<f64> = (0..n)
177 .map(|i| {
178 let t = i as f64 / 48000.0;
179 (2.0 * std::f64::consts::PI * 440.0 * t).sin()
180 + 0.5 * (2.0 * std::f64::consts::PI * 1000.0 * t).sin()
181 })
182 .collect();
183
184 let frames = processor.analyze(&input).expect("analyze");
185 let output = processor.synthesize(&frames, n).expect("synthesize");
186
187 let margin = fft_size * 2;
188 let max_err = (margin..n - margin)
189 .map(|i| (output[i] - input[i]).abs())
190 .fold(0.0, f64::max);
191 assert!(max_err < 1e-10, "roundtrip error: {max_err:.2e}");
192 }
193
194 #[test]
195 fn roundtrip_high_overlap() {
196 let fft_size = 4096;
197 let hop_size = 128;
198 let processor = StftProcessor::new(fft_size, hop_size, hann(fft_size));
199
200 let n = 32768;
201 let input: Vec<f64> = (0..n)
202 .map(|i| {
203 let t = i as f64 / 48000.0;
204 (2.0 * std::f64::consts::PI * 100.0 * t).sin()
205 })
206 .collect();
207
208 let frames = processor.analyze(&input).expect("analyze");
209 let output = processor.synthesize(&frames, n).expect("synthesize");
210
211 let margin = fft_size * 2;
212 let max_err = (margin..n - margin)
213 .map(|i| (output[i] - input[i]).abs())
214 .fold(0.0, f64::max);
215 assert!(max_err < 1e-10, "roundtrip error: {max_err:.2e}");
216 }
217
218 #[test]
219 fn roundtrip_f32() {
220 let fft_size = 1024;
221 let hop_size = 256;
222 let window = cast_vec(&hann(fft_size));
223 let processor = StftProcessor::<f32>::new(fft_size, hop_size, window);
224
225 let n = 16384;
226 let input: Vec<f32> = (0..n)
227 .map(|i| {
228 let t = i as f32 / 48000.0;
229 (2.0f32 * std::f32::consts::PI * 440.0 * t).sin()
230 + 0.5 * (2.0f32 * std::f32::consts::PI * 1000.0 * t).sin()
231 })
232 .collect();
233
234 let frames = processor.analyze(&input).expect("analyze");
235 let output = processor.synthesize(&frames, n).expect("synthesize");
236
237 let margin = fft_size * 2;
238 let max_err = (margin..n - margin)
239 .map(|i| (output[i] - input[i]).abs())
240 .fold(0.0, f32::max);
241 assert!(max_err < 1e-5, "roundtrip error: {max_err:.2e}");
242 }
243
244 #[test]
245 fn roundtrip_non_power_of_two() {
246 let fft_size = 1001;
247 let hop_size = 143;
248 let processor = StftProcessor::new(fft_size, hop_size, hann(fft_size));
249
250 let n = 24024;
251 let input: Vec<f64> = (0..n)
252 .map(|i| {
253 let t = i as f64 / 48000.0;
254 (2.0 * std::f64::consts::PI * 330.0 * t).sin()
255 + 0.35 * (2.0 * std::f64::consts::PI * 1234.0 * t).cos()
256 })
257 .collect();
258
259 let frames = processor.analyze(&input).expect("analyze");
260 let output = processor.synthesize(&frames, n).expect("synthesize");
261
262 let margin = fft_size * 2;
263 let max_err = (margin..n - margin)
264 .map(|i| (output[i] - input[i]).abs())
265 .fold(0.0, f64::max);
266 assert!(max_err < 1e-9, "roundtrip error: {max_err:.2e}");
267 }
268
269 #[test]
270 fn synthesize_rejects_wrong_frame_len() {
271 let processor = StftProcessor::new(16, 8, hann(16));
272 let err = processor
273 .synthesize(&[vec![Complex::new(0.0, 0.0); 3]], 16)
274 .expect_err("invalid frame len");
275 assert_eq!(
276 err,
277 StftError::InvalidFrameLen {
278 expected: processor.num_bins(),
279 got: 3,
280 }
281 );
282 }
283}