1use crate::{
8 Actor, ActorContext, ActorPath, Error, Handler, Message,
9 NotPersistentActor,
10 supervision::{RetryStrategy, Strategy},
11};
12
13use async_trait::async_trait;
14
15use std::{fmt::Debug, marker::PhantomData, time::Duration};
16use tracing::{debug, error, info_span};
17
18#[async_trait]
19trait CompletionNotifier<T>: Send + Sync
20where
21 T: Actor + Handler<T> + Clone + NotPersistentActor,
22{
23 async fn notify(&self, ctx: &ActorContext<RetryActor<T>>);
24}
25
26struct ParentMessageNotifier<T, P>
27where
28 T: Actor + Handler<T> + Clone + NotPersistentActor,
29 P: Actor + Handler<P>,
30{
31 message: P::Message,
32 _phantom: PhantomData<(T, P)>,
33}
34
35#[async_trait]
36impl<T, P> CompletionNotifier<T> for ParentMessageNotifier<T, P>
37where
38 T: Actor + Handler<T> + Clone + NotPersistentActor,
39 P: Actor + Handler<P>,
40{
41 async fn notify(&self, ctx: &ActorContext<RetryActor<T>>) {
42 if let Ok(parent) = ctx.get_parent::<P>().await {
43 let _ = parent.tell(self.message.clone()).await;
44 }
45 }
46}
47
48pub struct RetryActor<T>
65where
66 T: Actor + Handler<T> + Clone + NotPersistentActor,
67{
68 target: T,
69 message: T::Message,
70 retry_strategy: Strategy,
71 retries: usize,
72 started: bool,
73 is_end: bool,
74 completion_pending: bool,
75 completion_notified: bool,
76 on_finished: Option<Box<dyn CompletionNotifier<T>>>,
77 pending_retry: Option<tokio::task::JoinHandle<()>>,
78}
79
80impl<T> RetryActor<T>
81where
82 T: Actor + Handler<T> + Clone + NotPersistentActor,
83{
84 pub const fn new(
86 target: T,
87 message: T::Message,
88 retry_strategy: Strategy,
89 ) -> Self {
90 Self {
91 target,
92 message,
93 retry_strategy,
94 retries: 0,
95 started: false,
96 is_end: false,
97 completion_pending: false,
98 completion_notified: false,
99 on_finished: None,
100 pending_retry: None,
101 }
102 }
103
104 pub fn new_with_parent_message<P>(
111 target: T,
112 message: T::Message,
113 retry_strategy: Strategy,
114 completion_message: P::Message,
115 ) -> Self
116 where
117 P: Actor + Handler<P>,
118 {
119 Self {
120 target,
121 message,
122 retry_strategy,
123 retries: 0,
124 started: false,
125 is_end: false,
126 completion_pending: false,
127 completion_notified: false,
128 on_finished: Some(Box::new(ParentMessageNotifier::<T, P> {
129 message: completion_message,
130 _phantom: PhantomData,
131 })),
132 pending_retry: None,
133 }
134 }
135
136 async fn finish_retry_cycle(&mut self, ctx: &ActorContext<Self>) {
137 self.is_end = true;
138 if let Some(handle) = self.pending_retry.take() {
139 handle.abort();
140 }
141 if !self.completion_notified {
142 self.completion_notified = true;
143 } else {
144 ctx.stop(None).await;
145 return;
146 }
147
148 if let Some(notifier) = self.on_finished.as_ref() {
149 notifier.notify(ctx).await;
150 }
151
152 self.schedule_completion(ctx).await;
153 }
154
155 async fn schedule_completion(&mut self, ctx: &ActorContext<Self>) {
156 self.completion_pending = true;
157 if let Ok(actor) = ctx.reference().await {
158 self.pending_retry = Some(tokio::spawn(async move {
159 tokio::time::sleep(Duration::ZERO).await;
160 let _ = actor.tell(RetryMessage::Complete).await;
161 }));
162 } else {
163 ctx.stop(None).await;
164 }
165 }
166
167 async fn handle_retry_attempt(
168 &mut self,
169 ctx: &ActorContext<Self>,
170 ) -> Result<(), Error> {
171 if self.is_end {
172 return Ok(());
173 }
174
175 self.retries += 1;
176 if self.retries > self.retry_strategy.max_retries() {
177 self.finish_retry_cycle(ctx).await;
178 return Ok(());
179 }
180
181 debug!(
182 retry = self.retries,
183 max_retries = self.retry_strategy.max_retries(),
184 "Executing retry"
185 );
186
187 let child = match ctx.get_child::<T>("target").await {
190 Ok(child) => child,
191 Err(err) => {
192 error!(error = %err, "Retry target is not available");
193 self.finish_retry_cycle(ctx).await;
194 return Ok(());
195 }
196 };
197
198 if let Err(err) = child.tell(self.message.clone()).await {
199 error!(error = %err, "Failed to send retry message to target");
200 self.finish_retry_cycle(ctx).await;
201 return Ok(());
202 }
203
204 if let Ok(actor) = ctx.reference().await {
205 match self.retry_strategy.next_backoff() {
206 Some(duration) => {
207 self.pending_retry = Some(tokio::spawn(async move {
208 tokio::time::sleep(duration).await;
209 let _ = actor.tell(RetryMessage::Continue).await;
210 }));
211 }
212 None => {
213 let _ = actor.tell(RetryMessage::Continue).await;
214 }
215 }
216 } else {
217 debug!("Retry actor no longer registered, stopping silently");
218 self.is_end = true;
219 ctx.stop(None).await;
220 }
221
222 Ok(())
223 }
224}
225#[derive(Debug, Clone)]
226pub enum RetryMessage {
227 Retry,
229 Continue,
231 End,
233 Complete,
236}
237
238impl Message for RetryMessage {}
239
240impl<T> NotPersistentActor for RetryActor<T> where
241 T: Actor + Handler<T> + Clone + NotPersistentActor
242{
243}
244
245#[async_trait]
246impl<T> Actor for RetryActor<T>
247where
248 T: Actor + Handler<T> + Clone + NotPersistentActor,
249{
250 type Message = RetryMessage;
251 type Response = ();
252 type Event = ();
253
254 fn get_span(
255 id: &str,
256 _parent_span: Option<tracing::Span>,
257 ) -> tracing::Span {
258 info_span!("RetryActor", id = %id)
259 }
260
261 async fn pre_start(
262 &mut self,
263 ctx: &mut ActorContext<Self>,
264 ) -> Result<(), Error> {
265 ctx.create_child("target", self.target.clone()).await?;
266 Ok(())
267 }
268
269 async fn pre_stop(
270 &mut self,
271 _ctx: &mut ActorContext<Self>,
272 ) -> Result<(), Error> {
273 if let Some(handle) = self.pending_retry.take() {
274 handle.abort();
275 }
276 Ok(())
277 }
278}
279
280#[async_trait]
281impl<T> Handler<Self> for RetryActor<T>
282where
283 T: Actor + Handler<T> + Clone + NotPersistentActor,
284{
285 async fn handle_message(
286 &mut self,
287 _path: ActorPath,
288 message: RetryMessage,
289 ctx: &mut ActorContext<Self>,
290 ) -> Result<(), Error> {
291 match message {
292 RetryMessage::Retry => {
293 if self.started {
294 debug!(
295 "Retry cycle already started, ignoring duplicate start"
296 );
297 } else {
298 self.started = true;
299 self.handle_retry_attempt(ctx).await?;
300 }
301 }
302 RetryMessage::Continue => {
303 self.handle_retry_attempt(ctx).await?;
304 }
305 RetryMessage::End => {
306 self.finish_retry_cycle(ctx).await;
307 }
308 RetryMessage::Complete => {
309 if self.completion_pending {
310 self.completion_pending = false;
311 ctx.stop(None).await;
312 }
313 }
314 }
315 Ok(())
316 }
317}
318
319#[cfg(test)]
320mod tests {
321
322 use test_log::test;
323 use tokio_util::sync::CancellationToken;
324 use tracing::info_span;
325
326 use super::*;
327
328 use crate::{ActorRef, ActorSystem, Error, FixedIntervalStrategy};
329
330 use std::sync::{
331 Arc,
332 atomic::{AtomicUsize, Ordering},
333 };
334 use std::time::Duration;
335
336 pub struct SourceActor;
337
338 impl NotPersistentActor for SourceActor {}
339
340 #[derive(Debug, Clone)]
341 pub struct SourceMessage(pub String);
342
343 impl Message for SourceMessage {}
344
345 #[async_trait]
346 impl Actor for SourceActor {
347 type Message = SourceMessage;
348 type Response = ();
349 type Event = ();
350
351 fn get_span(
352 id: &str,
353 _parent_span: Option<tracing::Span>,
354 ) -> tracing::Span {
355 info_span!("SourceActor", id = %id)
356 }
357
358 async fn pre_start(
359 &mut self,
360 ctx: &mut ActorContext<SourceActor>,
361 ) -> Result<(), Error> {
362 println!("SourceActor pre_start");
363 let target = TargetActor { counter: 0 };
364
365 let strategy = Strategy::FixedInterval(FixedIntervalStrategy::new(
366 3,
367 Duration::from_secs(1),
368 ));
369
370 let retry_actor = RetryActor::new(
371 target,
372 TargetMessage {
373 source: ctx.path().clone(),
374 message: "Hello from parent".to_owned(),
375 },
376 strategy,
377 );
378 let retry: ActorRef<RetryActor<TargetActor>> =
379 ctx.create_child("retry", retry_actor).await.unwrap();
380
381 retry.tell(RetryMessage::Retry).await.unwrap();
382 Ok(())
383 }
384 }
385
386 #[async_trait]
387 impl Handler<SourceActor> for SourceActor {
388 async fn handle_message(
389 &mut self,
390 _path: ActorPath,
391 message: SourceMessage,
392 ctx: &mut ActorContext<SourceActor>,
393 ) -> Result<(), Error> {
394 println!("Message: {:?}", message);
395 assert_eq!(message.0, "Hello from child");
396
397 let retry = ctx
398 .get_child::<RetryActor<TargetActor>>("retry")
399 .await
400 .unwrap();
401 retry.tell(RetryMessage::End).await.unwrap();
402
403 Ok(())
404 }
405 }
406
407 #[derive(Debug, Clone)]
408 enum ParentMsg {
409 Start,
410 RetryFinished,
411 }
412
413 impl Message for ParentMsg {}
414
415 #[derive(Clone)]
416 struct CompletionParent {
417 completions: Arc<AtomicUsize>,
418 }
419
420 impl NotPersistentActor for CompletionParent {}
421
422 #[async_trait]
423 impl Actor for CompletionParent {
424 type Message = ParentMsg;
425 type Response = ();
426 type Event = ();
427
428 fn get_span(
429 id: &str,
430 _parent_span: Option<tracing::Span>,
431 ) -> tracing::Span {
432 info_span!("CompletionParent", id = %id)
433 }
434
435 async fn pre_start(
436 &mut self,
437 ctx: &mut ActorContext<Self>,
438 ) -> Result<(), Error> {
439 let retry = RetryActor::new_with_parent_message::<CompletionParent>(
440 PassiveTarget,
441 PassiveMessage,
442 Strategy::FixedInterval(FixedIntervalStrategy::new(
443 2,
444 Duration::from_millis(10),
445 )),
446 ParentMsg::RetryFinished,
447 );
448 let _: ActorRef<RetryActor<PassiveTarget>> =
449 ctx.create_child("retry", retry).await?;
450 Ok(())
451 }
452 }
453
454 #[async_trait]
455 impl Handler<CompletionParent> for CompletionParent {
456 async fn handle_message(
457 &mut self,
458 _path: ActorPath,
459 message: ParentMsg,
460 ctx: &mut ActorContext<CompletionParent>,
461 ) -> Result<(), Error> {
462 match message {
463 ParentMsg::Start => {
464 let retry = ctx
465 .get_child::<RetryActor<PassiveTarget>>("retry")
466 .await?;
467 retry.tell(RetryMessage::Retry).await?;
468 }
469 ParentMsg::RetryFinished => {
470 self.completions.fetch_add(1, Ordering::SeqCst);
471 }
472 }
473 Ok(())
474 }
475 }
476
477 #[derive(Clone)]
478 struct PassiveTarget;
479
480 impl NotPersistentActor for PassiveTarget {}
481
482 #[derive(Debug, Clone)]
483 struct PassiveMessage;
484
485 impl Message for PassiveMessage {}
486
487 impl Actor for PassiveTarget {
488 type Message = PassiveMessage;
489 type Response = ();
490 type Event = ();
491
492 fn get_span(
493 id: &str,
494 _parent_span: Option<tracing::Span>,
495 ) -> tracing::Span {
496 info_span!("PassiveTarget", id = %id)
497 }
498 }
499
500 #[async_trait]
501 impl Handler<PassiveTarget> for PassiveTarget {
502 async fn handle_message(
503 &mut self,
504 _path: ActorPath,
505 _message: PassiveMessage,
506 _ctx: &mut ActorContext<PassiveTarget>,
507 ) -> Result<(), Error> {
508 Ok(())
509 }
510 }
511
512 #[derive(Clone)]
513 struct CountingTarget {
514 deliveries: Arc<AtomicUsize>,
515 }
516
517 impl NotPersistentActor for CountingTarget {}
518
519 #[derive(Debug, Clone)]
520 struct CountMessage;
521
522 impl Message for CountMessage {}
523
524 impl Actor for CountingTarget {
525 type Message = CountMessage;
526 type Response = ();
527 type Event = ();
528
529 fn get_span(
530 id: &str,
531 _parent_span: Option<tracing::Span>,
532 ) -> tracing::Span {
533 info_span!("CountingTarget", id = %id)
534 }
535 }
536
537 #[async_trait]
538 impl Handler<CountingTarget> for CountingTarget {
539 async fn handle_message(
540 &mut self,
541 _path: ActorPath,
542 _message: CountMessage,
543 _ctx: &mut ActorContext<CountingTarget>,
544 ) -> Result<(), Error> {
545 self.deliveries.fetch_add(1, Ordering::SeqCst);
546 Ok(())
547 }
548 }
549
550 #[derive(Clone)]
551 pub struct TargetActor {
552 counter: usize,
553 }
554
555 #[derive(Debug, Clone)]
556 pub struct TargetMessage {
557 pub source: ActorPath,
558 pub message: String,
559 }
560
561 impl Message for TargetMessage {}
562
563 impl NotPersistentActor for TargetActor {}
564
565 impl Actor for TargetActor {
566 type Message = TargetMessage;
567 type Response = ();
568 type Event = ();
569
570 fn get_span(
571 id: &str,
572 _parent_span: Option<tracing::Span>,
573 ) -> tracing::Span {
574 info_span!("TargetActor", id = %id)
575 }
576 }
577
578 #[async_trait]
579 impl Handler<TargetActor> for TargetActor {
580 async fn handle_message(
581 &mut self,
582 _path: ActorPath,
583 message: TargetMessage,
584 ctx: &mut ActorContext<TargetActor>,
585 ) -> Result<(), Error> {
586 assert_eq!(message.message, "Hello from parent");
587 self.counter += 1;
588 println!("Counter: {}", self.counter);
589 if self.counter == 2 {
590 let source = ctx
591 .system()
592 .get_actor::<SourceActor>(&message.source)
593 .await
594 .unwrap();
595 source
596 .tell(SourceMessage("Hello from child".to_owned()))
597 .await?;
598 }
599 Ok(())
600 }
601 }
602
603 #[test(tokio::test)]
604 async fn test_retry_actor() {
605 let (system, mut runner) = ActorSystem::create(
606 CancellationToken::new(),
607 CancellationToken::new(),
608 );
609
610 tokio::spawn(async move {
611 runner.run().await;
612 });
613
614 let _: ActorRef<SourceActor> = system
615 .create_root_actor("source", SourceActor)
616 .await
617 .unwrap();
618
619 tokio::time::sleep(Duration::from_secs(5)).await;
620 }
621
622 #[derive(Clone)]
623 struct StopAfterFirstTarget {
624 deliveries: Arc<AtomicUsize>,
625 }
626
627 impl NotPersistentActor for StopAfterFirstTarget {}
628
629 #[derive(Debug, Clone)]
630 struct StopAfterFirstMessage;
631
632 impl Message for StopAfterFirstMessage {}
633
634 impl Actor for StopAfterFirstTarget {
635 type Message = StopAfterFirstMessage;
636 type Response = ();
637 type Event = ();
638
639 fn get_span(
640 id: &str,
641 _parent_span: Option<tracing::Span>,
642 ) -> tracing::Span {
643 info_span!("StopAfterFirstTarget", id = %id)
644 }
645 }
646
647 #[async_trait]
648 impl Handler<StopAfterFirstTarget> for StopAfterFirstTarget {
649 async fn handle_message(
650 &mut self,
651 _path: ActorPath,
652 _message: StopAfterFirstMessage,
653 ctx: &mut ActorContext<StopAfterFirstTarget>,
654 ) -> Result<(), Error> {
655 let count = self.deliveries.fetch_add(1, Ordering::SeqCst) + 1;
656 if count == 1 {
657 ctx.stop(None).await;
658 }
659 Ok(())
660 }
661 }
662
663 #[test(tokio::test)]
664 async fn test_retry_actor_stops_when_target_unavailable() {
665 let (system, mut runner) = ActorSystem::create(
666 CancellationToken::new(),
667 CancellationToken::new(),
668 );
669
670 tokio::spawn(async move {
671 runner.run().await;
672 });
673
674 let deliveries = Arc::new(AtomicUsize::new(0));
675 let retry_actor = RetryActor::new(
676 StopAfterFirstTarget {
677 deliveries: deliveries.clone(),
678 },
679 StopAfterFirstMessage,
680 Strategy::FixedInterval(FixedIntervalStrategy::new(
681 5,
682 Duration::from_millis(20),
683 )),
684 );
685
686 let retry_ref: ActorRef<RetryActor<StopAfterFirstTarget>> = system
687 .create_root_actor("retry_stop_on_send_failure", retry_actor)
688 .await
689 .unwrap();
690
691 retry_ref.tell(RetryMessage::Retry).await.unwrap();
692
693 tokio::time::timeout(Duration::from_secs(1), retry_ref.closed())
694 .await
695 .expect("retry actor should stop after target becomes unavailable");
696
697 assert_eq!(deliveries.load(Ordering::SeqCst), 1);
698 }
699
700 #[test(tokio::test)]
701 async fn test_retry_actor_notifies_parent_when_retries_finish() {
702 let (system, mut runner) = ActorSystem::create(
703 CancellationToken::new(),
704 CancellationToken::new(),
705 );
706
707 tokio::spawn(async move {
708 runner.run().await;
709 });
710
711 let completions = Arc::new(AtomicUsize::new(0));
712 let parent = CompletionParent {
713 completions: completions.clone(),
714 };
715
716 let parent_ref: ActorRef<CompletionParent> = system
717 .create_root_actor("completion_parent", parent)
718 .await
719 .unwrap();
720
721 parent_ref.tell(ParentMsg::Start).await.unwrap();
722
723 tokio::time::timeout(Duration::from_secs(1), async {
724 loop {
725 if completions.load(Ordering::SeqCst) == 1 {
726 break;
727 }
728 tokio::time::sleep(Duration::from_millis(10)).await;
729 }
730 })
731 .await
732 .expect("parent should receive completion notification");
733 }
734
735 #[test(tokio::test)]
736 async fn test_retry_actor_ignores_duplicate_retry_start() {
737 let (system, mut runner) = ActorSystem::create(
738 CancellationToken::new(),
739 CancellationToken::new(),
740 );
741
742 tokio::spawn(async move {
743 runner.run().await;
744 });
745
746 let deliveries = Arc::new(AtomicUsize::new(0));
747 let retry_actor = RetryActor::new(
748 CountingTarget {
749 deliveries: deliveries.clone(),
750 },
751 CountMessage,
752 Strategy::NoInterval(crate::NoIntervalStrategy::new(3)),
753 );
754
755 let retry_ref: ActorRef<RetryActor<CountingTarget>> = system
756 .create_root_actor::<RetryActor<CountingTarget>, _>(
757 "retry_duplicate_start",
758 retry_actor,
759 )
760 .await
761 .unwrap();
762
763 retry_ref.tell(RetryMessage::Retry).await.unwrap();
764 retry_ref.tell(RetryMessage::Retry).await.unwrap();
765
766 tokio::time::timeout(Duration::from_secs(1), async {
767 loop {
768 if deliveries.load(Ordering::SeqCst) == 3 {
769 break;
770 }
771 tokio::time::sleep(Duration::from_millis(10)).await;
772 }
773 })
774 .await
775 .expect("retry actor should deliver exactly one retry cycle");
776
777 tokio::time::timeout(Duration::from_secs(1), retry_ref.closed())
778 .await
779 .expect("retry actor should stop after exhausting retries");
780 }
781}