Skip to main content

opencode_voice/transcribe/
engine.rs

1//! whisper-rs transcription engine: in-process speech-to-text.
2//!
3//! Replaces the whisper-cli subprocess with native whisper-rs bindings.
4
5use anyhow::{Context, Result};
6use std::path::Path;
7
8/// Result of a transcription operation.
9pub struct TranscriptionResult {
10    pub text: String,
11}
12
13/// Checks if a model file is valid (exists and is > 1MB).
14pub fn is_model_valid(path: &Path) -> bool {
15    path.exists()
16        && path
17            .metadata()
18            .map(|m| m.len() > 1_000_000)
19            .unwrap_or(false)
20}
21
22/// In-process whisper transcription engine.
23pub struct WhisperEngine {
24    ctx: whisper_rs::WhisperContext,
25}
26
27impl WhisperEngine {
28    /// Loads a GGML model file and creates a WhisperEngine.
29    pub fn new(model_path: &Path) -> Result<Self> {
30        if !model_path.exists() {
31            anyhow::bail!(
32                "Whisper model not found at {}. Run 'opencode-voice setup' to download it.",
33                model_path.display()
34            );
35        }
36
37        let path_str = model_path
38            .to_str()
39            .context("Model path contains invalid UTF-8")?;
40
41        // Suppress whisper.cpp's verbose C-level logging during model load.
42        // whisper-rs 0.13.2 doesn't expose `no_prints` in WhisperContextParameters,
43        // so we install a no-op log callback via the sys crate.
44        suppress_whisper_logging();
45
46        let ctx = whisper_rs::WhisperContext::new_with_params(
47            path_str,
48            whisper_rs::WhisperContextParameters::default(),
49        )
50        .map_err(|e| anyhow::anyhow!("Failed to load whisper model: {:?}", e))?;
51
52        Ok(WhisperEngine { ctx })
53    }
54
55    /// Transcribes a WAV file and returns the text.
56    ///
57    /// Note: This is CPU-bound and blocking. Call via `tokio::task::spawn_blocking`
58    /// in an async context to avoid blocking the async runtime.
59    pub fn transcribe(&self, wav_path: &Path) -> Result<TranscriptionResult> {
60        // Read WAV file
61        let mut reader = hound::WavReader::open(wav_path)
62            .with_context(|| format!("Failed to open WAV file: {}", wav_path.display()))?;
63
64        // Convert i16 samples to f32 (whisper-rs expects f32 in range [-1.0, 1.0])
65        let samples: Vec<f32> = reader
66            .samples::<i16>()
67            .filter_map(|s| s.ok())
68            .map(|s| s as f32 / 32768.0)
69            .collect();
70
71        if samples.is_empty() {
72            return Ok(TranscriptionResult {
73                text: String::new(),
74            });
75        }
76
77        // Set up whisper params: no timestamps, no progress output
78        let mut params =
79            whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
80        params.set_print_special(false);
81        params.set_print_progress(false);
82        params.set_print_realtime(false);
83        params.set_print_timestamps(false);
84        params.set_no_timestamps(true);
85        params.set_single_segment(false);
86
87        // Run transcription
88        let mut state = self
89            .ctx
90            .create_state()
91            .map_err(|e| anyhow::anyhow!("Failed to create whisper state: {:?}", e))?;
92
93        state
94            .full(params, &samples)
95            .map_err(|e| anyhow::anyhow!("Whisper transcription failed: {:?}", e))?;
96
97        // Collect segments
98        let num_segments = state
99            .full_n_segments()
100            .map_err(|e| anyhow::anyhow!("Failed to get segment count: {:?}", e))?;
101
102        let mut text_parts: Vec<String> = Vec::new();
103        for i in 0..num_segments {
104            if let Ok(segment) = state.full_get_segment_text(i) {
105                text_parts.push(segment);
106            }
107        }
108
109        let raw_text = text_parts.join(" ");
110
111        // Strip timestamp brackets like [HH:MM:SS.mmm --> HH:MM:SS.mmm]
112        let clean_text = strip_timestamps(&raw_text);
113        let final_text = clean_text.trim().to_string();
114
115        Ok(TranscriptionResult { text: final_text })
116    }
117}
118
119/// Installs a no-op log callback to suppress whisper.cpp's C-level stderr output.
120///
121/// This must be called before `WhisperContext::new_with_params` to prevent
122/// the verbose model-loading messages from cluttering the terminal.
123fn suppress_whisper_logging() {
124    unsafe {
125        // A C-compatible no-op callback that discards all whisper log messages.
126        unsafe extern "C" fn noop_log(
127            _level: whisper_rs::whisper_rs_sys::ggml_log_level,
128            _text: *const std::ffi::c_char,
129            _user_data: *mut std::ffi::c_void,
130        ) {
131        }
132        whisper_rs::whisper_rs_sys::whisper_log_set(Some(noop_log), std::ptr::null_mut());
133        whisper_rs::whisper_rs_sys::ggml_log_set(Some(noop_log), std::ptr::null_mut());
134    }
135}
136
137/// Strips whisper timestamp annotations from transcribed text.
138///
139/// Example: "[00:00:00.000 --> 00:00:05.000] Hello world" → "Hello world"
140fn strip_timestamps(text: &str) -> String {
141    // Remove patterns like [HH:MM:SS.mmm --> HH:MM:SS.mmm]
142    let mut result = text.to_string();
143    while let Some(start) = result.find('[') {
144        if let Some(end) = result[start..].find(']') {
145            let bracket_content = &result[start + 1..start + end];
146            // Only remove if it looks like a timestamp (contains "-->")
147            if bracket_content.contains("-->") {
148                result = format!("{}{}", &result[..start], &result[start + end + 1..]);
149            } else {
150                break;
151            }
152        } else {
153            break;
154        }
155    }
156    result
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_is_model_valid_nonexistent() {
165        assert!(!is_model_valid(Path::new("/nonexistent/path/model.bin")));
166    }
167
168    #[test]
169    fn test_is_model_valid_small_file() {
170        // Create a tiny temp file (< 1MB)
171        let tmp = std::env::temp_dir().join("test-tiny.bin");
172        std::fs::write(&tmp, b"tiny").unwrap();
173        assert!(!is_model_valid(&tmp));
174        std::fs::remove_file(&tmp).ok();
175    }
176
177    #[test]
178    fn test_strip_timestamps_with_arrow() {
179        let input = "[00:00:00.000 --> 00:00:05.000]  Hello world";
180        let result = strip_timestamps(input);
181        assert!(!result.contains("-->"));
182        assert!(result.contains("Hello world"));
183    }
184
185    #[test]
186    fn test_strip_timestamps_no_timestamps() {
187        let input = "Hello world";
188        assert_eq!(strip_timestamps(input), "Hello world");
189    }
190
191    #[test]
192    fn test_strip_timestamps_preserves_non_timestamp_brackets() {
193        let input = "Hello [world]";
194        let result = strip_timestamps(input);
195        assert!(result.contains("[world]")); // Non-timestamp bracket preserved
196    }
197}