mod api;
mod builtins;
mod events;
mod handlers;
mod persistence;
mod spawning;
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{RwLock, mpsc};
use tokio::task::JoinHandle;
use crate::sdk::{LogLine, ServiceState};
pub use events::{SupervisorEvent, TimeoutKind};
use crate::server::error::SupervisorResult;
use crate::server::graph::{Service, ServiceGraph, ServiceId};
use crate::server::log::{self, LogBuffers};
use crate::server::process::process_exists;
use crate::server::state::{
PersistedState, PidStatus, cleanup_state_files, now_millis, try_load_restore_fds,
try_load_restore_state, validate_pid,
};
pub const SYSTEM_CONFIG_DIR: &str = "/etc/zinit/system";
pub struct Supervisor {
pub(crate) graph: Arc<RwLock<ServiceGraph>>,
pub(crate) log_buffers: LogBuffers,
pub(crate) event_tx: mpsc::Sender<SupervisorEvent>,
event_rx: mpsc::Receiver<SupervisorEvent>,
_log_shipper_tx: Option<mpsc::Sender<LogLine>>,
pub(crate) timers: HashMap<(ServiceId, TimeoutKind), JoinHandle<()>>,
pub(crate) process_tasks: HashMap<ServiceId, JoinHandle<()>>,
pub(crate) health_attempts: HashMap<ServiceId, u32>,
pub(crate) pending_restarts: HashSet<ServiceId>,
pub(crate) config_dir: PathBuf,
socket_path: PathBuf,
pub(crate) boot_time: u64,
shutdown: bool,
pub(crate) pid1_mode: bool,
pub(crate) system_service_names: HashSet<String>,
}
impl Supervisor {
pub async fn new(
config_dir: PathBuf,
socket_path: PathBuf,
pid1_mode: bool,
) -> SupervisorResult<Self> {
let restore_state = try_load_restore_state();
let restore_fds = try_load_restore_fds();
match (&restore_state, &restore_fds) {
(Some(state), Some(_fds)) => {
tracing::info!(
services = state.services.len(),
"restoring from saved state"
);
Self::restore_from_state(state.clone(), config_dir, socket_path, pid1_mode).await
}
(Some(state), None) => {
tracing::warn!("restore state found but no FDs, doing partial restore");
Self::restore_from_state(state.clone(), config_dir, socket_path, pid1_mode).await
}
_ => {
Self::fresh_start(config_dir, socket_path, pid1_mode).await
}
}
}
async fn fresh_start(
config_dir: PathBuf,
socket_path: PathBuf,
pid1_mode: bool,
) -> SupervisorResult<Self> {
let (event_tx, event_rx) = mpsc::channel(256);
let mut graph = ServiceGraph::new();
let mut system_service_names = HashSet::new();
if pid1_mode {
let system_dir = std::path::Path::new(SYSTEM_CONFIG_DIR);
match graph.load_from_system_directory(system_dir) {
Ok(names) => {
system_service_names = names.into_iter().collect();
tracing::info!(
system_dir = %system_dir.display(),
count = system_service_names.len(),
"loaded system services"
);
}
Err(e) => {
tracing::warn!(
system_dir = %system_dir.display(),
error = %e,
"failed to load system services (continuing with user services only)"
);
}
}
}
graph.load_from_user_directory(&config_dir, &system_service_names)?;
let missing_deps = graph.link_dependencies()?;
graph.validate()?;
let config_error_count = if !missing_deps.is_empty() {
let count = graph.mark_missing_deps_failed(&missing_deps);
tracing::error!(
count = count,
"services have missing dependencies and will not start"
);
for (service, dep) in &missing_deps {
tracing::error!(
service = %service,
missing_dependency = %dep,
"service has missing required dependency"
);
}
count
} else {
0
};
tracing::info!(
config_dir = %config_dir.display(),
pid1_mode = pid1_mode,
system_services = system_service_names.len(),
total_services = graph.len(),
config_errors = config_error_count,
"loaded service graph"
);
let supervisor = Self {
graph: Arc::new(RwLock::new(graph)),
log_buffers: log::new_log_buffers(),
event_tx,
event_rx,
_log_shipper_tx: None,
timers: HashMap::new(),
process_tasks: HashMap::new(),
health_attempts: HashMap::new(),
pending_restarts: HashSet::new(),
config_dir,
socket_path,
boot_time: now_millis(),
shutdown: false,
pid1_mode,
system_service_names,
};
Ok(supervisor)
}
async fn restore_from_state(
state: PersistedState,
config_dir: PathBuf,
socket_path: PathBuf,
pid1_mode: bool,
) -> SupervisorResult<Self> {
let (event_tx, event_rx) = mpsc::channel(256);
let mut graph = ServiceGraph::new();
let mut system_service_names = HashSet::new();
if pid1_mode {
let system_dir = std::path::Path::new(SYSTEM_CONFIG_DIR);
if let Ok(names) = graph.load_from_system_directory(system_dir) {
system_service_names = names.into_iter().collect();
}
}
graph.load_from_user_directory(&config_dir, &system_service_names)?;
let missing_deps = graph.link_dependencies()?;
graph.validate()?;
if !missing_deps.is_empty() {
let count = graph.mark_missing_deps_failed(&missing_deps);
tracing::error!(
count = count,
"services have missing dependencies and will not start"
);
}
for (name, persisted) in &state.services {
if let Some(id) = graph.get_by_name(name) {
if let Some(service) = graph.get_mut(id) {
service.restart_count = persisted.restart_count;
service.current_restart_delay_ms = persisted.current_restart_delay_ms;
service.started_at = persisted.started_at;
service.last_state_change = persisted.last_state_change;
service.last_exit_code = persisted.last_exit_code;
service.last_exit_signal = persisted.last_exit_signal;
if let Some(pid) = persisted.pid {
let expected_exec = service
.service_config()
.map(|c| c.service.exec.as_str())
.unwrap_or("");
match validate_pid(pid, expected_exec) {
PidStatus::Alive => {
service.state = persisted.state.into_service_state(Some(pid));
tracing::info!(
service = %name,
pid = pid,
"restored running service"
);
}
PidStatus::Dead => {
service.state = ServiceState::Exited { exit_code: None };
tracing::warn!(
service = %name,
pid = pid,
"restored service PID is dead, marking as exited"
);
}
PidStatus::WrongProcess => {
service.state = ServiceState::Exited { exit_code: None };
tracing::warn!(
service = %name,
pid = pid,
"restored service PID is a different process, marking as exited"
);
}
}
} else {
service.state = persisted.state.into_service_state(None);
}
}
} else if persisted.ephemeral {
if let Some(config) = &persisted.config {
let mut service = Service::from_service_ephemeral(config.clone());
service.restart_count = persisted.restart_count;
service.current_restart_delay_ms = persisted.current_restart_delay_ms;
service.started_at = persisted.started_at;
service.last_state_change = persisted.last_state_change;
service.last_exit_code = persisted.last_exit_code;
service.last_exit_signal = persisted.last_exit_signal;
service.state = persisted.state.into_service_state(persisted.pid);
if let Some(pid) = persisted.pid {
match validate_pid(pid, &config.service.exec) {
PidStatus::Alive => {
service.state = persisted.state.into_service_state(Some(pid));
}
_ => {
service.state = ServiceState::Exited { exit_code: None };
}
}
}
if let Err(e) = graph.add_service(service) {
tracing::warn!(
service = %name,
error = %e,
"failed to restore ephemeral service"
);
} else {
tracing::info!(service = %name, "restored ephemeral service");
}
}
} else {
tracing::warn!(
service = %name,
"service in restore state but not found on disk, skipping"
);
}
}
cleanup_state_files();
tracing::info!(
services = graph.len(),
boot_time = state.boot_time,
"restore complete"
);
Ok(Self {
graph: Arc::new(RwLock::new(graph)),
log_buffers: log::new_log_buffers(),
event_tx,
event_rx,
_log_shipper_tx: None,
timers: HashMap::new(),
process_tasks: HashMap::new(),
health_attempts: HashMap::new(),
pending_restarts: HashSet::new(),
config_dir,
socket_path,
boot_time: state.boot_time, shutdown: false,
pid1_mode,
system_service_names,
})
}
pub fn graph(&self) -> Arc<RwLock<ServiceGraph>> {
Arc::clone(&self.graph)
}
pub fn log_buffers(&self) -> LogBuffers {
Arc::clone(&self.log_buffers)
}
pub fn event_tx(&self) -> mpsc::Sender<SupervisorEvent> {
self.event_tx.clone()
}
async fn reconnect_restored_services(&mut self) {
let services_to_monitor: Vec<(ServiceId, String, u32, Option<crate::sdk::HealthDef>)> = {
let graph = self.graph.read().await;
graph
.all_services()
.filter_map(|id| {
let service = graph.get(id)?;
let pid = service.state.pid()?;
if pid == 0 {
return None; }
if self.process_tasks.contains_key(&id) {
return None;
}
let health = if matches!(service.state, ServiceState::Running { .. }) {
service.service_config().and_then(|c| c.health.clone())
} else {
None
};
Some((id, service.name.clone(), pid, health))
})
.collect()
};
if services_to_monitor.is_empty() {
return;
}
tracing::info!(
count = services_to_monitor.len(),
"reconnecting to restored services"
);
for (id, name, pid, health) in services_to_monitor {
log::init_buffer(&self.log_buffers, id, 100).await;
let event_tx = self.event_tx.clone();
let service_name = name.clone();
let task = tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_millis(500)).await;
if !process_exists(pid) {
tracing::info!(
service = %service_name,
pid = pid,
"restored service process exited"
);
let _ = event_tx
.send(SupervisorEvent::ProcessExited {
service_id: id,
exit_code: None,
signal: None,
})
.await;
break;
}
}
});
self.process_tasks.insert(id, task);
if let Some(ref health_def) = health {
let interval = match health_def {
crate::sdk::HealthDef::Tcp { common, .. }
| crate::sdk::HealthDef::Http { common, .. }
| crate::sdk::HealthDef::Exec { common, .. } => common.interval_ms,
};
self.schedule_timeout(id, TimeoutKind::HealthCheck, interval);
}
tracing::debug!(
service = %name,
pid = pid,
has_health = health.is_some(),
"reconnected to restored service"
);
}
}
pub async fn run(&mut self) -> SupervisorResult<()> {
self.reconnect_restored_services().await;
self.start_all().await?;
while let Some(event) = self.event_rx.recv().await {
if self.shutdown {
break;
}
self.handle_event(event).await;
}
tracing::info!("supervisor event loop ended");
Ok(())
}
pub fn request_shutdown(&mut self) {
self.shutdown = true;
}
async fn start_all(&mut self) -> SupervisorResult<()> {
let order = {
let graph = self.graph.read().await;
graph.start_order()
};
for id in order {
self.try_start_service(id).await;
}
Ok(())
}
pub(crate) async fn service_name(&self, id: ServiceId) -> Option<String> {
let graph = self.graph.read().await;
graph.get(id).map(|s| s.name.clone())
}
pub(crate) fn schedule_timeout(&mut self, id: ServiceId, kind: TimeoutKind, delay_ms: u64) {
self.cancel_timeout(id, kind);
let event_tx = self.event_tx.clone();
let task = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
let _ = event_tx
.send(SupervisorEvent::Timeout {
service_id: id,
kind,
})
.await;
});
self.timers.insert((id, kind), task);
}
pub(crate) fn cancel_timeout(&mut self, id: ServiceId, kind: TimeoutKind) {
if let Some(handle) = self.timers.remove(&(id, kind)) {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_supervisor_creation() {
let temp_dir = std::env::temp_dir().join("zinit-test-empty");
let socket_path = temp_dir.join("zinit.sock");
let _ = std::fs::create_dir_all(&temp_dir);
let supervisor = Supervisor::new(temp_dir.clone(), socket_path, false).await;
assert!(supervisor.is_ok());
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[tokio::test]
async fn test_list_services_empty() {
let temp_dir = std::env::temp_dir().join("zinit-test-list");
let socket_path = temp_dir.join("zinit.sock");
let _ = std::fs::create_dir_all(&temp_dir);
let supervisor = Supervisor::new(temp_dir.clone(), socket_path, false)
.await
.unwrap();
let services = supervisor.list_services().await;
assert!(services.is_empty());
let _ = std::fs::remove_dir_all(&temp_dir);
}
}