use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use crate::actor::ActorError;
use crate::message::Message;
use crate::node::{ActorId, NodeId};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ChildTerminated {
pub child_id: ActorId,
pub child_name: String,
pub reason: Option<String>,
}
impl Message for ChildTerminated {
type Reply = ();
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SupervisionAction {
Restart,
Stop,
Escalate,
}
pub trait SupervisionStrategy: Send + Sync + 'static {
fn on_child_failed(
&self,
child_id: &ActorId,
child_name: &str,
error: &ActorError,
) -> SupervisionAction;
}
struct RestartTracker {
max_restarts: u32,
within: Duration,
timestamps: Mutex<HashMap<ActorId, Vec<Instant>>>,
}
impl RestartTracker {
fn new(max_restarts: u32, within: Duration) -> Self {
Self {
max_restarts,
within,
timestamps: Mutex::new(HashMap::new()),
}
}
fn record(&self, child_id: &ActorId) -> bool {
let now = Instant::now();
let mut map = self.timestamps.lock().unwrap();
let entries = map.entry(child_id.clone()).or_default();
entries.retain(|t| now.duration_since(*t) <= self.within);
if entries.len() as u32 >= self.max_restarts {
false
} else {
entries.push(now);
true
}
}
fn record_global(&self) -> bool {
let sentinel = ActorId {
node: NodeId("__supervision_group__".into()),
local: 0,
};
self.record(&sentinel)
}
}
pub struct OneForOne {
tracker: RestartTracker,
}
impl OneForOne {
pub fn new(max_restarts: u32, within: Duration) -> Self {
Self {
tracker: RestartTracker::new(max_restarts, within),
}
}
}
impl SupervisionStrategy for OneForOne {
fn on_child_failed(
&self,
child_id: &ActorId,
_child_name: &str,
_error: &ActorError,
) -> SupervisionAction {
if self.tracker.record(child_id) {
SupervisionAction::Restart
} else {
SupervisionAction::Stop
}
}
}
pub struct AllForOne {
tracker: RestartTracker,
}
impl AllForOne {
pub fn new(max_restarts: u32, within: Duration) -> Self {
Self {
tracker: RestartTracker::new(max_restarts, within),
}
}
}
impl SupervisionStrategy for AllForOne {
fn on_child_failed(
&self,
_child_id: &ActorId,
_child_name: &str,
_error: &ActorError,
) -> SupervisionAction {
if self.tracker.record_global() {
SupervisionAction::Restart
} else {
SupervisionAction::Stop
}
}
}
pub struct RestForOne {
tracker: RestartTracker,
children: Mutex<Vec<ActorId>>,
}
impl RestForOne {
pub fn new(max_restarts: u32, within: Duration) -> Self {
Self {
tracker: RestartTracker::new(max_restarts, within),
children: Mutex::new(Vec::new()),
}
}
pub fn register_child(&self, child_id: ActorId) {
self.children.lock().unwrap().push(child_id);
}
pub fn children_to_restart(&self, failed_id: &ActorId) -> Vec<ActorId> {
let children = self.children.lock().unwrap();
if let Some(pos) = children.iter().position(|id| id == failed_id) {
children[pos..].to_vec()
} else {
vec![failed_id.clone()]
}
}
}
impl SupervisionStrategy for RestForOne {
fn on_child_failed(
&self,
_child_id: &ActorId,
_child_name: &str,
_error: &ActorError,
) -> SupervisionAction {
if self.tracker.record_global() {
SupervisionAction::Restart
} else {
SupervisionAction::Stop
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::actor::ActorError;
use crate::node::{ActorId, NodeId};
fn test_id(local: u64) -> ActorId {
ActorId {
node: NodeId("test".into()),
local,
}
}
fn test_error() -> ActorError {
ActorError::internal("test failure")
}
#[test]
fn one_for_one_returns_restart() {
let strategy = OneForOne::new(5, Duration::from_secs(60));
let id = test_id(1);
let action = strategy.on_child_failed(&id, "child-1", &test_error());
assert_eq!(action, SupervisionAction::Restart);
}
#[test]
fn one_for_one_max_exceeded_returns_stop() {
let strategy = OneForOne::new(3, Duration::from_secs(60));
let id = test_id(1);
for _ in 0..3 {
assert_eq!(
strategy.on_child_failed(&id, "child-1", &test_error()),
SupervisionAction::Restart,
);
}
assert_eq!(
strategy.on_child_failed(&id, "child-1", &test_error()),
SupervisionAction::Stop,
);
}
#[test]
fn one_for_one_tracks_per_child() {
let strategy = OneForOne::new(2, Duration::from_secs(60));
let id1 = test_id(1);
let id2 = test_id(2);
assert_eq!(
strategy.on_child_failed(&id1, "a", &test_error()),
SupervisionAction::Restart
);
assert_eq!(
strategy.on_child_failed(&id1, "a", &test_error()),
SupervisionAction::Restart
);
assert_eq!(
strategy.on_child_failed(&id1, "a", &test_error()),
SupervisionAction::Stop
);
assert_eq!(
strategy.on_child_failed(&id2, "b", &test_error()),
SupervisionAction::Restart
);
}
#[test]
fn all_for_one_returns_restart() {
let strategy = AllForOne::new(5, Duration::from_secs(60));
let id = test_id(1);
assert_eq!(
strategy.on_child_failed(&id, "child-1", &test_error()),
SupervisionAction::Restart,
);
}
#[test]
fn all_for_one_max_exceeded_returns_stop() {
let strategy = AllForOne::new(2, Duration::from_secs(60));
let id = test_id(1);
for _ in 0..2 {
assert_eq!(
strategy.on_child_failed(&id, "child-1", &test_error()),
SupervisionAction::Restart,
);
}
assert_eq!(
strategy.on_child_failed(&id, "child-1", &test_error()),
SupervisionAction::Stop,
);
}
#[test]
fn rest_for_one_returns_restart() {
let strategy = RestForOne::new(5, Duration::from_secs(60));
let id = test_id(1);
strategy.register_child(id.clone());
assert_eq!(
strategy.on_child_failed(&id, "child-1", &test_error()),
SupervisionAction::Restart,
);
}
#[test]
fn rest_for_one_children_to_restart() {
let strategy = RestForOne::new(5, Duration::from_secs(60));
let id1 = test_id(1);
let id2 = test_id(2);
let id3 = test_id(3);
strategy.register_child(id1.clone());
strategy.register_child(id2.clone());
strategy.register_child(id3.clone());
let to_restart = strategy.children_to_restart(&id2);
assert_eq!(to_restart, vec![id2.clone(), id3.clone()]);
let to_restart = strategy.children_to_restart(&id1);
assert_eq!(to_restart, vec![id1.clone(), id2.clone(), id3.clone()]);
let to_restart = strategy.children_to_restart(&id3);
assert_eq!(to_restart, vec![id3.clone()]);
}
#[test]
fn rest_for_one_max_exceeded_returns_stop() {
let strategy = RestForOne::new(2, Duration::from_secs(60));
let id = test_id(1);
strategy.register_child(id.clone());
for _ in 0..2 {
assert_eq!(
strategy.on_child_failed(&id, "child-1", &test_error()),
SupervisionAction::Restart,
);
}
assert_eq!(
strategy.on_child_failed(&id, "child-1", &test_error()),
SupervisionAction::Stop,
);
}
}