gnostr_asyncgit/asyncjob/
mod.rs

1//! provides `AsyncJob` trait and `AsyncSingleJob` struct
2
3#![deny(clippy::expect_used)]
4
5use std::sync::{Arc, Mutex, RwLock};
6
7use crossbeam_channel::Sender;
8
9use crate::error::Result;
10
11/// Passed to `AsyncJob::run` allowing sending intermediate progress
12/// notifications
13pub struct RunParams<T: Copy + Send, P: Clone + Send + Sync + PartialEq> {
14    sender: Sender<T>,
15    progress: Arc<RwLock<P>>,
16}
17
18impl<T: Copy + Send, P: Clone + Send + Sync + PartialEq> RunParams<T, P> {
19    /// send an intermediate update notification.
20    /// do not confuse this with the return value of `run`.
21    /// `send` should only be used about progress notifications
22    /// and not for the final notification indicating the end of the
23    /// async job. see `run` for more info
24    pub fn send(&self, notification: T) -> Result<()> {
25        self.sender.send(notification)?;
26        Ok(())
27    }
28
29    /// set the current progress
30    pub fn set_progress(&self, p: P) -> Result<bool> {
31        Ok(if *self.progress.read()? == p {
32            false
33        } else {
34            *(self.progress.write()?) = p;
35            true
36        })
37    }
38}
39
40/// trait that defines an async task we can run on a threadpool
41pub trait AsyncJob: Send + Sync + Clone {
42    /// defines what notification type is used to communicate outside
43    type Notification: Copy + Send;
44    /// type of progress
45    type Progress: Clone + Default + Send + Sync + PartialEq;
46
47    /// can run a synchronous time intensive task.
48    /// the returned notification is used to tell interested parties
49    /// that the job finished and the job can be access via
50    /// `take_last`. prior to this final notification it is not safe
51    /// to assume `take_last` will already return the correct job
52    fn run(
53        &mut self,
54        params: RunParams<Self::Notification, Self::Progress>,
55    ) -> Result<Self::Notification>;
56
57    /// allows observers to get intermediate progress status if the
58    /// job customizes it by default this will be returning
59    /// `Self::Progress::default()`
60    fn get_progress(&self) -> Self::Progress {
61        Self::Progress::default()
62    }
63}
64
65/// Abstraction for a FIFO task queue that will only queue up **one**
66/// `next` job. It keeps overwriting the next job until it is actually
67/// taken to be processed
68#[derive(Debug, Clone)]
69pub struct AsyncSingleJob<J: AsyncJob> {
70    next: Arc<Mutex<Option<J>>>,
71    last: Arc<Mutex<Option<J>>>,
72    progress: Arc<RwLock<J::Progress>>,
73    sender: Sender<J::Notification>,
74    pending: Arc<Mutex<()>>,
75}
76
77impl<J: 'static + AsyncJob> AsyncSingleJob<J> {
78    ///
79    pub fn new(sender: Sender<J::Notification>) -> Self {
80        Self {
81            next: Arc::new(Mutex::new(None)),
82            last: Arc::new(Mutex::new(None)),
83            pending: Arc::new(Mutex::new(())),
84            progress: Arc::new(RwLock::new(J::Progress::default())),
85            sender,
86        }
87    }
88
89    ///
90    pub fn is_pending(&self) -> bool {
91        self.pending.try_lock().is_err()
92    }
93
94    /// makes sure `next` is cleared and returns `true` if it actually
95    /// canceled something
96    pub fn cancel(&mut self) -> bool {
97        if let Ok(mut next) = self.next.lock() {
98            if next.is_some() {
99                *next = None;
100                return true;
101            }
102        }
103
104        false
105    }
106
107    /// take out last finished job
108    pub fn take_last(&self) -> Option<J> {
109        self.last.lock().map_or(None, |mut last| last.take())
110    }
111
112    /// spawns `task` if nothing is running currently,
113    /// otherwise schedules as `next` overwriting if `next` was set
114    /// before. return `true` if the new task gets started right
115    /// away.
116    pub fn spawn(&mut self, task: J) -> bool {
117        self.schedule_next(task);
118        self.check_for_job()
119    }
120
121    ///
122    pub fn progress(&self) -> Option<J::Progress> {
123        self.progress.read().ok().map(|d| (*d).clone())
124    }
125
126    fn check_for_job(&self) -> bool {
127        if self.is_pending() {
128            return false;
129        }
130
131        if let Some(task) = self.take_next() {
132            let self_clone = (*self).clone();
133            rayon_core::spawn(move || {
134                if let Err(e) = self_clone.run_job(task) {
135                    log::error!("async job error: {}", e);
136                }
137            });
138
139            return true;
140        }
141
142        false
143    }
144
145    fn run_job(&self, mut task: J) -> Result<()> {
146        //limit the pending scope
147        {
148            let _pending = self.pending.lock()?;
149
150            let notification = task.run(RunParams {
151                progress: self.progress.clone(),
152                sender: self.sender.clone(),
153            })?;
154
155            if let Ok(mut last) = self.last.lock() {
156                *last = Some(task);
157            }
158
159            self.sender.send(notification)?;
160        }
161
162        self.check_for_job();
163
164        Ok(())
165    }
166
167    fn schedule_next(&mut self, task: J) {
168        if let Ok(mut next) = self.next.lock() {
169            *next = Some(task);
170        }
171    }
172
173    fn take_next(&self) -> Option<J> {
174        self.next.lock().map_or(None, |mut next| next.take())
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use std::{
181        sync::atomic::{AtomicBool, AtomicU32, Ordering},
182        thread,
183        time::Duration,
184    };
185
186    use crossbeam_channel::unbounded;
187    use pretty_assertions::assert_eq;
188
189    use super::*;
190
191    #[derive(Clone)]
192    struct TestJob {
193        v: Arc<AtomicU32>,
194        finish: Arc<AtomicBool>,
195        value_to_add: u32,
196        notification_value: u32, // Added field
197    }
198
199    impl AsyncJob for TestJob {
200        type Notification = u32; // Changed from ()
201        type Progress = ();
202
203        fn run(
204            &mut self,
205            params: RunParams<Self::Notification, Self::Progress>, // Use params
206        ) -> Result<Self::Notification> {
207            println!("[job] wait");
208
209            while !self.finish.load(Ordering::SeqCst) {
210                std::thread::yield_now();
211            }
212
213            println!("[job] sleep");
214
215            thread::sleep(Duration::from_millis(100));
216
217            println!("[job] done sleeping");
218
219            // Send notification before adding value
220            params.send(self.notification_value)?;
221
222            let res = self.v.fetch_add(self.value_to_add, Ordering::SeqCst);
223
224            println!("[job] value: {res}");
225
226            Ok(self.notification_value) // Return the notification value
227        }
228    }
229
230    #[test]
231    #[ignore]
232    fn test_overwrite() {
233        let (sender, receiver) = unbounded();
234
235        let mut job: AsyncSingleJob<TestJob> = AsyncSingleJob::new(sender);
236
237        let shared_v = Arc::new(AtomicU32::new(1));
238        let shared_finish = Arc::new(AtomicBool::new(false));
239
240        let task_template = TestJob {
241            v: Arc::clone(&shared_v),
242            finish: Arc::clone(&shared_finish),
243            value_to_add: 1,
244            notification_value: 10,
245        };
246
247        // First job
248        let first_task = task_template.clone();
249        assert!(job.spawn(first_task));
250
251        // Subsequent jobs (only the last one will run)
252        for _ in 0..5 {
253            println!("spawn");
254            let next_task = task_template.clone();
255            assert!(!job.spawn(next_task));
256        }
257
258        shared_finish.store(true, Ordering::SeqCst);
259
260        println!("recv");
261        assert_eq!(receiver.recv().unwrap(), 10);
262        assert_eq!(receiver.recv().unwrap(), 10);
263        assert!(receiver.is_empty());
264
265        assert_eq!(shared_v.load(std::sync::atomic::Ordering::SeqCst), 2);
266    }
267
268    fn wait_for_job(job: &AsyncSingleJob<TestJob>) {
269        while job.is_pending() {
270            thread::sleep(Duration::from_millis(10));
271        }
272    }
273
274    #[test]
275    fn test_cancel() {
276        let (sender, receiver) = unbounded();
277
278        let mut job: AsyncSingleJob<TestJob> = AsyncSingleJob::new(sender);
279
280        // Define the first job
281        let initial_task = TestJob {
282            v: Arc::new(AtomicU32::new(1)),
283            finish: Arc::new(AtomicBool::new(false)),
284            value_to_add: 1,
285            notification_value: 20,
286        };
287
288        // Spawn the first job
289        assert!(job.spawn(initial_task.clone()));
290        initial_task.finish.store(true, Ordering::SeqCst); // Signal the first job to finish
291        thread::sleep(Duration::from_millis(10)); // Give it time to start
292
293        // Schedule subsequent jobs (these will be cancelled)
294        for _ in 0..5 {
295            println!("spawn");
296            // Use a new task definition for subsequent jobs, or ensure they are distinct if needed
297            // For now, cloning the initial_task definition is fine as they are meant to be cancelled.
298            assert!(!job.spawn(initial_task.clone()));
299        }
300
301        println!("cancel");
302        assert!(job.cancel()); // Cancel the queued jobs
303
304        // Wait for the first job to complete
305        wait_for_job(&job);
306
307        println!("recv");
308        // Assert the received notification value from the first job
309        assert_eq!(receiver.recv().unwrap(), 20);
310        println!("received");
311
312        // Retrieve the completed job from `last`
313        let completed_job = job.take_last().expect("Should have a completed job");
314
315        // Assert the value of the completed job
316        // Initial value was 1, it should have been incremented by 1 (value_to_add)
317        assert_eq!(completed_job.v.load(std::sync::atomic::Ordering::SeqCst), 2);
318    }
319}