1use crate::{NetworkConfig, PrioritizedFragment, Priority, PriorityQueue};
4use serde::{Deserialize, Serialize};
5use std::collections::VecDeque;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::{Mutex, Semaphore};
10use tracing::warn;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SchedulerConfig {
15 pub max_concurrent: usize,
17 pub sample_window: Duration,
19 pub sample_count: usize,
21 pub min_bandwidth: u64,
23 pub max_queue_size: usize,
25}
26
27impl Default for SchedulerConfig {
28 fn default() -> Self {
29 Self {
30 max_concurrent: 4,
31 sample_window: Duration::from_secs(5),
32 sample_count: 10,
33 min_bandwidth: 1024 * 1024, max_queue_size: 1000,
35 }
36 }
37}
38
39impl From<&NetworkConfig> for SchedulerConfig {
40 fn from(config: &NetworkConfig) -> Self {
41 Self {
42 max_concurrent: config.max_concurrent,
43 min_bandwidth: config.min_bandwidth,
44 ..Default::default()
45 }
46 }
47}
48
49#[derive(Debug, Clone, Copy)]
51struct BandwidthSample {
52 bytes: u64,
53 duration: Duration,
54 timestamp: Instant,
55}
56
57impl BandwidthSample {
58 fn bytes_per_second(&self) -> f64 {
59 self.bytes as f64 / self.duration.as_secs_f64()
60 }
61}
62
63pub struct BandwidthMonitor {
65 samples: Mutex<VecDeque<BandwidthSample>>,
66 sample_window: Duration,
67 max_samples: usize,
68 total_bytes: AtomicU64,
69 start_time: Instant,
70}
71
72impl BandwidthMonitor {
73 pub fn new(sample_window: Duration, max_samples: usize) -> Self {
75 Self {
76 samples: Mutex::new(VecDeque::with_capacity(max_samples)),
77 sample_window,
78 max_samples,
79 total_bytes: AtomicU64::new(0),
80 start_time: Instant::now(),
81 }
82 }
83
84 pub async fn record(&self, bytes: u64, duration: Duration) {
86 let sample = BandwidthSample {
87 bytes,
88 duration,
89 timestamp: Instant::now(),
90 };
91
92 let mut samples = self.samples.lock().await;
93
94 let cutoff = Instant::now() - self.sample_window;
96 while samples.front().is_some_and(|s| s.timestamp < cutoff) {
97 samples.pop_front();
98 }
99
100 if samples.len() >= self.max_samples {
102 samples.pop_front();
103 }
104 samples.push_back(sample);
105
106 self.total_bytes.fetch_add(bytes, Ordering::Relaxed);
107 }
108
109 pub async fn current_bandwidth(&self) -> f64 {
111 let samples = self.samples.lock().await;
112
113 if samples.is_empty() {
114 return 0.0;
115 }
116
117 let mut total_weight = 0.0;
119 let mut weighted_sum = 0.0;
120
121 for (i, sample) in samples.iter().enumerate() {
122 let weight = (i + 1) as f64;
123 weighted_sum += sample.bytes_per_second() * weight;
124 total_weight += weight;
125 }
126
127 if total_weight > 0.0 {
128 weighted_sum / total_weight
129 } else {
130 0.0
131 }
132 }
133
134 pub fn average_bandwidth(&self) -> f64 {
136 let bytes = self.total_bytes.load(Ordering::Relaxed);
137 let duration = self.start_time.elapsed();
138
139 if duration.as_secs_f64() > 0.0 {
140 bytes as f64 / duration.as_secs_f64()
141 } else {
142 0.0
143 }
144 }
145
146 pub fn total_bytes(&self) -> u64 {
148 self.total_bytes.load(Ordering::Relaxed)
149 }
150
151 pub async fn estimate_time(&self, bytes: u64) -> Duration {
153 let bandwidth = self.current_bandwidth().await;
154 if bandwidth > 0.0 {
155 Duration::from_secs_f64(bytes as f64 / bandwidth)
156 } else {
157 Duration::from_secs(u64::MAX)
158 }
159 }
160}
161
162pub struct Scheduler {
164 config: SchedulerConfig,
165 queue: PriorityQueue,
166 bandwidth: Arc<BandwidthMonitor>,
167 semaphore: Arc<Semaphore>,
168 active: AtomicU64,
169 completed: AtomicU64,
170 failed: AtomicU64,
171}
172
173impl Scheduler {
174 pub fn new(config: SchedulerConfig) -> Self {
176 let bandwidth = Arc::new(BandwidthMonitor::new(
177 config.sample_window,
178 config.sample_count,
179 ));
180
181 Self {
182 semaphore: Arc::new(Semaphore::new(config.max_concurrent)),
183 queue: PriorityQueue::new(),
184 bandwidth,
185 config,
186 active: AtomicU64::new(0),
187 completed: AtomicU64::new(0),
188 failed: AtomicU64::new(0),
189 }
190 }
191
192 pub fn enqueue(&self, fragment: PrioritizedFragment) {
194 if self.queue.len() >= self.config.max_queue_size {
195 warn!("Queue full, dropping fragment {:?}", fragment.fragment_id);
196 return;
197 }
198 self.queue.push(fragment);
199 }
200
201 pub fn enqueue_many(&self, fragments: impl IntoIterator<Item = PrioritizedFragment>) {
203 for fragment in fragments {
204 self.enqueue(fragment);
205 }
206 }
207
208 pub async fn next(&self) -> Option<(PrioritizedFragment, SchedulerPermit<'_>)> {
210 let fragment = self.queue.pop()?;
211
212 let permit = self.semaphore.clone().acquire_owned().await.ok()?;
214 self.active.fetch_add(1, Ordering::Relaxed);
215
216 Some((
217 fragment,
218 SchedulerPermit {
219 _permit: permit,
220 scheduler: self,
221 },
222 ))
223 }
224
225 pub async fn record_success(&self, bytes: u64, duration: Duration) {
227 self.bandwidth.record(bytes, duration).await;
228 self.completed.fetch_add(1, Ordering::Relaxed);
229 }
230
231 pub fn record_failure(&self) {
233 self.failed.fetch_add(1, Ordering::Relaxed);
234 }
235
236 pub fn bandwidth(&self) -> &BandwidthMonitor {
238 &self.bandwidth
239 }
240
241 pub fn queue_len(&self) -> usize {
243 self.queue.len()
244 }
245
246 pub fn active(&self) -> u64 {
248 self.active.load(Ordering::Relaxed)
249 }
250
251 pub fn completed(&self) -> u64 {
253 self.completed.load(Ordering::Relaxed)
254 }
255
256 pub fn failed(&self) -> u64 {
258 self.failed.load(Ordering::Relaxed)
259 }
260
261 pub async fn should_throttle(&self) -> bool {
263 let current = self.bandwidth.current_bandwidth().await;
264 current > 0.0 && current < self.config.min_bandwidth as f64
265 }
266
267 pub async fn stats(&self) -> SchedulerStats {
269 SchedulerStats {
270 queue_len: self.queue.len(),
271 active: self.active(),
272 completed: self.completed(),
273 failed: self.failed(),
274 current_bandwidth: self.bandwidth.current_bandwidth().await,
275 average_bandwidth: self.bandwidth.average_bandwidth(),
276 total_bytes: self.bandwidth.total_bytes(),
277 }
278 }
279
280 pub fn clear(&self) {
282 self.queue.clear();
283 }
284
285 pub fn update_priority(
287 &self,
288 fragment_id: &haagenti_fragments::FragmentId,
289 priority: Priority,
290 ) {
291 self.queue.update_priority(fragment_id, priority);
292 }
293}
294
295pub struct SchedulerPermit<'a> {
297 _permit: tokio::sync::OwnedSemaphorePermit,
298 scheduler: &'a Scheduler,
299}
300
301impl Drop for SchedulerPermit<'_> {
302 fn drop(&mut self) {
303 self.scheduler.active.fetch_sub(1, Ordering::Relaxed);
304 }
305}
306
307#[derive(Debug, Clone)]
309pub struct SchedulerStats {
310 pub queue_len: usize,
312 pub active: u64,
314 pub completed: u64,
316 pub failed: u64,
318 pub current_bandwidth: f64,
320 pub average_bandwidth: f64,
322 pub total_bytes: u64,
324}
325
326impl SchedulerStats {
327 pub fn success_rate(&self) -> f64 {
329 let total = self.completed + self.failed;
330 if total == 0 {
331 1.0
332 } else {
333 self.completed as f64 / total as f64
334 }
335 }
336
337 pub fn bandwidth_human(&self) -> String {
339 format_bytes_per_second(self.current_bandwidth)
340 }
341}
342
343fn format_bytes_per_second(bps: f64) -> String {
344 if bps >= 1_000_000_000.0 {
345 format!("{:.1} GB/s", bps / 1_000_000_000.0)
346 } else if bps >= 1_000_000.0 {
347 format!("{:.1} MB/s", bps / 1_000_000.0)
348 } else if bps >= 1_000.0 {
349 format!("{:.1} KB/s", bps / 1_000.0)
350 } else {
351 format!("{:.0} B/s", bps)
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use haagenti_fragments::FragmentId;
359
360 #[tokio::test]
361 async fn test_bandwidth_monitor() {
362 let monitor = BandwidthMonitor::new(Duration::from_secs(5), 10);
363
364 monitor.record(1024 * 1024, Duration::from_secs(1)).await;
366 monitor.record(2048 * 1024, Duration::from_secs(1)).await;
367
368 let bandwidth = monitor.current_bandwidth().await;
369 assert!(bandwidth > 1_000_000.0); assert_eq!(monitor.total_bytes(), 3 * 1024 * 1024);
372 }
373
374 #[tokio::test]
375 async fn test_scheduler_priority() {
376 let config = SchedulerConfig {
377 max_concurrent: 2,
378 ..Default::default()
379 };
380 let scheduler = Scheduler::new(config);
381
382 scheduler.enqueue(PrioritizedFragment::new(
384 FragmentId::new([1; 16]),
385 Priority::Low,
386 ));
387 scheduler.enqueue(PrioritizedFragment::new(
388 FragmentId::new([2; 16]),
389 Priority::Critical,
390 ));
391 scheduler.enqueue(PrioritizedFragment::new(
392 FragmentId::new([3; 16]),
393 Priority::Normal,
394 ));
395
396 let (frag, _permit) = scheduler.next().await.unwrap();
398 assert_eq!(frag.priority, Priority::Critical);
399 }
400}