1use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::mpsc;
10use std::sync::{Arc, Mutex};
11
12use crate::audio::capture::TARGET_RATE;
13use crate::transcription::transcriber::Transcriber;
14use crate::transcription::vad;
15
16const WINDOW_SECS: f32 = 3.0;
18
19const WINDOW_SAMPLES: usize = (TARGET_RATE as f32 * WINDOW_SECS) as usize;
21
22const POLL_INTERVAL_MS: u64 = 300;
24
25const MAX_POLL_INTERVAL_MS: u64 = 1200;
27
28const COOLDOWN_MS: u64 = 2000;
30
31#[derive(Debug, Clone)]
33pub enum WakeWordEvent {
34 WakeWordDetected,
36 StopPhraseDetected,
38}
39
40pub struct WakeWordHandle {
42 stop_tx: mpsc::Sender<()>,
43 join_handle: Option<std::thread::JoinHandle<()>>,
44 paused: Arc<AtomicBool>,
45}
46
47impl WakeWordHandle {
48 pub fn pause(&self) {
50 self.paused.store(true, Ordering::Relaxed);
51 log::debug!("Wake word detection paused");
52 }
53
54 pub fn resume(&self) {
56 self.paused.store(false, Ordering::Relaxed);
57 log::debug!("Wake word detection resumed");
58 }
59
60 pub fn stop(mut self) {
62 let _ = self.stop_tx.send(());
63 if let Some(handle) = self.join_handle.take() {
64 let _ = handle.join();
65 }
66 }
67}
68
69impl Drop for WakeWordHandle {
70 fn drop(&mut self) {
71 let _ = self.stop_tx.send(());
72 if let Some(handle) = self.join_handle.take() {
73 let _ = handle.join();
74 }
75 }
76}
77
78pub fn start_detector(
83 wake_phrase: String,
84 stop_phrase: String,
85 tx: mpsc::Sender<WakeWordEvent>,
86) -> anyhow::Result<WakeWordHandle> {
87 let (stop_tx, stop_rx) = mpsc::channel::<()>();
88 let paused = Arc::new(AtomicBool::new(false));
89 let paused_clone = paused.clone();
90
91 let join_handle = std::thread::spawn(move || {
92 if let Err(e) = detector_thread(wake_phrase, stop_phrase, tx, stop_rx, paused_clone) {
93 log::error!("Wake word detector failed: {e}");
94 }
95 });
96
97 Ok(WakeWordHandle {
98 stop_tx,
99 join_handle: Some(join_handle),
100 paused,
101 })
102}
103
104fn detector_thread(
105 wake_phrase: String,
106 stop_phrase: String,
107 tx: mpsc::Sender<WakeWordEvent>,
108 stop_rx: mpsc::Receiver<()>,
109 paused: Arc<AtomicBool>,
110) -> anyhow::Result<()> {
111 let model_size = "tiny.en";
113 if !crate::transcription::transcriber::model_exists(model_size) {
114 log::info!("Downloading {model_size} model for wake word detection...");
115 crate::transcription::model::download(model_size, |_| {})?;
116 }
117
118 let model_path = crate::transcription::transcriber::find_model(model_size)
119 .ok_or_else(|| anyhow::anyhow!("Wake word model '{model_size}' not found"))?;
120
121 let transcriber = Transcriber::new(&model_path, "en")?;
122 log::info!("Wake word detector ready (phrase: \"{wake_phrase}\")");
123
124 let ring_buffer: Arc<Mutex<Vec<f32>>> =
126 Arc::new(Mutex::new(Vec::with_capacity(WINDOW_SAMPLES * 2)));
127
128 let ring_clone = ring_buffer.clone();
130 let _stream = open_capture_stream(ring_clone)?;
131
132 let wake_lower = wake_phrase.to_lowercase();
133 let stop_lower = stop_phrase.to_lowercase();
134 let mut last_detection = std::time::Instant::now()
135 .checked_sub(std::time::Duration::from_millis(COOLDOWN_MS * 2))
136 .unwrap_or_else(std::time::Instant::now);
137
138 let mut current_poll_ms = POLL_INTERVAL_MS;
142
143 loop {
144 match stop_rx.try_recv() {
146 Ok(()) | Err(mpsc::TryRecvError::Disconnected) => break,
147 Err(mpsc::TryRecvError::Empty) => {}
148 }
149
150 if paused.load(Ordering::Relaxed) {
152 std::thread::sleep(std::time::Duration::from_millis(current_poll_ms));
153 continue;
154 }
155
156 let samples: Vec<f32> = {
158 let buf = ring_buffer.lock().unwrap_or_else(|e| e.into_inner());
159 if buf.len() < WINDOW_SAMPLES {
160 drop(buf);
161 std::thread::sleep(std::time::Duration::from_millis(current_poll_ms));
162 continue;
163 }
164 let start = buf.len().saturating_sub(WINDOW_SAMPLES);
166 buf[start..].to_vec()
167 };
168
169 {
171 let mut buf = ring_buffer.lock().unwrap_or_else(|e| e.into_inner());
172 if buf.len() > WINDOW_SAMPLES * 3 {
173 let drain_to = buf.len() - WINDOW_SAMPLES * 2;
174 buf.drain(..drain_to);
175 }
176 }
177
178 if !vad::contains_speech(&samples) {
180 current_poll_ms = (current_poll_ms * 3 / 2).min(MAX_POLL_INTERVAL_MS);
182 std::thread::sleep(std::time::Duration::from_millis(current_poll_ms));
183 continue;
184 }
185
186 current_poll_ms = POLL_INTERVAL_MS;
188
189 if last_detection.elapsed() < std::time::Duration::from_millis(COOLDOWN_MS) {
191 std::thread::sleep(std::time::Duration::from_millis(current_poll_ms));
192 continue;
193 }
194
195 match transcriber.transcribe_samples(&samples, false) {
197 Ok(text) => {
198 let text_lower = text.to_lowercase();
199 log::debug!("Wake word heard: \"{text}\"");
200
201 if contains_phrase(&text_lower, &wake_lower) {
202 log::info!("Wake word detected!");
203 last_detection = std::time::Instant::now();
204 if tx.send(WakeWordEvent::WakeWordDetected).is_err() {
205 break;
206 }
207 } else if contains_phrase(&text_lower, &stop_lower) {
208 log::info!("Stop phrase detected!");
209 last_detection = std::time::Instant::now();
210 if tx.send(WakeWordEvent::StopPhraseDetected).is_err() {
211 break;
212 }
213 }
214 }
215 Err(e) => {
216 log::warn!("Wake word transcription failed: {e}");
217 }
218 }
219
220 std::thread::sleep(std::time::Duration::from_millis(current_poll_ms));
221 }
222
223 log::info!("Wake word detector stopped");
224 Ok(())
225}
226
227fn contains_phrase(text: &str, phrase: &str) -> bool {
232 if phrase.is_empty() {
233 return false;
234 }
235
236 let phrase_words: Vec<&str> = phrase.split_whitespace().collect();
237 let text_words: Vec<&str> = text.split_whitespace().collect();
238
239 if phrase_words.len() > text_words.len() {
240 return false;
241 }
242
243 text_words.windows(phrase_words.len()).any(|window| {
244 window.iter().zip(phrase_words.iter()).all(|(tw, pw)| {
245 let tw_clean = tw.trim_matches(|c: char| c.is_ascii_punctuation());
246 let pw_clean = pw.trim_matches(|c: char| c.is_ascii_punctuation());
247 words_match(tw_clean, pw_clean)
248 })
249 })
250}
251
252fn words_match(heard: &str, expected: &str) -> bool {
255 if heard == expected {
256 return true;
257 }
258 if is_known_alias(heard, expected) {
260 return true;
261 }
262 if expected.len() <= 8 {
264 return edit_distance(heard, expected) <= 2;
265 }
266 false
267}
268
269const MURMUR_ALIASES: &[&str] = &[
271 "mama", "mamma", "mirror", "murmured", "memo", "memer", "merma", "mermer",
272];
273
274fn is_known_alias(heard: &str, expected: &str) -> bool {
276 if expected.eq_ignore_ascii_case("murmur") {
277 return MURMUR_ALIASES
278 .iter()
279 .any(|alias| alias.eq_ignore_ascii_case(heard));
280 }
281 false
282}
283
284fn edit_distance(a: &str, b: &str) -> usize {
286 let a: Vec<char> = a.chars().collect();
287 let b: Vec<char> = b.chars().collect();
288 let m = a.len();
289 let n = b.len();
290
291 if m.abs_diff(n) > 2 {
293 return m.abs_diff(n);
294 }
295
296 let mut prev: Vec<usize> = (0..=n).collect();
297 let mut curr = vec![0usize; n + 1];
298
299 for i in 1..=m {
300 curr[0] = i;
301 for j in 1..=n {
302 let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
303 curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
304 }
305 std::mem::swap(&mut prev, &mut curr);
306 }
307
308 prev[n]
309}
310
311fn open_capture_stream(buffer: Arc<Mutex<Vec<f32>>>) -> anyhow::Result<cpal::Stream> {
313 use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
314
315 let host = cpal::default_host();
316 let device = host
317 .default_input_device()
318 .ok_or_else(|| anyhow::anyhow!("No audio input device"))?;
319
320 let supported = device.default_input_config()?;
321 let sample_rate = supported.sample_rate();
322 let channels = supported.channels() as usize;
323
324 let config: cpal::StreamConfig = supported.into();
325
326 let stream = device.build_input_stream(
327 &config,
328 move |data: &[f32], _: &cpal::InputCallbackInfo| {
329 let mono: Vec<f32> = if channels == 1 {
331 data.to_vec()
332 } else {
333 data.chunks(channels)
334 .map(|frame| frame.iter().sum::<f32>() / channels as f32)
335 .collect()
336 };
337
338 let samples_16k = if sample_rate == TARGET_RATE {
340 mono
341 } else {
342 resample_simple(&mono, sample_rate, TARGET_RATE)
343 };
344
345 if let Ok(mut buf) = buffer.try_lock() {
346 buf.extend_from_slice(&samples_16k);
347 }
348 },
349 |err| {
350 log::error!("Wake word audio error: {err}");
351 },
352 None,
353 )?;
354
355 stream.play()?;
356 Ok(stream)
357}
358
359fn resample_simple(input: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
361 if from_rate == to_rate || input.is_empty() {
362 return input.to_vec();
363 }
364 let ratio = from_rate as f64 / to_rate as f64;
365 let output_len = (input.len() as f64 / ratio) as usize;
366 let mut output = Vec::with_capacity(output_len);
367
368 for i in 0..output_len {
369 let src_pos = i as f64 * ratio;
370 let idx = src_pos as usize;
371 let frac = src_pos - idx as f64;
372
373 let sample = if idx + 1 < input.len() {
374 input[idx] * (1.0 - frac as f32) + input[idx + 1] * frac as f32
375 } else if idx < input.len() {
376 input[idx]
377 } else {
378 0.0
379 };
380 output.push(sample);
381 }
382
383 output
384}
385
386pub fn check_and_strip_stop_phrase(text: &str, stop_phrase: &str) -> Option<String> {
389 let text_lower = text.to_lowercase();
390 let stop_lower = stop_phrase.to_lowercase();
391
392 if !contains_phrase(&text_lower, &stop_lower) {
393 return None;
394 }
395
396 let phrase_words: Vec<&str> = stop_phrase.split_whitespace().collect();
398 let text_words: Vec<&str> = text.split_whitespace().collect();
399
400 let phrase_lower_words: Vec<&str> = stop_lower.split_whitespace().collect();
402 let text_lower_words: Vec<String> = text_words
403 .iter()
404 .map(|w| {
405 w.to_lowercase()
406 .trim_matches(|c: char| c.is_ascii_punctuation())
407 .to_string()
408 })
409 .collect();
410
411 for i in 0..=text_words.len().saturating_sub(phrase_words.len()) {
412 let matches = text_lower_words[i..i + phrase_lower_words.len()]
413 .iter()
414 .zip(phrase_lower_words.iter())
415 .all(|(tw, pw)| {
416 let pw_clean = pw.trim_matches(|c: char| c.is_ascii_punctuation());
417 words_match(tw, pw_clean)
418 });
419
420 if matches {
421 let mut result_words: Vec<&str> = Vec::new();
422 result_words.extend_from_slice(&text_words[..i]);
423 result_words.extend_from_slice(&text_words[i + phrase_words.len()..]);
424 let result = result_words.join(" ").trim().to_string();
425 return Some(result);
426 }
427 }
428
429 Some(text.to_string())
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_contains_phrase_basic() {
439 assert!(contains_phrase(
440 "hello murmur start dictation please",
441 "murmur start dictation"
442 ));
443 assert!(contains_phrase(
444 "murmur start dictation",
445 "murmur start dictation"
446 ));
447 assert!(!contains_phrase("hello world", "murmur start dictation"));
448 }
449
450 #[test]
451 fn test_contains_phrase_punctuation() {
452 assert!(contains_phrase(
453 "hello, murmur start dictation.",
454 "murmur start dictation"
455 ));
456 assert!(contains_phrase(
457 "\"murmur start dictation\"",
458 "murmur start dictation"
459 ));
460 }
461
462 #[test]
463 fn test_contains_phrase_empty() {
464 assert!(!contains_phrase("hello", ""));
465 assert!(!contains_phrase("", "murmur start dictation"));
466 }
467
468 #[test]
469 fn test_contains_phrase_partial() {
470 assert!(!contains_phrase("murmur", "murmur start dictation"));
471 assert!(!contains_phrase(
472 "start dictation",
473 "murmur start dictation"
474 ));
475 }
476
477 #[test]
478 fn test_contains_phrase_fuzzy_murmur() {
479 assert!(contains_phrase(
481 "mama start dictation",
482 "murmur start dictation"
483 ));
484 assert!(contains_phrase(
485 "mirror start dictation",
486 "murmur start dictation"
487 ));
488 assert!(contains_phrase(
489 "murder start dictation",
490 "murmur start dictation"
491 ));
492 assert!(contains_phrase(
493 "murmer start dictation",
494 "murmur start dictation"
495 ));
496 assert!(!contains_phrase(
498 "banana start dictation",
499 "murmur start dictation"
500 ));
501 assert!(!contains_phrase(
502 "tomorrow start dictation",
503 "murmur start dictation"
504 ));
505 }
506
507 #[test]
508 fn test_contains_phrase_fuzzy_stop() {
509 assert!(contains_phrase(
510 "mama stop dictation",
511 "murmur stop dictation"
512 ));
513 assert!(contains_phrase(
514 "mirror stop dictation",
515 "murmur stop dictation"
516 ));
517 }
518
519 #[test]
520 fn test_edit_distance() {
521 assert_eq!(edit_distance("murmur", "murmur"), 0);
522 assert_eq!(edit_distance("murder", "murmur"), 2);
523 assert_eq!(edit_distance("murmer", "murmur"), 1);
524 assert_eq!(edit_distance("mama", "murmur"), 4);
525 assert_eq!(edit_distance("mirror", "murmur"), 3);
526 assert!(edit_distance("banana", "murmur") > 2);
527 }
528
529 #[test]
530 fn test_words_match_exact() {
531 assert!(words_match("start", "start"));
532 assert!(words_match("murmur", "murmur"));
533 assert!(!words_match("start", "stop"));
534 }
535
536 #[test]
537 fn test_words_match_fuzzy() {
538 assert!(words_match("mama", "murmur"));
540 assert!(words_match("mirror", "murmur"));
541 assert!(words_match("mamma", "murmur"));
542 assert!(words_match("murder", "murmur"));
544 assert!(words_match("murmer", "murmur"));
545 assert!(!words_match("banana", "murmur"));
547 assert!(!words_match("number", "murmur"));
548 }
549
550 #[test]
551 fn test_is_known_alias() {
552 assert!(is_known_alias("mama", "murmur"));
553 assert!(is_known_alias("mirror", "murmur"));
554 assert!(!is_known_alias("mama", "start"));
555 assert!(!is_known_alias("banana", "murmur"));
556 }
557
558 #[test]
559 fn test_check_and_strip_stop_phrase() {
560 let result = check_and_strip_stop_phrase(
561 "hello world murmur stop dictation thanks",
562 "murmur stop dictation",
563 );
564 assert_eq!(result, Some("hello world thanks".to_string()));
565 }
566
567 #[test]
568 fn test_check_and_strip_stop_phrase_at_end() {
569 let result = check_and_strip_stop_phrase(
570 "hello world murmur stop dictation",
571 "murmur stop dictation",
572 );
573 assert_eq!(result, Some("hello world".to_string()));
574 }
575
576 #[test]
577 fn test_check_and_strip_stop_phrase_at_start() {
578 let result = check_and_strip_stop_phrase(
579 "murmur stop dictation hello world",
580 "murmur stop dictation",
581 );
582 assert_eq!(result, Some("hello world".to_string()));
583 }
584
585 #[test]
586 fn test_check_and_strip_stop_phrase_not_found() {
587 let result = check_and_strip_stop_phrase("hello world", "murmur stop dictation");
588 assert_eq!(result, None);
589 }
590
591 #[test]
592 fn test_check_and_strip_stop_phrase_fuzzy() {
593 let result = check_and_strip_stop_phrase(
594 "hello mama stop dictation thanks",
595 "murmur stop dictation",
596 );
597 assert_eq!(result, Some("hello thanks".to_string()));
598 }
599
600 #[test]
601 fn test_resample_simple_same_rate() {
602 let input = vec![1.0, 2.0, 3.0];
603 let output = resample_simple(&input, 16000, 16000);
604 assert_eq!(output, input);
605 }
606
607 #[test]
608 fn test_resample_simple_downsample() {
609 let input: Vec<f32> = (0..48000).map(|i| (i as f32 / 48000.0).sin()).collect();
610 let output = resample_simple(&input, 48000, 16000);
611 assert!((output.len() as f32 - 16000.0).abs() < 2.0);
613 }
614
615 #[test]
616 fn test_resample_simple_empty() {
617 let output = resample_simple(&[], 48000, 16000);
618 assert!(output.is_empty());
619 }
620}