ractor_supervisor/
task.rs

1//! Task supervisor for managing supervised async tasks.
2//!
3//! The `TaskSupervisor` is a specialized version of [`DynamicSupervisor`](crate::DynamicSupervisor) that makes it easy
4//! to run async tasks (futures) under supervision. It wraps each task in a lightweight actor that can be monitored
5//! and restarted according to the configured policy.
6//!
7//! ## Use Cases
8//!
9//! The `TaskSupervisor` is ideal for:
10//! - Background jobs that need supervision
11//! - Periodic tasks that should be restarted on failure
12//! - Long-running async operations that need monitoring
13//! - Any async work that should be part of your supervision tree
14//!
15//! ## Key Features
16//!
17//! 1. **Simple API**: Wrap any async task in supervision with minimal boilerplate
18//! 2. **Full Supervision**: Tasks get all the benefits of actor supervision
19//! 3. **Flexible Policies**: Control restart behavior via [`TaskOptions`]
20//! 4. **Resource Control**: Inherit `max_children` and other limits from `DynamicSupervisor`
21//!
22//! ## Example
23//!
24//! ```rust
25//! use ractor_supervisor::*;
26//! use ractor::concurrency::Duration;
27//! use tokio::time::sleep;
28//!
29//! #[tokio::main]
30//! async fn main() {
31//!     // Configure the task supervisor
32//!     let options = TaskSupervisorOptions {
33//!         max_children: Some(10),
34//!         max_restarts: 3,
35//!         max_window: Duration::from_secs(10),
36//!         reset_after: Some(Duration::from_secs(30)),
37//!     };
38//!
39//!     // Spawn the supervisor
40//!     let (sup_ref, _) = TaskSupervisor::spawn(
41//!         "task-sup".into(),
42//!         options
43//!     ).await.unwrap();
44//!
45//!     // Define a task that might fail
46//!     let task_id = TaskSupervisor::spawn_task(
47//!         sup_ref.clone(),
48//!         || async {
49//!             // Simulate some work
50//!             sleep(Duration::from_secs(1)).await;
51//!             
52//!             // Maybe fail sometimes...
53//!             if rand::random() {
54//!                 panic!("Random failure!");
55//!             }
56//!
57//!             Ok(())
58//!         },
59//!         TaskOptions::new()
60//!             .name("periodic-job".into())
61//!             .restart_policy(Restart::Permanent)
62//!     ).await.unwrap();
63//!
64//!     // Later, stop the task if needed
65//!     TaskSupervisor::terminate_task(sup_ref, task_id).await.unwrap();
66//!
67//!     ()
68//! }
69//! ```
70//!
71//! ## Task Lifecycle
72//!
73//! 1. When you call [`TaskSupervisor::spawn_task`]:
74//!    - Your async task is wrapped in a [`TaskActor`]
75//!    - The actor is spawned under the supervisor
76//!    - The task starts executing immediately
77//!
78//! 2. If the task completes normally:
79//!    - The actor stops normally
80//!    - If policy is [`Restart::Permanent`], it's restarted
81//!    - If policy is [`Restart::Transient`] or [`Restart::Temporary`], it's not restarted
82//!
83//! 3. If the task panics:
84//!    - The actor fails abnormally
85//!    - If policy is [`Restart::Permanent`] or [`Restart::Transient`], it's restarted
86//!    - If policy is [`Restart::Temporary`], it's not restarted
87//!
88//! 4. Restart behavior is controlled by:
89//!    - The [`TaskOptions::restart_policy`]
90//!    - The supervisor's meltdown settings
91//!    - Any configured backoff delays
92
93use ractor::concurrency::Duration;
94use ractor::concurrency::JoinHandle;
95use ractor::{Actor, ActorCell, ActorName, ActorProcessingErr, ActorRef, SpawnErr};
96use std::future::Future;
97use std::pin::Pin;
98use std::sync::Arc;
99use uuid::Uuid;
100
101use crate::core::ChildSpec;
102use crate::{
103    ChildBackoffFn, DynamicSupervisor, DynamicSupervisorMsg, DynamicSupervisorOptions, Restart,
104    SpawnFn,
105};
106
107/// Actor that wraps and executes an async task.
108pub struct TaskActor;
109
110/// Messages that can be sent to a [`TaskActor`].
111pub enum TaskActorMessage {
112    /// Execute the wrapped task.
113    Run { task: TaskFn },
114}
115
116/// The result of a task execution.
117type TaskFuture = Pin<Box<dyn Future<Output = Result<(), ActorProcessingErr>> + Send>>;
118
119/// A wrapped async task that can be executed by a [`TaskActor`].
120#[derive(Clone)]
121pub struct TaskFn(Arc<dyn Fn() -> TaskFuture + Send + Sync>);
122
123impl TaskFn {
124    /// Create a new task wrapper from an async function.
125    pub fn new<F, Fut>(factory: F) -> Self
126    where
127        F: Fn() -> Fut + Send + Sync + 'static,
128        Fut: Future<Output = Result<(), ActorProcessingErr>> + Send + 'static,
129    {
130        TaskFn(Arc::new(move || Box::pin(factory())))
131    }
132}
133
134#[cfg_attr(feature = "async-trait", ractor::async_trait)]
135impl Actor for TaskActor {
136    type Msg = TaskActorMessage;
137    type State = TaskFn;
138    type Arguments = TaskFn;
139
140    async fn pre_start(
141        &self,
142        _myself: ActorRef<Self::Msg>,
143        task: Self::Arguments,
144    ) -> Result<Self::State, ActorProcessingErr> {
145        Ok(task)
146    }
147
148    async fn post_start(
149        &self,
150        myself: ActorRef<Self::Msg>,
151        task: &mut Self::State,
152    ) -> Result<(), ActorProcessingErr> {
153        (task.0)().await?;
154        myself.stop(None);
155        Ok(())
156    }
157}
158
159pub type TaskSupervisorMsg = DynamicSupervisorMsg;
160pub type TaskSupervisorOptions = DynamicSupervisorOptions;
161
162pub struct TaskSupervisor;
163
164/// Options for configuring a task to be supervised.
165pub struct TaskOptions {
166    pub name: ActorName,
167    pub restart: Restart,
168    pub backoff_fn: Option<ChildBackoffFn>,
169    /// Per-task "reset" duration: if a task has not failed for the given period,
170    /// its failure count is reset.
171    pub reset_after: Option<Duration>,
172}
173
174impl Default for TaskOptions {
175    fn default() -> Self {
176        Self {
177            name: Uuid::new_v4().to_string(),
178            restart: Restart::Temporary,
179            backoff_fn: None,
180            reset_after: None,
181        }
182    }
183}
184
185impl TaskOptions {
186    pub fn new() -> Self {
187        Self::default()
188    }
189
190    pub fn name(mut self, name: String) -> Self {
191        self.name = name;
192        self
193    }
194
195    pub fn restart_policy(mut self, restart: Restart) -> Self {
196        self.restart = restart;
197        self
198    }
199
200    pub fn backoff_fn(mut self, backoff_fn: ChildBackoffFn) -> Self {
201        self.backoff_fn = Some(backoff_fn);
202        self
203    }
204
205    /// Set the per-task reset duration.
206    pub fn reset_after(mut self, duration: Duration) -> Self {
207        self.reset_after = Some(duration);
208        self
209    }
210}
211
212impl TaskSupervisor {
213    pub async fn spawn(
214        name: ActorName,
215        options: TaskSupervisorOptions,
216    ) -> Result<(ActorRef<TaskSupervisorMsg>, JoinHandle<()>), SpawnErr> {
217        DynamicSupervisor::spawn(name, options).await
218    }
219
220    pub async fn spawn_linked(
221        name: ActorName,
222        startup_args: TaskSupervisorOptions,
223        supervisor: ActorCell,
224    ) -> Result<(ActorRef<TaskSupervisorMsg>, JoinHandle<()>), SpawnErr> {
225        Actor::spawn_linked(Some(name), DynamicSupervisor, startup_args, supervisor).await
226    }
227
228    pub async fn spawn_task<F, Fut>(
229        supervisor: ActorRef<TaskSupervisorMsg>,
230        task: F,
231        options: TaskOptions,
232    ) -> Result<String, ActorProcessingErr>
233    where
234        F: Fn() -> Fut + Send + Sync + 'static,
235        Fut: Future<Output = Result<(), ActorProcessingErr>> + Send + 'static,
236    {
237        let child_id = options.name;
238        let task_wrapper = TaskFn::new(task);
239
240        let spec = ChildSpec {
241            id: child_id.clone(),
242            spawn_fn: SpawnFn::new({
243                let task_wrapper = task_wrapper.clone();
244                move |sup, id| spawn_task_actor(id, task_wrapper.clone(), sup)
245            }),
246            restart: options.restart,
247            backoff_fn: options.backoff_fn,
248            reset_after: options.reset_after,
249        };
250
251        DynamicSupervisor::spawn_child(supervisor, spec).await?;
252        Ok(child_id)
253    }
254
255    pub async fn terminate_task(
256        supervisor: ActorRef<TaskSupervisorMsg>,
257        task_id: String,
258    ) -> Result<(), ActorProcessingErr> {
259        DynamicSupervisor::terminate_child(supervisor, task_id).await
260    }
261}
262
263async fn spawn_task_actor(id: String, task: TaskFn, sup: ActorCell) -> Result<ActorCell, SpawnErr> {
264    let (child_ref, _join) = DynamicSupervisor::spawn_linked(id, TaskActor, task, sup).await?;
265    Ok(child_ref.get_cell())
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use ractor::{
272        call,
273        concurrency::{sleep, Duration},
274        ActorStatus,
275    };
276    use serial_test::serial;
277    use tokio::sync::mpsc;
278
279    async fn before_each() {
280        sleep(Duration::from_millis(10)).await;
281    }
282
283    #[ractor::concurrency::test]
284    #[serial]
285    async fn test_basic_task_execution() {
286        before_each().await;
287
288        let (supervisor, handle) = TaskSupervisor::spawn(
289            "test-supervisor".into(),
290            TaskSupervisorOptions {
291                max_children: Some(10),
292                max_restarts: 3,
293                max_window: Duration::from_secs(10),
294                reset_after: Some(Duration::from_secs(30)),
295            },
296        )
297        .await
298        .unwrap();
299
300        let (tx, mut rx) = mpsc::channel(1);
301
302        let task_id = TaskSupervisor::spawn_task(
303            supervisor.clone(),
304            move || {
305                let tx = tx.clone();
306                async move {
307                    tx.send(()).await.unwrap();
308                    Ok(())
309                }
310            },
311            TaskOptions::new().name("background-task".into()),
312        )
313        .await
314        .unwrap();
315
316        rx.recv().await.expect("Task should have executed");
317        sleep(Duration::from_millis(100)).await;
318        let state = call!(supervisor, DynamicSupervisorMsg::InspectState).unwrap();
319
320        assert!(!state.active_children.contains_key(&task_id));
321
322        supervisor.stop(None);
323        let _ = handle.await;
324    }
325
326    #[ractor::concurrency::test]
327    #[serial]
328    async fn test_task_termination() {
329        before_each().await;
330
331        let (supervisor, handle) = TaskSupervisor::spawn(
332            "test-supervisor".into(),
333            TaskSupervisorOptions {
334                max_children: Some(10),
335                max_restarts: 3,
336                max_window: Duration::from_secs(1),
337                reset_after: Some(Duration::from_secs(1000)),
338            },
339        )
340        .await
341        .unwrap();
342
343        let (tx, mut rx) = mpsc::channel(1);
344        let task_id = TaskSupervisor::spawn_task(
345            supervisor.clone(),
346            move || {
347                let tx = tx.clone();
348                async move {
349                    sleep(Duration::from_secs(10)).await;
350                    tx.send(()).await.unwrap();
351                    Ok(())
352                }
353            },
354            TaskOptions::new().restart_policy(Restart::Permanent),
355        )
356        .await
357        .unwrap();
358
359        // Terminate before completion
360        TaskSupervisor::terminate_task(supervisor.clone(), task_id.clone())
361            .await
362            .unwrap();
363
364        // Verify task didn't complete
365        let result = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await;
366        assert!(result.is_err(), "Task should have been terminated");
367
368        supervisor.stop(None);
369        let _ = handle.await;
370    }
371
372    #[ractor::concurrency::test]
373    #[serial]
374    async fn test_restart_policy() {
375        before_each().await;
376
377        let (supervisor, handle) = TaskSupervisor::spawn(
378            "test-supervisor".into(),
379            TaskSupervisorOptions {
380                max_children: Some(10),
381                max_restarts: 3,
382                max_window: Duration::from_secs(1),
383                reset_after: Some(Duration::from_secs(1000)),
384            },
385        )
386        .await
387        .unwrap();
388
389        let (tx, mut rx) = mpsc::channel(3);
390        let _task_id = TaskSupervisor::spawn_task(
391            supervisor.clone(),
392            move || {
393                let tx = tx.clone();
394                async move {
395                    tx.send(()).await.unwrap();
396                    panic!("Simulated failure");
397                }
398            },
399            TaskOptions::new()
400                .restart_policy(Restart::Transient)
401                .name("restart-test".into()),
402        )
403        .await
404        .unwrap();
405
406        // Verify multiple restarts
407        for _ in 0..4 {
408            rx.recv().await.expect("Task should have restarted");
409        }
410
411        // Should be terminated after 3 failures (+1 initial)
412        sleep(Duration::from_millis(100)).await;
413        assert!(!supervisor
414            .get_children()
415            .iter()
416            .any(|cell| cell.get_status() == ActorStatus::Running));
417
418        supervisor.stop(None);
419        let _ = handle.await;
420    }
421}