use std::collections::VecDeque;
use std::time::Instant;
use super::dtw_algorithm::dtw_distance;
use super::mfcc::MfccExtractor;
use super::{WakeWordDetection, WakeWordDetector};
use crate::audio::error::{AudioError, AudioResult};
const DEFAULT_SAMPLE_RATE: u32 = 16_000;
const DEFAULT_FRAME_LEN: usize = 400; const DEFAULT_HOP: usize = 160; const DEFAULT_NUM_COEFFS: usize = 12;
const DEFAULT_THRESHOLD: f32 = 30.0;
const MIN_ENROLL_SAMPLES: usize = 3_200; const MAX_ENROLL_SAMPLES: usize = 32_000;
pub struct DtwWakeWordDetector {
threshold: f32,
extractor: MfccExtractor,
templates: Vec<Vec<Vec<f32>>>,
longest_template_frames: usize,
rolling: VecDeque<i16>,
start: Instant,
}
impl Default for DtwWakeWordDetector {
fn default() -> Self {
Self::new()
}
}
impl DtwWakeWordDetector {
pub fn new() -> Self {
Self {
threshold: DEFAULT_THRESHOLD,
extractor: MfccExtractor::new(
DEFAULT_SAMPLE_RATE,
DEFAULT_FRAME_LEN,
DEFAULT_HOP,
DEFAULT_NUM_COEFFS,
),
templates: Vec::new(),
longest_template_frames: 0,
rolling: VecDeque::new(),
start: Instant::now(),
}
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
pub fn enroll_template(&mut self, samples: &[i16]) -> AudioResult<()> {
if samples.len() < MIN_ENROLL_SAMPLES {
return Err(AudioError::Format(format!(
"wake-word enrollment too short: {} samples (need ≥{})",
samples.len(),
MIN_ENROLL_SAMPLES
)));
}
if samples.len() > MAX_ENROLL_SAMPLES {
return Err(AudioError::Format(format!(
"wake-word enrollment too long: {} samples (max {})",
samples.len(),
MAX_ENROLL_SAMPLES
)));
}
let mfcc = self.extractor.extract(samples);
if mfcc.is_empty() {
return Err(AudioError::Format(
"wake-word enrollment yielded zero MFCC frames".into(),
));
}
if mfcc.len() > self.longest_template_frames {
self.longest_template_frames = mfcc.len();
}
self.templates.push(mfcc);
Ok(())
}
pub fn reset_window(&mut self) {
self.rolling.clear();
}
pub fn template_count(&self) -> usize {
self.templates.len()
}
pub fn threshold(&self) -> f32 {
self.threshold
}
fn dtw_best(&mut self, samples: &[i16]) -> Option<f32> {
if self.templates.is_empty() {
return None;
}
for &s in samples {
self.rolling.push_back(s);
}
let max_samples = (self.longest_template_frames as f32 * 1.5) as usize
* self.extractor.hop()
+ self.extractor.frame_len();
while self.rolling.len() > max_samples {
self.rolling.pop_front();
}
if self.rolling.len() < self.extractor.frame_len() {
return None;
}
let buf: Vec<i16> = self.rolling.iter().copied().collect();
let live = self.extractor.extract(&buf);
if live.is_empty() {
return None;
}
let mut best: Option<f32> = None;
for template in &self.templates {
let tlen = template.len();
if tlen == 0 {
continue;
}
let window: &[Vec<f32>] = if live.len() >= tlen {
&live[live.len() - tlen..]
} else {
&live[..]
};
let d = dtw_distance(window, template);
if d.is_finite() && best.is_none_or(|b| d < b) {
best = Some(d);
}
}
best
}
}
impl WakeWordDetector for DtwWakeWordDetector {
fn sample_rate(&self) -> u32 {
self.extractor.sample_rate()
}
fn frame_size(&self) -> usize {
self.extractor.frame_len()
}
fn process_frame(&mut self, samples: &[i16]) -> Option<WakeWordDetection> {
let d = self.dtw_best(samples)?;
if d >= self.threshold {
return None;
}
let timestamp_ms = self.start.elapsed().as_millis() as u64;
Some(WakeWordDetection {
keyword: "wake".to_string(),
score: (1.0 - d / self.threshold).clamp(0.0, 1.0),
timestamp_ms,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f32::consts::PI;
fn sine_16k(samples: usize, freq_hz: f32) -> Vec<i16> {
(0..samples)
.map(|n| {
let t = n as f32 / 16_000.0;
((2.0 * PI * freq_hz * t).sin() * 10_000.0) as i16
})
.collect()
}
fn white_noise_16k(samples: usize, seed: u32) -> Vec<i16> {
let mut state = seed as u64;
(0..samples)
.map(|_| {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
((state >> 33) as i32 % 16_000) as i16
})
.collect()
}
#[test]
fn detector_template_count_matches_enrollments() {
let mut d = DtwWakeWordDetector::new();
d.enroll_template(&sine_16k(8_000, 400.0)).unwrap();
d.enroll_template(&sine_16k(8_000, 500.0)).unwrap();
d.enroll_template(&sine_16k(8_000, 600.0)).unwrap();
assert_eq!(d.template_count(), 3);
}
#[test]
fn detector_rejects_too_short_recording() {
let mut d = DtwWakeWordDetector::new();
let too_short = sine_16k(1_600, 440.0); let err = d.enroll_template(&too_short).unwrap_err();
let msg = format!("{err}");
assert!(
msg.to_lowercase().contains("too short"),
"expected 'too short' in error message, got: {msg}"
);
}
#[test]
fn detector_rejects_too_long_recording() {
let mut d = DtwWakeWordDetector::new();
let too_long = sine_16k(48_000, 440.0); let err = d.enroll_template(&too_long).unwrap_err();
let msg = format!("{err}");
assert!(
msg.to_lowercase().contains("too long"),
"expected 'too long' in error message, got: {msg}"
);
}
#[test]
fn detector_fires_on_identical_audio() {
let mut d = DtwWakeWordDetector::new().with_threshold(100.0);
let template_audio = sine_16k(8_000, 440.0); d.enroll_template(&template_audio).unwrap();
let mut fired = false;
for chunk in template_audio.chunks(160) {
if d.process_frame(chunk).is_some() {
fired = true;
break;
}
}
assert!(fired, "detector should fire when fed the enrolled audio");
}
#[test]
fn detector_does_not_fire_on_unrelated_noise() {
let mut d = DtwWakeWordDetector::new().with_threshold(20.0);
d.enroll_template(&sine_16k(8_000, 440.0)).unwrap();
let noise = white_noise_16k(32_000, 0xCAFEBABE);
for chunk in noise.chunks(160) {
if let Some(det) = d.process_frame(chunk) {
panic!(
"unexpected wake-word fire on white noise; score = {}, keyword = {}",
det.score, det.keyword
);
}
}
}
#[test]
fn detector_with_no_templates_never_fires() {
let mut d = DtwWakeWordDetector::new();
let audio = sine_16k(16_000, 440.0);
for chunk in audio.chunks(160) {
assert!(
d.process_frame(chunk).is_none(),
"no-template detector must never fire"
);
}
}
#[test]
fn detector_reset_window_clears_state() {
let mut d = DtwWakeWordDetector::new().with_threshold(100.0);
let template_audio = sine_16k(8_000, 440.0);
d.enroll_template(&template_audio).unwrap();
for chunk in template_audio.chunks(160) {
let _ = d.process_frame(chunk);
}
d.reset_window();
let tiny = vec![0i16; 80];
for i in 0..3 {
assert!(
d.process_frame(&tiny).is_none(),
"after reset_window, sub-frame input must not fire (iteration {i})"
);
}
}
#[test]
fn detector_with_threshold_low_does_not_fire_on_identical_audio() {
let mut d = DtwWakeWordDetector::new().with_threshold(0.0);
let template_audio = sine_16k(8_000, 440.0);
d.enroll_template(&template_audio).unwrap();
for chunk in template_audio.chunks(160) {
assert!(
d.process_frame(chunk).is_none(),
"threshold=0.0 must never fire (even identical audio has non-zero DTW)"
);
}
}
#[test]
fn detector_with_threshold_high_fires_on_unrelated_audio() {
let mut d = DtwWakeWordDetector::new().with_threshold(1.0e9);
d.enroll_template(&sine_16k(8_000, 440.0)).unwrap();
let mut fired = false;
let noise = white_noise_16k(16_000, 0xDEADBEEF);
for chunk in noise.chunks(160) {
if d.process_frame(chunk).is_some() {
fired = true;
break;
}
}
assert!(
fired,
"threshold=1e9 must accept anything once the rolling buffer fills"
);
}
#[test]
fn detector_threshold_getter_reflects_with_threshold() {
let d = DtwWakeWordDetector::new().with_threshold(42.5);
assert!(
(d.threshold() - 42.5).abs() < 1e-6,
"with_threshold must update the field exposed by threshold()"
);
}
#[test]
fn detector_second_enrollment_of_same_audio_does_not_regress() {
let mut d1 = DtwWakeWordDetector::new().with_threshold(200.0);
let template_audio = sine_16k(8_000, 440.0);
d1.enroll_template(&template_audio).unwrap();
let mut best_score_1: f32 = 0.0;
for chunk in template_audio.chunks(160) {
if let Some(det) = d1.process_frame(chunk) {
best_score_1 = best_score_1.max(det.score);
}
}
let mut d2 = DtwWakeWordDetector::new().with_threshold(200.0);
d2.enroll_template(&template_audio).unwrap();
d2.enroll_template(&template_audio).unwrap();
let mut best_score_2: f32 = 0.0;
for chunk in template_audio.chunks(160) {
if let Some(det) = d2.process_frame(chunk) {
best_score_2 = best_score_2.max(det.score);
}
}
assert!(
best_score_2 >= best_score_1 - 1e-4,
"two identical templates regressed score: 1-tpl={best_score_1}, 2-tpl={best_score_2}"
);
}
}