use ractor::concurrency::{Duration, Instant};
use ractor::{ActorCell, ActorId, ActorProcessingErr, ActorRef, Message, SpawnErr};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum SupervisorError {
#[error("Child '{child_id}' not found in specs")]
ChildNotFound { child_id: String },
#[error("Child '{pid}' does not have a name set")]
ChildNameNotSet { pid: ActorId },
#[error("Spawn error '{child_id}': {reason}")]
ChildSpawnError { child_id: String, reason: String },
#[error("Meltdown: {reason}")]
Meltdown { reason: String },
}
#[derive(Clone)]
pub struct ChildBackoffFn(pub Arc<BackoffFn>);
type BackoffFn = dyn Fn(&str, usize, Instant, Option<Duration>) -> Option<Duration> + Send + Sync;
impl ChildBackoffFn {
pub fn new<F>(func: F) -> Self
where
F: Fn(&str, usize, Instant, Option<Duration>) -> Option<Duration> + Send + Sync + 'static,
{
Self(Arc::new(func))
}
pub fn call(
&self,
child_id: &str,
restart_count: usize,
last_restart: Instant,
reset_after: Option<Duration>,
) -> Option<Duration> {
(self.0)(child_id, restart_count, last_restart, reset_after)
}
}
pub type SpawnFuture = Pin<Box<dyn Future<Output = Result<ActorCell, SpawnErr>> + Send>>;
#[derive(Clone)]
pub struct SpawnFn(pub Arc<dyn Fn(ActorCell, String) -> SpawnFuture + Send + Sync>);
impl SpawnFn {
pub fn new<F, Fut>(func: F) -> Self
where
F: Fn(ActorCell, String) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<ActorCell, SpawnErr>> + Send + 'static,
{
Self(Arc::new(move |cell, id| Box::pin(func(cell, id))))
}
pub fn call(&self, cell: ActorCell, id: String) -> SpawnFuture {
(self.0)(cell, id)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Restart {
Permanent,
Transient,
Temporary,
}
#[derive(Clone)]
pub struct ChildSpec {
pub id: String,
pub restart: Restart,
pub spawn_fn: SpawnFn,
pub backoff_fn: Option<ChildBackoffFn>,
pub reset_after: Option<Duration>,
}
impl std::fmt::Debug for ChildSpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChildSpec")
.field("id", &self.id)
.field("restart", &self.restart)
.field("reset_after", &self.reset_after)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ChildFailureState {
pub restart_count: usize,
pub last_fail_instant: Instant,
}
#[derive(Clone, Debug)]
pub struct RestartLog {
pub child_id: String,
pub timestamp: Instant,
}
pub trait CoreSupervisorOptions<Strategy> {
fn max_restarts(&self) -> usize;
fn max_window(&self) -> Duration;
fn reset_after(&self) -> Option<Duration>;
fn strategy(&self) -> Strategy;
}
#[derive(Debug)]
pub enum ExitReason {
Normal,
Reason(Option<String>),
Error(Box<dyn std::error::Error + Send + Sync>),
}
pub trait SupervisorCore {
type Message: Message;
type Strategy;
type Options: CoreSupervisorOptions<Self::Strategy>;
fn child_failure_state(&mut self) -> &mut HashMap<String, ChildFailureState>;
fn restart_log(&mut self) -> &mut Vec<RestartLog>;
fn options(&self) -> &Self::Options;
fn restart_msg(
&self,
child_spec: &ChildSpec,
strategy: Self::Strategy,
myself: ActorRef<Self::Message>,
) -> Self::Message;
fn prepare_child_failure(&mut self, child_spec: &ChildSpec) -> Result<(), ActorProcessingErr> {
let child_id = &child_spec.id;
let now = Instant::now();
let entry = self
.child_failure_state()
.entry(child_id.clone())
.or_insert_with(|| ChildFailureState {
restart_count: 0,
last_fail_instant: now,
});
if let Some(threshold) = child_spec.reset_after {
if now.duration_since(entry.last_fail_instant) >= threshold {
entry.restart_count = 0;
}
}
entry.restart_count += 1;
entry.last_fail_instant = now;
Ok(())
}
fn handle_child_exit(
&mut self,
child_spec: &ChildSpec,
abnormal: bool,
) -> Result<bool, ActorProcessingErr> {
let policy = child_spec.restart;
let should_restart = match policy {
Restart::Permanent => true,
Restart::Transient => abnormal,
Restart::Temporary => false,
};
if should_restart {
self.prepare_child_failure(child_spec)?;
}
Ok(should_restart)
}
fn handle_child_restart(
&mut self,
child_spec: &ChildSpec,
abnormal: bool,
myself: ActorRef<Self::Message>,
reason: &ExitReason,
) -> Result<(), ActorProcessingErr> {
if self.handle_child_exit(child_spec, abnormal)? {
log_child_restart(child_spec, abnormal, reason);
self.schedule_restart(child_spec, self.options().strategy(), myself.clone())?;
}
Ok(())
}
fn track_global_restart(&mut self, child_id: &str) -> Result<(), ActorProcessingErr> {
let now: Instant = Instant::now();
let max_restarts = self.options().max_restarts();
let max_window = self.options().max_window();
let reset_after = self.options().reset_after();
let restart_log = self.restart_log();
if let (Some(thresh), Some(latest)) = (reset_after, restart_log.last()) {
if now.duration_since(latest.timestamp) >= thresh {
restart_log.clear();
}
}
restart_log.push(RestartLog {
child_id: child_id.to_string(),
timestamp: now,
});
restart_log.retain(|t| now.duration_since(t.timestamp) < max_window);
if restart_log.len() > max_restarts {
Err(SupervisorError::Meltdown {
reason: "max_restarts exceeded".to_string(),
}
.into())
} else {
Ok(())
}
}
fn schedule_restart(
&mut self,
child_spec: &ChildSpec,
strategy: Self::Strategy,
myself: ActorRef<Self::Message>,
) -> Result<(), ActorProcessingErr> {
let child_id = &child_spec.id;
let (restart_count, last_fail_instant) = {
let failure_state = self.child_failure_state();
let st = failure_state
.get(child_id)
.ok_or(SupervisorError::ChildNotFound {
child_id: child_id.clone(),
})?;
(st.restart_count, st.last_fail_instant)
};
let msg = self.restart_msg(child_spec, strategy, myself.clone());
let delay = child_spec
.backoff_fn
.as_ref()
.and_then(|cb: &ChildBackoffFn| {
cb.call(
child_id,
restart_count,
last_fail_instant,
child_spec.reset_after,
)
});
match delay {
Some(delay) => {
myself.send_after(delay, move || msg);
}
None => {
myself.send_message(msg)?;
}
}
Ok(())
}
}
fn log_child_restart(child_spec: &ChildSpec, abnormal: bool, reason: &ExitReason) {
match (abnormal, reason) {
(true, ExitReason::Error(err)) => log::error!(
"Restarting child: {}, exit: abnormal, error: {:?}",
child_spec.id,
err
),
(false, ExitReason::Error(err)) => log::warn!(
"Restarting child: {}, exit: normal, error: {:?}",
child_spec.id,
err
),
(true, ExitReason::Reason(Some(reason))) => log::error!(
"Restarting child: {}, exit: abnormal, reason: {}",
child_spec.id,
reason
),
(false, ExitReason::Reason(Some(reason))) => log::warn!(
"Restarting child: {}, exit: normal, reason: {}",
child_spec.id,
reason
),
(true, ExitReason::Reason(None)) => {
log::error!("Restarting child: {}, exit: abnormal", child_spec.id)
}
(false, ExitReason::Reason(None)) => {
log::warn!("Restarting child: {}, exit: normal", child_spec.id)
}
(true, ExitReason::Normal) => {
log::error!("Restarting child: {}, exit: abnormal", child_spec.id)
}
(false, ExitReason::Normal) => {
log::warn!("Restarting child: {}, exit: normal", child_spec.id)
}
}
}