use super::{VoiceError, VoiceResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IsolationMethod {
#[default]
SpectralSubtraction,
WienerFilter,
NeuralMask,
UNet,
ConvTasNet,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NoiseEstimation {
#[default]
InitialSilence,
MinimumStatistics,
Adaptive,
FixedProfile,
}
#[derive(Debug, Clone)]
pub struct IsolationConfig {
pub method: IsolationMethod,
pub noise_estimation: NoiseEstimation,
pub sample_rate: u32,
pub fft_size: usize,
pub hop_length: usize,
pub reduction_strength: f32,
pub spectral_floor: f32,
pub preserve_musical_noise: bool,
pub noise_frames: usize,
}
impl Default for IsolationConfig {
fn default() -> Self {
Self {
method: IsolationMethod::default(),
noise_estimation: NoiseEstimation::default(),
sample_rate: 16000,
fft_size: 512,
hop_length: 128,
reduction_strength: 0.8,
spectral_floor: 0.01,
preserve_musical_noise: true,
noise_frames: 10,
}
}
}
impl IsolationConfig {
#[must_use]
pub fn aggressive() -> Self {
Self {
reduction_strength: 0.95,
spectral_floor: 0.001,
preserve_musical_noise: false,
..Self::default()
}
}
#[must_use]
pub fn mild() -> Self {
Self {
reduction_strength: 0.5,
spectral_floor: 0.1,
preserve_musical_noise: true,
..Self::default()
}
}
#[must_use]
pub fn neural() -> Self {
Self {
method: IsolationMethod::NeuralMask,
noise_estimation: NoiseEstimation::Adaptive,
reduction_strength: 0.9,
..Self::default()
}
}
#[must_use]
pub fn realtime() -> Self {
Self {
fft_size: 256,
hop_length: 64,
noise_frames: 5,
..Self::default()
}
}
pub fn validate(&self) -> VoiceResult<()> {
if self.sample_rate == 0 {
return Err(VoiceError::InvalidConfig(
"sample_rate must be > 0".to_string(),
));
}
if self.fft_size == 0 || !self.fft_size.is_power_of_two() {
return Err(VoiceError::InvalidConfig(
"fft_size must be a power of 2".to_string(),
));
}
if self.hop_length == 0 || self.hop_length > self.fft_size {
return Err(VoiceError::InvalidConfig(
"hop_length must be > 0 and <= fft_size".to_string(),
));
}
if !(0.0..=1.0).contains(&self.reduction_strength) {
return Err(VoiceError::InvalidConfig(
"reduction_strength must be in [0.0, 1.0]".to_string(),
));
}
if !(0.0..=1.0).contains(&self.spectral_floor) {
return Err(VoiceError::InvalidConfig(
"spectral_floor must be in [0.0, 1.0]".to_string(),
));
}
Ok(())
}
#[must_use]
pub fn freq_bins(&self) -> usize {
self.fft_size / 2 + 1
}
#[must_use]
pub fn frame_duration_secs(&self) -> f32 {
self.fft_size as f32 / self.sample_rate as f32
}
}
#[derive(Debug, Clone)]
pub struct IsolationResult {
pub audio: Vec<f32>,
pub noise_floor: Option<Vec<f32>>,
pub sample_rate: u32,
pub snr_improvement_db: f32,
pub input_snr_db: f32,
pub output_snr_db: f32,
}
impl IsolationResult {
#[must_use]
pub fn new(audio: Vec<f32>, sample_rate: u32) -> Self {
Self {
audio,
noise_floor: None,
sample_rate,
snr_improvement_db: 0.0,
input_snr_db: 0.0,
output_snr_db: 0.0,
}
}
#[must_use]
pub fn with_snr(mut self, input_snr: f32, output_snr: f32) -> Self {
self.input_snr_db = input_snr;
self.output_snr_db = output_snr;
self.snr_improvement_db = output_snr - input_snr;
self
}
#[must_use]
pub fn with_noise_floor(mut self, noise_floor: Vec<f32>) -> Self {
self.noise_floor = Some(noise_floor);
self
}
}
#[derive(Debug, Clone)]
pub struct NoiseProfile {
pub mean_spectrum: Vec<f32>,
pub std_spectrum: Vec<f32>,
pub num_frames: usize,
pub sample_rate: u32,
}
impl NoiseProfile {
#[must_use]
pub fn from_frames(frames: &[Vec<f32>], sample_rate: u32) -> Self {
if frames.is_empty() {
return Self {
mean_spectrum: vec![],
std_spectrum: vec![],
num_frames: 0,
sample_rate,
};
}
let num_bins = frames[0].len();
let num_frames = frames.len();
let mut mean_spectrum = vec![0.0f32; num_bins];
for frame in frames {
for (i, &val) in frame.iter().enumerate() {
if i < num_bins {
mean_spectrum[i] += val;
}
}
}
for m in &mut mean_spectrum {
*m /= num_frames as f32;
}
let mut std_spectrum = vec![0.0f32; num_bins];
for frame in frames {
for (i, &val) in frame.iter().enumerate() {
if i < num_bins {
let diff = val - mean_spectrum[i];
std_spectrum[i] += diff * diff;
}
}
}
for s in &mut std_spectrum {
*s = (*s / num_frames as f32).sqrt();
}
Self {
mean_spectrum,
std_spectrum,
num_frames,
sample_rate,
}
}
#[must_use]
pub fn noise_magnitude(&self, bin: usize) -> f32 {
self.mean_spectrum.get(bin).copied().unwrap_or(0.0)
}
#[must_use]
pub fn is_valid(&self) -> bool {
self.num_frames > 0 && !self.mean_spectrum.is_empty()
}
}
pub trait VoiceIsolator: Send + Sync {
fn config(&self) -> &IsolationConfig;
fn isolate(&self, audio: &[f32]) -> VoiceResult<IsolationResult>;
fn isolate_with_profile(
&self,
audio: &[f32],
noise_profile: &NoiseProfile,
) -> VoiceResult<IsolationResult>;
fn estimate_noise(&self, noise_audio: &[f32]) -> VoiceResult<NoiseProfile>;
}
#[derive(Debug, Clone)]
pub struct SpectralSubtractionIsolator {
config: IsolationConfig,
over_subtraction: f32,
}
include!("wiener_filter_isolator.rs");
include!("isolation_tests.rs");