Skip to main content

neco_stft/
stft.rs

1use neco_complex::Complex;
2
3use crate::dsp_float::DspFloat;
4use crate::FftError;
5
6pub type SpectrumFrame<T = f64> = Vec<Complex<T>>;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum StftError {
10    InvalidFrameLen { expected: usize, got: usize },
11    Fft(FftError),
12}
13
14impl core::fmt::Display for StftError {
15    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16        match self {
17            Self::InvalidFrameLen { expected, got } => {
18                write!(
19                    f,
20                    "wrong spectrum frame length: expected {expected}, got {got}"
21                )
22            }
23            Self::Fft(err) => err.fmt(f),
24        }
25    }
26}
27
28impl std::error::Error for StftError {}
29
30impl From<FftError> for StftError {
31    fn from(value: FftError) -> Self {
32        Self::Fft(value)
33    }
34}
35
36pub struct StftProcessor<T: DspFloat = f64> {
37    fft_size: usize,
38    hop_size: usize,
39    analysis_window: Vec<T>,
40    synthesis_window: Vec<T>,
41    norm: Vec<T>,
42}
43
44impl<T: DspFloat> StftProcessor<T> {
45    pub fn new(fft_size: usize, hop_size: usize, window: Vec<T>) -> Self {
46        assert_eq!(window.len(), fft_size);
47        assert!(hop_size > 0 && hop_size <= fft_size);
48
49        let norm = compute_wola_norm(&window, &window, fft_size, hop_size);
50        Self {
51            fft_size,
52            hop_size,
53            analysis_window: window.clone(),
54            synthesis_window: window,
55            norm,
56        }
57    }
58
59    pub fn fft_size(&self) -> usize {
60        self.fft_size
61    }
62
63    pub fn hop_size(&self) -> usize {
64        self.hop_size
65    }
66
67    pub fn num_bins(&self) -> usize {
68        self.fft_size / 2 + 1
69    }
70
71    pub fn analyze(&self, input: &[T]) -> Result<Vec<SpectrumFrame<T>>, StftError> {
72        T::with_fft_planner(|planner| {
73            let fft = planner.plan_fft_forward(self.fft_size);
74            let mut frames = Vec::new();
75            let mut pos = 0usize;
76            let mut windowed = vec![T::zero(); self.fft_size];
77            let mut spectrum = fft.make_output_vec();
78
79            while pos + self.fft_size <= input.len() {
80                for i in 0..self.fft_size {
81                    windowed[i] = input[pos + i] * self.analysis_window[i];
82                }
83                fft.process(&mut windowed, &mut spectrum)?;
84                frames.push(spectrum.clone());
85                pos += self.hop_size;
86            }
87
88            Ok(frames)
89        })
90    }
91
92    pub fn synthesize(
93        &self,
94        frames: &[SpectrumFrame<T>],
95        output_len: usize,
96    ) -> Result<Vec<T>, StftError> {
97        T::with_fft_planner(|planner| {
98            let ifft = planner.plan_fft_inverse(self.fft_size);
99            let mut output = vec![T::zero(); output_len];
100            let mut pos = 0usize;
101            let scale = T::one() / T::from_usize(self.fft_size);
102            let mut spectrum = ifft.make_input_vec();
103            let mut time_buf = ifft.make_output_vec();
104
105            for frame in frames {
106                if pos + self.fft_size > output_len {
107                    break;
108                }
109
110                if frame.len() != spectrum.len() {
111                    return Err(StftError::InvalidFrameLen {
112                        expected: spectrum.len(),
113                        got: frame.len(),
114                    });
115                }
116                spectrum.copy_from_slice(frame);
117                ifft.process(&mut spectrum, &mut time_buf)?;
118
119                for i in 0..self.fft_size {
120                    output[pos + i] += time_buf[i] * scale * self.synthesis_window[i];
121                }
122                pos += self.hop_size;
123            }
124
125            self.apply_norm(&mut output);
126            Ok(output)
127        })
128    }
129
130    fn apply_norm(&self, output: &mut [T]) {
131        let epsilon = T::from_f64(1e-10);
132        let norm_len = self.norm.len();
133        for (index, sample) in output.iter_mut().enumerate() {
134            let norm_value = self.norm[index % norm_len];
135            if norm_value > epsilon {
136                *sample /= norm_value;
137            }
138        }
139    }
140}
141
142fn compute_wola_norm<T: DspFloat>(
143    analysis_window: &[T],
144    synthesis_window: &[T],
145    fft_size: usize,
146    hop_size: usize,
147) -> Vec<T> {
148    let num_frames = fft_size.div_ceil(hop_size) + 1;
149    let total_len = num_frames * hop_size + fft_size;
150    let mut norm = vec![T::zero(); total_len];
151    let mut pos = 0usize;
152    for _ in 0..num_frames + 1 {
153        if pos + fft_size <= total_len {
154            for i in 0..fft_size {
155                norm[pos + i] += analysis_window[i] * synthesis_window[i];
156            }
157        }
158        pos += hop_size;
159    }
160    let steady_start = fft_size;
161    norm[steady_start..steady_start + hop_size].to_vec()
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::{cast_vec, hann};
168
169    #[test]
170    fn roundtrip_identity() {
171        let fft_size = 1024;
172        let hop_size = 256;
173        let processor = StftProcessor::new(fft_size, hop_size, hann(fft_size));
174
175        let n = 16384;
176        let input: Vec<f64> = (0..n)
177            .map(|i| {
178                let t = i as f64 / 48000.0;
179                (2.0 * std::f64::consts::PI * 440.0 * t).sin()
180                    + 0.5 * (2.0 * std::f64::consts::PI * 1000.0 * t).sin()
181            })
182            .collect();
183
184        let frames = processor.analyze(&input).expect("analyze");
185        let output = processor.synthesize(&frames, n).expect("synthesize");
186
187        let margin = fft_size * 2;
188        let max_err = (margin..n - margin)
189            .map(|i| (output[i] - input[i]).abs())
190            .fold(0.0, f64::max);
191        assert!(max_err < 1e-10, "roundtrip error: {max_err:.2e}");
192    }
193
194    #[test]
195    fn roundtrip_high_overlap() {
196        let fft_size = 4096;
197        let hop_size = 128;
198        let processor = StftProcessor::new(fft_size, hop_size, hann(fft_size));
199
200        let n = 32768;
201        let input: Vec<f64> = (0..n)
202            .map(|i| {
203                let t = i as f64 / 48000.0;
204                (2.0 * std::f64::consts::PI * 100.0 * t).sin()
205            })
206            .collect();
207
208        let frames = processor.analyze(&input).expect("analyze");
209        let output = processor.synthesize(&frames, n).expect("synthesize");
210
211        let margin = fft_size * 2;
212        let max_err = (margin..n - margin)
213            .map(|i| (output[i] - input[i]).abs())
214            .fold(0.0, f64::max);
215        assert!(max_err < 1e-10, "roundtrip error: {max_err:.2e}");
216    }
217
218    #[test]
219    fn roundtrip_f32() {
220        let fft_size = 1024;
221        let hop_size = 256;
222        let window = cast_vec(&hann(fft_size));
223        let processor = StftProcessor::<f32>::new(fft_size, hop_size, window);
224
225        let n = 16384;
226        let input: Vec<f32> = (0..n)
227            .map(|i| {
228                let t = i as f32 / 48000.0;
229                (2.0f32 * std::f32::consts::PI * 440.0 * t).sin()
230                    + 0.5 * (2.0f32 * std::f32::consts::PI * 1000.0 * t).sin()
231            })
232            .collect();
233
234        let frames = processor.analyze(&input).expect("analyze");
235        let output = processor.synthesize(&frames, n).expect("synthesize");
236
237        let margin = fft_size * 2;
238        let max_err = (margin..n - margin)
239            .map(|i| (output[i] - input[i]).abs())
240            .fold(0.0, f32::max);
241        assert!(max_err < 1e-5, "roundtrip error: {max_err:.2e}");
242    }
243
244    #[test]
245    fn roundtrip_non_power_of_two() {
246        let fft_size = 1001;
247        let hop_size = 143;
248        let processor = StftProcessor::new(fft_size, hop_size, hann(fft_size));
249
250        let n = 24024;
251        let input: Vec<f64> = (0..n)
252            .map(|i| {
253                let t = i as f64 / 48000.0;
254                (2.0 * std::f64::consts::PI * 330.0 * t).sin()
255                    + 0.35 * (2.0 * std::f64::consts::PI * 1234.0 * t).cos()
256            })
257            .collect();
258
259        let frames = processor.analyze(&input).expect("analyze");
260        let output = processor.synthesize(&frames, n).expect("synthesize");
261
262        let margin = fft_size * 2;
263        let max_err = (margin..n - margin)
264            .map(|i| (output[i] - input[i]).abs())
265            .fold(0.0, f64::max);
266        assert!(max_err < 1e-9, "roundtrip error: {max_err:.2e}");
267    }
268
269    #[test]
270    fn synthesize_rejects_wrong_frame_len() {
271        let processor = StftProcessor::new(16, 8, hann(16));
272        let err = processor
273            .synthesize(&[vec![Complex::new(0.0, 0.0); 3]], 16)
274            .expect_err("invalid frame len");
275        assert_eq!(
276            err,
277            StftError::InvalidFrameLen {
278                expected: processor.num_bins(),
279                got: 3,
280            }
281        );
282    }
283}