opencode_voice/transcribe/
setup.rs1use anyhow::{bail, Context, Result};
7use futures::StreamExt;
8use indicatif::{ProgressBar, ProgressStyle};
9use std::path::PathBuf;
10use tokio::fs as tokio_fs;
11use tokio::io::AsyncWriteExt;
12
13use crate::config::ModelSize;
14use crate::transcribe::engine::is_model_valid;
15
16const HUGGINGFACE_BASE_URL: &str =
18 "https://huggingface.co/ggerganov/whisper.cpp/resolve/main";
19
20pub fn get_model_path(data_dir: &PathBuf, model_size: &ModelSize) -> PathBuf {
25 data_dir
26 .join("models")
27 .join(format!("ggml-{}.bin", model_size))
28}
29
30pub fn is_whisper_ready(data_dir: &PathBuf, model_size: &ModelSize) -> bool {
34 let path = get_model_path(data_dir, model_size);
35 is_model_valid(&path)
36}
37
38pub async fn download_model(data_dir: &PathBuf, model_size: &ModelSize) -> Result<()> {
45 let models_dir = data_dir.join("models");
46 tokio_fs::create_dir_all(&models_dir)
47 .await
48 .with_context(|| format!("Failed to create models directory: {}", models_dir.display()))?;
49
50 let model_filename = format!("ggml-{}.bin", model_size);
51 let url = format!("{}/{}", HUGGINGFACE_BASE_URL, model_filename);
52 let final_path = models_dir.join(&model_filename);
53 let tmp_path = models_dir.join(format!("{}.tmp", model_filename));
54
55 println!("Downloading whisper model {} from HuggingFace…", model_size);
56 println!(" URL: {}", url);
57
58 let client = reqwest::Client::new();
59 let response = client
60 .get(&url)
61 .send()
62 .await
63 .with_context(|| format!("Failed to connect to {}", url))?;
64
65 if !response.status().is_success() {
66 bail!(
67 "Download failed: HTTP {} for {}",
68 response.status(),
69 url
70 );
71 }
72
73 let total_bytes = response.content_length();
75
76 let pb = ProgressBar::new(total_bytes.unwrap_or(0));
77 pb.set_style(
78 ProgressStyle::with_template(
79 "{percent}% [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})",
80 )
81 .unwrap_or_else(|_| ProgressStyle::default_bar())
82 .progress_chars("=>-"),
83 );
84
85 let mut tmp_file = tokio_fs::File::create(&tmp_path)
87 .await
88 .with_context(|| format!("Failed to create temp file: {}", tmp_path.display()))?;
89
90 let mut stream = response.bytes_stream();
91 while let Some(chunk) = stream.next().await {
92 let chunk = chunk.with_context(|| "Error reading download stream")?;
93 tmp_file
94 .write_all(&chunk)
95 .await
96 .with_context(|| "Failed to write chunk to temp file")?;
97 pb.inc(chunk.len() as u64);
98 }
99
100 tmp_file
101 .flush()
102 .await
103 .with_context(|| "Failed to flush temp file")?;
104 drop(tmp_file);
105
106 pb.finish_with_message("Download complete");
107
108 tokio_fs::rename(&tmp_path, &final_path)
110 .await
111 .with_context(|| {
112 format!(
113 "Failed to rename {} to {}",
114 tmp_path.display(),
115 final_path.display()
116 )
117 })?;
118
119 if !is_model_valid(&final_path) {
121 tokio_fs::remove_file(&final_path).await.ok();
123 bail!(
124 "Downloaded model file is invalid or too small (< 1MB): {}",
125 final_path.display()
126 );
127 }
128
129 println!("Model saved to {}", final_path.display());
130 Ok(())
131}
132
133pub async fn setup_whisper(data_dir: &PathBuf, model_size: &ModelSize) -> Result<()> {
138 if is_whisper_ready(data_dir, model_size) {
139 let path = get_model_path(data_dir, model_size);
140 println!(
141 "Whisper model already present at {}. Skipping download.",
142 path.display()
143 );
144 return Ok(());
145 }
146
147 download_model(data_dir, model_size).await
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use std::path::PathBuf;
154
155 fn nonexistent_data_dir() -> PathBuf {
157 std::env::temp_dir().join(format!(
158 "opencode-voice-test-{}",
159 uuid::Uuid::new_v4()
160 ))
161 }
162
163 #[test]
164 fn test_get_model_path() {
165 let data_dir = PathBuf::from("/tmp/opencode-voice");
166 let path = get_model_path(&data_dir, &ModelSize::TinyEn);
167 assert_eq!(path, PathBuf::from("/tmp/opencode-voice/models/ggml-tiny.en.bin"));
168
169 let path_base = get_model_path(&data_dir, &ModelSize::BaseEn);
170 assert_eq!(path_base, PathBuf::from("/tmp/opencode-voice/models/ggml-base.en.bin"));
171
172 let path_small = get_model_path(&data_dir, &ModelSize::SmallEn);
173 assert_eq!(path_small, PathBuf::from("/tmp/opencode-voice/models/ggml-small.en.bin"));
174 }
175
176 #[test]
177 fn test_is_whisper_ready_missing_file() {
178 let data_dir = nonexistent_data_dir();
179 assert!(!is_whisper_ready(&data_dir, &ModelSize::TinyEn));
181 }
182
183 #[test]
184 fn test_is_whisper_ready_small_file() {
185 let tmp_dir = std::env::temp_dir().join(format!(
187 "opencode-voice-test-small-{}",
188 uuid::Uuid::new_v4()
189 ));
190 let models_dir = tmp_dir.join("models");
191 std::fs::create_dir_all(&models_dir).unwrap();
192
193 let model_path = models_dir.join("ggml-tiny.en.bin");
194 std::fs::write(&model_path, b"this is way too small").unwrap();
195
196 assert!(!is_whisper_ready(&tmp_dir, &ModelSize::TinyEn));
197
198 std::fs::remove_dir_all(&tmp_dir).ok();
200 }
201
202 #[test]
203 fn test_is_whisper_ready_valid_file() {
204 let tmp_dir = std::env::temp_dir().join(format!(
206 "opencode-voice-test-valid-{}",
207 uuid::Uuid::new_v4()
208 ));
209 let models_dir = tmp_dir.join("models");
210 std::fs::create_dir_all(&models_dir).unwrap();
211
212 let model_path = models_dir.join("ggml-base.en.bin");
213 let big_data = vec![0u8; 1_000_001];
214 std::fs::write(&model_path, &big_data).unwrap();
215
216 assert!(is_whisper_ready(&tmp_dir, &ModelSize::BaseEn));
217
218 std::fs::remove_dir_all(&tmp_dir).ok();
220 }
221
222 #[test]
223 fn test_get_model_path_contains_model_size() {
224 let data_dir = PathBuf::from("/data");
225 for (size, expected_fragment) in [
226 (ModelSize::TinyEn, "tiny.en"),
227 (ModelSize::BaseEn, "base.en"),
228 (ModelSize::SmallEn, "small.en"),
229 ] {
230 let path = get_model_path(&data_dir, &size);
231 let path_str = path.to_string_lossy();
232 assert!(
233 path_str.contains(expected_fragment),
234 "Expected path to contain '{}', got '{}'",
235 expected_fragment,
236 path_str
237 );
238 assert!(
239 path_str.ends_with(".bin"),
240 "Expected path to end with '.bin', got '{}'",
241 path_str
242 );
243 }
244 }
245}