Skip to main content

opencode_voice/transcribe/
setup.rs

1//! Whisper model download and setup functions.
2//!
3//! Handles downloading GGML model files from HuggingFace and verifying their
4//! integrity before use.
5
6use anyhow::{bail, Context, Result};
7use futures::StreamExt;
8use indicatif::{ProgressBar, ProgressStyle};
9use std::path::PathBuf;
10use tokio::fs as tokio_fs;
11use tokio::io::AsyncWriteExt;
12
13use crate::config::ModelSize;
14use crate::transcribe::engine::is_model_valid;
15
16/// Base URL for downloading whisper GGML model files.
17const HUGGINGFACE_BASE_URL: &str =
18    "https://huggingface.co/ggerganov/whisper.cpp/resolve/main";
19
20/// Returns the filesystem path where the whisper model file should be stored.
21///
22/// Models are stored as `ggml-{model_size}.bin` inside the `models/` subdirectory
23/// of the application data directory.
24pub fn get_model_path(data_dir: &PathBuf, model_size: &ModelSize) -> PathBuf {
25    data_dir
26        .join("models")
27        .join(format!("ggml-{}.bin", model_size))
28}
29
30/// Returns `true` if the whisper model file exists and is valid (> 1MB).
31///
32/// Uses [`is_model_valid`] from the engine module to check file existence and size.
33pub fn is_whisper_ready(data_dir: &PathBuf, model_size: &ModelSize) -> bool {
34    let path = get_model_path(data_dir, model_size);
35    is_model_valid(&path)
36}
37
38/// Downloads the whisper GGML model from HuggingFace with a progress bar.
39///
40/// The file is first written to a temporary path and then atomically renamed
41/// to the final destination. Creates the `models/` directory if it does not
42/// exist. Returns an error if the download fails or the resulting file is
43/// smaller than 1MB.
44pub async fn download_model(data_dir: &PathBuf, model_size: &ModelSize) -> Result<()> {
45    let models_dir = data_dir.join("models");
46    tokio_fs::create_dir_all(&models_dir)
47        .await
48        .with_context(|| format!("Failed to create models directory: {}", models_dir.display()))?;
49
50    let model_filename = format!("ggml-{}.bin", model_size);
51    let url = format!("{}/{}", HUGGINGFACE_BASE_URL, model_filename);
52    let final_path = models_dir.join(&model_filename);
53    let tmp_path = models_dir.join(format!("{}.tmp", model_filename));
54
55    println!("Downloading whisper model {} from HuggingFace…", model_size);
56    println!("  URL: {}", url);
57
58    let client = reqwest::Client::new();
59    let response = client
60        .get(&url)
61        .send()
62        .await
63        .with_context(|| format!("Failed to connect to {}", url))?;
64
65    if !response.status().is_success() {
66        bail!(
67            "Download failed: HTTP {} for {}",
68            response.status(),
69            url
70        );
71    }
72
73    // Use Content-Length header to set up the progress bar total.
74    let total_bytes = response.content_length();
75
76    let pb = ProgressBar::new(total_bytes.unwrap_or(0));
77    pb.set_style(
78        ProgressStyle::with_template(
79            "{percent}% [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})",
80        )
81        .unwrap_or_else(|_| ProgressStyle::default_bar())
82        .progress_chars("=>-"),
83    );
84
85    // Stream response body to a temporary file.
86    let mut tmp_file = tokio_fs::File::create(&tmp_path)
87        .await
88        .with_context(|| format!("Failed to create temp file: {}", tmp_path.display()))?;
89
90    let mut stream = response.bytes_stream();
91    while let Some(chunk) = stream.next().await {
92        let chunk = chunk.with_context(|| "Error reading download stream")?;
93        tmp_file
94            .write_all(&chunk)
95            .await
96            .with_context(|| "Failed to write chunk to temp file")?;
97        pb.inc(chunk.len() as u64);
98    }
99
100    tmp_file
101        .flush()
102        .await
103        .with_context(|| "Failed to flush temp file")?;
104    drop(tmp_file);
105
106    pb.finish_with_message("Download complete");
107
108    // Atomically rename temp file to final destination.
109    tokio_fs::rename(&tmp_path, &final_path)
110        .await
111        .with_context(|| {
112            format!(
113                "Failed to rename {} to {}",
114                tmp_path.display(),
115                final_path.display()
116            )
117        })?;
118
119    // Verify the downloaded file is valid.
120    if !is_model_valid(&final_path) {
121        // Clean up the bad file before returning an error.
122        tokio_fs::remove_file(&final_path).await.ok();
123        bail!(
124            "Downloaded model file is invalid or too small (< 1MB): {}",
125            final_path.display()
126        );
127    }
128
129    println!("Model saved to {}", final_path.display());
130    Ok(())
131}
132
133/// Ensures the whisper model is present and valid, downloading it if necessary.
134///
135/// If the model already exists and passes validation, this function returns
136/// immediately without downloading. Otherwise it calls [`download_model`].
137pub async fn setup_whisper(data_dir: &PathBuf, model_size: &ModelSize) -> Result<()> {
138    if is_whisper_ready(data_dir, model_size) {
139        let path = get_model_path(data_dir, model_size);
140        println!(
141            "Whisper model already present at {}. Skipping download.",
142            path.display()
143        );
144        return Ok(());
145    }
146
147    download_model(data_dir, model_size).await
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use std::path::PathBuf;
154
155    /// Helper that returns a temporary directory path that does not exist on disk.
156    fn nonexistent_data_dir() -> PathBuf {
157        std::env::temp_dir().join(format!(
158            "opencode-voice-test-{}",
159            uuid::Uuid::new_v4()
160        ))
161    }
162
163    #[test]
164    fn test_get_model_path() {
165        let data_dir = PathBuf::from("/tmp/opencode-voice");
166        let path = get_model_path(&data_dir, &ModelSize::TinyEn);
167        assert_eq!(path, PathBuf::from("/tmp/opencode-voice/models/ggml-tiny.en.bin"));
168
169        let path_base = get_model_path(&data_dir, &ModelSize::BaseEn);
170        assert_eq!(path_base, PathBuf::from("/tmp/opencode-voice/models/ggml-base.en.bin"));
171
172        let path_small = get_model_path(&data_dir, &ModelSize::SmallEn);
173        assert_eq!(path_small, PathBuf::from("/tmp/opencode-voice/models/ggml-small.en.bin"));
174    }
175
176    #[test]
177    fn test_is_whisper_ready_missing_file() {
178        let data_dir = nonexistent_data_dir();
179        // Directory and file do not exist — should return false.
180        assert!(!is_whisper_ready(&data_dir, &ModelSize::TinyEn));
181    }
182
183    #[test]
184    fn test_is_whisper_ready_small_file() {
185        // Create a real but tiny file (< 1MB) and verify it is rejected.
186        let tmp_dir = std::env::temp_dir().join(format!(
187            "opencode-voice-test-small-{}",
188            uuid::Uuid::new_v4()
189        ));
190        let models_dir = tmp_dir.join("models");
191        std::fs::create_dir_all(&models_dir).unwrap();
192
193        let model_path = models_dir.join("ggml-tiny.en.bin");
194        std::fs::write(&model_path, b"this is way too small").unwrap();
195
196        assert!(!is_whisper_ready(&tmp_dir, &ModelSize::TinyEn));
197
198        // Cleanup
199        std::fs::remove_dir_all(&tmp_dir).ok();
200    }
201
202    #[test]
203    fn test_is_whisper_ready_valid_file() {
204        // Create a file that is exactly 1MB + 1 byte — should be accepted.
205        let tmp_dir = std::env::temp_dir().join(format!(
206            "opencode-voice-test-valid-{}",
207            uuid::Uuid::new_v4()
208        ));
209        let models_dir = tmp_dir.join("models");
210        std::fs::create_dir_all(&models_dir).unwrap();
211
212        let model_path = models_dir.join("ggml-base.en.bin");
213        let big_data = vec![0u8; 1_000_001];
214        std::fs::write(&model_path, &big_data).unwrap();
215
216        assert!(is_whisper_ready(&tmp_dir, &ModelSize::BaseEn));
217
218        // Cleanup
219        std::fs::remove_dir_all(&tmp_dir).ok();
220    }
221
222    #[test]
223    fn test_get_model_path_contains_model_size() {
224        let data_dir = PathBuf::from("/data");
225        for (size, expected_fragment) in [
226            (ModelSize::TinyEn, "tiny.en"),
227            (ModelSize::BaseEn, "base.en"),
228            (ModelSize::SmallEn, "small.en"),
229        ] {
230            let path = get_model_path(&data_dir, &size);
231            let path_str = path.to_string_lossy();
232            assert!(
233                path_str.contains(expected_fragment),
234                "Expected path to contain '{}', got '{}'",
235                expected_fragment,
236                path_str
237            );
238            assert!(
239                path_str.ends_with(".bin"),
240                "Expected path to end with '.bin', got '{}'",
241                path_str
242            );
243        }
244    }
245}