gnostr_asyncgit/asyncjob/
mod.rs1#![deny(clippy::expect_used)]
4
5use std::sync::{Arc, Mutex, RwLock};
6
7use crossbeam_channel::Sender;
8
9use crate::error::Result;
10
11pub struct RunParams<T: Copy + Send, P: Clone + Send + Sync + PartialEq> {
14 sender: Sender<T>,
15 progress: Arc<RwLock<P>>,
16}
17
18impl<T: Copy + Send, P: Clone + Send + Sync + PartialEq> RunParams<T, P> {
19 pub fn send(&self, notification: T) -> Result<()> {
25 self.sender.send(notification)?;
26 Ok(())
27 }
28
29 pub fn set_progress(&self, p: P) -> Result<bool> {
31 Ok(if *self.progress.read()? == p {
32 false
33 } else {
34 *(self.progress.write()?) = p;
35 true
36 })
37 }
38}
39
40pub trait AsyncJob: Send + Sync + Clone {
42 type Notification: Copy + Send;
44 type Progress: Clone + Default + Send + Sync + PartialEq;
46
47 fn run(
53 &mut self,
54 params: RunParams<Self::Notification, Self::Progress>,
55 ) -> Result<Self::Notification>;
56
57 fn get_progress(&self) -> Self::Progress {
61 Self::Progress::default()
62 }
63}
64
65#[derive(Debug, Clone)]
69pub struct AsyncSingleJob<J: AsyncJob> {
70 next: Arc<Mutex<Option<J>>>,
71 last: Arc<Mutex<Option<J>>>,
72 progress: Arc<RwLock<J::Progress>>,
73 sender: Sender<J::Notification>,
74 pending: Arc<Mutex<()>>,
75}
76
77impl<J: 'static + AsyncJob> AsyncSingleJob<J> {
78 pub fn new(sender: Sender<J::Notification>) -> Self {
80 Self {
81 next: Arc::new(Mutex::new(None)),
82 last: Arc::new(Mutex::new(None)),
83 pending: Arc::new(Mutex::new(())),
84 progress: Arc::new(RwLock::new(J::Progress::default())),
85 sender,
86 }
87 }
88
89 pub fn is_pending(&self) -> bool {
91 self.pending.try_lock().is_err()
92 }
93
94 pub fn cancel(&mut self) -> bool {
97 if let Ok(mut next) = self.next.lock() {
98 if next.is_some() {
99 *next = None;
100 return true;
101 }
102 }
103
104 false
105 }
106
107 pub fn take_last(&self) -> Option<J> {
109 self.last.lock().map_or(None, |mut last| last.take())
110 }
111
112 pub fn spawn(&mut self, task: J) -> bool {
117 self.schedule_next(task);
118 self.check_for_job()
119 }
120
121 pub fn progress(&self) -> Option<J::Progress> {
123 self.progress.read().ok().map(|d| (*d).clone())
124 }
125
126 fn check_for_job(&self) -> bool {
127 if self.is_pending() {
128 return false;
129 }
130
131 if let Some(task) = self.take_next() {
132 let self_clone = (*self).clone();
133 rayon_core::spawn(move || {
134 if let Err(e) = self_clone.run_job(task) {
135 log::error!("async job error: {}", e);
136 }
137 });
138
139 return true;
140 }
141
142 false
143 }
144
145 fn run_job(&self, mut task: J) -> Result<()> {
146 {
148 let _pending = self.pending.lock()?;
149
150 let notification = task.run(RunParams {
151 progress: self.progress.clone(),
152 sender: self.sender.clone(),
153 })?;
154
155 if let Ok(mut last) = self.last.lock() {
156 *last = Some(task);
157 }
158
159 self.sender.send(notification)?;
160 }
161
162 self.check_for_job();
163
164 Ok(())
165 }
166
167 fn schedule_next(&mut self, task: J) {
168 if let Ok(mut next) = self.next.lock() {
169 *next = Some(task);
170 }
171 }
172
173 fn take_next(&self) -> Option<J> {
174 self.next.lock().map_or(None, |mut next| next.take())
175 }
176}
177
178#[cfg(test)]
179mod test {
180 use std::{
181 sync::atomic::{AtomicBool, AtomicU32, Ordering},
182 thread,
183 time::Duration,
184 };
185
186 use crossbeam_channel::unbounded;
187 use pretty_assertions::assert_eq;
188
189 use super::*;
190
191 #[derive(Clone)]
192 struct TestJob {
193 v: Arc<AtomicU32>,
194 finish: Arc<AtomicBool>,
195 value_to_add: u32,
196 }
197
198 type TestNotification = ();
199
200 impl AsyncJob for TestJob {
201 type Notification = TestNotification;
202 type Progress = ();
203
204 fn run(
205 &mut self,
206 _params: RunParams<Self::Notification, Self::Progress>,
207 ) -> Result<Self::Notification> {
208 println!("[job] wait");
209
210 while !self.finish.load(Ordering::SeqCst) {
211 std::thread::yield_now();
212 }
213
214 println!("[job] sleep");
215
216 thread::sleep(Duration::from_millis(100));
217
218 println!("[job] done sleeping");
219
220 let res = self.v.fetch_add(self.value_to_add, Ordering::SeqCst);
221
222 println!("[job] value: {res}");
223
224 Ok(())
225 }
226 }
227
228 #[test]
229 fn test_overwrite() {
230 let (sender, receiver) = unbounded();
231
232 let mut job: AsyncSingleJob<TestJob> = AsyncSingleJob::new(sender);
233
234 let task = TestJob {
235 v: Arc::new(AtomicU32::new(1)),
236 finish: Arc::new(AtomicBool::new(false)),
237 value_to_add: 1,
238 };
239
240 assert!(job.spawn(task.clone()));
241 task.finish.store(true, Ordering::SeqCst);
242 thread::sleep(Duration::from_millis(10));
243
244 for _ in 0..5 {
245 println!("spawn");
246 assert!(!job.spawn(task.clone()));
247 }
248
249 println!("recv");
250 receiver.recv().unwrap();
251 receiver.recv().unwrap();
252 assert!(receiver.is_empty());
253
254 assert_eq!(task.v.load(std::sync::atomic::Ordering::SeqCst), 3);
255 }
256
257 fn wait_for_job(job: &AsyncSingleJob<TestJob>) {
258 while job.is_pending() {
259 thread::sleep(Duration::from_millis(10));
260 }
261 }
262
263 #[test]
264 fn test_cancel() {
265 let (sender, receiver) = unbounded();
266
267 let mut job: AsyncSingleJob<TestJob> = AsyncSingleJob::new(sender);
268
269 let task = TestJob {
270 v: Arc::new(AtomicU32::new(1)),
271 finish: Arc::new(AtomicBool::new(false)),
272 value_to_add: 1,
273 };
274
275 assert!(job.spawn(task.clone()));
276 task.finish.store(true, Ordering::SeqCst);
277 thread::sleep(Duration::from_millis(10));
278
279 for _ in 0..5 {
280 println!("spawn");
281 assert!(!job.spawn(task.clone()));
282 }
283
284 println!("cancel");
285 assert!(job.cancel());
286
287 task.finish.store(true, Ordering::SeqCst);
288
289 wait_for_job(&job);
290
291 println!("recv");
292 receiver.recv().unwrap();
293 println!("received");
294
295 assert_eq!(task.v.load(std::sync::atomic::Ordering::SeqCst), 2);
296 }
297}