1use super::{
8 progress::{DownloadStatus, ProgressTracker},
9 registry::ModelInfo,
10 storage::{ModelStorage, ModelMetadata, HardwareRequirements},
11};
12use anyhow::{Context, Result};
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use tokio::{fs::File, io::AsyncWriteExt};
17use tracing::{error, info};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub enum DownloadSource {
22 HuggingFace { repo_id: String, filename: String },
23 OpenRouter { model_id: String },
24}
25
26pub struct ModelDownloader {
28 storage: Arc<ModelStorage>,
29 progress_tracker: Arc<ProgressTracker>,
30 http_client: Client,
31}
32
33impl ModelDownloader {
34 pub fn new(storage: Arc<ModelStorage>) -> Self {
35 let client = Client::builder()
36 .user_agent("Aud.io/0.1.1")
37 .timeout(std::time::Duration::from_secs(3300)) .build()
39 .expect("Failed to create HTTP client");
40
41 Self {
42 storage,
43 progress_tracker: Arc::new(ProgressTracker::new()),
44 http_client: client,
45 }
46 }
47
48 pub async fn download_model(
52 &self,
53 model_info: ModelInfo,
54 source: DownloadSource,
55 existing_download_id: Option<String>,
56 hf_token: Option<String>,
57 ) -> Result<String> {
58 info!("Starting download of model: {} from {:?}", model_info.name, source);
59
60 let download_id = if let Some(id) = existing_download_id {
62 info!("Using existing download ID: {}", id);
63 id
64 } else {
65 self.progress_tracker
67 .start_download(
68 model_info.id.clone(),
69 model_info.name.clone(),
70 Some(model_info.size_bytes),
71 )
72 .await
73 };
74
75 self.progress_tracker
77 .update_progress(&download_id, 0, DownloadStatus::Starting, None)
78 .await;
79
80 let model_dir = self.storage.create_model_directory(&model_info.id)
82 .context("Failed to create model directory")?;
83
84 let result = match source {
85 DownloadSource::HuggingFace { repo_id, filename } => {
86 self.download_from_huggingface(&download_id, &repo_id, &filename, &model_dir, &model_info, hf_token).await
87 }
88 DownloadSource::OpenRouter { model_id } => {
89 self.download_from_openrouter(&download_id, &model_id, &model_dir, &model_info).await
90 }
91 };
92
93 match result {
94 Ok(_) => {
95 info!("Successfully downloaded model: {}", model_info.name);
96 self.progress_tracker
97 .update_progress(&download_id, model_info.size_bytes, DownloadStatus::Completed, None)
98 .await;
99 Ok(download_id)
100 }
101 Err(e) => {
102 error!("Failed to download model {}: {}", model_info.name, e);
103 self.progress_tracker
104 .update_progress(&download_id, 0, DownloadStatus::Failed, Some(e.to_string()))
105 .await;
106 Err(e)
107 }
108 }
109 }
110
111 async fn download_from_huggingface(
113 &self,
114 download_id: &str,
115 repo_id: &str,
116 filename: &str,
117 model_dir: &std::path::Path,
118 model_info: &ModelInfo,
119 hf_token: Option<String>,
120 ) -> Result<()> {
121 if let Some(total_shards) = self.detect_shard_pattern(filename) {
123 info!("Detected sharded model with {} parts, downloading all shards", total_shards);
124 self.download_sharded_model(download_id, repo_id, filename, model_dir, total_shards, hf_token).await
125 } else {
126 let url = format!("https://huggingface.co/{}/resolve/main/{}", repo_id, filename);
127 info!("Downloading from HuggingFace: {}", url);
128 self.download_file_with_progress(download_id, &url, model_dir, filename, model_info, hf_token).await
129 }
130 }
131
132 fn detect_shard_pattern(&self, filename: &str) -> Option<u32> {
134 let re = regex::Regex::new(r".*-(\d{5})-of-(\d{5})\.[^.]+$").ok()?;
136 if let Some(caps) = re.captures(filename) {
137 if let Some(total_str) = caps.get(2) {
138 if let Ok(total) = total_str.as_str().parse::<u32>() {
139 return Some(total);
140 }
141 }
142 }
143 None
144 }
145
146 async fn download_sharded_model(
148 &self,
149 download_id: &str,
150 repo_id: &str,
151 first_filename: &str,
152 model_dir: &std::path::Path,
153 total_shards: u32,
154 hf_token: Option<String>,
155 ) -> Result<()> {
156 let re = regex::Regex::new(r"(.*-)(\d{5})(-of-\d{5}\.[^.]+)$").unwrap();
158 let caps = re.captures(first_filename).ok_or_else(|| {
159 anyhow::anyhow!("Invalid shard filename format: {}", first_filename)
160 })?;
161
162 let prefix = &caps[1];
163 let suffix = &caps[3];
164
165 let mut total_expected_size = 0u64;
167 for i in 1..=total_shards {
168 let shard_filename = format!("{}{:05}{}", prefix, i, suffix);
169 let url = format!("https://huggingface.co/{}/resolve/main/{}", repo_id, shard_filename);
170
171 let response = self.http_client.head(&url).send().await?;
173 if response.status().is_success() {
174 if let Some(content_length) = response.content_length() {
175 total_expected_size += content_length;
176 }
177 }
178 }
179
180 self.progress_tracker.update_total_bytes(download_id, total_expected_size).await;
182
183 let mut downloaded_so_far = 0u64;
185 for i in 1..=total_shards {
186 let shard_filename = format!("{}{:05}{}", prefix, i, suffix);
187 let url = format!("https://huggingface.co/{}/resolve/main/{}", repo_id, shard_filename);
188
189 info!("Downloading shard {}/{}: {}", i, total_shards, shard_filename);
190
191 self.download_single_shard_with_progress(
193 download_id,
194 &url,
195 model_dir,
196 &shard_filename,
197 hf_token.clone(),
198 &mut downloaded_so_far,
199 ).await?;
200 }
201
202 info!("Successfully downloaded all {} shards", total_shards);
203 Ok(())
204 }
205
206 async fn download_single_shard_with_progress(
208 &self,
209 download_id: &str,
210 url: &str,
211 model_dir: &std::path::Path,
212 filename: &str,
213 hf_token: Option<String>,
214 downloaded_so_far: &mut u64,
215 ) -> Result<()> {
216 use futures_util::StreamExt;
217
218 let mut request = self.http_client.get(url);
220
221 let token_to_use = hf_token.or_else(|| self.get_hf_token());
223 if let Some(hf_token) = token_to_use {
224 request = request.header("Authorization", format!("Bearer {}", hf_token));
225 }
226
227 let response = request
228 .send()
229 .await
230 .context("Failed to start download")?;
231
232 if !response.status().is_success() {
233 let status = response.status();
234 let error_msg = if status == 401 {
235 let repo_id = url.strip_prefix("https://huggingface.co/")
237 .and_then(|s| s.split("/resolve/").next())
238 .unwrap_or("unknown");
239
240 format!(
241 "Download failed with HTTP status: 401 Unauthorized.\n\n\
242 This may be because:\n\
243 1. The model requires authentication - check your HF_TOKEN\n\
244 2. The model is gated and requires terms acceptance at huggingface.co\n\
245 3. You've hit the unauthenticated rate limit (100 req/hour)\n\
246 4. The model is private or has been removed\n\n\
247 To fix:\n\
248 - Get a token from https://huggingface.co/settings/tokens\n\
249 - Visit https://huggingface.co/{} to request access to gated models\n\
250 - Set your token in the app settings and try again\n\n\
251 REPO_ID:{}",
252 repo_id, repo_id
253 )
254 } else {
255 format!("Download failed with HTTP status: {}", status)
256 };
257 return Err(anyhow::anyhow!(error_msg));
258 }
259
260 let shard_size = response.content_length().unwrap_or(0);
262
263 let mut file = tokio::fs::File::create(model_dir.join(filename)).await?;
264let mut downloaded: u64 = 0;
265 let start_time = std::time::Instant::now();
266
267 let mut stream = response.bytes_stream();
269
270 while let Some(chunk_result) = stream.next().await {
271 if let Some(progress) = self.progress_tracker.get_progress(download_id).await {
273 if progress.status == DownloadStatus::Cancelled {
274 let _ = tokio::fs::remove_file(model_dir.join(filename)).await;
276 return Err(anyhow::anyhow!("Download cancelled by user"));
277 }
278
279 if progress.status == DownloadStatus::Paused {
281 loop {
283 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; if let Some(updated_progress) = self.progress_tracker.get_progress(download_id).await {
285 if updated_progress.status != DownloadStatus::Paused {
286 break; }
288 } else {
289 break;
291 }
292 }
293
294 if let Some(updated_progress) = self.progress_tracker.get_progress(download_id).await {
296 if updated_progress.status == DownloadStatus::Cancelled {
297 let _ = tokio::fs::remove_file(model_dir.join(filename)).await;
299 return Err(anyhow::anyhow!("Download cancelled by user"));
300 }
301 }
302 }
303 }
304
305 let chunk = chunk_result.context("Error reading download stream")?;
306 file.write_all(&chunk).await?;
307 downloaded += chunk.len() as u64;
308 }
309
310 file.flush().await?;
311
312 *downloaded_so_far += shard_size;
314 let elapsed = start_time.elapsed();
315 self.progress_tracker.update_elapsed_time(download_id, elapsed).await;
316 self.progress_tracker
317 .update_progress(download_id, *downloaded_so_far, DownloadStatus::Downloading, None)
318 .await;
319
320 Ok(())
321 }
322
323 fn get_hf_token(&self) -> Option<String> {
325 std::env::var("HF_TOKEN").ok()
326 }
327
328 async fn download_from_openrouter(
330 &self,
331 _download_id: &str,
332 model_id: &str,
333 model_dir: &std::path::Path,
334 model_info: &ModelInfo,
335 ) -> Result<()> {
336 let config_content = serde_json::json!({
339 "model_id": model_id,
340 "provider": "openrouter",
341 "api_key_required": true,
342 "downloaded_at": chrono::Utc::now().to_rfc3339(),
343 "size_bytes": model_info.size_bytes
344 });
345
346 let config_path = model_dir.join("openrouter_config.json");
347 tokio::fs::write(&config_path, serde_json::to_string_pretty(&config_content)?).await?;
348
349 info!("Created OpenRouter configuration for model: {}", model_id);
350 Ok(())
351 }
352
353 async fn download_file_with_progress(
355 &self,
356 download_id: &str,
357 url: &str,
358 model_dir: &std::path::Path,
359 filename: &str,
360 model_info: &ModelInfo,
361 hf_token: Option<String>,
362 ) -> Result<()> {
363 use futures_util::StreamExt;
364
365 let mut request = self.http_client.get(url);
367
368 let token_to_use = hf_token.or_else(|| self.get_hf_token());
370 if let Some(hf_token) = token_to_use {
371 request = request.header("Authorization", format!("Bearer {}", hf_token));
372 }
373
374 let response = request
375 .send()
376 .await
377 .context("Failed to start download")?;
378
379 if !response.status().is_success() {
380 let status = response.status();
381 let error_msg = if status == 401 {
382 let repo_id = url.strip_prefix("https://huggingface.co/")
384 .and_then(|s| s.split("/resolve/").next())
385 .unwrap_or("unknown");
386
387 format!(
388 "Download failed with HTTP status: 401 Unauthorized.\n\n\
389 This may be because:\n\
390 1. The model requires authentication - check your HF_TOKEN\n\
391 2. The model is gated and requires terms acceptance at huggingface.co\n\
392 3. You've hit the unauthenticated rate limit (100 req/hour)\n\
393 4. The model is private or has been removed\n\n\
394 To fix:\n\
395 - Get a token from https://huggingface.co/settings/tokens\n\
396 - Visit https://huggingface.co/{} to request access to gated models\n\
397 - Set your token in the app settings and try again\n\n\
398 REPO_ID:{}",
399 repo_id, repo_id
400 )
401 } else {
402 format!("Download failed with HTTP status: {}", status)
403 };
404 return Err(anyhow::anyhow!(error_msg));
405 }
406
407 let total_size = response.content_length().unwrap_or(model_info.size_bytes);
409
410 if total_size > 0 {
412 self.progress_tracker
413 .update_total_bytes(download_id, total_size)
414 .await;
415 }
416
417 let mut file = File::create(model_dir.join(filename)).await?;
418 let mut downloaded: u64 = 0;
419 let start_time = std::time::Instant::now();
420
421 self.progress_tracker
422 .update_progress(download_id, 0, DownloadStatus::Downloading, None)
423 .await;
424
425 let mut stream = response.bytes_stream();
427
428 while let Some(chunk_result) = stream.next().await {
429 if let Some(progress) = self.progress_tracker.get_progress(download_id).await {
431 if progress.status == DownloadStatus::Cancelled {
432 let _ = tokio::fs::remove_file(model_dir.join(filename)).await;
434 return Err(anyhow::anyhow!("Download cancelled by user"));
435 }
436
437 if progress.status == DownloadStatus::Paused {
439 loop {
441 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; if let Some(updated_progress) = self.progress_tracker.get_progress(download_id).await {
443 if updated_progress.status != DownloadStatus::Paused {
444 break; }
446 } else {
447 break;
449 }
450 }
451
452 if let Some(updated_progress) = self.progress_tracker.get_progress(download_id).await {
454 if updated_progress.status == DownloadStatus::Cancelled {
455 let _ = tokio::fs::remove_file(model_dir.join(filename)).await;
457 return Err(anyhow::anyhow!("Download cancelled by user"));
458 }
459 }
460 }
461 }
462
463 let chunk = chunk_result.context("Error reading download stream")?;
464 file.write_all(&chunk).await?;
465 downloaded += chunk.len() as u64;
466
467 let elapsed = start_time.elapsed();
469 self.progress_tracker.update_elapsed_time(download_id, elapsed).await;
470 self.progress_tracker
471 .update_progress(download_id, downloaded, DownloadStatus::Downloading, None)
472 .await;
473 }
474
475 file.flush().await?;
476 Ok(())
477 }
478
479 pub async fn save_model_metadata(
481 &self,
482 model_info: &ModelInfo,
483 source: &DownloadSource,
484 ) -> Result<()> {
485 let metadata = ModelMetadata {
486 id: model_info.id.clone(),
487 name: model_info.name.clone(),
488 description: model_info.description.clone(),
489 author: model_info.author.clone(),
490 size_bytes: model_info.size_bytes,
491 format: model_info.format.clone(),
492 download_source: match source {
493 DownloadSource::HuggingFace { .. } => "huggingface".to_string(),
494 DownloadSource::OpenRouter { .. } => "openrouter".to_string(),
495 },
496 download_date: chrono::Utc::now(),
497 last_used: None,
498 tags: model_info.tags.clone(),
499 hardware_requirements: HardwareRequirements::default(),
500 compatibility_notes: None,
501 runtime_binaries: self.get_appropriate_runtime_binaries_for_model(&model_info.format).await, };
503
504 let metadata_path = self.storage.metadata_path(&model_info.id);
505 let metadata_json = serde_json::to_string_pretty(&metadata)?;
506 tokio::fs::write(&metadata_path, metadata_json).await?;
507
508 Ok(())
509 }
510
511 pub fn progress_tracker(&self) -> &Arc<ProgressTracker> {
513 &self.progress_tracker
514 }
515
516 pub async fn cancel_download(&self, download_id: &str) -> bool {
518 self.progress_tracker.cancel_download(download_id).await
519 }
520
521 async fn get_appropriate_runtime_binaries_for_model(&self, _model_format: &str) -> std::collections::HashMap<String, std::path::PathBuf> {
523 use crate::model_runtime::platform_detector::HardwareCapabilities;
524
525 let mut binaries = std::collections::HashMap::new();
526 let hw_caps = HardwareCapabilities::default();
527
528 if let Some(platform_binary) = hw_caps.get_runtime_binary_path() {
530 let platform_name = match hw_caps.platform {
532 crate::model_runtime::platform_detector::Platform::Windows => "windows",
533 crate::model_runtime::platform_detector::Platform::Linux => "linux",
534 crate::model_runtime::platform_detector::Platform::MacOS => "macos",
535 }.to_string();
536
537 binaries.insert(platform_name, platform_binary);
538 }
539
540 binaries.insert("default".to_string(), std::path::PathBuf::from("llama-server"));
542
543 binaries
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use tempfile::TempDir;
551
552 #[tokio::test]
553 async fn test_downloader_creation() -> Result<()> {
554 let temp_dir = TempDir::new()?;
555 let storage = Arc::new(ModelStorage {
556 location: super::super::storage::StorageLocation {
557 app_data_dir: temp_dir.path().to_path_buf(),
558 models_dir: temp_dir.path().join("models"),
559 registry_dir: temp_dir.path().join("registry"),
560 },
561 });
562
563 let downloader = ModelDownloader::new(storage);
564 assert!(downloader.progress_tracker().get_all_downloads().await.is_empty());
565
566 Ok(())
567 }
568}