1use super::cache::{ModelCache, ModelMetadata};
7use anyhow::{Context, Result, anyhow, bail};
8use backoff::ExponentialBackoff;
9use backoff::future::retry;
10use futures::stream::StreamExt;
11use md5::{Digest, Md5};
12use reqwest::Client;
13use std::fs;
14use std::io::Read;
15use std::time::{Duration, SystemTime};
16use tokio::io::AsyncWriteExt;
17use tokio::sync::Semaphore;
18use tracing::{debug, error, info, warn};
19
20pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
22
23pub struct DownloadConfig {
25 pub model_id: String,
27 pub filename: String,
29 pub checksum: Option<String>,
31 pub timeout_secs: u64,
33 pub max_retries: u32,
35 pub progress_callback: Option<ProgressCallback>,
37}
38
39impl Default for DownloadConfig {
40 fn default() -> Self {
41 Self {
42 model_id: "hexgrad/Kokoro-82M".to_string(),
43 filename: "kokoro-v0_19.onnx".to_string(),
44 checksum: None,
45 timeout_secs: 600,
46 max_retries: 3,
47 progress_callback: None,
48 }
49 }
50}
51
52impl Clone for DownloadConfig {
53 fn clone(&self) -> Self {
54 Self {
55 model_id: self.model_id.clone(),
56 filename: self.filename.clone(),
57 checksum: self.checksum.clone(),
58 timeout_secs: self.timeout_secs,
59 max_retries: self.max_retries,
60 progress_callback: None, }
62 }
63}
64
65pub struct HFHubClient {
67 client: Client,
68 api_base: String,
69 download_semaphore: std::sync::Arc<Semaphore>,
70}
71
72impl HFHubClient {
73 pub fn new() -> Self {
75 Self {
76 client: Client::builder()
77 .timeout(Duration::from_secs(600))
78 .build()
79 .expect("Failed to create HTTP client"),
80 api_base: "https://huggingface.co".to_string(),
81 download_semaphore: std::sync::Arc::new(Semaphore::new(3)), }
83 }
84
85 pub fn with_api_base(api_base: String) -> Self {
87 Self {
88 client: Client::builder()
89 .timeout(Duration::from_secs(600))
90 .build()
91 .expect("Failed to create HTTP client"),
92 api_base,
93 download_semaphore: std::sync::Arc::new(Semaphore::new(3)),
94 }
95 }
96
97 pub fn get_download_url(&self, model_id: &str, filename: &str) -> String {
99 format!("{}/{}/resolve/main/{}", self.api_base, model_id, filename)
100 }
101
102 pub async fn download_model(
104 &self,
105 config: DownloadConfig,
106 cache: &ModelCache,
107 ) -> Result<std::path::PathBuf> {
108 let _permit = self
109 .download_semaphore
110 .acquire()
111 .await
112 .map_err(|e| anyhow!("Failed to acquire download semaphore: {}", e))?;
113
114 info!(
115 "Starting download: {} / {}",
116 config.model_id, config.filename
117 );
118
119 if cache.exists(&config.model_id).await
121 && let Some(expected_checksum) = &config.checksum
122 && cache
123 .validate(&config.model_id, Some(expected_checksum))
124 .await?
125 {
126 info!("Model already cached and valid: {}", config.model_id);
127 return Ok(cache.model_path(&config.model_id));
128 }
129
130 let download_url = self.get_download_url(&config.model_id, &config.filename);
132 let output_path = cache.model_path(&config.model_id);
133
134 info!("Downloading from: {}", download_url);
135 info!("Saving to: {:?}", output_path);
136
137 let file_size = self
139 .download_file_with_retry(&download_url, &output_path, &config)
140 .await?;
141
142 let actual_checksum = self.calculate_checksum(&output_path)?;
144
145 if let Some(expected) = &config.checksum
147 && actual_checksum != *expected
148 {
149 error!("Checksum validation failed");
150 fs::remove_file(&output_path)?;
151 bail!(
152 "Downloaded file checksum mismatch. Expected: {}, Got: {}",
153 expected,
154 actual_checksum
155 );
156 }
157
158 let metadata = ModelMetadata {
160 model_id: config.model_id.clone(),
161 version: "latest".to_string(),
162 file_size,
163 checksum: actual_checksum,
164 downloaded_at: SystemTime::now(),
165 last_accessed: SystemTime::now(),
166 access_count: 0,
167 };
168
169 cache.save_metadata(&metadata).await?;
170
171 info!(
172 "Download completed successfully: {} ({})",
173 config.model_id,
174 format_bytes(file_size)
175 );
176
177 Ok(output_path)
178 }
179
180 async fn download_file_with_retry(
182 &self,
183 url: &str,
184 output_path: &std::path::PathBuf,
185 config: &DownloadConfig,
186 ) -> Result<u64> {
187 let backoff = ExponentialBackoff {
188 max_elapsed_time: Some(Duration::from_secs(config.timeout_secs)),
189 max_interval: Duration::from_secs(60),
190 ..Default::default()
191 };
192
193 retry(backoff, || async {
194 self.download_file_once(url, output_path, config)
195 .await
196 .map_err(|e| {
197 warn!("Download attempt failed: {}", e);
198 backoff::Error::transient(e)
199 })
200 })
201 .await
202 .context("Failed to download file after retries")
203 }
204
205 async fn download_file_once(
207 &self,
208 url: &str,
209 output_path: &std::path::PathBuf,
210 config: &DownloadConfig,
211 ) -> Result<u64> {
212 let response = self
214 .client
215 .get(url)
216 .send()
217 .await
218 .context("Failed to initiate download")?;
219
220 if !response.status().is_success() {
221 bail!("Download failed with HTTP {}", response.status());
222 }
223
224 let total_size = response.content_length().unwrap_or(0);
226
227 debug!("Download size: {} bytes", total_size);
228
229 let mut file = tokio::fs::File::create(output_path)
231 .await
232 .context("Failed to create output file")?;
233
234 let mut downloaded = 0u64;
236 let mut stream = response.bytes_stream();
237
238 while let Some(chunk) = stream.next().await {
239 let chunk = chunk.context("Failed to read download chunk")?;
240 file.write_all(&chunk)
241 .await
242 .context("Failed to write to output file")?;
243
244 downloaded += chunk.len() as u64;
245
246 if let Some(ref callback) = config.progress_callback {
248 callback(downloaded, total_size);
249 }
250
251 if total_size > 0 {
253 let progress = (downloaded as f64 / total_size as f64) * 100.0;
254 if (progress as u64).is_multiple_of(10) {
255 debug!("Download progress: {:.1}%", progress);
256 }
257 }
258 }
259
260 file.sync_all()
261 .await
262 .context("Failed to sync file to disk")?;
263
264 Ok(downloaded)
265 }
266
267 fn calculate_checksum(&self, path: &std::path::PathBuf) -> Result<String> {
269 let file = fs::File::open(path).context(format!("Failed to open file: {:?}", path))?;
270
271 let mut hasher = Md5::new();
272 let mut reader = std::io::BufReader::new(file);
273 let mut buffer = [0u8; 8192];
274
275 loop {
276 let n = reader
277 .read(&mut buffer)
278 .context("Failed to read file for checksum")?;
279 if n == 0 {
280 break;
281 }
282 hasher.update(&buffer[..n]);
283 }
284
285 Ok(format!("{:x}", hasher.finalize()))
286 }
287}
288
289impl Default for HFHubClient {
290 fn default() -> Self {
291 Self::new()
292 }
293}
294
295fn format_bytes(bytes: u64) -> String {
297 const KB: u64 = 1024;
298 const MB: u64 = KB * 1024;
299 const GB: u64 = MB * 1024;
300
301 if bytes >= GB {
302 format!("{:.2} GB", bytes as f64 / GB as f64)
303 } else if bytes >= MB {
304 format!("{:.2} MB", bytes as f64 / MB as f64)
305 } else if bytes >= KB {
306 format!("{:.2} KB", bytes as f64 / KB as f64)
307 } else {
308 format!("{} B", bytes)
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_hf_client_creation() {
318 let client = HFHubClient::new();
319 assert_eq!(client.api_base, "https://huggingface.co");
320 }
321
322 #[test]
323 fn test_download_url_generation() {
324 let client = HFHubClient::new();
325 let url = client.get_download_url("hexgrad/Kokoro-82M", "model.onnx");
326 assert_eq!(
327 url,
328 "https://huggingface.co/hexgrad/Kokoro-82M/resolve/main/model.onnx"
329 );
330 }
331
332 #[test]
333 fn test_format_bytes() {
334 assert_eq!(format_bytes(500), "500 B");
335 assert_eq!(format_bytes(2048), "2.00 KB");
336 assert_eq!(format_bytes(3 * 1024 * 1024), "3.00 MB");
337 }
338
339 #[test]
340 fn test_download_config_default() {
341 let config = DownloadConfig::default();
342 assert_eq!(config.model_id, "hexgrad/Kokoro-82M");
343 assert_eq!(config.filename, "kokoro-v0_19.onnx");
344 assert_eq!(config.timeout_secs, 600);
345 assert_eq!(config.max_retries, 3);
346 }
347}