use crate::Pid;
use crate::actor::{Actor, ActorContext};
use crate::message::{ExitReason, Message, Signal};
use crate::system::{ActorRef, ActorSystem};
use crate::telemetry::SupervisorMetrics;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RestartStrategy {
OneForOne,
OneForAll,
RestForOne,
}
#[derive(Debug, Clone, Copy)]
pub struct RestartIntensity {
pub max_restarts: usize,
pub within_seconds: u64,
}
impl Default for RestartIntensity {
fn default() -> Self {
Self {
max_restarts: 3,
within_seconds: 5,
}
}
}
pub struct ChildSpec {
pub id: String,
pub start: Box<dyn FnMut() -> Box<dyn Actor> + Send>,
}
impl ChildSpec {
pub fn new(
id: impl Into<String>,
start: impl FnMut() -> Box<dyn Actor> + Send + 'static,
) -> Self {
Self {
id: id.into(),
start: Box::new(start),
}
}
}
pub struct SupervisorSpec {
pub strategy: RestartStrategy,
pub intensity: RestartIntensity,
pub children: Vec<ChildSpec>,
}
impl SupervisorSpec {
pub fn new(strategy: RestartStrategy) -> Self {
Self {
strategy,
intensity: RestartIntensity::default(),
children: Vec::new(),
}
}
pub fn child(mut self, spec: ChildSpec) -> Self {
self.children.push(spec);
self
}
pub fn intensity(mut self, intensity: RestartIntensity) -> Self {
self.intensity = intensity;
self
}
}
struct Child {
#[allow(dead_code)]
id: String,
pid: Pid,
start_factory: Box<dyn FnMut() -> Box<dyn Actor> + Send>,
restart_times: Vec<Instant>,
}
pub struct Supervisor {
strategy: RestartStrategy,
intensity: RestartIntensity,
children: HashMap<String, Child>,
child_order: Vec<String>, child_specs: Option<Vec<ChildSpec>>, system: Arc<ActorSystem>,
}
impl Supervisor {
pub fn from_spec(spec: SupervisorSpec, system: Arc<ActorSystem>) -> Self {
Self {
strategy: spec.strategy,
intensity: spec.intensity,
children: HashMap::new(),
child_order: Vec::new(),
child_specs: Some(spec.children),
system,
}
}
fn start_children(&mut self, mut specs: Vec<ChildSpec>, ctx: &mut ActorContext) {
for mut spec in specs.drain(..) {
let child_actor = (spec.start)();
let actor_ref = self.system.spawn_boxed(child_actor);
let pid = actor_ref.pid();
let _ = actor_ref.monitor(ctx.pid());
let child = Child {
id: spec.id.clone(),
pid,
start_factory: spec.start,
restart_times: Vec::new(),
};
self.children.insert(spec.id.clone(), child);
self.child_order.push(spec.id);
}
}
async fn handle_child_exit(
&mut self,
child_pid: Pid,
reason: &ExitReason,
ctx: &mut ActorContext,
) {
let child_id = self
.children
.iter()
.find(|(_, child)| child.pid == child_pid)
.map(|(id, _)| id.clone());
if let Some(child_id) = child_id {
tracing::warn!("Child {} (pid {}) exited: {}", child_id, child_pid, reason);
if reason.is_normal() {
return;
}
match self.strategy {
RestartStrategy::OneForOne => {
self.restart_child(&child_id, ctx).await;
}
RestartStrategy::OneForAll => {
self.restart_all_children(ctx).await;
}
RestartStrategy::RestForOne => {
self.restart_from_child(&child_id, ctx).await;
}
}
}
}
async fn restart_child(&mut self, child_id: &str, ctx: &mut ActorContext) {
let _span = SupervisorMetrics::restart_span();
if let Some(child) = self.children.get_mut(child_id) {
let now = Instant::now();
let cutoff = now - Duration::from_secs(self.intensity.within_seconds);
child.restart_times.retain(|&t| t > cutoff);
if child.restart_times.len() >= self.intensity.max_restarts {
tracing::error!(
"Child {} exceeded restart intensity, stopping supervisor",
child_id
);
SupervisorMetrics::restart_intensity_exceeded();
ctx.stop(ExitReason::Custom(format!(
"restart intensity exceeded for {}",
child_id
)));
return;
}
child.restart_times.push(now);
let child_actor = (child.start_factory)();
let actor_ref = self.system.spawn_boxed(child_actor);
let new_pid = actor_ref.pid();
let _ = actor_ref.monitor(ctx.pid());
child.pid = new_pid;
let strategy_str = match self.strategy {
RestartStrategy::OneForOne => "one_for_one",
RestartStrategy::OneForAll => "one_for_all",
RestartStrategy::RestForOne => "rest_for_one",
};
SupervisorMetrics::child_restarted(strategy_str);
tracing::info!("Restarted child {} with new pid {}", child_id, new_pid);
}
}
async fn restart_all_children(&mut self, ctx: &mut ActorContext) {
for child_id in self.child_order.clone() {
self.restart_child(&child_id, ctx).await;
}
}
async fn restart_from_child(&mut self, from_id: &str, ctx: &mut ActorContext) {
let mut should_restart = false;
for child_id in self.child_order.clone() {
if child_id == from_id {
should_restart = true;
}
if should_restart {
self.restart_child(&child_id, ctx).await;
}
}
}
}
#[async_trait]
impl Actor for Supervisor {
async fn started(&mut self, ctx: &mut ActorContext) {
ctx.trap_exit(true);
tracing::info!("Supervisor {} started", ctx.pid());
if let Some(specs) = self.child_specs.take() {
self.start_children(specs, ctx);
}
}
async fn handle_message(&mut self, _msg: Message, _ctx: &mut ActorContext) {
}
async fn handle_signal(&mut self, signal: Signal, ctx: &mut ActorContext) {
match signal {
Signal::Down { pid, reason, .. } => {
self.handle_child_exit(pid, &reason, ctx).await;
}
Signal::Exit { from, reason } => {
self.handle_child_exit(from, &reason, ctx).await;
}
Signal::Stop => {
ctx.stop(ExitReason::Shutdown);
}
Signal::Kill => {
ctx.stop(ExitReason::Killed);
}
}
}
async fn stopped(&mut self, _reason: &ExitReason, ctx: &mut ActorContext) {
tracing::info!("Supervisor {} stopping, terminating children", ctx.pid());
}
}
pub fn spawn_supervisor(system: &Arc<ActorSystem>, spec: SupervisorSpec) -> ActorRef {
let system_clone = Arc::clone(system);
system.spawn(Supervisor::from_spec(spec, system_clone))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_restart_strategy() {
assert_eq!(RestartStrategy::OneForOne, RestartStrategy::OneForOne);
assert_ne!(RestartStrategy::OneForOne, RestartStrategy::OneForAll);
}
#[test]
fn test_restart_intensity_default() {
let intensity = RestartIntensity::default();
assert_eq!(intensity.max_restarts, 3);
assert_eq!(intensity.within_seconds, 5);
}
#[test]
fn test_supervisor_spec() {
let spec = SupervisorSpec::new(RestartStrategy::OneForOne).intensity(RestartIntensity {
max_restarts: 5,
within_seconds: 10,
});
assert_eq!(spec.strategy, RestartStrategy::OneForOne);
assert_eq!(spec.intensity.max_restarts, 5);
}
#[tokio::test]
async fn test_child_spec() {
struct TestActor;
#[async_trait]
impl Actor for TestActor {
async fn handle_message(&mut self, _msg: Message, _ctx: &mut ActorContext) {}
}
let _spec = ChildSpec::new("test", || Box::new(TestActor) as Box<dyn Actor>);
}
}