Skip to main content

haagenti_network/
scheduler.rs

1//! Download scheduler with bandwidth monitoring
2
3use crate::{NetworkConfig, PrioritizedFragment, Priority, PriorityQueue};
4use serde::{Deserialize, Serialize};
5use std::collections::VecDeque;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{Mutex, Semaphore};
10use tracing::warn;
11
12/// Scheduler configuration
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SchedulerConfig {
15    /// Maximum concurrent downloads
16    pub max_concurrent: usize,
17    /// Bandwidth sample window
18    pub sample_window: Duration,
19    /// Number of samples to keep
20    pub sample_count: usize,
21    /// Minimum acceptable bandwidth (bytes/sec)
22    pub min_bandwidth: u64,
23    /// Maximum queue size
24    pub max_queue_size: usize,
25}
26
27impl Default for SchedulerConfig {
28    fn default() -> Self {
29        Self {
30            max_concurrent: 4,
31            sample_window: Duration::from_secs(5),
32            sample_count: 10,
33            min_bandwidth: 1024 * 1024, // 1MB/s
34            max_queue_size: 1000,
35        }
36    }
37}
38
39impl From<&NetworkConfig> for SchedulerConfig {
40    fn from(config: &NetworkConfig) -> Self {
41        Self {
42            max_concurrent: config.max_concurrent,
43            min_bandwidth: config.min_bandwidth,
44            ..Default::default()
45        }
46    }
47}
48
49/// Bandwidth measurement sample
50#[derive(Debug, Clone, Copy)]
51struct BandwidthSample {
52    bytes: u64,
53    duration: Duration,
54    timestamp: Instant,
55}
56
57impl BandwidthSample {
58    fn bytes_per_second(&self) -> f64 {
59        self.bytes as f64 / self.duration.as_secs_f64()
60    }
61}
62
63/// Bandwidth monitor
64pub struct BandwidthMonitor {
65    samples: Mutex<VecDeque<BandwidthSample>>,
66    sample_window: Duration,
67    max_samples: usize,
68    total_bytes: AtomicU64,
69    start_time: Instant,
70}
71
72impl BandwidthMonitor {
73    /// Create a new bandwidth monitor
74    pub fn new(sample_window: Duration, max_samples: usize) -> Self {
75        Self {
76            samples: Mutex::new(VecDeque::with_capacity(max_samples)),
77            sample_window,
78            max_samples,
79            total_bytes: AtomicU64::new(0),
80            start_time: Instant::now(),
81        }
82    }
83
84    /// Record a completed download
85    pub async fn record(&self, bytes: u64, duration: Duration) {
86        let sample = BandwidthSample {
87            bytes,
88            duration,
89            timestamp: Instant::now(),
90        };
91
92        let mut samples = self.samples.lock().await;
93
94        // Remove old samples
95        let cutoff = Instant::now() - self.sample_window;
96        while samples.front().is_some_and(|s| s.timestamp < cutoff) {
97            samples.pop_front();
98        }
99
100        // Add new sample
101        if samples.len() >= self.max_samples {
102            samples.pop_front();
103        }
104        samples.push_back(sample);
105
106        self.total_bytes.fetch_add(bytes, Ordering::Relaxed);
107    }
108
109    /// Get current bandwidth estimate (bytes/sec)
110    pub async fn current_bandwidth(&self) -> f64 {
111        let samples = self.samples.lock().await;
112
113        if samples.is_empty() {
114            return 0.0;
115        }
116
117        // Weighted moving average (more recent samples weighted higher)
118        let mut total_weight = 0.0;
119        let mut weighted_sum = 0.0;
120
121        for (i, sample) in samples.iter().enumerate() {
122            let weight = (i + 1) as f64;
123            weighted_sum += sample.bytes_per_second() * weight;
124            total_weight += weight;
125        }
126
127        if total_weight > 0.0 {
128            weighted_sum / total_weight
129        } else {
130            0.0
131        }
132    }
133
134    /// Get average bandwidth over entire session
135    pub fn average_bandwidth(&self) -> f64 {
136        let bytes = self.total_bytes.load(Ordering::Relaxed);
137        let duration = self.start_time.elapsed();
138
139        if duration.as_secs_f64() > 0.0 {
140            bytes as f64 / duration.as_secs_f64()
141        } else {
142            0.0
143        }
144    }
145
146    /// Get total bytes transferred
147    pub fn total_bytes(&self) -> u64 {
148        self.total_bytes.load(Ordering::Relaxed)
149    }
150
151    /// Estimate time to download given bytes
152    pub async fn estimate_time(&self, bytes: u64) -> Duration {
153        let bandwidth = self.current_bandwidth().await;
154        if bandwidth > 0.0 {
155            Duration::from_secs_f64(bytes as f64 / bandwidth)
156        } else {
157            Duration::from_secs(u64::MAX)
158        }
159    }
160}
161
162/// Download scheduler
163pub struct Scheduler {
164    config: SchedulerConfig,
165    queue: PriorityQueue,
166    bandwidth: Arc<BandwidthMonitor>,
167    semaphore: Arc<Semaphore>,
168    active: AtomicU64,
169    completed: AtomicU64,
170    failed: AtomicU64,
171}
172
173impl Scheduler {
174    /// Create a new scheduler
175    pub fn new(config: SchedulerConfig) -> Self {
176        let bandwidth = Arc::new(BandwidthMonitor::new(
177            config.sample_window,
178            config.sample_count,
179        ));
180
181        Self {
182            semaphore: Arc::new(Semaphore::new(config.max_concurrent)),
183            queue: PriorityQueue::new(),
184            bandwidth,
185            config,
186            active: AtomicU64::new(0),
187            completed: AtomicU64::new(0),
188            failed: AtomicU64::new(0),
189        }
190    }
191
192    /// Enqueue a fragment for download
193    pub fn enqueue(&self, fragment: PrioritizedFragment) {
194        if self.queue.len() >= self.config.max_queue_size {
195            warn!("Queue full, dropping fragment {:?}", fragment.fragment_id);
196            return;
197        }
198        self.queue.push(fragment);
199    }
200
201    /// Enqueue multiple fragments
202    pub fn enqueue_many(&self, fragments: impl IntoIterator<Item = PrioritizedFragment>) {
203        for fragment in fragments {
204            self.enqueue(fragment);
205        }
206    }
207
208    /// Get next fragment to download
209    pub async fn next(&self) -> Option<(PrioritizedFragment, SchedulerPermit<'_>)> {
210        let fragment = self.queue.pop()?;
211
212        // Wait for download slot
213        let permit = self.semaphore.clone().acquire_owned().await.ok()?;
214        self.active.fetch_add(1, Ordering::Relaxed);
215
216        Some((
217            fragment,
218            SchedulerPermit {
219                _permit: permit,
220                scheduler: self,
221            },
222        ))
223    }
224
225    /// Record completed download
226    pub async fn record_success(&self, bytes: u64, duration: Duration) {
227        self.bandwidth.record(bytes, duration).await;
228        self.completed.fetch_add(1, Ordering::Relaxed);
229    }
230
231    /// Record failed download
232    pub fn record_failure(&self) {
233        self.failed.fetch_add(1, Ordering::Relaxed);
234    }
235
236    /// Get bandwidth monitor
237    pub fn bandwidth(&self) -> &BandwidthMonitor {
238        &self.bandwidth
239    }
240
241    /// Get queue length
242    pub fn queue_len(&self) -> usize {
243        self.queue.len()
244    }
245
246    /// Get active downloads
247    pub fn active(&self) -> u64 {
248        self.active.load(Ordering::Relaxed)
249    }
250
251    /// Get completed downloads
252    pub fn completed(&self) -> u64 {
253        self.completed.load(Ordering::Relaxed)
254    }
255
256    /// Get failed downloads
257    pub fn failed(&self) -> u64 {
258        self.failed.load(Ordering::Relaxed)
259    }
260
261    /// Check if should reduce concurrency (bandwidth dropping)
262    pub async fn should_throttle(&self) -> bool {
263        let current = self.bandwidth.current_bandwidth().await;
264        current > 0.0 && current < self.config.min_bandwidth as f64
265    }
266
267    /// Get scheduler statistics
268    pub async fn stats(&self) -> SchedulerStats {
269        SchedulerStats {
270            queue_len: self.queue.len(),
271            active: self.active(),
272            completed: self.completed(),
273            failed: self.failed(),
274            current_bandwidth: self.bandwidth.current_bandwidth().await,
275            average_bandwidth: self.bandwidth.average_bandwidth(),
276            total_bytes: self.bandwidth.total_bytes(),
277        }
278    }
279
280    /// Clear the queue
281    pub fn clear(&self) {
282        self.queue.clear();
283    }
284
285    /// Update priority of queued fragment
286    pub fn update_priority(
287        &self,
288        fragment_id: &haagenti_fragments::FragmentId,
289        priority: Priority,
290    ) {
291        self.queue.update_priority(fragment_id, priority);
292    }
293}
294
295/// Permit for an active download
296pub struct SchedulerPermit<'a> {
297    _permit: tokio::sync::OwnedSemaphorePermit,
298    scheduler: &'a Scheduler,
299}
300
301impl Drop for SchedulerPermit<'_> {
302    fn drop(&mut self) {
303        self.scheduler.active.fetch_sub(1, Ordering::Relaxed);
304    }
305}
306
307/// Scheduler statistics
308#[derive(Debug, Clone)]
309pub struct SchedulerStats {
310    /// Queue length
311    pub queue_len: usize,
312    /// Active downloads
313    pub active: u64,
314    /// Completed downloads
315    pub completed: u64,
316    /// Failed downloads
317    pub failed: u64,
318    /// Current bandwidth (bytes/sec)
319    pub current_bandwidth: f64,
320    /// Average bandwidth (bytes/sec)
321    pub average_bandwidth: f64,
322    /// Total bytes transferred
323    pub total_bytes: u64,
324}
325
326impl SchedulerStats {
327    /// Success rate
328    pub fn success_rate(&self) -> f64 {
329        let total = self.completed + self.failed;
330        if total == 0 {
331            1.0
332        } else {
333            self.completed as f64 / total as f64
334        }
335    }
336
337    /// Format bandwidth as human readable
338    pub fn bandwidth_human(&self) -> String {
339        format_bytes_per_second(self.current_bandwidth)
340    }
341}
342
343fn format_bytes_per_second(bps: f64) -> String {
344    if bps >= 1_000_000_000.0 {
345        format!("{:.1} GB/s", bps / 1_000_000_000.0)
346    } else if bps >= 1_000_000.0 {
347        format!("{:.1} MB/s", bps / 1_000_000.0)
348    } else if bps >= 1_000.0 {
349        format!("{:.1} KB/s", bps / 1_000.0)
350    } else {
351        format!("{:.0} B/s", bps)
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use haagenti_fragments::FragmentId;
359
360    #[tokio::test]
361    async fn test_bandwidth_monitor() {
362        let monitor = BandwidthMonitor::new(Duration::from_secs(5), 10);
363
364        // Record some transfers
365        monitor.record(1024 * 1024, Duration::from_secs(1)).await;
366        monitor.record(2048 * 1024, Duration::from_secs(1)).await;
367
368        let bandwidth = monitor.current_bandwidth().await;
369        assert!(bandwidth > 1_000_000.0); // > 1MB/s
370
371        assert_eq!(monitor.total_bytes(), 3 * 1024 * 1024);
372    }
373
374    #[tokio::test]
375    async fn test_scheduler_priority() {
376        let config = SchedulerConfig {
377            max_concurrent: 2,
378            ..Default::default()
379        };
380        let scheduler = Scheduler::new(config);
381
382        // Enqueue fragments with different priorities
383        scheduler.enqueue(PrioritizedFragment::new(
384            FragmentId::new([1; 16]),
385            Priority::Low,
386        ));
387        scheduler.enqueue(PrioritizedFragment::new(
388            FragmentId::new([2; 16]),
389            Priority::Critical,
390        ));
391        scheduler.enqueue(PrioritizedFragment::new(
392            FragmentId::new([3; 16]),
393            Priority::Normal,
394        ));
395
396        // Should get critical first
397        let (frag, _permit) = scheduler.next().await.unwrap();
398        assert_eq!(frag.priority, Priority::Critical);
399    }
400}