df/
lib.rs

1#![allow(dead_code)]
2
3use std::ops::MulAssign;
4use std::sync::Arc;
5use std::vec::Vec;
6
7use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
8
9pub type Complex32 = num_complex::Complex32;
10
11pub const MEAN_NORM_INIT: [f32; 2] = [-60., -90.];
12pub const UNIT_NORM_INIT: [f32; 2] = [0.001, 0.0001];
13
14#[cfg(any(feature = "transforms", feature = "dataset"))]
15pub mod transforms;
16#[cfg(feature = "dataset")]
17#[path = ""]
18mod reexport_dataset_modules {
19    pub mod augmentations;
20    pub mod dataloader;
21    pub mod dataset;
22    pub mod util;
23    pub mod wav_utils;
24}
25#[cfg(feature = "dataset")]
26pub use reexport_dataset_modules::*;
27#[cfg(feature = "cache")]
28mod cache;
29
30pub(crate) fn freq2erb(freq_hz: f32) -> f32 {
31    9.265 * (freq_hz / (24.7 * 9.265)).ln_1p()
32}
33pub(crate) fn erb2freq(n_erb: f32) -> f32 {
34    24.7 * 9.265 * ((n_erb / 9.265).exp() - 1.)
35}
36
37#[derive(Clone)]
38pub struct DFState {
39    pub sr: usize,
40    pub frame_size: usize,  // hop_size
41    pub window_size: usize, // Same as fft_size
42    pub freq_size: usize,   // fft_size / 2 + 1
43    pub fft_forward: Arc<dyn RealToComplex<f32>>,
44    pub fft_inverse: Arc<dyn ComplexToReal<f32>>,
45    pub window: Vec<f32>,
46    pub wnorm: f32,
47    pub erb: Vec<usize>, // frequencies bandwidth (in bands) per ERB band
48    analysis_mem: Vec<f32>,
49    analysis_scratch: Vec<Complex32>,
50    synthesis_mem: Vec<f32>,
51    synthesis_scratch: Vec<Complex32>,
52}
53
54pub fn erb_fb(sr: usize, fft_size: usize, nb_bands: usize, min_nb_freqs: usize) -> Vec<usize> {
55    // Init ERB filter bank
56    let nyq_freq = sr / 2;
57    let freq_width = sr as f32 / fft_size as f32;
58    let erb_low: f32 = freq2erb(0.);
59    let erb_high: f32 = freq2erb(nyq_freq as f32);
60    let mut erb = vec![0; nb_bands];
61    let step = (erb_high - erb_low) / nb_bands as f32;
62    let min_nb_freqs = min_nb_freqs as i32; // Minimum number of frequency bands per erb band
63    let mut prev_freq = 0; // Last frequency band of the previous erb band
64    let mut freq_over = 0; // Number of frequency bands that are already stored in previous erb bands
65    for i in 1..nb_bands + 1 {
66        let f = erb2freq(erb_low + i as f32 * step);
67        let fb = (f / freq_width).round() as usize;
68        let mut nb_freqs = fb as i32 - prev_freq as i32 - freq_over;
69        if nb_freqs < min_nb_freqs {
70            // Not enough freq bins in current bark bin
71            freq_over = min_nb_freqs - nb_freqs; // keep track of number of enforced bins
72            nb_freqs = min_nb_freqs; // enforce min_nb_freqs
73        } else {
74            freq_over = 0
75        }
76        erb[i - 1] = nb_freqs as usize;
77        prev_freq = fb;
78    }
79    erb[nb_bands - 1] += 1; // since we have WINDOW_SIZE/2+1 frequency bins
80    let too_large = erb.iter().sum::<usize>() - (fft_size / 2 + 1);
81    if too_large > 0 {
82        erb[nb_bands - 1] -= too_large;
83    }
84    debug_assert!(erb.iter().sum::<usize>() == fft_size / 2 + 1);
85    erb
86}
87
88// TODO Check delay for diferent hop sizes
89impl DFState {
90    pub fn new(
91        sr: usize,
92        fft_size: usize,
93        hop_size: usize,
94        nb_bands: usize,
95        min_nb_freqs: usize,
96    ) -> Self {
97        assert!(hop_size * 2 <= fft_size);
98        let mut fft = RealFftPlanner::<f32>::new();
99        let frame_size = hop_size;
100        let window_size = fft_size;
101        let window_size_h = fft_size / 2;
102        let freq_size = fft_size / 2 + 1;
103        let forward = fft.plan_fft_forward(fft_size);
104        let backward = fft.plan_fft_inverse(fft_size);
105        let analysis_mem = vec![0.; fft_size - frame_size];
106        let synthesis_mem = vec![0.; fft_size - frame_size];
107        let analysis_scratch = forward.make_scratch_vec();
108        let synthesis_scratch = backward.make_scratch_vec();
109
110        let erb = erb_fb(sr, fft_size, nb_bands, min_nb_freqs);
111
112        let pi = std::f64::consts::PI;
113        let mut window = vec![0.0; fft_size];
114        for (i, w) in window.iter_mut().enumerate() {
115            let sin = (0.5 * pi * (i as f64 + 0.5) / window_size_h as f64).sin();
116            *w = (0.5 * pi * sin * sin).sin() as f32;
117        }
118        let wnorm =
119            1_f32 / window.iter().map(|x| x * x).sum::<f32>() * frame_size as f32 / fft_size as f32;
120
121        DFState {
122            sr,
123            frame_size,
124            window_size,
125            freq_size,
126            fft_forward: forward,
127            fft_inverse: backward,
128            erb,
129            analysis_mem,
130            analysis_scratch,
131            synthesis_mem,
132            synthesis_scratch,
133            window,
134            wnorm,
135        }
136    }
137
138    pub fn reset(&mut self) {
139        self.analysis_mem.fill(0.);
140        self.synthesis_mem.fill(0.);
141    }
142
143    pub fn process_frame(&mut self, input: &[f32], output: &mut [f32]) {
144        debug_assert_eq!(input.len(), self.frame_size);
145        debug_assert_eq!(output.len(), self.frame_size);
146        process_frame(input, output, self);
147    }
148
149    pub fn analysis(&mut self, input: &[f32], output: &mut [Complex32]) {
150        debug_assert_eq!(input.len(), self.frame_size);
151        frame_analysis(input, output, self)
152    }
153
154    pub fn synthesis(&mut self, input: &mut [Complex32], output: &mut [f32]) {
155        debug_assert_eq!(output.len(), self.frame_size);
156        frame_synthesis(input, output, self)
157    }
158}
159
160impl Default for DFState {
161    fn default() -> Self {
162        Self::new(48000, 960, 480, 32, 2)
163    }
164}
165
166pub fn band_mean_norm_freq(xs: &[Complex32], xout: &mut [f32], state: &mut [f32], alpha: f32) {
167    debug_assert_eq!(xs.len(), state.len());
168    debug_assert_eq!(xout.len(), state.len());
169    for ((x, s), xo) in xs.iter().zip(state.iter_mut()).zip(xout.iter_mut()) {
170        let xabs = x.norm();
171        *s = xabs * (1. - alpha) + *s * alpha;
172        *xo = xabs - *s;
173    }
174}
175
176pub fn band_mean_norm_erb(xs: &mut [f32], state: &mut [f32], alpha: f32) {
177    debug_assert_eq!(xs.len(), state.len());
178    for (x, s) in xs.iter_mut().zip(state.iter_mut()) {
179        *s = *x * (1. - alpha) + *s * alpha;
180        *x -= *s;
181        *x /= 40.;
182    }
183}
184
185pub fn band_unit_norm(xs: &mut [Complex32], state: &mut [f32], alpha: f32) {
186    debug_assert_eq!(xs.len(), state.len());
187    for (x, s) in xs.iter_mut().zip(state.iter_mut()) {
188        *s = x.norm() * (1. - alpha) + *s * alpha;
189        *x /= s.sqrt();
190    }
191}
192
193pub fn compute_band_corr(out: &mut [f32], x: &[Complex32], p: &[Complex32], erb_fb: &[usize]) {
194    for y in out.iter_mut() {
195        *y = 0.0;
196    }
197    debug_assert_eq!(erb_fb.len(), out.len());
198
199    let mut bcsum = 0;
200    for (&band_size, out_b) in erb_fb.iter().zip(out.iter_mut()) {
201        let k = 1. / band_size as f32;
202        for j in 0..band_size {
203            let idx = bcsum + j;
204            *out_b += (x[idx].re * p[idx].re + x[idx].im * p[idx].im) * k;
205        }
206        bcsum += band_size;
207    }
208}
209
210pub fn band_compr(out: &mut [f32], x: &[f32], erb_fb: &[usize]) {
211    for y in out.iter_mut() {
212        *y = 0.0;
213    }
214    debug_assert_eq!(erb_fb.len(), out.len());
215
216    let mut bcsum = 0;
217    for (&band_size, out_b) in erb_fb.iter().zip(out.iter_mut()) {
218        let k = 1. / band_size as f32;
219        for j in 0..band_size {
220            let idx = bcsum + j;
221            *out_b += x[idx] * k;
222        }
223        bcsum += band_size;
224    }
225}
226
227fn apply_interp_band_gain<T>(out: &mut [T], band_e: &[f32], erb_fb: &[usize])
228where
229    T: MulAssign<f32>,
230{
231    let mut bcsum = 0;
232    for (&band_size, &b) in erb_fb.iter().zip(band_e.iter()) {
233        for j in 0..band_size {
234            let idx = bcsum + j;
235            out[idx] *= b;
236        }
237        bcsum += band_size;
238    }
239}
240
241fn interp_band_gain(out: &mut [f32], band_e: &[f32], erb_fb: &[usize]) {
242    let mut bcsum = 0;
243    for (&band_size, &b) in erb_fb.iter().zip(band_e.iter()) {
244        for j in 0..band_size {
245            let idx = bcsum + j;
246            out[idx] = b;
247        }
248        bcsum += band_size;
249    }
250}
251
252fn apply_band_gain(out: &mut [Complex32], band_e: &[f32], erb_fb: &[usize]) {
253    let mut bcsum = 0;
254    for (&band_size, b) in erb_fb.iter().zip(band_e.iter()) {
255        for j in 0..band_size {
256            let idx = bcsum + j;
257            out[idx] *= *b;
258        }
259        bcsum += band_size;
260    }
261}
262
263fn process_frame(input: &[f32], output: &mut [f32], state: &mut DFState) {
264    let mut freq_mem = vec![Complex32::default(); state.freq_size];
265    frame_analysis(input, &mut freq_mem, state);
266    frame_synthesis(&mut freq_mem, output, state);
267}
268
269fn frame_analysis(input: &[f32], output: &mut [Complex32], state: &mut DFState) {
270    debug_assert_eq!(input.len(), state.frame_size);
271    debug_assert_eq!(output.len(), state.freq_size);
272
273    let mut buf = state.fft_forward.make_input_vec();
274    // First part of the window on the previous frame
275    let (buf_first, buf_second) = buf.split_at_mut(state.window_size - state.frame_size);
276    let (window_first, window_second) = state.window.split_at(state.window_size - state.frame_size);
277    let analysis_split = state.analysis_mem.len() - state.frame_size;
278    for ((&y, &w), x) in
279        state.analysis_mem.iter().zip(window_first.iter()).zip(buf_first.iter_mut())
280    {
281        *x = y * w;
282    }
283    // Second part of the window on the new input frame
284    for ((&y, &w), x) in input.iter().zip(window_second.iter()).zip(buf_second.iter_mut()) {
285        *x = y * w;
286    }
287    // Shift analysis_mem
288    if analysis_split > 0 {
289        // hop_size is < window_size / 2
290        state.analysis_mem.rotate_left(state.frame_size);
291    }
292    // Copy input to analysis_mem for next iteration
293    for (x, &y) in state.analysis_mem[analysis_split..].iter_mut().zip(input) {
294        *x = y
295    }
296    state
297        .fft_forward
298        .process_with_scratch(&mut buf, output, &mut state.analysis_scratch)
299        .expect("FFT forward failed");
300    // Apply normalization in analysis only
301    let norm = state.wnorm;
302    for x in output.iter_mut() {
303        *x *= norm;
304    }
305}
306
307fn frame_synthesis(input: &mut [Complex32], output: &mut [f32], state: &mut DFState) {
308    let mut x = state.fft_inverse.make_output_vec();
309    match state
310        .fft_inverse
311        .process_with_scratch(input, &mut x[..], &mut state.synthesis_scratch)
312    {
313        Err(realfft::FftError::InputValues(_, _)) => (),
314        Err(e) => Err(e).unwrap(),
315        Ok(_) => (),
316    }
317    apply_window_in_place(&mut x, &state.window);
318    let (x_first, x_second) = x.split_at(state.frame_size);
319    for ((&xi, &mem), out) in x_first.iter().zip(state.synthesis_mem.iter()).zip(output.iter_mut())
320    {
321        *out = xi + mem;
322    }
323
324    let split = state.synthesis_mem.len() - state.frame_size;
325    if split > 0 {
326        state.synthesis_mem.rotate_left(state.frame_size);
327    }
328    let (s_first, s_second) = state.synthesis_mem.split_at_mut(split);
329    let (xs_first, xs_second) = x_second.split_at(split);
330    for (&xi, mem) in xs_first.iter().zip(s_first.iter_mut()) {
331        // Overlap add for next frame
332        *mem += xi;
333    }
334    for (&xi, mem) in xs_second.iter().zip(s_second.iter_mut()) {
335        // Override left shifted buffer
336        *mem = xi;
337    }
338}
339
340fn apply_window(xs: &[f32], window: &[f32]) -> Vec<f32> {
341    let mut out = vec![0.; window.len()];
342    for ((&x, &w), o) in xs.iter().zip(window.iter()).zip(out.iter_mut()) {
343        *o = x * w;
344    }
345    out
346}
347
348fn apply_window_in_place<'a, I>(xs: &mut [f32], window: I)
349where
350    I: IntoIterator<Item = &'a f32>,
351{
352    for (x, &w) in xs.iter_mut().zip(window) {
353        *x *= w;
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use rand::distributions::{Distribution, Uniform};
360
361    use super::*;
362
363    #[test]
364    fn test_erb_inout() {
365        let sr = 24000;
366        let n_fft = 192;
367        let n_freqs = n_fft / 2 + 1;
368        let hop = n_fft / 2;
369        let nb_bands = 24;
370        let state = DFState::new(sr, n_fft, hop, nb_bands, 1);
371        let d = Uniform::new(-1., 1.);
372        let mut input = Vec::with_capacity(n_freqs);
373        let mut rng = rand::thread_rng();
374        for _ in 0..(n_freqs) {
375            input.push(Complex32::new(d.sample(&mut rng), d.sample(&mut rng)))
376        }
377        let mut mask = vec![1.; nb_bands];
378        mask[3] = 0.3;
379        mask[nb_bands - 1] = 0.5;
380        let mut output = input.clone();
381        apply_band_gain(&mut output, mask.as_slice(), &state.erb);
382        let mut cumsum = 0;
383        for (erb_idx, erb_w) in state.erb.iter().enumerate() {
384            for i in cumsum..cumsum + erb_w {
385                assert_eq!(input[i] * mask[erb_idx], output[i])
386            }
387            cumsum += erb_w;
388        }
389    }
390}