Skip to main content

opencode_voice/transcribe/
engine.rs

1//! whisper-rs transcription engine: in-process speech-to-text.
2//!
3//! Replaces the whisper-cli subprocess with native whisper-rs bindings.
4
5use anyhow::{Context, Result};
6use std::path::Path;
7
8/// Result of a transcription operation.
9pub struct TranscriptionResult {
10    pub text: String,
11}
12
13/// Checks if a model file is valid (exists and is > 1MB).
14pub fn is_model_valid(path: &Path) -> bool {
15    path.exists()
16        && path
17            .metadata()
18            .map(|m| m.len() > 1_000_000)
19            .unwrap_or(false)
20}
21
22/// In-process whisper transcription engine.
23pub struct WhisperEngine {
24    ctx: whisper_rs::WhisperContext,
25}
26
27impl WhisperEngine {
28    /// Loads a GGML model file and creates a WhisperEngine.
29    pub fn new(model_path: &Path) -> Result<Self> {
30        if !model_path.exists() {
31            anyhow::bail!(
32                "Whisper model not found at {}. Run 'opencode-voice setup' to download it.",
33                model_path.display()
34            );
35        }
36
37        let path_str = model_path
38            .to_str()
39            .context("Model path contains invalid UTF-8")?;
40
41        // Suppress whisper.cpp's verbose C-level logging during model load.
42        // whisper-rs 0.13.2 doesn't expose `no_prints` in WhisperContextParameters,
43        // so we install a no-op log callback via the sys crate.
44        suppress_whisper_logging();
45
46        let ctx = whisper_rs::WhisperContext::new_with_params(
47            path_str,
48            whisper_rs::WhisperContextParameters::default(),
49        )
50        .map_err(|e| anyhow::anyhow!("Failed to load whisper model: {:?}", e))?;
51
52        Ok(WhisperEngine { ctx })
53    }
54
55    /// Transcribes a WAV file and returns the text.
56    ///
57    /// Note: This is CPU-bound and blocking. Call via `tokio::task::spawn_blocking`
58    /// in an async context to avoid blocking the async runtime.
59    pub fn transcribe(&self, wav_path: &Path) -> Result<TranscriptionResult> {
60        // Read WAV file
61        let mut reader = hound::WavReader::open(wav_path)
62            .with_context(|| format!("Failed to open WAV file: {}", wav_path.display()))?;
63
64        // Convert i16 samples to f32 (whisper-rs expects f32 in range [-1.0, 1.0])
65        let samples: Vec<f32> = reader
66            .samples::<i16>()
67            .filter_map(|s| s.ok())
68            .map(|s| s as f32 / 32768.0)
69            .collect();
70
71        if samples.is_empty() {
72            return Ok(TranscriptionResult {
73                text: String::new(),
74            });
75        }
76
77        // Set up whisper params: no timestamps, no progress output
78        let mut params =
79            whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
80        params.set_print_special(false);
81        params.set_print_progress(false);
82        params.set_print_realtime(false);
83        params.set_print_timestamps(false);
84        params.set_no_timestamps(true);
85        params.set_single_segment(false);
86
87        // Run transcription
88        let mut state = self
89            .ctx
90            .create_state()
91            .map_err(|e| anyhow::anyhow!("Failed to create whisper state: {:?}", e))?;
92
93        state
94            .full(params, &samples)
95            .map_err(|e| anyhow::anyhow!("Whisper transcription failed: {:?}", e))?;
96
97        // Collect segments
98        let num_segments = state
99            .full_n_segments()
100            .map_err(|e| anyhow::anyhow!("Failed to get segment count: {:?}", e))?;
101
102        let mut text_parts: Vec<String> = Vec::new();
103        for i in 0..num_segments {
104            if let Ok(segment) = state.full_get_segment_text(i) {
105                text_parts.push(segment);
106            }
107        }
108
109        let raw_text = text_parts.join(" ");
110
111        // Strip timestamp brackets like [HH:MM:SS.mmm --> HH:MM:SS.mmm]
112        let clean_text = strip_timestamps(&raw_text);
113        // Filter out Whisper hallucination artifacts (e.g. "[BLANK_AUDIO]")
114        let clean_text = strip_hallucinations(&clean_text);
115        let final_text = clean_text.trim().to_string();
116
117        Ok(TranscriptionResult { text: final_text })
118    }
119}
120
121/// Installs a no-op log callback to suppress whisper.cpp's C-level stderr output.
122///
123/// This must be called before `WhisperContext::new_with_params` to prevent
124/// the verbose model-loading messages from cluttering the terminal.
125fn suppress_whisper_logging() {
126    unsafe {
127        // A C-compatible no-op callback that discards all whisper log messages.
128        unsafe extern "C" fn noop_log(
129            _level: whisper_rs::whisper_rs_sys::ggml_log_level,
130            _text: *const std::ffi::c_char,
131            _user_data: *mut std::ffi::c_void,
132        ) {
133        }
134        whisper_rs::whisper_rs_sys::whisper_log_set(Some(noop_log), std::ptr::null_mut());
135        whisper_rs::whisper_rs_sys::ggml_log_set(Some(noop_log), std::ptr::null_mut());
136    }
137}
138
139/// Known Whisper hallucination phrases that should be treated as silence.
140///
141/// These are bracketed tags or repeated filler phrases that Whisper emits
142/// when the audio contains silence, noise, or non-speech content.
143const WHISPER_HALLUCINATIONS: &[&str] = &[
144    "[BLANK_AUDIO]",
145    "[NO_SPEECH]",
146    "(blank audio)",
147    "(no speech)",
148    "[silence]",
149    "(silence)",
150];
151
152/// Removes known Whisper hallucination artifacts from transcribed text.
153///
154/// If the entire text (after removal) is empty, returns an empty string
155/// so the caller treats it the same as silence.
156fn strip_hallucinations(text: &str) -> String {
157    let mut result = text.to_string();
158    for pattern in WHISPER_HALLUCINATIONS {
159        // Case-insensitive removal
160        while let Some(pos) = result.to_lowercase().find(&pattern.to_lowercase()) {
161            result = format!("{}{}", &result[..pos], &result[pos + pattern.len()..]);
162        }
163    }
164    result
165}
166
167/// Strips whisper timestamp annotations from transcribed text.
168///
169/// Example: "[00:00:00.000 --> 00:00:05.000] Hello world" → "Hello world"
170fn strip_timestamps(text: &str) -> String {
171    // Remove patterns like [HH:MM:SS.mmm --> HH:MM:SS.mmm]
172    let mut result = text.to_string();
173    while let Some(start) = result.find('[') {
174        if let Some(end) = result[start..].find(']') {
175            let bracket_content = &result[start + 1..start + end];
176            // Only remove if it looks like a timestamp (contains "-->")
177            if bracket_content.contains("-->") {
178                result = format!("{}{}", &result[..start], &result[start + end + 1..]);
179            } else {
180                break;
181            }
182        } else {
183            break;
184        }
185    }
186    result
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_is_model_valid_nonexistent() {
195        assert!(!is_model_valid(Path::new("/nonexistent/path/model.bin")));
196    }
197
198    #[test]
199    fn test_is_model_valid_small_file() {
200        // Create a tiny temp file (< 1MB)
201        let tmp = std::env::temp_dir().join("test-tiny.bin");
202        std::fs::write(&tmp, b"tiny").unwrap();
203        assert!(!is_model_valid(&tmp));
204        std::fs::remove_file(&tmp).ok();
205    }
206
207    #[test]
208    fn test_strip_timestamps_with_arrow() {
209        let input = "[00:00:00.000 --> 00:00:05.000]  Hello world";
210        let result = strip_timestamps(input);
211        assert!(!result.contains("-->"));
212        assert!(result.contains("Hello world"));
213    }
214
215    #[test]
216    fn test_strip_timestamps_no_timestamps() {
217        let input = "Hello world";
218        assert_eq!(strip_timestamps(input), "Hello world");
219    }
220
221    #[test]
222    fn test_strip_timestamps_preserves_non_timestamp_brackets() {
223        let input = "Hello [world]";
224        let result = strip_timestamps(input);
225        assert!(result.contains("[world]")); // Non-timestamp bracket preserved
226    }
227
228    #[test]
229    fn test_strip_hallucinations_blank_audio() {
230        assert_eq!(strip_hallucinations("[BLANK_AUDIO]").trim(), "");
231    }
232
233    #[test]
234    fn test_strip_hallucinations_case_insensitive() {
235        assert_eq!(strip_hallucinations("[blank_audio]").trim(), "");
236        assert_eq!(strip_hallucinations("[Blank_Audio]").trim(), "");
237    }
238
239    #[test]
240    fn test_strip_hallucinations_preserves_real_text() {
241        assert_eq!(strip_hallucinations("hello world"), "hello world");
242    }
243
244    #[test]
245    fn test_strip_hallucinations_mixed() {
246        let result = strip_hallucinations("[BLANK_AUDIO] hello [BLANK_AUDIO]");
247        assert_eq!(result.trim(), "hello");
248    }
249}