Skip to main content

reinhardt_tasks/
priority_queue.rs

1//! Priority task queue with weighted scheduling
2
3use crate::{Task, TaskResult};
4use std::collections::{BTreeMap, HashMap, VecDeque};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7use tokio::sync::RwLock;
8
9/// Type alias for priority queue map
10type PriorityQueueMap = BTreeMap<Priority, VecDeque<Box<dyn Task>>>;
11
12/// Priority level for tasks
13///
14/// Ordering is based on weight values: `Low` (10) < `Normal` (50) < `High` (100).
15/// `Custom(w)` is ordered by its weight value relative to the standard priorities.
16///
17/// # Example
18///
19/// ```rust
20/// use reinhardt_tasks::Priority;
21///
22/// let high = Priority::High;
23/// let normal = Priority::Normal;
24/// let low = Priority::Low;
25/// assert!(high > normal);
26/// assert!(normal > low);
27///
28/// // Custom priority is ordered by weight value
29/// let custom_75 = Priority::Custom(75);
30/// assert!(custom_75 > normal);  // 75 > 50
31/// assert!(custom_75 < high);    // 75 < 100
32///
33/// let custom_200 = Priority::Custom(200);
34/// assert!(custom_200 > high);   // 200 > 100
35/// ```
36#[derive(Debug, Clone, Copy, Default)]
37pub enum Priority {
38	/// Low priority (weight: 10)
39	Low,
40	/// Normal priority (weight: 50)
41	#[default]
42	Normal,
43	/// High priority (weight: 100)
44	High,
45	/// Custom priority with specified weight
46	Custom(u32),
47}
48
49impl PartialEq for Priority {
50	fn eq(&self, other: &Self) -> bool {
51		self.default_weight() == other.default_weight()
52	}
53}
54
55impl Eq for Priority {}
56
57impl std::hash::Hash for Priority {
58	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
59		self.default_weight().hash(state);
60	}
61}
62
63impl PartialOrd for Priority {
64	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
65		Some(self.cmp(other))
66	}
67}
68
69impl Ord for Priority {
70	fn cmp(&self, other: &Self) -> std::cmp::Ordering {
71		self.default_weight().cmp(&other.default_weight())
72	}
73}
74
75impl Priority {
76	/// Get the default weight for this priority
77	///
78	/// # Example
79	///
80	/// ```rust
81	/// use reinhardt_tasks::Priority;
82	///
83	/// assert_eq!(Priority::High.default_weight(), 100);
84	/// assert_eq!(Priority::Normal.default_weight(), 50);
85	/// assert_eq!(Priority::Low.default_weight(), 10);
86	/// assert_eq!(Priority::Custom(75).default_weight(), 75);
87	/// ```
88	pub fn default_weight(&self) -> u32 {
89		match self {
90			Priority::High => 100,
91			Priority::Normal => 50,
92			Priority::Low => 10,
93			Priority::Custom(weight) => *weight,
94		}
95	}
96}
97
98/// Priority task queue with weighted scheduling
99///
100/// Tasks are dequeued based on their priority weights. Higher priority tasks
101/// have a higher chance of being selected, but lower priority tasks are not
102/// starved due to the weighted scheduling algorithm.
103///
104/// # Example
105///
106/// ```rust
107/// use reinhardt_tasks::{Priority, PriorityTaskQueue};
108///
109/// # async fn example() -> reinhardt_tasks::TaskResult<()> {
110/// let queue = PriorityTaskQueue::new();
111///
112/// // High priority tasks are more likely to be dequeued first
113/// // but low priority tasks will also be processed
114/// # Ok(())
115/// # }
116/// ```
117// Fixes #785: counter is per-instance instead of global static
118pub struct PriorityTaskQueue {
119	queues: Arc<RwLock<PriorityQueueMap>>,
120	weights: HashMap<Priority, u32>,
121	counter: AtomicU64,
122}
123
124impl PriorityTaskQueue {
125	/// Create a new priority task queue with default weights
126	///
127	/// Default weights:
128	/// - High: 100
129	/// - Normal: 50
130	/// - Low: 10
131	///
132	/// # Example
133	///
134	/// ```rust
135	/// use reinhardt_tasks::PriorityTaskQueue;
136	///
137	/// let queue = PriorityTaskQueue::new();
138	/// ```
139	pub fn new() -> Self {
140		let mut weights = HashMap::new();
141		weights.insert(Priority::High, 100);
142		weights.insert(Priority::Normal, 50);
143		weights.insert(Priority::Low, 10);
144
145		Self {
146			queues: Arc::new(RwLock::new(BTreeMap::new())),
147			weights,
148			counter: AtomicU64::new(0),
149		}
150	}
151
152	/// Create a new priority task queue with custom weights
153	///
154	/// # Example
155	///
156	/// ```rust
157	/// use reinhardt_tasks::{Priority, PriorityTaskQueue};
158	/// use std::collections::HashMap;
159	///
160	/// let mut weights = HashMap::new();
161	/// weights.insert(Priority::High, 200);
162	/// weights.insert(Priority::Normal, 100);
163	/// weights.insert(Priority::Low, 20);
164	///
165	/// let queue = PriorityTaskQueue::with_weights(weights);
166	/// ```
167	pub fn with_weights(weights: HashMap<Priority, u32>) -> Self {
168		Self {
169			queues: Arc::new(RwLock::new(BTreeMap::new())),
170			weights,
171			counter: AtomicU64::new(0),
172		}
173	}
174
175	/// Enqueue a task with the specified priority
176	///
177	/// # Example
178	///
179	/// ```rust,no_run
180	/// use reinhardt_tasks::{Priority, PriorityTaskQueue};
181	///
182	/// # async fn example() -> reinhardt_tasks::TaskResult<()> {
183	/// # struct MyTask;
184	/// # impl MyTask { fn new() -> Self { MyTask } }
185	/// let queue = PriorityTaskQueue::new();
186	/// let task = MyTask::new();
187	///
188	/// // queue.enqueue(Box::new(task), Priority::High).await?;
189	/// # Ok(())
190	/// # }
191	/// ```
192	pub async fn enqueue(&self, task: Box<dyn Task>, priority: Priority) -> TaskResult<()> {
193		let mut queues = self.queues.write().await;
194		queues.entry(priority).or_default().push_back(task);
195		Ok(())
196	}
197
198	/// Dequeue a task using weighted scheduling
199	///
200	/// Tasks are selected based on their priority weights. Higher priority
201	/// tasks have a higher probability of being selected, but lower priority
202	/// tasks are not starved.
203	///
204	/// Returns `None` if the queue is empty.
205	///
206	/// # Example
207	///
208	/// ```rust,no_run
209	/// use reinhardt_tasks::PriorityTaskQueue;
210	///
211	/// # async fn example() -> reinhardt_tasks::TaskResult<()> {
212	/// let queue = PriorityTaskQueue::new();
213	///
214	/// if let Some(task) = queue.dequeue().await? {
215	///     // Process task
216	/// }
217	/// # Ok(())
218	/// # }
219	/// ```
220	pub async fn dequeue(&self) -> TaskResult<Option<Box<dyn Task>>> {
221		let mut queues = self.queues.write().await;
222
223		if queues.is_empty() {
224			return Ok(None);
225		}
226
227		// Calculate total weight of non-empty queues
228		let mut total_weight = 0u32;
229		let mut priorities_with_weight = Vec::new();
230
231		for (priority, queue) in queues.iter() {
232			if !queue.is_empty() {
233				let weight = self.weights.get(priority).copied().unwrap_or_else(|| {
234					if let Priority::Custom(w) = priority {
235						*w
236					} else {
237						priority.default_weight()
238					}
239				});
240				total_weight += weight;
241				priorities_with_weight.push((*priority, weight));
242			}
243		}
244
245		if total_weight == 0 {
246			return Ok(None);
247		}
248
249		// Select a priority based on weights
250		// Use a simple counter-based approach for deterministic weighted round-robin
251		let selected_priority =
252			self.select_priority_weighted(&priorities_with_weight, total_weight);
253
254		// Dequeue from the selected priority
255		if let Some(queue) = queues.get_mut(&selected_priority)
256			&& let Some(task) = queue.pop_front()
257		{
258			return Ok(Some(task));
259		}
260
261		Ok(None)
262	}
263
264	/// Select a priority using weighted round-robin
265	fn select_priority_weighted(
266		&self,
267		priorities: &[(Priority, u32)],
268		total_weight: u32,
269	) -> Priority {
270		// Simple weighted selection: iterate through priorities in order
271		// and select based on accumulated weights
272		// This ensures FIFO within same priority and fair distribution
273
274		// Fixes #785: use instance counter instead of global static to avoid
275		// cross-instance interference between independent queue instances
276		let counter = self.counter.fetch_add(1, Ordering::Relaxed);
277		let target = (counter % total_weight as u64) as u32;
278
279		let mut accumulated = 0;
280		for (priority, weight) in priorities {
281			accumulated += weight;
282			if target < accumulated {
283				return *priority;
284			}
285		}
286
287		// Fallback to highest priority
288		priorities
289			.first()
290			.map(|(p, _)| *p)
291			.unwrap_or(Priority::Normal)
292	}
293
294	/// Get the total number of tasks in all queues
295	///
296	/// # Example
297	///
298	/// ```rust
299	/// use reinhardt_tasks::PriorityTaskQueue;
300	///
301	/// # async fn example() {
302	/// let queue = PriorityTaskQueue::new();
303	/// assert_eq!(queue.len().await, 0);
304	/// # }
305	/// ```
306	pub async fn len(&self) -> usize {
307		let queues = self.queues.read().await;
308		queues.values().map(|q| q.len()).sum()
309	}
310
311	/// Check if the queue is empty
312	///
313	/// # Example
314	///
315	/// ```rust
316	/// use reinhardt_tasks::PriorityTaskQueue;
317	///
318	/// # async fn example() {
319	/// let queue = PriorityTaskQueue::new();
320	/// assert!(queue.is_empty().await);
321	/// # }
322	/// ```
323	pub async fn is_empty(&self) -> bool {
324		let queues = self.queues.read().await;
325		queues.values().all(|q| q.is_empty())
326	}
327
328	/// Get the number of tasks for a specific priority
329	///
330	/// # Example
331	///
332	/// ```rust
333	/// use reinhardt_tasks::{Priority, PriorityTaskQueue};
334	///
335	/// # async fn example() {
336	/// let queue = PriorityTaskQueue::new();
337	/// assert_eq!(queue.len_for_priority(Priority::High).await, 0);
338	/// # }
339	/// ```
340	pub async fn len_for_priority(&self, priority: Priority) -> usize {
341		let queues = self.queues.read().await;
342		queues.get(&priority).map(|q| q.len()).unwrap_or(0)
343	}
344}
345
346impl Default for PriorityTaskQueue {
347	fn default() -> Self {
348		Self::new()
349	}
350}
351
352#[cfg(test)]
353mod tests {
354	use super::*;
355	use crate::TaskId;
356
357	#[derive(Debug)]
358	struct TestTask {
359		id: TaskId,
360		name: String,
361	}
362
363	impl TestTask {
364		fn new(name: &str) -> Self {
365			Self {
366				id: TaskId::new(),
367				name: name.to_string(),
368			}
369		}
370	}
371
372	impl Task for TestTask {
373		fn id(&self) -> TaskId {
374			self.id
375		}
376
377		fn name(&self) -> &str {
378			&self.name
379		}
380	}
381
382	#[tokio::test]
383	async fn test_priority_ordering() {
384		let queue = PriorityTaskQueue::new();
385
386		// Enqueue tasks with different priorities
387		queue
388			.enqueue(Box::new(TestTask::new("low1")), Priority::Low)
389			.await
390			.unwrap();
391		queue
392			.enqueue(Box::new(TestTask::new("high1")), Priority::High)
393			.await
394			.unwrap();
395		queue
396			.enqueue(Box::new(TestTask::new("normal1")), Priority::Normal)
397			.await
398			.unwrap();
399		queue
400			.enqueue(Box::new(TestTask::new("high2")), Priority::High)
401			.await
402			.unwrap();
403
404		assert_eq!(queue.len().await, 4);
405
406		// High priority tasks should be more likely to be dequeued first
407		let mut high_count = 0;
408		let mut dequeued = Vec::new();
409
410		for _ in 0..4 {
411			if let Some(task) = queue.dequeue().await.unwrap() {
412				dequeued.push(task.name().to_string());
413				if task.name().starts_with("high") {
414					high_count += 1;
415				}
416			}
417		}
418
419		// Should have dequeued at least one high priority task
420		assert!(high_count > 0);
421		assert_eq!(queue.len().await, 0);
422	}
423
424	#[tokio::test]
425	async fn test_weighted_scheduling() {
426		let mut weights = HashMap::new();
427		weights.insert(Priority::High, 90);
428		weights.insert(Priority::Normal, 9);
429		weights.insert(Priority::Low, 1);
430
431		let queue = PriorityTaskQueue::with_weights(weights);
432
433		// Enqueue more high priority tasks to match the weight ratio
434		// This ensures we can observe the weighted scheduling behavior
435		for i in 0..30 {
436			queue
437				.enqueue(
438					Box::new(TestTask::new(&format!("high{}", i))),
439					Priority::High,
440				)
441				.await
442				.unwrap();
443		}
444		for i in 0..10 {
445			queue
446				.enqueue(
447					Box::new(TestTask::new(&format!("normal{}", i))),
448					Priority::Normal,
449				)
450				.await
451				.unwrap();
452		}
453		for i in 0..5 {
454			queue
455				.enqueue(Box::new(TestTask::new(&format!("low{}", i))), Priority::Low)
456				.await
457				.unwrap();
458		}
459
460		let mut high_count = 0;
461		let mut normal_count = 0;
462		let mut low_count = 0;
463
464		// Dequeue all tasks and count by priority
465		while let Some(task) = queue.dequeue().await.unwrap() {
466			if task.name().starts_with("high") {
467				high_count += 1;
468			} else if task.name().starts_with("normal") {
469				normal_count += 1;
470			} else if task.name().starts_with("low") {
471				low_count += 1;
472			}
473		}
474
475		// Verify all tasks were dequeued
476		assert_eq!(high_count + normal_count + low_count, 45);
477
478		// All priorities should get at least some tasks (no starvation)
479		assert!(high_count > 0, "High priority tasks should be dequeued");
480		assert!(normal_count > 0, "Normal priority tasks should be dequeued");
481		assert!(low_count > 0, "Low priority tasks should be dequeued");
482
483		// High priority should get more tasks than normal
484		assert!(
485			high_count > normal_count,
486			"High count {} should be greater than normal count {}",
487			high_count,
488			normal_count
489		);
490
491		// Normal priority should get more tasks than low
492		assert!(
493			normal_count > low_count,
494			"Normal count {} should be greater than low count {}",
495			normal_count,
496			low_count
497		);
498	}
499
500	#[tokio::test]
501	async fn test_fifo_within_priority() {
502		let queue = PriorityTaskQueue::new();
503
504		// Enqueue multiple tasks with the same priority
505		queue
506			.enqueue(Box::new(TestTask::new("task1")), Priority::Normal)
507			.await
508			.unwrap();
509		queue
510			.enqueue(Box::new(TestTask::new("task2")), Priority::Normal)
511			.await
512			.unwrap();
513		queue
514			.enqueue(Box::new(TestTask::new("task3")), Priority::Normal)
515			.await
516			.unwrap();
517
518		// Tasks should be dequeued in FIFO order for the same priority
519		let task1 = queue.dequeue().await.unwrap().unwrap();
520		let task2 = queue.dequeue().await.unwrap().unwrap();
521		let task3 = queue.dequeue().await.unwrap().unwrap();
522
523		assert_eq!(task1.name(), "task1");
524		assert_eq!(task2.name(), "task2");
525		assert_eq!(task3.name(), "task3");
526	}
527
528	#[tokio::test]
529	async fn test_concurrent_access() {
530		let queue = Arc::new(PriorityTaskQueue::new());
531
532		// Spawn multiple tasks that enqueue
533		let mut handles = vec![];
534		for i in 0..10 {
535			let queue_clone = queue.clone();
536			handles.push(tokio::spawn(async move {
537				queue_clone
538					.enqueue(
539						Box::new(TestTask::new(&format!("task{}", i))),
540						Priority::Normal,
541					)
542					.await
543					.unwrap();
544			}));
545		}
546
547		// Wait for all enqueues to complete
548		for handle in handles {
549			handle.await.unwrap();
550		}
551
552		assert_eq!(queue.len().await, 10);
553
554		// Spawn multiple tasks that dequeue
555		let mut handles = vec![];
556		for _ in 0..10 {
557			let queue_clone = queue.clone();
558			handles.push(tokio::spawn(
559				async move { queue_clone.dequeue().await.unwrap() },
560			));
561		}
562
563		let mut count = 0;
564		for handle in handles {
565			if handle.await.unwrap().is_some() {
566				count += 1;
567			}
568		}
569
570		assert_eq!(count, 10);
571		assert!(queue.is_empty().await);
572	}
573
574	#[tokio::test]
575	async fn test_custom_priority() {
576		let queue = PriorityTaskQueue::new();
577
578		queue
579			.enqueue(Box::new(TestTask::new("custom75")), Priority::Custom(75))
580			.await
581			.unwrap();
582		queue
583			.enqueue(Box::new(TestTask::new("high")), Priority::High)
584			.await
585			.unwrap();
586		queue
587			.enqueue(Box::new(TestTask::new("normal")), Priority::Normal)
588			.await
589			.unwrap();
590
591		assert_eq!(queue.len().await, 3);
592
593		// Custom(75) should be between Normal(50) and High(100)
594		// Just verify all tasks are dequeued correctly
595		for _ in 0..3 {
596			let task = queue.dequeue().await.unwrap();
597			assert!(task.is_some());
598		}
599
600		// All tasks should have been dequeued
601		assert!(queue.is_empty().await);
602	}
603
604	#[tokio::test]
605	async fn test_empty_queue() {
606		let queue = PriorityTaskQueue::new();
607
608		assert!(queue.is_empty().await);
609		assert_eq!(queue.len().await, 0);
610
611		let task = queue.dequeue().await.unwrap();
612		assert!(task.is_none());
613	}
614
615	#[tokio::test]
616	async fn test_len_for_priority() {
617		let queue = PriorityTaskQueue::new();
618
619		queue
620			.enqueue(Box::new(TestTask::new("high1")), Priority::High)
621			.await
622			.unwrap();
623		queue
624			.enqueue(Box::new(TestTask::new("high2")), Priority::High)
625			.await
626			.unwrap();
627		queue
628			.enqueue(Box::new(TestTask::new("normal1")), Priority::Normal)
629			.await
630			.unwrap();
631
632		assert_eq!(queue.len_for_priority(Priority::High).await, 2);
633		assert_eq!(queue.len_for_priority(Priority::Normal).await, 1);
634		assert_eq!(queue.len_for_priority(Priority::Low).await, 0);
635	}
636
637	#[test]
638	fn test_priority_default_weights() {
639		assert_eq!(Priority::High.default_weight(), 100);
640		assert_eq!(Priority::Normal.default_weight(), 50);
641		assert_eq!(Priority::Low.default_weight(), 10);
642		assert_eq!(Priority::Custom(75).default_weight(), 75);
643	}
644
645	#[test]
646	fn test_priority_comparison() {
647		// Ordering is based on weight values
648		assert!(Priority::High > Priority::Normal);
649		assert!(Priority::Normal > Priority::Low);
650
651		// Custom priorities are ordered by their weight value
652		assert!(Priority::Custom(75) > Priority::Normal); // 75 > 50
653		assert!(Priority::Custom(75) < Priority::High); // 75 < 100
654		assert!(Priority::Custom(200) > Priority::High); // 200 > 100
655		assert!(Priority::Custom(0) < Priority::Low); // 0 < 10
656
657		// Custom with same weight as standard priority is equal
658		assert_eq!(Priority::Custom(100), Priority::High);
659		assert_eq!(Priority::Custom(50), Priority::Normal);
660		assert_eq!(Priority::Custom(10), Priority::Low);
661	}
662
663	#[test]
664	fn test_priority_default() {
665		assert_eq!(Priority::default(), Priority::Normal);
666	}
667}