use rustfft::{num_complex::Complex, FftPlanner};
use serde::Serialize;
use std::sync::Arc;
use crate::AgcConfig;
#[derive(Clone, Debug)]
pub enum SpeechStateChange {
None,
Started { lookback_samples: usize },
Ended { duration_ms: u64 },
}
#[derive(Clone, Debug)]
pub struct WordBreakEvent {
pub offset_ms: u32,
pub gap_duration_ms: u32,
}
#[derive(Clone, Debug, Serialize)]
pub struct SpeechMetrics {
pub amplitude_db: f32,
pub zcr: f32,
pub centroid_hz: f32,
pub is_speaking: bool,
pub is_voiced_pending: bool,
pub is_whisper_pending: bool,
pub is_transient: bool,
pub is_lookback_speech: bool,
pub lookback_offset_ms: Option<u32>,
pub is_word_break: bool,
}
#[derive(Clone, Debug, Serialize)]
pub struct SpeechEventPayload {
pub duration_ms: Option<u64>,
pub lookback_offset_ms: Option<u32>,
}
#[derive(Clone, Debug, Serialize)]
pub struct WordBreakPayload {
pub offset_ms: u32,
pub gap_duration_ms: u32,
}
pub trait SpeechEventCallback: Send {
fn on_speech_started(&self, payload: SpeechEventPayload);
fn on_speech_ended(&self, payload: SpeechEventPayload);
fn on_word_break(&self, payload: WordBreakPayload);
}
#[derive(Clone)]
#[allow(dead_code)]
struct SpeechModeConfig {
threshold_db: f32,
zcr_range: (f32, f32),
centroid_range: (f32, f32),
onset_samples: u32,
}
#[allow(dead_code)]
pub struct SpeechDetector {
sample_rate: u32,
voiced_config: SpeechModeConfig,
whisper_config: SpeechModeConfig,
transient_zcr_threshold: f32,
transient_centroid_threshold: f32,
hold_samples: u32,
is_speaking: bool,
is_pending_voiced: bool,
is_pending_whisper: bool,
voiced_onset_count: u32,
whisper_onset_count: u32,
silence_sample_count: u32,
speech_sample_count: u64,
onset_grace_samples: u32,
voiced_grace_count: u32,
whisper_grace_count: u32,
initialized: bool,
last_amplitude_db: f32,
last_zcr: f32,
last_centroid_hz: f32,
last_is_transient: bool,
lookback_buffer: Vec<f32>,
lookback_write_index: usize,
lookback_capacity: usize,
lookback_filled: bool,
lookback_threshold_db: f32,
last_lookback_offset_ms: Option<u32>,
last_state_change: SpeechStateChange,
word_break_threshold_ratio: f32,
min_word_break_samples: u32,
max_word_break_samples: u32,
recent_speech_window_samples: u32,
recent_speech_amplitude_sum: f32,
recent_speech_amplitude_count: u32,
in_word_break: bool,
word_break_sample_count: u32,
word_break_start_speech_samples: u64,
last_is_word_break: bool,
last_word_break_event: Option<WordBreakEvent>,
callback: Option<Arc<dyn SpeechEventCallback>>,
}
impl SpeechDetector {
pub fn new(sample_rate: u32) -> Self {
Self::with_defaults(sample_rate)
}
pub fn with_defaults(sample_rate: u32) -> Self {
let hold_samples = (sample_rate as u64 * 200 / 1000) as u32;
let lookback_capacity = (sample_rate as u64 * 200 / 1000) as usize;
Self {
sample_rate,
voiced_config: SpeechModeConfig {
threshold_db: -42.0,
zcr_range: (0.01, 0.30),
centroid_range: (200.0, 5500.0),
onset_samples: (sample_rate as u64 * 80 / 1000) as u32,
},
whisper_config: SpeechModeConfig {
threshold_db: -52.0,
zcr_range: (0.02, 0.45),
centroid_range: (300.0, 7000.0),
onset_samples: (sample_rate as u64 * 120 / 1000) as u32,
},
transient_zcr_threshold: 0.45,
transient_centroid_threshold: 6500.0,
hold_samples,
is_speaking: false,
is_pending_voiced: false,
is_pending_whisper: false,
voiced_onset_count: 0,
whisper_onset_count: 0,
silence_sample_count: 0,
speech_sample_count: 0,
onset_grace_samples: (sample_rate as u64 * 30 / 1000) as u32,
voiced_grace_count: 0,
whisper_grace_count: 0,
initialized: false,
last_amplitude_db: -100.0, last_zcr: 0.0,
last_centroid_hz: 0.0,
last_is_transient: false,
lookback_buffer: vec![0.0; lookback_capacity],
lookback_write_index: 0,
lookback_capacity,
lookback_filled: false,
lookback_threshold_db: -55.0,
last_lookback_offset_ms: None,
last_state_change: SpeechStateChange::None,
word_break_threshold_ratio: 0.25,
min_word_break_samples: (sample_rate as u64 * 40 / 1000) as u32,
max_word_break_samples: (sample_rate as u64 * 250 / 1000) as u32,
recent_speech_window_samples: (sample_rate as u64 * 100 / 1000) as u32,
recent_speech_amplitude_sum: 0.0,
recent_speech_amplitude_count: 0,
in_word_break: false,
word_break_sample_count: 0,
word_break_start_speech_samples: 0,
last_is_word_break: false,
last_word_break_event: None,
callback: None,
}
}
pub fn set_callback(&mut self, callback: Arc<dyn SpeechEventCallback>) {
self.callback = Some(callback);
}
fn calculate_rms(samples: &[f32]) -> f32 {
if samples.is_empty() {
return 0.0;
}
let sum_squares: f32 = samples.iter().map(|s| s * s).sum();
(sum_squares / samples.len() as f32).sqrt()
}
fn amplitude_to_db(amplitude: f32) -> f32 {
if amplitude <= 0.0 {
return -100.0; }
20.0 * amplitude.log10()
}
fn calculate_zcr(samples: &[f32]) -> f32 {
if samples.len() < 2 {
return 0.0;
}
let mut crossings = 0u32;
for i in 1..samples.len() {
if (samples[i] >= 0.0) != (samples[i - 1] >= 0.0) {
crossings += 1;
}
}
crossings as f32 / (samples.len() - 1) as f32
}
fn estimate_spectral_centroid(&self, samples: &[f32], amplitude_db: f32) -> f32 {
const CENTROID_GATE_DB: f32 = -55.0;
if samples.len() < 2 || amplitude_db < CENTROID_GATE_DB {
return 0.0;
}
let mut diff_sum = 0.0f32;
for i in 1..samples.len() {
diff_sum += (samples[i] - samples[i - 1]).abs();
}
let mean_diff = diff_sum / (samples.len() - 1) as f32;
let mean_abs: f32 = samples.iter().map(|s| s.abs()).sum::<f32>() / samples.len() as f32;
if mean_abs < 1e-10 {
return 0.0;
}
self.sample_rate as f32 * mean_diff / (2.0 * mean_abs)
}
#[allow(dead_code)]
fn is_transient(&self, zcr: f32, centroid: f32) -> bool {
zcr > self.transient_zcr_threshold && centroid > self.transient_centroid_threshold
}
#[allow(dead_code)]
fn matches_voiced_mode(&self, db: f32, zcr: f32, centroid: f32) -> bool {
db >= self.voiced_config.threshold_db
&& zcr >= self.voiced_config.zcr_range.0
&& zcr <= self.voiced_config.zcr_range.1
&& centroid >= self.voiced_config.centroid_range.0
&& centroid <= self.voiced_config.centroid_range.1
}
#[allow(dead_code)]
fn matches_whisper_mode(&self, db: f32, zcr: f32, centroid: f32) -> bool {
db >= self.whisper_config.threshold_db
&& zcr >= self.whisper_config.zcr_range.0
&& zcr <= self.whisper_config.zcr_range.1
&& centroid >= self.whisper_config.centroid_range.0
&& centroid <= self.whisper_config.centroid_range.1
}
fn samples_to_ms(&self, samples: u64) -> u64 {
samples * 1000 / self.sample_rate as u64
}
#[allow(dead_code)]
fn reset_onset_state(&mut self) {
self.is_pending_voiced = false;
self.is_pending_whisper = false;
self.voiced_onset_count = 0;
self.whisper_onset_count = 0;
self.voiced_grace_count = 0;
self.whisper_grace_count = 0;
}
fn push_to_lookback_buffer(&mut self, samples: &[f32]) {
for &sample in samples {
self.lookback_buffer[self.lookback_write_index] = sample;
self.lookback_write_index = (self.lookback_write_index + 1) % self.lookback_capacity;
if self.lookback_write_index == 0 {
self.lookback_filled = true;
}
}
}
fn get_lookback_buffer_contents(&self) -> Vec<f32> {
if !self.lookback_filled {
return self.lookback_buffer[..self.lookback_write_index].to_vec();
}
let mut result = Vec::with_capacity(self.lookback_capacity);
result.extend_from_slice(&self.lookback_buffer[self.lookback_write_index..]);
result.extend_from_slice(&self.lookback_buffer[..self.lookback_write_index]);
result
}
fn find_lookback_start(&self) -> (Vec<f32>, u32) {
let buffer = self.get_lookback_buffer_contents();
if buffer.is_empty() {
return (Vec::new(), 0);
}
const CHUNK_SIZE: usize = 128;
let margin_samples = (self.sample_rate as usize * 20) / 1000;
let threshold_linear = 10.0f32.powf(self.lookback_threshold_db / 20.0);
let mut first_above_threshold_idx = buffer.len();
let mut pos = buffer.len();
while pos > 0 {
let chunk_start = pos.saturating_sub(CHUNK_SIZE);
let chunk = &buffer[chunk_start..pos];
let peak = chunk.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
if peak >= threshold_linear {
first_above_threshold_idx = chunk_start;
} else if first_above_threshold_idx < buffer.len() {
break;
}
pos = chunk_start;
}
let start_with_margin = first_above_threshold_idx.saturating_sub(margin_samples);
let lookback_samples = buffer[start_with_margin..].to_vec();
let samples_before = buffer.len() - start_with_margin;
let offset_ms = (samples_before as u64 * 1000 / self.sample_rate as u64) as u32;
(lookback_samples, offset_ms)
}
pub fn get_metrics(&self) -> SpeechMetrics {
SpeechMetrics {
amplitude_db: self.last_amplitude_db,
zcr: self.last_zcr,
centroid_hz: self.last_centroid_hz,
is_speaking: self.is_speaking,
is_voiced_pending: self.is_pending_voiced,
is_whisper_pending: self.is_pending_whisper,
is_transient: self.last_is_transient,
is_lookback_speech: false,
lookback_offset_ms: self.last_lookback_offset_ms,
is_word_break: self.last_is_word_break,
}
}
pub fn take_state_change(&mut self) -> SpeechStateChange {
std::mem::replace(&mut self.last_state_change, SpeechStateChange::None)
}
pub fn take_word_break_event(&mut self) -> Option<WordBreakEvent> {
self.last_word_break_event.take()
}
#[allow(dead_code)]
fn update_speech_amplitude_average(&mut self, rms: f32, sample_count: u32) {
self.recent_speech_amplitude_sum += rms * sample_count as f32;
self.recent_speech_amplitude_count += sample_count;
if self.recent_speech_amplitude_count > self.recent_speech_window_samples {
let scale = self.recent_speech_window_samples as f32
/ self.recent_speech_amplitude_count as f32;
self.recent_speech_amplitude_sum *= scale;
self.recent_speech_amplitude_count = self.recent_speech_window_samples;
}
}
#[allow(dead_code)]
fn get_recent_speech_amplitude(&self) -> f32 {
if self.recent_speech_amplitude_count == 0 {
return 0.0;
}
self.recent_speech_amplitude_sum / self.recent_speech_amplitude_count as f32
}
#[allow(dead_code)]
fn reset_word_break_state(&mut self) {
self.in_word_break = false;
self.word_break_sample_count = 0;
self.word_break_start_speech_samples = 0;
self.recent_speech_amplitude_sum = 0.0;
self.recent_speech_amplitude_count = 0;
self.last_is_word_break = false;
self.last_word_break_event = None;
}
pub fn process(&mut self, samples: &[f32]) {
self.last_state_change = SpeechStateChange::None;
self.last_word_break_event = None;
self.push_to_lookback_buffer(samples);
let rms = Self::calculate_rms(samples);
let db = Self::amplitude_to_db(rms);
let zcr = Self::calculate_zcr(samples);
let centroid = self.estimate_spectral_centroid(samples, db);
self.last_amplitude_db = db;
self.last_zcr = zcr;
self.last_centroid_hz = centroid;
self.last_is_transient = false; self.last_lookback_offset_ms = None;
self.last_is_word_break = false;
if !self.initialized {
self.initialized = true;
return;
}
let is_speech_candidate = db >= self.voiced_config.threshold_db;
let samples_len = samples.len() as u32;
if is_speech_candidate {
self.silence_sample_count = 0;
if self.is_speaking {
self.speech_sample_count += samples.len() as u64;
} else {
self.voiced_onset_count += samples_len;
if self.voiced_onset_count >= self.voiced_config.onset_samples {
self.is_speaking = true;
self.speech_sample_count = self.voiced_onset_count as u64;
self.voiced_onset_count = 0;
let (lookback_samples, lookback_offset_ms) = self.find_lookback_start();
self.last_lookback_offset_ms = Some(lookback_offset_ms);
self.last_state_change = SpeechStateChange::Started {
lookback_samples: lookback_samples.len(),
};
let payload = SpeechEventPayload {
duration_ms: None,
lookback_offset_ms: Some(lookback_offset_ms),
};
if let Some(ref callback) = self.callback {
callback.on_speech_started(payload);
}
tracing::debug!(
"Speech started (amplitude mode, lookback: {}ms)",
lookback_offset_ms
);
}
}
} else {
self.voiced_onset_count = 0;
if self.is_speaking {
self.silence_sample_count += samples_len;
self.speech_sample_count += samples.len() as u64;
if self.silence_sample_count >= self.hold_samples {
let duration_ms = self.samples_to_ms(self.speech_sample_count);
self.is_speaking = false;
self.speech_sample_count = 0;
self.last_state_change = SpeechStateChange::Ended { duration_ms };
let payload = SpeechEventPayload {
duration_ms: Some(duration_ms),
lookback_offset_ms: None,
};
if let Some(ref callback) = self.callback {
callback.on_speech_ended(payload);
}
tracing::debug!("Speech ended (duration: {}ms)", duration_ms);
}
}
}
}
}
#[derive(Clone, Debug, Serialize)]
pub struct SpectrogramColumn {
pub colors: Vec<u8>,
}
#[derive(Clone, Debug, Serialize)]
pub struct VisualizationPayload {
pub waveform: Vec<f32>,
pub spectrogram: Vec<SpectrogramColumn>,
pub speech_metrics: Option<SpeechMetrics>,
}
pub trait VisualizationCallback: Send {
fn on_visualization_data(&self, payload: VisualizationPayload);
}
struct ColorStop {
position: f32,
r: u8,
g: u8,
b: u8,
}
pub struct VisualizationProcessor {
sample_rate: u32,
output_height: usize,
fft_size: usize,
fft: Arc<dyn rustfft::Fft<f32>>,
hanning_window: Vec<f32>,
fft_buffer: Vec<f32>,
fft_write_index: usize,
color_lut: Vec<[u8; 3]>,
waveform_buffer: Vec<f32>,
waveform_target_samples: usize,
pending_speech_metrics: Option<SpeechMetrics>,
callback: Option<Arc<dyn VisualizationCallback>>,
}
impl VisualizationProcessor {
pub fn new(sample_rate: u32, output_height: usize) -> Self {
let fft_size = 512;
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(fft_size);
let hanning_window: Vec<f32> = (0..fft_size)
.map(|i| {
0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / (fft_size - 1) as f32).cos())
})
.collect();
let color_lut = Self::build_color_lut();
Self {
sample_rate,
output_height,
fft_size,
fft,
hanning_window,
fft_buffer: Vec::with_capacity(fft_size),
fft_write_index: 0,
color_lut,
waveform_buffer: Vec::with_capacity(256),
waveform_target_samples: 64,
pending_speech_metrics: None,
callback: None,
}
}
pub fn set_callback(&mut self, callback: Arc<dyn VisualizationCallback>) {
self.callback = Some(callback);
}
pub fn set_speech_metrics(&mut self, metrics: SpeechMetrics) {
self.pending_speech_metrics = Some(metrics);
}
fn build_color_lut() -> Vec<[u8; 3]> {
let stops = [
ColorStop {
position: 0.00,
r: 10,
g: 15,
b: 26,
},
ColorStop {
position: 0.15,
r: 0,
g: 50,
b: 200,
},
ColorStop {
position: 0.35,
r: 0,
g: 255,
b: 150,
},
ColorStop {
position: 0.60,
r: 200,
g: 255,
b: 0,
},
ColorStop {
position: 0.80,
r: 255,
g: 155,
b: 0,
},
ColorStop {
position: 1.00,
r: 255,
g: 0,
b: 0,
},
];
let mut lut = Vec::with_capacity(256);
for i in 0..256 {
let t_raw = i as f32 / 255.0;
let t = t_raw.powf(0.7);
let mut color = [255u8, 0, 0];
for j in 0..stops.len() - 1 {
let s1 = &stops[j];
let s2 = &stops[j + 1];
if t >= s1.position && t <= s2.position {
let s = (t - s1.position) / (s2.position - s1.position);
color[0] = (s1.r as f32 + s * (s2.r as f32 - s1.r as f32)).round() as u8;
color[1] = (s1.g as f32 + s * (s2.g as f32 - s1.g as f32)).round() as u8;
color[2] = (s1.b as f32 + s * (s2.b as f32 - s1.b as f32)).round() as u8;
break;
}
}
lut.push(color);
}
lut
}
fn position_to_freq_bin(&self, pos: f32, num_bins: usize) -> f32 {
const MIN_FREQ: f32 = 20.0;
const MAX_FREQ: f32 = 24000.0;
let min_log = MIN_FREQ.log10();
let max_log = MAX_FREQ.log10();
let log_freq = min_log + pos * (max_log - min_log);
let freq = 10.0f32.powf(log_freq);
let bin_index = freq * self.fft_size as f32 / self.sample_rate as f32;
bin_index.clamp(0.0, (num_bins - 1) as f32)
}
fn get_magnitude_for_pixel(&self, magnitudes: &[f32], y: usize, height: usize) -> f32 {
let num_bins = magnitudes.len();
let pos1 = (height - 1 - y) as f32 / height as f32;
let pos2 = (height - y) as f32 / height as f32;
let bin1 = self.position_to_freq_bin(pos1, num_bins);
let bin2 = self.position_to_freq_bin(pos2, num_bins);
let bin_low = bin1.min(bin2).max(0.0);
let bin_high = bin1.max(bin2).min((num_bins - 1) as f32);
if bin_high - bin_low < 1.0 {
let bin_floor = bin_low.floor() as usize;
let bin_ceil = (bin_floor + 1).min(num_bins - 1);
let frac = bin_low - bin_floor as f32;
return magnitudes[bin_floor] * (1.0 - frac) + magnitudes[bin_ceil] * frac;
}
let mut sum = 0.0f32;
let mut weight = 0.0f32;
let start_bin = bin_low.floor() as usize;
let end_bin = bin_high.ceil() as usize;
#[allow(clippy::needless_range_loop)]
for b in start_bin..=end_bin.min(num_bins - 1) {
let bin_start = b as f32;
let bin_end = (b + 1) as f32;
let overlap_start = bin_low.max(bin_start);
let overlap_end = bin_high.min(bin_end);
let overlap_weight = (overlap_end - overlap_start).max(0.0);
if overlap_weight > 0.0 {
sum += magnitudes[b] * overlap_weight;
weight += overlap_weight;
}
}
if weight > 0.0 {
sum / weight
} else {
0.0
}
}
fn process_fft(&self) -> SpectrogramColumn {
let mut complex_buffer: Vec<Complex<f32>> = self
.fft_buffer
.iter()
.zip(self.hanning_window.iter())
.map(|(&sample, &window)| Complex::new(sample * window, 0.0))
.collect();
complex_buffer.resize(self.fft_size, Complex::new(0.0, 0.0));
self.fft.process(&mut complex_buffer);
let num_bins = self.fft_size / 2;
let magnitudes: Vec<f32> = complex_buffer[..num_bins]
.iter()
.map(|c| (c.re * c.re + c.im * c.im).sqrt() / self.fft_size as f32)
.collect();
let max_mag = magnitudes.iter().cloned().fold(0.001f32, f32::max);
let ref_level = max_mag.max(0.05);
let mut colors = Vec::with_capacity(self.output_height * 3);
for y in 0..self.output_height {
let magnitude = self.get_magnitude_for_pixel(&magnitudes, y, self.output_height);
let normalized_db = (1.0 + magnitude / ref_level * 9.0).log10();
let normalized = normalized_db.clamp(0.0, 1.0);
let color_idx = (normalized * 255.0).floor() as usize;
let color = &self.color_lut[color_idx.min(255)];
colors.push(color[0]);
colors.push(color[1]);
colors.push(color[2]);
}
SpectrogramColumn { colors }
}
fn downsample_waveform(&self, samples: &[f32]) -> Vec<f32> {
if samples.is_empty() {
return Vec::new();
}
let window_size = (samples.len() / self.waveform_target_samples).max(1);
let output_count = samples.len().div_ceil(window_size);
let mut output = Vec::with_capacity(output_count);
for chunk in samples.chunks(window_size) {
let peak = chunk
.iter()
.max_by(|a, b| a.abs().partial_cmp(&b.abs()).unwrap())
.copied()
.unwrap_or(0.0);
output.push(peak);
}
output
}
pub fn process(&mut self, samples: &[f32]) -> Option<crate::VisualizationData> {
let mut spectrogram_columns: Vec<SpectrogramColumn> = Vec::new();
for &sample in samples {
if self.fft_buffer.len() <= self.fft_write_index {
self.fft_buffer.push(sample);
} else {
self.fft_buffer[self.fft_write_index] = sample;
}
self.fft_write_index += 1;
if self.fft_write_index >= self.fft_size {
spectrogram_columns.push(self.process_fft());
self.fft_write_index = 0;
}
}
self.waveform_buffer.extend_from_slice(samples);
let waveform = self.downsample_waveform(&self.waveform_buffer);
self.waveform_buffer.clear();
let speech_metrics = self.pending_speech_metrics.take();
let payload = VisualizationPayload {
waveform: waveform.clone(),
spectrogram: spectrogram_columns.clone(),
speech_metrics: speech_metrics.clone(),
};
if let Some(ref callback) = self.callback {
callback.on_visualization_data(payload);
}
let frame_interval_ms = samples.len() as f32 / self.sample_rate as f32 * 1000.0;
let viz = crate::VisualizationData {
waveform,
spectrogram: spectrogram_columns
.into_iter()
.map(|s| crate::SpectrogramColumn { colors: s.colors })
.collect(),
speech_metrics: speech_metrics.map(|m| crate::SpeechMetrics {
amplitude_db: m.amplitude_db,
zcr: m.zcr,
centroid_hz: m.centroid_hz,
is_speaking: m.is_speaking,
voiced_onset_pending: m.is_voiced_pending,
whisper_onset_pending: m.is_whisper_pending,
is_transient: m.is_transient,
is_lookback_speech: m.is_lookback_speech,
is_word_break: m.is_word_break,
}),
sample_rate: self.sample_rate,
frame_interval_ms,
};
Some(viz)
}
}
#[cfg(test)]
mod tests {
use super::*;
const SAMPLE_RATE: u32 = 48000;
const CHUNK_SIZE: usize = 480;
fn sine_chunk(freq_hz: f32, amplitude: f32, sample_rate: u32, num_samples: usize) -> Vec<f32> {
(0..num_samples)
.map(|i| {
amplitude
* (2.0 * std::f32::consts::PI * freq_hz * i as f32 / sample_rate as f32).sin()
})
.collect()
}
fn silence_chunk(num_samples: usize) -> Vec<f32> {
vec![0.0; num_samples]
}
fn feed_ms(detector: &mut SpeechDetector, chunk: &[f32], duration_ms: u32) {
let chunks_needed =
(SAMPLE_RATE as u64 * duration_ms as u64 / 1000 / CHUNK_SIZE as u64) as u32;
for _ in 0..chunks_needed {
detector.process(chunk);
}
}
#[test]
fn voiced_speech_detected_after_onset() {
let mut det = SpeechDetector::new(SAMPLE_RATE);
let chunk = sine_chunk(300.0, 0.032, SAMPLE_RATE, CHUNK_SIZE);
for _ in 0..10 {
det.process(&chunk);
}
assert!(
det.is_speaking,
"Speech should be detected after 100ms of voiced audio"
);
}
#[test]
fn rms_of_silence_is_zero() {
let silence = silence_chunk(CHUNK_SIZE);
let rms = SpeechDetector::calculate_rms(&silence);
assert_eq!(rms, 0.0);
}
#[test]
fn zcr_of_silence_is_zero() {
let silence = silence_chunk(CHUNK_SIZE);
let zcr = SpeechDetector::calculate_zcr(&silence);
assert_eq!(zcr, 0.0);
}
#[test]
fn rms_of_sine_matches_theory() {
let amplitude = 0.5;
let chunk = sine_chunk(440.0, amplitude, SAMPLE_RATE, CHUNK_SIZE * 10); let rms = SpeechDetector::calculate_rms(&chunk);
let expected = amplitude / 2.0f32.sqrt();
assert!(
(rms - expected).abs() < 0.01,
"RMS of sine should be ~{}, got {}",
expected,
rms
);
}
#[test]
fn silence_ends_speech_after_hold_time() {
let mut det = SpeechDetector::new(SAMPLE_RATE);
let speech = sine_chunk(300.0, 0.032, SAMPLE_RATE, CHUNK_SIZE);
let silence = silence_chunk(CHUNK_SIZE);
feed_ms(&mut det, &speech, 200);
assert!(det.is_speaking, "Speech should be active after onset");
feed_ms(&mut det, &silence, 210);
assert!(!det.is_speaking, "Speech should end after hold time");
}
#[test]
fn brief_silence_does_not_end_speech() {
let mut det = SpeechDetector::new(SAMPLE_RATE);
let speech = sine_chunk(300.0, 0.032, SAMPLE_RATE, CHUNK_SIZE);
let silence = silence_chunk(CHUNK_SIZE);
feed_ms(&mut det, &speech, 200);
assert!(det.is_speaking);
feed_ms(&mut det, &silence, 100);
assert!(det.is_speaking, "Brief silence should not end speech");
feed_ms(&mut det, &speech, 50);
assert!(
det.is_speaking,
"Speech should still be active after brief gap"
);
}
#[test]
fn below_threshold_audio_does_not_trigger_speech() {
let mut det = SpeechDetector::new(SAMPLE_RATE);
let quiet = sine_chunk(300.0, 0.003, SAMPLE_RATE, CHUNK_SIZE);
feed_ms(&mut det, &quiet, 500);
assert!(
!det.is_speaking,
"Audio below threshold should not trigger speech"
);
}
#[test]
fn dump_wav_metrics() {
use std::io::Write;
let wav_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../..")
.join("exampe.wav");
if !wav_path.exists() {
eprintln!("Skipping dump_wav_metrics: {:?} not found", wav_path);
return;
}
let reader = hound::WavReader::open(&wav_path)
.unwrap_or_else(|e| panic!("Failed to open {:?}: {}", wav_path, e));
let spec = reader.spec();
let wav_sr = spec.sample_rate;
let wav_ch = spec.channels as usize;
eprintln!(
"WAV: {}Hz, {} ch, {} bit, {:?}",
wav_sr, wav_ch, spec.bits_per_sample, spec.sample_format
);
let raw_samples: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Float => reader
.into_samples::<f32>()
.filter_map(|s| s.ok())
.collect(),
hound::SampleFormat::Int => {
let bits = spec.bits_per_sample;
let max_val = (1u32 << (bits - 1)) as f32;
reader
.into_samples::<i32>()
.filter_map(|s| s.ok())
.map(|s| s as f32 / max_val)
.collect()
}
};
let mono: Vec<f32> = if wav_ch > 1 {
crate::audio::convert_to_mono(&raw_samples, wav_ch)
} else {
raw_samples
};
eprintln!(
"Mono samples: {} ({:.2}s)",
mono.len(),
mono.len() as f64 / wav_sr as f64
);
let chunk_size = (wav_sr as usize) / 100; let mut det = SpeechDetector::new(wav_sr);
let csv_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../..")
.join("speech_metrics.csv");
let mut csv = std::fs::File::create(&csv_path)
.unwrap_or_else(|e| panic!("Cannot create {:?}: {}", csv_path, e));
writeln!(
csv,
"frame,time_ms,amplitude_db,zcr,centroid_hz,is_speaking,is_voiced_pending,is_whisper_pending,is_transient,is_word_break,rms_linear"
)
.unwrap();
let mut frame_idx = 0u32;
for chunk in mono.chunks(chunk_size) {
let time_ms = frame_idx as f64 * 10.0;
let rms = SpeechDetector::calculate_rms(chunk);
let db = SpeechDetector::amplitude_to_db(rms);
let _zcr = SpeechDetector::calculate_zcr(chunk);
let _centroid = det.estimate_spectral_centroid(chunk, db);
det.process(chunk);
let m = det.get_metrics();
writeln!(
csv,
"{},{:.1},{:.2},{:.4},{:.1},{},{},{},{},{},{:.6}",
frame_idx,
time_ms,
m.amplitude_db,
m.zcr,
m.centroid_hz,
m.is_speaking as u8,
m.is_voiced_pending as u8,
m.is_whisper_pending as u8,
m.is_transient as u8,
m.is_word_break as u8,
rms,
)
.unwrap();
frame_idx += 1;
}
eprintln!("Wrote {} frames to {:?}", frame_idx, csv_path);
let expected_frames = (mono.len() + chunk_size - 1) / chunk_size;
assert_eq!(frame_idx as usize, expected_frames);
}
}
const AGC_NOISE_FLOOR_POWER: f32 = 1e-10;
const AGC_EVENT_INTERVAL_CHUNKS: u32 = 10;
const AGC_GATE_DECAY_TIME_MS: f32 = 500.0;
const AGC_MAX_GAIN_RISE_DB_PER_CHUNK: f32 = 3.0;
pub struct AgcProcessor {
power_estimate: f32,
current_gain_linear: f32,
config: AgcConfig,
gate_threshold_power: f32,
boost_threshold_power: f32,
chunks_since_event: u32,
hold_timer_ms: f32,
last_power_region: PowerRegion,
gate_hold_time_ms: f32,
}
#[derive(Clone, Copy, Debug, PartialEq)]
enum PowerRegion {
BelowNoiseFloor,
GateRegion,
AboveThreshold,
}
impl AgcProcessor {
pub fn new(config: AgcConfig) -> Self {
let gate_threshold_power = db_to_power(config.gate_threshold_db);
let boost_threshold_power = db_to_power(config.boost_threshold_db);
let gate_hold_time_ms = config.gate_hold_time_ms;
Self {
power_estimate: 1e-6,
current_gain_linear: 1.0,
config,
gate_threshold_power,
boost_threshold_power,
chunks_since_event: 0,
hold_timer_ms: gate_hold_time_ms,
last_power_region: PowerRegion::BelowNoiseFloor,
gate_hold_time_ms,
}
}
pub fn update_config(&mut self, config: AgcConfig) {
self.gate_threshold_power = db_to_power(config.gate_threshold_db);
self.boost_threshold_power = db_to_power(config.boost_threshold_db);
self.gate_hold_time_ms = config.gate_hold_time_ms;
self.config = config;
}
pub fn process(&mut self, samples: &mut [f32], _sample_rate: u32) -> Option<f32> {
if !self.config.enabled || samples.is_empty() {
return None;
}
let chunk_power: f32 = {
let sum_sq: f32 = samples.iter().map(|s| s * s).sum();
sum_sq / samples.len() as f32
};
let chunk_duration_s = samples.len() as f32 / _sample_rate as f32;
let alpha = if chunk_power > self.power_estimate {
let tau = self.config.attack_time_ms / 1000.0;
(-chunk_duration_s / tau).exp()
} else {
let tau = self.config.release_time_ms / 1000.0;
(-chunk_duration_s / tau).exp()
};
self.power_estimate = alpha * self.power_estimate + (1.0 - alpha) * chunk_power;
let current_region = if self.power_estimate > self.gate_threshold_power {
PowerRegion::AboveThreshold
} else if self.power_estimate > AGC_NOISE_FLOOR_POWER {
PowerRegion::GateRegion
} else {
PowerRegion::BelowNoiseFloor
};
let boost_ready = self.power_estimate >= self.boost_threshold_power;
let chunk_duration_ms = chunk_duration_s * 1000.0;
if boost_ready {
if self.hold_timer_ms > 0.0 {
self.hold_timer_ms -= chunk_duration_ms;
if self.hold_timer_ms < 0.0 {
self.hold_timer_ms = 0.0;
}
}
} else {
self.hold_timer_ms = self.gate_hold_time_ms;
}
let hold_active = boost_ready && self.hold_timer_ms > 0.0;
if hold_active {
let target_rms = db_to_linear(self.config.target_level_db);
let current_rms = self.power_estimate.sqrt();
let raw_gain = target_rms / current_rms;
let min_gain = db_to_linear(self.config.min_gain_db);
let max_gain = db_to_linear(self.config.max_gain_db);
let computed_gain = raw_gain.clamp(min_gain, max_gain);
self.current_gain_linear = self.current_gain_linear.min(computed_gain);
} else if boost_ready {
let target_rms = db_to_linear(self.config.target_level_db);
let current_rms = self.power_estimate.sqrt();
let raw_gain = target_rms / current_rms;
let min_gain = db_to_linear(self.config.min_gain_db);
let max_gain = db_to_linear(self.config.max_gain_db);
let target_gain = raw_gain.clamp(min_gain, max_gain);
let max_rise_linear = db_to_linear(AGC_MAX_GAIN_RISE_DB_PER_CHUNK);
let allowed_gain = self.current_gain_linear * max_rise_linear;
self.current_gain_linear = if target_gain > self.current_gain_linear {
target_gain.min(allowed_gain)
} else {
target_gain
};
} else if current_region != PowerRegion::BelowNoiseFloor {
let tau = AGC_GATE_DECAY_TIME_MS / 1000.0;
let decay_alpha = (-chunk_duration_s / tau).exp();
self.current_gain_linear =
decay_alpha * self.current_gain_linear + (1.0 - decay_alpha) * 1.0;
}
self.last_power_region = current_region;
let g = self.current_gain_linear;
for s in samples.iter_mut() {
*s = (*s * g).clamp(-1.0, 1.0);
}
self.chunks_since_event += 1;
if self.chunks_since_event >= AGC_EVENT_INTERVAL_CHUNKS {
self.chunks_since_event = 0;
Some(self.current_gain_db())
} else {
None
}
}
pub fn current_gain_db(&self) -> f32 {
linear_to_db(self.current_gain_linear)
}
}
#[inline]
fn db_to_linear(db: f32) -> f32 {
10f32.powf(db / 20.0)
}
#[inline]
fn db_to_power(db: f32) -> f32 {
10f32.powf(db / 10.0)
}
#[inline]
fn linear_to_db(linear: f32) -> f32 {
if linear <= 0.0 {
f32::NEG_INFINITY
} else {
20.0 * linear.log10()
}
}
#[cfg(test)]
mod agc_tests {
use super::*;
fn sine_wave(rms: f32, num_samples: usize, sample_rate: u32) -> Vec<f32> {
let amplitude = rms * 2f32.sqrt();
let freq = 440.0_f32;
(0..num_samples)
.map(|i| {
amplitude
* (2.0 * std::f32::consts::PI * freq * i as f32 / sample_rate as f32).sin()
})
.collect()
}
fn rms(samples: &[f32]) -> f32 {
let sum_sq: f32 = samples.iter().map(|s| s * s).sum();
(sum_sq / samples.len() as f32).sqrt()
}
fn default_config() -> AgcConfig {
AgcConfig {
enabled: true,
target_level_db: -18.0,
attack_time_ms: 10.0,
release_time_ms: 200.0,
min_gain_db: -6.0,
max_gain_db: 30.0,
gate_threshold_db: -50.0,
boost_threshold_db: -40.0,
gate_hold_time_ms: 50.0,
}
}
#[test]
fn agc_unity_gain_convergence() {
let sample_rate = 16000_u32;
let target_rms = db_to_linear(-18.0);
let cfg = default_config();
let mut proc = AgcProcessor::new(cfg);
let chunk_size = 160; let num_chunks = 50;
for _ in 0..num_chunks {
let mut chunk = sine_wave(target_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_db = proc.current_gain_db();
assert!(
gain_db.abs() < 1.5,
"Expected gain near 0 dB after convergence, got {:.2} dB",
gain_db
);
}
#[test]
fn agc_increases_gain_for_quiet_input() {
let sample_rate = 16000_u32;
let quiet_rms = db_to_linear(-40.0);
let cfg = default_config();
let mut proc = AgcProcessor::new(cfg);
let chunk_size = 160;
for _ in 0..100 {
let mut chunk = sine_wave(quiet_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_db = proc.current_gain_db();
assert!(
gain_db > 10.0,
"Expected AGC to boost quiet input (>10 dB), got {:.2} dB",
gain_db
);
}
#[test]
fn agc_decreases_gain_for_loud_input() {
let sample_rate = 16000_u32;
let loud_rms = db_to_linear(-1.0);
let cfg = default_config();
let mut proc = AgcProcessor::new(cfg);
let chunk_size = 160;
for _ in 0..50 {
let mut chunk = sine_wave(loud_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_db = proc.current_gain_db();
assert!(
gain_db < -1.0,
"Expected AGC to attenuate loud input (<-1 dB), got {:.2} dB",
gain_db
);
}
#[test]
fn agc_clamps_to_max_gain() {
let sample_rate = 16000_u32;
let cfg = AgcConfig {
max_gain_db: 10.0,
..default_config()
};
let mut proc = AgcProcessor::new(cfg);
let chunk_size = 160;
let near_silence_rms = db_to_linear(-60.0);
for _ in 0..200 {
let mut chunk = sine_wave(near_silence_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_db = proc.current_gain_db();
assert!(
gain_db <= 10.0 + 1e-3,
"Gain {:.2} dB exceeds max_gain_db=10.0",
gain_db
);
}
#[test]
fn agc_silence_does_not_explode() {
let sample_rate = 16000_u32;
let cfg = default_config();
let max_gain_db = cfg.max_gain_db;
let mut proc = AgcProcessor::new(cfg);
let chunk_size = 160;
for _ in 0..500 {
let mut chunk = vec![0.0f32; chunk_size];
proc.process(&mut chunk, sample_rate);
}
let gain_db = proc.current_gain_db();
assert!(
gain_db.is_finite(),
"gain_db should be finite on silence, got {}",
gain_db
);
assert!(
gain_db <= max_gain_db + 1e-3,
"gain {:.2} dB exceeds max_gain_db={} on silence",
gain_db,
max_gain_db
);
}
#[test]
fn agc_output_level_near_target() {
let sample_rate = 16000_u32;
let target_db = -18.0_f32;
let input_rms = db_to_linear(-35.0); let cfg = default_config();
let mut proc = AgcProcessor::new(cfg);
let chunk_size = 160;
for _ in 0..200 {
let mut chunk = sine_wave(input_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let mut output: Vec<f32> = sine_wave(input_rms, chunk_size * 10, sample_rate);
proc.process(&mut output, sample_rate);
let out_rms_db = 20.0 * rms(&output).log10();
assert!(
(out_rms_db - target_db).abs() < 3.0,
"Output RMS {:.1} dBFS not within 3 dB of target {:.1} dBFS",
out_rms_db,
target_db
);
}
#[test]
fn agc_gate_decays_gain_on_noise() {
let sample_rate = 16000_u32;
let chunk_size = 160;
let cfg = default_config();
let mut proc = AgcProcessor::new(cfg);
let speech_rms = db_to_linear(-30.0);
for _ in 0..100 {
let mut chunk = sine_wave(speech_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_after_speech = proc.current_gain_db();
let noise_rms = db_to_linear(-55.0);
for _ in 0..300 {
let mut chunk = sine_wave(noise_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_after_noise = proc.current_gain_db();
assert!(
gain_after_noise < gain_after_speech - 5.0,
"Expected gain to decay significantly, got {:.2} dB (was {:.2} dB after speech)",
gain_after_noise,
gain_after_speech
);
assert!(
gain_after_noise.abs() < 5.0,
"Expected gain near 0 dB after gate decay, got {:.2} dB",
gain_after_noise
);
}
#[test]
fn agc_gate_normal_processing_above_threshold() {
let sample_rate = 16000_u32;
let chunk_size = 160;
let cfg = default_config();
let mut proc = AgcProcessor::new(cfg);
let quiet_speech_rms = db_to_linear(-35.0);
for _ in 0..100 {
let mut chunk = sine_wave(quiet_speech_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_db = proc.current_gain_db();
assert!(
gain_db > 10.0,
"Expected AGC to boost quiet speech above boost threshold (>10 dB), got {:.2} dB",
gain_db
);
}
#[test]
fn agc_does_not_boost_borderline_noise() {
let sample_rate = 16000_u32;
let chunk_size = 160;
let cfg = default_config();
let mut proc = AgcProcessor::new(cfg);
let borderline_noise_rms = db_to_linear(-45.0);
for _ in 0..200 {
let mut chunk = sine_wave(borderline_noise_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_db = proc.current_gain_db();
assert!(
gain_db < 6.0,
"Expected borderline noise to stay near unity, got {:.2} dB",
gain_db
);
}
#[test]
fn agc_gate_smooth_resumption() {
let sample_rate = 16000_u32;
let chunk_size = 160;
let cfg = default_config();
let mut proc = AgcProcessor::new(cfg);
let speech_rms = db_to_linear(-30.0);
for _ in 0..100 {
let mut chunk = sine_wave(speech_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let noise_rms = db_to_linear(-55.0);
for _ in 0..200 {
let mut chunk = sine_wave(noise_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_before_resumption = proc.current_gain_db();
let mut prev_gain = gain_before_resumption;
for _ in 0..50 {
let mut chunk = sine_wave(speech_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
let current = proc.current_gain_db();
let delta = (current - prev_gain).abs();
assert!(
delta < 6.0,
"Gain jumped {:.2} dB between chunks (from {:.2} to {:.2}), expected smooth transition",
delta,
prev_gain,
current
);
prev_gain = current;
}
}
#[test]
fn agc_gate_very_low_threshold_preserves_behavior() {
let sample_rate = 16000_u32;
let chunk_size = 160;
let cfg = AgcConfig {
gate_threshold_db: -100.0,
boost_threshold_db: -100.0,
..default_config()
};
let mut proc = AgcProcessor::new(cfg);
let noise_rms = db_to_linear(-55.0);
for _ in 0..200 {
let mut chunk = sine_wave(noise_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_db = proc.current_gain_db();
assert!(
gain_db > 20.0,
"Expected AGC to boost noise when gate is disabled (>20 dB), got {:.2} dB",
gain_db
);
}
#[test]
fn agc_gate_hold_time_prevents_noise_burst() {
let sample_rate = 16000_u32;
let chunk_size = 160;
let cfg = AgcConfig {
gate_hold_time_ms: 100.0,
..default_config()
};
let mut proc = AgcProcessor::new(cfg);
let noise_rms = db_to_linear(-55.0);
for _ in 0..50 {
let mut chunk = sine_wave(noise_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_after_decay = proc.current_gain_db();
assert!(
gain_after_decay.abs() < 2.0,
"Expected gain near unity after noise decay, got {:.2} dB",
gain_after_decay
);
let speech_rms = db_to_linear(-30.0);
for _ in 0..5 {
let mut chunk = sine_wave(speech_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_during_hold = proc.current_gain_db();
assert!(
gain_during_hold.abs() < 3.0,
"Expected gain near unity during hold time, got {:.2} dB",
gain_during_hold
);
for _ in 0..10 {
let mut chunk = sine_wave(speech_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_after_hold = proc.current_gain_db();
assert!(
gain_after_hold > 5.0,
"Expected gain to increase after hold time expires, got {:.2} dB",
gain_after_hold
);
}
#[test]
fn agc_zero_hold_time_legacy_behavior() {
let sample_rate = 16000_u32;
let chunk_size = 160;
let cfg = AgcConfig {
gate_hold_time_ms: 0.0,
..default_config()
};
let mut proc = AgcProcessor::new(cfg);
let noise_rms = db_to_linear(-55.0);
for _ in 0..50 {
let mut chunk = sine_wave(noise_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_after_decay = proc.current_gain_db();
assert!(
gain_after_decay.abs() < 2.0,
"Expected gain near unity after noise decay, got {:.2} dB",
gain_after_decay
);
let speech_rms = db_to_linear(-30.0);
for _ in 0..5 {
let mut chunk = sine_wave(speech_rms, chunk_size, sample_rate);
proc.process(&mut chunk, sample_rate);
}
let gain_immediate = proc.current_gain_db();
assert!(
gain_immediate > 5.0,
"Expected immediate gain increase with 0 hold time, got {:.2} dB",
gain_immediate
);
}
}