Skip to main content

beetry_core/leaf/
action.rs

1use core::fmt;
2use std::{sync::Arc, time::Duration};
3
4use anyhow::Result;
5use tracing::{debug, error};
6
7use crate::{
8    ActionTask, Node, TickStatus,
9    task::{RegisterTask, TaskHandle, TaskStatus},
10};
11
12#[cfg_attr(any(test, feature = "mock"), mockall::automock)]
13/// Behavior contract for user-defined action nodes.
14///
15/// Implementors define the [`ActionTask`] to execute when the action starts,
16/// and can react to task lifecycle updates through the provided hooks.
17///
18/// For more details, see the Action Lifecycle chapter in the book:
19/// <https://beetry.pages.dev/runtime/action-lifecycle.html>.
20pub trait Behavior {
21    /// Construct the task that should be scheduled for this action.
22    fn task(&mut self) -> Result<ActionTask>;
23
24    /// Reset action state.
25    fn reset(&mut self) {}
26
27    /// Hook called on [`TaskStatus::Running`].
28    fn on_running(&mut self) -> Result<()> {
29        Ok(())
30    }
31    /// Hook called on [`TaskStatus::Success`].
32    fn on_success(&mut self) -> Result<()> {
33        Ok(())
34    }
35
36    /// Hook called on [`TaskStatus::Failure`]
37    fn on_failure(&mut self) -> Result<()> {
38        Ok(())
39    }
40
41    /// Hook called after the action is aborted.
42    fn on_aborted(&mut self) -> Result<()> {
43        Ok(())
44    }
45}
46
47pub type BoxBehavior = Box<dyn Behavior>;
48impl Behavior for BoxBehavior {
49    fn task(&mut self) -> Result<ActionTask> {
50        (**self).task()
51    }
52    fn reset(&mut self) {
53        (**self).reset();
54    }
55    fn on_running(&mut self) -> Result<()> {
56        (**self).on_running()
57    }
58    fn on_success(&mut self) -> Result<()> {
59        (**self).on_success()
60    }
61    fn on_failure(&mut self) -> Result<()> {
62        (**self).on_failure()
63    }
64    fn on_aborted(&mut self) -> Result<()> {
65        (**self).on_aborted()
66    }
67}
68
69#[must_use]
70fn dispatch_hooks(behavior: &mut impl Behavior, status: TaskStatus) -> TaskStatus {
71    match status {
72        TaskStatus::Success => behavior.on_success(),
73        TaskStatus::Running => behavior.on_running(),
74        TaskStatus::Failure => behavior.on_failure(),
75        TaskStatus::Aborted => behavior.on_aborted(),
76    }
77    .inspect_err(|e| error!("error during action hook invocation: {e}"))
78    .map(|()| status)
79    .unwrap_or(TaskStatus::Failure)
80}
81
82/// Action leaf node that bridges the synchronous `tick` interface and the
83/// executor.
84///
85/// `Action` uses the user-provided [`Behavior`] to create an [`ActionTask`],
86/// registers it through [`RegisterTask`], and reports progress through
87/// [`TickStatus`] on later ticks.
88///
89/// See the runtime execution chapter in the book for a more detailed
90/// explanation:
91/// <https://beetry.pages.dev/runtime/execution.html>.
92pub struct Action<R, TH, B>
93where
94    R: RegisterTask<TH>,
95    TH: TaskHandle,
96    B: Behavior,
97{
98    behavior: B,
99    registry: Arc<R>,
100    abort_poll_interval: Duration,
101    state: State<TH>,
102}
103
104impl<R, TH, B> Action<R, TH, B>
105where
106    R: RegisterTask<TH>,
107    TH: TaskHandle,
108    B: Behavior,
109{
110    pub fn new(behavior: B, registry: Arc<R>, abort_poll_interval: Duration) -> Self {
111        Self {
112            behavior,
113            registry,
114            abort_poll_interval,
115            state: State::Idle,
116        }
117    }
118}
119
120impl<R, TH, B> Node for Action<R, TH, B>
121where
122    R: RegisterTask<TH>,
123    TH: TaskHandle,
124    B: Behavior,
125{
126    fn tick(&mut self) -> TickStatus {
127        match &mut self.state {
128            State::Idle => match self.behavior.task() {
129                Ok(task) => match self.registry.register(task) {
130                    Ok(handle) => {
131                        self.state = State::Running(handle);
132                        TickStatus::Running
133                    }
134                    Err(e) => {
135                        error!("task registration failed: {e}");
136                        TickStatus::Failure
137                    }
138                },
139                Err(e) => {
140                    error!("creating task failed: {e}");
141                    TickStatus::Failure
142                }
143            },
144            State::Running(handle) => {
145                let task_status = handle.query();
146                // calling hooks can modify the status of the task
147                let task_status = dispatch_hooks(&mut self.behavior, task_status);
148
149                let status: TickStatus = task_status.try_into().unwrap();
150                if status.is_terminal() {
151                    self.state = State::Idle;
152                }
153
154                status
155            }
156        }
157    }
158
159    fn reset(&mut self) {
160        assert!(
161            matches!(self.state, State::Idle),
162            "requested action reset during task execution"
163        );
164        self.behavior.reset();
165    }
166
167    fn abort(&mut self) {
168        let mut on_aborted = || {
169            if let Err(e) = self.behavior.on_aborted() {
170                error!("on aborted hook failed: {e}");
171            }
172        };
173
174        match &mut self.state {
175            State::Idle => {
176                on_aborted();
177            }
178            State::Running(task_handle) => {
179                task_handle.abort();
180                loop {
181                    let status = task_handle.query();
182                    if status.is_terminal() {
183                        debug!("aborted task terminal status: {status:?}");
184                        break;
185                    }
186                    std::thread::sleep(self.abort_poll_interval);
187                }
188                on_aborted();
189                debug!("switching state to idle");
190                self.state = State::Idle;
191            }
192        }
193    }
194}
195
196enum State<TH>
197where
198    TH: TaskHandle,
199{
200    Idle,
201    Running(TH),
202}
203
204impl<TH> fmt::Display for State<TH>
205where
206    TH: TaskHandle,
207{
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        match self {
210            Self::Idle => write!(f, "Idle"),
211            Self::Running(_) => write!(f, "Running"),
212        }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use std::str::FromStr;
219
220    use bon::builder;
221    use mockall::mock;
222
223    use super::*;
224    use crate::{
225        Task, TaskDescription,
226        task::{AbortTask, MockRegisterTask, QueryTask},
227    };
228
229    const DEFAULT_ABORT_INTERVAL: Duration = Duration::from_millis(10);
230
231    mock! {
232        TaskHandle {}
233
234    impl QueryTask for TaskHandle {
235        fn query(&mut self) -> TaskStatus;
236    }
237
238    impl AbortTask for TaskHandle {
239        fn abort(&mut self);
240    }
241    }
242
243    struct TaskStub;
244
245    impl TaskStub {
246        fn new() -> Self {
247            Self {}
248        }
249    }
250
251    impl Task for TaskStub {
252        async fn run(self) -> TickStatus {
253            TickStatus::Success
254        }
255        fn task_desc(&self) -> TaskDescription {
256            TaskDescription::from_str("TaskStub").unwrap()
257        }
258    }
259
260    #[builder]
261    fn task_handle(
262        query_times: usize,
263        statuses: Vec<TaskStatus>,
264        abort_times: Option<usize>,
265    ) -> MockTaskHandle {
266        let mut m = MockTaskHandle::new();
267        let mut it = statuses.into_iter();
268        m.expect_query()
269            .returning(move || it.next().unwrap())
270            .times(query_times);
271
272        if let Some(abort_times) = abort_times {
273            m.expect_abort().times(abort_times).return_const(());
274        }
275        m
276    }
277
278    #[test]
279    fn action_success() {
280        let mut registry = MockRegisterTask::<MockTaskHandle>::new();
281        registry
282            .expect_register()
283            .returning(|_| {
284                Ok(task_handle()
285                    .query_times(1)
286                    .statuses(vec![TaskStatus::Success])
287                    .call())
288            })
289            .once();
290
291        let mut behavior = MockBehavior::new();
292        behavior
293            .expect_task()
294            .returning(|| Ok(ActionTask::new(TaskStub::new())));
295        behavior
296            .expect_on_success()
297            .once()
298            .returning(|| Result::Ok(()));
299
300        let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
301        assert_eq!(action.tick(), TickStatus::Running);
302        assert_eq!(action.tick(), TickStatus::Success);
303    }
304
305    #[test]
306    fn action_running() {
307        let mut registry = MockRegisterTask::<MockTaskHandle>::new();
308        registry
309            .expect_register()
310            .returning(|_| {
311                Ok(task_handle()
312                    .query_times(1)
313                    .statuses(vec![TaskStatus::Running])
314                    .call())
315            })
316            .once();
317
318        let mut behavior = MockBehavior::new();
319        behavior
320            .expect_task()
321            .returning(|| Ok(ActionTask::new(TaskStub::new())));
322        behavior
323            .expect_on_running()
324            .once()
325            .returning(|| Result::Ok(()));
326
327        let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
328        assert_eq!(action.tick(), TickStatus::Running);
329        assert_eq!(action.tick(), TickStatus::Running);
330    }
331
332    #[test]
333    fn action_failure() {
334        let mut registry = MockRegisterTask::<MockTaskHandle>::new();
335        registry
336            .expect_register()
337            .returning(|_| {
338                Ok(task_handle()
339                    .query_times(1)
340                    .statuses(vec![TaskStatus::Failure])
341                    .call())
342            })
343            .once();
344
345        let mut behavior = MockBehavior::new();
346        behavior
347            .expect_task()
348            .returning(|| Ok(ActionTask::new(TaskStub::new())));
349        behavior
350            .expect_on_failure()
351            .once()
352            .returning(|| Result::Ok(()));
353
354        let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
355
356        assert_eq!(action.tick(), TickStatus::Running);
357        assert_eq!(action.tick(), TickStatus::Failure);
358    }
359
360    #[test]
361    fn task_creation_failure() {
362        let registry = MockRegisterTask::<MockTaskHandle>::new();
363
364        let mut behavior = MockBehavior::new();
365        behavior
366            .expect_task()
367            .returning(|| Err(anyhow::anyhow!("task creation failed")));
368
369        let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
370        assert_eq!(action.tick(), TickStatus::Failure);
371    }
372
373    #[test]
374    fn task_registration_failure() {
375        let mut registry = MockRegisterTask::<MockTaskHandle>::new();
376        registry
377            .expect_register()
378            .returning(|_| Err(anyhow::anyhow!("registration failed")))
379            .once();
380
381        let mut behavior = MockBehavior::new();
382        behavior
383            .expect_task()
384            .returning(|| Ok(ActionTask::new(TaskStub::new())));
385
386        let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
387        assert_eq!(action.tick(), TickStatus::Failure);
388    }
389
390    #[test]
391    fn action_abort_when_running() {
392        let mut registry = MockRegisterTask::<MockTaskHandle>::new();
393        registry
394            .expect_register()
395            .returning(|_| {
396                Ok(task_handle()
397                    .query_times(3)
398                    .statuses(vec![
399                        TaskStatus::Running,
400                        TaskStatus::Running,
401                        TaskStatus::Aborted,
402                    ])
403                    .abort_times(1)
404                    .call())
405            })
406            .once();
407
408        let mut behavior = MockBehavior::new();
409        behavior
410            .expect_task()
411            .returning(|| Ok(ActionTask::new(TaskStub::new())));
412        behavior
413            .expect_on_aborted()
414            .once()
415            .returning(|| Result::Ok(()));
416        behavior
417            .expect_on_running()
418            .once()
419            .returning(|| Result::Ok(()));
420
421        let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
422
423        assert_eq!(action.tick(), TickStatus::Running);
424        assert_eq!(action.tick(), TickStatus::Running);
425        action.abort();
426    }
427
428    #[test]
429    fn action_abort_when_idle() {
430        let registry = MockRegisterTask::<MockTaskHandle>::new();
431
432        let mut behavior = MockBehavior::new();
433        behavior
434            .expect_on_aborted()
435            .once()
436            .returning(|| Result::Ok(()));
437
438        let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
439
440        action.abort();
441    }
442
443    #[test]
444    fn action_reset() {
445        let registry = MockRegisterTask::<MockTaskHandle>::new();
446
447        let mut behavior = MockBehavior::new();
448        behavior.expect_reset().once().return_const(());
449
450        let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
451        action.reset();
452    }
453}