1use crate::job_manager::SendFuture;
2use async_trait::async_trait;
3use std::error::Error;
4use std::fmt::Display;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pub enum ProceedWithExecution {
10 True,
11 False,
12}
13
14#[derive(Debug)]
15pub struct JobError {
16 pub reason: String,
17}
18
19impl<T: Into<String>> From<T> for JobError {
20 fn from(value: T) -> Self {
21 Self {
22 reason: value.into(),
23 }
24 }
25}
26
27impl Display for JobError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "{reason}", reason = self.reason)
30 }
31}
32
33impl Error for JobError {}
34
35#[async_trait]
36pub trait ExecutableJob: Send + 'static {
37 async fn pre_job_hook(&mut self) -> Result<ProceedWithExecution, JobError>;
38 async fn job(&mut self) -> Result<(), JobError>;
39 async fn post_job_hook(&mut self) -> Result<(), JobError>;
40 async fn catch(&mut self);
41
42 async fn execute(&mut self) -> Result<(), JobError> {
43 match self.pre_job_hook().await? {
44 ProceedWithExecution::True => match self.job().await {
45 Ok(_) => match self.post_job_hook().await {
46 Ok(_) => Ok(()),
47 Err(err) => {
48 self.catch().await;
49 Err(err)
50 }
51 },
52 Err(err) => {
53 self.catch().await;
54 Err(err)
55 }
56 },
57 ProceedWithExecution::False => Ok(()),
58 }
59 }
60}
61
62pub struct ExecutableJobWrapper<Pre: ?Sized, Protocol: ?Sized, Post: ?Sized, Catch: ?Sized> {
63 pre: Pin<Box<Pre>>,
64 protocol: Pin<Box<Protocol>>,
65 post: Pin<Box<Post>>,
66 catch: Pin<Box<Catch>>,
67}
68
69#[async_trait]
70impl<Pre: ?Sized, Protocol: ?Sized, Post: ?Sized, Catch: ?Sized> ExecutableJob
71 for ExecutableJobWrapper<Pre, Protocol, Post, Catch>
72where
73 Pre: SendFuture<'static, Result<ProceedWithExecution, JobError>>,
74 Protocol: SendFuture<'static, Result<(), JobError>>,
75 Post: SendFuture<'static, Result<(), JobError>>,
76 Catch: SendFuture<'static, ()>,
77{
78 async fn pre_job_hook(&mut self) -> Result<ProceedWithExecution, JobError> {
79 self.pre.as_mut().await
80 }
81
82 async fn job(&mut self) -> Result<(), JobError> {
83 self.protocol.as_mut().await
84 }
85
86 async fn post_job_hook(&mut self) -> Result<(), JobError> {
87 self.post.as_mut().await
88 }
89
90 async fn catch(&mut self) {
91 self.catch.as_mut().await
92 }
93}
94
95impl<Pre, Protocol, Post, Catch> ExecutableJobWrapper<Pre, Protocol, Post, Catch>
96where
97 Pre: SendFuture<'static, Result<ProceedWithExecution, JobError>>,
98 Protocol: SendFuture<'static, Result<(), JobError>>,
99 Post: SendFuture<'static, Result<(), JobError>>,
100 Catch: SendFuture<'static, ()>,
101{
102 pub fn new(pre: Pre, protocol: Protocol, post: Post, catch: Catch) -> Self {
103 Self {
104 pre: Box::pin(pre),
105 protocol: Box::pin(protocol),
106 post: Box::pin(post),
107 catch: Box::pin(catch),
108 }
109 }
110}
111
112#[derive(Default)]
113pub struct JobBuilder {
114 pre: Option<Pin<Box<PreJobHook>>>,
115 protocol: Option<Pin<Box<ProtocolJobHook>>>,
116 post: Option<Pin<Box<PostJobHook>>>,
117 catch: Option<Pin<Box<CatchJobHook>>>,
118}
119
120pub type PreJobHook = dyn SendFuture<'static, Result<ProceedWithExecution, JobError>>;
121pub type PostJobHook = dyn SendFuture<'static, Result<(), JobError>>;
122pub type ProtocolJobHook = dyn SendFuture<'static, Result<(), JobError>>;
123pub type CatchJobHook = dyn SendFuture<'static, ()>;
124
125pub struct DefaultPreJobHook;
126impl Future for DefaultPreJobHook {
127 type Output = Result<ProceedWithExecution, JobError>;
128
129 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
130 Poll::Ready(Ok(ProceedWithExecution::True))
131 }
132}
133
134pub struct DefaultPostJobHook;
135impl Future for DefaultPostJobHook {
136 type Output = Result<(), JobError>;
137
138 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
139 Poll::Ready(Ok(()))
140 }
141}
142
143struct DefaultCatchJobHook;
144
145impl Future for DefaultCatchJobHook {
146 type Output = ();
147
148 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
149 Poll::Ready(())
150 }
151}
152
153pub type BuiltExecutableJobWrapper = ExecutableJobWrapper<
154 dyn SendFuture<'static, Result<ProceedWithExecution, JobError>>,
155 dyn SendFuture<'static, Result<(), JobError>>,
156 dyn SendFuture<'static, Result<(), JobError>>,
157 dyn SendFuture<'static, ()>,
158>;
159
160impl JobBuilder {
161 pub fn new() -> Self {
162 Self::default()
163 }
164
165 pub fn pre<Pre>(mut self, pre: Pre) -> Self
166 where
167 Pre: SendFuture<'static, Result<ProceedWithExecution, JobError>>,
168 {
169 self.pre = Some(Box::pin(pre));
170 self
171 }
172
173 pub fn protocol<Protocol>(mut self, protocol: Protocol) -> Self
174 where
175 Protocol: SendFuture<'static, Result<(), JobError>>,
176 {
177 self.protocol = Some(Box::pin(protocol));
178 self
179 }
180
181 pub fn post<Post>(mut self, post: Post) -> Self
182 where
183 Post: SendFuture<'static, Result<(), JobError>>,
184 {
185 self.post = Some(Box::pin(post));
186 self
187 }
188
189 pub fn catch<Catch>(mut self, catch: Catch) -> Self
190 where
191 Catch: SendFuture<'static, ()>,
192 {
193 self.catch = Some(Box::pin(catch));
194 self
195 }
196
197 pub fn build(self) -> BuiltExecutableJobWrapper {
198 let pre = if let Some(pre) = self.pre {
199 pre
200 } else {
201 Box::pin(DefaultPreJobHook)
202 };
203
204 let post = if let Some(post) = self.post {
205 post
206 } else {
207 Box::pin(DefaultPostJobHook)
208 };
209
210 let catch = if let Some(catch) = self.catch {
211 catch
212 } else {
213 Box::pin(DefaultCatchJobHook)
214 };
215
216 let protocol = Box::pin(self.protocol.expect("Must specify protocol"));
217
218 ExecutableJobWrapper {
219 pre,
220 protocol,
221 post,
222 catch,
223 }
224 }
225}
226
227#[cfg(test)]
228#[cfg(not(target_family = "wasm"))]
229mod tests {
230 use crate::job::ExecutableJob;
231 use gadget_io::tokio;
232
233 #[gadget_io::tokio::test]
234 async fn test_executable_job_wrapper_proceed() {
235 let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
236 let counter_clone = counter.clone();
237 let counter_clone2 = counter.clone();
238 let counter_final = counter.clone();
239
240 let pre = async move {
241 counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
242 Ok(super::ProceedWithExecution::True)
243 };
244
245 let protocol = async move {
246 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
247 Ok(())
248 };
249
250 let post = async move {
251 counter_clone2.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
252 Ok(())
253 };
254
255 let catch = async move {};
256
257 let mut job = super::ExecutableJobWrapper::new(pre, protocol, post, catch);
258 job.execute().await.unwrap();
259 assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 3);
260 }
261
262 #[gadget_io::tokio::test]
263 async fn test_executable_job_wrapper_no_proceed() {
264 let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
265 let counter_clone = counter.clone();
266 let counter_clone2 = counter.clone();
267 let counter_final = counter.clone();
268
269 let pre = async move {
270 counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
271 Ok(super::ProceedWithExecution::False)
272 };
273
274 let protocol = async move {
275 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
276 Ok(())
277 };
278
279 let post = async move {
280 counter_clone2.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
281 Ok(())
282 };
283
284 let catch = async move {};
285
286 let mut job = super::ExecutableJobWrapper::new(pre, protocol, post, catch);
287 job.execute().await.unwrap();
288 assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 1);
289 }
290
291 #[gadget_io::tokio::test]
292 async fn test_job_builder() {
293 let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
294 let counter_clone = counter.clone();
295 let counter_clone2 = counter.clone();
296 let counter_final = counter.clone();
297
298 let mut job = super::JobBuilder::new()
299 .pre(async move {
300 counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
301 Ok(super::ProceedWithExecution::True)
302 })
303 .protocol(async move {
304 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
305 Ok(())
306 })
307 .post(async move {
308 counter_clone2.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
309 Ok(())
310 })
311 .build();
312
313 job.execute().await.unwrap();
314 assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 3);
315 }
316
317 #[gadget_io::tokio::test]
318 async fn test_job_builder_no_pre() {
319 let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
320 let counter_clone = counter.clone();
321 let counter_clone2 = counter.clone();
322 let counter_final = counter.clone();
323
324 let mut job = super::JobBuilder::default()
325 .protocol(async move {
326 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
327 Ok(())
328 })
329 .post(async move {
330 counter_clone2.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
331 Ok(())
332 })
333 .build();
334
335 job.execute().await.unwrap();
336 assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 2);
337 }
338
339 #[gadget_io::tokio::test]
340 async fn test_job_builder_no_post() {
341 let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
342 let counter_clone = counter.clone();
343 let counter_final = counter.clone();
344
345 let mut job = super::JobBuilder::default()
346 .pre(async move {
347 counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
348 Ok(super::ProceedWithExecution::True)
349 })
350 .protocol(async move {
351 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
352 Ok(())
353 })
354 .build();
355
356 job.execute().await.unwrap();
357 assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 2);
358 }
359
360 #[gadget_io::tokio::test]
361 async fn test_job_builder_no_pre_no_post() {
362 let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
363 let counter_clone = counter.clone();
364 let counter_final = counter.clone();
365
366 let mut job = super::JobBuilder::default()
367 .protocol(async move {
368 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
369 Ok(())
370 })
371 .build();
372
373 job.execute().await.unwrap();
374 assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 1);
375 }
376
377 #[gadget_io::tokio::test]
378 async fn test_protocol_err_catch_performs_increment() {
379 let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
380 let counter_clone = counter.clone();
381 let counter_clone2 = counter.clone();
382 let counter_final = counter.clone();
383
384 let pre = async move {
385 counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
386 Ok(super::ProceedWithExecution::True)
387 };
388
389 let protocol = async move {
390 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
391 Err(super::JobError::from("Protocol error"))
392 };
393
394 let post = async move { unreachable!("Post should not be called") };
395
396 let catch = async move {
397 counter_clone2.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
398 };
399
400 let mut job = super::ExecutableJobWrapper::new(pre, protocol, post, catch);
401 job.execute().await.unwrap_err();
402 assert_eq!(counter_final.load(std::sync::atomic::Ordering::SeqCst), 3);
403 }
404}