1#![allow(dead_code)]
2
3use std::ops::MulAssign;
4use std::sync::Arc;
5use std::vec::Vec;
6
7use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
8
9pub type Complex32 = num_complex::Complex32;
10
11pub const MEAN_NORM_INIT: [f32; 2] = [-60., -90.];
12pub const UNIT_NORM_INIT: [f32; 2] = [0.001, 0.0001];
13
14#[cfg(any(feature = "transforms", feature = "dataset"))]
15pub mod transforms;
16#[cfg(feature = "dataset")]
17#[path = ""]
18mod reexport_dataset_modules {
19 pub mod augmentations;
20 pub mod dataloader;
21 pub mod dataset;
22 pub mod util;
23 pub mod wav_utils;
24}
25#[cfg(feature = "dataset")]
26pub use reexport_dataset_modules::*;
27#[cfg(feature = "cache")]
28mod cache;
29
30pub(crate) fn freq2erb(freq_hz: f32) -> f32 {
31 9.265 * (freq_hz / (24.7 * 9.265)).ln_1p()
32}
33pub(crate) fn erb2freq(n_erb: f32) -> f32 {
34 24.7 * 9.265 * ((n_erb / 9.265).exp() - 1.)
35}
36
37#[derive(Clone)]
38pub struct DFState {
39 pub sr: usize,
40 pub frame_size: usize, pub window_size: usize, pub freq_size: usize, pub fft_forward: Arc<dyn RealToComplex<f32>>,
44 pub fft_inverse: Arc<dyn ComplexToReal<f32>>,
45 pub window: Vec<f32>,
46 pub wnorm: f32,
47 pub erb: Vec<usize>, analysis_mem: Vec<f32>,
49 analysis_scratch: Vec<Complex32>,
50 synthesis_mem: Vec<f32>,
51 synthesis_scratch: Vec<Complex32>,
52}
53
54pub fn erb_fb(sr: usize, fft_size: usize, nb_bands: usize, min_nb_freqs: usize) -> Vec<usize> {
55 let nyq_freq = sr / 2;
57 let freq_width = sr as f32 / fft_size as f32;
58 let erb_low: f32 = freq2erb(0.);
59 let erb_high: f32 = freq2erb(nyq_freq as f32);
60 let mut erb = vec![0; nb_bands];
61 let step = (erb_high - erb_low) / nb_bands as f32;
62 let min_nb_freqs = min_nb_freqs as i32; let mut prev_freq = 0; let mut freq_over = 0; for i in 1..nb_bands + 1 {
66 let f = erb2freq(erb_low + i as f32 * step);
67 let fb = (f / freq_width).round() as usize;
68 let mut nb_freqs = fb as i32 - prev_freq as i32 - freq_over;
69 if nb_freqs < min_nb_freqs {
70 freq_over = min_nb_freqs - nb_freqs; nb_freqs = min_nb_freqs; } else {
74 freq_over = 0
75 }
76 erb[i - 1] = nb_freqs as usize;
77 prev_freq = fb;
78 }
79 erb[nb_bands - 1] += 1; let too_large = erb.iter().sum::<usize>() - (fft_size / 2 + 1);
81 if too_large > 0 {
82 erb[nb_bands - 1] -= too_large;
83 }
84 debug_assert!(erb.iter().sum::<usize>() == fft_size / 2 + 1);
85 erb
86}
87
88impl DFState {
90 pub fn new(
91 sr: usize,
92 fft_size: usize,
93 hop_size: usize,
94 nb_bands: usize,
95 min_nb_freqs: usize,
96 ) -> Self {
97 assert!(hop_size * 2 <= fft_size);
98 let mut fft = RealFftPlanner::<f32>::new();
99 let frame_size = hop_size;
100 let window_size = fft_size;
101 let window_size_h = fft_size / 2;
102 let freq_size = fft_size / 2 + 1;
103 let forward = fft.plan_fft_forward(fft_size);
104 let backward = fft.plan_fft_inverse(fft_size);
105 let analysis_mem = vec![0.; fft_size - frame_size];
106 let synthesis_mem = vec![0.; fft_size - frame_size];
107 let analysis_scratch = forward.make_scratch_vec();
108 let synthesis_scratch = backward.make_scratch_vec();
109
110 let erb = erb_fb(sr, fft_size, nb_bands, min_nb_freqs);
111
112 let pi = std::f64::consts::PI;
113 let mut window = vec![0.0; fft_size];
114 for (i, w) in window.iter_mut().enumerate() {
115 let sin = (0.5 * pi * (i as f64 + 0.5) / window_size_h as f64).sin();
116 *w = (0.5 * pi * sin * sin).sin() as f32;
117 }
118 let wnorm =
119 1_f32 / window.iter().map(|x| x * x).sum::<f32>() * frame_size as f32 / fft_size as f32;
120
121 DFState {
122 sr,
123 frame_size,
124 window_size,
125 freq_size,
126 fft_forward: forward,
127 fft_inverse: backward,
128 erb,
129 analysis_mem,
130 analysis_scratch,
131 synthesis_mem,
132 synthesis_scratch,
133 window,
134 wnorm,
135 }
136 }
137
138 pub fn reset(&mut self) {
139 self.analysis_mem.fill(0.);
140 self.synthesis_mem.fill(0.);
141 }
142
143 pub fn process_frame(&mut self, input: &[f32], output: &mut [f32]) {
144 debug_assert_eq!(input.len(), self.frame_size);
145 debug_assert_eq!(output.len(), self.frame_size);
146 process_frame(input, output, self);
147 }
148
149 pub fn analysis(&mut self, input: &[f32], output: &mut [Complex32]) {
150 debug_assert_eq!(input.len(), self.frame_size);
151 frame_analysis(input, output, self)
152 }
153
154 pub fn synthesis(&mut self, input: &mut [Complex32], output: &mut [f32]) {
155 debug_assert_eq!(output.len(), self.frame_size);
156 frame_synthesis(input, output, self)
157 }
158}
159
160impl Default for DFState {
161 fn default() -> Self {
162 Self::new(48000, 960, 480, 32, 2)
163 }
164}
165
166pub fn band_mean_norm_freq(xs: &[Complex32], xout: &mut [f32], state: &mut [f32], alpha: f32) {
167 debug_assert_eq!(xs.len(), state.len());
168 debug_assert_eq!(xout.len(), state.len());
169 for ((x, s), xo) in xs.iter().zip(state.iter_mut()).zip(xout.iter_mut()) {
170 let xabs = x.norm();
171 *s = xabs * (1. - alpha) + *s * alpha;
172 *xo = xabs - *s;
173 }
174}
175
176pub fn band_mean_norm_erb(xs: &mut [f32], state: &mut [f32], alpha: f32) {
177 debug_assert_eq!(xs.len(), state.len());
178 for (x, s) in xs.iter_mut().zip(state.iter_mut()) {
179 *s = *x * (1. - alpha) + *s * alpha;
180 *x -= *s;
181 *x /= 40.;
182 }
183}
184
185pub fn band_unit_norm(xs: &mut [Complex32], state: &mut [f32], alpha: f32) {
186 debug_assert_eq!(xs.len(), state.len());
187 for (x, s) in xs.iter_mut().zip(state.iter_mut()) {
188 *s = x.norm() * (1. - alpha) + *s * alpha;
189 *x /= s.sqrt();
190 }
191}
192
193pub fn compute_band_corr(out: &mut [f32], x: &[Complex32], p: &[Complex32], erb_fb: &[usize]) {
194 for y in out.iter_mut() {
195 *y = 0.0;
196 }
197 debug_assert_eq!(erb_fb.len(), out.len());
198
199 let mut bcsum = 0;
200 for (&band_size, out_b) in erb_fb.iter().zip(out.iter_mut()) {
201 let k = 1. / band_size as f32;
202 for j in 0..band_size {
203 let idx = bcsum + j;
204 *out_b += (x[idx].re * p[idx].re + x[idx].im * p[idx].im) * k;
205 }
206 bcsum += band_size;
207 }
208}
209
210pub fn band_compr(out: &mut [f32], x: &[f32], erb_fb: &[usize]) {
211 for y in out.iter_mut() {
212 *y = 0.0;
213 }
214 debug_assert_eq!(erb_fb.len(), out.len());
215
216 let mut bcsum = 0;
217 for (&band_size, out_b) in erb_fb.iter().zip(out.iter_mut()) {
218 let k = 1. / band_size as f32;
219 for j in 0..band_size {
220 let idx = bcsum + j;
221 *out_b += x[idx] * k;
222 }
223 bcsum += band_size;
224 }
225}
226
227fn apply_interp_band_gain<T>(out: &mut [T], band_e: &[f32], erb_fb: &[usize])
228where
229 T: MulAssign<f32>,
230{
231 let mut bcsum = 0;
232 for (&band_size, &b) in erb_fb.iter().zip(band_e.iter()) {
233 for j in 0..band_size {
234 let idx = bcsum + j;
235 out[idx] *= b;
236 }
237 bcsum += band_size;
238 }
239}
240
241fn interp_band_gain(out: &mut [f32], band_e: &[f32], erb_fb: &[usize]) {
242 let mut bcsum = 0;
243 for (&band_size, &b) in erb_fb.iter().zip(band_e.iter()) {
244 for j in 0..band_size {
245 let idx = bcsum + j;
246 out[idx] = b;
247 }
248 bcsum += band_size;
249 }
250}
251
252fn apply_band_gain(out: &mut [Complex32], band_e: &[f32], erb_fb: &[usize]) {
253 let mut bcsum = 0;
254 for (&band_size, b) in erb_fb.iter().zip(band_e.iter()) {
255 for j in 0..band_size {
256 let idx = bcsum + j;
257 out[idx] *= *b;
258 }
259 bcsum += band_size;
260 }
261}
262
263fn process_frame(input: &[f32], output: &mut [f32], state: &mut DFState) {
264 let mut freq_mem = vec![Complex32::default(); state.freq_size];
265 frame_analysis(input, &mut freq_mem, state);
266 frame_synthesis(&mut freq_mem, output, state);
267}
268
269fn frame_analysis(input: &[f32], output: &mut [Complex32], state: &mut DFState) {
270 debug_assert_eq!(input.len(), state.frame_size);
271 debug_assert_eq!(output.len(), state.freq_size);
272
273 let mut buf = state.fft_forward.make_input_vec();
274 let (buf_first, buf_second) = buf.split_at_mut(state.window_size - state.frame_size);
276 let (window_first, window_second) = state.window.split_at(state.window_size - state.frame_size);
277 let analysis_split = state.analysis_mem.len() - state.frame_size;
278 for ((&y, &w), x) in
279 state.analysis_mem.iter().zip(window_first.iter()).zip(buf_first.iter_mut())
280 {
281 *x = y * w;
282 }
283 for ((&y, &w), x) in input.iter().zip(window_second.iter()).zip(buf_second.iter_mut()) {
285 *x = y * w;
286 }
287 if analysis_split > 0 {
289 state.analysis_mem.rotate_left(state.frame_size);
291 }
292 for (x, &y) in state.analysis_mem[analysis_split..].iter_mut().zip(input) {
294 *x = y
295 }
296 state
297 .fft_forward
298 .process_with_scratch(&mut buf, output, &mut state.analysis_scratch)
299 .expect("FFT forward failed");
300 let norm = state.wnorm;
302 for x in output.iter_mut() {
303 *x *= norm;
304 }
305}
306
307fn frame_synthesis(input: &mut [Complex32], output: &mut [f32], state: &mut DFState) {
308 let mut x = state.fft_inverse.make_output_vec();
309 match state
310 .fft_inverse
311 .process_with_scratch(input, &mut x[..], &mut state.synthesis_scratch)
312 {
313 Err(realfft::FftError::InputValues(_, _)) => (),
314 Err(e) => Err(e).unwrap(),
315 Ok(_) => (),
316 }
317 apply_window_in_place(&mut x, &state.window);
318 let (x_first, x_second) = x.split_at(state.frame_size);
319 for ((&xi, &mem), out) in x_first.iter().zip(state.synthesis_mem.iter()).zip(output.iter_mut())
320 {
321 *out = xi + mem;
322 }
323
324 let split = state.synthesis_mem.len() - state.frame_size;
325 if split > 0 {
326 state.synthesis_mem.rotate_left(state.frame_size);
327 }
328 let (s_first, s_second) = state.synthesis_mem.split_at_mut(split);
329 let (xs_first, xs_second) = x_second.split_at(split);
330 for (&xi, mem) in xs_first.iter().zip(s_first.iter_mut()) {
331 *mem += xi;
333 }
334 for (&xi, mem) in xs_second.iter().zip(s_second.iter_mut()) {
335 *mem = xi;
337 }
338}
339
340fn apply_window(xs: &[f32], window: &[f32]) -> Vec<f32> {
341 let mut out = vec![0.; window.len()];
342 for ((&x, &w), o) in xs.iter().zip(window.iter()).zip(out.iter_mut()) {
343 *o = x * w;
344 }
345 out
346}
347
348fn apply_window_in_place<'a, I>(xs: &mut [f32], window: I)
349where
350 I: IntoIterator<Item = &'a f32>,
351{
352 for (x, &w) in xs.iter_mut().zip(window) {
353 *x *= w;
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use rand::distributions::{Distribution, Uniform};
360
361 use super::*;
362
363 #[test]
364 fn test_erb_inout() {
365 let sr = 24000;
366 let n_fft = 192;
367 let n_freqs = n_fft / 2 + 1;
368 let hop = n_fft / 2;
369 let nb_bands = 24;
370 let state = DFState::new(sr, n_fft, hop, nb_bands, 1);
371 let d = Uniform::new(-1., 1.);
372 let mut input = Vec::with_capacity(n_freqs);
373 let mut rng = rand::thread_rng();
374 for _ in 0..(n_freqs) {
375 input.push(Complex32::new(d.sample(&mut rng), d.sample(&mut rng)))
376 }
377 let mut mask = vec![1.; nb_bands];
378 mask[3] = 0.3;
379 mask[nb_bands - 1] = 0.5;
380 let mut output = input.clone();
381 apply_band_gain(&mut output, mask.as_slice(), &state.erb);
382 let mut cumsum = 0;
383 for (erb_idx, erb_w) in state.erb.iter().enumerate() {
384 for i in cumsum..cumsum + erb_w {
385 assert_eq!(input[i] * mask[erb_idx], output[i])
386 }
387 cumsum += erb_w;
388 }
389 }
390}