use std::borrow::Cow;
use crate::{
Complex, RnnModel, CEPS_MEM, FRAME_SIZE, FREQ_SIZE, NB_BANDS, NB_DELTA_CEPS, NB_FEATURES,
PITCH_BUF_SIZE, WINDOW_SIZE,
};
#[derive(Clone)]
pub struct DenoiseState<'model> {
input_mem: Vec<f32>,
cepstral_mem: [[f32; crate::NB_BANDS]; crate::CEPS_MEM],
mem_id: usize,
synthesis_mem: [f32; FRAME_SIZE],
mem_hp_x: [f32; 2],
lastg: [f32; crate::NB_BANDS],
rnn: crate::rnn::RnnState<'model>,
fft: crate::fft::RealFft,
pitch_finder: crate::pitch::PitchFinder,
}
impl DenoiseState<'static> {
pub const FRAME_SIZE: usize = FRAME_SIZE;
pub(crate) fn default() -> Self {
DenoiseState::from_model_owned(Cow::Owned(RnnModel::default()))
}
pub fn new() -> Box<DenoiseState<'static>> {
Box::new(Self::default())
}
pub fn from_model(model: RnnModel) -> Box<DenoiseState<'static>> {
Box::new(DenoiseState::from_model_owned(Cow::Owned(model)))
}
}
impl<'model> DenoiseState<'model> {
pub fn with_model(model: &'model RnnModel) -> Box<DenoiseState<'model>> {
Box::new(DenoiseState::from_model_owned(Cow::Borrowed(model)))
}
pub(crate) fn from_model_owned(model: Cow<'model, RnnModel>) -> DenoiseState<'model> {
DenoiseState {
input_mem: vec![0.0; FRAME_SIZE.max(PITCH_BUF_SIZE)],
cepstral_mem: [[0.0; NB_BANDS]; CEPS_MEM],
mem_id: 0,
synthesis_mem: [0.0; FRAME_SIZE],
mem_hp_x: [0.0; 2],
lastg: [0.0; NB_BANDS],
fft: crate::fft::RealFft::new(crate::sin_cos_table()),
rnn: crate::rnn::RnnState::new(model),
pitch_finder: crate::pitch::PitchFinder::new(),
}
}
fn input(&self, len: usize) -> &[f32] {
&self.input_mem[self.input_mem.len().checked_sub(len).unwrap()..]
}
fn find_pitch(&mut self) -> usize {
let input = &self.input_mem[self.input_mem.len().checked_sub(PITCH_BUF_SIZE).unwrap()..];
let (pitch, _gain) = self.pitch_finder.process(input);
pitch
}
fn forward_transform(&mut self, output: &mut [Complex], input: &mut [f32]) {
self.fft.forward(input, output);
let norm = 1.0 / WINDOW_SIZE as f32;
for x in &mut output[..] {
*x *= norm;
}
}
fn inverse_transform(&mut self, output: &mut [f32], input: &mut [Complex]) {
self.fft.inverse(input, output);
}
pub fn process_frame(&mut self, output: &mut [f32], input: &[f32]) -> f32 {
process_frame(self, output, input)
}
}
fn frame_analysis(state: &mut DenoiseState, x: &mut [Complex], ex: &mut [f32]) {
let mut buf = [0.0; WINDOW_SIZE];
crate::apply_window(&mut buf[..], state.input(WINDOW_SIZE));
state.forward_transform(x, &mut buf[..]);
crate::compute_band_corr(ex, x, x);
}
fn compute_frame_features(
state: &mut DenoiseState,
x: &mut [Complex],
p: &mut [Complex],
ex: &mut [f32],
ep: &mut [f32],
exp: &mut [f32],
features: &mut [f32],
) -> usize {
let mut ly = [0.0; NB_BANDS];
let mut p_buf = [0.0; WINDOW_SIZE];
let mut tmp = [0.0; NB_BANDS];
frame_analysis(state, x, ex);
let pitch_idx = state.find_pitch();
crate::apply_window(&mut p_buf[..], state.input(WINDOW_SIZE + pitch_idx));
state.forward_transform(p, &mut p_buf[..]);
crate::compute_band_corr(ep, p, p);
crate::compute_band_corr(exp, x, p);
for i in 0..NB_BANDS {
exp[i] /= (0.001 + ex[i] * ep[i]).sqrt();
}
crate::dct(&mut tmp[..], exp);
for i in 0..NB_DELTA_CEPS {
features[NB_BANDS + 2 * NB_DELTA_CEPS + i] = tmp[i];
}
features[NB_BANDS + 2 * NB_DELTA_CEPS] -= 1.3;
features[NB_BANDS + 2 * NB_DELTA_CEPS + 1] -= 0.9;
features[NB_BANDS + 3 * NB_DELTA_CEPS] = 0.01 * (pitch_idx as f32 - 300.0);
let mut log_max = -2.0;
let mut follow = -2.0;
let mut e = 0.0;
for i in 0..NB_BANDS {
ly[i] = (1e-2 + ex[i]).log10().max(log_max - 7.0).max(follow - 1.5);
log_max = log_max.max(ly[i]);
follow = (follow - 1.5).max(ly[i]);
e += ex[i];
}
if e < 0.04 {
for i in 0..NB_FEATURES {
features[i] = 0.0;
}
return 1;
}
crate::dct(features, &ly[..]);
features[0] -= 12.0;
features[1] -= 4.0;
let ceps_0_idx = state.mem_id;
let ceps_1_idx = if state.mem_id < 1 {
CEPS_MEM + state.mem_id - 1
} else {
state.mem_id - 1
};
let ceps_2_idx = if state.mem_id < 2 {
CEPS_MEM + state.mem_id - 2
} else {
state.mem_id - 2
};
for i in 0..NB_BANDS {
state.cepstral_mem[ceps_0_idx][i] = features[i];
}
state.mem_id += 1;
let ceps_0 = &state.cepstral_mem[ceps_0_idx];
let ceps_1 = &state.cepstral_mem[ceps_1_idx];
let ceps_2 = &state.cepstral_mem[ceps_2_idx];
for i in 0..NB_DELTA_CEPS {
features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
features[NB_BANDS + i] = ceps_0[i] - ceps_2[i];
features[NB_BANDS + NB_DELTA_CEPS + i] = ceps_0[i] - 2.0 * ceps_1[i] + ceps_2[i];
}
let mut spec_variability = 0.0;
if state.mem_id == CEPS_MEM {
state.mem_id = 0;
}
for i in 0..CEPS_MEM {
let mut min_dist = 1e15f32;
for j in 0..CEPS_MEM {
let mut dist = 0.0;
for k in 0..NB_BANDS {
let tmp = state.cepstral_mem[i][k] - state.cepstral_mem[j][k];
dist += tmp * tmp;
}
if j != i {
min_dist = min_dist.min(dist);
}
}
spec_variability += min_dist;
}
features[NB_BANDS + 3 * NB_DELTA_CEPS + 1] = spec_variability / CEPS_MEM as f32 - 2.1;
return 0;
}
fn frame_synthesis(state: &mut DenoiseState, out: &mut [f32], y: &mut [Complex]) {
let mut x = [0.0; WINDOW_SIZE];
state.inverse_transform(&mut x[..], y);
crate::apply_window_in_place(&mut x[..]);
for i in 0..FRAME_SIZE {
out[i] = x[i] + state.synthesis_mem[i];
state.synthesis_mem[i] = x[FRAME_SIZE + i];
}
}
fn biquad(ys: &mut [f32], mem: &mut [f32], xs: &[f32], b: &[f32], a: &[f32]) {
let a0 = a[0] as f64;
let a1 = a[1] as f64;
let b0 = b[0] as f64;
let b1 = b[1] as f64;
for (&x, y) in xs.iter().zip(ys) {
let x64 = x as f64;
let y64 = x64 + mem[0] as f64;
mem[0] = (mem[1] as f64 + (b0 * x64 - a0 * y64)) as f32;
mem[1] = (b1 * x64 - a1 * y64) as f32;
*y = y64 as f32;
}
}
fn pitch_filter(x: &mut [Complex], p: &[Complex], ex: &[f32], ep: &[f32], exp: &[f32], g: &[f32]) {
let mut r = [0.0; NB_BANDS];
let mut rf = [0.0; FREQ_SIZE];
for i in 0..NB_BANDS {
r[i] = if exp[i] > g[i] {
1.0
} else {
let exp_sq = exp[i] * exp[i];
let g_sq = g[i] * g[i];
exp_sq * (1.0 - g_sq) / (0.001 + g_sq * (1.0 - exp_sq))
};
r[i] = 1.0_f32.min(0.0_f32.max(r[i])).sqrt();
r[i] *= (ex[i] / (1e-8 + ep[i])).sqrt();
}
crate::interp_band_gain(&mut rf[..], &r[..]);
for i in 0..FREQ_SIZE {
x[i] += rf[i] * p[i];
}
let mut new_e = [0.0; NB_BANDS];
crate::compute_band_corr(&mut new_e[..], x, x);
let mut norm = [0.0; NB_BANDS];
let mut normf = [0.0; FREQ_SIZE];
for i in 0..NB_BANDS {
norm[i] = (ex[i] / (1e-8 + new_e[i])).sqrt();
}
crate::interp_band_gain(&mut normf[..], &norm[..]);
for i in 0..FREQ_SIZE {
x[i] *= normf[i];
}
}
fn process_frame(state: &mut DenoiseState, output: &mut [f32], input: &[f32]) -> f32 {
let mut x_freq = [Complex::from(0.0); FREQ_SIZE];
let mut p = [Complex::from(0.0); FREQ_SIZE];
let mut ex = [0.0; NB_BANDS];
let mut ep = [0.0; NB_BANDS];
let mut exp = [0.0; NB_BANDS];
let mut features = [0.0; NB_FEATURES];
let mut g = [0.0; NB_BANDS];
let mut gf = [1.0; FREQ_SIZE];
let a_hp = [-1.99599, 0.99600];
let b_hp = [-2.0, 1.0];
let mut vad_prob = [0.0];
let new_idx = state.input_mem.len() - FRAME_SIZE;
for i in 0..new_idx {
state.input_mem[i] = state.input_mem[i + FRAME_SIZE];
}
biquad(
&mut state.input_mem[new_idx..],
&mut state.mem_hp_x[..],
input,
&b_hp[..],
&a_hp[..],
);
let silence = compute_frame_features(
state,
&mut x_freq[..],
&mut p[..],
&mut ex[..],
&mut ep[..],
&mut exp[..],
&mut features[..],
);
if silence == 0 {
state
.rnn
.compute(&mut g[..], &mut vad_prob[..], &features[..]);
pitch_filter(&mut x_freq[..], &p[..], &ex[..], &ep[..], &exp[..], &g[..]);
for i in 0..NB_BANDS {
g[i] = g[i].max(0.6 * state.lastg[i]);
state.lastg[i] = g[i];
}
crate::interp_band_gain(&mut gf[..], &g[..]);
for i in 0..FREQ_SIZE {
x_freq[i] *= gf[i];
}
}
frame_synthesis(state, output, &mut x_freq[..]);
vad_prob[0]
}
#[cfg(test)]
mod tests {
use super::*;
extern crate static_assertions as sa;
sa::assert_impl_all!(DenoiseState: Send, Sync);
}