Skip to main content

offline_intelligence/model_management/
downloader.rs

1//! Model Downloader
2//!
3//! Downloads models from various sources:
4//! - Hugging Face Hub
5//! - OpenRouter (API-based models)
6
7use super::{
8    progress::{DownloadStatus, ProgressTracker},
9    registry::ModelInfo,
10    storage::{ModelStorage, ModelMetadata, HardwareRequirements},
11};
12use anyhow::{Context, Result};
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use tokio::{fs::File, io::AsyncWriteExt};
17use tracing::{error, info};
18
19/// Source from which to download a model
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub enum DownloadSource {
22    HuggingFace { repo_id: String, filename: String },
23    OpenRouter { model_id: String },
24}
25
26/// Model downloader service
27pub struct ModelDownloader {
28    storage: Arc<ModelStorage>,
29    progress_tracker: Arc<ProgressTracker>,
30    http_client: Client,
31}
32
33impl ModelDownloader {
34    pub fn new(storage: Arc<ModelStorage>) -> Self {
35        let client = Client::builder()
36            .user_agent("Aud.io/0.1.1")
37            .timeout(std::time::Duration::from_secs(3300))  // 55 minutes for large model downloads
38            .build()
39            .expect("Failed to create HTTP client");
40
41        Self {
42            storage,
43            progress_tracker: Arc::new(ProgressTracker::new()),
44            http_client: client,
45        }
46    }
47
48    /// Download a model from the specified source
49    /// If `existing_download_id` is provided, uses that for progress tracking instead of creating a new one
50    /// If `hf_token` is provided, it will be used for HuggingFace authentication
51    pub async fn download_model(
52        &self,
53        model_info: ModelInfo,
54        source: DownloadSource,
55        existing_download_id: Option<String>,
56        hf_token: Option<String>,
57    ) -> Result<String> {
58        info!("Starting download of model: {} from {:?}", model_info.name, source);
59
60        // Use existing download ID if provided, otherwise create a new one
61        let download_id = if let Some(id) = existing_download_id {
62            info!("Using existing download ID: {}", id);
63            id
64        } else {
65            // Start progress tracking with a new ID
66            self.progress_tracker
67                .start_download(
68                    model_info.id.clone(),
69                    model_info.name.clone(),
70                    Some(model_info.size_bytes),
71                )
72                .await
73        };
74
75        // Update status to starting
76        self.progress_tracker
77            .update_progress(&download_id, 0, DownloadStatus::Starting, None)
78            .await;
79
80        // Create model directory
81        let model_dir = self.storage.create_model_directory(&model_info.id)
82            .context("Failed to create model directory")?;
83
84        let result = match source {
85            DownloadSource::HuggingFace { repo_id, filename } => {
86                self.download_from_huggingface(&download_id, &repo_id, &filename, &model_dir, &model_info, hf_token).await
87            }
88            DownloadSource::OpenRouter { model_id } => {
89                self.download_from_openrouter(&download_id, &model_id, &model_dir, &model_info).await
90            }
91        };
92
93        match result {
94            Ok(_) => {
95                info!("Successfully downloaded model: {}", model_info.name);
96                self.progress_tracker
97                    .update_progress(&download_id, model_info.size_bytes, DownloadStatus::Completed, None)
98                    .await;
99                Ok(download_id)
100            }
101            Err(e) => {
102                error!("Failed to download model {}: {}", model_info.name, e);
103                self.progress_tracker
104                    .update_progress(&download_id, 0, DownloadStatus::Failed, Some(e.to_string()))
105                    .await;
106                Err(e)
107            }
108        }
109    }
110
111    /// Download from Hugging Face Hub
112    async fn download_from_huggingface(
113        &self,
114        download_id: &str,
115        repo_id: &str,
116        filename: &str,
117        model_dir: &std::path::Path,
118        model_info: &ModelInfo,
119        hf_token: Option<String>,
120    ) -> Result<()> {
121        // Check if this is a sharded model by looking for the shard pattern
122        if let Some(total_shards) = self.detect_shard_pattern(filename) {
123            info!("Detected sharded model with {} parts, downloading all shards", total_shards);
124            self.download_sharded_model(download_id, repo_id, filename, model_dir, total_shards, hf_token).await
125        } else {
126            let url = format!("https://huggingface.co/{}/resolve/main/{}", repo_id, filename);
127            info!("Downloading from HuggingFace: {}", url);
128            self.download_file_with_progress(download_id, &url, model_dir, filename, model_info, hf_token).await
129        }
130    }
131
132    /// Detect if the filename follows a shard pattern (e.g., model-00001-of-00003.gguf)
133    fn detect_shard_pattern(&self, filename: &str) -> Option<u32> {
134        // Pattern: some-name-00001-of-00003.ext
135        let re = regex::Regex::new(r".*-(\d{5})-of-(\d{5})\.[^.]+$").ok()?;
136        if let Some(caps) = re.captures(filename) {
137            if let Some(total_str) = caps.get(2) {
138                if let Ok(total) = total_str.as_str().parse::<u32>() {
139                    return Some(total);
140                }
141            }
142        }
143        None
144    }
145
146    /// Download all shards of a sharded model
147    async fn download_sharded_model(
148        &self,
149        download_id: &str,
150        repo_id: &str,
151        first_filename: &str,
152        model_dir: &std::path::Path,
153        total_shards: u32,
154        hf_token: Option<String>,
155    ) -> Result<()> {
156        // Extract the pattern from the first filename to construct other shard names
157        let re = regex::Regex::new(r"(.*-)(\d{5})(-of-\d{5}\.[^.]+)$").unwrap();
158        let caps = re.captures(first_filename).ok_or_else(|| {
159            anyhow::anyhow!("Invalid shard filename format: {}", first_filename)
160        })?;
161
162        let prefix = &caps[1];
163        let suffix = &caps[3];
164
165        // Calculate total size for progress tracking
166        let mut total_expected_size = 0u64;
167        for i in 1..=total_shards {
168            let shard_filename = format!("{}{:05}{}", prefix, i, suffix);
169            let url = format!("https://huggingface.co/{}/resolve/main/{}", repo_id, shard_filename);
170            
171            // Get the size of each shard
172            let response = self.http_client.head(&url).send().await?;
173            if response.status().is_success() {
174                if let Some(content_length) = response.content_length() {
175                    total_expected_size += content_length;
176                }
177            }
178        }
179
180        // Update progress tracker with total size
181        self.progress_tracker.update_total_bytes(download_id, total_expected_size).await;
182
183        // Download each shard
184        let mut downloaded_so_far = 0u64;
185        for i in 1..=total_shards {
186            let shard_filename = format!("{}{:05}{}", prefix, i, suffix);
187            let url = format!("https://huggingface.co/{}/resolve/main/{}", repo_id, shard_filename);
188            
189            info!("Downloading shard {}/{}: {}", i, total_shards, shard_filename);
190
191            // Download the shard file
192            self.download_single_shard_with_progress(
193                download_id,
194                &url,
195                model_dir,
196                &shard_filename,
197                hf_token.clone(),
198                &mut downloaded_so_far,
199            ).await?;
200        }
201
202        info!("Successfully downloaded all {} shards", total_shards);
203        Ok(())
204    }
205
206    /// Download a single shard with progress tracking
207    async fn download_single_shard_with_progress(
208        &self,
209        download_id: &str,
210        url: &str,
211        model_dir: &std::path::Path,
212        filename: &str,
213        hf_token: Option<String>,
214        downloaded_so_far: &mut u64,
215    ) -> Result<()> {
216        use futures_util::StreamExt;
217
218        // Build request with optional HF_TOKEN authentication
219        let mut request = self.http_client.get(url);
220        
221        // Add HuggingFace token if available (provided token takes precedence over env var)
222        let token_to_use = hf_token.or_else(|| self.get_hf_token());
223        if let Some(hf_token) = token_to_use {
224            request = request.header("Authorization", format!("Bearer {}", hf_token));
225        }
226
227        let response = request
228            .send()
229            .await
230            .context("Failed to start download")?;
231
232        if !response.status().is_success() {
233            let status = response.status();
234            let error_msg = if status == 401 {
235                // Extract repo_id from URL for the gated model link
236                let repo_id = url.strip_prefix("https://huggingface.co/")
237                    .and_then(|s| s.split("/resolve/").next())
238                    .unwrap_or("unknown");
239                
240                format!(
241                    "Download failed with HTTP status: 401 Unauthorized.\n\n\
242                    This may be because:\n\
243                    1. The model requires authentication - check your HF_TOKEN\n\
244                    2. The model is gated and requires terms acceptance at huggingface.co\n\
245                    3. You've hit the unauthenticated rate limit (100 req/hour)\n\
246                    4. The model is private or has been removed\n\n\
247                    To fix:\n\
248                    - Get a token from https://huggingface.co/settings/tokens\n\
249                    - Visit https://huggingface.co/{} to request access to gated models\n\
250                    - Set your token in the app settings and try again\n\n\
251                    REPO_ID:{}",
252                    repo_id, repo_id
253                )
254            } else {
255                format!("Download failed with HTTP status: {}", status)
256            };
257            return Err(anyhow::anyhow!(error_msg));
258        }
259
260        // Get size of this shard from Content-Length header
261        let shard_size = response.content_length().unwrap_or(0);
262        
263        let mut file = tokio::fs::File::create(model_dir.join(filename)).await?;
264let mut downloaded: u64 = 0;
265        let start_time = std::time::Instant::now();
266
267        // Stream the response in chunks for real-time progress
268        let mut stream = response.bytes_stream();
269
270        while let Some(chunk_result) = stream.next().await {
271            // Check for cancellation
272            if let Some(progress) = self.progress_tracker.get_progress(download_id).await {
273                if progress.status == DownloadStatus::Cancelled {
274                    // Clean up partial file
275                    let _ = tokio::fs::remove_file(model_dir.join(filename)).await;
276                    return Err(anyhow::anyhow!("Download cancelled by user"));
277                }
278                
279                // Check for pause status
280                if progress.status == DownloadStatus::Paused {
281                    // Wait until the download is resumed
282                    loop {
283                        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; // Check every 100ms
284                        if let Some(updated_progress) = self.progress_tracker.get_progress(download_id).await {
285                            if updated_progress.status != DownloadStatus::Paused {
286                                break; // Exit the pause loop when resumed or status changed
287                            }
288                        } else {
289                            // If progress is no longer tracked, exit
290                            break;
291                        }
292                    }
293                    
294                    // After resuming, check if we should continue or stop
295                    if let Some(updated_progress) = self.progress_tracker.get_progress(download_id).await {
296                        if updated_progress.status == DownloadStatus::Cancelled {
297                            // Clean up partial file
298                            let _ = tokio::fs::remove_file(model_dir.join(filename)).await;
299                            return Err(anyhow::anyhow!("Download cancelled by user"));
300                        }
301                    }
302                }
303            }
304
305            let chunk = chunk_result.context("Error reading download stream")?;
306            file.write_all(&chunk).await?;
307            downloaded += chunk.len() as u64;
308        }
309
310        file.flush().await?;
311        
312        // Update cumulative progress
313        *downloaded_so_far += shard_size;
314        let elapsed = start_time.elapsed();
315        self.progress_tracker.update_elapsed_time(download_id, elapsed).await;
316        self.progress_tracker
317            .update_progress(download_id, *downloaded_so_far, DownloadStatus::Downloading, None)
318            .await;
319
320        Ok(())
321    }
322
323    /// Get HuggingFace token from environment
324    fn get_hf_token(&self) -> Option<String> {
325        std::env::var("HF_TOKEN").ok()
326    }
327
328    /// Download from OpenRouter (API-based)
329    async fn download_from_openrouter(
330        &self,
331        _download_id: &str,
332        model_id: &str,
333        model_dir: &std::path::Path,
334        model_info: &ModelInfo,
335    ) -> Result<()> {
336        // OpenRouter provides API access rather than direct downloads
337        // We'll create a configuration file for API usage
338        let config_content = serde_json::json!({
339            "model_id": model_id,
340            "provider": "openrouter",
341            "api_key_required": true,
342            "downloaded_at": chrono::Utc::now().to_rfc3339(),
343            "size_bytes": model_info.size_bytes
344        });
345
346        let config_path = model_dir.join("openrouter_config.json");
347        tokio::fs::write(&config_path, serde_json::to_string_pretty(&config_content)?).await?;
348
349        info!("Created OpenRouter configuration for model: {}", model_id);
350        Ok(())
351    }
352
353    /// Download a file with progress tracking using chunked streaming
354    async fn download_file_with_progress(
355        &self,
356        download_id: &str,
357        url: &str,
358        model_dir: &std::path::Path,
359        filename: &str,
360        model_info: &ModelInfo,
361        hf_token: Option<String>,
362    ) -> Result<()> {
363        use futures_util::StreamExt;
364
365        // Build request with optional HF_TOKEN authentication
366        let mut request = self.http_client.get(url);
367        
368        // Add HuggingFace token if available (provided token takes precedence over env var)
369        let token_to_use = hf_token.or_else(|| self.get_hf_token());
370        if let Some(hf_token) = token_to_use {
371            request = request.header("Authorization", format!("Bearer {}", hf_token));
372        }
373
374        let response = request
375            .send()
376            .await
377            .context("Failed to start download")?;
378
379        if !response.status().is_success() {
380            let status = response.status();
381            let error_msg = if status == 401 {
382                // Extract repo_id from URL for the gated model link
383                let repo_id = url.strip_prefix("https://huggingface.co/")
384                    .and_then(|s| s.split("/resolve/").next())
385                    .unwrap_or("unknown");
386                
387                format!(
388                    "Download failed with HTTP status: 401 Unauthorized.\n\n\
389                    This may be because:\n\
390                    1. The model requires authentication - check your HF_TOKEN\n\
391                    2. The model is gated and requires terms acceptance at huggingface.co\n\
392                    3. You've hit the unauthenticated rate limit (100 req/hour)\n\
393                    4. The model is private or has been removed\n\n\
394                    To fix:\n\
395                    - Get a token from https://huggingface.co/settings/tokens\n\
396                    - Visit https://huggingface.co/{} to request access to gated models\n\
397                    - Set your token in the app settings and try again\n\n\
398                    REPO_ID:{}",
399                    repo_id, repo_id
400                )
401            } else {
402                format!("Download failed with HTTP status: {}", status)
403            };
404            return Err(anyhow::anyhow!(error_msg));
405        }
406
407        // Get total size from Content-Length header or fall back to model_info.size_bytes
408        let total_size = response.content_length().unwrap_or(model_info.size_bytes);
409        
410        // Update progress tracker with actual size from HTTP response if available
411        if total_size > 0 {
412            self.progress_tracker
413                .update_total_bytes(download_id, total_size)
414                .await;
415        }
416        
417        let mut file = File::create(model_dir.join(filename)).await?;
418        let mut downloaded: u64 = 0;
419        let start_time = std::time::Instant::now();
420
421        self.progress_tracker
422            .update_progress(download_id, 0, DownloadStatus::Downloading, None)
423            .await;
424
425        // Stream the response in chunks for real-time progress
426        let mut stream = response.bytes_stream();
427
428        while let Some(chunk_result) = stream.next().await {
429            // Check for cancellation
430            if let Some(progress) = self.progress_tracker.get_progress(download_id).await {
431                if progress.status == DownloadStatus::Cancelled {
432                    // Clean up partial file
433                    let _ = tokio::fs::remove_file(model_dir.join(filename)).await;
434                    return Err(anyhow::anyhow!("Download cancelled by user"));
435                }
436                
437                // Check for pause status
438                if progress.status == DownloadStatus::Paused {
439                    // Wait until the download is resumed
440                    loop {
441                        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; // Check every 100ms
442                        if let Some(updated_progress) = self.progress_tracker.get_progress(download_id).await {
443                            if updated_progress.status != DownloadStatus::Paused {
444                                break; // Exit the pause loop when resumed or status changed
445                            }
446                        } else {
447                            // If progress is no longer tracked, exit
448                            break;
449                        }
450                    }
451                    
452                    // After resuming, check if we should continue or stop
453                    if let Some(updated_progress) = self.progress_tracker.get_progress(download_id).await {
454                        if updated_progress.status == DownloadStatus::Cancelled {
455                            // Clean up partial file
456                            let _ = tokio::fs::remove_file(model_dir.join(filename)).await;
457                            return Err(anyhow::anyhow!("Download cancelled by user"));
458                        }
459                    }
460                }
461            }
462
463            let chunk = chunk_result.context("Error reading download stream")?;
464            file.write_all(&chunk).await?;
465            downloaded += chunk.len() as u64;
466
467            // Update elapsed time and progress
468            let elapsed = start_time.elapsed();
469            self.progress_tracker.update_elapsed_time(download_id, elapsed).await;
470            self.progress_tracker
471                .update_progress(download_id, downloaded, DownloadStatus::Downloading, None)
472                .await;
473        }
474
475        file.flush().await?;
476        Ok(())
477    }
478
479    /// Save model metadata after successful download
480    pub async fn save_model_metadata(
481        &self,
482        model_info: &ModelInfo,
483        source: &DownloadSource,
484    ) -> Result<()> {
485        let metadata = ModelMetadata {
486            id: model_info.id.clone(),
487            name: model_info.name.clone(),
488            description: model_info.description.clone(),
489            author: model_info.author.clone(),
490            size_bytes: model_info.size_bytes,
491            format: model_info.format.clone(),
492            download_source: match source {
493                DownloadSource::HuggingFace { .. } => "huggingface".to_string(),
494                DownloadSource::OpenRouter { .. } => "openrouter".to_string(),
495            },
496            download_date: chrono::Utc::now(),
497            last_used: None,
498            tags: model_info.tags.clone(),
499            hardware_requirements: HardwareRequirements::default(),
500            compatibility_notes: None,
501            runtime_binaries: self.get_appropriate_runtime_binaries_for_model(&model_info.format).await, // Populate with platform-appropriate binaries based on model format
502        };
503
504        let metadata_path = self.storage.metadata_path(&model_info.id);
505        let metadata_json = serde_json::to_string_pretty(&metadata)?;
506        tokio::fs::write(&metadata_path, metadata_json).await?;
507
508        Ok(())
509    }
510
511    /// Get reference to progress tracker
512    pub fn progress_tracker(&self) -> &Arc<ProgressTracker> {
513        &self.progress_tracker
514    }
515
516    /// Cancel an ongoing download
517    pub async fn cancel_download(&self, download_id: &str) -> bool {
518        self.progress_tracker.cancel_download(download_id).await
519    }
520    
521    /// Get appropriate runtime binaries for a given model format based on the current platform
522    async fn get_appropriate_runtime_binaries_for_model(&self, _model_format: &str) -> std::collections::HashMap<String, std::path::PathBuf> {
523        use crate::model_runtime::platform_detector::HardwareCapabilities;
524        
525        let mut binaries = std::collections::HashMap::new();
526        let hw_caps = HardwareCapabilities::default();
527        
528        // Get the platform-appropriate binary path
529        if let Some(platform_binary) = hw_caps.get_runtime_binary_path() {
530            // Use the platform-specific binary for this format
531            let platform_name = match hw_caps.platform {
532                crate::model_runtime::platform_detector::Platform::Windows => "windows",
533                crate::model_runtime::platform_detector::Platform::Linux => "linux",
534                crate::model_runtime::platform_detector::Platform::MacOS => "macos",
535            }.to_string();
536            
537            binaries.insert(platform_name, platform_binary);
538        }
539        
540        // Also add generic format mappings
541        binaries.insert("default".to_string(), std::path::PathBuf::from("llama-server"));
542        
543        binaries
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550    use tempfile::TempDir;
551
552    #[tokio::test]
553    async fn test_downloader_creation() -> Result<()> {
554        let temp_dir = TempDir::new()?;
555        let storage = Arc::new(ModelStorage {
556            location: super::super::storage::StorageLocation {
557                app_data_dir: temp_dir.path().to_path_buf(),
558                models_dir: temp_dir.path().join("models"),
559                registry_dir: temp_dir.path().join("registry"),
560            },
561        });
562
563        let downloader = ModelDownloader::new(storage);
564        assert!(downloader.progress_tracker().get_all_downloads().await.is_empty());
565        
566        Ok(())
567    }
568}