Skip to main content

offline_intelligence/model_management/
progress.rs

1//! Download Progress Tracking
2//!
3//! Provides real-time progress tracking for model downloads
4//! with support for multiple concurrent downloads.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::{RwLock, watch};
10use tracing::{debug, info};
11use uuid::Uuid;
12
13/// Serialize Duration as seconds (f64) for JSON compatibility with frontend
14mod duration_as_secs_f64 {
15    use serde::{self, Deserialize, Deserializer, Serializer};
16    use std::time::Duration;
17
18    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
19    where S: Serializer {
20        serializer.serialize_f64(duration.as_secs_f64())
21    }
22
23    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
24    where D: Deserializer<'de> {
25        let secs = f64::deserialize(deserializer)?;
26        Ok(Duration::from_secs_f64(secs))
27    }
28}
29
30/// Serialize Option<Duration> as Option<f64> seconds for JSON compatibility
31mod option_duration_as_secs_f64 {
32    use serde::{self, Deserialize, Deserializer, Serializer};
33    use std::time::Duration;
34
35    pub fn serialize<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
36    where S: Serializer {
37        match duration {
38            Some(d) => serializer.serialize_some(&d.as_secs_f64()),
39            None => serializer.serialize_none(),
40        }
41    }
42
43    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
44    where D: Deserializer<'de> {
45        let opt = Option::<f64>::deserialize(deserializer)?;
46        Ok(opt.map(Duration::from_secs_f64))
47    }
48}
49
50/// Download progress information
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct DownloadProgress {
53    pub download_id: String,
54    pub model_id: String,
55    pub model_name: String,
56    pub status: DownloadStatus,
57    pub bytes_downloaded: u64,
58    pub total_bytes: Option<u64>,
59    pub percentage: f32,
60    pub speed_bps: f64, // bytes per second
61    #[serde(with = "duration_as_secs_f64")]
62    pub elapsed_time: std::time::Duration,
63    #[serde(with = "option_duration_as_secs_f64")]
64    pub estimated_time_remaining: Option<std::time::Duration>,
65    pub error_message: Option<String>,
66}
67
68/// Status of a download
69#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
70pub enum DownloadStatus {
71    Queued,
72    Starting,
73    Downloading,
74    Paused,
75    Completed,
76    Failed,
77    Cancelled,
78}
79
80impl DownloadStatus {
81    pub fn is_active(&self) -> bool {
82        matches!(self, DownloadStatus::Queued | DownloadStatus::Starting | DownloadStatus::Downloading)
83    }
84
85    pub fn is_finished(&self) -> bool {
86        matches!(self, DownloadStatus::Completed | DownloadStatus::Failed | DownloadStatus::Cancelled)
87    }
88}
89
90/// Progress tracker for managing multiple downloads
91pub struct ProgressTracker {
92    downloads: Arc<RwLock<HashMap<String, DownloadProgress>>>,
93    watchers: Arc<RwLock<HashMap<String, watch::Sender<DownloadProgress>>>>,
94}
95
96impl ProgressTracker {
97    pub fn new() -> Self {
98        Self {
99            downloads: Arc::new(RwLock::new(HashMap::new())),
100            watchers: Arc::new(RwLock::new(HashMap::new())),
101        }
102    }
103
104    /// Start tracking a new download
105    pub async fn start_download(
106        &self,
107        model_id: String,
108        model_name: String,
109        total_bytes: Option<u64>,
110    ) -> String {
111        let download_id = Uuid::new_v4().to_string();
112        let progress = DownloadProgress {
113            download_id: download_id.clone(),
114            model_id,
115            model_name,
116            status: DownloadStatus::Queued,
117            bytes_downloaded: 0,
118            total_bytes,
119            percentage: 0.0,
120            speed_bps: 0.0,
121            elapsed_time: std::time::Duration::from_secs(0),
122            estimated_time_remaining: None,
123            error_message: None,
124        };
125
126        {
127            let mut downloads = self.downloads.write().await;
128            downloads.insert(download_id.clone(), progress);
129        }
130
131        info!("Started tracking download: {}", download_id);
132        download_id
133    }
134
135    /// Update download progress
136    pub async fn update_progress(
137        &self,
138        download_id: &str,
139        bytes_downloaded: u64,
140        status: DownloadStatus,
141        error_message: Option<String>,
142    ) {
143        let mut downloads = self.downloads.write().await;
144        
145        if let Some(progress) = downloads.get_mut(download_id) {
146            let old_bytes = progress.bytes_downloaded;
147            progress.bytes_downloaded = bytes_downloaded;
148            progress.status = status;
149            progress.error_message = error_message;
150
151            // Calculate percentage (guard against division by zero)
152            if let Some(total) = progress.total_bytes {
153                if total > 0 {
154                    progress.percentage = (bytes_downloaded as f32 / total as f32) * 100.0;
155                }
156            }
157
158            // Calculate speed and ETA
159            if bytes_downloaded > old_bytes && progress.elapsed_time.as_secs() > 0 {
160                let time_elapsed_secs = progress.elapsed_time.as_secs_f64();
161                progress.speed_bps = bytes_downloaded as f64 / time_elapsed_secs;
162
163                if let Some(total) = progress.total_bytes {
164                    // Use saturating_sub to prevent overflow if bytes_downloaded > total
165                    let remaining_bytes = total.saturating_sub(bytes_downloaded);
166                    if progress.speed_bps > 0.0 {
167                        let eta_secs = remaining_bytes as f64 / progress.speed_bps;
168                        progress.estimated_time_remaining = Some(std::time::Duration::from_secs_f64(eta_secs));
169                    }
170                }
171            }
172
173            // Notify watchers
174            self.notify_watchers(download_id, progress.clone()).await;
175            
176            debug!("Updated progress for {}: {:.1}%", download_id, progress.percentage);
177        }
178    }
179
180    /// Update elapsed time for a download
181    pub async fn update_elapsed_time(&self, download_id: &str, elapsed: std::time::Duration) {
182        let mut downloads = self.downloads.write().await;
183        if let Some(progress) = downloads.get_mut(download_id) {
184            progress.elapsed_time = elapsed;
185        }
186    }
187
188    /// Update total bytes for a download (e.g., when Content-Length is received from HTTP response)
189    pub async fn update_total_bytes(&self, download_id: &str, total_bytes: u64) {
190        let mut downloads = self.downloads.write().await;
191        if let Some(progress) = downloads.get_mut(download_id) {
192            // Only update if we didn't already have a total size or the new size is different
193            if progress.total_bytes.is_none() || progress.total_bytes != Some(total_bytes) {
194                progress.total_bytes = Some(total_bytes);
195                
196                // Recalculate percentage if we have bytes downloaded
197                if total_bytes > 0 {
198                    progress.percentage = (progress.bytes_downloaded as f32 / total_bytes as f32) * 100.0;
199                    
200                    // Recalculate ETA
201                    if progress.speed_bps > 0.0 {
202                        let remaining_bytes = total_bytes.saturating_sub(progress.bytes_downloaded);
203                        let eta_secs = remaining_bytes as f64 / progress.speed_bps;
204                        progress.estimated_time_remaining = Some(std::time::Duration::from_secs_f64(eta_secs));
205                    }
206                }
207                
208                // Notify watchers of the update
209                self.notify_watchers(download_id, progress.clone()).await;
210                debug!("Updated total_bytes for {}: {} bytes", download_id, total_bytes);
211            }
212        }
213    }
214
215    /// Get current progress for a download
216    pub async fn get_progress(&self, download_id: &str) -> Option<DownloadProgress> {
217        let downloads = self.downloads.read().await;
218        downloads.get(download_id).cloned()
219    }
220
221    /// Get all active downloads
222    pub async fn get_active_downloads(&self) -> Vec<DownloadProgress> {
223        let downloads = self.downloads.read().await;
224        downloads.values()
225            .filter(|p| p.status.is_active())
226            .cloned()
227            .collect()
228    }
229
230    /// Get all downloads (active and completed)
231    pub async fn get_all_downloads(&self) -> Vec<DownloadProgress> {
232        let downloads = self.downloads.read().await;
233        downloads.values().cloned().collect()
234    }
235
236    /// Subscribe to progress updates for a specific download
237    pub async fn subscribe(&self, download_id: &str) -> Option<watch::Receiver<DownloadProgress>> {
238        let mut watchers = self.watchers.write().await;
239        let (tx, rx) = watch::channel(DownloadProgress {
240            download_id: download_id.to_string(),
241            model_id: String::new(),
242            model_name: String::new(),
243            status: DownloadStatus::Queued,
244            bytes_downloaded: 0,
245            total_bytes: None,
246            percentage: 0.0,
247            speed_bps: 0.0,
248            elapsed_time: std::time::Duration::from_secs(0),
249            estimated_time_remaining: None,
250            error_message: None,
251        });
252        
253        watchers.insert(download_id.to_string(), tx);
254        Some(rx)
255    }
256
257    /// Notify watchers of progress update
258    async fn notify_watchers(&self, download_id: &str, progress: DownloadProgress) {
259        let watchers = self.watchers.read().await;
260        if let Some(tx) = watchers.get(download_id) {
261            let _ = tx.send(progress);
262        }
263    }
264
265    /// Remove a download from tracking
266    pub async fn remove_download(&self, download_id: &str) {
267        {
268            let mut downloads = self.downloads.write().await;
269            downloads.remove(download_id);
270        }
271        {
272            let mut watchers = self.watchers.write().await;
273            watchers.remove(download_id);
274        }
275        info!("Removed download tracking: {}", download_id);
276    }
277
278    /// Cancel a download
279    pub async fn cancel_download(&self, download_id: &str) -> bool {
280        let mut downloads = self.downloads.write().await;
281        if let Some(progress) = downloads.get_mut(download_id) {
282            if progress.status.is_active() {
283                progress.status = DownloadStatus::Cancelled;
284                self.notify_watchers(download_id, progress.clone()).await;
285                true
286            } else {
287                false
288            }
289        } else {
290            false
291        }
292    }
293
294    /// Get overall download statistics
295    pub async fn get_statistics(&self) -> DownloadStatistics {
296        let downloads = self.downloads.read().await;
297        let mut stats = DownloadStatistics::default();
298        
299        for progress in downloads.values() {
300            match progress.status {
301                DownloadStatus::Queued => stats.queued += 1,
302                DownloadStatus::Starting => stats.starting += 1,
303                DownloadStatus::Downloading => {
304                    stats.downloading += 1;
305                    stats.total_speed_bps += progress.speed_bps;
306                },
307                DownloadStatus::Paused => stats.paused += 1,
308                DownloadStatus::Completed => stats.completed += 1,
309                DownloadStatus::Failed => stats.failed += 1,
310                DownloadStatus::Cancelled => stats.cancelled += 1,
311            }
312            
313            if let Some(total) = progress.total_bytes {
314                stats.total_data_bytes += total;
315            }
316            stats.downloaded_bytes += progress.bytes_downloaded;
317        }
318        
319        stats
320    }
321}
322
323/// Download statistics
324#[derive(Debug, Default)]
325pub struct DownloadStatistics {
326    pub queued: usize,
327    pub starting: usize,
328    pub downloading: usize,
329    pub paused: usize,
330    pub completed: usize,
331    pub failed: usize,
332    pub cancelled: usize,
333    pub total_speed_bps: f64,
334    pub downloaded_bytes: u64,
335    pub total_data_bytes: u64,
336}
337
338impl DownloadStatistics {
339    pub fn active_downloads(&self) -> usize {
340        self.queued + self.starting + self.downloading + self.paused
341    }
342
343    pub fn finished_downloads(&self) -> usize {
344        self.completed + self.failed + self.cancelled
345    }
346
347    pub fn overall_percentage(&self) -> f32 {
348        if self.total_data_bytes > 0 {
349            (self.downloaded_bytes as f32 / self.total_data_bytes as f32) * 100.0
350        } else {
351            0.0
352        }
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[tokio::test]
361    async fn test_progress_tracking() {
362        let tracker = ProgressTracker::new();
363        
364        let download_id = tracker.start_download(
365            "test-model".to_string(),
366            "Test Model".to_string(),
367            Some(1000)
368        ).await;
369
370        // Update progress
371        tracker.update_progress(&download_id, 500, DownloadStatus::Downloading, None).await;
372        
373        let progress = tracker.get_progress(&download_id).await.unwrap();
374        assert_eq!(progress.bytes_downloaded, 500);
375        assert_eq!(progress.percentage, 50.0);
376        assert_eq!(progress.status, DownloadStatus::Downloading);
377    }
378
379    #[tokio::test]
380    async fn test_subscription() {
381        let tracker = ProgressTracker::new();
382        
383        let download_id = tracker.start_download(
384            "test-model".to_string(),
385            "Test Model".to_string(),
386            Some(1000)
387        ).await;
388
389        let mut receiver = tracker.subscribe(&download_id).await.unwrap();
390        
391        // Update progress
392        tracker.update_progress(&download_id, 250, DownloadStatus::Downloading, None).await;
393        
394        // Check that we received the update
395        let progress = receiver.borrow_and_update().clone();
396        assert_eq!(progress.bytes_downloaded, 250);
397    }
398}