use std::collections::HashMap;
use std::os::fd::BorrowedFd;
use std::os::unix::io::RawFd;
use std::time::{SystemTime, UNIX_EPOCH};
use nix::fcntl::{FcntlArg, FdFlag, fcntl};
use nix::unistd::Pid;
use serde::{Deserialize, Serialize};
use crate::sdk::{FailureReason, ServiceConfig, ServiceState};
pub const STATE_VERSION: u32 = 1;
pub const FDS_VERSION: u32 = 1;
pub const STATE_PATH: &str = "/run/zinit/state.json";
pub const FDS_PATH: &str = "/run/zinit/fds.json";
pub const STATE_MAX_AGE_MS: u64 = 5 * 60 * 1000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistedState {
pub version: u32,
pub saved_at: u64,
pub boot_time: u64,
pub services: HashMap<String, PersistedService>,
pub config_dir: String,
pub socket_path: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistedService {
pub name: String,
pub state: PersistedServiceState,
pub pid: Option<u32>,
pub restart_count: u32,
pub current_restart_delay_ms: u64,
pub last_exit_code: Option<i32>,
pub last_exit_signal: Option<i32>,
pub started_at: Option<u64>,
pub last_state_change: u64,
pub ephemeral: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub config: Option<ServiceConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PersistedServiceState {
Inactive,
Blocked,
Starting,
Running,
Stopping,
Exited,
Failed,
}
impl From<&ServiceState> for PersistedServiceState {
fn from(state: &ServiceState) -> Self {
match state {
ServiceState::Inactive => Self::Inactive,
ServiceState::Blocked { .. } => Self::Blocked,
ServiceState::Starting { .. } => Self::Starting,
ServiceState::Running { .. } => Self::Running,
ServiceState::Stopping { .. } => Self::Stopping,
ServiceState::Exited { .. } => Self::Exited,
ServiceState::Failed { .. } => Self::Failed,
}
}
}
impl PersistedServiceState {
pub fn into_service_state(self, pid: Option<u32>) -> ServiceState {
match self {
Self::Inactive => ServiceState::Inactive,
Self::Blocked => ServiceState::Blocked { waiting_on: vec![] }, Self::Starting => pid
.map(|p| ServiceState::Starting { pid: p })
.unwrap_or(ServiceState::Inactive),
Self::Running => pid
.map(|p| ServiceState::Running { pid: p })
.unwrap_or(ServiceState::Exited { exit_code: None }),
Self::Stopping => pid
.map(|p| ServiceState::Stopping { pid: p })
.unwrap_or(ServiceState::Exited { exit_code: None }),
Self::Exited => ServiceState::Exited { exit_code: None },
Self::Failed => ServiceState::Failed {
reason: FailureReason::SpawnError {
message: "restored from failed state".into(),
},
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistedFds {
pub version: u32,
pub rpc_socket: Option<RawFd>,
pub services: HashMap<String, ServiceFds>,
pub socket_activated: HashMap<String, RawFd>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceFds {
pub stdout_pipe: Option<RawFd>,
pub stderr_pipe: Option<RawFd>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PidStatus {
Alive,
Dead,
WrongProcess,
}
pub fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
pub fn clear_cloexec(fd: RawFd) -> Result<(), std::io::Error> {
let borrowed = unsafe { BorrowedFd::borrow_raw(fd) };
let flags = fcntl(borrowed, FcntlArg::F_GETFD).map_err(std::io::Error::other)?;
let new_flags = FdFlag::from_bits_truncate(flags) - FdFlag::FD_CLOEXEC;
let borrowed = unsafe { BorrowedFd::borrow_raw(fd) };
fcntl(borrowed, FcntlArg::F_SETFD(new_flags)).map_err(std::io::Error::other)?;
Ok(())
}
pub fn is_fd_valid(fd: RawFd) -> bool {
let borrowed = unsafe { BorrowedFd::borrow_raw(fd) };
fcntl(borrowed, FcntlArg::F_GETFD).is_ok()
}
pub fn validate_pid(pid: u32, expected_exec: &str) -> PidStatus {
use std::fs;
if nix::sys::signal::kill(Pid::from_raw(pid as i32), None).is_err() {
return PidStatus::Dead;
}
let exe_path = format!("/proc/{}/exe", pid);
let cmdline_path = format!("/proc/{}/cmdline", pid);
if let Ok(exe) = fs::read_link(&exe_path) {
let exe_str = exe.to_string_lossy();
let expected_binary = expected_exec
.split_whitespace()
.next()
.unwrap_or(expected_exec);
if exe_str.contains(expected_binary) || expected_binary.contains(&*exe_str) {
return PidStatus::Alive;
}
}
if let Ok(cmdline) = fs::read_to_string(&cmdline_path) {
let cmdline_clean = cmdline.replace('\0', " ");
let expected_binary = expected_exec
.split_whitespace()
.next()
.unwrap_or(expected_exec);
if cmdline_clean.contains(expected_binary) {
return PidStatus::Alive;
}
}
PidStatus::WrongProcess
}
pub fn try_load_restore_state() -> Option<PersistedState> {
let json = std::fs::read_to_string(STATE_PATH).ok()?;
let state: PersistedState = serde_json::from_str(&json).ok()?;
if state.version > STATE_VERSION {
tracing::warn!(
"state version {} > supported {}, ignoring",
state.version,
STATE_VERSION
);
return None;
}
let age_ms = now_millis().saturating_sub(state.saved_at);
if age_ms > STATE_MAX_AGE_MS {
tracing::warn!("restore state is {} seconds old, ignoring", age_ms / 1000);
return None;
}
Some(state)
}
pub fn try_load_restore_fds() -> Option<PersistedFds> {
let fd_json = std::env::var("ZINIT_FDS").ok()?;
serde_json::from_str(&fd_json).ok()
}
pub fn cleanup_state_files() {
let _ = std::fs::remove_file(STATE_PATH);
let _ = std::fs::remove_file(FDS_PATH);
}
pub fn save_state(state: &PersistedState) -> Result<(), std::io::Error> {
std::fs::create_dir_all("/run/zinit")?;
let json = serde_json::to_string_pretty(state)?;
let tmp_path = format!("{}.tmp", STATE_PATH);
std::fs::write(&tmp_path, &json)?;
std::fs::rename(&tmp_path, STATE_PATH)?;
Ok(())
}
pub fn save_fds(fds: &PersistedFds) -> Result<(), std::io::Error> {
let json = serde_json::to_string_pretty(fds)?;
let tmp_path = format!("{}.tmp", FDS_PATH);
std::fs::write(&tmp_path, &json)?;
std::fs::rename(&tmp_path, FDS_PATH)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_persisted_state_roundtrip() {
let state = PersistedState {
version: STATE_VERSION,
saved_at: now_millis(),
boot_time: now_millis() - 1000,
services: HashMap::new(),
config_dir: crate::sdk::socket::system_config_dir()
.to_string_lossy()
.to_string(),
socket_path: crate::sdk::socket::system_path()
.to_string_lossy()
.to_string(),
};
let json = serde_json::to_string(&state).unwrap();
let restored: PersistedState = serde_json::from_str(&json).unwrap();
assert_eq!(state.version, restored.version);
assert_eq!(state.saved_at, restored.saved_at);
assert_eq!(state.config_dir, restored.config_dir);
}
#[test]
fn test_persisted_service_roundtrip() {
let service = PersistedService {
name: "test".to_string(),
state: PersistedServiceState::Running,
pid: Some(12345),
restart_count: 2,
current_restart_delay_ms: 2000,
last_exit_code: Some(0),
last_exit_signal: None,
started_at: Some(now_millis() - 5000),
last_state_change: now_millis(),
ephemeral: false,
config: None,
};
let json = serde_json::to_string(&service).unwrap();
let restored: PersistedService = serde_json::from_str(&json).unwrap();
assert_eq!(service.name, restored.name);
assert_eq!(service.pid, restored.pid);
assert_eq!(service.state, restored.state);
}
#[test]
fn test_persisted_service_state_conversion() {
assert_eq!(
PersistedServiceState::from(&ServiceState::Inactive),
PersistedServiceState::Inactive
);
assert_eq!(
PersistedServiceState::from(&ServiceState::Running { pid: 123 }),
PersistedServiceState::Running
);
assert_eq!(
PersistedServiceState::from(&ServiceState::Starting { pid: 123 }),
PersistedServiceState::Starting
);
let running = PersistedServiceState::Running.into_service_state(Some(456));
assert!(matches!(running, ServiceState::Running { pid: 456 }));
let running_no_pid = PersistedServiceState::Running.into_service_state(None);
assert!(matches!(running_no_pid, ServiceState::Exited { .. }));
}
#[test]
fn test_pid_validation_dead() {
assert_eq!(validate_pid(999999, "nonexistent"), PidStatus::Dead);
}
#[test]
fn test_persisted_fds_roundtrip() {
let fds = PersistedFds {
version: FDS_VERSION,
rpc_socket: Some(5),
services: HashMap::from([(
"app".to_string(),
ServiceFds {
stdout_pipe: Some(7),
stderr_pipe: Some(8),
},
)]),
socket_activated: HashMap::new(),
};
let json = serde_json::to_string(&fds).unwrap();
let restored: PersistedFds = serde_json::from_str(&json).unwrap();
assert_eq!(fds.version, restored.version);
assert_eq!(fds.rpc_socket, restored.rpc_socket);
assert_eq!(fds.services.len(), restored.services.len());
}
}