1use crate::core::{
2    ChildFailureState, ChildSpec, CoreSupervisorOptions, RestartLog, SupervisorCore,
3    SupervisorError,
4};
5use crate::ExitReason;
6use ractor::concurrency::{sleep, Duration, JoinHandle};
7use ractor::{
8    Actor, ActorCell, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort, SpawnErr,
9    SupervisionEvent,
10};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SupervisorStrategy {
16    OneForOne,
18    OneForAll,
20    RestForOne,
22}
23
24#[derive(Clone)]
33pub struct SupervisorOptions {
34    pub strategy: SupervisorStrategy,
36    pub max_restarts: usize,
38    pub max_window: Duration,
40    pub reset_after: Option<Duration>,
42}
43
44impl CoreSupervisorOptions<SupervisorStrategy> for SupervisorOptions {
45    fn max_restarts(&self) -> usize {
46        self.max_restarts
47    }
48
49    fn max_window(&self) -> Duration {
50        self.max_window
51    }
52
53    fn reset_after(&self) -> Option<Duration> {
54        self.reset_after
55    }
56
57    fn strategy(&self) -> SupervisorStrategy {
58        self.strategy
59    }
60}
61
62pub enum SupervisorMsg {
64    OneForOneSpawn { child_id: String },
66    OneForAllSpawn { child_id: String },
68    RestForOneSpawn { child_id: String },
70    InspectState(RpcReplyPort<SupervisorState>),
72}
73
74pub struct SupervisorArguments {
76    pub child_specs: Vec<ChildSpec>,
78    pub options: SupervisorOptions,
80}
81
82#[derive(Clone)]
88pub struct SupervisorState {
89    pub child_specs: Vec<ChildSpec>,
91
92    pub child_failure_state: HashMap<String, ChildFailureState>,
94
95    pub restart_log: Vec<RestartLog>,
97
98    pub options: SupervisorOptions,
100}
101
102impl SupervisorCore for SupervisorState {
103    type Message = SupervisorMsg;
104    type Options = SupervisorOptions;
105    type Strategy = SupervisorStrategy;
106
107    fn child_failure_state(&mut self) -> &mut HashMap<String, ChildFailureState> {
108        &mut self.child_failure_state
109    }
110
111    fn restart_log(&mut self) -> &mut Vec<RestartLog> {
112        &mut self.restart_log
113    }
114
115    fn options(&self) -> &SupervisorOptions {
116        &self.options
117    }
118
119    fn restart_msg(
120        &self,
121        child_spec: &ChildSpec,
122        strategy: SupervisorStrategy,
123        _myself: ActorRef<SupervisorMsg>,
124    ) -> SupervisorMsg {
125        let child_id = child_spec.id.clone();
126        match strategy {
127            SupervisorStrategy::OneForOne => SupervisorMsg::OneForOneSpawn { child_id },
128            SupervisorStrategy::OneForAll => SupervisorMsg::OneForAllSpawn { child_id },
129            SupervisorStrategy::RestForOne => SupervisorMsg::RestForOneSpawn { child_id },
130        }
131    }
132}
133
134impl SupervisorState {
135    fn new(args: SupervisorArguments) -> Self {
137        Self {
138            child_specs: args.child_specs,
139            child_failure_state: HashMap::new(),
140            restart_log: Vec::new(),
141            options: args.options,
142        }
143    }
144
145    pub async fn spawn_child(
146        &mut self,
147        child_spec: &ChildSpec,
148        myself: ActorRef<SupervisorMsg>,
149    ) -> Result<(), ActorProcessingErr> {
150        let result = child_spec
151            .spawn_fn
152            .call(myself.get_cell().clone(), child_spec.id.clone())
153            .await
154            .map_err(|e| SupervisorError::ChildSpawnError {
155                child_id: child_spec.id.clone(),
156                reason: e.to_string(),
157            });
158
159        if let Err(err) = result {
162            self.handle_child_restart(
163                child_spec,
164                true,
165                myself.clone(),
166                &ExitReason::Error(err.into()),
167            )?;
168        }
169
170        Ok(())
171    }
172
173    pub async fn spawn_all_children(
175        &mut self,
176        myself: ActorRef<SupervisorMsg>,
177    ) -> Result<(), ActorProcessingErr> {
178        let child_specs = std::mem::take(&mut self.child_specs);
180        for spec in &child_specs {
181            self.spawn_child(spec, myself.clone()).await?;
182        }
183        self.child_specs = child_specs;
185        Ok(())
186    }
187
188    pub async fn perform_one_for_one_spawn(
190        &mut self,
191        child_id: &str,
192        myself: ActorRef<SupervisorMsg>,
193    ) -> Result<(), ActorProcessingErr> {
194        self.track_global_restart(child_id)?;
195        let child_specs = std::mem::take(&mut self.child_specs);
197        if let Some(spec) = child_specs.iter().find(|s| s.id == child_id) {
198            self.spawn_child(spec, myself.clone()).await?;
199        }
200        self.child_specs = child_specs;
202        Ok(())
203    }
204
205    pub async fn perform_one_for_all_spawn(
207        &mut self,
208        child_id: &str,
209        myself: ActorRef<SupervisorMsg>,
210    ) -> Result<(), ActorProcessingErr> {
211        self.track_global_restart(child_id)?;
212        for cell in myself.get_children() {
214            cell.unlink(myself.get_cell());
215            cell.kill();
216        }
217        sleep(Duration::from_millis(10)).await;
219        self.spawn_all_children(myself).await?;
220        Ok(())
221    }
222
223    pub async fn perform_rest_for_one_spawn(
225        &mut self,
226        child_id: &str,
227        myself: ActorRef<SupervisorMsg>,
228    ) -> Result<(), ActorProcessingErr> {
229        self.track_global_restart(child_id)?;
230        let child_specs = std::mem::take(&mut self.child_specs);
232        let children = myself.get_children();
233        let child_cell_by_name: HashMap<String, &ActorCell> = children
234            .iter()
235            .filter_map(|cell| cell.get_name().map(|name| (name, cell)))
236            .collect();
237        if let Some(i) = child_specs.iter().position(|s| s.id == child_id) {
238            for spec in child_specs.iter().skip(i) {
240                if let Some(cell) = child_cell_by_name.get(&spec.id) {
241                    cell.unlink(myself.get_cell());
242                    cell.kill();
243                }
244            }
245            sleep(Duration::from_millis(10)).await;
247            for spec in child_specs.iter().skip(i) {
249                self.spawn_child(spec, myself.clone()).await?;
250            }
251        }
252        self.child_specs = child_specs;
254        Ok(())
255    }
256}
257
258pub struct Supervisor;
262
263impl Supervisor {
264    pub async fn spawn_linked<T: Actor>(
265        name: ActorName,
266        handler: T,
267        startup_args: T::Arguments,
268        supervisor: ActorCell,
269    ) -> Result<(ActorRef<T::Msg>, JoinHandle<()>), SpawnErr> {
270        Actor::spawn_linked(Some(name), handler, startup_args, supervisor).await
271    }
272
273    pub async fn spawn(
274        name: ActorName,
275        startup_args: SupervisorArguments,
276    ) -> Result<(ActorRef<SupervisorMsg>, JoinHandle<()>), SpawnErr> {
277        Actor::spawn(Some(name), Supervisor, startup_args).await
278    }
279}
280
281#[cfg(test)]
283static SUPERVISOR_FINAL: std::sync::OnceLock<tokio::sync::Mutex<HashMap<String, SupervisorState>>> =
284    std::sync::OnceLock::new();
285
286#[cfg_attr(feature = "async-trait", ractor::async_trait)]
287impl Actor for Supervisor {
288    type Msg = SupervisorMsg;
289    type State = SupervisorState;
290    type Arguments = SupervisorArguments;
291
292    async fn pre_start(
293        &self,
294        _myself: ActorRef<Self::Msg>,
295        args: Self::Arguments,
296    ) -> Result<Self::State, ActorProcessingErr> {
297        Ok(SupervisorState::new(args))
298    }
299
300    async fn post_start(
301        &self,
302        myself: ActorRef<Self::Msg>,
303        state: &mut SupervisorState,
304    ) -> Result<(), ActorProcessingErr> {
305        state.spawn_all_children(myself).await?;
307        Ok(())
308    }
309
310    async fn handle(
313        &self,
314        myself: ActorRef<Self::Msg>,
315        msg: SupervisorMsg,
316        state: &mut SupervisorState,
317    ) -> Result<(), ActorProcessingErr> {
318        let result = match msg {
319            SupervisorMsg::OneForOneSpawn { child_id } => {
320                state
321                    .perform_one_for_one_spawn(&child_id, myself.clone())
322                    .await
323            }
324            SupervisorMsg::OneForAllSpawn { child_id } => {
325                state
326                    .perform_one_for_all_spawn(&child_id, myself.clone())
327                    .await
328            }
329            SupervisorMsg::RestForOneSpawn { child_id } => {
330                state
331                    .perform_rest_for_one_spawn(&child_id, myself.clone())
332                    .await
333            }
334            SupervisorMsg::InspectState(rpc_reply_port) => {
335                rpc_reply_port.send(state.clone())?;
336                Ok(())
337            }
338        };
339
340        #[cfg(test)]
341        {
342            store_final_state(myself, state).await;
343        }
344
345        result
347    }
348
349    async fn handle_supervisor_evt(
353        &self,
354        myself: ActorRef<Self::Msg>,
355        evt: SupervisionEvent,
356        state: &mut Self::State,
357    ) -> Result<(), ActorProcessingErr> {
358        match evt {
359            SupervisionEvent::ActorStarted(cell) => {
360                let child_id = cell
361                    .get_name()
362                    .ok_or(SupervisorError::ChildNameNotSet { pid: cell.get_id() })?;
363                log::info!("Started child: {}", child_id);
364                if state.child_specs.iter().any(|s| s.id == child_id) {
365                    state
367                        .child_failure_state
368                        .entry(child_id.clone())
369                        .or_insert_with(|| ChildFailureState {
370                            restart_count: 0,
371                            last_fail_instant: ractor::concurrency::Instant::now(),
372                        });
373                }
374            }
375            SupervisionEvent::ActorTerminated(cell, _final_state, reason) => {
376                let child_id = cell
378                    .get_name()
379                    .ok_or(SupervisorError::ChildNameNotSet { pid: cell.get_id() })?;
380                let child_specs = std::mem::take(&mut state.child_specs);
381                if let Some(spec) = child_specs.iter().find(|s| s.id == child_id) {
382                    state.handle_child_restart(
383                        spec,
384                        false,
385                        myself.clone(),
386                        &ExitReason::Reason(reason),
387                    )?;
388                }
389                state.child_specs = child_specs;
390            }
391            SupervisionEvent::ActorFailed(cell, err) => {
392                let child_id = cell
394                    .get_name()
395                    .ok_or(SupervisorError::ChildNameNotSet { pid: cell.get_id() })?;
396                let child_specs = std::mem::take(&mut state.child_specs);
397                if let Some(spec) = child_specs.iter().find(|s| s.id == child_id) {
398                    state.handle_child_restart(
399                        spec,
400                        true,
401                        myself.clone(),
402                        &ExitReason::Error(err),
403                    )?;
404                }
405                state.child_specs = child_specs;
406            }
407            SupervisionEvent::ProcessGroupChanged(_group) => {}
408        }
409        Ok(())
410    }
411
412    async fn post_stop(
415        &self,
416        _myself: ActorRef<Self::Msg>,
417        _state: &mut Self::State,
418    ) -> Result<(), ActorProcessingErr> {
419        #[cfg(test)]
420        {
421            store_final_state(_myself, _state).await;
422        }
423        Ok(())
424    }
425}
426
427#[cfg(test)]
428async fn store_final_state(myself: ActorRef<SupervisorMsg>, state: &SupervisorState) {
429    let mut map = SUPERVISOR_FINAL
430        .get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
431        .lock()
432        .await;
433    if let Some(name) = myself.get_name() {
434        map.insert(name, state.clone());
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::core::{ChildBackoffFn, Restart};
442    use crate::SpawnFn;
443    use ractor::concurrency::Instant;
444    use ractor::{call_t, Actor, ActorCell, ActorRef, ActorStatus};
445    use serial_test::serial;
446    use std::sync::atomic::{AtomicU64, Ordering};
447    use std::sync::Arc;
448
449    #[cfg(test)]
450    static ACTOR_CALL_COUNT: std::sync::OnceLock<
451        tokio::sync::Mutex<std::collections::HashMap<String, u64>>,
452    > = std::sync::OnceLock::new();
453
454    async fn before_each() {
455        if let Some(map) = SUPERVISOR_FINAL.get() {
457            let mut map = map.lock().await;
458            map.clear();
459        }
460        if let Some(map) = ACTOR_CALL_COUNT.get() {
462            let mut map = map.lock().await;
463            map.clear();
464        }
465        sleep(Duration::from_millis(10)).await;
466    }
467
468    async fn increment_actor_count(child_id: &str) {
469        let mut map = ACTOR_CALL_COUNT
470            .get_or_init(|| tokio::sync::Mutex::new(std::collections::HashMap::new()))
471            .lock()
472            .await;
473        *map.entry(child_id.to_string()).or_default() += 1;
474    }
475
476    async fn read_final_supervisor_state(sup_name: &str) -> SupervisorState {
478        let map = SUPERVISOR_FINAL
479            .get()
480            .expect("SUPERVISOR_FINAL not initialized!")
481            .lock()
482            .await;
483        map.get(sup_name)
484            .cloned()
485            .unwrap_or_else(|| panic!("No final state for supervisor '{sup_name}'"))
486    }
487
488    async fn read_actor_call_count(child_id: &str) -> u64 {
489        let map = ACTOR_CALL_COUNT
490            .get()
491            .expect("ACTOR_CALL_COUNT not initialized!")
492            .lock()
493            .await;
494        *map.get(child_id)
495            .unwrap_or_else(|| panic!("No actor call count for child '{child_id}'"))
496    }
497
498    #[derive(Clone)]
500    pub enum ChildBehavior {
501        DelayedFail {
502            ms: u64,
503        },
504        DelayedNormal {
505            ms: u64,
506        },
507        ImmediateFail,
508        ImmediateNormal,
509        CountedFails {
510            delay_ms: u64,
511            fail_count: u64,
512            current: Arc<AtomicU64>,
513        },
514        FailWaitFail {
515            initial_fails: u64,
516            wait_ms: u64,
517            final_fails: u64,
518            current: Arc<AtomicU64>,
519        },
520    }
521
522    pub struct TestChild;
523
524    #[cfg_attr(feature = "async-trait", ractor::async_trait)]
525    impl Actor for TestChild {
526        type Msg = ();
527        type State = ChildBehavior;
528        type Arguments = ChildBehavior;
529
530        async fn pre_start(
531            &self,
532            myself: ActorRef<Self::Msg>,
533            arg: Self::Arguments,
534        ) -> Result<Self::State, ractor::ActorProcessingErr> {
535            increment_actor_count(myself.get_name().unwrap().as_str()).await;
537            match arg {
538                ChildBehavior::DelayedFail { ms } => {
539                    myself.send_after(Duration::from_millis(ms), || ());
540                }
541                ChildBehavior::DelayedNormal { ms } => {
542                    myself.send_after(Duration::from_millis(ms), || ());
543                }
544                ChildBehavior::ImmediateFail => {
545                    panic!("Immediate fail => ActorFailed");
546                }
547                ChildBehavior::ImmediateNormal => {
548                    myself.stop(None);
549                }
550                ChildBehavior::CountedFails { delay_ms, .. } => {
551                    myself.send_after(Duration::from_millis(delay_ms), || ());
552                }
553                ChildBehavior::FailWaitFail { .. } => {
554                    myself.cast(())?;
556                }
557            }
558            Ok(arg)
559        }
560
561        async fn handle(
562            &self,
563            myself: ActorRef<Self::Msg>,
564            _msg: Self::Msg,
565            state: &mut Self::State,
566        ) -> Result<(), ractor::ActorProcessingErr> {
567            match state {
568                ChildBehavior::DelayedFail { .. } => {
569                    panic!("Delayed fail => ActorFailed");
570                }
571                ChildBehavior::DelayedNormal { .. } => {
572                    myself.stop(None);
573                }
574                ChildBehavior::ImmediateFail => {
575                    panic!("ImmediateFail => ActorFailed");
576                }
577                ChildBehavior::ImmediateNormal => {
578                    myself.stop(None);
579                }
580                ChildBehavior::CountedFails {
581                    fail_count,
582                    current,
583                    ..
584                } => {
585                    let old = current.fetch_add(1, Ordering::SeqCst);
586                    let newv = old + 1;
587                    if newv <= *fail_count {
588                        panic!("CountedFails => fail #{newv}");
589                    }
590                }
591                ChildBehavior::FailWaitFail {
592                    initial_fails,
593                    wait_ms,
594                    final_fails,
595                    current,
596                } => {
597                    let so_far = current.fetch_add(1, Ordering::SeqCst) + 1;
598                    if so_far <= *initial_fails {
599                        panic!("FailWaitFail => initial fail #{so_far}");
600                    } else if so_far == *initial_fails + 1 {
601                        myself.send_after(Duration::from_millis(*wait_ms), || ());
603                    } else {
604                        let n = so_far - (*initial_fails + 1);
605                        if n <= *final_fails {
606                            panic!("FailWaitFail => final fail #{n}");
607                        }
608                    }
609                }
610            }
611            Ok(())
612        }
613    }
614
615    fn get_running_children(sup_ref: &ActorRef<SupervisorMsg>) -> HashMap<String, ActorCell> {
616        sup_ref
617            .get_children()
618            .into_iter()
619            .filter_map(|c| {
620                if c.get_status() == ActorStatus::Running {
621                    c.get_name().map(|n| (n, c))
622                } else {
623                    None
624                }
625            })
626            .collect()
627    }
628
629    async fn spawn_test_child(
631        sup_cell: ActorCell,
632        id: String,
633        behavior: ChildBehavior,
634    ) -> Result<ActorCell, SpawnErr> {
635        let (ch_ref, _join) = Supervisor::spawn_linked(id, TestChild, behavior, sup_cell).await?;
636        Ok(ch_ref.get_cell())
637    }
638
639    fn make_child_spec(id: &str, restart: Restart, behavior: ChildBehavior) -> ChildSpec {
641        ChildSpec {
642            id: id.to_string(),
643            restart,
644            spawn_fn: SpawnFn::new(move |sup_cell, child_id| {
645                spawn_test_child(sup_cell, child_id, behavior.clone())
646            }),
647            backoff_fn: None, reset_after: None,
649        }
650    }
651
652    #[ractor::concurrency::test]
653    #[serial]
654    async fn test_permanent_delayed_fail() -> Result<(), Box<dyn std::error::Error>> {
655        before_each().await;
656
657        let child_spec = make_child_spec(
659            "fail-delay",
660            Restart::Permanent,
661            ChildBehavior::DelayedFail { ms: 200 },
662        );
663        let options = SupervisorOptions {
664            strategy: SupervisorStrategy::OneForOne,
665            max_restarts: 1, max_window: Duration::from_secs(2),
667            reset_after: None,
668        };
669        let args = SupervisorArguments {
670            child_specs: vec![child_spec],
671            options,
672        };
673
674        let (sup_ref, sup_handle) =
675            Supervisor::spawn("test_permanent_delayed_fail".into(), args).await?;
676
677        sleep(Duration::from_millis(100)).await;
678        let st = call_t!(sup_ref, SupervisorMsg::InspectState, 500).unwrap();
679        let mut running = get_running_children(&sup_ref);
680        assert_eq!(running.len(), 1);
681        assert_eq!(st.restart_log.len(), 0);
682
683        let _ = sup_handle.await;
685        assert_eq!(sup_ref.get_status(), ActorStatus::Stopped);
686
687        let final_st = read_final_supervisor_state("test_permanent_delayed_fail").await;
689        running = get_running_children(&sup_ref);
690        assert_eq!(running.len(), 0);
691        assert!(final_st.restart_log.len() >= 2);
692
693        assert_eq!(read_actor_call_count("fail-delay").await, 2);
695
696        Ok(())
697    }
698
699    #[ractor::concurrency::test]
700    #[serial]
701    async fn test_transient_delayed_normal() -> Result<(), Box<dyn std::error::Error>> {
702        before_each().await;
703
704        let child_spec = make_child_spec(
706            "normal-delay",
707            Restart::Transient,
708            ChildBehavior::DelayedNormal { ms: 300 },
709        );
710        let options = SupervisorOptions {
711            strategy: SupervisorStrategy::OneForOne,
712            max_restarts: 5,
713            max_window: Duration::from_secs(5),
714            reset_after: None,
715        };
716        let args = SupervisorArguments {
717            child_specs: vec![child_spec],
718            options,
719        };
720
721        let (sup_ref, sup_handle) =
722            Supervisor::spawn("test_transient_delayed_normal".into(), args).await?;
723
724        sleep(Duration::from_millis(150)).await;
725        let st1 = call_t!(sup_ref, SupervisorMsg::InspectState, 500).unwrap();
726
727        let running = get_running_children(&sup_ref);
728        assert_eq!(running.len(), 1);
729        assert_eq!(st1.restart_log.len(), 0);
730
731        sleep(Duration::from_millis(300)).await;
733        sup_ref.stop(None);
734        let _ = sup_handle.await;
735
736        let final_state = read_final_supervisor_state("test_transient_delayed_normal").await;
737        let running = get_running_children(&sup_ref);
738        assert!(!running.contains_key("normal-delay"));
739        assert_eq!(final_state.restart_log.len(), 0);
740
741        assert_eq!(read_actor_call_count("normal-delay").await, 1);
743
744        Ok(())
745    }
746
747    #[ractor::concurrency::test]
748    #[serial]
749    async fn test_temporary_delayed_fail() -> Result<(), Box<dyn std::error::Error>> {
750        before_each().await;
751
752        let child_spec = make_child_spec(
754            "temp-delay",
755            Restart::Temporary,
756            ChildBehavior::DelayedFail { ms: 200 },
757        );
758        let options = SupervisorOptions {
759            strategy: SupervisorStrategy::OneForOne,
760            max_restarts: 10,
761            max_window: Duration::from_secs(10),
762            reset_after: None,
763        };
764        let args = SupervisorArguments {
765            child_specs: vec![child_spec],
766            options,
767        };
768
769        let (sup_ref, sup_handle) =
770            Supervisor::spawn("test_temporary_delayed_fail".into(), args).await?;
771
772        sleep(Duration::from_millis(100)).await;
773        let st1 = call_t!(sup_ref, SupervisorMsg::InspectState, 500).unwrap();
774        let running = get_running_children(&sup_ref);
775        assert_eq!(running.len(), 1);
776        assert_eq!(st1.restart_log.len(), 0);
777
778        sleep(Duration::from_millis(300)).await;
780        assert_eq!(sup_ref.get_status(), ActorStatus::Running);
781
782        sup_ref.stop(None);
783        let _ = sup_handle.await;
784
785        let final_state = read_final_supervisor_state("test_temporary_delayed_fail").await;
786        let running = get_running_children(&sup_ref);
787        assert_eq!(running.len(), 0);
788        assert_eq!(final_state.restart_log.len(), 0);
789
790        assert_eq!(read_actor_call_count("temp-delay").await, 1);
792
793        Ok(())
794    }
795
796    #[ractor::concurrency::test]
797    #[serial]
798    async fn test_one_for_all_stop_all_on_failure() -> Result<(), Box<dyn std::error::Error>> {
799        before_each().await;
800
801        let child1 = make_child_spec(
803            "ofa-fail",
804            Restart::Permanent,
805            ChildBehavior::DelayedFail { ms: 200 },
806        );
807        let child2 = make_child_spec(
808            "ofa-normal",
809            Restart::Permanent,
810            ChildBehavior::DelayedNormal { ms: 9999 },
811        );
812
813        let options = SupervisorOptions {
814            strategy: SupervisorStrategy::OneForAll,
815            max_restarts: 2,
816            max_window: Duration::from_secs(2),
817            reset_after: None,
818        };
819        let args = SupervisorArguments {
820            child_specs: vec![child1, child2],
821            options,
822        };
823        let (sup_ref, sup_handle) =
824            Supervisor::spawn("test_one_for_all_stop_all_on_failure".into(), args).await?;
825
826        sleep(Duration::from_millis(100)).await;
827        let running_children = get_running_children(&sup_ref);
828        assert_eq!(running_children.len(), 2);
829
830        let _ = sup_handle.await;
831        assert_eq!(sup_ref.get_status(), ActorStatus::Stopped);
832
833        let final_state = read_final_supervisor_state("test_one_for_all_stop_all_on_failure").await;
834        assert_eq!(sup_ref.get_children().len(), 0);
835        assert_eq!(final_state.restart_log.len(), 3);
836
837        assert_eq!(read_actor_call_count("ofa-fail").await, 3);
840        assert_eq!(read_actor_call_count("ofa-normal").await, 3);
842
843        Ok(())
844    }
845
846    #[ractor::concurrency::test]
847    #[serial]
848    async fn test_rest_for_one_restart_subset() -> Result<(), Box<dyn std::error::Error>> {
849        before_each().await;
850
851        let child_a = make_child_spec(
853            "A",
854            Restart::Permanent,
855            ChildBehavior::DelayedNormal { ms: 9999 },
856        );
857        let child_b = make_child_spec(
858            "B",
859            Restart::Permanent,
860            ChildBehavior::DelayedFail { ms: 200 },
861        );
862        let child_c = make_child_spec(
863            "C",
864            Restart::Permanent,
865            ChildBehavior::DelayedNormal { ms: 9999 },
866        );
867
868        let options = SupervisorOptions {
869            strategy: SupervisorStrategy::RestForOne,
870            max_restarts: 1,
871            max_window: Duration::from_secs(2),
872            reset_after: None,
873        };
874        let args = SupervisorArguments {
875            child_specs: vec![child_a, child_b, child_c],
876            options,
877        };
878        let (sup_ref, sup_handle) =
879            Supervisor::spawn("test_rest_for_one_restart_subset".into(), args).await?;
880
881        sleep(Duration::from_millis(100)).await;
882        let running_children = get_running_children(&sup_ref);
883        assert_eq!(running_children.len(), 3);
884
885        let _ = sup_handle.await;
886        assert_eq!(sup_ref.get_status(), ActorStatus::Stopped);
887
888        let final_state = read_final_supervisor_state("test_rest_for_one_restart_subset").await;
889        assert_eq!(sup_ref.get_children().len(), 0);
890        assert_eq!(final_state.restart_log.len(), 2);
891        assert_eq!(final_state.restart_log[0].child_id, "B");
892        assert_eq!(final_state.restart_log[1].child_id, "B");
893
894        assert_eq!(read_actor_call_count("A").await, 1);
897        assert_eq!(read_actor_call_count("B").await, 2);
898        assert_eq!(read_actor_call_count("C").await, 2);
899
900        Ok(())
901    }
902
903    #[ractor::concurrency::test]
904    #[serial]
905    async fn test_max_restarts_in_time_window() -> Result<(), Box<dyn std::error::Error>> {
906        before_each().await;
907
908        let child_spec =
910            make_child_spec("fastfail", Restart::Permanent, ChildBehavior::ImmediateFail);
911
912        let options = SupervisorOptions {
913            strategy: SupervisorStrategy::OneForOne,
914            max_restarts: 2,
915            max_window: Duration::from_secs(1),
916            reset_after: None,
917        };
918        let args = SupervisorArguments {
919            child_specs: vec![child_spec],
920            options,
921        };
922
923        let (sup_ref, sup_handle) =
924            Supervisor::spawn("test_max_restarts_in_time_window".into(), args).await?;
925
926        let _ = sup_handle.await;
927        assert_eq!(sup_ref.get_status(), ActorStatus::Stopped);
928
929        let final_state = read_final_supervisor_state("test_max_restarts_in_time_window").await;
930        assert_eq!(
931            final_state.restart_log.len(),
932            3,
933            "3 fails in <1s => meltdown"
934        );
935
936        assert_eq!(read_actor_call_count("fastfail").await, 3);
938
939        Ok(())
940    }
941
942    #[ractor::concurrency::test]
943    #[serial]
944    async fn test_transient_abnormal_exit() -> Result<(), Box<dyn std::error::Error>> {
945        before_each().await;
946
947        let child_spec = make_child_spec(
949            "transient-bad",
950            Restart::Transient,
951            ChildBehavior::ImmediateFail,
952        );
953
954        let options = SupervisorOptions {
955            strategy: SupervisorStrategy::OneForOne,
956            max_restarts: 0, max_window: Duration::from_secs(5),
958            reset_after: None,
959        };
960
961        let args = SupervisorArguments {
962            child_specs: vec![child_spec],
963            options,
964        };
965        let (sup_ref, sup_handle) =
966            Supervisor::spawn("test_transient_abnormal_exit".into(), args).await?;
967
968        let _ = sup_handle.await;
969        assert_eq!(sup_ref.get_status(), ActorStatus::Stopped);
970
971        let final_state = read_final_supervisor_state("test_transient_abnormal_exit").await;
972        assert_eq!(
973            final_state.restart_log.len(),
974            1,
975            "1 fail => meltdown with max_restarts=0"
976        );
977
978        assert_eq!(read_actor_call_count("transient-bad").await, 1);
980
981        Ok(())
982    }
983
984    #[ractor::concurrency::test]
985    #[serial]
986    async fn test_backoff_fn_delays_restart() -> Result<(), Box<dyn std::error::Error>> {
987        before_each().await;
988
989        let child_backoff: ChildBackoffFn =
992            ChildBackoffFn::new(|_id, count, _last, _child_reset| {
993                if count <= 1 {
994                    None
995                } else {
996                    Some(Duration::from_secs(2))
997                }
998            });
999
1000        let mut child_spec =
1001            make_child_spec("backoff", Restart::Permanent, ChildBehavior::ImmediateFail);
1002        child_spec.backoff_fn = Some(child_backoff);
1003
1004        let options = SupervisorOptions {
1005            strategy: SupervisorStrategy::OneForOne,
1006            max_restarts: 1, max_window: Duration::from_secs(10),
1008            reset_after: None,
1009        };
1010        let args = SupervisorArguments {
1011            child_specs: vec![child_spec],
1012            options,
1013        };
1014
1015        let before = Instant::now();
1016        let (sup_ref, sup_handle) =
1017            Supervisor::spawn("test_backoff_fn_delays_restart".into(), args).await?;
1018        let _ = sup_handle.await;
1019
1020        let elapsed = before.elapsed();
1021        assert!(
1022            elapsed >= Duration::from_secs(2),
1023            "2s delay on second fail due to child-level backoff"
1024        );
1025        assert_eq!(sup_ref.get_status(), ActorStatus::Stopped);
1026
1027        let final_st = read_final_supervisor_state("test_backoff_fn_delays_restart").await;
1028        assert_eq!(
1029            final_st.restart_log.len(),
1030            2,
1031            "first fail => immediate restart => second fail => meltdown"
1032        );
1033
1034        assert_eq!(read_actor_call_count("backoff").await, 2);
1036
1037        Ok(())
1038    }
1039
1040    #[ractor::concurrency::test]
1041    #[serial]
1042    async fn test_restart_counter_reset_after() -> Result<(), Box<dyn std::error::Error>> {
1043        before_each().await;
1044
1045        let behavior = ChildBehavior::FailWaitFail {
1049            initial_fails: 2,
1050            wait_ms: 3000,
1051            final_fails: 1,
1052            current: Arc::new(AtomicU64::new(0)),
1053        };
1054
1055        let child_spec = ChildSpec {
1056            id: "reset-test".to_string(),
1057            restart: Restart::Permanent,
1058            spawn_fn: SpawnFn::new(move |sup_cell, id| {
1059                spawn_test_child(sup_cell, id, behavior.clone())
1060            }),
1061            backoff_fn: None,
1062            reset_after: None, };
1064
1065        let options = SupervisorOptions {
1068            strategy: SupervisorStrategy::OneForOne,
1069            max_restarts: 2,
1070            max_window: Duration::from_secs(10),
1071            reset_after: Some(Duration::from_secs(2)), };
1073
1074        let args = SupervisorArguments {
1075            child_specs: vec![child_spec],
1076            options,
1077        };
1078        let (sup_ref, sup_handle) =
1079            Supervisor::spawn("test_restart_counter_reset_after_improved".into(), args).await?;
1080
1081        sleep(Duration::from_secs(4)).await;
1083
1084        sup_ref.stop(None);
1086        let _ = sup_handle.await;
1087
1088        let final_st =
1089            read_final_supervisor_state("test_restart_counter_reset_after_improved").await;
1090        assert_eq!(sup_ref.get_status(), ActorStatus::Stopped);
1091        assert_eq!(
1092            final_st.restart_log.len(),
1093            1,
1094            "After clearing, we only see a single fail in meltdown log"
1095        );
1096
1097        assert_eq!(read_actor_call_count("reset-test").await, 4);
1101
1102        Ok(())
1103    }
1104
1105    #[ractor::concurrency::test]
1106    #[serial]
1107    async fn test_child_level_restart_counter_reset_after() -> Result<(), Box<dyn std::error::Error>>
1108    {
1109        before_each().await;
1110
1111        let behavior = ChildBehavior::FailWaitFail {
1116            initial_fails: 2,
1117            wait_ms: 3000,
1118            final_fails: 1,
1119            current: Arc::new(AtomicU64::new(0)),
1120        };
1121
1122        let mut child_spec = make_child_spec("child-reset", Restart::Permanent, behavior);
1123        child_spec.reset_after = Some(Duration::from_secs(2));
1125
1126        let options = SupervisorOptions {
1128            strategy: SupervisorStrategy::OneForOne,
1129            max_restarts: 5,
1130            max_window: Duration::from_secs(30),
1131            reset_after: None,
1132        };
1133        let args = SupervisorArguments {
1134            child_specs: vec![child_spec],
1135            options,
1136        };
1137
1138        let (sup_ref, sup_handle) =
1139            Supervisor::spawn("test_child_level_restart_counter_reset_after".into(), args).await?;
1140
1141        sleep(Duration::from_millis(100)).await;
1143        let st1 = call_t!(sup_ref, SupervisorMsg::InspectState, 500).unwrap();
1144        let cfs1 = st1.child_failure_state.get("child-reset").unwrap();
1145        assert_eq!(cfs1.restart_count, 2);
1146
1147        sleep(Duration::from_secs(3)).await;
1149
1150        sup_ref.stop(None);
1152        let _ = sup_handle.await;
1153
1154        let final_st =
1155            read_final_supervisor_state("test_child_level_restart_counter_reset_after").await;
1156        let cfs2 = final_st.child_failure_state.get("child-reset").unwrap();
1157        assert_eq!(
1158            cfs2.restart_count, 1,
1159            "child-level reset => next fail sees count=1"
1160        );
1161
1162        assert_eq!(read_actor_call_count("child-reset").await, 4);
1164
1165        Ok(())
1166    }
1167
1168    #[ractor::concurrency::test]
1172    #[serial]
1173    async fn test_nested_supervisors() -> Result<(), Box<dyn std::error::Error>> {
1174        before_each().await;
1175
1176        async fn spawn_subsupervisor(
1177            sup_cell: ActorCell,
1178            id: String,
1179            args: SupervisorArguments,
1180        ) -> Result<ActorCell, SpawnErr> {
1181            let (sub_sup_ref, _join) =
1182                Supervisor::spawn_linked(id, Supervisor, args, sup_cell).await?;
1183            Ok(sub_sup_ref.get_cell())
1184        }
1185
1186        let sub_sup_spec = ChildSpec {
1188            id: "sub-sup".to_string(),
1189            restart: Restart::Permanent,
1190            spawn_fn: SpawnFn::new(move |cell, id| {
1191                let leaf_child = ChildSpec {
1192                    id: "leaf-worker".to_string(),
1193                    restart: Restart::Transient,
1194                    spawn_fn: SpawnFn::new(|c, i| {
1195                        let bh = ChildBehavior::DelayedFail { ms: 300 };
1197                        spawn_test_child(c, i, bh)
1198                    }),
1199                    backoff_fn: None,
1200                    reset_after: None,
1201                };
1202
1203                let sub_sup_args = SupervisorArguments {
1204                    child_specs: vec![leaf_child],
1205                    options: SupervisorOptions {
1206                        strategy: SupervisorStrategy::OneForOne,
1207                        max_restarts: 1, max_window: Duration::from_secs(2),
1209                        reset_after: None,
1210                    },
1211                };
1212                spawn_subsupervisor(cell, id, sub_sup_args)
1213            }),
1214            backoff_fn: None,
1215            reset_after: None,
1216        };
1217
1218        let root_args = SupervisorArguments {
1220            child_specs: vec![sub_sup_spec],
1221            options: SupervisorOptions {
1222                strategy: SupervisorStrategy::OneForOne,
1223                max_restarts: 1, max_window: Duration::from_secs(5),
1225                reset_after: None,
1226            },
1227        };
1228
1229        let (root_sup_ref, root_handle) = Supervisor::spawn("root-sup".into(), root_args).await?;
1230
1231        sleep(Duration::from_millis(600)).await;
1233        assert_eq!(root_sup_ref.get_status(), ActorStatus::Running);
1234
1235        root_sup_ref.stop(None);
1237        let _ = root_handle.await;
1238
1239        let root_final = read_final_supervisor_state("root-sup").await;
1240        let sub_final = read_final_supervisor_state("sub-sup").await;
1241
1242        assert_eq!(root_final.restart_log.len(), 0);
1243        assert_eq!(sub_final.restart_log.len(), 1);
1244
1245        assert_eq!(read_actor_call_count("leaf-worker").await, 2);
1246
1247        Ok(())
1248    }
1249}