offline_intelligence/model_management/
progress.rs1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::{RwLock, watch};
10use tracing::{debug, info};
11use uuid::Uuid;
12
13mod duration_as_secs_f64 {
15 use serde::{self, Deserialize, Deserializer, Serializer};
16 use std::time::Duration;
17
18 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
19 where S: Serializer {
20 serializer.serialize_f64(duration.as_secs_f64())
21 }
22
23 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
24 where D: Deserializer<'de> {
25 let secs = f64::deserialize(deserializer)?;
26 Ok(Duration::from_secs_f64(secs))
27 }
28}
29
30mod option_duration_as_secs_f64 {
32 use serde::{self, Deserialize, Deserializer, Serializer};
33 use std::time::Duration;
34
35 pub fn serialize<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
36 where S: Serializer {
37 match duration {
38 Some(d) => serializer.serialize_some(&d.as_secs_f64()),
39 None => serializer.serialize_none(),
40 }
41 }
42
43 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
44 where D: Deserializer<'de> {
45 let opt = Option::<f64>::deserialize(deserializer)?;
46 Ok(opt.map(Duration::from_secs_f64))
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct DownloadProgress {
53 pub download_id: String,
54 pub model_id: String,
55 pub model_name: String,
56 pub status: DownloadStatus,
57 pub bytes_downloaded: u64,
58 pub total_bytes: Option<u64>,
59 pub percentage: f32,
60 pub speed_bps: f64, #[serde(with = "duration_as_secs_f64")]
62 pub elapsed_time: std::time::Duration,
63 #[serde(with = "option_duration_as_secs_f64")]
64 pub estimated_time_remaining: Option<std::time::Duration>,
65 pub error_message: Option<String>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
70pub enum DownloadStatus {
71 Queued,
72 Starting,
73 Downloading,
74 Paused,
75 Completed,
76 Failed,
77 Cancelled,
78}
79
80impl DownloadStatus {
81 pub fn is_active(&self) -> bool {
82 matches!(self, DownloadStatus::Queued | DownloadStatus::Starting | DownloadStatus::Downloading)
83 }
84
85 pub fn is_finished(&self) -> bool {
86 matches!(self, DownloadStatus::Completed | DownloadStatus::Failed | DownloadStatus::Cancelled)
87 }
88}
89
90pub struct ProgressTracker {
92 downloads: Arc<RwLock<HashMap<String, DownloadProgress>>>,
93 watchers: Arc<RwLock<HashMap<String, watch::Sender<DownloadProgress>>>>,
94}
95
96impl ProgressTracker {
97 pub fn new() -> Self {
98 Self {
99 downloads: Arc::new(RwLock::new(HashMap::new())),
100 watchers: Arc::new(RwLock::new(HashMap::new())),
101 }
102 }
103
104 pub async fn start_download(
106 &self,
107 model_id: String,
108 model_name: String,
109 total_bytes: Option<u64>,
110 ) -> String {
111 let download_id = Uuid::new_v4().to_string();
112 let progress = DownloadProgress {
113 download_id: download_id.clone(),
114 model_id,
115 model_name,
116 status: DownloadStatus::Queued,
117 bytes_downloaded: 0,
118 total_bytes,
119 percentage: 0.0,
120 speed_bps: 0.0,
121 elapsed_time: std::time::Duration::from_secs(0),
122 estimated_time_remaining: None,
123 error_message: None,
124 };
125
126 {
127 let mut downloads = self.downloads.write().await;
128 downloads.insert(download_id.clone(), progress);
129 }
130
131 info!("Started tracking download: {}", download_id);
132 download_id
133 }
134
135 pub async fn update_progress(
137 &self,
138 download_id: &str,
139 bytes_downloaded: u64,
140 status: DownloadStatus,
141 error_message: Option<String>,
142 ) {
143 let mut downloads = self.downloads.write().await;
144
145 if let Some(progress) = downloads.get_mut(download_id) {
146 let old_bytes = progress.bytes_downloaded;
147 progress.bytes_downloaded = bytes_downloaded;
148 progress.status = status;
149 progress.error_message = error_message;
150
151 if let Some(total) = progress.total_bytes {
153 if total > 0 {
154 progress.percentage = (bytes_downloaded as f32 / total as f32) * 100.0;
155 }
156 }
157
158 if bytes_downloaded > old_bytes && progress.elapsed_time.as_secs() > 0 {
160 let time_elapsed_secs = progress.elapsed_time.as_secs_f64();
161 progress.speed_bps = bytes_downloaded as f64 / time_elapsed_secs;
162
163 if let Some(total) = progress.total_bytes {
164 let remaining_bytes = total.saturating_sub(bytes_downloaded);
166 if progress.speed_bps > 0.0 {
167 let eta_secs = remaining_bytes as f64 / progress.speed_bps;
168 progress.estimated_time_remaining = Some(std::time::Duration::from_secs_f64(eta_secs));
169 }
170 }
171 }
172
173 self.notify_watchers(download_id, progress.clone()).await;
175
176 debug!("Updated progress for {}: {:.1}%", download_id, progress.percentage);
177 }
178 }
179
180 pub async fn update_elapsed_time(&self, download_id: &str, elapsed: std::time::Duration) {
182 let mut downloads = self.downloads.write().await;
183 if let Some(progress) = downloads.get_mut(download_id) {
184 progress.elapsed_time = elapsed;
185 }
186 }
187
188 pub async fn update_total_bytes(&self, download_id: &str, total_bytes: u64) {
190 let mut downloads = self.downloads.write().await;
191 if let Some(progress) = downloads.get_mut(download_id) {
192 if progress.total_bytes.is_none() || progress.total_bytes != Some(total_bytes) {
194 progress.total_bytes = Some(total_bytes);
195
196 if total_bytes > 0 {
198 progress.percentage = (progress.bytes_downloaded as f32 / total_bytes as f32) * 100.0;
199
200 if progress.speed_bps > 0.0 {
202 let remaining_bytes = total_bytes.saturating_sub(progress.bytes_downloaded);
203 let eta_secs = remaining_bytes as f64 / progress.speed_bps;
204 progress.estimated_time_remaining = Some(std::time::Duration::from_secs_f64(eta_secs));
205 }
206 }
207
208 self.notify_watchers(download_id, progress.clone()).await;
210 debug!("Updated total_bytes for {}: {} bytes", download_id, total_bytes);
211 }
212 }
213 }
214
215 pub async fn get_progress(&self, download_id: &str) -> Option<DownloadProgress> {
217 let downloads = self.downloads.read().await;
218 downloads.get(download_id).cloned()
219 }
220
221 pub async fn get_active_downloads(&self) -> Vec<DownloadProgress> {
223 let downloads = self.downloads.read().await;
224 downloads.values()
225 .filter(|p| p.status.is_active())
226 .cloned()
227 .collect()
228 }
229
230 pub async fn get_all_downloads(&self) -> Vec<DownloadProgress> {
232 let downloads = self.downloads.read().await;
233 downloads.values().cloned().collect()
234 }
235
236 pub async fn subscribe(&self, download_id: &str) -> Option<watch::Receiver<DownloadProgress>> {
238 let mut watchers = self.watchers.write().await;
239 let (tx, rx) = watch::channel(DownloadProgress {
240 download_id: download_id.to_string(),
241 model_id: String::new(),
242 model_name: String::new(),
243 status: DownloadStatus::Queued,
244 bytes_downloaded: 0,
245 total_bytes: None,
246 percentage: 0.0,
247 speed_bps: 0.0,
248 elapsed_time: std::time::Duration::from_secs(0),
249 estimated_time_remaining: None,
250 error_message: None,
251 });
252
253 watchers.insert(download_id.to_string(), tx);
254 Some(rx)
255 }
256
257 async fn notify_watchers(&self, download_id: &str, progress: DownloadProgress) {
259 let watchers = self.watchers.read().await;
260 if let Some(tx) = watchers.get(download_id) {
261 let _ = tx.send(progress);
262 }
263 }
264
265 pub async fn remove_download(&self, download_id: &str) {
267 {
268 let mut downloads = self.downloads.write().await;
269 downloads.remove(download_id);
270 }
271 {
272 let mut watchers = self.watchers.write().await;
273 watchers.remove(download_id);
274 }
275 info!("Removed download tracking: {}", download_id);
276 }
277
278 pub async fn cancel_download(&self, download_id: &str) -> bool {
280 let mut downloads = self.downloads.write().await;
281 if let Some(progress) = downloads.get_mut(download_id) {
282 if progress.status.is_active() {
283 progress.status = DownloadStatus::Cancelled;
284 self.notify_watchers(download_id, progress.clone()).await;
285 true
286 } else {
287 false
288 }
289 } else {
290 false
291 }
292 }
293
294 pub async fn get_statistics(&self) -> DownloadStatistics {
296 let downloads = self.downloads.read().await;
297 let mut stats = DownloadStatistics::default();
298
299 for progress in downloads.values() {
300 match progress.status {
301 DownloadStatus::Queued => stats.queued += 1,
302 DownloadStatus::Starting => stats.starting += 1,
303 DownloadStatus::Downloading => {
304 stats.downloading += 1;
305 stats.total_speed_bps += progress.speed_bps;
306 },
307 DownloadStatus::Paused => stats.paused += 1,
308 DownloadStatus::Completed => stats.completed += 1,
309 DownloadStatus::Failed => stats.failed += 1,
310 DownloadStatus::Cancelled => stats.cancelled += 1,
311 }
312
313 if let Some(total) = progress.total_bytes {
314 stats.total_data_bytes += total;
315 }
316 stats.downloaded_bytes += progress.bytes_downloaded;
317 }
318
319 stats
320 }
321}
322
323#[derive(Debug, Default)]
325pub struct DownloadStatistics {
326 pub queued: usize,
327 pub starting: usize,
328 pub downloading: usize,
329 pub paused: usize,
330 pub completed: usize,
331 pub failed: usize,
332 pub cancelled: usize,
333 pub total_speed_bps: f64,
334 pub downloaded_bytes: u64,
335 pub total_data_bytes: u64,
336}
337
338impl DownloadStatistics {
339 pub fn active_downloads(&self) -> usize {
340 self.queued + self.starting + self.downloading + self.paused
341 }
342
343 pub fn finished_downloads(&self) -> usize {
344 self.completed + self.failed + self.cancelled
345 }
346
347 pub fn overall_percentage(&self) -> f32 {
348 if self.total_data_bytes > 0 {
349 (self.downloaded_bytes as f32 / self.total_data_bytes as f32) * 100.0
350 } else {
351 0.0
352 }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[tokio::test]
361 async fn test_progress_tracking() {
362 let tracker = ProgressTracker::new();
363
364 let download_id = tracker.start_download(
365 "test-model".to_string(),
366 "Test Model".to_string(),
367 Some(1000)
368 ).await;
369
370 tracker.update_progress(&download_id, 500, DownloadStatus::Downloading, None).await;
372
373 let progress = tracker.get_progress(&download_id).await.unwrap();
374 assert_eq!(progress.bytes_downloaded, 500);
375 assert_eq!(progress.percentage, 50.0);
376 assert_eq!(progress.status, DownloadStatus::Downloading);
377 }
378
379 #[tokio::test]
380 async fn test_subscription() {
381 let tracker = ProgressTracker::new();
382
383 let download_id = tracker.start_download(
384 "test-model".to_string(),
385 "Test Model".to_string(),
386 Some(1000)
387 ).await;
388
389 let mut receiver = tracker.subscribe(&download_id).await.unwrap();
390
391 tracker.update_progress(&download_id, 250, DownloadStatus::Downloading, None).await;
393
394 let progress = receiver.borrow_and_update().clone();
396 assert_eq!(progress.bytes_downloaded, 250);
397 }
398}