gnostr_asyncgit/asyncjob/
mod.rs1#![deny(clippy::expect_used)]
4
5use std::sync::{Arc, Mutex, RwLock};
6
7use crossbeam_channel::Sender;
8
9use crate::error::Result;
10
11pub struct RunParams<T: Copy + Send, P: Clone + Send + Sync + PartialEq> {
14 sender: Sender<T>,
15 progress: Arc<RwLock<P>>,
16}
17
18impl<T: Copy + Send, P: Clone + Send + Sync + PartialEq> RunParams<T, P> {
19 pub fn send(&self, notification: T) -> Result<()> {
25 self.sender.send(notification)?;
26 Ok(())
27 }
28
29 pub fn set_progress(&self, p: P) -> Result<bool> {
31 Ok(if *self.progress.read()? == p {
32 false
33 } else {
34 *(self.progress.write()?) = p;
35 true
36 })
37 }
38}
39
40pub trait AsyncJob: Send + Sync + Clone {
42 type Notification: Copy + Send;
44 type Progress: Clone + Default + Send + Sync + PartialEq;
46
47 fn run(
53 &mut self,
54 params: RunParams<Self::Notification, Self::Progress>,
55 ) -> Result<Self::Notification>;
56
57 fn get_progress(&self) -> Self::Progress {
61 Self::Progress::default()
62 }
63}
64
65#[derive(Debug, Clone)]
69pub struct AsyncSingleJob<J: AsyncJob> {
70 next: Arc<Mutex<Option<J>>>,
71 last: Arc<Mutex<Option<J>>>,
72 progress: Arc<RwLock<J::Progress>>,
73 sender: Sender<J::Notification>,
74 pending: Arc<Mutex<()>>,
75}
76
77impl<J: 'static + AsyncJob> AsyncSingleJob<J> {
78 pub fn new(sender: Sender<J::Notification>) -> Self {
80 Self {
81 next: Arc::new(Mutex::new(None)),
82 last: Arc::new(Mutex::new(None)),
83 pending: Arc::new(Mutex::new(())),
84 progress: Arc::new(RwLock::new(J::Progress::default())),
85 sender,
86 }
87 }
88
89 pub fn is_pending(&self) -> bool {
91 self.pending.try_lock().is_err()
92 }
93
94 pub fn cancel(&mut self) -> bool {
97 if let Ok(mut next) = self.next.lock() {
98 if next.is_some() {
99 *next = None;
100 return true;
101 }
102 }
103
104 false
105 }
106
107 pub fn take_last(&self) -> Option<J> {
109 self.last.lock().map_or(None, |mut last| last.take())
110 }
111
112 pub fn spawn(&mut self, task: J) -> bool {
117 self.schedule_next(task);
118 self.check_for_job()
119 }
120
121 pub fn progress(&self) -> Option<J::Progress> {
123 self.progress.read().ok().map(|d| (*d).clone())
124 }
125
126 fn check_for_job(&self) -> bool {
127 if self.is_pending() {
128 return false;
129 }
130
131 if let Some(task) = self.take_next() {
132 let self_clone = (*self).clone();
133 rayon_core::spawn(move || {
134 if let Err(e) = self_clone.run_job(task) {
135 log::error!("async job error: {}", e);
136 }
137 });
138
139 return true;
140 }
141
142 false
143 }
144
145 fn run_job(&self, mut task: J) -> Result<()> {
146 {
148 let _pending = self.pending.lock()?;
149
150 let notification = task.run(RunParams {
151 progress: self.progress.clone(),
152 sender: self.sender.clone(),
153 })?;
154
155 if let Ok(mut last) = self.last.lock() {
156 *last = Some(task);
157 }
158
159 self.sender.send(notification)?;
160 }
161
162 self.check_for_job();
163
164 Ok(())
165 }
166
167 fn schedule_next(&mut self, task: J) {
168 if let Ok(mut next) = self.next.lock() {
169 *next = Some(task);
170 }
171 }
172
173 fn take_next(&self) -> Option<J> {
174 self.next.lock().map_or(None, |mut next| next.take())
175 }
176}
177
178#[cfg(test)]
179mod test {
180 use std::{
181 sync::atomic::{AtomicBool, AtomicU32, Ordering},
182 thread,
183 time::Duration,
184 };
185
186 use crossbeam_channel::unbounded;
187 use pretty_assertions::assert_eq;
188
189 use super::*;
190
191 #[derive(Clone)]
192 struct TestJob {
193 v: Arc<AtomicU32>,
194 finish: Arc<AtomicBool>,
195 value_to_add: u32,
196 notification_value: u32, }
198
199 impl AsyncJob for TestJob {
200 type Notification = u32; type Progress = ();
202
203 fn run(
204 &mut self,
205 params: RunParams<Self::Notification, Self::Progress>, ) -> Result<Self::Notification> {
207 println!("[job] wait");
208
209 while !self.finish.load(Ordering::SeqCst) {
210 std::thread::yield_now();
211 }
212
213 println!("[job] sleep");
214
215 thread::sleep(Duration::from_millis(100));
216
217 println!("[job] done sleeping");
218
219 params.send(self.notification_value)?;
221
222 let res = self.v.fetch_add(self.value_to_add, Ordering::SeqCst);
223
224 println!("[job] value: {res}");
225
226 Ok(self.notification_value) }
228 }
229
230 #[test]
231 #[ignore]
232 fn test_overwrite() {
233 let (sender, receiver) = unbounded();
234
235 let mut job: AsyncSingleJob<TestJob> = AsyncSingleJob::new(sender);
236
237 let shared_v = Arc::new(AtomicU32::new(1));
238 let shared_finish = Arc::new(AtomicBool::new(false));
239
240 let task_template = TestJob {
241 v: Arc::clone(&shared_v),
242 finish: Arc::clone(&shared_finish),
243 value_to_add: 1,
244 notification_value: 10,
245 };
246
247 let first_task = task_template.clone();
249 assert!(job.spawn(first_task));
250
251 for _ in 0..5 {
253 println!("spawn");
254 let next_task = task_template.clone();
255 assert!(!job.spawn(next_task));
256 }
257
258 shared_finish.store(true, Ordering::SeqCst);
259
260 println!("recv");
261 assert_eq!(receiver.recv().unwrap(), 10);
262 assert_eq!(receiver.recv().unwrap(), 10);
263 assert!(receiver.is_empty());
264
265 assert_eq!(shared_v.load(std::sync::atomic::Ordering::SeqCst), 2);
266 }
267
268 fn wait_for_job(job: &AsyncSingleJob<TestJob>) {
269 while job.is_pending() {
270 thread::sleep(Duration::from_millis(10));
271 }
272 }
273
274 #[test]
275 fn test_cancel() {
276 let (sender, receiver) = unbounded();
277
278 let mut job: AsyncSingleJob<TestJob> = AsyncSingleJob::new(sender);
279
280 let initial_task = TestJob {
282 v: Arc::new(AtomicU32::new(1)),
283 finish: Arc::new(AtomicBool::new(false)),
284 value_to_add: 1,
285 notification_value: 20,
286 };
287
288 assert!(job.spawn(initial_task.clone()));
290 initial_task.finish.store(true, Ordering::SeqCst); thread::sleep(Duration::from_millis(10)); for _ in 0..5 {
295 println!("spawn");
296 assert!(!job.spawn(initial_task.clone()));
299 }
300
301 println!("cancel");
302 assert!(job.cancel()); wait_for_job(&job);
306
307 println!("recv");
308 assert_eq!(receiver.recv().unwrap(), 20);
310 println!("received");
311
312 let completed_job = job.take_last().expect("Should have a completed job");
314
315 assert_eq!(completed_job.v.load(std::sync::atomic::Ordering::SeqCst), 2);
318 }
319}