1use ractor::concurrency::{Duration, Instant};
2use ractor::{ActorCell, ActorId, ActorProcessingErr, ActorRef, Message, SpawnErr};
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use thiserror::Error;
8
9#[derive(Error, Debug, Clone)]
11pub enum SupervisorError {
12    #[error("Child '{child_id}' not found in specs")]
13    ChildNotFound { child_id: String },
14
15    #[error("Child '{pid}' does not have a name set")]
16    ChildNameNotSet { pid: ActorId },
17
18    #[error("Spawn error '{child_id}': {reason}")]
19    ChildSpawnError { child_id: String, reason: String },
20
21    #[error("Meltdown: {reason}")]
22    Meltdown { reason: String },
23}
24
25#[derive(Clone)]
35pub struct ChildBackoffFn(pub Arc<BackoffFn>);
36
37type BackoffFn = dyn Fn(&str, usize, Instant, Option<Duration>) -> Option<Duration> + Send + Sync;
38
39impl ChildBackoffFn {
40    pub fn new<F>(func: F) -> Self
50    where
51        F: Fn(&str, usize, Instant, Option<Duration>) -> Option<Duration> + Send + Sync + 'static,
52    {
53        Self(Arc::new(func))
54    }
55
56    pub fn call(
58        &self,
59        child_id: &str,
60        restart_count: usize,
61        last_restart: Instant,
62        reset_after: Option<Duration>,
63    ) -> Option<Duration> {
64        (self.0)(child_id, restart_count, last_restart, reset_after)
65    }
66}
67
68pub type SpawnFuture = Pin<Box<dyn Future<Output = Result<ActorCell, SpawnErr>> + Send>>;
70
71#[derive(Clone)]
73pub struct SpawnFn(pub Arc<dyn Fn(ActorCell, String) -> SpawnFuture + Send + Sync>);
74
75impl SpawnFn {
76    pub fn new<F, Fut>(func: F) -> Self
86    where
87        F: Fn(ActorCell, String) -> Fut + Send + Sync + 'static,
88        Fut: Future<Output = Result<ActorCell, SpawnErr>> + Send + 'static,
89    {
90        Self(Arc::new(move |cell, id| Box::pin(func(cell, id))))
91    }
92
93    pub fn call(&self, cell: ActorCell, id: String) -> SpawnFuture {
95        (self.0)(cell, id)
96    }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum Restart {
102    Permanent,
104    Transient,
107    Temporary,
109}
110
111#[derive(Clone)]
113pub struct ChildSpec {
114    pub id: String,
123
124    pub restart: Restart,
126
127    pub spawn_fn: SpawnFn,
129
130    pub backoff_fn: Option<ChildBackoffFn>,
132
133    pub reset_after: Option<Duration>,
136}
137
138impl std::fmt::Debug for ChildSpec {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        f.debug_struct("ChildSpec")
141            .field("id", &self.id)
142            .field("restart", &self.restart)
143            .field("reset_after", &self.reset_after)
144            .finish()
145    }
146}
147
148#[derive(Debug, Clone)]
150pub struct ChildFailureState {
151    pub restart_count: usize,
152    pub last_fail_instant: Instant,
153}
154
155#[derive(Clone, Debug)]
157pub struct RestartLog {
158    pub child_id: String,
159    pub timestamp: Instant,
160}
161
162pub trait CoreSupervisorOptions<Strategy> {
163    fn max_restarts(&self) -> usize;
164    fn max_window(&self) -> Duration;
165    fn reset_after(&self) -> Option<Duration>;
166    fn strategy(&self) -> Strategy;
167}
168
169#[derive(Debug)]
170pub enum ExitReason {
171    Normal,
172    Reason(Option<String>),
173    Error(Box<dyn std::error::Error + Send + Sync>),
174}
175
176pub trait SupervisorCore {
177    type Message: Message;
178    type Strategy;
179    type Options: CoreSupervisorOptions<Self::Strategy>;
180
181    fn child_failure_state(&mut self) -> &mut HashMap<String, ChildFailureState>;
182    fn restart_log(&mut self) -> &mut Vec<RestartLog>;
183    fn options(&self) -> &Self::Options;
184    fn restart_msg(
185        &self,
186        child_spec: &ChildSpec,
187        strategy: Self::Strategy,
188        myself: ActorRef<Self::Message>,
189    ) -> Self::Message;
190
191    fn prepare_child_failure(&mut self, child_spec: &ChildSpec) -> Result<(), ActorProcessingErr> {
194        let child_id = &child_spec.id;
195        let now = Instant::now();
196        let entry = self
197            .child_failure_state()
198            .entry(child_id.clone())
199            .or_insert_with(|| ChildFailureState {
200                restart_count: 0,
201                last_fail_instant: now,
202            });
203
204        if let Some(threshold) = child_spec.reset_after {
205            if now.duration_since(entry.last_fail_instant) >= threshold {
206                entry.restart_count = 0;
207            }
208        }
209
210        entry.restart_count += 1;
211        entry.last_fail_instant = now;
212        Ok(())
213    }
214
215    fn handle_child_exit(
221        &mut self,
222        child_spec: &ChildSpec,
223        abnormal: bool,
224    ) -> Result<bool, ActorProcessingErr> {
225        let policy = child_spec.restart;
226
227        let should_restart = match policy {
229            Restart::Permanent => true,
230            Restart::Transient => abnormal,
231            Restart::Temporary => false,
232        };
233
234        if should_restart {
235            self.prepare_child_failure(child_spec)?;
236        }
237
238        Ok(should_restart)
239    }
240
241    fn handle_child_restart(
245        &mut self,
246        child_spec: &ChildSpec,
247        abnormal: bool,
248        myself: ActorRef<Self::Message>,
249        reason: &ExitReason,
250    ) -> Result<(), ActorProcessingErr> {
251        if self.handle_child_exit(child_spec, abnormal)? {
252            log_child_restart(child_spec, abnormal, reason);
253            self.schedule_restart(child_spec, self.options().strategy(), myself.clone())?;
254        }
255
256        Ok(())
257    }
258
259    fn track_global_restart(&mut self, child_id: &str) -> Result<(), ActorProcessingErr> {
265        let now: Instant = Instant::now();
266
267        let max_restarts = self.options().max_restarts();
268        let max_window = self.options().max_window();
269        let reset_after = self.options().reset_after();
270
271        let restart_log = self.restart_log();
272
273        if let (Some(thresh), Some(latest)) = (reset_after, restart_log.last()) {
274            if now.duration_since(latest.timestamp) >= thresh {
275                restart_log.clear();
276            }
277        }
278
279        restart_log.push(RestartLog {
280            child_id: child_id.to_string(),
281            timestamp: now,
282        });
283
284        restart_log.retain(|t| now.duration_since(t.timestamp) < max_window);
285
286        if restart_log.len() > max_restarts {
287            Err(SupervisorError::Meltdown {
288                reason: "max_restarts exceeded".to_string(),
289            }
290            .into())
291        } else {
292            Ok(())
293        }
294    }
295
296    fn schedule_restart(
298        &mut self,
299        child_spec: &ChildSpec,
300        strategy: Self::Strategy,
301        myself: ActorRef<Self::Message>,
302    ) -> Result<(), ActorProcessingErr> {
303        let child_id = &child_spec.id;
304
305        let (restart_count, last_fail_instant) = {
306            let failure_state = self.child_failure_state();
307            let st = failure_state
308                .get(child_id)
309                .ok_or(SupervisorError::ChildNotFound {
310                    child_id: child_id.clone(),
311                })?;
312            (st.restart_count, st.last_fail_instant)
313        };
314        let msg = self.restart_msg(child_spec, strategy, myself.clone());
315
316        let delay = child_spec
317            .backoff_fn
318            .as_ref()
319            .and_then(|cb: &ChildBackoffFn| {
320                cb.call(
321                    child_id,
322                    restart_count,
323                    last_fail_instant,
324                    child_spec.reset_after,
325                )
326            });
327
328        match delay {
329            Some(delay) => {
330                myself.send_after(delay, move || msg);
331            }
332            None => {
333                myself.send_message(msg)?;
334            }
335        }
336
337        Ok(())
338    }
339}
340
341fn log_child_restart(child_spec: &ChildSpec, abnormal: bool, reason: &ExitReason) {
342    match (abnormal, reason) {
343        (true, ExitReason::Error(err)) => log::error!(
344            "Restarting child: {}, exit: abnormal, error: {:?}",
345            child_spec.id,
346            err
347        ),
348        (false, ExitReason::Error(err)) => log::warn!(
349            "Restarting child: {}, exit: normal, error: {:?}",
350            child_spec.id,
351            err
352        ),
353        (true, ExitReason::Reason(Some(reason))) => log::error!(
354            "Restarting child: {}, exit: abnormal, reason: {}",
355            child_spec.id,
356            reason
357        ),
358        (false, ExitReason::Reason(Some(reason))) => log::warn!(
359            "Restarting child: {}, exit: normal, reason: {}",
360            child_spec.id,
361            reason
362        ),
363        (true, ExitReason::Reason(None)) => {
364            log::error!("Restarting child: {}, exit: abnormal", child_spec.id)
365        }
366        (false, ExitReason::Reason(None)) => {
367            log::warn!("Restarting child: {}, exit: normal", child_spec.id)
368        }
369        (true, ExitReason::Normal) => {
370            log::error!("Restarting child: {}, exit: abnormal", child_spec.id)
371        }
372        (false, ExitReason::Normal) => {
373            log::warn!("Restarting child: {}, exit: normal", child_spec.id)
374        }
375    }
376}