use std::collections::HashMap;
use std::sync::{Arc, Mutex, mpsc};
use std::thread;
use std::time::{Duration, Instant};
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug, Clone)]
pub enum RestartStrategy {
OneForOne,
OneForAll,
RestForOne,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ProcessState {
Running,
Failed,
Restarting,
Stopped,
Unstarted,
}
#[derive(Debug, Clone)]
pub enum ChildType {
Permanent,
Temporary,
Transient,
}
#[derive(Debug, Clone)]
pub enum ShutdownStrategy {
BrutalKill,
Shutdown(Duration),
}
#[derive(Debug, Clone)]
pub struct SupervisorConfig {
pub max_restarts: usize,
pub max_time: Duration,
pub restart_strategy: RestartStrategy,
pub shutdown_strategy: ShutdownStrategy,
}
impl Default for SupervisorConfig {
fn default() -> Self {
SupervisorConfig {
max_restarts: 3,
max_time: Duration::from_secs(5),
restart_strategy: RestartStrategy::OneForOne,
shutdown_strategy: ShutdownStrategy::Shutdown(Duration::from_secs(5)),
}
}
}
pub trait EventCallback: Send + Sync {
fn on_process_started(&self, _process_name: &str) {}
fn on_process_failed(&self, _process_name: &str) {}
fn on_process_restarted(&self, _process_name: &str, _restart_count: usize) {}
fn on_process_stopped(&self, _process_name: &str) {}
}
pub struct NoOpCallback;
impl EventCallback for NoOpCallback {}
struct ChildSpec {
child_type: ChildType,
factory: Box<dyn Fn() -> thread::JoinHandle<()> + Send + 'static>,
shutdown_strategy: ShutdownStrategy,
shutdown_signal: Arc<AtomicBool>,
}
struct ProcessInfo {
handle: Option<thread::JoinHandle<()>>,
restart_times: Vec<Instant>,
state: ProcessState,
restart_count: usize,
spec: ChildSpec,
}
pub struct Supervisor {
processes: Arc<Mutex<HashMap<String, ProcessInfo>>>,
config: SupervisorConfig,
dependencies: Arc<Mutex<HashMap<String, Vec<String>>>>,
event_callback: Arc<dyn EventCallback>,
monitor_handle: Arc<Mutex<Option<thread::JoinHandle<()>>>>,
shutdown_flag: Arc<AtomicBool>,
signal_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
}
impl Supervisor {
pub fn new(config: SupervisorConfig) -> Self {
Supervisor::with_callback(config, Arc::new(NoOpCallback))
}
pub fn with_callback(config: SupervisorConfig, callback: Arc<dyn EventCallback>) -> Self {
Supervisor {
processes: Arc::new(Mutex::new(HashMap::new())),
config,
dependencies: Arc::new(Mutex::new(HashMap::new())),
event_callback: callback,
monitor_handle: Arc::new(Mutex::new(None)),
shutdown_flag: Arc::new(AtomicBool::new(false)),
signal_tx: Arc::new(Mutex::new(None)),
}
}
pub fn add_process<F>(&mut self, name: &str, child_type: ChildType, factory: F)
where
F: Fn() -> thread::JoinHandle<()> + Send + 'static,
{
self.add_process_with_shutdown(
name,
child_type,
factory,
self.config.shutdown_strategy.clone(),
);
}
pub fn add_process_with_shutdown<F>(
&mut self,
name: &str,
child_type: ChildType,
factory: F,
shutdown_strategy: ShutdownStrategy,
)
where
F: Fn() -> thread::JoinHandle<()> + Send + 'static,
{
let factory_box = Box::new(factory);
let shutdown_signal = Arc::new(AtomicBool::new(false));
let spec = ChildSpec {
child_type,
factory: factory_box,
shutdown_strategy,
shutdown_signal,
};
let mut processes = self.processes.lock().unwrap();
processes.insert(
name.to_string(),
ProcessInfo {
handle: None,
restart_times: Vec::new(),
state: ProcessState::Unstarted,
restart_count: 0,
spec,
},
);
}
pub fn add_dependency(&self, process: &str, depends_on: &str) {
let mut dependencies = self.dependencies.lock().unwrap();
dependencies
.entry(process.to_string())
.or_insert_with(Vec::new)
.push(depends_on.to_string());
}
pub fn start_monitoring(self) -> Arc<Self>
where
Self: Sized,
{
let supervisor = Arc::new(self);
let should_start = {
let handle = supervisor.monitor_handle.lock().unwrap();
handle.is_none()
};
if !should_start {
return supervisor;
}
{
let (tx, _rx) = mpsc::channel();
*supervisor.signal_tx.lock().unwrap() = Some(tx);
let supervisor_clone = Arc::clone(&supervisor);
let monitor_thread = thread::spawn(move || {
supervisor_clone.monitor_loop();
});
let mut handle = supervisor.monitor_handle.lock().unwrap();
*handle = Some(monitor_thread);
}
{
let mut processes = supervisor.processes.lock().unwrap();
for (name, info) in processes.iter_mut() {
info.state = ProcessState::Restarting;
info.handle = Some((info.spec.factory)());
info.state = ProcessState::Running;
info.restart_times.push(Instant::now());
supervisor.event_callback.on_process_started(name);
}
}
supervisor
}
fn monitor_loop(&self) {
loop {
if self.shutdown_flag.load(Ordering::Relaxed) {
break;
}
thread::sleep(Duration::from_millis(100));
let mut failed_processes = Vec::new();
{
let mut processes = self.processes.lock().unwrap();
for (name, info) in processes.iter_mut() {
if info.state == ProcessState::Unstarted {
continue;
}
if let Some(handle) = &info.handle {
if handle.is_finished() {
info.state = ProcessState::Failed;
info.handle = None;
self.event_callback.on_process_failed(name);
let should_check_restart = match info.spec.child_type {
ChildType::Permanent => true,
ChildType::Temporary => false,
ChildType::Transient => {
true
}
};
if should_check_restart {
let now = Instant::now();
info.restart_times
.retain(|time| now.duration_since(*time) < self.config.max_time);
if info.restart_times.len() < self.config.max_restarts {
failed_processes.push(name.clone());
} else {
info.state = ProcessState::Stopped;
}
} else {
info.state = ProcessState::Stopped;
}
}
}
}
}
for failed_process in failed_processes {
let processes_to_restart = {
let processes = self.processes.lock().unwrap();
let dependencies = self.dependencies.lock().unwrap();
match self.config.restart_strategy {
RestartStrategy::OneForOne => vec![failed_process.clone()],
RestartStrategy::OneForAll => processes.keys().cloned().collect(),
RestartStrategy::RestForOne => {
let mut to_restart = vec![failed_process.clone()];
for (proc_name, deps) in dependencies.iter() {
if deps.contains(&failed_process) {
to_restart.push(proc_name.clone());
}
}
to_restart
}
}
};
let now = Instant::now();
for proc_name in processes_to_restart {
let mut processes = self.processes.lock().unwrap();
if let Some(proc_info) = processes.get_mut(&proc_name) {
if matches!(proc_info.spec.child_type, ChildType::Temporary)
|| proc_info.state == ProcessState::Stopped
{
continue;
}
proc_info.state = ProcessState::Restarting;
proc_info.restart_count += 1;
proc_info.handle = Some((proc_info.spec.factory)());
proc_info.restart_times.push(now);
proc_info.state = ProcessState::Running;
self.event_callback
.on_process_restarted(&proc_name, proc_info.restart_count);
}
}
}
}
}
pub fn stop_process(&self, name: &str) -> bool {
let mut processes = self.processes.lock().unwrap();
if let Some(info) = processes.get_mut(name) {
if let Some(handle) = info.handle.take() {
info.spec.shutdown_signal.store(true, Ordering::Relaxed);
match &info.spec.shutdown_strategy {
ShutdownStrategy::BrutalKill => {
drop(handle);
}
ShutdownStrategy::Shutdown(timeout) => {
let start = Instant::now();
while !handle.is_finished() && start.elapsed() < *timeout {
thread::sleep(Duration::from_millis(10));
}
drop(handle);
}
}
info.state = ProcessState::Stopped;
self.event_callback.on_process_stopped(name);
return true;
}
}
false
}
pub fn shutdown(&self) {
self.shutdown_flag.store(true, Ordering::Relaxed);
let process_names: Vec<String> = {
let processes = self.processes.lock().unwrap();
processes.keys().cloned().collect()
};
for name in process_names {
self.stop_process(&name);
}
if let Ok(mut handle) = self.monitor_handle.lock() {
if let Some(thread) = handle.take() {
let _ = thread.join();
}
}
}
pub fn get_process_state(&self, name: &str) -> Option<ProcessState> {
let processes = self.processes.lock().unwrap();
processes.get(name).map(|info| info.state.clone())
}
pub fn get_restart_count(&self, name: &str) -> Option<usize> {
let processes = self.processes.lock().unwrap();
processes.get(name).map(|info| info.restart_count)
}
pub fn get_all_states(&self) -> HashMap<String, (ProcessState, usize)> {
let processes = self.processes.lock().unwrap();
processes
.iter()
.map(|(name, info)| {
(
name.clone(),
(info.state.clone(), info.restart_count),
)
})
.collect()
}
pub fn get_shutdown_signal(&self, name: &str) -> Option<Arc<AtomicBool>> {
let processes = self.processes.lock().unwrap();
processes.get(name).map(|info| Arc::clone(&info.spec.shutdown_signal))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_supervisor_creation() {
let supervisor = Supervisor::new(SupervisorConfig::default());
assert_eq!(supervisor.get_all_states().len(), 0);
}
#[test]
fn test_add_process() {
let mut supervisor = Supervisor::new(SupervisorConfig::default());
supervisor.add_process("worker1", ChildType::Permanent, || {
thread::spawn(|| {
thread::sleep(Duration::from_secs(10));
})
});
assert_eq!(supervisor.get_all_states().len(), 1);
assert_eq!(
supervisor.get_process_state("worker1"),
Some(ProcessState::Unstarted)
);
}
#[test]
fn test_process_starts_on_monitoring() {
let mut supervisor = Supervisor::new(SupervisorConfig::default());
supervisor.add_process("worker1", ChildType::Permanent, || {
thread::spawn(|| {
thread::sleep(Duration::from_secs(10));
})
});
let supervisor = supervisor.start_monitoring();
thread::sleep(Duration::from_millis(200));
assert_eq!(
supervisor.get_process_state("worker1"),
Some(ProcessState::Running)
);
supervisor.shutdown();
}
#[test]
fn test_permanent_process_restart() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let mut supervisor = Supervisor::new(SupervisorConfig::default());
supervisor.add_process("failing_worker", ChildType::Permanent, move || {
let cnt = Arc::clone(&counter_clone);
thread::spawn(move || {
cnt.fetch_add(1, Ordering::Relaxed);
panic!("Intentional failure");
})
});
let supervisor = supervisor.start_monitoring();
thread::sleep(Duration::from_millis(500));
assert!(counter.load(Ordering::Relaxed) > 1);
supervisor.shutdown();
}
#[test]
fn test_temporary_process_no_restart() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let mut supervisor = Supervisor::new(SupervisorConfig::default());
supervisor.add_process("temp_worker", ChildType::Temporary, move || {
let cnt = Arc::clone(&counter_clone);
thread::spawn(move || {
cnt.fetch_add(1, Ordering::Relaxed);
panic!("Intentional failure");
})
});
let supervisor = supervisor.start_monitoring();
thread::sleep(Duration::from_millis(500));
assert_eq!(counter.load(Ordering::Relaxed), 1);
assert_eq!(
supervisor.get_process_state("temp_worker"),
Some(ProcessState::Stopped)
);
supervisor.shutdown();
}
#[test]
fn test_stop_process() {
let mut supervisor = Supervisor::new(SupervisorConfig::default());
supervisor.add_process("worker1", ChildType::Permanent, || {
thread::spawn(|| {
thread::sleep(Duration::from_secs(10));
})
});
let supervisor = supervisor.start_monitoring();
thread::sleep(Duration::from_millis(200));
assert!(supervisor.stop_process("worker1"));
thread::sleep(Duration::from_millis(100));
assert_eq!(
supervisor.get_process_state("worker1"),
Some(ProcessState::Stopped)
);
supervisor.shutdown();
}
#[test]
fn test_restart_count() {
let mut supervisor = Supervisor::new(SupervisorConfig::default());
supervisor.add_process("failing_worker", ChildType::Permanent, || {
thread::spawn(|| {
panic!("Intentional failure");
})
});
let supervisor = supervisor.start_monitoring();
thread::sleep(Duration::from_millis(500));
let restart_count = supervisor.get_restart_count("failing_worker").unwrap_or(0);
assert!(restart_count > 0);
supervisor.shutdown();
}
#[test]
fn test_restart_strategy_one_for_one() {
let mut config = SupervisorConfig::default();
config.restart_strategy = RestartStrategy::OneForOne;
let counter1 = Arc::new(AtomicUsize::new(0));
let counter1_clone = Arc::clone(&counter1);
let counter2 = Arc::new(AtomicUsize::new(0));
let counter2_clone = Arc::clone(&counter2);
let mut supervisor = Supervisor::new(config);
supervisor.add_process("failing_worker", ChildType::Permanent, move || {
let cnt = Arc::clone(&counter1_clone);
thread::spawn(move || {
cnt.fetch_add(1, Ordering::Relaxed);
panic!("Intentional failure");
})
});
supervisor.add_process("stable_worker", ChildType::Permanent, move || {
let cnt = Arc::clone(&counter2_clone);
thread::spawn(move || {
cnt.fetch_add(1, Ordering::Relaxed);
thread::sleep(Duration::from_secs(10));
})
});
let supervisor = supervisor.start_monitoring();
thread::sleep(Duration::from_millis(500));
let count1 = counter1.load(Ordering::Relaxed);
let count2 = counter2.load(Ordering::Relaxed);
assert!(count1 > count2);
supervisor.shutdown();
}
#[test]
fn test_process_dependencies() {
let mut supervisor = Supervisor::new(SupervisorConfig::default());
supervisor.add_process("base_worker", ChildType::Permanent, || {
thread::spawn(|| {
thread::sleep(Duration::from_secs(10));
})
});
supervisor.add_process("dependent_worker", ChildType::Permanent, || {
thread::spawn(|| {
thread::sleep(Duration::from_secs(10));
})
});
supervisor.add_dependency("dependent_worker", "base_worker");
let supervisor = supervisor.start_monitoring();
thread::sleep(Duration::from_millis(200));
assert_eq!(
supervisor.get_process_state("base_worker"),
Some(ProcessState::Running)
);
assert_eq!(
supervisor.get_process_state("dependent_worker"),
Some(ProcessState::Running)
);
supervisor.shutdown();
}
#[test]
fn test_max_restarts_limit() {
let mut config = SupervisorConfig::default();
config.max_restarts = 2;
config.max_time = Duration::from_secs(5);
let max_restarts = config.max_restarts;
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let mut supervisor = Supervisor::new(config);
supervisor.add_process("failing_worker", ChildType::Permanent, move || {
let cnt = Arc::clone(&counter_clone);
thread::spawn(move || {
cnt.fetch_add(1, Ordering::Relaxed);
panic!("Intentional failure");
})
});
let supervisor = supervisor.start_monitoring();
thread::sleep(Duration::from_millis(1000));
assert_eq!(
supervisor.get_process_state("failing_worker"),
Some(ProcessState::Stopped)
);
assert!(counter.load(Ordering::Relaxed) <= max_restarts + 1);
supervisor.shutdown();
}
#[test]
fn test_supervisor_shutdown() {
let mut supervisor = Supervisor::new(SupervisorConfig::default());
supervisor.add_process("worker1", ChildType::Permanent, || {
thread::spawn(|| {
thread::sleep(Duration::from_secs(10));
})
});
supervisor.add_process("worker2", ChildType::Permanent, || {
thread::spawn(|| {
thread::sleep(Duration::from_secs(10));
})
});
let supervisor = supervisor.start_monitoring();
thread::sleep(Duration::from_millis(200));
supervisor.shutdown();
thread::sleep(Duration::from_millis(200));
assert_eq!(
supervisor.get_process_state("worker1"),
Some(ProcessState::Stopped)
);
assert_eq!(
supervisor.get_process_state("worker2"),
Some(ProcessState::Stopped)
);
}
#[test]
fn test_event_callback() {
struct TestCallback {
started: AtomicUsize,
failed: AtomicUsize,
restarted: AtomicUsize,
}
impl EventCallback for TestCallback {
fn on_process_started(&self, _process_name: &str) {
self.started.fetch_add(1, Ordering::Relaxed);
}
fn on_process_failed(&self, _process_name: &str) {
self.failed.fetch_add(1, Ordering::Relaxed);
}
fn on_process_restarted(&self, _process_name: &str, _restart_count: usize) {
self.restarted.fetch_add(1, Ordering::Relaxed);
}
}
let callback: Arc<dyn EventCallback> = Arc::new(TestCallback {
started: AtomicUsize::new(0),
failed: AtomicUsize::new(0),
restarted: AtomicUsize::new(0),
});
let mut supervisor = Supervisor::with_callback(SupervisorConfig::default(), callback.clone());
supervisor.add_process("failing_worker", ChildType::Permanent, || {
thread::spawn(|| {
panic!("Intentional failure");
})
});
let supervisor = supervisor.start_monitoring();
thread::sleep(Duration::from_millis(500));
let callback_test = callback.as_ref() as *const dyn EventCallback as *const TestCallback;
unsafe {
assert!((*callback_test).started.load(Ordering::Relaxed) > 0);
assert!((*callback_test).failed.load(Ordering::Relaxed) > 0);
assert!((*callback_test).restarted.load(Ordering::Relaxed) > 0);
}
supervisor.shutdown();
}
}