Skip to main content

murmur_core/input/
wake_word.rs

1//! Wake word detection using a dedicated Whisper tiny model.
2//!
3//! Continuously captures audio via a dedicated cpal stream, runs VAD to
4//! detect speech, and transcribes short windows with Whisper tiny to check
5//! for the configured wake/stop phrase. This keeps CPU usage low: the
6//! neural network only runs when the VAD detects speech.
7
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::mpsc;
10use std::sync::{Arc, Mutex};
11
12use crate::audio::capture::TARGET_RATE;
13use crate::transcription::transcriber::Transcriber;
14use crate::transcription::vad;
15
16/// Duration of the detection window in seconds.
17const WINDOW_SECS: f32 = 3.0;
18
19/// Samples in one detection window.
20const WINDOW_SAMPLES: usize = (TARGET_RATE as f32 * WINDOW_SECS) as usize;
21
22/// How often to check for speech (milliseconds) — base interval.
23const POLL_INTERVAL_MS: u64 = 300;
24
25/// Maximum poll interval after sustained silence (exponential backoff).
26const MAX_POLL_INTERVAL_MS: u64 = 1200;
27
28/// Minimum silence gap between detections to avoid re-triggering.
29const COOLDOWN_MS: u64 = 2000;
30
31/// Events emitted by the wake word detector.
32#[derive(Debug, Clone)]
33pub enum WakeWordEvent {
34    /// The wake phrase was detected — start dictation.
35    WakeWordDetected,
36    /// The stop phrase was detected — stop dictation.
37    StopPhraseDetected,
38}
39
40/// Handle to control the wake word detector thread.
41pub struct WakeWordHandle {
42    stop_tx: mpsc::Sender<()>,
43    join_handle: Option<std::thread::JoinHandle<()>>,
44    paused: Arc<AtomicBool>,
45}
46
47impl WakeWordHandle {
48    /// Pause detection (e.g., while dictation is active).
49    pub fn pause(&self) {
50        self.paused.store(true, Ordering::Relaxed);
51        log::debug!("Wake word detection paused");
52    }
53
54    /// Resume detection.
55    pub fn resume(&self) {
56        self.paused.store(false, Ordering::Relaxed);
57        log::debug!("Wake word detection resumed");
58    }
59
60    /// Stop and join the detector thread.
61    pub fn stop(mut self) {
62        let _ = self.stop_tx.send(());
63        if let Some(handle) = self.join_handle.take() {
64            let _ = handle.join();
65        }
66    }
67}
68
69impl Drop for WakeWordHandle {
70    fn drop(&mut self) {
71        let _ = self.stop_tx.send(());
72        if let Some(handle) = self.join_handle.take() {
73            let _ = handle.join();
74        }
75    }
76}
77
78/// Start the wake word detector.
79///
80/// Loads Whisper tiny (downloading if needed), opens a dedicated audio
81/// stream, and monitors for the wake/stop phrases. Sends events via `tx`.
82pub fn start_detector(
83    wake_phrase: String,
84    stop_phrase: String,
85    tx: mpsc::Sender<WakeWordEvent>,
86) -> anyhow::Result<WakeWordHandle> {
87    let (stop_tx, stop_rx) = mpsc::channel::<()>();
88    let paused = Arc::new(AtomicBool::new(false));
89    let paused_clone = paused.clone();
90
91    let join_handle = std::thread::spawn(move || {
92        if let Err(e) = detector_thread(wake_phrase, stop_phrase, tx, stop_rx, paused_clone) {
93            log::error!("Wake word detector failed: {e}");
94        }
95    });
96
97    Ok(WakeWordHandle {
98        stop_tx,
99        join_handle: Some(join_handle),
100        paused,
101    })
102}
103
104fn detector_thread(
105    wake_phrase: String,
106    stop_phrase: String,
107    tx: mpsc::Sender<WakeWordEvent>,
108    stop_rx: mpsc::Receiver<()>,
109    paused: Arc<AtomicBool>,
110) -> anyhow::Result<()> {
111    // Ensure the tiny model is available
112    let model_size = "tiny.en";
113    if !crate::transcription::transcriber::model_exists(model_size) {
114        log::info!("Downloading {model_size} model for wake word detection...");
115        crate::transcription::model::download(model_size, |_| {})?;
116    }
117
118    let model_path = crate::transcription::transcriber::find_model(model_size)
119        .ok_or_else(|| anyhow::anyhow!("Wake word model '{model_size}' not found"))?;
120
121    let transcriber = Transcriber::new(&model_path, "en")?;
122    log::info!("Wake word detector ready (phrase: \"{wake_phrase}\")");
123
124    // Audio capture ring buffer shared with the cpal callback
125    let ring_buffer: Arc<Mutex<Vec<f32>>> =
126        Arc::new(Mutex::new(Vec::with_capacity(WINDOW_SAMPLES * 2)));
127
128    // Open a dedicated cpal audio stream for wake word detection
129    let ring_clone = ring_buffer.clone();
130    let _stream = open_capture_stream(ring_clone)?;
131
132    let wake_lower = wake_phrase.to_lowercase();
133    let stop_lower = stop_phrase.to_lowercase();
134    let mut last_detection = std::time::Instant::now()
135        .checked_sub(std::time::Duration::from_millis(COOLDOWN_MS * 2))
136        .unwrap_or_else(std::time::Instant::now);
137
138    // Adaptive poll interval: starts at POLL_INTERVAL_MS and backs off
139    // when consecutive polls find no speech, up to MAX_POLL_INTERVAL_MS.
140    // Resets to base on speech detection.
141    let mut current_poll_ms = POLL_INTERVAL_MS;
142
143    loop {
144        // Check for stop signal
145        match stop_rx.try_recv() {
146            Ok(()) | Err(mpsc::TryRecvError::Disconnected) => break,
147            Err(mpsc::TryRecvError::Empty) => {}
148        }
149
150        // Skip if paused
151        if paused.load(Ordering::Relaxed) {
152            std::thread::sleep(std::time::Duration::from_millis(current_poll_ms));
153            continue;
154        }
155
156        // Wait for enough audio
157        let samples: Vec<f32> = {
158            let buf = ring_buffer.lock().unwrap_or_else(|e| e.into_inner());
159            if buf.len() < WINDOW_SAMPLES {
160                drop(buf);
161                std::thread::sleep(std::time::Duration::from_millis(current_poll_ms));
162                continue;
163            }
164            // Take the most recent window
165            let start = buf.len().saturating_sub(WINDOW_SAMPLES);
166            buf[start..].to_vec()
167        };
168
169        // Trim the ring buffer to prevent unbounded growth
170        {
171            let mut buf = ring_buffer.lock().unwrap_or_else(|e| e.into_inner());
172            if buf.len() > WINDOW_SAMPLES * 3 {
173                let drain_to = buf.len() - WINDOW_SAMPLES * 2;
174                buf.drain(..drain_to);
175            }
176        }
177
178        // Only transcribe if VAD detects speech
179        if !vad::contains_speech(&samples) {
180            // Back off: increase poll interval on consecutive silence
181            current_poll_ms = (current_poll_ms * 3 / 2).min(MAX_POLL_INTERVAL_MS);
182            std::thread::sleep(std::time::Duration::from_millis(current_poll_ms));
183            continue;
184        }
185
186        // Speech detected — reset to base poll interval
187        current_poll_ms = POLL_INTERVAL_MS;
188
189        // Cooldown check
190        if last_detection.elapsed() < std::time::Duration::from_millis(COOLDOWN_MS) {
191            std::thread::sleep(std::time::Duration::from_millis(current_poll_ms));
192            continue;
193        }
194
195        // Transcribe the window
196        match transcriber.transcribe_samples(&samples, false) {
197            Ok(text) => {
198                let text_lower = text.to_lowercase();
199                log::debug!("Wake word heard: \"{text}\"");
200
201                if contains_phrase(&text_lower, &wake_lower) {
202                    log::info!("Wake word detected!");
203                    last_detection = std::time::Instant::now();
204                    if tx.send(WakeWordEvent::WakeWordDetected).is_err() {
205                        break;
206                    }
207                } else if contains_phrase(&text_lower, &stop_lower) {
208                    log::info!("Stop phrase detected!");
209                    last_detection = std::time::Instant::now();
210                    if tx.send(WakeWordEvent::StopPhraseDetected).is_err() {
211                        break;
212                    }
213                }
214            }
215            Err(e) => {
216                log::warn!("Wake word transcription failed: {e}");
217            }
218        }
219
220        std::thread::sleep(std::time::Duration::from_millis(current_poll_ms));
221    }
222
223    log::info!("Wake word detector stopped");
224    Ok(())
225}
226
227/// Check if `text` contains the given `phrase` (fuzzy word-boundary match).
228///
229/// Uses exact matching for most words but fuzzy matching (edit distance ≤ 2)
230/// for short words that Whisper often mistranscribes (e.g. "murmur" → "mama").
231fn contains_phrase(text: &str, phrase: &str) -> bool {
232    if phrase.is_empty() {
233        return false;
234    }
235
236    let phrase_words: Vec<&str> = phrase.split_whitespace().collect();
237    let text_words: Vec<&str> = text.split_whitespace().collect();
238
239    if phrase_words.len() > text_words.len() {
240        return false;
241    }
242
243    text_words.windows(phrase_words.len()).any(|window| {
244        window.iter().zip(phrase_words.iter()).all(|(tw, pw)| {
245            let tw_clean = tw.trim_matches(|c: char| c.is_ascii_punctuation());
246            let pw_clean = pw.trim_matches(|c: char| c.is_ascii_punctuation());
247            words_match(tw_clean, pw_clean)
248        })
249    })
250}
251
252/// Check whether two words match, using fuzzy matching for short words
253/// that are prone to mistranscription and exact matching otherwise.
254fn words_match(heard: &str, expected: &str) -> bool {
255    if heard == expected {
256        return true;
257    }
258    // Check known aliases for the app name (Whisper commonly mistranscribes these)
259    if is_known_alias(heard, expected) {
260        return true;
261    }
262    // Use edit distance ≤ 2 for words ≤ 8 chars to catch minor transcription errors
263    if expected.len() <= 8 {
264        return edit_distance(heard, expected) <= 2;
265    }
266    false
267}
268
269/// Known mistranscriptions of "murmur" by Whisper tiny.
270const MURMUR_ALIASES: &[&str] = &[
271    "mama", "mamma", "mirror", "murmured", "memo", "memer", "merma", "mermer",
272];
273
274/// Check if `heard` is a known alias for `expected`.
275fn is_known_alias(heard: &str, expected: &str) -> bool {
276    if expected.eq_ignore_ascii_case("murmur") {
277        return MURMUR_ALIASES
278            .iter()
279            .any(|alias| alias.eq_ignore_ascii_case(heard));
280    }
281    false
282}
283
284/// Compute Levenshtein edit distance between two strings.
285fn edit_distance(a: &str, b: &str) -> usize {
286    let a: Vec<char> = a.chars().collect();
287    let b: Vec<char> = b.chars().collect();
288    let m = a.len();
289    let n = b.len();
290
291    // Early exit: if length difference alone exceeds threshold, skip full computation
292    if m.abs_diff(n) > 2 {
293        return m.abs_diff(n);
294    }
295
296    let mut prev: Vec<usize> = (0..=n).collect();
297    let mut curr = vec![0usize; n + 1];
298
299    for i in 1..=m {
300        curr[0] = i;
301        for j in 1..=n {
302            let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
303            curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
304        }
305        std::mem::swap(&mut prev, &mut curr);
306    }
307
308    prev[n]
309}
310
311/// Open a cpal input stream that pushes 16 kHz mono samples into `buffer`.
312fn open_capture_stream(buffer: Arc<Mutex<Vec<f32>>>) -> anyhow::Result<cpal::Stream> {
313    use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
314
315    let host = cpal::default_host();
316    let device = host
317        .default_input_device()
318        .ok_or_else(|| anyhow::anyhow!("No audio input device"))?;
319
320    let supported = device.default_input_config()?;
321    let sample_rate = supported.sample_rate();
322    let channels = supported.channels() as usize;
323
324    let config: cpal::StreamConfig = supported.into();
325
326    let stream = device.build_input_stream(
327        &config,
328        move |data: &[f32], _: &cpal::InputCallbackInfo| {
329            // Mix to mono
330            let mono: Vec<f32> = if channels == 1 {
331                data.to_vec()
332            } else {
333                data.chunks(channels)
334                    .map(|frame| frame.iter().sum::<f32>() / channels as f32)
335                    .collect()
336            };
337
338            // Resample to 16 kHz if needed
339            let samples_16k = if sample_rate == TARGET_RATE {
340                mono
341            } else {
342                resample_simple(&mono, sample_rate, TARGET_RATE)
343            };
344
345            if let Ok(mut buf) = buffer.try_lock() {
346                buf.extend_from_slice(&samples_16k);
347            }
348        },
349        |err| {
350            log::error!("Wake word audio error: {err}");
351        },
352        None,
353    )?;
354
355    stream.play()?;
356    Ok(stream)
357}
358
359/// Simple linear resampling.
360fn resample_simple(input: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
361    if from_rate == to_rate || input.is_empty() {
362        return input.to_vec();
363    }
364    let ratio = from_rate as f64 / to_rate as f64;
365    let output_len = (input.len() as f64 / ratio) as usize;
366    let mut output = Vec::with_capacity(output_len);
367
368    for i in 0..output_len {
369        let src_pos = i as f64 * ratio;
370        let idx = src_pos as usize;
371        let frac = src_pos - idx as f64;
372
373        let sample = if idx + 1 < input.len() {
374            input[idx] * (1.0 - frac as f32) + input[idx + 1] * frac as f32
375        } else if idx < input.len() {
376            input[idx]
377        } else {
378            0.0
379        };
380        output.push(sample);
381    }
382
383    output
384}
385
386/// Check streaming partial text for the stop phrase and return
387/// the text with the stop phrase removed if found.
388pub fn check_and_strip_stop_phrase(text: &str, stop_phrase: &str) -> Option<String> {
389    let text_lower = text.to_lowercase();
390    let stop_lower = stop_phrase.to_lowercase();
391
392    if !contains_phrase(&text_lower, &stop_lower) {
393        return None;
394    }
395
396    // Remove the stop phrase from the text
397    let phrase_words: Vec<&str> = stop_phrase.split_whitespace().collect();
398    let text_words: Vec<&str> = text.split_whitespace().collect();
399
400    // Find the position of the stop phrase in the text
401    let phrase_lower_words: Vec<&str> = stop_lower.split_whitespace().collect();
402    let text_lower_words: Vec<String> = text_words
403        .iter()
404        .map(|w| {
405            w.to_lowercase()
406                .trim_matches(|c: char| c.is_ascii_punctuation())
407                .to_string()
408        })
409        .collect();
410
411    for i in 0..=text_words.len().saturating_sub(phrase_words.len()) {
412        let matches = text_lower_words[i..i + phrase_lower_words.len()]
413            .iter()
414            .zip(phrase_lower_words.iter())
415            .all(|(tw, pw)| {
416                let pw_clean = pw.trim_matches(|c: char| c.is_ascii_punctuation());
417                words_match(tw, pw_clean)
418            });
419
420        if matches {
421            let mut result_words: Vec<&str> = Vec::new();
422            result_words.extend_from_slice(&text_words[..i]);
423            result_words.extend_from_slice(&text_words[i + phrase_words.len()..]);
424            let result = result_words.join(" ").trim().to_string();
425            return Some(result);
426        }
427    }
428
429    // Fallback: couldn't pinpoint location, return text as-is
430    Some(text.to_string())
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_contains_phrase_basic() {
439        assert!(contains_phrase(
440            "hello murmur start dictation please",
441            "murmur start dictation"
442        ));
443        assert!(contains_phrase(
444            "murmur start dictation",
445            "murmur start dictation"
446        ));
447        assert!(!contains_phrase("hello world", "murmur start dictation"));
448    }
449
450    #[test]
451    fn test_contains_phrase_punctuation() {
452        assert!(contains_phrase(
453            "hello, murmur start dictation.",
454            "murmur start dictation"
455        ));
456        assert!(contains_phrase(
457            "\"murmur start dictation\"",
458            "murmur start dictation"
459        ));
460    }
461
462    #[test]
463    fn test_contains_phrase_empty() {
464        assert!(!contains_phrase("hello", ""));
465        assert!(!contains_phrase("", "murmur start dictation"));
466    }
467
468    #[test]
469    fn test_contains_phrase_partial() {
470        assert!(!contains_phrase("murmur", "murmur start dictation"));
471        assert!(!contains_phrase(
472            "start dictation",
473            "murmur start dictation"
474        ));
475    }
476
477    #[test]
478    fn test_contains_phrase_fuzzy_murmur() {
479        // Common Whisper mistranscriptions of "murmur"
480        assert!(contains_phrase(
481            "mama start dictation",
482            "murmur start dictation"
483        ));
484        assert!(contains_phrase(
485            "mirror start dictation",
486            "murmur start dictation"
487        ));
488        assert!(contains_phrase(
489            "murder start dictation",
490            "murmur start dictation"
491        ));
492        assert!(contains_phrase(
493            "murmer start dictation",
494            "murmur start dictation"
495        ));
496        // Too far away — should NOT match
497        assert!(!contains_phrase(
498            "banana start dictation",
499            "murmur start dictation"
500        ));
501        assert!(!contains_phrase(
502            "tomorrow start dictation",
503            "murmur start dictation"
504        ));
505    }
506
507    #[test]
508    fn test_contains_phrase_fuzzy_stop() {
509        assert!(contains_phrase(
510            "mama stop dictation",
511            "murmur stop dictation"
512        ));
513        assert!(contains_phrase(
514            "mirror stop dictation",
515            "murmur stop dictation"
516        ));
517    }
518
519    #[test]
520    fn test_edit_distance() {
521        assert_eq!(edit_distance("murmur", "murmur"), 0);
522        assert_eq!(edit_distance("murder", "murmur"), 2);
523        assert_eq!(edit_distance("murmer", "murmur"), 1);
524        assert_eq!(edit_distance("mama", "murmur"), 4);
525        assert_eq!(edit_distance("mirror", "murmur"), 3);
526        assert!(edit_distance("banana", "murmur") > 2);
527    }
528
529    #[test]
530    fn test_words_match_exact() {
531        assert!(words_match("start", "start"));
532        assert!(words_match("murmur", "murmur"));
533        assert!(!words_match("start", "stop"));
534    }
535
536    #[test]
537    fn test_words_match_fuzzy() {
538        // Known aliases
539        assert!(words_match("mama", "murmur"));
540        assert!(words_match("mirror", "murmur"));
541        assert!(words_match("mamma", "murmur"));
542        // Edit distance ≤ 2
543        assert!(words_match("murder", "murmur"));
544        assert!(words_match("murmer", "murmur"));
545        // Too different
546        assert!(!words_match("banana", "murmur"));
547        assert!(!words_match("number", "murmur"));
548    }
549
550    #[test]
551    fn test_is_known_alias() {
552        assert!(is_known_alias("mama", "murmur"));
553        assert!(is_known_alias("mirror", "murmur"));
554        assert!(!is_known_alias("mama", "start"));
555        assert!(!is_known_alias("banana", "murmur"));
556    }
557
558    #[test]
559    fn test_check_and_strip_stop_phrase() {
560        let result = check_and_strip_stop_phrase(
561            "hello world murmur stop dictation thanks",
562            "murmur stop dictation",
563        );
564        assert_eq!(result, Some("hello world thanks".to_string()));
565    }
566
567    #[test]
568    fn test_check_and_strip_stop_phrase_at_end() {
569        let result = check_and_strip_stop_phrase(
570            "hello world murmur stop dictation",
571            "murmur stop dictation",
572        );
573        assert_eq!(result, Some("hello world".to_string()));
574    }
575
576    #[test]
577    fn test_check_and_strip_stop_phrase_at_start() {
578        let result = check_and_strip_stop_phrase(
579            "murmur stop dictation hello world",
580            "murmur stop dictation",
581        );
582        assert_eq!(result, Some("hello world".to_string()));
583    }
584
585    #[test]
586    fn test_check_and_strip_stop_phrase_not_found() {
587        let result = check_and_strip_stop_phrase("hello world", "murmur stop dictation");
588        assert_eq!(result, None);
589    }
590
591    #[test]
592    fn test_check_and_strip_stop_phrase_fuzzy() {
593        let result = check_and_strip_stop_phrase(
594            "hello mama stop dictation thanks",
595            "murmur stop dictation",
596        );
597        assert_eq!(result, Some("hello thanks".to_string()));
598    }
599
600    #[test]
601    fn test_resample_simple_same_rate() {
602        let input = vec![1.0, 2.0, 3.0];
603        let output = resample_simple(&input, 16000, 16000);
604        assert_eq!(output, input);
605    }
606
607    #[test]
608    fn test_resample_simple_downsample() {
609        let input: Vec<f32> = (0..48000).map(|i| (i as f32 / 48000.0).sin()).collect();
610        let output = resample_simple(&input, 48000, 16000);
611        // Should be roughly 1/3 the length
612        assert!((output.len() as f32 - 16000.0).abs() < 2.0);
613    }
614
615    #[test]
616    fn test_resample_simple_empty() {
617        let output = resample_simple(&[], 48000, 16000);
618        assert!(output.is_empty());
619    }
620}