use crate::error::{AgentError, Result};
use crate::runtime::{ContainerId, ContainerState, Runtime};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, Notify, RwLock};
use zlayer_spec::{PanicAction, ServiceSpec};
pub type IsolateCallback = Arc<dyn Fn(&ContainerId) + Send + Sync>;
const DEFAULT_MAX_RESTARTS: u32 = 5;
const DEFAULT_RESTART_WINDOW: Duration = Duration::from_secs(300); const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(5);
const CRASH_LOOP_BACKOFF_DELAY: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SupervisedState {
Running,
Restarting,
CrashLoopBackOff,
Isolated,
Shutdown,
Completed,
}
#[derive(Debug, Clone)]
pub struct SupervisedContainer {
pub id: ContainerId,
pub service_name: String,
pub state: SupervisedState,
pub panic_action: PanicAction,
pub restart_times: Vec<Instant>,
pub total_restarts: u32,
pub last_exit_code: Option<i32>,
pub supervised_since: Instant,
}
impl SupervisedContainer {
#[must_use]
pub fn new(id: ContainerId, service_name: String, panic_action: PanicAction) -> Self {
Self {
id,
service_name,
state: SupervisedState::Running,
panic_action,
restart_times: Vec::new(),
total_restarts: 0,
last_exit_code: None,
supervised_since: Instant::now(),
}
}
pub fn record_restart(&mut self, window: Duration, max_restarts: u32) -> bool {
let now = Instant::now();
self.restart_times.push(now);
self.total_restarts += 1;
self.restart_times
.retain(|&t| now.duration_since(t) < window);
#[allow(clippy::cast_possible_truncation)]
let count = self.restart_times.len() as u32;
count > max_restarts
}
#[must_use]
pub fn should_monitor(&self) -> bool {
matches!(
self.state,
SupervisedState::Running | SupervisedState::CrashLoopBackOff
)
}
}
#[derive(Debug, Clone)]
pub enum SupervisorEvent {
ContainerRestarted {
id: ContainerId,
service_name: String,
exit_code: i32,
restart_count: u32,
},
CrashLoopBackOff {
id: ContainerId,
service_name: String,
restart_count: u32,
},
ContainerIsolated {
id: ContainerId,
service_name: String,
exit_code: i32,
},
ServiceShutdown {
id: ContainerId,
service_name: String,
exit_code: i32,
},
ContainerCompleted {
id: ContainerId,
service_name: String,
},
}
#[derive(Debug, Clone)]
pub struct SupervisorConfig {
pub max_restarts: u32,
pub restart_window: Duration,
pub poll_interval: Duration,
}
impl Default for SupervisorConfig {
fn default() -> Self {
Self {
max_restarts: DEFAULT_MAX_RESTARTS,
restart_window: DEFAULT_RESTART_WINDOW,
poll_interval: DEFAULT_POLL_INTERVAL,
}
}
}
pub struct ContainerSupervisor {
runtime: Arc<dyn Runtime + Send + Sync>,
containers: Arc<RwLock<HashMap<ContainerId, SupervisedContainer>>>,
config: SupervisorConfig,
event_tx: mpsc::Sender<SupervisorEvent>,
event_rx: Arc<RwLock<mpsc::Receiver<SupervisorEvent>>>,
running: Arc<AtomicBool>,
shutdown: Arc<Notify>,
on_isolate: Option<IsolateCallback>,
}
impl ContainerSupervisor {
pub fn new(runtime: Arc<dyn Runtime + Send + Sync>) -> Self {
Self::with_config(runtime, SupervisorConfig::default())
}
pub fn with_config(runtime: Arc<dyn Runtime + Send + Sync>, config: SupervisorConfig) -> Self {
let (event_tx, event_rx) = mpsc::channel(100);
Self {
runtime,
containers: Arc::new(RwLock::new(HashMap::new())),
config,
event_tx,
event_rx: Arc::new(RwLock::new(event_rx)),
running: Arc::new(AtomicBool::new(false)),
shutdown: Arc::new(Notify::new()),
on_isolate: None,
}
}
pub fn set_isolate_callback<F>(&mut self, callback: F)
where
F: Fn(&ContainerId) + Send + Sync + 'static,
{
self.on_isolate = Some(Arc::new(callback));
}
pub async fn supervise(&self, container_id: &ContainerId, spec: &ServiceSpec) {
let supervised = SupervisedContainer::new(
container_id.clone(),
container_id.service.clone(),
spec.errors.on_panic.action,
);
let mut containers = self.containers.write().await;
containers.insert(container_id.clone(), supervised);
tracing::info!(
container = %container_id,
panic_action = ?spec.errors.on_panic.action,
"Container registered for supervision"
);
}
pub async fn unsupervise(&self, container_id: &ContainerId) {
let mut containers = self.containers.write().await;
if containers.remove(container_id).is_some() {
tracing::debug!(container = %container_id, "Container removed from supervision");
}
}
pub async fn get_state(&self, container_id: &ContainerId) -> Option<SupervisedState> {
let containers = self.containers.read().await;
containers.get(container_id).map(|c| c.state.clone())
}
pub async fn get_container_info(
&self,
container_id: &ContainerId,
) -> Option<SupervisedContainer> {
let containers = self.containers.read().await;
containers.get(container_id).cloned()
}
pub async fn list_supervised(&self) -> Vec<SupervisedContainer> {
let containers = self.containers.read().await;
containers.values().cloned().collect()
}
pub async fn take_event_receiver(&self) -> Option<mpsc::Receiver<SupervisorEvent>> {
let mut rx_guard = self.event_rx.write().await;
let (_, dummy_rx) = mpsc::channel(1);
let old_rx = std::mem::replace(&mut *rx_guard, dummy_rx);
Some(old_rx)
}
pub async fn run_loop(&self) {
self.running.store(true, Ordering::SeqCst);
tracing::info!(
poll_interval_ms = self.config.poll_interval.as_millis(),
"Container supervisor started"
);
loop {
tokio::select! {
() = self.shutdown.notified() => {
tracing::info!("Container supervisor shutting down");
break;
}
() = tokio::time::sleep(self.config.poll_interval) => {
if let Err(e) = self.check_all_containers().await {
tracing::error!(error = %e, "Error during container health check");
}
}
}
}
self.running.store(false, Ordering::SeqCst);
}
pub fn shutdown(&self) {
self.shutdown.notify_one();
}
#[must_use]
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
async fn check_all_containers(&self) -> Result<()> {
let containers_to_check: Vec<_> = {
let containers = self.containers.read().await;
containers
.iter()
.filter(|(_, c)| c.should_monitor())
.map(|(id, c)| (id.clone(), c.panic_action))
.collect()
};
for (container_id, panic_action) in containers_to_check {
self.check_container(&container_id, panic_action).await?;
}
Ok(())
}
async fn check_container(
&self,
container_id: &ContainerId,
panic_action: PanicAction,
) -> Result<()> {
let state = self.runtime.container_state(container_id).await?;
match state {
ContainerState::Running
| ContainerState::Pending
| ContainerState::Initializing
| ContainerState::Stopping => {
}
ContainerState::Exited { code } => {
self.handle_container_exit(container_id, code, panic_action)
.await?;
}
ContainerState::Failed { reason } => {
tracing::warn!(
container = %container_id,
reason = %reason,
"Container reported as failed"
);
self.handle_container_exit(container_id, -1, panic_action)
.await?;
}
}
Ok(())
}
async fn handle_container_exit(
&self,
container_id: &ContainerId,
exit_code: i32,
panic_action: PanicAction,
) -> Result<()> {
let (service_name, _should_restart, in_crash_loop) = {
let mut containers = self.containers.write().await;
let Some(container) = containers.get_mut(container_id) else {
return Ok(()); };
container.last_exit_code = Some(exit_code);
if exit_code == 0 {
container.state = SupervisedState::Completed;
let _ = self
.event_tx
.send(SupervisorEvent::ContainerCompleted {
id: container_id.clone(),
service_name: container.service_name.clone(),
})
.await;
return Ok(());
}
let service_name = container.service_name.clone();
let in_crash_loop =
container.record_restart(self.config.restart_window, self.config.max_restarts);
let should_restart = match panic_action {
PanicAction::Restart => !in_crash_loop,
PanicAction::Shutdown | PanicAction::Isolate => false,
};
if in_crash_loop && matches!(panic_action, PanicAction::Restart) {
container.state = SupervisedState::CrashLoopBackOff;
} else if should_restart {
container.state = SupervisedState::Restarting;
}
(service_name, should_restart, in_crash_loop)
};
match panic_action {
PanicAction::Restart => {
if in_crash_loop {
self.handle_crash_loop_backoff(container_id, &service_name)
.await?;
} else {
self.restart_container(container_id, &service_name, exit_code)
.await?;
}
}
PanicAction::Shutdown => {
self.shutdown_container(container_id, &service_name, exit_code)
.await?;
}
PanicAction::Isolate => {
self.isolate_container(container_id, &service_name, exit_code)
.await?;
}
}
Ok(())
}
async fn restart_container(
&self,
container_id: &ContainerId,
service_name: &str,
exit_code: i32,
) -> Result<()> {
let restart_count = {
let containers = self.containers.read().await;
containers.get(container_id).map_or(0, |c| c.total_restarts)
};
tracing::info!(
container = %container_id,
service = %service_name,
exit_code = exit_code,
restart_count = restart_count,
"Restarting crashed container"
);
self.runtime
.start_container(container_id)
.await
.map_err(|e| AgentError::StartFailed {
id: container_id.to_string(),
reason: e.to_string(),
})?;
{
let mut containers = self.containers.write().await;
if let Some(container) = containers.get_mut(container_id) {
container.state = SupervisedState::Running;
}
}
let _ = self
.event_tx
.send(SupervisorEvent::ContainerRestarted {
id: container_id.clone(),
service_name: service_name.to_string(),
exit_code,
restart_count,
})
.await;
Ok(())
}
async fn handle_crash_loop_backoff(
&self,
container_id: &ContainerId,
service_name: &str,
) -> Result<()> {
let restart_count = {
let containers = self.containers.read().await;
containers.get(container_id).map_or(0, |c| c.total_restarts)
};
tracing::warn!(
container = %container_id,
service = %service_name,
restart_count = restart_count,
backoff_delay_secs = CRASH_LOOP_BACKOFF_DELAY.as_secs(),
"Container in CrashLoopBackOff, delaying restart"
);
let _ = self
.event_tx
.send(SupervisorEvent::CrashLoopBackOff {
id: container_id.clone(),
service_name: service_name.to_string(),
restart_count,
})
.await;
let runtime = Arc::clone(&self.runtime);
let container_id = container_id.clone();
let containers = Arc::clone(&self.containers);
tokio::spawn(async move {
tokio::time::sleep(CRASH_LOOP_BACKOFF_DELAY).await;
if let Err(e) = runtime.start_container(&container_id).await {
tracing::error!(
container = %container_id,
error = %e,
"Failed to restart container after CrashLoopBackOff delay"
);
return;
}
let mut containers_guard = containers.write().await;
if let Some(container) = containers_guard.get_mut(&container_id) {
container.state = SupervisedState::Running;
}
});
Ok(())
}
async fn shutdown_container(
&self,
container_id: &ContainerId,
service_name: &str,
exit_code: i32,
) -> Result<()> {
tracing::warn!(
container = %container_id,
service = %service_name,
exit_code = exit_code,
"Shutting down service due to panic policy"
);
{
let mut containers = self.containers.write().await;
if let Some(container) = containers.get_mut(container_id) {
container.state = SupervisedState::Shutdown;
}
}
let _ = self
.event_tx
.send(SupervisorEvent::ServiceShutdown {
id: container_id.clone(),
service_name: service_name.to_string(),
exit_code,
})
.await;
Ok(())
}
async fn isolate_container(
&self,
container_id: &ContainerId,
service_name: &str,
exit_code: i32,
) -> Result<()> {
tracing::info!(
container = %container_id,
service = %service_name,
exit_code = exit_code,
"Isolating container (removed from load balancer for debugging)"
);
if let Some(ref callback) = self.on_isolate {
callback(container_id);
}
{
let mut containers = self.containers.write().await;
if let Some(container) = containers.get_mut(container_id) {
container.state = SupervisedState::Isolated;
}
}
let _ = self
.event_tx
.send(SupervisorEvent::ContainerIsolated {
id: container_id.clone(),
service_name: service_name.to_string(),
exit_code,
})
.await;
Ok(())
}
pub async fn supervised_count(&self) -> usize {
self.containers.read().await.len()
}
pub async fn count_by_state(&self, state: SupervisedState) -> usize {
self.containers
.read()
.await
.values()
.filter(|c| c.state == state)
.count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::MockRuntime;
fn mock_container_id(service: &str, replica: u32) -> ContainerId {
ContainerId {
service: service.to_string(),
replica,
}
}
fn mock_service_spec(panic_action: PanicAction) -> ServiceSpec {
let mut spec: ServiceSpec = serde_yaml::from_str::<zlayer_spec::DeploymentSpec>(
r"
version: v1
deployment: test
services:
test:
rtype: service
image:
name: test:latest
",
)
.unwrap()
.services
.remove("test")
.unwrap();
spec.errors.on_panic.action = panic_action;
spec
}
#[tokio::test]
async fn test_supervisor_creation() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let supervisor = ContainerSupervisor::new(runtime);
assert!(!supervisor.is_running());
assert_eq!(supervisor.supervised_count().await, 0);
}
#[tokio::test]
async fn test_supervisor_with_config() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let config = SupervisorConfig {
max_restarts: 10,
restart_window: Duration::from_secs(600),
poll_interval: Duration::from_secs(1),
};
let supervisor = ContainerSupervisor::with_config(runtime, config);
assert_eq!(supervisor.config.max_restarts, 10);
assert_eq!(supervisor.config.restart_window, Duration::from_secs(600));
}
#[tokio::test]
async fn test_supervise_container() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let supervisor = ContainerSupervisor::new(runtime);
let container_id = mock_container_id("api", 1);
let spec = mock_service_spec(PanicAction::Restart);
supervisor.supervise(&container_id, &spec).await;
assert_eq!(supervisor.supervised_count().await, 1);
let state = supervisor.get_state(&container_id).await;
assert_eq!(state, Some(SupervisedState::Running));
}
#[tokio::test]
async fn test_unsupervise_container() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let supervisor = ContainerSupervisor::new(runtime);
let container_id = mock_container_id("api", 1);
let spec = mock_service_spec(PanicAction::Restart);
supervisor.supervise(&container_id, &spec).await;
assert_eq!(supervisor.supervised_count().await, 1);
supervisor.unsupervise(&container_id).await;
assert_eq!(supervisor.supervised_count().await, 0);
}
#[tokio::test]
async fn test_list_supervised() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let supervisor = ContainerSupervisor::new(runtime);
let spec = mock_service_spec(PanicAction::Restart);
supervisor
.supervise(&mock_container_id("api", 1), &spec)
.await;
supervisor
.supervise(&mock_container_id("api", 2), &spec)
.await;
supervisor
.supervise(&mock_container_id("web", 1), &spec)
.await;
let containers = supervisor.list_supervised().await;
assert_eq!(containers.len(), 3);
}
#[tokio::test]
async fn test_supervised_container_record_restart() {
let mut container = SupervisedContainer::new(
mock_container_id("api", 1),
"api".to_string(),
PanicAction::Restart,
);
for _ in 0..5 {
let in_loop = container.record_restart(Duration::from_secs(300), 5);
assert!(!in_loop);
}
let in_loop = container.record_restart(Duration::from_secs(300), 5);
assert!(in_loop);
}
#[tokio::test]
async fn test_supervised_container_restart_window() {
let mut container = SupervisedContainer::new(
mock_container_id("api", 1),
"api".to_string(),
PanicAction::Restart,
);
for _ in 0..5 {
container.record_restart(Duration::from_millis(100), 5);
}
tokio::time::sleep(Duration::from_millis(150)).await;
let in_loop = container.record_restart(Duration::from_millis(100), 5);
assert!(!in_loop);
}
#[tokio::test]
async fn test_get_container_info() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let supervisor = ContainerSupervisor::new(runtime);
let container_id = mock_container_id("api", 1);
let spec = mock_service_spec(PanicAction::Isolate);
supervisor.supervise(&container_id, &spec).await;
let info = supervisor.get_container_info(&container_id).await;
assert!(info.is_some());
let info = info.unwrap();
assert_eq!(info.id, container_id);
assert_eq!(info.service_name, "api");
assert_eq!(info.panic_action, PanicAction::Isolate);
assert_eq!(info.state, SupervisedState::Running);
assert_eq!(info.total_restarts, 0);
}
#[tokio::test]
async fn test_count_by_state() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let supervisor = ContainerSupervisor::new(runtime);
let spec = mock_service_spec(PanicAction::Restart);
supervisor
.supervise(&mock_container_id("api", 1), &spec)
.await;
supervisor
.supervise(&mock_container_id("api", 2), &spec)
.await;
assert_eq!(supervisor.count_by_state(SupervisedState::Running).await, 2);
assert_eq!(
supervisor
.count_by_state(SupervisedState::CrashLoopBackOff)
.await,
0
);
}
#[test]
fn test_supervisor_config_default() {
let config = SupervisorConfig::default();
assert_eq!(config.max_restarts, DEFAULT_MAX_RESTARTS);
assert_eq!(config.restart_window, DEFAULT_RESTART_WINDOW);
assert_eq!(config.poll_interval, DEFAULT_POLL_INTERVAL);
}
#[test]
fn test_supervised_state_should_monitor() {
let container = SupervisedContainer {
state: SupervisedState::Running,
..SupervisedContainer::new(
mock_container_id("api", 1),
"api".to_string(),
PanicAction::Restart,
)
};
assert!(container.should_monitor());
let container = SupervisedContainer {
state: SupervisedState::CrashLoopBackOff,
..SupervisedContainer::new(
mock_container_id("api", 1),
"api".to_string(),
PanicAction::Restart,
)
};
assert!(container.should_monitor());
let container = SupervisedContainer {
state: SupervisedState::Shutdown,
..SupervisedContainer::new(
mock_container_id("api", 1),
"api".to_string(),
PanicAction::Restart,
)
};
assert!(!container.should_monitor());
let container = SupervisedContainer {
state: SupervisedState::Isolated,
..SupervisedContainer::new(
mock_container_id("api", 1),
"api".to_string(),
PanicAction::Restart,
)
};
assert!(!container.should_monitor());
let container = SupervisedContainer {
state: SupervisedState::Completed,
..SupervisedContainer::new(
mock_container_id("api", 1),
"api".to_string(),
PanicAction::Restart,
)
};
assert!(!container.should_monitor());
}
}