drop_queue/
lib.rs

1use std::sync::{Arc, atomic::AtomicBool};
2
3use tokio::sync::{
4    Mutex,
5    mpsc::{self, UnboundedReceiver, UnboundedSender},
6};
7
8#[cfg(feature = "notmad")]
9use tokio_util::sync::CancellationToken;
10
11type ThreadSafeQueueItem = Arc<Mutex<dyn QueueItem + Send + Sync + 'static>>;
12
13#[derive(Clone)]
14pub struct DropQueue {
15    draining: Arc<AtomicBool>,
16    input: UnboundedSender<ThreadSafeQueueItem>,
17    receiver: Arc<Mutex<UnboundedReceiver<ThreadSafeQueueItem>>>,
18}
19
20impl Default for DropQueue {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl DropQueue {
27    pub fn new() -> Self {
28        let (tx, rx) = mpsc::unbounded_channel();
29
30        Self {
31            draining: Arc::new(AtomicBool::new(false)),
32            input: tx,
33            receiver: Arc::new(Mutex::new(rx)),
34        }
35    }
36
37    pub fn assign<F, Fut>(&self, f: F) -> anyhow::Result<()>
38    where
39        F: FnOnce() -> Fut + Send + Sync + 'static,
40        Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
41    {
42        if self.draining.load(std::sync::atomic::Ordering::Relaxed) {
43            panic!("trying to put an item on a draining queue. This is not allowed");
44        }
45
46        self.input
47            .send(Arc::new(Mutex::new(ClosureComponent {
48                inner: Box::new(Some(f)),
49            })))
50            .expect("unbounded channel should never be full");
51
52        Ok(())
53    }
54
55    pub async fn process_next(&self) -> anyhow::Result<()> {
56        let item = {
57            let mut queue = self.receiver.lock().await;
58            queue.recv().await
59        };
60
61        if let Some(item) = item {
62            let mut item = item.try_lock().expect("should always be unlockable");
63            item.execute().await?;
64        }
65
66        Ok(())
67    }
68
69    pub async fn process(&self) -> anyhow::Result<()> {
70        loop {
71            if self.draining.load(std::sync::atomic::Ordering::Relaxed) {
72                return Ok(());
73            }
74
75            self.process_next().await?;
76        }
77    }
78
79    pub async fn try_process_next(&self) -> anyhow::Result<Option<()>> {
80        let item = {
81            let mut queue = self.receiver.lock().await;
82            match queue.try_recv() {
83                Ok(o) => o,
84                Err(_) => return Ok(None),
85            }
86        };
87
88        let mut item = item
89            .try_lock()
90            .expect("we should always be able to unlock item");
91        item.execute().await?;
92
93        Ok(Some(()))
94    }
95
96    pub async fn drain(&self) -> anyhow::Result<()> {
97        self.draining
98            .store(true, std::sync::atomic::Ordering::Release);
99
100        while self.try_process_next().await?.is_some() {}
101
102        Ok(())
103    }
104
105    #[cfg(feature = "notmad")]
106    async fn process_all(&self, cancellation_token: CancellationToken) -> anyhow::Result<()> {
107        loop {
108            tokio::select! {
109                _ = cancellation_token.cancelled() => {
110                    break;
111                },
112                res = self.process_next() => {
113                    res?;
114                }
115            }
116        }
117
118        self.drain().await?;
119
120        Ok(())
121    }
122}
123
124struct ClosureComponent<F, Fut>
125where
126    F: FnOnce() -> Fut + Send + Sync + 'static,
127    Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
128{
129    inner: Box<Option<F>>,
130}
131
132#[async_trait::async_trait]
133trait QueueItem {
134    async fn execute(&mut self) -> anyhow::Result<()>;
135}
136
137#[async_trait::async_trait]
138impl<F, Fut> QueueItem for ClosureComponent<F, Fut>
139where
140    F: FnOnce() -> Fut + Send + Sync + 'static,
141    Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
142{
143    async fn execute(&mut self) -> Result<(), anyhow::Error> {
144        let item = self.inner.take().expect("to only be called once");
145
146        item().await?;
147
148        Ok(())
149    }
150}
151
152#[cfg(feature = "notmad")]
153mod notmad {
154    use async_trait::async_trait;
155    use notmad::MadError;
156    use tokio_util::sync::CancellationToken;
157
158    use crate::DropQueue;
159
160    #[async_trait]
161    impl notmad::Component for DropQueue {
162        fn name(&self) -> Option<String> {
163            Some("drop-queue/drop-queue".into())
164        }
165
166        async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError> {
167            self.process_all(cancellation_token)
168                .await
169                .map_err(notmad::MadError::Inner)?;
170
171            Ok(())
172        }
173    }
174}
175#[cfg(feature = "notmad")]
176#[allow(unused_imports)]
177pub use notmad::*;
178
179#[cfg(test)]
180mod test {
181    use tokio::sync::oneshot;
182
183    use crate::DropQueue;
184
185    #[tokio::test]
186    async fn can_drop_item() -> anyhow::Result<()> {
187        let drop_queue = DropQueue::new();
188
189        let (called_tx, called_rx) = oneshot::channel();
190
191        drop_queue.assign(|| async move {
192            tracing::info!("was called");
193
194            called_tx.send(()).unwrap();
195
196            Ok(())
197        })?;
198
199        drop_queue.process_next().await?;
200
201        called_rx.await?;
202
203        Ok(())
204    }
205
206    #[tokio::test]
207    async fn can_drop_multiple_items() -> anyhow::Result<()> {
208        let drop_queue = DropQueue::new();
209
210        let (called_tx, called_rx) = oneshot::channel();
211        let _drop_queue = drop_queue.clone();
212        tokio::spawn(async move {
213            _drop_queue
214                .assign(|| async move {
215                    tracing::info!("was called");
216
217                    called_tx.send(()).unwrap();
218
219                    Ok(())
220                })
221                .unwrap();
222        });
223
224        let (called_tx2, called_rx2) = oneshot::channel();
225        let _drop_queue = drop_queue.clone();
226        tokio::spawn(async move {
227            _drop_queue
228                .assign(|| async move {
229                    tracing::info!("was called");
230
231                    called_tx2.send(()).unwrap();
232
233                    Ok(())
234                })
235                .unwrap();
236        });
237
238        drop_queue.process_next().await?;
239        drop_queue.process_next().await?;
240
241        called_rx.await?;
242        called_rx2.await?;
243
244        Ok(())
245    }
246}