opencode_voice/transcribe/
engine.rs1use anyhow::{Context, Result};
6use std::path::Path;
7
8pub struct TranscriptionResult {
10 pub text: String,
11}
12
13pub fn is_model_valid(path: &Path) -> bool {
15 path.exists()
16 && path
17 .metadata()
18 .map(|m| m.len() > 1_000_000)
19 .unwrap_or(false)
20}
21
22pub struct WhisperEngine {
24 ctx: whisper_rs::WhisperContext,
25}
26
27impl WhisperEngine {
28 pub fn new(model_path: &Path) -> Result<Self> {
30 if !model_path.exists() {
31 anyhow::bail!(
32 "Whisper model not found at {}. Run 'opencode-voice setup' to download it.",
33 model_path.display()
34 );
35 }
36
37 let path_str = model_path
38 .to_str()
39 .context("Model path contains invalid UTF-8")?;
40
41 suppress_whisper_logging();
45
46 let ctx = whisper_rs::WhisperContext::new_with_params(
47 path_str,
48 whisper_rs::WhisperContextParameters::default(),
49 )
50 .map_err(|e| anyhow::anyhow!("Failed to load whisper model: {:?}", e))?;
51
52 Ok(WhisperEngine { ctx })
53 }
54
55 pub fn transcribe(&self, wav_path: &Path) -> Result<TranscriptionResult> {
60 let mut reader = hound::WavReader::open(wav_path)
62 .with_context(|| format!("Failed to open WAV file: {}", wav_path.display()))?;
63
64 let samples: Vec<f32> = reader
66 .samples::<i16>()
67 .filter_map(|s| s.ok())
68 .map(|s| s as f32 / 32768.0)
69 .collect();
70
71 if samples.is_empty() {
72 return Ok(TranscriptionResult {
73 text: String::new(),
74 });
75 }
76
77 let mut params =
79 whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
80 params.set_print_special(false);
81 params.set_print_progress(false);
82 params.set_print_realtime(false);
83 params.set_print_timestamps(false);
84 params.set_no_timestamps(true);
85 params.set_single_segment(false);
86
87 let mut state = self
89 .ctx
90 .create_state()
91 .map_err(|e| anyhow::anyhow!("Failed to create whisper state: {:?}", e))?;
92
93 state
94 .full(params, &samples)
95 .map_err(|e| anyhow::anyhow!("Whisper transcription failed: {:?}", e))?;
96
97 let num_segments = state
99 .full_n_segments()
100 .map_err(|e| anyhow::anyhow!("Failed to get segment count: {:?}", e))?;
101
102 let mut text_parts: Vec<String> = Vec::new();
103 for i in 0..num_segments {
104 if let Ok(segment) = state.full_get_segment_text(i) {
105 text_parts.push(segment);
106 }
107 }
108
109 let raw_text = text_parts.join(" ");
110
111 let clean_text = strip_timestamps(&raw_text);
113 let clean_text = strip_hallucinations(&clean_text);
115 let final_text = clean_text.trim().to_string();
116
117 Ok(TranscriptionResult { text: final_text })
118 }
119}
120
121fn suppress_whisper_logging() {
126 unsafe {
127 unsafe extern "C" fn noop_log(
129 _level: whisper_rs::whisper_rs_sys::ggml_log_level,
130 _text: *const std::ffi::c_char,
131 _user_data: *mut std::ffi::c_void,
132 ) {
133 }
134 whisper_rs::whisper_rs_sys::whisper_log_set(Some(noop_log), std::ptr::null_mut());
135 whisper_rs::whisper_rs_sys::ggml_log_set(Some(noop_log), std::ptr::null_mut());
136 }
137}
138
139const WHISPER_HALLUCINATIONS: &[&str] = &[
144 "[BLANK_AUDIO]",
145 "[NO_SPEECH]",
146 "(blank audio)",
147 "(no speech)",
148 "[silence]",
149 "(silence)",
150];
151
152fn strip_hallucinations(text: &str) -> String {
157 let mut result = text.to_string();
158 for pattern in WHISPER_HALLUCINATIONS {
159 while let Some(pos) = result.to_lowercase().find(&pattern.to_lowercase()) {
161 result = format!("{}{}", &result[..pos], &result[pos + pattern.len()..]);
162 }
163 }
164 result
165}
166
167fn strip_timestamps(text: &str) -> String {
171 let mut result = text.to_string();
173 while let Some(start) = result.find('[') {
174 if let Some(end) = result[start..].find(']') {
175 let bracket_content = &result[start + 1..start + end];
176 if bracket_content.contains("-->") {
178 result = format!("{}{}", &result[..start], &result[start + end + 1..]);
179 } else {
180 break;
181 }
182 } else {
183 break;
184 }
185 }
186 result
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_is_model_valid_nonexistent() {
195 assert!(!is_model_valid(Path::new("/nonexistent/path/model.bin")));
196 }
197
198 #[test]
199 fn test_is_model_valid_small_file() {
200 let tmp = std::env::temp_dir().join("test-tiny.bin");
202 std::fs::write(&tmp, b"tiny").unwrap();
203 assert!(!is_model_valid(&tmp));
204 std::fs::remove_file(&tmp).ok();
205 }
206
207 #[test]
208 fn test_strip_timestamps_with_arrow() {
209 let input = "[00:00:00.000 --> 00:00:05.000] Hello world";
210 let result = strip_timestamps(input);
211 assert!(!result.contains("-->"));
212 assert!(result.contains("Hello world"));
213 }
214
215 #[test]
216 fn test_strip_timestamps_no_timestamps() {
217 let input = "Hello world";
218 assert_eq!(strip_timestamps(input), "Hello world");
219 }
220
221 #[test]
222 fn test_strip_timestamps_preserves_non_timestamp_brackets() {
223 let input = "Hello [world]";
224 let result = strip_timestamps(input);
225 assert!(result.contains("[world]")); }
227
228 #[test]
229 fn test_strip_hallucinations_blank_audio() {
230 assert_eq!(strip_hallucinations("[BLANK_AUDIO]").trim(), "");
231 }
232
233 #[test]
234 fn test_strip_hallucinations_case_insensitive() {
235 assert_eq!(strip_hallucinations("[blank_audio]").trim(), "");
236 assert_eq!(strip_hallucinations("[Blank_Audio]").trim(), "");
237 }
238
239 #[test]
240 fn test_strip_hallucinations_preserves_real_text() {
241 assert_eq!(strip_hallucinations("hello world"), "hello world");
242 }
243
244 #[test]
245 fn test_strip_hallucinations_mixed() {
246 let result = strip_hallucinations("[BLANK_AUDIO] hello [BLANK_AUDIO]");
247 assert_eq!(result.trim(), "hello");
248 }
249}