1use std::sync::{Arc, atomic::AtomicBool};
2
3use tokio::sync::{
4 Mutex,
5 mpsc::{self, UnboundedReceiver, UnboundedSender},
6};
7
8#[cfg(feature = "notmad")]
9use tokio_util::sync::CancellationToken;
10
11type ThreadSafeQueueItem = Arc<Mutex<dyn QueueItem + Send + Sync + 'static>>;
12
13#[derive(Clone)]
14pub struct DropQueue {
15 draining: Arc<AtomicBool>,
16 input: UnboundedSender<ThreadSafeQueueItem>,
17 receiver: Arc<Mutex<UnboundedReceiver<ThreadSafeQueueItem>>>,
18}
19
20impl Default for DropQueue {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl DropQueue {
27 pub fn new() -> Self {
28 let (tx, rx) = mpsc::unbounded_channel();
29
30 Self {
31 draining: Arc::new(AtomicBool::new(false)),
32 input: tx,
33 receiver: Arc::new(Mutex::new(rx)),
34 }
35 }
36
37 pub fn assign<F, Fut>(&self, f: F) -> anyhow::Result<()>
38 where
39 F: FnOnce() -> Fut + Send + Sync + 'static,
40 Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
41 {
42 if self.draining.load(std::sync::atomic::Ordering::Relaxed) {
43 panic!("trying to put an item on a draining queue. This is not allowed");
44 }
45
46 self.input
47 .send(Arc::new(Mutex::new(ClosureComponent {
48 inner: Box::new(Some(f)),
49 })))
50 .expect("unbounded channel should never be full");
51
52 Ok(())
53 }
54
55 pub async fn process_next(&self) -> anyhow::Result<()> {
56 let item = {
57 let mut queue = self.receiver.lock().await;
58 queue.recv().await
59 };
60
61 if let Some(item) = item {
62 let mut item = item.try_lock().expect("should always be unlockable");
63 item.execute().await?;
64 }
65
66 Ok(())
67 }
68
69 pub async fn process(&self) -> anyhow::Result<()> {
70 loop {
71 if self.draining.load(std::sync::atomic::Ordering::Relaxed) {
72 return Ok(());
73 }
74
75 self.process_next().await?;
76 }
77 }
78
79 pub async fn try_process_next(&self) -> anyhow::Result<Option<()>> {
80 let item = {
81 let mut queue = self.receiver.lock().await;
82 match queue.try_recv() {
83 Ok(o) => o,
84 Err(_) => return Ok(None),
85 }
86 };
87
88 let mut item = item
89 .try_lock()
90 .expect("we should always be able to unlock item");
91 item.execute().await?;
92
93 Ok(Some(()))
94 }
95
96 pub async fn drain(&self) -> anyhow::Result<()> {
97 self.draining
98 .store(true, std::sync::atomic::Ordering::Release);
99
100 while self.try_process_next().await?.is_some() {}
101
102 Ok(())
103 }
104
105 #[cfg(feature = "notmad")]
106 async fn process_all(&self, cancellation_token: CancellationToken) -> anyhow::Result<()> {
107 loop {
108 tokio::select! {
109 _ = cancellation_token.cancelled() => {
110 break;
111 },
112 res = self.process_next() => {
113 res?;
114 }
115 }
116 }
117
118 self.drain().await?;
119
120 Ok(())
121 }
122}
123
124struct ClosureComponent<F, Fut>
125where
126 F: FnOnce() -> Fut + Send + Sync + 'static,
127 Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
128{
129 inner: Box<Option<F>>,
130}
131
132#[async_trait::async_trait]
133trait QueueItem {
134 async fn execute(&mut self) -> anyhow::Result<()>;
135}
136
137#[async_trait::async_trait]
138impl<F, Fut> QueueItem for ClosureComponent<F, Fut>
139where
140 F: FnOnce() -> Fut + Send + Sync + 'static,
141 Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
142{
143 async fn execute(&mut self) -> Result<(), anyhow::Error> {
144 let item = self.inner.take().expect("to only be called once");
145
146 item().await?;
147
148 Ok(())
149 }
150}
151
152#[cfg(feature = "notmad")]
153mod notmad {
154 use async_trait::async_trait;
155 use notmad::MadError;
156 use tokio_util::sync::CancellationToken;
157
158 use crate::DropQueue;
159
160 #[async_trait]
161 impl notmad::Component for DropQueue {
162 fn name(&self) -> Option<String> {
163 Some("drop-queue/drop-queue".into())
164 }
165
166 async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError> {
167 self.process_all(cancellation_token)
168 .await
169 .map_err(notmad::MadError::Inner)?;
170
171 Ok(())
172 }
173 }
174}
175#[cfg(feature = "notmad")]
176#[allow(unused_imports)]
177pub use notmad::*;
178
179#[cfg(test)]
180mod test {
181 use tokio::sync::oneshot;
182
183 use crate::DropQueue;
184
185 #[tokio::test]
186 async fn can_drop_item() -> anyhow::Result<()> {
187 let drop_queue = DropQueue::new();
188
189 let (called_tx, called_rx) = oneshot::channel();
190
191 drop_queue.assign(|| async move {
192 tracing::info!("was called");
193
194 called_tx.send(()).unwrap();
195
196 Ok(())
197 })?;
198
199 drop_queue.process_next().await?;
200
201 called_rx.await?;
202
203 Ok(())
204 }
205
206 #[tokio::test]
207 async fn can_drop_multiple_items() -> anyhow::Result<()> {
208 let drop_queue = DropQueue::new();
209
210 let (called_tx, called_rx) = oneshot::channel();
211 let _drop_queue = drop_queue.clone();
212 tokio::spawn(async move {
213 _drop_queue
214 .assign(|| async move {
215 tracing::info!("was called");
216
217 called_tx.send(()).unwrap();
218
219 Ok(())
220 })
221 .unwrap();
222 });
223
224 let (called_tx2, called_rx2) = oneshot::channel();
225 let _drop_queue = drop_queue.clone();
226 tokio::spawn(async move {
227 _drop_queue
228 .assign(|| async move {
229 tracing::info!("was called");
230
231 called_tx2.send(()).unwrap();
232
233 Ok(())
234 })
235 .unwrap();
236 });
237
238 drop_queue.process_next().await?;
239 drop_queue.process_next().await?;
240
241 called_rx.await?;
242 called_rx2.await?;
243
244 Ok(())
245 }
246}