1use std::sync::{Arc, atomic::AtomicBool};
2
3use tokio::sync::{
4 Mutex,
5 mpsc::{self, UnboundedReceiver, UnboundedSender},
6};
7
8#[cfg(feature = "notmad")]
9use tokio_util::sync::CancellationToken;
10
11type ThreadSafeQueueItem = Arc<Mutex<dyn QueueItem + Send + Sync + 'static>>;
12
13#[derive(Clone)]
14pub struct DropQueue {
15 draining: Arc<AtomicBool>,
16 input: UnboundedSender<ThreadSafeQueueItem>,
17 receiver: Arc<Mutex<UnboundedReceiver<ThreadSafeQueueItem>>>,
18}
19
20impl Default for DropQueue {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl DropQueue {
27 pub fn new() -> Self {
28 let (tx, rx) = mpsc::unbounded_channel();
29
30 Self {
31 draining: Arc::new(AtomicBool::new(false)),
32 input: tx,
33 receiver: Arc::new(Mutex::new(rx)),
34 }
35 }
36
37 pub fn assign<F, Fut>(&self, f: F) -> anyhow::Result<()>
38 where
39 F: FnOnce() -> Fut + Send + Sync + 'static,
40 Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
41 {
42 if self.draining.load(std::sync::atomic::Ordering::Relaxed) {
43 panic!("trying to put an item on a draining queue. This is not allowed");
44 }
45
46 self.input
47 .send(Arc::new(Mutex::new(ClosureComponent {
48 inner: Box::new(Some(f)),
49 })))
50 .expect("unbounded channel should never be full");
51
52 Ok(())
53 }
54
55 pub async fn process_next(&self) -> anyhow::Result<()> {
56 let item = {
57 let mut queue = self.receiver.lock().await;
58 queue.recv().await
59 };
60
61 if let Some(item) = item {
62 let mut item = item.try_lock().expect("should always be unlockable");
63 item.execute().await?;
64 }
65
66 Ok(())
67 }
68
69 pub async fn process(&self) -> anyhow::Result<()> {
70 loop {
71 if self.draining.load(std::sync::atomic::Ordering::Relaxed) {
72 return Ok(());
73 }
74
75 self.process_next().await?;
76 }
77 }
78
79 pub async fn try_process_next(&self) -> anyhow::Result<Option<()>> {
80 let item = {
81 let mut queue = self.receiver.lock().await;
82 match queue.try_recv() {
83 Ok(o) => o,
84 Err(_) => return Ok(None),
85 }
86 };
87
88 let mut item = item
89 .try_lock()
90 .expect("we should always be able to unlock item");
91 item.execute().await?;
92
93 Ok(Some(()))
94 }
95
96 pub async fn drain(&self) -> anyhow::Result<()> {
97 self.draining
98 .store(true, std::sync::atomic::Ordering::Release);
99
100 while self.try_process_next().await?.is_some() {}
101
102 Ok(())
103 }
104
105 #[cfg(feature = "notmad")]
106 async fn process_all(&self, cancellation_token: CancellationToken) -> anyhow::Result<()> {
107 loop {
108 tokio::select! {
109 _ = cancellation_token.cancelled() => {
110 break;
111 },
112 res = self.process_next() => {
113 res?;
114 }
115 }
116 }
117
118 self.drain().await?;
119
120 Ok(())
121 }
122}
123
124struct ClosureComponent<F, Fut>
125where
126 F: FnOnce() -> Fut + Send + Sync + 'static,
127 Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
128{
129 inner: Box<Option<F>>,
130}
131
132#[async_trait::async_trait]
133trait QueueItem {
134 async fn execute(&mut self) -> anyhow::Result<()>;
135}
136
137#[async_trait::async_trait]
138impl<F, Fut> QueueItem for ClosureComponent<F, Fut>
139where
140 F: FnOnce() -> Fut + Send + Sync + 'static,
141 Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
142{
143 async fn execute(&mut self) -> Result<(), anyhow::Error> {
144 let item = self.inner.take().expect("to only be called once");
145
146 item().await?;
147
148 Ok(())
149 }
150}
151
152#[cfg(feature = "notmad")]
153mod notmad {
154 use notmad::{ComponentInfo, MadError};
155 use tokio_util::sync::CancellationToken;
156
157 use crate::DropQueue;
158
159 impl notmad::Component for DropQueue {
160 fn info(&self) -> ComponentInfo {
161 "drop-queue/drop-queue".into()
162 }
163
164 async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError> {
165 self.process_all(cancellation_token)
166 .await
167 .map_err(notmad::MadError::Inner)?;
168
169 Ok(())
170 }
171 }
172}
173#[cfg(feature = "notmad")]
174#[allow(unused_imports)]
175pub use notmad::*;
176
177#[cfg(test)]
178mod test {
179 use tokio::sync::oneshot;
180
181 use crate::DropQueue;
182
183 #[tokio::test]
184 async fn can_drop_item() -> anyhow::Result<()> {
185 let drop_queue = DropQueue::new();
186
187 let (called_tx, called_rx) = oneshot::channel();
188
189 drop_queue.assign(|| async move {
190 tracing::info!("was called");
191
192 called_tx.send(()).unwrap();
193
194 Ok(())
195 })?;
196
197 drop_queue.process_next().await?;
198
199 called_rx.await?;
200
201 Ok(())
202 }
203
204 #[tokio::test]
205 async fn can_drop_multiple_items() -> anyhow::Result<()> {
206 let drop_queue = DropQueue::new();
207
208 let (called_tx, called_rx) = oneshot::channel();
209 let _drop_queue = drop_queue.clone();
210 tokio::spawn(async move {
211 _drop_queue
212 .assign(|| async move {
213 tracing::info!("was called");
214
215 called_tx.send(()).unwrap();
216
217 Ok(())
218 })
219 .unwrap();
220 });
221
222 let (called_tx2, called_rx2) = oneshot::channel();
223 let _drop_queue = drop_queue.clone();
224 tokio::spawn(async move {
225 _drop_queue
226 .assign(|| async move {
227 tracing::info!("was called");
228
229 called_tx2.send(()).unwrap();
230
231 Ok(())
232 })
233 .unwrap();
234 });
235
236 drop_queue.process_next().await?;
237 drop_queue.process_next().await?;
238
239 called_rx.await?;
240 called_rx2.await?;
241
242 Ok(())
243 }
244}