1use std::{cmp::Reverse, collections::BinaryHeap, sync::Arc, thread::ThreadId, time};
2
3use parking_lot::{Condvar, Mutex};
4
5pub trait Delayed: Ord {
6 fn delayed(&self) -> i64;
7}
8
9#[derive(Default)]
10pub struct DelayQueue<T: Delayed> {
11 queue: Arc<Mutex<DelayQueueInner<T>>>,
12 available: Arc<Condvar>,
13}
14
15impl<T: Delayed> Clone for DelayQueue<T> {
16 fn clone(&self) -> Self {
17 Self {
18 queue: Arc::clone(&self.queue),
19 available: Arc::clone(&self.available),
20 }
21 }
22}
23
24#[derive(Default, Clone)]
25struct DelayQueueInner<T: Delayed> {
26 queue: BinaryHeap<Reverse<Arc<T>>>,
27 current_thread: Option<ThreadId>,
28}
29
30impl<T: Delayed> DelayQueueInner<T> {
31 fn peek(&self) -> Option<&T> {
32 let result = self.queue.peek()?;
33 Some(&result.0)
34 }
35}
36
37impl<T> DelayQueue<T>
38where
39 T: Delayed + Sync + Send,
40{
41 pub fn put(&mut self, t: T) {
42 let queue = self.queue.clone();
43 let queue = &mut queue.lock().queue;
44 let t = Reverse(Arc::new(t));
45 queue.push(t.clone());
46 if queue.peek() == Some(&t) {
47 self.available.notify_one();
48 }
49 }
50
51 pub fn take(&mut self) -> Arc<T> {
52 let queue = self.queue.clone();
53 let avaliable = self.available.clone();
54 let mut guard = queue.lock();
55 loop {
56 match guard.peek() {
57 None => {
58 avaliable.wait(&mut guard);
59 }
60 Some(first) => {
61 let delayed = first.delayed();
62 if delayed <= 0 {
63 let result = guard.queue.pop().unwrap();
64 if guard.current_thread.is_none() && guard.peek().is_some() {
65 avaliable.notify_one();
66 }
67 return result.0;
68 }
69 let _ = first;
70 match guard.current_thread {
71 Some(_) => {
72 avaliable.wait(&mut guard);
73 }
74 None => {
75 let thread_id = std::thread::current().id();
76 guard.current_thread = Some(thread_id);
77 avaliable
78 .wait_for(&mut guard, time::Duration::from_nanos(delayed as u64));
79 if guard.current_thread == Some(thread_id) {
80 guard.current_thread = None
81 }
82 }
83 }
84 }
85 }
86 }
87 }
88}
89
90#[cfg(test)]
91mod test {
92 use std::collections::HashMap;
93
94 use chrono::{DateTime, Duration, Local};
95
96 use super::*;
97 #[test]
98 fn test() {
99 #[derive(Default, Debug, PartialEq, Eq)]
100 struct Task {
101 deadline: i64,
102
103 message: String,
104 }
105
106 impl Task {
107 fn new<S: Into<String>>(deadline: i64, message: S) -> Task {
108 let message = message.into();
109 Task { deadline, message }
110 }
111 }
112
113 impl Delayed for Task {
114 fn delayed(&self) -> i64 {
115 self.deadline - chrono::Local::now().timestamp_nanos()
116 }
117 }
118
119 impl Ord for Task {
120 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
121 self.deadline.cmp(&other.deadline)
122 }
123 }
124
125 impl PartialOrd for Task {
126 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
127 Some(self.deadline.cmp(&other.deadline))
128 }
129 }
130
131 const TOTAL_COUNT: usize = 1000;
132 const THREAD_COUNT: usize = 8;
133
134 let queue = DelayQueue::<Task>::default();
135 {
136 let mut queue = queue.clone();
137 std::thread::spawn(move || {
138 for index in 0..TOTAL_COUNT {
139 let v = rand::random::<u64>() % 10000;
140 queue.put(Task::new(
141 after(Duration::milliseconds(v as i64)).timestamp_nanos(),
142 format!("index: {}. delay for {}ms", index, v),
143 ));
144 }
145 });
146 }
147
148 let maps = (0..THREAD_COUNT)
149 .map(|_thead_id| {
150 let mut queue = queue.clone();
151 std::thread::spawn(move || {
152 let mut map = HashMap::<i32, i32>::new();
153 for _i in 0..TOTAL_COUNT / THREAD_COUNT {
154 let task = queue.take();
155 let now = chrono::Local::now();
156 let diff = (now.timestamp_nanos() - task.deadline) / 1000;
157 if diff <= 100 {
158 *map.entry(100).or_default() += 1;
159 } else if diff <= 200 {
160 *map.entry(200).or_default() += 1;
161 } else if diff <= 300 {
162 *map.entry(300).or_default() += 1;
163 } else if diff <= 400 {
164 *map.entry(400).or_default() += 1;
165 } else if diff <= 500 {
166 *map.entry(500).or_default() += 1;
167 } else if diff <= 600 {
168 *map.entry(600).or_default() += 1;
169 } else {
170 *map.entry(1000).or_default() += 1;
171 }
172 }
178 map
179 })
180 })
181 .collect::<Vec<_>>()
182 .into_iter()
183 .map(|h| h.join().unwrap())
184 .collect::<Vec<_>>();
185
186 let mut results = HashMap::<i32, i32>::default();
187
188 for map in maps {
189 for (k, v) in map {
190 *results.entry(k).or_default() += v;
191 }
192 }
193
194 let mut result_count = 0;
195 println!("Response Latency\tCount\tPercent");
196 for (k, v) in results {
197 result_count += v;
198 println!(
199 "{:14}us\t{:5}\t{:.2}%",
200 k,
201 v,
202 (v as f64 / TOTAL_COUNT as f64) * 100f64
203 );
204 }
205 assert_eq!(1000, result_count);
206 }
207
208 fn after(du: Duration) -> DateTime<Local> {
209 chrono::Local::now() + du
210 }
211}