opencode_voice/transcribe/
engine.rs1use anyhow::{Context, Result};
6use std::path::Path;
7
8pub struct TranscriptionResult {
10 pub text: String,
11}
12
13pub fn is_model_valid(path: &Path) -> bool {
15 path.exists()
16 && path
17 .metadata()
18 .map(|m| m.len() > 1_000_000)
19 .unwrap_or(false)
20}
21
22pub struct WhisperEngine {
24 ctx: whisper_rs::WhisperContext,
25}
26
27impl WhisperEngine {
28 pub fn new(model_path: &Path) -> Result<Self> {
30 if !model_path.exists() {
31 anyhow::bail!(
32 "Whisper model not found at {}. Run 'opencode-voice setup' to download it.",
33 model_path.display()
34 );
35 }
36
37 let path_str = model_path
38 .to_str()
39 .context("Model path contains invalid UTF-8")?;
40
41 suppress_whisper_logging();
45
46 let ctx = whisper_rs::WhisperContext::new_with_params(
47 path_str,
48 whisper_rs::WhisperContextParameters::default(),
49 )
50 .map_err(|e| anyhow::anyhow!("Failed to load whisper model: {:?}", e))?;
51
52 Ok(WhisperEngine { ctx })
53 }
54
55 pub fn transcribe(&self, wav_path: &Path) -> Result<TranscriptionResult> {
60 let mut reader = hound::WavReader::open(wav_path)
62 .with_context(|| format!("Failed to open WAV file: {}", wav_path.display()))?;
63
64 let samples: Vec<f32> = reader
66 .samples::<i16>()
67 .filter_map(|s| s.ok())
68 .map(|s| s as f32 / 32768.0)
69 .collect();
70
71 if samples.is_empty() {
72 return Ok(TranscriptionResult {
73 text: String::new(),
74 });
75 }
76
77 let mut params =
79 whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
80 params.set_print_special(false);
81 params.set_print_progress(false);
82 params.set_print_realtime(false);
83 params.set_print_timestamps(false);
84 params.set_no_timestamps(true);
85 params.set_single_segment(false);
86
87 let mut state = self
89 .ctx
90 .create_state()
91 .map_err(|e| anyhow::anyhow!("Failed to create whisper state: {:?}", e))?;
92
93 state
94 .full(params, &samples)
95 .map_err(|e| anyhow::anyhow!("Whisper transcription failed: {:?}", e))?;
96
97 let num_segments = state
99 .full_n_segments()
100 .map_err(|e| anyhow::anyhow!("Failed to get segment count: {:?}", e))?;
101
102 let mut text_parts: Vec<String> = Vec::new();
103 for i in 0..num_segments {
104 if let Ok(segment) = state.full_get_segment_text(i) {
105 text_parts.push(segment);
106 }
107 }
108
109 let raw_text = text_parts.join(" ");
110
111 let clean_text = strip_timestamps(&raw_text);
113 let final_text = clean_text.trim().to_string();
114
115 Ok(TranscriptionResult { text: final_text })
116 }
117}
118
119fn suppress_whisper_logging() {
124 unsafe {
125 unsafe extern "C" fn noop_log(
127 _level: whisper_rs::whisper_rs_sys::ggml_log_level,
128 _text: *const std::ffi::c_char,
129 _user_data: *mut std::ffi::c_void,
130 ) {
131 }
132 whisper_rs::whisper_rs_sys::whisper_log_set(Some(noop_log), std::ptr::null_mut());
133 whisper_rs::whisper_rs_sys::ggml_log_set(Some(noop_log), std::ptr::null_mut());
134 }
135}
136
137fn strip_timestamps(text: &str) -> String {
141 let mut result = text.to_string();
143 while let Some(start) = result.find('[') {
144 if let Some(end) = result[start..].find(']') {
145 let bracket_content = &result[start + 1..start + end];
146 if bracket_content.contains("-->") {
148 result = format!("{}{}", &result[..start], &result[start + end + 1..]);
149 } else {
150 break;
151 }
152 } else {
153 break;
154 }
155 }
156 result
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_is_model_valid_nonexistent() {
165 assert!(!is_model_valid(Path::new("/nonexistent/path/model.bin")));
166 }
167
168 #[test]
169 fn test_is_model_valid_small_file() {
170 let tmp = std::env::temp_dir().join("test-tiny.bin");
172 std::fs::write(&tmp, b"tiny").unwrap();
173 assert!(!is_model_valid(&tmp));
174 std::fs::remove_file(&tmp).ok();
175 }
176
177 #[test]
178 fn test_strip_timestamps_with_arrow() {
179 let input = "[00:00:00.000 --> 00:00:05.000] Hello world";
180 let result = strip_timestamps(input);
181 assert!(!result.contains("-->"));
182 assert!(result.contains("Hello world"));
183 }
184
185 #[test]
186 fn test_strip_timestamps_no_timestamps() {
187 let input = "Hello world";
188 assert_eq!(strip_timestamps(input), "Hello world");
189 }
190
191 #[test]
192 fn test_strip_timestamps_preserves_non_timestamp_brackets() {
193 let input = "Hello [world]";
194 let result = strip_timestamps(input);
195 assert!(result.contains("[world]")); }
197}