opencode_voice/transcribe/
engine.rs1use anyhow::{Context, Result};
6use std::path::Path;
7
8pub struct TranscriptionResult {
10 pub text: String,
11}
12
13pub fn is_model_valid(path: &Path) -> bool {
15 path.exists()
16 && path
17 .metadata()
18 .map(|m| m.len() > 1_000_000)
19 .unwrap_or(false)
20}
21
22pub struct WhisperEngine {
24 ctx: whisper_rs::WhisperContext,
25 multilingual: bool,
27}
28
29impl WhisperEngine {
30 pub fn new(model_path: &Path, multilingual: bool) -> Result<Self> {
35 if !model_path.exists() {
36 anyhow::bail!(
37 "Whisper model not found at {}. Run 'opencode-voice setup' to download it.",
38 model_path.display()
39 );
40 }
41
42 let path_str = model_path
43 .to_str()
44 .context("Model path contains invalid UTF-8")?;
45
46 suppress_whisper_logging();
50
51 let ctx = whisper_rs::WhisperContext::new_with_params(
52 path_str,
53 whisper_rs::WhisperContextParameters::default(),
54 )
55 .map_err(|e| anyhow::anyhow!("Failed to load whisper model: {:?}", e))?;
56
57 Ok(WhisperEngine { ctx, multilingual })
58 }
59
60 pub fn transcribe(&self, wav_path: &Path) -> Result<TranscriptionResult> {
65 let mut reader = hound::WavReader::open(wav_path)
67 .with_context(|| format!("Failed to open WAV file: {}", wav_path.display()))?;
68
69 let samples: Vec<f32> = reader
71 .samples::<i16>()
72 .filter_map(|s| s.ok())
73 .map(|s| s as f32 / 32768.0)
74 .collect();
75
76 if samples.is_empty() {
77 return Ok(TranscriptionResult {
78 text: String::new(),
79 });
80 }
81
82 let mut params =
84 whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
85 params.set_print_special(false);
86 params.set_print_progress(false);
87 params.set_print_realtime(false);
88 params.set_print_timestamps(false);
89 params.set_no_timestamps(true);
90 params.set_single_segment(false);
91
92 if self.multilingual {
95 params.set_language(Some("en"));
96 }
97
98 let mut state = self
100 .ctx
101 .create_state()
102 .map_err(|e| anyhow::anyhow!("Failed to create whisper state: {:?}", e))?;
103
104 state
105 .full(params, &samples)
106 .map_err(|e| anyhow::anyhow!("Whisper transcription failed: {:?}", e))?;
107
108 let num_segments = state
110 .full_n_segments()
111 .map_err(|e| anyhow::anyhow!("Failed to get segment count: {:?}", e))?;
112
113 let mut text_parts: Vec<String> = Vec::new();
114 for i in 0..num_segments {
115 if let Ok(segment) = state.full_get_segment_text(i) {
116 text_parts.push(segment);
117 }
118 }
119
120 let raw_text = text_parts.join(" ");
121
122 let clean_text = strip_timestamps(&raw_text);
124 let clean_text = strip_hallucinations(&clean_text);
126 let final_text = clean_text.trim().to_string();
127
128 Ok(TranscriptionResult { text: final_text })
129 }
130}
131
132fn suppress_whisper_logging() {
137 unsafe {
138 unsafe extern "C" fn noop_log(
140 _level: whisper_rs::whisper_rs_sys::ggml_log_level,
141 _text: *const std::ffi::c_char,
142 _user_data: *mut std::ffi::c_void,
143 ) {
144 }
145 whisper_rs::whisper_rs_sys::whisper_log_set(Some(noop_log), std::ptr::null_mut());
146 whisper_rs::whisper_rs_sys::ggml_log_set(Some(noop_log), std::ptr::null_mut());
147 }
148}
149
150const WHISPER_HALLUCINATIONS: &[&str] = &[
155 "[BLANK_AUDIO]",
156 "[NO_SPEECH]",
157 "(blank audio)",
158 "(no speech)",
159 "[silence]",
160 "(silence)",
161];
162
163fn strip_hallucinations(text: &str) -> String {
168 let mut result = text.to_string();
169 for pattern in WHISPER_HALLUCINATIONS {
170 while let Some(pos) = result.to_lowercase().find(&pattern.to_lowercase()) {
172 result = format!("{}{}", &result[..pos], &result[pos + pattern.len()..]);
173 }
174 }
175 result
176}
177
178fn strip_timestamps(text: &str) -> String {
182 let mut result = text.to_string();
184 while let Some(start) = result.find('[') {
185 if let Some(end) = result[start..].find(']') {
186 let bracket_content = &result[start + 1..start + end];
187 if bracket_content.contains("-->") {
189 result = format!("{}{}", &result[..start], &result[start + end + 1..]);
190 } else {
191 break;
192 }
193 } else {
194 break;
195 }
196 }
197 result
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_is_model_valid_nonexistent() {
206 assert!(!is_model_valid(Path::new("/nonexistent/path/model.bin")));
207 }
208
209 #[test]
210 fn test_is_model_valid_small_file() {
211 let tmp = std::env::temp_dir().join("test-tiny.bin");
213 std::fs::write(&tmp, b"tiny").unwrap();
214 assert!(!is_model_valid(&tmp));
215 std::fs::remove_file(&tmp).ok();
216 }
217
218 #[test]
219 fn test_strip_timestamps_with_arrow() {
220 let input = "[00:00:00.000 --> 00:00:05.000] Hello world";
221 let result = strip_timestamps(input);
222 assert!(!result.contains("-->"));
223 assert!(result.contains("Hello world"));
224 }
225
226 #[test]
227 fn test_strip_timestamps_no_timestamps() {
228 let input = "Hello world";
229 assert_eq!(strip_timestamps(input), "Hello world");
230 }
231
232 #[test]
233 fn test_strip_timestamps_preserves_non_timestamp_brackets() {
234 let input = "Hello [world]";
235 let result = strip_timestamps(input);
236 assert!(result.contains("[world]")); }
238
239 #[test]
240 fn test_strip_hallucinations_blank_audio() {
241 assert_eq!(strip_hallucinations("[BLANK_AUDIO]").trim(), "");
242 }
243
244 #[test]
245 fn test_strip_hallucinations_case_insensitive() {
246 assert_eq!(strip_hallucinations("[blank_audio]").trim(), "");
247 assert_eq!(strip_hallucinations("[Blank_Audio]").trim(), "");
248 }
249
250 #[test]
251 fn test_strip_hallucinations_preserves_real_text() {
252 assert_eq!(strip_hallucinations("hello world"), "hello world");
253 }
254
255 #[test]
256 fn test_strip_hallucinations_mixed() {
257 let result = strip_hallucinations("[BLANK_AUDIO] hello [BLANK_AUDIO]");
258 assert_eq!(result.trim(), "hello");
259 }
260}