Skip to main content

drop_queue/
lib.rs

1use std::sync::{Arc, atomic::AtomicBool};
2
3use tokio::sync::{
4    Mutex,
5    mpsc::{self, UnboundedReceiver, UnboundedSender},
6};
7
8#[cfg(feature = "notmad")]
9use tokio_util::sync::CancellationToken;
10
11type ThreadSafeQueueItem = Arc<Mutex<dyn QueueItem + Send + Sync + 'static>>;
12
13#[derive(Clone)]
14pub struct DropQueue {
15    draining: Arc<AtomicBool>,
16    input: UnboundedSender<ThreadSafeQueueItem>,
17    receiver: Arc<Mutex<UnboundedReceiver<ThreadSafeQueueItem>>>,
18}
19
20impl Default for DropQueue {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl DropQueue {
27    pub fn new() -> Self {
28        let (tx, rx) = mpsc::unbounded_channel();
29
30        Self {
31            draining: Arc::new(AtomicBool::new(false)),
32            input: tx,
33            receiver: Arc::new(Mutex::new(rx)),
34        }
35    }
36
37    pub fn assign<F, Fut>(&self, f: F) -> anyhow::Result<()>
38    where
39        F: FnOnce() -> Fut + Send + Sync + 'static,
40        Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
41    {
42        if self.draining.load(std::sync::atomic::Ordering::Relaxed) {
43            panic!("trying to put an item on a draining queue. This is not allowed");
44        }
45
46        self.input
47            .send(Arc::new(Mutex::new(ClosureComponent {
48                inner: Box::new(Some(f)),
49            })))
50            .expect("unbounded channel should never be full");
51
52        Ok(())
53    }
54
55    pub async fn process_next(&self) -> anyhow::Result<()> {
56        let item = {
57            let mut queue = self.receiver.lock().await;
58            queue.recv().await
59        };
60
61        if let Some(item) = item {
62            let mut item = item.try_lock().expect("should always be unlockable");
63            item.execute().await?;
64        }
65
66        Ok(())
67    }
68
69    pub async fn process(&self) -> anyhow::Result<()> {
70        loop {
71            if self.draining.load(std::sync::atomic::Ordering::Relaxed) {
72                return Ok(());
73            }
74
75            self.process_next().await?;
76        }
77    }
78
79    pub async fn try_process_next(&self) -> anyhow::Result<Option<()>> {
80        let item = {
81            let mut queue = self.receiver.lock().await;
82            match queue.try_recv() {
83                Ok(o) => o,
84                Err(_) => return Ok(None),
85            }
86        };
87
88        let mut item = item
89            .try_lock()
90            .expect("we should always be able to unlock item");
91        item.execute().await?;
92
93        Ok(Some(()))
94    }
95
96    pub async fn drain(&self) -> anyhow::Result<()> {
97        self.draining
98            .store(true, std::sync::atomic::Ordering::Release);
99
100        while self.try_process_next().await?.is_some() {}
101
102        Ok(())
103    }
104
105    #[cfg(feature = "notmad")]
106    async fn process_all(&self, cancellation_token: CancellationToken) -> anyhow::Result<()> {
107        loop {
108            tokio::select! {
109                _ = cancellation_token.cancelled() => {
110                    break;
111                },
112                res = self.process_next() => {
113                    res?;
114                }
115            }
116        }
117
118        self.drain().await?;
119
120        Ok(())
121    }
122}
123
124struct ClosureComponent<F, Fut>
125where
126    F: FnOnce() -> Fut + Send + Sync + 'static,
127    Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
128{
129    inner: Box<Option<F>>,
130}
131
132#[async_trait::async_trait]
133trait QueueItem {
134    async fn execute(&mut self) -> anyhow::Result<()>;
135}
136
137#[async_trait::async_trait]
138impl<F, Fut> QueueItem for ClosureComponent<F, Fut>
139where
140    F: FnOnce() -> Fut + Send + Sync + 'static,
141    Fut: Future<Output = Result<(), anyhow::Error>> + Send + 'static,
142{
143    async fn execute(&mut self) -> Result<(), anyhow::Error> {
144        let item = self.inner.take().expect("to only be called once");
145
146        item().await?;
147
148        Ok(())
149    }
150}
151
152#[cfg(feature = "notmad")]
153mod notmad {
154    use notmad::{ComponentInfo, MadError};
155    use tokio_util::sync::CancellationToken;
156
157    use crate::DropQueue;
158
159    impl notmad::Component for DropQueue {
160        fn info(&self) -> ComponentInfo {
161            "drop-queue/drop-queue".into()
162        }
163
164        async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError> {
165            self.process_all(cancellation_token)
166                .await
167                .map_err(notmad::MadError::Inner)?;
168
169            Ok(())
170        }
171    }
172}
173#[cfg(feature = "notmad")]
174#[allow(unused_imports)]
175pub use notmad::*;
176
177#[cfg(test)]
178mod test {
179    use tokio::sync::oneshot;
180
181    use crate::DropQueue;
182
183    #[tokio::test]
184    async fn can_drop_item() -> anyhow::Result<()> {
185        let drop_queue = DropQueue::new();
186
187        let (called_tx, called_rx) = oneshot::channel();
188
189        drop_queue.assign(|| async move {
190            tracing::info!("was called");
191
192            called_tx.send(()).unwrap();
193
194            Ok(())
195        })?;
196
197        drop_queue.process_next().await?;
198
199        called_rx.await?;
200
201        Ok(())
202    }
203
204    #[tokio::test]
205    async fn can_drop_multiple_items() -> anyhow::Result<()> {
206        let drop_queue = DropQueue::new();
207
208        let (called_tx, called_rx) = oneshot::channel();
209        let _drop_queue = drop_queue.clone();
210        tokio::spawn(async move {
211            _drop_queue
212                .assign(|| async move {
213                    tracing::info!("was called");
214
215                    called_tx.send(()).unwrap();
216
217                    Ok(())
218                })
219                .unwrap();
220        });
221
222        let (called_tx2, called_rx2) = oneshot::channel();
223        let _drop_queue = drop_queue.clone();
224        tokio::spawn(async move {
225            _drop_queue
226                .assign(|| async move {
227                    tracing::info!("was called");
228
229                    called_tx2.send(()).unwrap();
230
231                    Ok(())
232                })
233                .unwrap();
234        });
235
236        drop_queue.process_next().await?;
237        drop_queue.process_next().await?;
238
239        called_rx.await?;
240        called_rx2.await?;
241
242        Ok(())
243    }
244}