gnostr_asyncgit/asyncjob/
mod.rs

1//! provides `AsyncJob` trait and `AsyncSingleJob` struct
2
3#![deny(clippy::expect_used)]
4
5use std::sync::{Arc, Mutex, RwLock};
6
7use crossbeam_channel::Sender;
8
9use crate::error::Result;
10
11/// Passed to `AsyncJob::run` allowing sending intermediate progress
12/// notifications
13pub struct RunParams<T: Copy + Send, P: Clone + Send + Sync + PartialEq> {
14    sender: Sender<T>,
15    progress: Arc<RwLock<P>>,
16}
17
18impl<T: Copy + Send, P: Clone + Send + Sync + PartialEq> RunParams<T, P> {
19    /// send an intermediate update notification.
20    /// do not confuse this with the return value of `run`.
21    /// `send` should only be used about progress notifications
22    /// and not for the final notification indicating the end of the
23    /// async job. see `run` for more info
24    pub fn send(&self, notification: T) -> Result<()> {
25        self.sender.send(notification)?;
26        Ok(())
27    }
28
29    /// set the current progress
30    pub fn set_progress(&self, p: P) -> Result<bool> {
31        Ok(if *self.progress.read()? == p {
32            false
33        } else {
34            *(self.progress.write()?) = p;
35            true
36        })
37    }
38}
39
40/// trait that defines an async task we can run on a threadpool
41pub trait AsyncJob: Send + Sync + Clone {
42    /// defines what notification type is used to communicate outside
43    type Notification: Copy + Send;
44    /// type of progress
45    type Progress: Clone + Default + Send + Sync + PartialEq;
46
47    /// can run a synchronous time intensive task.
48    /// the returned notification is used to tell interested parties
49    /// that the job finished and the job can be access via
50    /// `take_last`. prior to this final notification it is not safe
51    /// to assume `take_last` will already return the correct job
52    fn run(
53        &mut self,
54        params: RunParams<Self::Notification, Self::Progress>,
55    ) -> Result<Self::Notification>;
56
57    /// allows observers to get intermediate progress status if the
58    /// job customizes it by default this will be returning
59    /// `Self::Progress::default()`
60    fn get_progress(&self) -> Self::Progress {
61        Self::Progress::default()
62    }
63}
64
65/// Abstraction for a FIFO task queue that will only queue up **one**
66/// `next` job. It keeps overwriting the next job until it is actually
67/// taken to be processed
68#[derive(Debug, Clone)]
69pub struct AsyncSingleJob<J: AsyncJob> {
70    next: Arc<Mutex<Option<J>>>,
71    last: Arc<Mutex<Option<J>>>,
72    progress: Arc<RwLock<J::Progress>>,
73    sender: Sender<J::Notification>,
74    pending: Arc<Mutex<()>>,
75}
76
77impl<J: 'static + AsyncJob> AsyncSingleJob<J> {
78    ///
79    pub fn new(sender: Sender<J::Notification>) -> Self {
80        Self {
81            next: Arc::new(Mutex::new(None)),
82            last: Arc::new(Mutex::new(None)),
83            pending: Arc::new(Mutex::new(())),
84            progress: Arc::new(RwLock::new(J::Progress::default())),
85            sender,
86        }
87    }
88
89    ///
90    pub fn is_pending(&self) -> bool {
91        self.pending.try_lock().is_err()
92    }
93
94    /// makes sure `next` is cleared and returns `true` if it actually
95    /// canceled something
96    pub fn cancel(&mut self) -> bool {
97        if let Ok(mut next) = self.next.lock() {
98            if next.is_some() {
99                *next = None;
100                return true;
101            }
102        }
103
104        false
105    }
106
107    /// take out last finished job
108    pub fn take_last(&self) -> Option<J> {
109        self.last.lock().map_or(None, |mut last| last.take())
110    }
111
112    /// spawns `task` if nothing is running currently,
113    /// otherwise schedules as `next` overwriting if `next` was set
114    /// before. return `true` if the new task gets started right
115    /// away.
116    pub fn spawn(&mut self, task: J) -> bool {
117        self.schedule_next(task);
118        self.check_for_job()
119    }
120
121    ///
122    pub fn progress(&self) -> Option<J::Progress> {
123        self.progress.read().ok().map(|d| (*d).clone())
124    }
125
126    fn check_for_job(&self) -> bool {
127        if self.is_pending() {
128            return false;
129        }
130
131        if let Some(task) = self.take_next() {
132            let self_clone = (*self).clone();
133            rayon_core::spawn(move || {
134                if let Err(e) = self_clone.run_job(task) {
135                    log::error!("async job error: {}", e);
136                }
137            });
138
139            return true;
140        }
141
142        false
143    }
144
145    fn run_job(&self, mut task: J) -> Result<()> {
146        //limit the pending scope
147        {
148            let _pending = self.pending.lock()?;
149
150            let notification = task.run(RunParams {
151                progress: self.progress.clone(),
152                sender: self.sender.clone(),
153            })?;
154
155            if let Ok(mut last) = self.last.lock() {
156                *last = Some(task);
157            }
158
159            self.sender.send(notification)?;
160        }
161
162        self.check_for_job();
163
164        Ok(())
165    }
166
167    fn schedule_next(&mut self, task: J) {
168        if let Ok(mut next) = self.next.lock() {
169            *next = Some(task);
170        }
171    }
172
173    fn take_next(&self) -> Option<J> {
174        self.next.lock().map_or(None, |mut next| next.take())
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use std::{
181        sync::atomic::{AtomicBool, AtomicU32, Ordering},
182        thread,
183        time::Duration,
184    };
185
186    use crossbeam_channel::unbounded;
187    use pretty_assertions::assert_eq;
188
189    use super::*;
190
191    #[derive(Clone)]
192    struct TestJob {
193        v: Arc<AtomicU32>,
194        finish: Arc<AtomicBool>,
195        value_to_add: u32,
196    }
197
198    type TestNotification = ();
199
200    impl AsyncJob for TestJob {
201        type Notification = TestNotification;
202        type Progress = ();
203
204        fn run(
205            &mut self,
206            _params: RunParams<Self::Notification, Self::Progress>,
207        ) -> Result<Self::Notification> {
208            println!("[job] wait");
209
210            while !self.finish.load(Ordering::SeqCst) {
211                std::thread::yield_now();
212            }
213
214            println!("[job] sleep");
215
216            thread::sleep(Duration::from_millis(100));
217
218            println!("[job] done sleeping");
219
220            let res = self.v.fetch_add(self.value_to_add, Ordering::SeqCst);
221
222            println!("[job] value: {res}");
223
224            Ok(())
225        }
226    }
227
228    #[test]
229    fn test_overwrite() {
230        let (sender, receiver) = unbounded();
231
232        let mut job: AsyncSingleJob<TestJob> = AsyncSingleJob::new(sender);
233
234        let task = TestJob {
235            v: Arc::new(AtomicU32::new(1)),
236            finish: Arc::new(AtomicBool::new(false)),
237            value_to_add: 1,
238        };
239
240        assert!(job.spawn(task.clone()));
241        task.finish.store(true, Ordering::SeqCst);
242        thread::sleep(Duration::from_millis(10));
243
244        for _ in 0..5 {
245            println!("spawn");
246            assert!(!job.spawn(task.clone()));
247        }
248
249        println!("recv");
250        receiver.recv().unwrap();
251        receiver.recv().unwrap();
252        assert!(receiver.is_empty());
253
254        assert_eq!(task.v.load(std::sync::atomic::Ordering::SeqCst), 3);
255    }
256
257    fn wait_for_job(job: &AsyncSingleJob<TestJob>) {
258        while job.is_pending() {
259            thread::sleep(Duration::from_millis(10));
260        }
261    }
262
263    #[test]
264    fn test_cancel() {
265        let (sender, receiver) = unbounded();
266
267        let mut job: AsyncSingleJob<TestJob> = AsyncSingleJob::new(sender);
268
269        let task = TestJob {
270            v: Arc::new(AtomicU32::new(1)),
271            finish: Arc::new(AtomicBool::new(false)),
272            value_to_add: 1,
273        };
274
275        assert!(job.spawn(task.clone()));
276        task.finish.store(true, Ordering::SeqCst);
277        thread::sleep(Duration::from_millis(10));
278
279        for _ in 0..5 {
280            println!("spawn");
281            assert!(!job.spawn(task.clone()));
282        }
283
284        println!("cancel");
285        assert!(job.cancel());
286
287        task.finish.store(true, Ordering::SeqCst);
288
289        wait_for_job(&job);
290
291        println!("recv");
292        receiver.recv().unwrap();
293        println!("received");
294
295        assert_eq!(task.v.load(std::sync::atomic::Ordering::SeqCst), 2);
296    }
297}