gadget_core/
job.rs

1use crate::job_manager::SendFuture;
2use async_trait::async_trait;
3use std::error::Error;
4use std::fmt::Display;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pub enum ProceedWithExecution {
10    True,
11    False,
12}
13
14#[derive(Debug)]
15pub struct JobError {
16    pub reason: String,
17}
18
19impl<T: Into<String>> From<T> for JobError {
20    fn from(value: T) -> Self {
21        Self {
22            reason: value.into(),
23        }
24    }
25}
26
27impl Display for JobError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "{reason}", reason = self.reason)
30    }
31}
32
33impl Error for JobError {}
34
35#[async_trait]
36pub trait ExecutableJob: Send + 'static {
37    async fn pre_job_hook(&mut self) -> Result<ProceedWithExecution, JobError>;
38    async fn job(&mut self) -> Result<(), JobError>;
39    async fn post_job_hook(&mut self) -> Result<(), JobError>;
40    async fn catch(&mut self);
41
42    async fn execute(&mut self) -> Result<(), JobError> {
43        match self.pre_job_hook().await? {
44            ProceedWithExecution::True => match self.job().await {
45                Ok(_) => match self.post_job_hook().await {
46                    Ok(_) => Ok(()),
47                    Err(err) => {
48                        self.catch().await;
49                        Err(err)
50                    }
51                },
52                Err(err) => {
53                    self.catch().await;
54                    Err(err)
55                }
56            },
57            ProceedWithExecution::False => Ok(()),
58        }
59    }
60}
61
62pub struct ExecutableJobWrapper<Pre: ?Sized, Protocol: ?Sized, Post: ?Sized, Catch: ?Sized> {
63    pre: Pin<Box<Pre>>,
64    protocol: Pin<Box<Protocol>>,
65    post: Pin<Box<Post>>,
66    catch: Pin<Box<Catch>>,
67}
68
69#[async_trait]
70impl<Pre: ?Sized, Protocol: ?Sized, Post: ?Sized, Catch: ?Sized> ExecutableJob
71    for ExecutableJobWrapper<Pre, Protocol, Post, Catch>
72where
73    Pre: SendFuture<'static, Result<ProceedWithExecution, JobError>>,
74    Protocol: SendFuture<'static, Result<(), JobError>>,
75    Post: SendFuture<'static, Result<(), JobError>>,
76    Catch: SendFuture<'static, ()>,
77{
78    async fn pre_job_hook(&mut self) -> Result<ProceedWithExecution, JobError> {
79        self.pre.as_mut().await
80    }
81
82    async fn job(&mut self) -> Result<(), JobError> {
83        self.protocol.as_mut().await
84    }
85
86    async fn post_job_hook(&mut self) -> Result<(), JobError> {
87        self.post.as_mut().await
88    }
89
90    async fn catch(&mut self) {
91        self.catch.as_mut().await
92    }
93}
94
95impl<Pre, Protocol, Post, Catch> ExecutableJobWrapper<Pre, Protocol, Post, Catch>
96where
97    Pre: SendFuture<'static, Result<ProceedWithExecution, JobError>>,
98    Protocol: SendFuture<'static, Result<(), JobError>>,
99    Post: SendFuture<'static, Result<(), JobError>>,
100    Catch: SendFuture<'static, ()>,
101{
102    pub fn new(pre: Pre, protocol: Protocol, post: Post, catch: Catch) -> Self {
103        Self {
104            pre: Box::pin(pre),
105            protocol: Box::pin(protocol),
106            post: Box::pin(post),
107            catch: Box::pin(catch),
108        }
109    }
110}
111
112#[derive(Default)]
113pub struct JobBuilder {
114    pre: Option<Pin<Box<PreJobHook>>>,
115    protocol: Option<Pin<Box<ProtocolJobHook>>>,
116    post: Option<Pin<Box<PostJobHook>>>,
117    catch: Option<Pin<Box<CatchJobHook>>>,
118}
119
120pub type PreJobHook = dyn SendFuture<'static, Result<ProceedWithExecution, JobError>>;
121pub type PostJobHook = dyn SendFuture<'static, Result<(), JobError>>;
122pub type ProtocolJobHook = dyn SendFuture<'static, Result<(), JobError>>;
123pub type CatchJobHook = dyn SendFuture<'static, ()>;
124
125pub struct DefaultPreJobHook;
126impl Future for DefaultPreJobHook {
127    type Output = Result<ProceedWithExecution, JobError>;
128
129    fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
130        Poll::Ready(Ok(ProceedWithExecution::True))
131    }
132}
133
134pub struct DefaultPostJobHook;
135impl Future for DefaultPostJobHook {
136    type Output = Result<(), JobError>;
137
138    fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
139        Poll::Ready(Ok(()))
140    }
141}
142
143struct DefaultCatchJobHook;
144
145impl Future for DefaultCatchJobHook {
146    type Output = ();
147
148    fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
149        Poll::Ready(())
150    }
151}
152
153pub type BuiltExecutableJobWrapper = ExecutableJobWrapper<
154    dyn SendFuture<'static, Result<ProceedWithExecution, JobError>>,
155    dyn SendFuture<'static, Result<(), JobError>>,
156    dyn SendFuture<'static, Result<(), JobError>>,
157    dyn SendFuture<'static, ()>,
158>;
159
160impl JobBuilder {
161    pub fn new() -> Self {
162        Self::default()
163    }
164
165    pub fn pre<Pre>(mut self, pre: Pre) -> Self
166    where
167        Pre: SendFuture<'static, Result<ProceedWithExecution, JobError>>,
168    {
169        self.pre = Some(Box::pin(pre));
170        self
171    }
172
173    pub fn protocol<Protocol>(mut self, protocol: Protocol) -> Self
174    where
175        Protocol: SendFuture<'static, Result<(), JobError>>,
176    {
177        self.protocol = Some(Box::pin(protocol));
178        self
179    }
180
181    pub fn post<Post>(mut self, post: Post) -> Self
182    where
183        Post: SendFuture<'static, Result<(), JobError>>,
184    {
185        self.post = Some(Box::pin(post));
186        self
187    }
188
189    pub fn catch<Catch>(mut self, catch: Catch) -> Self
190    where
191        Catch: SendFuture<'static, ()>,
192    {
193        self.catch = Some(Box::pin(catch));
194        self
195    }
196
197    pub fn build(self) -> BuiltExecutableJobWrapper {
198        let pre = if let Some(pre) = self.pre {
199            pre
200        } else {
201            Box::pin(DefaultPreJobHook)
202        };
203
204        let post = if let Some(post) = self.post {
205            post
206        } else {
207            Box::pin(DefaultPostJobHook)
208        };
209
210        let catch = if let Some(catch) = self.catch {
211            catch
212        } else {
213            Box::pin(DefaultCatchJobHook)
214        };
215
216        let protocol = Box::pin(self.protocol.expect("Must specify protocol"));
217
218        ExecutableJobWrapper {
219            pre,
220            protocol,
221            post,
222            catch,
223        }
224    }
225}
226
227#[cfg(test)]
228#[cfg(not(target_family = "wasm"))]
229mod tests {
230    use crate::job::ExecutableJob;
231    use gadget_io::tokio;
232
233    #[gadget_io::tokio::test]
234    async fn test_executable_job_wrapper_proceed() {
235        let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
236        let counter_clone = counter.clone();
237        let counter_clone2 = counter.clone();
238        let counter_final = counter.clone();
239
240        let pre = async move {
241            counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
242            Ok(super::ProceedWithExecution::True)
243        };
244
245        let protocol = async move {
246            counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
247            Ok(())
248        };
249
250        let post = async move {
251            counter_clone2.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
252            Ok(())
253        };
254
255        let catch = async move {};
256
257        let mut job = super::ExecutableJobWrapper::new(pre, protocol, post, catch);
258        job.execute().await.unwrap();
259        assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 3);
260    }
261
262    #[gadget_io::tokio::test]
263    async fn test_executable_job_wrapper_no_proceed() {
264        let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
265        let counter_clone = counter.clone();
266        let counter_clone2 = counter.clone();
267        let counter_final = counter.clone();
268
269        let pre = async move {
270            counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
271            Ok(super::ProceedWithExecution::False)
272        };
273
274        let protocol = async move {
275            counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
276            Ok(())
277        };
278
279        let post = async move {
280            counter_clone2.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
281            Ok(())
282        };
283
284        let catch = async move {};
285
286        let mut job = super::ExecutableJobWrapper::new(pre, protocol, post, catch);
287        job.execute().await.unwrap();
288        assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 1);
289    }
290
291    #[gadget_io::tokio::test]
292    async fn test_job_builder() {
293        let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
294        let counter_clone = counter.clone();
295        let counter_clone2 = counter.clone();
296        let counter_final = counter.clone();
297
298        let mut job = super::JobBuilder::new()
299            .pre(async move {
300                counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
301                Ok(super::ProceedWithExecution::True)
302            })
303            .protocol(async move {
304                counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
305                Ok(())
306            })
307            .post(async move {
308                counter_clone2.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
309                Ok(())
310            })
311            .build();
312
313        job.execute().await.unwrap();
314        assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 3);
315    }
316
317    #[gadget_io::tokio::test]
318    async fn test_job_builder_no_pre() {
319        let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
320        let counter_clone = counter.clone();
321        let counter_clone2 = counter.clone();
322        let counter_final = counter.clone();
323
324        let mut job = super::JobBuilder::default()
325            .protocol(async move {
326                counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
327                Ok(())
328            })
329            .post(async move {
330                counter_clone2.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
331                Ok(())
332            })
333            .build();
334
335        job.execute().await.unwrap();
336        assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 2);
337    }
338
339    #[gadget_io::tokio::test]
340    async fn test_job_builder_no_post() {
341        let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
342        let counter_clone = counter.clone();
343        let counter_final = counter.clone();
344
345        let mut job = super::JobBuilder::default()
346            .pre(async move {
347                counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
348                Ok(super::ProceedWithExecution::True)
349            })
350            .protocol(async move {
351                counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
352                Ok(())
353            })
354            .build();
355
356        job.execute().await.unwrap();
357        assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 2);
358    }
359
360    #[gadget_io::tokio::test]
361    async fn test_job_builder_no_pre_no_post() {
362        let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
363        let counter_clone = counter.clone();
364        let counter_final = counter.clone();
365
366        let mut job = super::JobBuilder::default()
367            .protocol(async move {
368                counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
369                Ok(())
370            })
371            .build();
372
373        job.execute().await.unwrap();
374        assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 1);
375    }
376
377    #[gadget_io::tokio::test]
378    async fn test_protocol_err_catch_performs_increment() {
379        let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
380        let counter_clone = counter.clone();
381        let counter_clone2 = counter.clone();
382        let counter_final = counter.clone();
383
384        let pre = async move {
385            counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
386            Ok(super::ProceedWithExecution::True)
387        };
388
389        let protocol = async move {
390            counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
391            Err(super::JobError::from("Protocol error"))
392        };
393
394        let post = async move { unreachable!("Post should not be called") };
395
396        let catch = async move {
397            counter_clone2.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
398        };
399
400        let mut job = super::ExecutableJobWrapper::new(pre, protocol, post, catch);
401        job.execute().await.unwrap_err();
402        assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 3);
403    }
404}