1use core::fmt;
2use std::{sync::Arc, time::Duration};
3
4use anyhow::Result;
5use tracing::{debug, error};
6
7use crate::{
8 ActionTask, Node, TickStatus,
9 task::{RegisterTask, TaskHandle, TaskStatus},
10};
11
12#[cfg_attr(any(test, feature = "mock"), mockall::automock)]
13pub trait Behavior {
21 fn task(&mut self) -> Result<ActionTask>;
23
24 fn reset(&mut self) {}
26
27 fn on_running(&mut self) -> Result<()> {
29 Ok(())
30 }
31 fn on_success(&mut self) -> Result<()> {
33 Ok(())
34 }
35
36 fn on_failure(&mut self) -> Result<()> {
38 Ok(())
39 }
40
41 fn on_aborted(&mut self) -> Result<()> {
43 Ok(())
44 }
45}
46
47pub type BoxBehavior = Box<dyn Behavior>;
48impl Behavior for BoxBehavior {
49 fn task(&mut self) -> Result<ActionTask> {
50 (**self).task()
51 }
52 fn reset(&mut self) {
53 (**self).reset();
54 }
55 fn on_running(&mut self) -> Result<()> {
56 (**self).on_running()
57 }
58 fn on_success(&mut self) -> Result<()> {
59 (**self).on_success()
60 }
61 fn on_failure(&mut self) -> Result<()> {
62 (**self).on_failure()
63 }
64 fn on_aborted(&mut self) -> Result<()> {
65 (**self).on_aborted()
66 }
67}
68
69#[must_use]
70fn dispatch_hooks(behavior: &mut impl Behavior, status: TaskStatus) -> TaskStatus {
71 match status {
72 TaskStatus::Success => behavior.on_success(),
73 TaskStatus::Running => behavior.on_running(),
74 TaskStatus::Failure => behavior.on_failure(),
75 TaskStatus::Aborted => behavior.on_aborted(),
76 }
77 .inspect_err(|e| error!("error during action hook invocation: {e}"))
78 .map(|()| status)
79 .unwrap_or(TaskStatus::Failure)
80}
81
82pub struct Action<R, TH, B>
93where
94 R: RegisterTask<TH>,
95 TH: TaskHandle,
96 B: Behavior,
97{
98 behavior: B,
99 registry: Arc<R>,
100 abort_poll_interval: Duration,
101 state: State<TH>,
102}
103
104impl<R, TH, B> Action<R, TH, B>
105where
106 R: RegisterTask<TH>,
107 TH: TaskHandle,
108 B: Behavior,
109{
110 pub fn new(behavior: B, registry: Arc<R>, abort_poll_interval: Duration) -> Self {
111 Self {
112 behavior,
113 registry,
114 abort_poll_interval,
115 state: State::Idle,
116 }
117 }
118}
119
120impl<R, TH, B> Node for Action<R, TH, B>
121where
122 R: RegisterTask<TH>,
123 TH: TaskHandle,
124 B: Behavior,
125{
126 fn tick(&mut self) -> TickStatus {
127 match &mut self.state {
128 State::Idle => match self.behavior.task() {
129 Ok(task) => match self.registry.register(task) {
130 Ok(handle) => {
131 self.state = State::Running(handle);
132 TickStatus::Running
133 }
134 Err(e) => {
135 error!("task registration failed: {e}");
136 TickStatus::Failure
137 }
138 },
139 Err(e) => {
140 error!("creating task failed: {e}");
141 TickStatus::Failure
142 }
143 },
144 State::Running(handle) => {
145 let task_status = handle.query();
146 let task_status = dispatch_hooks(&mut self.behavior, task_status);
148
149 let status: TickStatus = task_status.try_into().unwrap();
150 if status.is_terminal() {
151 self.state = State::Idle;
152 }
153
154 status
155 }
156 }
157 }
158
159 fn reset(&mut self) {
160 assert!(
161 matches!(self.state, State::Idle),
162 "requested action reset during task execution"
163 );
164 self.behavior.reset();
165 }
166
167 fn abort(&mut self) {
168 let mut on_aborted = || {
169 if let Err(e) = self.behavior.on_aborted() {
170 error!("on aborted hook failed: {e}");
171 }
172 };
173
174 match &mut self.state {
175 State::Idle => {
176 on_aborted();
177 }
178 State::Running(task_handle) => {
179 task_handle.abort();
180 loop {
181 let status = task_handle.query();
182 if status.is_terminal() {
183 debug!("aborted task terminal status: {status:?}");
184 break;
185 }
186 std::thread::sleep(self.abort_poll_interval);
187 }
188 on_aborted();
189 debug!("switching state to idle");
190 self.state = State::Idle;
191 }
192 }
193 }
194}
195
196enum State<TH>
197where
198 TH: TaskHandle,
199{
200 Idle,
201 Running(TH),
202}
203
204impl<TH> fmt::Display for State<TH>
205where
206 TH: TaskHandle,
207{
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 match self {
210 Self::Idle => write!(f, "Idle"),
211 Self::Running(_) => write!(f, "Running"),
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use std::str::FromStr;
219
220 use bon::builder;
221 use mockall::mock;
222
223 use super::*;
224 use crate::{
225 Task, TaskDescription,
226 task::{AbortTask, MockRegisterTask, QueryTask},
227 };
228
229 const DEFAULT_ABORT_INTERVAL: Duration = Duration::from_millis(10);
230
231 mock! {
232 TaskHandle {}
233
234 impl QueryTask for TaskHandle {
235 fn query(&mut self) -> TaskStatus;
236 }
237
238 impl AbortTask for TaskHandle {
239 fn abort(&mut self);
240 }
241 }
242
243 struct TaskStub;
244
245 impl TaskStub {
246 fn new() -> Self {
247 Self {}
248 }
249 }
250
251 impl Task for TaskStub {
252 async fn run(self) -> TickStatus {
253 TickStatus::Success
254 }
255 fn task_desc(&self) -> TaskDescription {
256 TaskDescription::from_str("TaskStub").unwrap()
257 }
258 }
259
260 #[builder]
261 fn task_handle(
262 query_times: usize,
263 statuses: Vec<TaskStatus>,
264 abort_times: Option<usize>,
265 ) -> MockTaskHandle {
266 let mut m = MockTaskHandle::new();
267 let mut it = statuses.into_iter();
268 m.expect_query()
269 .returning(move || it.next().unwrap())
270 .times(query_times);
271
272 if let Some(abort_times) = abort_times {
273 m.expect_abort().times(abort_times).return_const(());
274 }
275 m
276 }
277
278 #[test]
279 fn action_success() {
280 let mut registry = MockRegisterTask::<MockTaskHandle>::new();
281 registry
282 .expect_register()
283 .returning(|_| {
284 Ok(task_handle()
285 .query_times(1)
286 .statuses(vec![TaskStatus::Success])
287 .call())
288 })
289 .once();
290
291 let mut behavior = MockBehavior::new();
292 behavior
293 .expect_task()
294 .returning(|| Ok(ActionTask::new(TaskStub::new())));
295 behavior
296 .expect_on_success()
297 .once()
298 .returning(|| Result::Ok(()));
299
300 let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
301 assert_eq!(action.tick(), TickStatus::Running);
302 assert_eq!(action.tick(), TickStatus::Success);
303 }
304
305 #[test]
306 fn action_running() {
307 let mut registry = MockRegisterTask::<MockTaskHandle>::new();
308 registry
309 .expect_register()
310 .returning(|_| {
311 Ok(task_handle()
312 .query_times(1)
313 .statuses(vec![TaskStatus::Running])
314 .call())
315 })
316 .once();
317
318 let mut behavior = MockBehavior::new();
319 behavior
320 .expect_task()
321 .returning(|| Ok(ActionTask::new(TaskStub::new())));
322 behavior
323 .expect_on_running()
324 .once()
325 .returning(|| Result::Ok(()));
326
327 let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
328 assert_eq!(action.tick(), TickStatus::Running);
329 assert_eq!(action.tick(), TickStatus::Running);
330 }
331
332 #[test]
333 fn action_failure() {
334 let mut registry = MockRegisterTask::<MockTaskHandle>::new();
335 registry
336 .expect_register()
337 .returning(|_| {
338 Ok(task_handle()
339 .query_times(1)
340 .statuses(vec![TaskStatus::Failure])
341 .call())
342 })
343 .once();
344
345 let mut behavior = MockBehavior::new();
346 behavior
347 .expect_task()
348 .returning(|| Ok(ActionTask::new(TaskStub::new())));
349 behavior
350 .expect_on_failure()
351 .once()
352 .returning(|| Result::Ok(()));
353
354 let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
355
356 assert_eq!(action.tick(), TickStatus::Running);
357 assert_eq!(action.tick(), TickStatus::Failure);
358 }
359
360 #[test]
361 fn task_creation_failure() {
362 let registry = MockRegisterTask::<MockTaskHandle>::new();
363
364 let mut behavior = MockBehavior::new();
365 behavior
366 .expect_task()
367 .returning(|| Err(anyhow::anyhow!("task creation failed")));
368
369 let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
370 assert_eq!(action.tick(), TickStatus::Failure);
371 }
372
373 #[test]
374 fn task_registration_failure() {
375 let mut registry = MockRegisterTask::<MockTaskHandle>::new();
376 registry
377 .expect_register()
378 .returning(|_| Err(anyhow::anyhow!("registration failed")))
379 .once();
380
381 let mut behavior = MockBehavior::new();
382 behavior
383 .expect_task()
384 .returning(|| Ok(ActionTask::new(TaskStub::new())));
385
386 let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
387 assert_eq!(action.tick(), TickStatus::Failure);
388 }
389
390 #[test]
391 fn action_abort_when_running() {
392 let mut registry = MockRegisterTask::<MockTaskHandle>::new();
393 registry
394 .expect_register()
395 .returning(|_| {
396 Ok(task_handle()
397 .query_times(3)
398 .statuses(vec![
399 TaskStatus::Running,
400 TaskStatus::Running,
401 TaskStatus::Aborted,
402 ])
403 .abort_times(1)
404 .call())
405 })
406 .once();
407
408 let mut behavior = MockBehavior::new();
409 behavior
410 .expect_task()
411 .returning(|| Ok(ActionTask::new(TaskStub::new())));
412 behavior
413 .expect_on_aborted()
414 .once()
415 .returning(|| Result::Ok(()));
416 behavior
417 .expect_on_running()
418 .once()
419 .returning(|| Result::Ok(()));
420
421 let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
422
423 assert_eq!(action.tick(), TickStatus::Running);
424 assert_eq!(action.tick(), TickStatus::Running);
425 action.abort();
426 }
427
428 #[test]
429 fn action_abort_when_idle() {
430 let registry = MockRegisterTask::<MockTaskHandle>::new();
431
432 let mut behavior = MockBehavior::new();
433 behavior
434 .expect_on_aborted()
435 .once()
436 .returning(|| Result::Ok(()));
437
438 let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
439
440 action.abort();
441 }
442
443 #[test]
444 fn action_reset() {
445 let registry = MockRegisterTask::<MockTaskHandle>::new();
446
447 let mut behavior = MockBehavior::new();
448 behavior.expect_reset().once().return_const(());
449
450 let mut action = Action::new(behavior, Arc::new(registry), DEFAULT_ABORT_INTERVAL);
451 action.reset();
452 }
453}