1use super::QoSClass;
50use std::collections::HashMap;
51use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
52use std::sync::Arc;
53use std::time::Instant;
54use tokio::sync::RwLock;
55use uuid::Uuid;
56
57pub type TransferId = Uuid;
59
60#[derive(Debug)]
62pub struct ActiveTransfer {
63 pub id: TransferId,
65
66 pub class: QoSClass,
68
69 pub bytes_total: usize,
71
72 pub bytes_sent: AtomicUsize,
74
75 pub can_pause: bool,
77
78 pub is_paused: AtomicBool,
80
81 pub started_at: Instant,
83
84 pub paused_at: RwLock<Option<Instant>>,
86}
87
88impl ActiveTransfer {
89 pub fn new(class: QoSClass, bytes_total: usize, can_pause: bool) -> Self {
91 Self {
92 id: Uuid::new_v4(),
93 class,
94 bytes_total,
95 bytes_sent: AtomicUsize::new(0),
96 can_pause,
97 is_paused: AtomicBool::new(false),
98 started_at: Instant::now(),
99 paused_at: RwLock::new(None),
100 }
101 }
102
103 pub fn progress(&self) -> f64 {
105 if self.bytes_total == 0 {
106 1.0
107 } else {
108 self.bytes_sent.load(Ordering::Relaxed) as f64 / self.bytes_total as f64
109 }
110 }
111
112 pub fn bytes_remaining(&self) -> usize {
114 let sent = self.bytes_sent.load(Ordering::Relaxed);
115 self.bytes_total.saturating_sub(sent)
116 }
117
118 pub fn record_sent(&self, bytes: usize) {
120 self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
121 }
122
123 pub fn is_complete(&self) -> bool {
125 self.bytes_sent.load(Ordering::Relaxed) >= self.bytes_total
126 }
127
128 pub fn can_be_preempted_by(&self, class: QoSClass) -> bool {
130 self.can_pause && class.can_preempt(&self.class)
131 }
132
133 pub async fn pause(&self) {
135 if self.can_pause && !self.is_paused.load(Ordering::Relaxed) {
136 self.is_paused.store(true, Ordering::Relaxed);
137 *self.paused_at.write().await = Some(Instant::now());
138 }
139 }
140
141 pub async fn resume(&self) {
143 self.is_paused.store(false, Ordering::Relaxed);
144 *self.paused_at.write().await = None;
145 }
146
147 pub async fn paused_duration(&self) -> Option<std::time::Duration> {
149 if self.is_paused.load(Ordering::Relaxed) {
150 self.paused_at.read().await.map(|t| t.elapsed())
151 } else {
152 None
153 }
154 }
155}
156
157#[derive(Debug)]
162pub struct PreemptionController {
163 active_transfers: RwLock<HashMap<TransferId, Arc<ActiveTransfer>>>,
165
166 preemption_count: AtomicUsize,
168
169 paused_count: AtomicUsize,
171}
172
173impl PreemptionController {
174 pub fn new() -> Self {
176 Self {
177 active_transfers: RwLock::new(HashMap::new()),
178 preemption_count: AtomicUsize::new(0),
179 paused_count: AtomicUsize::new(0),
180 }
181 }
182
183 pub async fn register_transfer(
187 &self,
188 class: QoSClass,
189 bytes_total: usize,
190 can_pause: bool,
191 ) -> TransferId {
192 let transfer = Arc::new(ActiveTransfer::new(class, bytes_total, can_pause));
193 let id = transfer.id;
194
195 let mut transfers = self.active_transfers.write().await;
196 transfers.insert(id, transfer);
197
198 id
199 }
200
201 pub async fn unregister_transfer(&self, id: TransferId) {
203 let mut transfers = self.active_transfers.write().await;
204 if let Some(transfer) = transfers.remove(&id) {
205 if transfer.is_paused.load(Ordering::Relaxed) {
206 self.paused_count.fetch_sub(1, Ordering::Relaxed);
207 }
208 }
209 }
210
211 pub async fn get_transfer(&self, id: TransferId) -> Option<Arc<ActiveTransfer>> {
213 let transfers = self.active_transfers.read().await;
214 transfers.get(&id).cloned()
215 }
216
217 pub async fn should_preempt(&self, incoming_class: QoSClass) -> bool {
221 if !matches!(incoming_class, QoSClass::Critical | QoSClass::High) {
223 return false;
224 }
225
226 let transfers = self.active_transfers.read().await;
227 for transfer in transfers.values() {
228 if transfer.can_be_preempted_by(incoming_class)
229 && !transfer.is_paused.load(Ordering::Relaxed)
230 {
231 return true;
232 }
233 }
234 false
235 }
236
237 pub async fn pause_transfers_below(&self, class: QoSClass) -> Vec<TransferId> {
241 let transfers = self.active_transfers.read().await;
242 let mut paused = Vec::new();
243
244 for transfer in transfers.values() {
245 if transfer.can_be_preempted_by(class) {
246 transfer.pause().await;
247 paused.push(transfer.id);
248 self.paused_count.fetch_add(1, Ordering::Relaxed);
249 }
250 }
251
252 if !paused.is_empty() {
253 self.preemption_count.fetch_add(1, Ordering::Relaxed);
254 }
255
256 paused
257 }
258
259 pub async fn resume_transfers(&self, transfers_to_resume: Vec<TransferId>) {
261 let transfers = self.active_transfers.read().await;
262
263 for id in transfers_to_resume {
264 if let Some(transfer) = transfers.get(&id) {
265 if transfer.is_paused.load(Ordering::Relaxed) {
266 transfer.resume().await;
267 self.paused_count.fetch_sub(1, Ordering::Relaxed);
268 }
269 }
270 }
271 }
272
273 pub async fn resume_all(&self) {
275 let transfers = self.active_transfers.read().await;
276
277 for transfer in transfers.values() {
278 if transfer.is_paused.load(Ordering::Relaxed) {
279 transfer.resume().await;
280 }
281 }
282
283 self.paused_count.store(0, Ordering::Relaxed);
284 }
285
286 pub async fn active_count(&self) -> usize {
288 self.active_transfers.read().await.len()
289 }
290
291 pub fn paused_count(&self) -> usize {
293 self.paused_count.load(Ordering::Relaxed)
294 }
295
296 pub fn preemption_count(&self) -> usize {
298 self.preemption_count.load(Ordering::Relaxed)
299 }
300
301 pub async fn transfers_by_class(&self, class: QoSClass) -> Vec<Arc<ActiveTransfer>> {
303 let transfers = self.active_transfers.read().await;
304 transfers
305 .values()
306 .filter(|t| t.class == class)
307 .cloned()
308 .collect()
309 }
310
311 pub async fn preemptable_transfers(&self, by_class: QoSClass) -> Vec<Arc<ActiveTransfer>> {
313 let transfers = self.active_transfers.read().await;
314 transfers
315 .values()
316 .filter(|t| t.can_be_preempted_by(by_class))
317 .cloned()
318 .collect()
319 }
320
321 pub async fn bandwidth_used_below(&self, class: QoSClass) -> usize {
323 let transfers = self.active_transfers.read().await;
324 transfers
325 .values()
326 .filter(|t| class.can_preempt(&t.class) && !t.is_paused.load(Ordering::Relaxed))
327 .map(|t| t.bytes_remaining())
328 .sum()
329 }
330
331 pub async fn cleanup_completed(&self) -> usize {
333 let mut transfers = self.active_transfers.write().await;
334 let initial_len = transfers.len();
335
336 transfers.retain(|_, t| !t.is_complete());
337
338 initial_len - transfers.len()
339 }
340
341 pub async fn stats(&self) -> PreemptionStats {
343 let transfers = self.active_transfers.read().await;
344
345 let mut by_class = HashMap::new();
346 for class in QoSClass::all_by_priority() {
347 by_class.insert(*class, 0usize);
348 }
349
350 for transfer in transfers.values() {
351 *by_class.entry(transfer.class).or_insert(0) += 1;
352 }
353
354 PreemptionStats {
355 active_transfers: transfers.len(),
356 paused_transfers: self.paused_count.load(Ordering::Relaxed),
357 preemption_events: self.preemption_count.load(Ordering::Relaxed),
358 transfers_by_class: by_class,
359 }
360 }
361}
362
363impl Default for PreemptionController {
364 fn default() -> Self {
365 Self::new()
366 }
367}
368
369#[derive(Debug, Clone)]
371pub struct PreemptionStats {
372 pub active_transfers: usize,
374
375 pub paused_transfers: usize,
377
378 pub preemption_events: usize,
380
381 pub transfers_by_class: HashMap<QoSClass, usize>,
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn test_active_transfer_creation() {
391 let transfer = ActiveTransfer::new(QoSClass::Normal, 1000, true);
392
393 assert_eq!(transfer.class, QoSClass::Normal);
394 assert_eq!(transfer.bytes_total, 1000);
395 assert!(transfer.can_pause);
396 assert!(!transfer.is_paused.load(Ordering::Relaxed));
397 assert_eq!(transfer.bytes_remaining(), 1000);
398 }
399
400 #[test]
401 fn test_transfer_progress() {
402 let transfer = ActiveTransfer::new(QoSClass::Normal, 1000, true);
403
404 assert_eq!(transfer.progress(), 0.0);
405
406 transfer.record_sent(500);
407 assert!((transfer.progress() - 0.5).abs() < 0.001);
408
409 transfer.record_sent(500);
410 assert!((transfer.progress() - 1.0).abs() < 0.001);
411 assert!(transfer.is_complete());
412 }
413
414 #[test]
415 fn test_preemption_eligibility() {
416 let low_transfer = ActiveTransfer::new(QoSClass::Low, 1000, true);
417 let critical_transfer = ActiveTransfer::new(QoSClass::Critical, 1000, true);
418
419 assert!(low_transfer.can_be_preempted_by(QoSClass::Critical));
421
422 assert!(!critical_transfer.can_be_preempted_by(QoSClass::Low));
424
425 assert!(!low_transfer.can_be_preempted_by(QoSClass::Low));
427 }
428
429 #[test]
430 fn test_non_pausable_transfer() {
431 let transfer = ActiveTransfer::new(QoSClass::Low, 1000, false);
432
433 assert!(!transfer.can_be_preempted_by(QoSClass::Critical));
435 }
436
437 #[tokio::test]
438 async fn test_controller_register_unregister() {
439 let controller = PreemptionController::new();
440
441 let id = controller
442 .register_transfer(QoSClass::Normal, 1000, true)
443 .await;
444
445 assert_eq!(controller.active_count().await, 1);
446
447 controller.unregister_transfer(id).await;
448
449 assert_eq!(controller.active_count().await, 0);
450 }
451
452 #[tokio::test]
453 async fn test_should_preempt() {
454 let controller = PreemptionController::new();
455
456 assert!(!controller.should_preempt(QoSClass::Critical).await);
458
459 controller
461 .register_transfer(QoSClass::Low, 1000, true)
462 .await;
463
464 assert!(controller.should_preempt(QoSClass::Critical).await);
466
467 assert!(!controller.should_preempt(QoSClass::Bulk).await);
469 }
470
471 #[tokio::test]
472 async fn test_pause_resume() {
473 let controller = PreemptionController::new();
474
475 let id1 = controller
476 .register_transfer(QoSClass::Low, 1000, true)
477 .await;
478 let id2 = controller
479 .register_transfer(QoSClass::Bulk, 1000, true)
480 .await;
481 let _id3 = controller
482 .register_transfer(QoSClass::Critical, 1000, true)
483 .await;
484
485 let paused = controller.pause_transfers_below(QoSClass::Critical).await;
487
488 assert_eq!(paused.len(), 2);
489 assert!(paused.contains(&id1));
490 assert!(paused.contains(&id2));
491 assert_eq!(controller.paused_count(), 2);
492
493 controller.resume_transfers(paused).await;
495
496 assert_eq!(controller.paused_count(), 0);
497 }
498
499 #[tokio::test]
500 async fn test_preemption_count() {
501 let controller = PreemptionController::new();
502
503 controller
504 .register_transfer(QoSClass::Bulk, 1000, true)
505 .await;
506
507 assert_eq!(controller.preemption_count(), 0);
508
509 controller.pause_transfers_below(QoSClass::Critical).await;
510
511 assert_eq!(controller.preemption_count(), 1);
512 }
513
514 #[tokio::test]
515 async fn test_transfers_by_class() {
516 let controller = PreemptionController::new();
517
518 controller
519 .register_transfer(QoSClass::Normal, 1000, true)
520 .await;
521 controller
522 .register_transfer(QoSClass::Normal, 2000, true)
523 .await;
524 controller
525 .register_transfer(QoSClass::High, 3000, true)
526 .await;
527
528 let normal = controller.transfers_by_class(QoSClass::Normal).await;
529 assert_eq!(normal.len(), 2);
530
531 let high = controller.transfers_by_class(QoSClass::High).await;
532 assert_eq!(high.len(), 1);
533 }
534
535 #[tokio::test]
536 async fn test_cleanup_completed() {
537 let controller = PreemptionController::new();
538
539 let id = controller
540 .register_transfer(QoSClass::Normal, 100, true)
541 .await;
542
543 let transfer = controller.get_transfer(id).await.unwrap();
544 transfer.record_sent(100);
545
546 assert!(transfer.is_complete());
547
548 let cleaned = controller.cleanup_completed().await;
549 assert_eq!(cleaned, 1);
550 assert_eq!(controller.active_count().await, 0);
551 }
552
553 #[tokio::test]
554 async fn test_bandwidth_used_below() {
555 let controller = PreemptionController::new();
556
557 let id = controller
558 .register_transfer(QoSClass::Low, 1000, true)
559 .await;
560 controller
561 .register_transfer(QoSClass::Bulk, 2000, true)
562 .await;
563 controller
564 .register_transfer(QoSClass::Critical, 3000, true)
565 .await;
566
567 let bw = controller.bandwidth_used_below(QoSClass::Critical).await;
569 assert_eq!(bw, 3000);
570
571 controller.pause_transfers_below(QoSClass::High).await;
573
574 let transfer = controller.get_transfer(id).await.unwrap();
576 assert!(transfer.is_paused.load(Ordering::Relaxed));
577 }
578
579 #[tokio::test]
580 async fn test_stats() {
581 let controller = PreemptionController::new();
582
583 controller
584 .register_transfer(QoSClass::Critical, 1000, true)
585 .await;
586 controller
587 .register_transfer(QoSClass::Normal, 1000, true)
588 .await;
589 controller
590 .register_transfer(QoSClass::Bulk, 1000, true)
591 .await;
592
593 let stats = controller.stats().await;
594
595 assert_eq!(stats.active_transfers, 3);
596 assert_eq!(stats.paused_transfers, 0);
597 assert_eq!(
598 *stats.transfers_by_class.get(&QoSClass::Critical).unwrap(),
599 1
600 );
601 assert_eq!(*stats.transfers_by_class.get(&QoSClass::Normal).unwrap(), 1);
602 assert_eq!(*stats.transfers_by_class.get(&QoSClass::Bulk).unwrap(), 1);
603 }
604
605 #[tokio::test]
606 async fn test_resume_all() {
607 let controller = PreemptionController::new();
608
609 controller
610 .register_transfer(QoSClass::Low, 1000, true)
611 .await;
612 controller
613 .register_transfer(QoSClass::Bulk, 1000, true)
614 .await;
615
616 controller.pause_transfers_below(QoSClass::Critical).await;
617 assert_eq!(controller.paused_count(), 2);
618
619 controller.resume_all().await;
620 assert_eq!(controller.paused_count(), 0);
621 }
622}