Skip to main content

peat_mesh/qos/
preemption.rs

1//! Preemption control for QoS-aware transfers (ADR-019 Phase 3)
2//!
3//! This module manages transfer preemption, allowing high-priority data
4//! to pause or cancel lower-priority transfers when bandwidth is constrained.
5//!
6//! # Preemption Rules
7//!
8//! - P1 Critical: Can preempt all lower priorities (P2-P5)
9//! - P2 High: Can preempt P3-P5
10//! - P3-P5: Cannot preempt (must wait for bandwidth)
11//!
12//! # Architecture
13//!
14//! The `PreemptionController` tracks active transfers and coordinates
15//! preemption decisions:
16//!
17//! 1. When critical data arrives, check if preemption is needed
18//! 2. Identify preemptable transfers (lower priority, pausable)
19//! 3. Pause transfers and release their bandwidth
20//! 4. Resume paused transfers when bandwidth becomes available
21//!
22//! # Example
23//!
24//! ```
25//! use peat_mesh::qos::{QoSClass, PreemptionController, ActiveTransfer};
26//! use uuid::Uuid;
27//!
28//! # async fn example() {
29//! let controller = PreemptionController::new();
30//!
31//! // Register an active transfer
32//! let transfer_id = controller.register_transfer(
33//!     QoSClass::Low,
34//!     10000,  // 10KB
35//!     true,   // can pause
36//! ).await;
37//!
38//! // Check if preemption is needed for critical data
39//! if controller.should_preempt(QoSClass::Critical).await {
40//!     // Pause lower priority transfers
41//!     let paused = controller.pause_transfers_below(QoSClass::Critical).await;
42//!     // ... transmit critical data ...
43//!     // Resume paused transfers
44//!     controller.resume_transfers(paused).await;
45//! }
46//! # }
47//! ```
48
49use super::QoSClass;
50use std::collections::HashMap;
51use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
52use std::sync::Arc;
53use std::time::Instant;
54use tokio::sync::RwLock;
55use uuid::Uuid;
56
57/// Unique identifier for a transfer
58pub type TransferId = Uuid;
59
60/// An active data transfer being tracked by the preemption controller
61#[derive(Debug)]
62pub struct ActiveTransfer {
63    /// Unique transfer identifier
64    pub id: TransferId,
65
66    /// QoS class of this transfer
67    pub class: QoSClass,
68
69    /// Total bytes to transfer
70    pub bytes_total: usize,
71
72    /// Bytes sent so far
73    pub bytes_sent: AtomicUsize,
74
75    /// Whether this transfer can be paused
76    pub can_pause: bool,
77
78    /// Whether this transfer is currently paused
79    pub is_paused: AtomicBool,
80
81    /// When the transfer started
82    pub started_at: Instant,
83
84    /// When the transfer was paused (if applicable)
85    pub paused_at: RwLock<Option<Instant>>,
86}
87
88impl ActiveTransfer {
89    /// Create a new active transfer
90    pub fn new(class: QoSClass, bytes_total: usize, can_pause: bool) -> Self {
91        Self {
92            id: Uuid::new_v4(),
93            class,
94            bytes_total,
95            bytes_sent: AtomicUsize::new(0),
96            can_pause,
97            is_paused: AtomicBool::new(false),
98            started_at: Instant::now(),
99            paused_at: RwLock::new(None),
100        }
101    }
102
103    /// Get progress as percentage (0.0 - 1.0)
104    pub fn progress(&self) -> f64 {
105        if self.bytes_total == 0 {
106            1.0
107        } else {
108            self.bytes_sent.load(Ordering::Relaxed) as f64 / self.bytes_total as f64
109        }
110    }
111
112    /// Get remaining bytes to transfer
113    pub fn bytes_remaining(&self) -> usize {
114        let sent = self.bytes_sent.load(Ordering::Relaxed);
115        self.bytes_total.saturating_sub(sent)
116    }
117
118    /// Record bytes sent
119    pub fn record_sent(&self, bytes: usize) {
120        self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
121    }
122
123    /// Check if transfer is complete
124    pub fn is_complete(&self) -> bool {
125        self.bytes_sent.load(Ordering::Relaxed) >= self.bytes_total
126    }
127
128    /// Check if this transfer can be preempted by the given class
129    pub fn can_be_preempted_by(&self, class: QoSClass) -> bool {
130        self.can_pause && class.can_preempt(&self.class)
131    }
132
133    /// Pause this transfer
134    pub async fn pause(&self) {
135        if self.can_pause && !self.is_paused.load(Ordering::Relaxed) {
136            self.is_paused.store(true, Ordering::Relaxed);
137            *self.paused_at.write().await = Some(Instant::now());
138        }
139    }
140
141    /// Resume this transfer
142    pub async fn resume(&self) {
143        self.is_paused.store(false, Ordering::Relaxed);
144        *self.paused_at.write().await = None;
145    }
146
147    /// Get time spent paused (if currently paused)
148    pub async fn paused_duration(&self) -> Option<std::time::Duration> {
149        if self.is_paused.load(Ordering::Relaxed) {
150            self.paused_at.read().await.map(|t| t.elapsed())
151        } else {
152            None
153        }
154    }
155}
156
157/// Controller for managing transfer preemption
158///
159/// Tracks active transfers and coordinates preemption decisions
160/// to ensure high-priority data can preempt lower priorities.
161#[derive(Debug)]
162pub struct PreemptionController {
163    /// Active transfers indexed by ID
164    active_transfers: RwLock<HashMap<TransferId, Arc<ActiveTransfer>>>,
165
166    /// Number of preemption events
167    preemption_count: AtomicUsize,
168
169    /// Number of transfers currently paused
170    paused_count: AtomicUsize,
171}
172
173impl PreemptionController {
174    /// Create a new preemption controller
175    pub fn new() -> Self {
176        Self {
177            active_transfers: RwLock::new(HashMap::new()),
178            preemption_count: AtomicUsize::new(0),
179            paused_count: AtomicUsize::new(0),
180        }
181    }
182
183    /// Register a new active transfer
184    ///
185    /// Returns the transfer ID for tracking.
186    pub async fn register_transfer(
187        &self,
188        class: QoSClass,
189        bytes_total: usize,
190        can_pause: bool,
191    ) -> TransferId {
192        let transfer = Arc::new(ActiveTransfer::new(class, bytes_total, can_pause));
193        let id = transfer.id;
194
195        let mut transfers = self.active_transfers.write().await;
196        transfers.insert(id, transfer);
197
198        id
199    }
200
201    /// Unregister a completed or cancelled transfer
202    pub async fn unregister_transfer(&self, id: TransferId) {
203        let mut transfers = self.active_transfers.write().await;
204        if let Some(transfer) = transfers.remove(&id) {
205            if transfer.is_paused.load(Ordering::Relaxed) {
206                self.paused_count.fetch_sub(1, Ordering::Relaxed);
207            }
208        }
209    }
210
211    /// Get a transfer by ID
212    pub async fn get_transfer(&self, id: TransferId) -> Option<Arc<ActiveTransfer>> {
213        let transfers = self.active_transfers.read().await;
214        transfers.get(&id).cloned()
215    }
216
217    /// Check if preemption is needed for incoming data of given class
218    ///
219    /// Returns true if there are lower-priority transfers that can be preempted.
220    pub async fn should_preempt(&self, incoming_class: QoSClass) -> bool {
221        // Only P1 and P2 can preempt
222        if !matches!(incoming_class, QoSClass::Critical | QoSClass::High) {
223            return false;
224        }
225
226        let transfers = self.active_transfers.read().await;
227        for transfer in transfers.values() {
228            if transfer.can_be_preempted_by(incoming_class)
229                && !transfer.is_paused.load(Ordering::Relaxed)
230            {
231                return true;
232            }
233        }
234        false
235    }
236
237    /// Pause all transfers below a given priority
238    ///
239    /// Returns the IDs of paused transfers for later resumption.
240    pub async fn pause_transfers_below(&self, class: QoSClass) -> Vec<TransferId> {
241        let transfers = self.active_transfers.read().await;
242        let mut paused = Vec::new();
243
244        for transfer in transfers.values() {
245            if transfer.can_be_preempted_by(class) {
246                transfer.pause().await;
247                paused.push(transfer.id);
248                self.paused_count.fetch_add(1, Ordering::Relaxed);
249            }
250        }
251
252        if !paused.is_empty() {
253            self.preemption_count.fetch_add(1, Ordering::Relaxed);
254        }
255
256        paused
257    }
258
259    /// Resume previously paused transfers
260    pub async fn resume_transfers(&self, transfers_to_resume: Vec<TransferId>) {
261        let transfers = self.active_transfers.read().await;
262
263        for id in transfers_to_resume {
264            if let Some(transfer) = transfers.get(&id) {
265                if transfer.is_paused.load(Ordering::Relaxed) {
266                    transfer.resume().await;
267                    self.paused_count.fetch_sub(1, Ordering::Relaxed);
268                }
269            }
270        }
271    }
272
273    /// Resume all paused transfers
274    pub async fn resume_all(&self) {
275        let transfers = self.active_transfers.read().await;
276
277        for transfer in transfers.values() {
278            if transfer.is_paused.load(Ordering::Relaxed) {
279                transfer.resume().await;
280            }
281        }
282
283        self.paused_count.store(0, Ordering::Relaxed);
284    }
285
286    /// Get number of active transfers
287    pub async fn active_count(&self) -> usize {
288        self.active_transfers.read().await.len()
289    }
290
291    /// Get number of paused transfers
292    pub fn paused_count(&self) -> usize {
293        self.paused_count.load(Ordering::Relaxed)
294    }
295
296    /// Get total preemption events
297    pub fn preemption_count(&self) -> usize {
298        self.preemption_count.load(Ordering::Relaxed)
299    }
300
301    /// Get transfers by class
302    pub async fn transfers_by_class(&self, class: QoSClass) -> Vec<Arc<ActiveTransfer>> {
303        let transfers = self.active_transfers.read().await;
304        transfers
305            .values()
306            .filter(|t| t.class == class)
307            .cloned()
308            .collect()
309    }
310
311    /// Get all preemptable transfers for a given class
312    pub async fn preemptable_transfers(&self, by_class: QoSClass) -> Vec<Arc<ActiveTransfer>> {
313        let transfers = self.active_transfers.read().await;
314        transfers
315            .values()
316            .filter(|t| t.can_be_preempted_by(by_class))
317            .cloned()
318            .collect()
319    }
320
321    /// Calculate bandwidth currently used by transfers below a priority
322    pub async fn bandwidth_used_below(&self, class: QoSClass) -> usize {
323        let transfers = self.active_transfers.read().await;
324        transfers
325            .values()
326            .filter(|t| class.can_preempt(&t.class) && !t.is_paused.load(Ordering::Relaxed))
327            .map(|t| t.bytes_remaining())
328            .sum()
329    }
330
331    /// Clean up completed transfers
332    pub async fn cleanup_completed(&self) -> usize {
333        let mut transfers = self.active_transfers.write().await;
334        let initial_len = transfers.len();
335
336        transfers.retain(|_, t| !t.is_complete());
337
338        initial_len - transfers.len()
339    }
340
341    /// Get controller statistics
342    pub async fn stats(&self) -> PreemptionStats {
343        let transfers = self.active_transfers.read().await;
344
345        let mut by_class = HashMap::new();
346        for class in QoSClass::all_by_priority() {
347            by_class.insert(*class, 0usize);
348        }
349
350        for transfer in transfers.values() {
351            *by_class.entry(transfer.class).or_insert(0) += 1;
352        }
353
354        PreemptionStats {
355            active_transfers: transfers.len(),
356            paused_transfers: self.paused_count.load(Ordering::Relaxed),
357            preemption_events: self.preemption_count.load(Ordering::Relaxed),
358            transfers_by_class: by_class,
359        }
360    }
361}
362
363impl Default for PreemptionController {
364    fn default() -> Self {
365        Self::new()
366    }
367}
368
369/// Preemption controller statistics
370#[derive(Debug, Clone)]
371pub struct PreemptionStats {
372    /// Number of active transfers
373    pub active_transfers: usize,
374
375    /// Number of paused transfers
376    pub paused_transfers: usize,
377
378    /// Total preemption events
379    pub preemption_events: usize,
380
381    /// Transfers by QoS class
382    pub transfers_by_class: HashMap<QoSClass, usize>,
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    #[test]
390    fn test_active_transfer_creation() {
391        let transfer = ActiveTransfer::new(QoSClass::Normal, 1000, true);
392
393        assert_eq!(transfer.class, QoSClass::Normal);
394        assert_eq!(transfer.bytes_total, 1000);
395        assert!(transfer.can_pause);
396        assert!(!transfer.is_paused.load(Ordering::Relaxed));
397        assert_eq!(transfer.bytes_remaining(), 1000);
398    }
399
400    #[test]
401    fn test_transfer_progress() {
402        let transfer = ActiveTransfer::new(QoSClass::Normal, 1000, true);
403
404        assert_eq!(transfer.progress(), 0.0);
405
406        transfer.record_sent(500);
407        assert!((transfer.progress() - 0.5).abs() < 0.001);
408
409        transfer.record_sent(500);
410        assert!((transfer.progress() - 1.0).abs() < 0.001);
411        assert!(transfer.is_complete());
412    }
413
414    #[test]
415    fn test_preemption_eligibility() {
416        let low_transfer = ActiveTransfer::new(QoSClass::Low, 1000, true);
417        let critical_transfer = ActiveTransfer::new(QoSClass::Critical, 1000, true);
418
419        // Critical can preempt Low
420        assert!(low_transfer.can_be_preempted_by(QoSClass::Critical));
421
422        // Low cannot preempt Critical
423        assert!(!critical_transfer.can_be_preempted_by(QoSClass::Low));
424
425        // Same class cannot preempt
426        assert!(!low_transfer.can_be_preempted_by(QoSClass::Low));
427    }
428
429    #[test]
430    fn test_non_pausable_transfer() {
431        let transfer = ActiveTransfer::new(QoSClass::Low, 1000, false);
432
433        // Even Critical cannot preempt non-pausable transfers
434        assert!(!transfer.can_be_preempted_by(QoSClass::Critical));
435    }
436
437    #[tokio::test]
438    async fn test_controller_register_unregister() {
439        let controller = PreemptionController::new();
440
441        let id = controller
442            .register_transfer(QoSClass::Normal, 1000, true)
443            .await;
444
445        assert_eq!(controller.active_count().await, 1);
446
447        controller.unregister_transfer(id).await;
448
449        assert_eq!(controller.active_count().await, 0);
450    }
451
452    #[tokio::test]
453    async fn test_should_preempt() {
454        let controller = PreemptionController::new();
455
456        // No transfers, no preemption needed
457        assert!(!controller.should_preempt(QoSClass::Critical).await);
458
459        // Add a low priority pausable transfer
460        controller
461            .register_transfer(QoSClass::Low, 1000, true)
462            .await;
463
464        // Critical should be able to preempt
465        assert!(controller.should_preempt(QoSClass::Critical).await);
466
467        // Bulk should not be able to preempt
468        assert!(!controller.should_preempt(QoSClass::Bulk).await);
469    }
470
471    #[tokio::test]
472    async fn test_pause_resume() {
473        let controller = PreemptionController::new();
474
475        let id1 = controller
476            .register_transfer(QoSClass::Low, 1000, true)
477            .await;
478        let id2 = controller
479            .register_transfer(QoSClass::Bulk, 1000, true)
480            .await;
481        let _id3 = controller
482            .register_transfer(QoSClass::Critical, 1000, true)
483            .await;
484
485        // Pause transfers below Critical
486        let paused = controller.pause_transfers_below(QoSClass::Critical).await;
487
488        assert_eq!(paused.len(), 2);
489        assert!(paused.contains(&id1));
490        assert!(paused.contains(&id2));
491        assert_eq!(controller.paused_count(), 2);
492
493        // Resume them
494        controller.resume_transfers(paused).await;
495
496        assert_eq!(controller.paused_count(), 0);
497    }
498
499    #[tokio::test]
500    async fn test_preemption_count() {
501        let controller = PreemptionController::new();
502
503        controller
504            .register_transfer(QoSClass::Bulk, 1000, true)
505            .await;
506
507        assert_eq!(controller.preemption_count(), 0);
508
509        controller.pause_transfers_below(QoSClass::Critical).await;
510
511        assert_eq!(controller.preemption_count(), 1);
512    }
513
514    #[tokio::test]
515    async fn test_transfers_by_class() {
516        let controller = PreemptionController::new();
517
518        controller
519            .register_transfer(QoSClass::Normal, 1000, true)
520            .await;
521        controller
522            .register_transfer(QoSClass::Normal, 2000, true)
523            .await;
524        controller
525            .register_transfer(QoSClass::High, 3000, true)
526            .await;
527
528        let normal = controller.transfers_by_class(QoSClass::Normal).await;
529        assert_eq!(normal.len(), 2);
530
531        let high = controller.transfers_by_class(QoSClass::High).await;
532        assert_eq!(high.len(), 1);
533    }
534
535    #[tokio::test]
536    async fn test_cleanup_completed() {
537        let controller = PreemptionController::new();
538
539        let id = controller
540            .register_transfer(QoSClass::Normal, 100, true)
541            .await;
542
543        let transfer = controller.get_transfer(id).await.unwrap();
544        transfer.record_sent(100);
545
546        assert!(transfer.is_complete());
547
548        let cleaned = controller.cleanup_completed().await;
549        assert_eq!(cleaned, 1);
550        assert_eq!(controller.active_count().await, 0);
551    }
552
553    #[tokio::test]
554    async fn test_bandwidth_used_below() {
555        let controller = PreemptionController::new();
556
557        let id = controller
558            .register_transfer(QoSClass::Low, 1000, true)
559            .await;
560        controller
561            .register_transfer(QoSClass::Bulk, 2000, true)
562            .await;
563        controller
564            .register_transfer(QoSClass::Critical, 3000, true)
565            .await;
566
567        // Bandwidth below Critical = Low (1000) + Bulk (2000) = 3000
568        let bw = controller.bandwidth_used_below(QoSClass::Critical).await;
569        assert_eq!(bw, 3000);
570
571        // Pause Low
572        controller.pause_transfers_below(QoSClass::High).await;
573
574        // Check transfer is paused
575        let transfer = controller.get_transfer(id).await.unwrap();
576        assert!(transfer.is_paused.load(Ordering::Relaxed));
577    }
578
579    #[tokio::test]
580    async fn test_stats() {
581        let controller = PreemptionController::new();
582
583        controller
584            .register_transfer(QoSClass::Critical, 1000, true)
585            .await;
586        controller
587            .register_transfer(QoSClass::Normal, 1000, true)
588            .await;
589        controller
590            .register_transfer(QoSClass::Bulk, 1000, true)
591            .await;
592
593        let stats = controller.stats().await;
594
595        assert_eq!(stats.active_transfers, 3);
596        assert_eq!(stats.paused_transfers, 0);
597        assert_eq!(
598            *stats.transfers_by_class.get(&QoSClass::Critical).unwrap(),
599            1
600        );
601        assert_eq!(*stats.transfers_by_class.get(&QoSClass::Normal).unwrap(), 1);
602        assert_eq!(*stats.transfers_by_class.get(&QoSClass::Bulk).unwrap(), 1);
603    }
604
605    #[tokio::test]
606    async fn test_resume_all() {
607        let controller = PreemptionController::new();
608
609        controller
610            .register_transfer(QoSClass::Low, 1000, true)
611            .await;
612        controller
613            .register_transfer(QoSClass::Bulk, 1000, true)
614            .await;
615
616        controller.pause_transfers_below(QoSClass::Critical).await;
617        assert_eq!(controller.paused_count(), 2);
618
619        controller.resume_all().await;
620        assert_eq!(controller.paused_count(), 0);
621    }
622}