1use crate::{Task, TaskResult};
4use std::collections::{BTreeMap, HashMap, VecDeque};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7use tokio::sync::RwLock;
8
9type PriorityQueueMap = BTreeMap<Priority, VecDeque<Box<dyn Task>>>;
11
12#[derive(Debug, Clone, Copy, Default)]
37pub enum Priority {
38 Low,
40 #[default]
42 Normal,
43 High,
45 Custom(u32),
47}
48
49impl PartialEq for Priority {
50 fn eq(&self, other: &Self) -> bool {
51 self.default_weight() == other.default_weight()
52 }
53}
54
55impl Eq for Priority {}
56
57impl std::hash::Hash for Priority {
58 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
59 self.default_weight().hash(state);
60 }
61}
62
63impl PartialOrd for Priority {
64 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
65 Some(self.cmp(other))
66 }
67}
68
69impl Ord for Priority {
70 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
71 self.default_weight().cmp(&other.default_weight())
72 }
73}
74
75impl Priority {
76 pub fn default_weight(&self) -> u32 {
89 match self {
90 Priority::High => 100,
91 Priority::Normal => 50,
92 Priority::Low => 10,
93 Priority::Custom(weight) => *weight,
94 }
95 }
96}
97
98pub struct PriorityTaskQueue {
119 queues: Arc<RwLock<PriorityQueueMap>>,
120 weights: HashMap<Priority, u32>,
121 counter: AtomicU64,
122}
123
124impl PriorityTaskQueue {
125 pub fn new() -> Self {
140 let mut weights = HashMap::new();
141 weights.insert(Priority::High, 100);
142 weights.insert(Priority::Normal, 50);
143 weights.insert(Priority::Low, 10);
144
145 Self {
146 queues: Arc::new(RwLock::new(BTreeMap::new())),
147 weights,
148 counter: AtomicU64::new(0),
149 }
150 }
151
152 pub fn with_weights(weights: HashMap<Priority, u32>) -> Self {
168 Self {
169 queues: Arc::new(RwLock::new(BTreeMap::new())),
170 weights,
171 counter: AtomicU64::new(0),
172 }
173 }
174
175 pub async fn enqueue(&self, task: Box<dyn Task>, priority: Priority) -> TaskResult<()> {
193 let mut queues = self.queues.write().await;
194 queues.entry(priority).or_default().push_back(task);
195 Ok(())
196 }
197
198 pub async fn dequeue(&self) -> TaskResult<Option<Box<dyn Task>>> {
221 let mut queues = self.queues.write().await;
222
223 if queues.is_empty() {
224 return Ok(None);
225 }
226
227 let mut total_weight = 0u32;
229 let mut priorities_with_weight = Vec::new();
230
231 for (priority, queue) in queues.iter() {
232 if !queue.is_empty() {
233 let weight = self.weights.get(priority).copied().unwrap_or_else(|| {
234 if let Priority::Custom(w) = priority {
235 *w
236 } else {
237 priority.default_weight()
238 }
239 });
240 total_weight += weight;
241 priorities_with_weight.push((*priority, weight));
242 }
243 }
244
245 if total_weight == 0 {
246 return Ok(None);
247 }
248
249 let selected_priority =
252 self.select_priority_weighted(&priorities_with_weight, total_weight);
253
254 if let Some(queue) = queues.get_mut(&selected_priority)
256 && let Some(task) = queue.pop_front()
257 {
258 return Ok(Some(task));
259 }
260
261 Ok(None)
262 }
263
264 fn select_priority_weighted(
266 &self,
267 priorities: &[(Priority, u32)],
268 total_weight: u32,
269 ) -> Priority {
270 let counter = self.counter.fetch_add(1, Ordering::Relaxed);
277 let target = (counter % total_weight as u64) as u32;
278
279 let mut accumulated = 0;
280 for (priority, weight) in priorities {
281 accumulated += weight;
282 if target < accumulated {
283 return *priority;
284 }
285 }
286
287 priorities
289 .first()
290 .map(|(p, _)| *p)
291 .unwrap_or(Priority::Normal)
292 }
293
294 pub async fn len(&self) -> usize {
307 let queues = self.queues.read().await;
308 queues.values().map(|q| q.len()).sum()
309 }
310
311 pub async fn is_empty(&self) -> bool {
324 let queues = self.queues.read().await;
325 queues.values().all(|q| q.is_empty())
326 }
327
328 pub async fn len_for_priority(&self, priority: Priority) -> usize {
341 let queues = self.queues.read().await;
342 queues.get(&priority).map(|q| q.len()).unwrap_or(0)
343 }
344}
345
346impl Default for PriorityTaskQueue {
347 fn default() -> Self {
348 Self::new()
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::TaskId;
356
357 #[derive(Debug)]
358 struct TestTask {
359 id: TaskId,
360 name: String,
361 }
362
363 impl TestTask {
364 fn new(name: &str) -> Self {
365 Self {
366 id: TaskId::new(),
367 name: name.to_string(),
368 }
369 }
370 }
371
372 impl Task for TestTask {
373 fn id(&self) -> TaskId {
374 self.id
375 }
376
377 fn name(&self) -> &str {
378 &self.name
379 }
380 }
381
382 #[tokio::test]
383 async fn test_priority_ordering() {
384 let queue = PriorityTaskQueue::new();
385
386 queue
388 .enqueue(Box::new(TestTask::new("low1")), Priority::Low)
389 .await
390 .unwrap();
391 queue
392 .enqueue(Box::new(TestTask::new("high1")), Priority::High)
393 .await
394 .unwrap();
395 queue
396 .enqueue(Box::new(TestTask::new("normal1")), Priority::Normal)
397 .await
398 .unwrap();
399 queue
400 .enqueue(Box::new(TestTask::new("high2")), Priority::High)
401 .await
402 .unwrap();
403
404 assert_eq!(queue.len().await, 4);
405
406 let mut high_count = 0;
408 let mut dequeued = Vec::new();
409
410 for _ in 0..4 {
411 if let Some(task) = queue.dequeue().await.unwrap() {
412 dequeued.push(task.name().to_string());
413 if task.name().starts_with("high") {
414 high_count += 1;
415 }
416 }
417 }
418
419 assert!(high_count > 0);
421 assert_eq!(queue.len().await, 0);
422 }
423
424 #[tokio::test]
425 async fn test_weighted_scheduling() {
426 let mut weights = HashMap::new();
427 weights.insert(Priority::High, 90);
428 weights.insert(Priority::Normal, 9);
429 weights.insert(Priority::Low, 1);
430
431 let queue = PriorityTaskQueue::with_weights(weights);
432
433 for i in 0..30 {
436 queue
437 .enqueue(
438 Box::new(TestTask::new(&format!("high{}", i))),
439 Priority::High,
440 )
441 .await
442 .unwrap();
443 }
444 for i in 0..10 {
445 queue
446 .enqueue(
447 Box::new(TestTask::new(&format!("normal{}", i))),
448 Priority::Normal,
449 )
450 .await
451 .unwrap();
452 }
453 for i in 0..5 {
454 queue
455 .enqueue(Box::new(TestTask::new(&format!("low{}", i))), Priority::Low)
456 .await
457 .unwrap();
458 }
459
460 let mut high_count = 0;
461 let mut normal_count = 0;
462 let mut low_count = 0;
463
464 while let Some(task) = queue.dequeue().await.unwrap() {
466 if task.name().starts_with("high") {
467 high_count += 1;
468 } else if task.name().starts_with("normal") {
469 normal_count += 1;
470 } else if task.name().starts_with("low") {
471 low_count += 1;
472 }
473 }
474
475 assert_eq!(high_count + normal_count + low_count, 45);
477
478 assert!(high_count > 0, "High priority tasks should be dequeued");
480 assert!(normal_count > 0, "Normal priority tasks should be dequeued");
481 assert!(low_count > 0, "Low priority tasks should be dequeued");
482
483 assert!(
485 high_count > normal_count,
486 "High count {} should be greater than normal count {}",
487 high_count,
488 normal_count
489 );
490
491 assert!(
493 normal_count > low_count,
494 "Normal count {} should be greater than low count {}",
495 normal_count,
496 low_count
497 );
498 }
499
500 #[tokio::test]
501 async fn test_fifo_within_priority() {
502 let queue = PriorityTaskQueue::new();
503
504 queue
506 .enqueue(Box::new(TestTask::new("task1")), Priority::Normal)
507 .await
508 .unwrap();
509 queue
510 .enqueue(Box::new(TestTask::new("task2")), Priority::Normal)
511 .await
512 .unwrap();
513 queue
514 .enqueue(Box::new(TestTask::new("task3")), Priority::Normal)
515 .await
516 .unwrap();
517
518 let task1 = queue.dequeue().await.unwrap().unwrap();
520 let task2 = queue.dequeue().await.unwrap().unwrap();
521 let task3 = queue.dequeue().await.unwrap().unwrap();
522
523 assert_eq!(task1.name(), "task1");
524 assert_eq!(task2.name(), "task2");
525 assert_eq!(task3.name(), "task3");
526 }
527
528 #[tokio::test]
529 async fn test_concurrent_access() {
530 let queue = Arc::new(PriorityTaskQueue::new());
531
532 let mut handles = vec![];
534 for i in 0..10 {
535 let queue_clone = queue.clone();
536 handles.push(tokio::spawn(async move {
537 queue_clone
538 .enqueue(
539 Box::new(TestTask::new(&format!("task{}", i))),
540 Priority::Normal,
541 )
542 .await
543 .unwrap();
544 }));
545 }
546
547 for handle in handles {
549 handle.await.unwrap();
550 }
551
552 assert_eq!(queue.len().await, 10);
553
554 let mut handles = vec![];
556 for _ in 0..10 {
557 let queue_clone = queue.clone();
558 handles.push(tokio::spawn(
559 async move { queue_clone.dequeue().await.unwrap() },
560 ));
561 }
562
563 let mut count = 0;
564 for handle in handles {
565 if handle.await.unwrap().is_some() {
566 count += 1;
567 }
568 }
569
570 assert_eq!(count, 10);
571 assert!(queue.is_empty().await);
572 }
573
574 #[tokio::test]
575 async fn test_custom_priority() {
576 let queue = PriorityTaskQueue::new();
577
578 queue
579 .enqueue(Box::new(TestTask::new("custom75")), Priority::Custom(75))
580 .await
581 .unwrap();
582 queue
583 .enqueue(Box::new(TestTask::new("high")), Priority::High)
584 .await
585 .unwrap();
586 queue
587 .enqueue(Box::new(TestTask::new("normal")), Priority::Normal)
588 .await
589 .unwrap();
590
591 assert_eq!(queue.len().await, 3);
592
593 for _ in 0..3 {
596 let task = queue.dequeue().await.unwrap();
597 assert!(task.is_some());
598 }
599
600 assert!(queue.is_empty().await);
602 }
603
604 #[tokio::test]
605 async fn test_empty_queue() {
606 let queue = PriorityTaskQueue::new();
607
608 assert!(queue.is_empty().await);
609 assert_eq!(queue.len().await, 0);
610
611 let task = queue.dequeue().await.unwrap();
612 assert!(task.is_none());
613 }
614
615 #[tokio::test]
616 async fn test_len_for_priority() {
617 let queue = PriorityTaskQueue::new();
618
619 queue
620 .enqueue(Box::new(TestTask::new("high1")), Priority::High)
621 .await
622 .unwrap();
623 queue
624 .enqueue(Box::new(TestTask::new("high2")), Priority::High)
625 .await
626 .unwrap();
627 queue
628 .enqueue(Box::new(TestTask::new("normal1")), Priority::Normal)
629 .await
630 .unwrap();
631
632 assert_eq!(queue.len_for_priority(Priority::High).await, 2);
633 assert_eq!(queue.len_for_priority(Priority::Normal).await, 1);
634 assert_eq!(queue.len_for_priority(Priority::Low).await, 0);
635 }
636
637 #[test]
638 fn test_priority_default_weights() {
639 assert_eq!(Priority::High.default_weight(), 100);
640 assert_eq!(Priority::Normal.default_weight(), 50);
641 assert_eq!(Priority::Low.default_weight(), 10);
642 assert_eq!(Priority::Custom(75).default_weight(), 75);
643 }
644
645 #[test]
646 fn test_priority_comparison() {
647 assert!(Priority::High > Priority::Normal);
649 assert!(Priority::Normal > Priority::Low);
650
651 assert!(Priority::Custom(75) > Priority::Normal); assert!(Priority::Custom(75) < Priority::High); assert!(Priority::Custom(200) > Priority::High); assert!(Priority::Custom(0) < Priority::Low); assert_eq!(Priority::Custom(100), Priority::High);
659 assert_eq!(Priority::Custom(50), Priority::Normal);
660 assert_eq!(Priority::Custom(10), Priority::Low);
661 }
662
663 #[test]
664 fn test_priority_default() {
665 assert_eq!(Priority::default(), Priority::Normal);
666 }
667}