use crate::error::{Error, Result};
use crate::time::{AudioDuration, AudioInstant};
use tracing::{info, warn};
#[derive(Debug, Clone, Copy)]
pub struct PreprocessingConfig {
pub highpass_cutoff_hz: f32,
pub sample_rate_hz: u32,
pub dc_bias_alpha: f32,
pub enable_dc_removal: bool,
pub enable_highpass: bool,
pub highpass_order: HighpassOrder,
}
impl Default for PreprocessingConfig {
fn default() -> Self {
Self {
highpass_cutoff_hz: 80.0,
sample_rate_hz: 16_000,
dc_bias_alpha: 0.95,
enable_dc_removal: true,
enable_highpass: true,
highpass_order: HighpassOrder::FourthOrder,
}
}
}
impl PreprocessingConfig {
#[allow(clippy::trivially_copy_pass_by_ref)]
pub fn validate(&self) -> Result<()> {
if self.sample_rate_hz == 0 {
return Err(Error::Configuration(
"sample_rate_hz must be greater than zero".into(),
));
}
if self.highpass_cutoff_hz < 20.0 {
return Err(Error::Configuration(format!(
"Cutoff {:.1} Hz too low (minimum 20 Hz)",
self.highpass_cutoff_hz
)));
}
let nyquist = self.sample_rate_hz as f32 / 2.0;
if self.highpass_cutoff_hz >= nyquist {
return Err(Error::Configuration(format!(
"Cutoff {:.1} Hz exceeds Nyquist {:.1} Hz",
self.highpass_cutoff_hz, nyquist
)));
}
if self.dc_bias_alpha <= 0.0 || self.dc_bias_alpha >= 1.0 {
return Err(Error::Configuration(format!(
"Invalid EMA alpha: {:.3} (must be in range 0.0 < α < 1.0)",
self.dc_bias_alpha
)));
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum HighpassOrder {
SecondOrder,
#[default]
FourthOrder,
}
impl HighpassOrder {
#[must_use]
fn stage_count(self) -> usize {
match self {
Self::SecondOrder => 1,
Self::FourthOrder => 2,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct VadContext {
pub is_silence: bool,
}
#[allow(missing_copy_implementations)]
#[derive(Debug, Clone)]
pub struct DcHighPassFilter {
config: PreprocessingConfig,
coeffs: BiquadCoefficients,
stages: Vec<BiquadState>,
dc_bias: f32,
}
#[derive(Debug, Clone, Copy)]
struct BiquadCoefficients {
b0: f32,
b1: f32,
b2: f32,
a1: f32,
a2: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
struct BiquadState {
x1: f32,
x2: f32,
y1: f32,
y2: f32,
}
impl BiquadState {
#[inline]
fn process(&mut self, coeffs: &BiquadCoefficients, input: f32) -> f32 {
let acc = coeffs
.b0
.mul_add(input, coeffs.b1.mul_add(self.x1, coeffs.b2 * self.x2));
let output = acc - coeffs.a1.mul_add(self.y1, coeffs.a2 * self.y2);
self.x2 = self.x1;
self.x1 = input;
self.y2 = self.y1;
self.y1 = output;
output
}
fn reset(&mut self) {
*self = Self::default();
}
#[cfg(test)]
fn is_reset(self) -> bool {
self == Self::default()
}
}
impl DcHighPassFilter {
pub fn new(config: PreprocessingConfig) -> Result<Self> {
config.validate()?;
let (b0, b1, b2, a1, a2) = compute_butterworth_highpass_coefficients(
config.highpass_cutoff_hz,
config.sample_rate_hz,
)?;
let coeffs = BiquadCoefficients { b0, b1, b2, a1, a2 };
let stage_count = config.highpass_order.stage_count();
let stages = vec![BiquadState::default(); stage_count];
Ok(Self {
config,
coeffs,
stages,
dc_bias: 0.0,
})
}
#[allow(clippy::unnecessary_wraps)]
#[allow(clippy::trivially_copy_pass_by_ref)]
pub fn process(
&mut self,
samples: &[f32],
vad_context: Option<&VadContext>,
) -> Result<Vec<f32>> {
let processing_start = AudioInstant::now();
if samples.is_empty() {
return Ok(Vec::new());
}
let should_update_bias = vad_context.is_none_or(|ctx| ctx.is_silence);
if self.config.enable_dc_removal && should_update_bias {
self.update_dc_bias(samples);
}
let output = self.process_samples(samples);
let elapsed = elapsed_duration(processing_start);
let latency_ms = elapsed.as_secs_f64() * 1000.0;
self.record_performance_metrics(samples.len(), latency_ms);
Ok(output)
}
#[inline]
fn process_samples(&mut self, samples: &[f32]) -> Vec<f32> {
let mut output = Vec::with_capacity(samples.len());
for &sample in samples {
let mut next = if self.config.enable_dc_removal {
sample - self.dc_bias
} else {
sample
};
if self.config.enable_highpass {
for stage in &mut self.stages {
next = stage.process(&self.coeffs, next);
}
}
output.push(next);
}
output
}
fn record_performance_metrics(&self, sample_count: usize, latency_ms: f64) {
if sample_count < 8000 {
return;
}
if latency_ms > 2.0 {
warn!(
target: "audio.preprocess.highpass",
latency_ms,
samples = sample_count,
cutoff_hz = self.config.highpass_cutoff_hz,
order = ?self.config.highpass_order,
"high-pass latency exceeded target"
);
}
info!(
target: "audio.preprocess.highpass",
dc_bias = self.dc_bias,
latency_ms,
samples = sample_count,
cutoff_hz = self.config.highpass_cutoff_hz,
order = ?self.config.highpass_order,
"audio preprocess high-pass metrics"
);
}
pub fn reset(&mut self) {
for stage in &mut self.stages {
stage.reset();
}
self.dc_bias = 0.0;
}
#[must_use]
pub fn dc_bias(&self) -> f32 {
self.dc_bias
}
#[must_use]
pub fn config(&self) -> &PreprocessingConfig {
&self.config
}
fn update_dc_bias(&mut self, samples: &[f32]) {
if samples.is_empty() {
return;
}
let sum: f32 = samples.iter().sum();
let current_mean = sum / samples.len() as f32;
let alpha = self.config.dc_bias_alpha;
self.dc_bias = alpha.mul_add(self.dc_bias, (1.0 - alpha) * current_mean);
}
}
fn compute_butterworth_highpass_coefficients(
cutoff_hz: f32,
sample_rate_hz: u32,
) -> Result<(f32, f32, f32, f32, f32)> {
use std::f32::consts::PI;
let w0 = 2.0 * PI * cutoff_hz / sample_rate_hz as f32;
let q = 0.707; let alpha = w0.sin() / (2.0 * q);
let cos_w0 = w0.cos();
let b0_unnorm = f32::midpoint(1.0, cos_w0);
let b1_unnorm = -(1.0 + cos_w0);
let b2_unnorm = f32::midpoint(1.0, cos_w0);
let a0 = 1.0 + alpha;
let a1_unnorm = -2.0 * cos_w0;
let a2_unnorm = 1.0 - alpha;
let b0 = b0_unnorm / a0;
let b1 = b1_unnorm / a0;
let b2 = b2_unnorm / a0;
let a1 = a1_unnorm / a0;
let a2 = a2_unnorm / a0;
if !b0.is_finite() || !b1.is_finite() || !b2.is_finite() || !a1.is_finite() || !a2.is_finite() {
return Err(Error::Processing(format!(
"Invalid filter coefficients for fc={cutoff_hz:.1}Hz, fs={sample_rate_hz}: \
b0={b0:.6}, b1={b1:.6}, b2={b2:.6}, a1={a1:.6}, a2={a2:.6}"
)));
}
Ok((b0, b1, b2, a1, a2))
}
fn elapsed_duration(start: AudioInstant) -> AudioDuration {
AudioInstant::now().duration_since(start)
}
#[cfg(test)]
mod tests {
use super::*;
type TestResult<T> = std::result::Result<T, String>;
fn generate_sine_wave(
frequency: f32,
sample_rate: u32,
duration_secs: f32,
amplitude: f32,
) -> Vec<f32> {
use std::f32::consts::PI;
let samples = (sample_rate as f32 * duration_secs).round() as usize;
(0..samples)
.map(|i| {
let t = i as f32 / sample_rate as f32;
(2.0 * PI * frequency * t).sin() * amplitude
})
.collect()
}
fn calculate_rms(samples: &[f32]) -> f32 {
if samples.is_empty() {
return 0.0;
}
let sum_sq: f32 = samples.iter().map(|&s| s * s).sum();
(sum_sq / samples.len() as f32).sqrt()
}
fn calculate_attenuation_db(input: &[f32], output: &[f32]) -> f32 {
let rms_in = calculate_rms(input);
let rms_out = calculate_rms(output);
if rms_in == 0.0 || rms_out == 0.0 {
return 0.0;
}
20.0 * (rms_out / rms_in).log10()
}
#[test]
fn test_dc_offset_removal_synthetic_bias() -> TestResult<()> {
let dc_offset = 0.5;
let mut samples_with_dc: Vec<f32> = generate_sine_wave(440.0, 16000, 0.5, 0.3);
for sample in &mut samples_with_dc {
*sample += dc_offset;
}
let config = PreprocessingConfig {
enable_dc_removal: true,
enable_highpass: false, dc_bias_alpha: 0.5, ..Default::default()
};
let mut filter = DcHighPassFilter::new(config).map_err(|e| e.to_string())?;
let mut final_output = Vec::new();
for _ in 0..10 {
final_output = filter
.process(&samples_with_dc, None)
.map_err(|e| e.to_string())?;
}
let mean: f32 = final_output.iter().sum::<f32>() / final_output.len() as f32;
assert!(
mean.abs() < 0.001,
"DC residual too high after convergence: {:.6} (expected < 0.001)",
mean
);
assert!(
(filter.dc_bias() - dc_offset).abs() < 0.005,
"DC bias estimate {:.6} not converged to {:.6}",
filter.dc_bias(),
dc_offset
);
Ok(())
}
#[test]
fn test_highpass_frequency_response() -> TestResult<()> {
let config = PreprocessingConfig {
highpass_cutoff_hz: 80.0,
enable_dc_removal: false, ..Default::default()
};
let mut filter = DcHighPassFilter::new(config).map_err(|e| e.to_string())?;
let input_20hz = generate_sine_wave(20.0, 16000, 1.0, 1.0);
let output_20hz = filter
.process(&input_20hz, None)
.map_err(|e| e.to_string())?;
filter.reset();
let attenuation_20hz = calculate_attenuation_db(&input_20hz, &output_20hz);
assert!(
attenuation_20hz <= -30.0,
"Insufficient attenuation at 20Hz: {:.1} dB (expected ≤ -30 dB)",
attenuation_20hz
);
let input_40hz = generate_sine_wave(40.0, 16000, 1.0, 1.0);
let output_40hz = filter
.process(&input_40hz, None)
.map_err(|e| e.to_string())?;
filter.reset();
let attenuation_40hz = calculate_attenuation_db(&input_40hz, &output_40hz);
assert!(
attenuation_40hz <= -20.0,
"Insufficient attenuation at 40Hz: {:.1} dB (expected ≤ -20 dB)",
attenuation_40hz
);
let input_150hz = generate_sine_wave(150.0, 16000, 1.0, 1.0);
let output_150hz = filter
.process(&input_150hz, None)
.map_err(|e| e.to_string())?;
let loss_150hz = calculate_attenuation_db(&input_150hz, &output_150hz);
assert!(
loss_150hz > -1.0,
"Excessive loss at 150Hz: {:.1} dB (expected > -1 dB)",
loss_150hz
);
Ok(())
}
#[test]
fn test_chunk_boundary_continuity() -> TestResult<()> {
let long_signal = generate_sine_wave(440.0, 16000, 1.0, 0.5); let config = PreprocessingConfig::default();
let mut filter1 = DcHighPassFilter::new(config).map_err(|e| e.to_string())?;
let output_single = filter1
.process(&long_signal, None)
.map_err(|e| e.to_string())?;
let mut filter2 = DcHighPassFilter::new(config).map_err(|e| e.to_string())?;
let mid = long_signal.len() / 2;
let chunk1 = &long_signal[0..mid];
let chunk2 = &long_signal[mid..];
let output_chunk1 = filter2.process(chunk1, None).map_err(|e| e.to_string())?;
let output_chunk2 = filter2.process(chunk2, None).map_err(|e| e.to_string())?;
let output_chunked: Vec<f32> = output_chunk1.into_iter().chain(output_chunk2).collect();
for (i, (single, chunked)) in output_single.iter().zip(output_chunked.iter()).enumerate() {
let diff = (single - chunked).abs();
assert!(
diff < 5e-5,
"Discontinuity at sample {}: diff={:.9} (single={:.9}, chunked={:.9})",
i,
diff,
single,
chunked
);
}
Ok(())
}
#[test]
fn test_vad_informed_dc_update() -> TestResult<()> {
let config = PreprocessingConfig::default();
let mut filter = DcHighPassFilter::new(config).map_err(|e| e.to_string())?;
let speech_samples = vec![0.1, 0.2, -0.1, 0.3];
let speech_ctx = VadContext { is_silence: false };
let initial_bias = filter.dc_bias();
filter
.process(&speech_samples, Some(&speech_ctx))
.map_err(|e| e.to_string())?;
assert_eq!(
filter.dc_bias(),
initial_bias,
"DC bias changed during speech"
);
let silence_samples = vec![0.5; 1000];
let silence_ctx = VadContext { is_silence: true };
filter
.process(&silence_samples, Some(&silence_ctx))
.map_err(|e| e.to_string())?;
assert!(
filter.dc_bias() > initial_bias,
"DC bias did not adapt during silence (initial={:.6}, after={:.6})",
initial_bias,
filter.dc_bias()
);
Ok(())
}
#[test]
fn test_configuration_validation() {
let valid_config = PreprocessingConfig::default();
assert!(valid_config.validate().is_ok());
let config_low = PreprocessingConfig {
highpass_cutoff_hz: 10.0,
..Default::default()
};
assert!(config_low.validate().is_err());
let config_high = PreprocessingConfig {
highpass_cutoff_hz: 9000.0, ..Default::default()
};
assert!(config_high.validate().is_err());
let config_alpha = PreprocessingConfig {
dc_bias_alpha: 1.0, ..Default::default()
};
assert!(config_alpha.validate().is_err());
let config_zero_sr = PreprocessingConfig {
sample_rate_hz: 0,
..Default::default()
};
assert!(config_zero_sr.validate().is_err());
}
#[test]
fn test_reset_clears_state() -> TestResult<()> {
let config = PreprocessingConfig::default();
let mut filter = DcHighPassFilter::new(config).map_err(|e| e.to_string())?;
let samples = generate_sine_wave(440.0, 16000, 0.5, 0.8);
filter.process(&samples, None).map_err(|e| e.to_string())?;
assert_ne!(
filter.dc_bias(),
0.0,
"DC bias should be non-zero after processing"
);
assert!(
filter.stages.iter().copied().any(|stage| !stage.is_reset()),
"Filter stages should accumulate state after processing"
);
filter.reset();
assert_eq!(filter.dc_bias(), 0.0, "DC bias should be zero after reset");
assert!(
filter.stages.iter().copied().all(BiquadState::is_reset),
"Filter stages should be reset to zero state"
);
Ok(())
}
#[test]
fn test_empty_input() -> TestResult<()> {
let config = PreprocessingConfig::default();
let mut filter = DcHighPassFilter::new(config).map_err(|e| e.to_string())?;
let output = filter.process(&[], None).map_err(|e| e.to_string())?;
assert!(output.is_empty());
Ok(())
}
}