Skip to main content

opencode_voice/transcribe/
engine.rs

1//! whisper-rs transcription engine: in-process speech-to-text.
2//!
3//! Replaces the whisper-cli subprocess with native whisper-rs bindings.
4
5use anyhow::{Context, Result};
6use std::path::Path;
7
8/// Result of a transcription operation.
9pub struct TranscriptionResult {
10    pub text: String,
11}
12
13/// Checks if a model file is valid (exists and is > 1MB).
14pub fn is_model_valid(path: &Path) -> bool {
15    path.exists()
16        && path
17            .metadata()
18            .map(|m| m.len() > 1_000_000)
19            .unwrap_or(false)
20}
21
22/// In-process whisper transcription engine.
23pub struct WhisperEngine {
24    ctx: whisper_rs::WhisperContext,
25    /// Whether the loaded model is multilingual (requires explicit language hint).
26    multilingual: bool,
27}
28
29impl WhisperEngine {
30    /// Loads a GGML model file and creates a WhisperEngine.
31    ///
32    /// When `multilingual` is `true`, the engine will set `language = "en"` on
33    /// each transcription request to avoid auto-detection overhead.
34    pub fn new(model_path: &Path, multilingual: bool) -> Result<Self> {
35        if !model_path.exists() {
36            anyhow::bail!(
37                "Whisper model not found at {}. Run 'opencode-voice setup' to download it.",
38                model_path.display()
39            );
40        }
41
42        let path_str = model_path
43            .to_str()
44            .context("Model path contains invalid UTF-8")?;
45
46        // Suppress whisper.cpp's verbose C-level logging during model load.
47        // whisper-rs 0.13.2 doesn't expose `no_prints` in WhisperContextParameters,
48        // so we install a no-op log callback via the sys crate.
49        suppress_whisper_logging();
50
51        let ctx = whisper_rs::WhisperContext::new_with_params(
52            path_str,
53            whisper_rs::WhisperContextParameters::default(),
54        )
55        .map_err(|e| anyhow::anyhow!("Failed to load whisper model: {:?}", e))?;
56
57        Ok(WhisperEngine { ctx, multilingual })
58    }
59
60    /// Transcribes a WAV file and returns the text.
61    ///
62    /// Note: This is CPU-bound and blocking. Call via `tokio::task::spawn_blocking`
63    /// in an async context to avoid blocking the async runtime.
64    pub fn transcribe(&self, wav_path: &Path) -> Result<TranscriptionResult> {
65        // Read WAV file
66        let mut reader = hound::WavReader::open(wav_path)
67            .with_context(|| format!("Failed to open WAV file: {}", wav_path.display()))?;
68
69        // Convert i16 samples to f32 (whisper-rs expects f32 in range [-1.0, 1.0])
70        let samples: Vec<f32> = reader
71            .samples::<i16>()
72            .filter_map(|s| s.ok())
73            .map(|s| s as f32 / 32768.0)
74            .collect();
75
76        if samples.is_empty() {
77            return Ok(TranscriptionResult {
78                text: String::new(),
79            });
80        }
81
82        // Set up whisper params: no timestamps, no progress output
83        let mut params =
84            whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
85        params.set_print_special(false);
86        params.set_print_progress(false);
87        params.set_print_realtime(false);
88        params.set_print_timestamps(false);
89        params.set_no_timestamps(true);
90        params.set_single_segment(false);
91
92        // Multilingual models need an explicit language hint to avoid
93        // auto-detection overhead and ensure English transcription.
94        if self.multilingual {
95            params.set_language(Some("en"));
96        }
97
98        // Run transcription
99        let mut state = self
100            .ctx
101            .create_state()
102            .map_err(|e| anyhow::anyhow!("Failed to create whisper state: {:?}", e))?;
103
104        state
105            .full(params, &samples)
106            .map_err(|e| anyhow::anyhow!("Whisper transcription failed: {:?}", e))?;
107
108        // Collect segments
109        let num_segments = state
110            .full_n_segments()
111            .map_err(|e| anyhow::anyhow!("Failed to get segment count: {:?}", e))?;
112
113        let mut text_parts: Vec<String> = Vec::new();
114        for i in 0..num_segments {
115            if let Ok(segment) = state.full_get_segment_text(i) {
116                text_parts.push(segment);
117            }
118        }
119
120        let raw_text = text_parts.join(" ");
121
122        // Strip timestamp brackets like [HH:MM:SS.mmm --> HH:MM:SS.mmm]
123        let clean_text = strip_timestamps(&raw_text);
124        // Filter out Whisper hallucination artifacts (e.g. "[BLANK_AUDIO]")
125        let clean_text = strip_hallucinations(&clean_text);
126        let final_text = clean_text.trim().to_string();
127
128        Ok(TranscriptionResult { text: final_text })
129    }
130}
131
132/// Installs a no-op log callback to suppress whisper.cpp's C-level stderr output.
133///
134/// This must be called before `WhisperContext::new_with_params` to prevent
135/// the verbose model-loading messages from cluttering the terminal.
136fn suppress_whisper_logging() {
137    unsafe {
138        // A C-compatible no-op callback that discards all whisper log messages.
139        unsafe extern "C" fn noop_log(
140            _level: whisper_rs::whisper_rs_sys::ggml_log_level,
141            _text: *const std::ffi::c_char,
142            _user_data: *mut std::ffi::c_void,
143        ) {
144        }
145        whisper_rs::whisper_rs_sys::whisper_log_set(Some(noop_log), std::ptr::null_mut());
146        whisper_rs::whisper_rs_sys::ggml_log_set(Some(noop_log), std::ptr::null_mut());
147    }
148}
149
150/// Known Whisper hallucination phrases that should be treated as silence.
151///
152/// These are bracketed tags or repeated filler phrases that Whisper emits
153/// when the audio contains silence, noise, or non-speech content.
154const WHISPER_HALLUCINATIONS: &[&str] = &[
155    "[BLANK_AUDIO]",
156    "[NO_SPEECH]",
157    "(blank audio)",
158    "(no speech)",
159    "[silence]",
160    "(silence)",
161];
162
163/// Removes known Whisper hallucination artifacts from transcribed text.
164///
165/// If the entire text (after removal) is empty, returns an empty string
166/// so the caller treats it the same as silence.
167fn strip_hallucinations(text: &str) -> String {
168    let mut result = text.to_string();
169    for pattern in WHISPER_HALLUCINATIONS {
170        // Case-insensitive removal
171        while let Some(pos) = result.to_lowercase().find(&pattern.to_lowercase()) {
172            result = format!("{}{}", &result[..pos], &result[pos + pattern.len()..]);
173        }
174    }
175    result
176}
177
178/// Strips whisper timestamp annotations from transcribed text.
179///
180/// Example: "[00:00:00.000 --> 00:00:05.000] Hello world" → "Hello world"
181fn strip_timestamps(text: &str) -> String {
182    // Remove patterns like [HH:MM:SS.mmm --> HH:MM:SS.mmm]
183    let mut result = text.to_string();
184    while let Some(start) = result.find('[') {
185        if let Some(end) = result[start..].find(']') {
186            let bracket_content = &result[start + 1..start + end];
187            // Only remove if it looks like a timestamp (contains "-->")
188            if bracket_content.contains("-->") {
189                result = format!("{}{}", &result[..start], &result[start + end + 1..]);
190            } else {
191                break;
192            }
193        } else {
194            break;
195        }
196    }
197    result
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_is_model_valid_nonexistent() {
206        assert!(!is_model_valid(Path::new("/nonexistent/path/model.bin")));
207    }
208
209    #[test]
210    fn test_is_model_valid_small_file() {
211        // Create a tiny temp file (< 1MB)
212        let tmp = std::env::temp_dir().join("test-tiny.bin");
213        std::fs::write(&tmp, b"tiny").unwrap();
214        assert!(!is_model_valid(&tmp));
215        std::fs::remove_file(&tmp).ok();
216    }
217
218    #[test]
219    fn test_strip_timestamps_with_arrow() {
220        let input = "[00:00:00.000 --> 00:00:05.000]  Hello world";
221        let result = strip_timestamps(input);
222        assert!(!result.contains("-->"));
223        assert!(result.contains("Hello world"));
224    }
225
226    #[test]
227    fn test_strip_timestamps_no_timestamps() {
228        let input = "Hello world";
229        assert_eq!(strip_timestamps(input), "Hello world");
230    }
231
232    #[test]
233    fn test_strip_timestamps_preserves_non_timestamp_brackets() {
234        let input = "Hello [world]";
235        let result = strip_timestamps(input);
236        assert!(result.contains("[world]")); // Non-timestamp bracket preserved
237    }
238
239    #[test]
240    fn test_strip_hallucinations_blank_audio() {
241        assert_eq!(strip_hallucinations("[BLANK_AUDIO]").trim(), "");
242    }
243
244    #[test]
245    fn test_strip_hallucinations_case_insensitive() {
246        assert_eq!(strip_hallucinations("[blank_audio]").trim(), "");
247        assert_eq!(strip_hallucinations("[Blank_Audio]").trim(), "");
248    }
249
250    #[test]
251    fn test_strip_hallucinations_preserves_real_text() {
252        assert_eq!(strip_hallucinations("hello world"), "hello world");
253    }
254
255    #[test]
256    fn test_strip_hallucinations_mixed() {
257        let result = strip_hallucinations("[BLANK_AUDIO] hello [BLANK_AUDIO]");
258        assert_eq!(result.trim(), "hello");
259    }
260}