Skip to main content

mofa_plugins/tts/
model_downloader.rs

1//! Hugging Face Model Downloader
2//!
3//! Handles downloading TTS models from Hugging Face Hub
4//! with progress tracking, checksum validation, and retry logic.
5
6use super::cache::{ModelCache, ModelMetadata};
7use anyhow::{Context, Result, anyhow, bail};
8use backoff::ExponentialBackoff;
9use backoff::future::retry;
10use futures::stream::StreamExt;
11use md5::{Digest, Md5};
12use reqwest::Client;
13use std::fs;
14use std::io::Read;
15use std::time::{Duration, SystemTime};
16use tokio::io::AsyncWriteExt;
17use tokio::sync::Semaphore;
18use tracing::{debug, error, info, warn};
19
20/// Download progress callback type
21pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
22
23/// Model downloader configuration
24pub struct DownloadConfig {
25    /// Hugging Face model ID (e.g., "hexgrad/Kokoro-82M")
26    pub model_id: String,
27    /// Specific file to download (e.g., "kokoro-v0_19.onnx")
28    pub filename: String,
29    /// Expected checksum (MD5 hex string)
30    pub checksum: Option<String>,
31    /// Download timeout in seconds
32    pub timeout_secs: u64,
33    /// Maximum retry attempts
34    pub max_retries: u32,
35    /// Progress callback (optional)
36    pub progress_callback: Option<ProgressCallback>,
37}
38
39impl Default for DownloadConfig {
40    fn default() -> Self {
41        Self {
42            model_id: "hexgrad/Kokoro-82M".to_string(),
43            filename: "kokoro-v0_19.onnx".to_string(),
44            checksum: None,
45            timeout_secs: 600,
46            max_retries: 3,
47            progress_callback: None,
48        }
49    }
50}
51
52impl Clone for DownloadConfig {
53    fn clone(&self) -> Self {
54        Self {
55            model_id: self.model_id.clone(),
56            filename: self.filename.clone(),
57            checksum: self.checksum.clone(),
58            timeout_secs: self.timeout_secs,
59            max_retries: self.max_retries,
60            progress_callback: None, // Callbacks cannot be cloned
61        }
62    }
63}
64
65/// Hugging Face Hub API client
66pub struct HFHubClient {
67    client: Client,
68    api_base: String,
69    download_semaphore: std::sync::Arc<Semaphore>,
70}
71
72impl HFHubClient {
73    /// Create a new Hugging Face Hub client
74    pub fn new() -> Self {
75        Self {
76            client: Client::builder()
77                .timeout(Duration::from_secs(600))
78                .build()
79                .expect("Failed to create HTTP client"),
80            api_base: "https://huggingface.co".to_string(),
81            download_semaphore: std::sync::Arc::new(Semaphore::new(3)), // Max 3 concurrent downloads
82        }
83    }
84
85    /// Create client with custom API base (useful for mirrors)
86    pub fn with_api_base(api_base: String) -> Self {
87        Self {
88            client: Client::builder()
89                .timeout(Duration::from_secs(600))
90                .build()
91                .expect("Failed to create HTTP client"),
92            api_base,
93            download_semaphore: std::sync::Arc::new(Semaphore::new(3)),
94        }
95    }
96
97    /// Get direct download URL for a model file
98    pub fn get_download_url(&self, model_id: &str, filename: &str) -> String {
99        format!("{}/{}/resolve/main/{}", self.api_base, model_id, filename)
100    }
101
102    /// Download model file with progress tracking
103    pub async fn download_model(
104        &self,
105        config: DownloadConfig,
106        cache: &ModelCache,
107    ) -> Result<std::path::PathBuf> {
108        let _permit = self
109            .download_semaphore
110            .acquire()
111            .await
112            .map_err(|e| anyhow!("Failed to acquire download semaphore: {}", e))?;
113
114        info!(
115            "Starting download: {} / {}",
116            config.model_id, config.filename
117        );
118
119        // Check if model already exists and is valid
120        if cache.exists(&config.model_id).await
121            && let Some(expected_checksum) = &config.checksum
122            && cache
123                .validate(&config.model_id, Some(expected_checksum))
124                .await?
125        {
126            info!("Model already cached and valid: {}", config.model_id);
127            return Ok(cache.model_path(&config.model_id));
128        }
129
130        // Get download URL
131        let download_url = self.get_download_url(&config.model_id, &config.filename);
132        let output_path = cache.model_path(&config.model_id);
133
134        info!("Downloading from: {}", download_url);
135        info!("Saving to: {:?}", output_path);
136
137        // Perform download with retry logic
138        let file_size = self
139            .download_file_with_retry(&download_url, &output_path, &config)
140            .await?;
141
142        // Calculate checksum
143        let actual_checksum = self.calculate_checksum(&output_path)?;
144
145        // Validate checksum if provided
146        if let Some(expected) = &config.checksum
147            && actual_checksum != *expected
148        {
149            error!("Checksum validation failed");
150            fs::remove_file(&output_path)?;
151            bail!(
152                "Downloaded file checksum mismatch. Expected: {}, Got: {}",
153                expected,
154                actual_checksum
155            );
156        }
157
158        // Save metadata
159        let metadata = ModelMetadata {
160            model_id: config.model_id.clone(),
161            version: "latest".to_string(),
162            file_size,
163            checksum: actual_checksum,
164            downloaded_at: SystemTime::now(),
165            last_accessed: SystemTime::now(),
166            access_count: 0,
167        };
168
169        cache.save_metadata(&metadata).await?;
170
171        info!(
172            "Download completed successfully: {} ({})",
173            config.model_id,
174            format_bytes(file_size)
175        );
176
177        Ok(output_path)
178    }
179
180    /// Download file with retry logic and progress tracking
181    async fn download_file_with_retry(
182        &self,
183        url: &str,
184        output_path: &std::path::PathBuf,
185        config: &DownloadConfig,
186    ) -> Result<u64> {
187        let backoff = ExponentialBackoff {
188            max_elapsed_time: Some(Duration::from_secs(config.timeout_secs)),
189            max_interval: Duration::from_secs(60),
190            ..Default::default()
191        };
192
193        retry(backoff, || async {
194            self.download_file_once(url, output_path, config)
195                .await
196                .map_err(|e| {
197                    warn!("Download attempt failed: {}", e);
198                    backoff::Error::transient(e)
199                })
200        })
201        .await
202        .context("Failed to download file after retries")
203    }
204
205    /// Single download attempt
206    async fn download_file_once(
207        &self,
208        url: &str,
209        output_path: &std::path::PathBuf,
210        config: &DownloadConfig,
211    ) -> Result<u64> {
212        // Send HTTP request
213        let response = self
214            .client
215            .get(url)
216            .send()
217            .await
218            .context("Failed to initiate download")?;
219
220        if !response.status().is_success() {
221            bail!("Download failed with HTTP {}", response.status());
222        }
223
224        // Get content length for progress tracking
225        let total_size = response.content_length().unwrap_or(0);
226
227        debug!("Download size: {} bytes", total_size);
228
229        // Create output file
230        let mut file = tokio::fs::File::create(output_path)
231            .await
232            .context("Failed to create output file")?;
233
234        // Download with progress tracking
235        let mut downloaded = 0u64;
236        let mut stream = response.bytes_stream();
237
238        while let Some(chunk) = stream.next().await {
239            let chunk = chunk.context("Failed to read download chunk")?;
240            file.write_all(&chunk)
241                .await
242                .context("Failed to write to output file")?;
243
244            downloaded += chunk.len() as u64;
245
246            // Call progress callback if provided
247            if let Some(ref callback) = config.progress_callback {
248                callback(downloaded, total_size);
249            }
250
251            // Log progress every 10%
252            if total_size > 0 {
253                let progress = (downloaded as f64 / total_size as f64) * 100.0;
254                if (progress as u64).is_multiple_of(10) {
255                    debug!("Download progress: {:.1}%", progress);
256                }
257            }
258        }
259
260        file.sync_all()
261            .await
262            .context("Failed to sync file to disk")?;
263
264        Ok(downloaded)
265    }
266
267    /// Calculate MD5 checksum of a file
268    fn calculate_checksum(&self, path: &std::path::PathBuf) -> Result<String> {
269        let file = fs::File::open(path).context(format!("Failed to open file: {:?}", path))?;
270
271        let mut hasher = Md5::new();
272        let mut reader = std::io::BufReader::new(file);
273        let mut buffer = [0u8; 8192];
274
275        loop {
276            let n = reader
277                .read(&mut buffer)
278                .context("Failed to read file for checksum")?;
279            if n == 0 {
280                break;
281            }
282            hasher.update(&buffer[..n]);
283        }
284
285        Ok(format!("{:x}", hasher.finalize()))
286    }
287}
288
289impl Default for HFHubClient {
290    fn default() -> Self {
291        Self::new()
292    }
293}
294
295/// Format bytes in human-readable format
296fn format_bytes(bytes: u64) -> String {
297    const KB: u64 = 1024;
298    const MB: u64 = KB * 1024;
299    const GB: u64 = MB * 1024;
300
301    if bytes >= GB {
302        format!("{:.2} GB", bytes as f64 / GB as f64)
303    } else if bytes >= MB {
304        format!("{:.2} MB", bytes as f64 / MB as f64)
305    } else if bytes >= KB {
306        format!("{:.2} KB", bytes as f64 / KB as f64)
307    } else {
308        format!("{} B", bytes)
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_hf_client_creation() {
318        let client = HFHubClient::new();
319        assert_eq!(client.api_base, "https://huggingface.co");
320    }
321
322    #[test]
323    fn test_download_url_generation() {
324        let client = HFHubClient::new();
325        let url = client.get_download_url("hexgrad/Kokoro-82M", "model.onnx");
326        assert_eq!(
327            url,
328            "https://huggingface.co/hexgrad/Kokoro-82M/resolve/main/model.onnx"
329        );
330    }
331
332    #[test]
333    fn test_format_bytes() {
334        assert_eq!(format_bytes(500), "500 B");
335        assert_eq!(format_bytes(2048), "2.00 KB");
336        assert_eq!(format_bytes(3 * 1024 * 1024), "3.00 MB");
337    }
338
339    #[test]
340    fn test_download_config_default() {
341        let config = DownloadConfig::default();
342        assert_eq!(config.model_id, "hexgrad/Kokoro-82M");
343        assert_eq!(config.filename, "kokoro-v0_19.onnx");
344        assert_eq!(config.timeout_secs, 600);
345        assert_eq!(config.max_retries, 3);
346    }
347}