gnostr_asyncgit/asyncjob/
mod.rs

1//! provides `AsyncJob` trait and `AsyncSingleJob` struct
2
3#![deny(clippy::expect_used)]
4
5use std::sync::{Arc, Mutex, RwLock};
6
7use crossbeam_channel::Sender;
8
9use crate::error::Result;
10
11/// Passed to `AsyncJob::run` allowing sending intermediate progress
12/// notifications
13pub struct RunParams<T: Copy + Send, P: Clone + Send + Sync + PartialEq> {
14    sender: Sender<T>,
15    progress: Arc<RwLock<P>>,
16}
17
18impl<T: Copy + Send, P: Clone + Send + Sync + PartialEq> RunParams<T, P> {
19    /// send an intermediate update notification.
20    /// do not confuse this with the return value of `run`.
21    /// `send` should only be used about progress notifications
22    /// and not for the final notification indicating the end of the
23    /// async job. see `run` for more info
24    pub fn send(&self, notification: T) -> Result<()> {
25        self.sender.send(notification)?;
26        Ok(())
27    }
28
29    /// set the current progress
30    pub fn set_progress(&self, p: P) -> Result<bool> {
31        Ok(if *self.progress.read()? == p {
32            false
33        } else {
34            *(self.progress.write()?) = p;
35            true
36        })
37    }
38}
39
40/// trait that defines an async task we can run on a threadpool
41pub trait AsyncJob: Send + Sync + Clone {
42    /// defines what notification type is used to communicate outside
43    type Notification: Copy + Send;
44    /// type of progress
45    type Progress: Clone + Default + Send + Sync + PartialEq;
46
47    /// can run a synchronous time intensive task.
48    /// the returned notification is used to tell interested parties
49    /// that the job finished and the job can be access via
50    /// `take_last`. prior to this final notification it is not safe
51    /// to assume `take_last` will already return the correct job
52    fn run(
53        &mut self,
54        params: RunParams<Self::Notification, Self::Progress>,
55    ) -> Result<Self::Notification>;
56
57    /// allows observers to get intermediate progress status if the
58    /// job customizes it by default this will be returning
59    /// `Self::Progress::default()`
60    fn get_progress(&self) -> Self::Progress {
61        Self::Progress::default()
62    }
63}
64
65/// Abstraction for a FIFO task queue that will only queue up **one**
66/// `next` job. It keeps overwriting the next job until it is actually
67/// taken to be processed
68#[derive(Debug, Clone)]
69pub struct AsyncSingleJob<J: AsyncJob> {
70    next: Arc<Mutex<Option<J>>>,
71    last: Arc<Mutex<Option<J>>>,
72    progress: Arc<RwLock<J::Progress>>,
73    sender: Sender<J::Notification>,
74    pending: Arc<Mutex<()>>,
75}
76
77impl<J: 'static + AsyncJob> AsyncSingleJob<J> {
78    ///
79    pub fn new(sender: Sender<J::Notification>) -> Self {
80        Self {
81            next: Arc::new(Mutex::new(None)),
82            last: Arc::new(Mutex::new(None)),
83            pending: Arc::new(Mutex::new(())),
84            progress: Arc::new(RwLock::new(J::Progress::default())),
85            sender,
86        }
87    }
88
89    ///
90    pub fn is_pending(&self) -> bool {
91        self.pending.try_lock().is_err()
92    }
93
94    /// makes sure `next` is cleared and returns `true` if it actually
95    /// canceled something
96    pub fn cancel(&mut self) -> bool {
97        if let Ok(mut next) = self.next.lock() {
98            if next.is_some() {
99                *next = None;
100                return true;
101            }
102        }
103
104        false
105    }
106
107    /// take out last finished job
108    pub fn take_last(&self) -> Option<J> {
109        self.last.lock().map_or(None, |mut last| last.take())
110    }
111
112    /// spawns `task` if nothing is running currently,
113    /// otherwise schedules as `next` overwriting if `next` was set
114    /// before. return `true` if the new task gets started right
115    /// away.
116    pub fn spawn(&mut self, task: J) -> bool {
117        self.schedule_next(task);
118        self.check_for_job()
119    }
120
121    ///
122    pub fn progress(&self) -> Option<J::Progress> {
123        self.progress.read().ok().map(|d| (*d).clone())
124    }
125
126    fn check_for_job(&self) -> bool {
127        if self.is_pending() {
128            return false;
129        }
130
131        if let Some(task) = self.take_next() {
132            let self_clone = (*self).clone();
133            rayon_core::spawn(move || {
134                if let Err(e) = self_clone.run_job(task) {
135                    log::error!("async job error: {}", e);
136                }
137            });
138
139            return true;
140        }
141
142        false
143    }
144
145    fn run_job(&self, mut task: J) -> Result<()> {
146        //limit the pending scope
147        {
148            let _pending = self.pending.lock()?;
149
150            let notification = task.run(RunParams {
151                progress: self.progress.clone(),
152                sender: self.sender.clone(),
153            })?;
154
155            if let Ok(mut last) = self.last.lock() {
156                *last = Some(task);
157            }
158
159            self.sender.send(notification)?;
160        }
161
162        self.check_for_job();
163
164        Ok(())
165    }
166
167    fn schedule_next(&mut self, task: J) {
168        if let Ok(mut next) = self.next.lock() {
169            *next = Some(task);
170        }
171    }
172
173    fn take_next(&self) -> Option<J> {
174        self.next.lock().map_or(None, |mut next| next.take())
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use std::{
181        sync::atomic::{AtomicBool, AtomicU32, Ordering},
182        thread,
183        time::Duration,
184    };
185
186    use crossbeam_channel::unbounded;
187    use pretty_assertions::assert_eq;
188
189    use super::*;
190
191    #[derive(Clone)]
192    struct TestJob {
193        v: Arc<AtomicU32>,
194        finish: Arc<AtomicBool>,
195        value_to_add: u32,
196        notification_value: u32, // Added field
197    }
198
199    impl AsyncJob for TestJob {
200        type Notification = u32; // Changed from ()
201        type Progress = ();
202
203        fn run(
204            &mut self,
205            params: RunParams<Self::Notification, Self::Progress>, // Use params
206        ) -> Result<Self::Notification> {
207            println!("[job] wait");
208
209            while !self.finish.load(Ordering::SeqCst) {
210                std::thread::yield_now();
211            }
212
213            println!("[job] sleep");
214
215            thread::sleep(Duration::from_millis(100));
216
217            println!("[job] done sleeping");
218
219            // Send notification before adding value
220            params.send(self.notification_value)?;
221
222            let res = self.v.fetch_add(self.value_to_add, Ordering::SeqCst);
223
224            println!("[job] value: {res}");
225
226            Ok(self.notification_value) // Return the notification value
227        }
228    }
229
230    #[test]
231    fn test_overwrite() {
232        let (sender, receiver) = unbounded();
233
234        let mut job: AsyncSingleJob<TestJob> = AsyncSingleJob::new(sender);
235
236        let shared_v = Arc::new(AtomicU32::new(1));
237        let shared_finish = Arc::new(AtomicBool::new(false));
238
239        let task_template = TestJob {
240            v: Arc::clone(&shared_v),
241            finish: Arc::clone(&shared_finish),
242            value_to_add: 1,
243            notification_value: 10,
244        };
245
246        // First job
247        let first_task = task_template.clone();
248        assert!(job.spawn(first_task));
249
250        // Subsequent jobs (only the last one will run)
251        for _ in 0..5 {
252            println!("spawn");
253            let next_task = task_template.clone();
254            assert!(!job.spawn(next_task));
255        }
256
257        shared_finish.store(true, Ordering::SeqCst);
258
259        println!("recv");
260        assert_eq!(receiver.recv().unwrap(), 10);
261        assert_eq!(receiver.recv().unwrap(), 10);
262        assert!(receiver.is_empty());
263
264        assert_eq!(shared_v.load(std::sync::atomic::Ordering::SeqCst), 2);
265    }
266
267    fn wait_for_job(job: &AsyncSingleJob<TestJob>) {
268        while job.is_pending() {
269            thread::sleep(Duration::from_millis(10));
270        }
271    }
272
273    #[test]
274    fn test_cancel() {
275        let (sender, receiver) = unbounded();
276
277        let mut job: AsyncSingleJob<TestJob> = AsyncSingleJob::new(sender);
278
279        // Define the first job
280        let initial_task = TestJob {
281            v: Arc::new(AtomicU32::new(1)),
282            finish: Arc::new(AtomicBool::new(false)),
283            value_to_add: 1,
284            notification_value: 20,
285        };
286
287        // Spawn the first job
288        assert!(job.spawn(initial_task.clone()));
289        initial_task.finish.store(true, Ordering::SeqCst); // Signal the first job to finish
290        thread::sleep(Duration::from_millis(10)); // Give it time to start
291
292        // Schedule subsequent jobs (these will be cancelled)
293        for _ in 0..5 {
294            println!("spawn");
295            // Use a new task definition for subsequent jobs, or ensure they are distinct if needed
296            // For now, cloning the initial_task definition is fine as they are meant to be cancelled.
297            assert!(!job.spawn(initial_task.clone()));
298        }
299
300        println!("cancel");
301        assert!(job.cancel()); // Cancel the queued jobs
302
303        // Wait for the first job to complete
304        wait_for_job(&job);
305
306        println!("recv");
307        // Assert the received notification value from the first job
308        assert_eq!(receiver.recv().unwrap(), 20);
309        println!("received");
310
311        // Retrieve the completed job from `last`
312        let completed_job = job.take_last().expect("Should have a completed job");
313
314        // Assert the value of the completed job
315        // Initial value was 1, it should have been incremented by 1 (value_to_add)
316        assert_eq!(completed_job.v.load(std::sync::atomic::Ordering::SeqCst), 2);
317    }
318}