delayqueue/
lib.rs

1use std::{cmp::Reverse, collections::BinaryHeap, sync::Arc, thread::ThreadId, time};
2
3use parking_lot::{Condvar, Mutex};
4
5pub trait Delayed: Ord {
6    fn delayed(&self) -> i64;
7}
8
9#[derive(Default)]
10pub struct DelayQueue<T: Delayed> {
11    queue: Arc<Mutex<DelayQueueInner<T>>>,
12    available: Arc<Condvar>,
13}
14
15impl<T: Delayed> Clone for DelayQueue<T> {
16    fn clone(&self) -> Self {
17        Self {
18            queue: Arc::clone(&self.queue),
19            available: Arc::clone(&self.available),
20        }
21    }
22}
23
24#[derive(Default, Clone)]
25struct DelayQueueInner<T: Delayed> {
26    queue: BinaryHeap<Reverse<Arc<T>>>,
27    current_thread: Option<ThreadId>,
28}
29
30impl<T: Delayed> DelayQueueInner<T> {
31    fn peek(&self) -> Option<&T> {
32        let result = self.queue.peek()?;
33        Some(&result.0)
34    }
35}
36
37impl<T> DelayQueue<T>
38where
39    T: Delayed + Sync + Send,
40{
41    pub fn put(&mut self, t: T) {
42        let queue = self.queue.clone();
43        let queue = &mut queue.lock().queue;
44        let t = Reverse(Arc::new(t));
45        queue.push(t.clone());
46        if queue.peek() == Some(&t) {
47            self.available.notify_one();
48        }
49    }
50
51    pub fn take(&mut self) -> Arc<T> {
52        let queue = self.queue.clone();
53        let avaliable = self.available.clone();
54        let mut guard = queue.lock();
55        loop {
56            match guard.peek() {
57                None => {
58                    avaliable.wait(&mut guard);
59                }
60                Some(first) => {
61                    let delayed = first.delayed();
62                    if delayed <= 0 {
63                        let result = guard.queue.pop().unwrap();
64                        if guard.current_thread.is_none() && guard.peek().is_some() {
65                            avaliable.notify_one();
66                        }
67                        return result.0;
68                    }
69                    let _ = first;
70                    match guard.current_thread {
71                        Some(_) => {
72                            avaliable.wait(&mut guard);
73                        }
74                        None => {
75                            let thread_id = std::thread::current().id();
76                            guard.current_thread = Some(thread_id);
77                            avaliable
78                                .wait_for(&mut guard, time::Duration::from_nanos(delayed as u64));
79                            if guard.current_thread == Some(thread_id) {
80                                guard.current_thread = None
81                            }
82                        }
83                    }
84                }
85            }
86        }
87    }
88}
89
90#[cfg(test)]
91mod test {
92    use std::collections::HashMap;
93
94    use chrono::{DateTime, Duration, Local};
95
96    use super::*;
97    #[test]
98    fn test() {
99        #[derive(Default, Debug, PartialEq, Eq)]
100        struct Task {
101            deadline: i64,
102
103            message: String,
104        }
105
106        impl Task {
107            fn new<S: Into<String>>(deadline: i64, message: S) -> Task {
108                let message = message.into();
109                Task { deadline, message }
110            }
111        }
112
113        impl Delayed for Task {
114            fn delayed(&self) -> i64 {
115                self.deadline - chrono::Local::now().timestamp_nanos()
116            }
117        }
118
119        impl Ord for Task {
120            fn cmp(&self, other: &Self) -> std::cmp::Ordering {
121                self.deadline.cmp(&other.deadline)
122            }
123        }
124
125        impl PartialOrd for Task {
126            fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
127                Some(self.deadline.cmp(&other.deadline))
128            }
129        }
130
131        const TOTAL_COUNT: usize = 1000;
132        const THREAD_COUNT: usize = 8;
133
134        let queue = DelayQueue::<Task>::default();
135        {
136            let mut queue = queue.clone();
137            std::thread::spawn(move || {
138                for index in 0..TOTAL_COUNT {
139                    let v = rand::random::<u64>() % 10000;
140                    queue.put(Task::new(
141                        after(Duration::milliseconds(v as i64)).timestamp_nanos(),
142                        format!("index: {}. delay for {}ms", index, v),
143                    ));
144                }
145            });
146        }
147
148        let maps = (0..THREAD_COUNT)
149            .map(|_thead_id| {
150                let mut queue = queue.clone();
151                std::thread::spawn(move || {
152                    let mut map = HashMap::<i32, i32>::new();
153                    for _i in 0..TOTAL_COUNT / THREAD_COUNT {
154                        let task = queue.take();
155                        let now = chrono::Local::now();
156                        let diff = (now.timestamp_nanos() - task.deadline) / 1000;
157                        if diff <= 100 {
158                            *map.entry(100).or_default() += 1;
159                        } else if diff <= 200 {
160                            *map.entry(200).or_default() += 1;
161                        } else if diff <= 300 {
162                            *map.entry(300).or_default() += 1;
163                        } else if diff <= 400 {
164                            *map.entry(400).or_default() += 1;
165                        } else if diff <= 500 {
166                            *map.entry(500).or_default() += 1;
167                        } else if diff <= 600 {
168                            *map.entry(600).or_default() += 1;
169                        } else {
170                            *map.entry(1000).or_default() += 1;
171                        }
172                        // assert!(diff < 500);
173                        // println!(
174                        //     "thread: {:2}. {} task: {:?} diff: {}us",
175                        //     thead_id, now, task, diff
176                        // );
177                    }
178                    map
179                })
180            })
181            .collect::<Vec<_>>()
182            .into_iter()
183            .map(|h| h.join().unwrap())
184            .collect::<Vec<_>>();
185
186        let mut results = HashMap::<i32, i32>::default();
187
188        for map in maps {
189            for (k, v) in map {
190                *results.entry(k).or_default() += v;
191            }
192        }
193
194        let mut result_count = 0;
195        println!("Response Latency\tCount\tPercent");
196        for (k, v) in results {
197            result_count += v;
198            println!(
199                "{:14}us\t{:5}\t{:.2}%",
200                k,
201                v,
202                (v as f64 / TOTAL_COUNT as f64) * 100f64
203            );
204        }
205        assert_eq!(1000, result_count);
206    }
207
208    fn after(du: Duration) -> DateTime<Local> {
209        chrono::Local::now() + du
210    }
211}