use super::resample::AudioResampler;
use nnnoiseless::DenoiseState;
pub const DENOISE_SAMPLE_RATE: u32 = 48000;
const FRAME_SIZE: usize = 480;
pub struct Denoiser {
state: Box<DenoiseState<'static>>,
sample_rate: u32,
upsampler: Option<AudioResampler>,
downsampler: Option<AudioResampler>,
input_buffer: Vec<f32>,
output_buffer: Vec<f32>,
first_frame: bool,
}
impl std::fmt::Debug for Denoiser {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Denoiser")
.field("sample_rate", &self.sample_rate)
.field("resampling", &self.upsampler.is_some())
.field("input_buffer_len", &self.input_buffer.len())
.field("output_buffer_len", &self.output_buffer.len())
.field("first_frame", &self.first_frame)
.finish_non_exhaustive()
}
}
impl Denoiser {
pub fn new(sample_rate: u32) -> Self {
let (upsampler, downsampler) = if sample_rate == DENOISE_SAMPLE_RATE {
(None, None)
} else {
let up = AudioResampler::new(sample_rate, DENOISE_SAMPLE_RATE)
.expect("failed to create upsampler");
let down = AudioResampler::new(DENOISE_SAMPLE_RATE, sample_rate)
.expect("failed to create downsampler");
(Some(up), Some(down))
};
Self {
state: DenoiseState::new(),
sample_rate,
upsampler,
downsampler,
input_buffer: Vec::with_capacity(FRAME_SIZE),
output_buffer: Vec::new(),
first_frame: true,
}
}
pub fn sample_rate(&self) -> u32 {
self.sample_rate
}
pub fn is_resampling(&self) -> bool {
self.upsampler.is_some()
}
pub fn process(&mut self, samples: &[i16]) -> Vec<i16> {
let samples_48k: Vec<i16> = if let Some(ref mut upsampler) = self.upsampler {
upsampler.process(samples)
} else {
samples.to_vec()
};
for &sample in &samples_48k {
self.input_buffer.push(sample as f32);
}
while self.input_buffer.len() >= FRAME_SIZE {
let mut input_frame = [0.0f32; FRAME_SIZE];
let mut output_frame = [0.0f32; FRAME_SIZE];
input_frame.copy_from_slice(&self.input_buffer[..FRAME_SIZE]);
self.input_buffer.drain(..FRAME_SIZE);
let _vad_prob = self.state.process_frame(&mut output_frame, &input_frame);
if self.first_frame {
self.first_frame = false;
self.output_buffer
.extend(std::iter::repeat_n(0.0, FRAME_SIZE));
} else {
self.output_buffer.extend_from_slice(&output_frame);
}
}
let denoised_48k: Vec<i16> = self
.output_buffer
.drain(..)
.map(|s| s.round().clamp(-32768.0, 32767.0) as i16)
.collect();
if let Some(ref mut downsampler) = self.downsampler {
downsampler.process(&denoised_48k)
} else {
denoised_48k
}
}
pub fn process_aligned(&mut self, samples: &[i16]) -> Vec<i16> {
assert!(
samples.len().is_multiple_of(FRAME_SIZE),
"Input length {} is not a multiple of frame size {}",
samples.len(),
FRAME_SIZE
);
let mut output = Vec::with_capacity(samples.len());
let mut input_frame = [0.0f32; FRAME_SIZE];
let mut output_frame = [0.0f32; FRAME_SIZE];
for chunk in samples.chunks_exact(FRAME_SIZE) {
for (i, &sample) in chunk.iter().enumerate() {
input_frame[i] = sample as f32;
}
let _vad_prob = self.state.process_frame(&mut output_frame, &input_frame);
if self.first_frame {
self.first_frame = false;
output.extend(std::iter::repeat_n(0i16, FRAME_SIZE));
} else {
for &s in &output_frame {
output.push(s.round().clamp(-32768.0, 32767.0) as i16);
}
}
}
output
}
pub fn reset(&mut self) {
self.state = DenoiseState::new();
self.input_buffer.clear();
self.output_buffer.clear();
self.first_frame = true;
if let Some(ref mut upsampler) = self.upsampler {
upsampler.reset();
}
if let Some(ref mut downsampler) = self.downsampler {
downsampler.reset();
}
}
pub fn buffered_samples(&self) -> usize {
self.input_buffer.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_denoiser_creation_48k() {
let denoiser = Denoiser::new(48000);
assert_eq!(denoiser.buffered_samples(), 0);
assert_eq!(denoiser.sample_rate(), 48000);
assert!(!denoiser.is_resampling());
}
#[test]
fn test_denoiser_creation_16k() {
let denoiser = Denoiser::new(16000);
assert_eq!(denoiser.buffered_samples(), 0);
assert_eq!(denoiser.sample_rate(), 16000);
assert!(denoiser.is_resampling());
}
#[test]
fn test_denoiser_process_single_frame_48k() {
let mut denoiser = Denoiser::new(48000);
let input: Vec<i16> = vec![0; FRAME_SIZE];
let output = denoiser.process(&input);
assert_eq!(output.len(), FRAME_SIZE);
}
#[test]
fn test_denoiser_process_multiple_frames_48k() {
let mut denoiser = Denoiser::new(48000);
let input: Vec<i16> = vec![0; FRAME_SIZE * 2];
let output = denoiser.process(&input);
assert_eq!(output.len(), FRAME_SIZE * 2);
}
#[test]
fn test_denoiser_process_partial_frame() {
let mut denoiser = Denoiser::new(48000);
let input: Vec<i16> = vec![0; 100];
let output = denoiser.process(&input);
assert_eq!(output.len(), 0);
assert_eq!(denoiser.buffered_samples(), 100);
let input2: Vec<i16> = vec![0; FRAME_SIZE - 100];
let output2 = denoiser.process(&input2);
assert_eq!(output2.len(), FRAME_SIZE);
assert_eq!(denoiser.buffered_samples(), 0);
}
#[test]
fn test_denoiser_reset() {
let mut denoiser = Denoiser::new(48000);
let input: Vec<i16> = vec![0; 100];
denoiser.process(&input);
assert_eq!(denoiser.buffered_samples(), 100);
denoiser.reset();
assert_eq!(denoiser.buffered_samples(), 0);
}
#[test]
fn test_denoiser_aligned() {
let mut denoiser = Denoiser::new(48000);
let input: Vec<i16> = vec![0; FRAME_SIZE * 3];
let output = denoiser.process_aligned(&input);
assert_eq!(output.len(), FRAME_SIZE * 3);
}
#[test]
fn test_denoiser_16k_produces_output() {
let mut denoiser = Denoiser::new(16000);
let input: Vec<i16> = vec![0; 2048];
let output = denoiser.process(&input);
assert!(
output.len() > 0 || denoiser.buffered_samples() > 0,
"Should either produce output or buffer samples"
);
}
#[test]
fn test_denoiser_16k_continuous_processing() {
let mut denoiser = Denoiser::new(16000);
let chunk: Vec<i16> = vec![0; 320]; let mut total_output = 0;
for _ in 0..20 {
let output = denoiser.process(&chunk);
total_output += output.len();
}
assert!(
total_output > 5000,
"Expected significant output, got {total_output}"
);
}
}