1use ractor::concurrency::Duration;
94use ractor::concurrency::JoinHandle;
95use ractor::{Actor, ActorCell, ActorName, ActorProcessingErr, ActorRef, SpawnErr};
96use std::future::Future;
97use std::pin::Pin;
98use std::sync::Arc;
99use uuid::Uuid;
100
101use crate::core::ChildSpec;
102use crate::{
103 ChildBackoffFn, DynamicSupervisor, DynamicSupervisorMsg, DynamicSupervisorOptions, Restart,
104 SpawnFn,
105};
106
107pub struct TaskActor;
109
110pub enum TaskActorMessage {
112 Run { task: TaskFn },
114}
115
116type TaskFuture = Pin<Box<dyn Future<Output = Result<(), ActorProcessingErr>> + Send>>;
118
119#[derive(Clone)]
121pub struct TaskFn(Arc<dyn Fn() -> TaskFuture + Send + Sync>);
122
123impl TaskFn {
124 pub fn new<F, Fut>(factory: F) -> Self
126 where
127 F: Fn() -> Fut + Send + Sync + 'static,
128 Fut: Future<Output = Result<(), ActorProcessingErr>> + Send + 'static,
129 {
130 TaskFn(Arc::new(move || Box::pin(factory())))
131 }
132}
133
134#[cfg_attr(feature = "async-trait", ractor::async_trait)]
135impl Actor for TaskActor {
136 type Msg = TaskActorMessage;
137 type State = TaskFn;
138 type Arguments = TaskFn;
139
140 async fn pre_start(
141 &self,
142 _myself: ActorRef<Self::Msg>,
143 task: Self::Arguments,
144 ) -> Result<Self::State, ActorProcessingErr> {
145 Ok(task)
146 }
147
148 async fn post_start(
149 &self,
150 myself: ActorRef<Self::Msg>,
151 task: &mut Self::State,
152 ) -> Result<(), ActorProcessingErr> {
153 (task.0)().await?;
154 myself.stop(None);
155 Ok(())
156 }
157}
158
159pub type TaskSupervisorMsg = DynamicSupervisorMsg;
160pub type TaskSupervisorOptions = DynamicSupervisorOptions;
161
162pub struct TaskSupervisor;
163
164pub struct TaskOptions {
166 pub name: ActorName,
167 pub restart: Restart,
168 pub backoff_fn: Option<ChildBackoffFn>,
169 pub reset_after: Option<Duration>,
172}
173
174impl Default for TaskOptions {
175 fn default() -> Self {
176 Self {
177 name: Uuid::new_v4().to_string(),
178 restart: Restart::Temporary,
179 backoff_fn: None,
180 reset_after: None,
181 }
182 }
183}
184
185impl TaskOptions {
186 pub fn new() -> Self {
187 Self::default()
188 }
189
190 pub fn name(mut self, name: String) -> Self {
191 self.name = name;
192 self
193 }
194
195 pub fn restart_policy(mut self, restart: Restart) -> Self {
196 self.restart = restart;
197 self
198 }
199
200 pub fn backoff_fn(mut self, backoff_fn: ChildBackoffFn) -> Self {
201 self.backoff_fn = Some(backoff_fn);
202 self
203 }
204
205 pub fn reset_after(mut self, duration: Duration) -> Self {
207 self.reset_after = Some(duration);
208 self
209 }
210}
211
212impl TaskSupervisor {
213 pub async fn spawn(
214 name: ActorName,
215 options: TaskSupervisorOptions,
216 ) -> Result<(ActorRef<TaskSupervisorMsg>, JoinHandle<()>), SpawnErr> {
217 DynamicSupervisor::spawn(name, options).await
218 }
219
220 pub async fn spawn_linked(
221 name: ActorName,
222 startup_args: TaskSupervisorOptions,
223 supervisor: ActorCell,
224 ) -> Result<(ActorRef<TaskSupervisorMsg>, JoinHandle<()>), SpawnErr> {
225 Actor::spawn_linked(Some(name), DynamicSupervisor, startup_args, supervisor).await
226 }
227
228 pub async fn spawn_task<F, Fut>(
229 supervisor: ActorRef<TaskSupervisorMsg>,
230 task: F,
231 options: TaskOptions,
232 ) -> Result<String, ActorProcessingErr>
233 where
234 F: Fn() -> Fut + Send + Sync + 'static,
235 Fut: Future<Output = Result<(), ActorProcessingErr>> + Send + 'static,
236 {
237 let child_id = options.name;
238 let task_wrapper = TaskFn::new(task);
239
240 let spec = ChildSpec {
241 id: child_id.clone(),
242 spawn_fn: SpawnFn::new({
243 let task_wrapper = task_wrapper.clone();
244 move |sup, id| spawn_task_actor(id, task_wrapper.clone(), sup)
245 }),
246 restart: options.restart,
247 backoff_fn: options.backoff_fn,
248 reset_after: options.reset_after,
249 };
250
251 DynamicSupervisor::spawn_child(supervisor, spec).await?;
252 Ok(child_id)
253 }
254
255 pub async fn terminate_task(
256 supervisor: ActorRef<TaskSupervisorMsg>,
257 task_id: String,
258 ) -> Result<(), ActorProcessingErr> {
259 DynamicSupervisor::terminate_child(supervisor, task_id).await
260 }
261}
262
263async fn spawn_task_actor(id: String, task: TaskFn, sup: ActorCell) -> Result<ActorCell, SpawnErr> {
264 let (child_ref, _join) = DynamicSupervisor::spawn_linked(id, TaskActor, task, sup).await?;
265 Ok(child_ref.get_cell())
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use ractor::{
272 call,
273 concurrency::{sleep, Duration},
274 ActorStatus,
275 };
276 use serial_test::serial;
277 use tokio::sync::mpsc;
278
279 async fn before_each() {
280 sleep(Duration::from_millis(10)).await;
281 }
282
283 #[ractor::concurrency::test]
284 #[serial]
285 async fn test_basic_task_execution() {
286 before_each().await;
287
288 let (supervisor, handle) = TaskSupervisor::spawn(
289 "test-supervisor".into(),
290 TaskSupervisorOptions {
291 max_children: Some(10),
292 max_restarts: 3,
293 max_window: Duration::from_secs(10),
294 reset_after: Some(Duration::from_secs(30)),
295 },
296 )
297 .await
298 .unwrap();
299
300 let (tx, mut rx) = mpsc::channel(1);
301
302 let task_id = TaskSupervisor::spawn_task(
303 supervisor.clone(),
304 move || {
305 let tx = tx.clone();
306 async move {
307 tx.send(()).await.unwrap();
308 Ok(())
309 }
310 },
311 TaskOptions::new().name("background-task".into()),
312 )
313 .await
314 .unwrap();
315
316 rx.recv().await.expect("Task should have executed");
317 sleep(Duration::from_millis(100)).await;
318 let state = call!(supervisor, DynamicSupervisorMsg::InspectState).unwrap();
319
320 assert!(!state.active_children.contains_key(&task_id));
321
322 supervisor.stop(None);
323 let _ = handle.await;
324 }
325
326 #[ractor::concurrency::test]
327 #[serial]
328 async fn test_task_termination() {
329 before_each().await;
330
331 let (supervisor, handle) = TaskSupervisor::spawn(
332 "test-supervisor".into(),
333 TaskSupervisorOptions {
334 max_children: Some(10),
335 max_restarts: 3,
336 max_window: Duration::from_secs(1),
337 reset_after: Some(Duration::from_secs(1000)),
338 },
339 )
340 .await
341 .unwrap();
342
343 let (tx, mut rx) = mpsc::channel(1);
344 let task_id = TaskSupervisor::spawn_task(
345 supervisor.clone(),
346 move || {
347 let tx = tx.clone();
348 async move {
349 sleep(Duration::from_secs(10)).await;
350 tx.send(()).await.unwrap();
351 Ok(())
352 }
353 },
354 TaskOptions::new().restart_policy(Restart::Permanent),
355 )
356 .await
357 .unwrap();
358
359 TaskSupervisor::terminate_task(supervisor.clone(), task_id.clone())
361 .await
362 .unwrap();
363
364 let result = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await;
366 assert!(result.is_err(), "Task should have been terminated");
367
368 supervisor.stop(None);
369 let _ = handle.await;
370 }
371
372 #[ractor::concurrency::test]
373 #[serial]
374 async fn test_restart_policy() {
375 before_each().await;
376
377 let (supervisor, handle) = TaskSupervisor::spawn(
378 "test-supervisor".into(),
379 TaskSupervisorOptions {
380 max_children: Some(10),
381 max_restarts: 3,
382 max_window: Duration::from_secs(1),
383 reset_after: Some(Duration::from_secs(1000)),
384 },
385 )
386 .await
387 .unwrap();
388
389 let (tx, mut rx) = mpsc::channel(3);
390 let _task_id = TaskSupervisor::spawn_task(
391 supervisor.clone(),
392 move || {
393 let tx = tx.clone();
394 async move {
395 tx.send(()).await.unwrap();
396 panic!("Simulated failure");
397 }
398 },
399 TaskOptions::new()
400 .restart_policy(Restart::Transient)
401 .name("restart-test".into()),
402 )
403 .await
404 .unwrap();
405
406 for _ in 0..4 {
408 rx.recv().await.expect("Task should have restarted");
409 }
410
411 sleep(Duration::from_millis(100)).await;
413 assert!(!supervisor
414 .get_children()
415 .iter()
416 .any(|cell| cell.get_status() == ActorStatus::Running));
417
418 supervisor.stop(None);
419 let _ = handle.await;
420 }
421}