use std::collections::HashMap;
use std::sync::Arc;
use chrono::{DateTime, Local};
use tokio::sync::mpsc;
use crate::procserv::child::{ChildEvent, ChildHandle, ChildSpec};
use crate::procserv::client::{
ClientId, ClientMeta, InboundEvent, IncomingClient, OutboundFrame, spawn_client,
};
use crate::procserv::config::ProcServConfig;
use crate::procserv::error::{ProcServError, ProcServResult};
use crate::procserv::menu::{Action, scan as menu_scan};
use crate::procserv::restart::{RestartMode, RestartTracker};
use crate::procserv::sidecar::{
InfoSnapshot, LogFile, remove_pid_file, render_procserv_info, write_info_file, write_pid_file,
};
pub struct ProcServ {
config: Arc<ProcServConfig>,
}
impl ProcServ {
pub fn new(config: ProcServConfig) -> ProcServResult<Self> {
config.validate().map_err(ProcServError::Config)?;
Ok(Self {
config: Arc::new(config),
})
}
pub async fn run(self) -> ProcServResult<()> {
let mut state = SupervisorState::bootstrap(self.config).await?;
state.event_loop().await
}
}
struct SupervisorState {
config: Arc<ProcServConfig>,
inbound_tx: mpsc::Sender<(ClientId, InboundEvent)>,
inbound_rx: mpsc::Receiver<(ClientId, InboundEvent)>,
incoming_rx: mpsc::Receiver<IncomingClient>,
clients: HashMap<ClientId, ClientEntry>,
child: Option<ChildSlot>,
restart_mode: RestartMode,
restart_tracker: RestartTracker,
log: Option<LogFile>,
sighup: tokio::signal::unix::Signal,
has_run_once: bool,
pending_restart: Option<PendingRestart>,
proc_started: DateTime<Local>,
}
struct PendingRestart {
at: tokio::time::Instant,
banner: &'static str,
}
struct ClientEntry {
out_tx: mpsc::Sender<OutboundFrame>,
meta: ClientMeta,
}
struct ChildSlot {
handle: ChildHandle,
rx: mpsc::Receiver<ChildEvent>,
started_at: tokio::time::Instant,
started_wall: DateTime<Local>,
}
impl SupervisorState {
async fn bootstrap(config: Arc<ProcServConfig>) -> ProcServResult<Self> {
let (inbound_tx, inbound_rx) = mpsc::channel::<(ClientId, InboundEvent)>(256);
let (incoming_tx, incoming_rx) = mpsc::channel::<IncomingClient>(8);
if let Some(p) = &config.logging.pid_path {
write_pid_file(p, std::process::id() as i32)?;
}
let log = if let Some(p) = &config.logging.log_path {
Some(
LogFile::open(
p,
config.logging.stamp_log,
config.logging.stamp_format.clone(),
)
.await?,
)
} else {
None
};
if let Some(addr) = config.listen.tcp_bind {
let tx = incoming_tx.clone();
tokio::spawn(async move {
if let Err(e) = super::listener::run_tcp(addr, false, tx).await {
tracing::error!(error = %e, "procserv-rs: TCP listener exited");
}
});
}
if let Some(path) = config.listen.unix_path.clone() {
let tx = incoming_tx.clone();
tokio::spawn(async move {
if let Err(e) = super::listener::run_unix(path, false, tx).await {
tracing::error!(error = %e, "procserv-rs: UNIX listener exited");
}
});
}
if let Some(addr) = config.listen.log_bind {
let tx = incoming_tx.clone();
tokio::spawn(async move {
if let Err(e) = super::listener::run_tcp(addr, true, tx).await {
tracing::error!(error = %e, "procserv-rs: log listener exited");
}
});
}
drop(incoming_tx);
let sighup = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
.map_err(ProcServError::Io)?;
let mut state = Self {
restart_mode: config.restart_mode,
config,
inbound_tx,
inbound_rx,
incoming_rx,
clients: HashMap::new(),
child: None,
restart_tracker: RestartTracker::new(),
log,
sighup,
has_run_once: false,
pending_restart: None,
proc_started: Local::now(),
};
if !state.config.wait_for_manual_start {
state.respawn_child().await?;
}
Ok(state)
}
async fn event_loop(&mut self) -> ProcServResult<()> {
loop {
let restart_at = self.pending_restart.as_ref().map(|p| p.at);
let child_event = async {
match self.child.as_mut() {
Some(slot) => slot.rx.recv().await,
None => std::future::pending().await,
}
};
let restart_due = async {
match restart_at {
Some(at) => tokio::time::sleep_until(at).await,
None => std::future::pending().await,
}
};
tokio::select! {
biased;
Some((peer_id, event)) = self.inbound_rx.recv() => {
if self.handle_inbound(peer_id, event).await? {
return Ok(()); }
}
ev = child_event => {
if let Some(ev) = ev {
match self.handle_child_event(ev).await? {
ChildLoopOutcome::Continue => {}
ChildLoopOutcome::Shutdown => return Ok(()),
}
} else {
self.child = None;
}
}
Some(incoming) = self.incoming_rx.recv() => {
self.handle_new_client(incoming).await;
}
_ = self.sighup.recv() => {
self.reopen_log().await;
}
_ = restart_due => {
if let Some(pending) = self.pending_restart.take() {
self.banner(pending.banner).await;
self.respawn_child().await?;
}
}
}
}
}
async fn reopen_log(&self) {
if let Some(log) = &self.log {
match log.reopen().await {
Ok(()) => tracing::info!("procserv-rs: reopened log file on SIGHUP"),
Err(e) => {
tracing::warn!(error = %e, "procserv-rs: log reopen on SIGHUP failed")
}
}
}
}
async fn handle_inbound(
&mut self,
client_id: ClientId,
event: InboundEvent,
) -> ProcServResult<bool> {
match event {
InboundEvent::TelnetReply { bytes } => {
if let Some(entry) = self.clients.get(&client_id) {
let _ = entry.out_tx.send(OutboundFrame::RawIac(bytes)).await;
}
}
InboundEvent::Disconnected => {
self.clients.remove(&client_id);
}
InboundEvent::Data { bytes } => {
let child_alive = self.child.is_some();
let actions = menu_scan(&bytes, &self.config.keys, child_alive);
let mut quit = false;
for action in &actions {
match action {
Action::None => {}
Action::KillChild => {
if self.child.is_some() {
self.fanout_to_all(b"\r\n@@@ Got a kill command\r\n").await;
if let Some(slot) = self.child.as_ref() {
let _ = slot.handle.signal(self.config.child.kill_signal);
}
}
}
Action::RestartChild => {
if self.child.is_none() {
self.banner("@@@ Manual restart").await;
if let Err(e) = self.respawn_child().await {
tracing::error!(error = %e, "procserv-rs: manual respawn failed");
}
}
}
Action::ToggleRestartMode => {
self.restart_mode = self.restart_mode.next();
let msg = format!(
"\r\n@@@ Toggled auto restart mode to {}\r\n",
self.restart_mode.label()
);
self.fanout_to_all(msg.as_bytes()).await;
}
Action::LogoutClient => {
if let Some(entry) = self.clients.remove(&client_id) {
let _ = entry.out_tx.send(OutboundFrame::Disconnect).await;
}
}
Action::QuitServer => {
quit = true;
}
}
}
self.fanout_excluding(&bytes, Some(client_id)).await;
if quit {
return Ok(true);
}
}
}
Ok(false)
}
async fn handle_child_event(&mut self, event: ChildEvent) -> ProcServResult<ChildLoopOutcome> {
match event {
ChildEvent::Output(bytes) => {
self.fanout_excluding(&bytes, None).await;
if let Some(log) = &self.log
&& let Err(e) = log.write_chunk(&bytes).await
{
tracing::warn!(error = %e, "procserv-rs: log write failed");
}
Ok(ChildLoopOutcome::Continue)
}
ChildEvent::Exited { status } => {
let started_at = self.child.take().map(|slot| slot.started_at);
let msg = format!(
"\r\n@@@ Child exited (status: {:?})\r\n",
status
.map(|s| s.to_string())
.unwrap_or_else(|| "unknown".into())
);
self.fanout_to_all(msg.as_bytes()).await;
match self.restart_mode {
RestartMode::OnExit => {
match self.restart_tracker.try_record(&self.config.restart) {
Ok(()) => {
self.schedule_restart(started_at, "@@@ Auto restart");
Ok(ChildLoopOutcome::Continue)
}
Err((max, win)) => Err(ProcServError::RestartLimitExceeded {
attempts: max,
window_secs: win,
}),
}
}
RestartMode::OneShot => {
if !self.has_run_once {
self.has_run_once = true;
self.schedule_restart(started_at, "@@@ One-shot relaunch");
Ok(ChildLoopOutcome::Continue)
} else {
self.banner("@@@ One-shot mode: exiting").await;
Ok(ChildLoopOutcome::Shutdown)
}
}
RestartMode::Disabled => {
self.banner("@@@ Auto restart disabled — exiting").await;
Ok(ChildLoopOutcome::Shutdown)
}
}
}
}
}
async fn respawn_child(&mut self) -> ProcServResult<()> {
let pre_spawn_info = InfoSnapshot {
procserv_pid: std::process::id() as i32,
child_pid: None,
child_exe: self.config.child.program.clone(),
child_args: self.config.child.args.clone(),
};
unsafe { std::env::set_var("PROCSERV_INFO", render_procserv_info(&pre_spawn_info)) };
let spec = ChildSpec {
program: self.config.child.program.clone(),
args: self.config.child.args.clone(),
cwd: self.config.child.cwd.clone(),
ignore_chars: self.config.child.ignore_chars.clone(),
};
let (handle, rx) = ChildHandle::spawn(&spec)?;
let post_spawn_info = InfoSnapshot {
procserv_pid: pre_spawn_info.procserv_pid,
child_pid: Some(handle.pid()),
child_exe: pre_spawn_info.child_exe.clone(),
child_args: pre_spawn_info.child_args.clone(),
};
if let Some(p) = &self.config.logging.info_path {
let _ = write_info_file(p, &post_spawn_info);
}
self.has_run_once = true;
self.pending_restart = None;
self.banner(&format!("@@@ Child started (pid {})", handle.pid()))
.await;
self.child = Some(ChildSlot {
handle,
rx,
started_at: tokio::time::Instant::now(),
started_wall: Local::now(),
});
Ok(())
}
async fn handle_new_client(&mut self, incoming: IncomingClient) {
let (meta, out_tx) = spawn_client(incoming, self.inbound_tx.clone());
let banner = self.welcome_banner(meta.readonly);
let _ = out_tx.send(OutboundFrame::Bytes(banner.into_bytes())).await;
self.clients.insert(
meta.id,
ClientEntry {
out_tx,
meta: meta.clone(),
},
);
tracing::debug!(client = meta.id.raw(), peer = ?meta.peer, readonly = meta.readonly, "procserv-rs: client connected");
}
fn welcome_banner(&self, readonly: bool) -> String {
let mut s = String::new();
if !readonly {
s.push_str("@@@ Welcome to procserv-rs\r\n");
s.push_str(&format!(
"@@@ Wrapping: {} (mode: {})\r\n",
self.config.child.name,
self.restart_mode.label()
));
if let Some(c) = self.config.keys.kill {
s.push_str(&format!(
"@@@ Use ^{} to kill the child\r\n",
ascii_caret(c)
));
}
if let Some(c) = self.config.keys.toggle_restart {
s.push_str(&format!(
"@@@ Use ^{} to toggle auto restart\r\n",
ascii_caret(c)
));
}
if let Some(c) = self.config.keys.logout {
s.push_str(&format!("@@@ Use ^{} to logout\r\n", ascii_caret(c)));
}
}
let tf = &self.config.logging.time_format;
s.push_str(&format!(
"@@@ procServ server started at: {}\r\n",
self.proc_started.format(tf)
));
if let Some(slot) = &self.child {
s.push_str(&format!(
"@@@ Child \"{}\" started at: {}\r\n",
self.config.child.name,
slot.started_wall.format(tf)
));
}
if !readonly {
let users = self.clients.values().filter(|e| !e.meta.readonly).count();
let loggers = self.clients.values().filter(|e| e.meta.readonly).count();
s.push_str(&format!(
"@@@ {users} user(s) and {loggers} logger(s) connected (plus you)\r\n"
));
}
s
}
async fn fanout_to_all(&self, bytes: &[u8]) {
for entry in self.clients.values() {
let _ = entry
.out_tx
.send(OutboundFrame::Bytes(bytes.to_vec()))
.await;
}
if let Some(log) = &self.log
&& let Err(e) = log.write_chunk(bytes).await
{
tracing::warn!(error = %e, "procserv-rs: log write failed");
}
}
async fn fanout_excluding(&self, bytes: &[u8], exclude: Option<ClientId>) {
for (id, entry) in &self.clients {
if Some(*id) == exclude {
continue;
}
let _ = entry
.out_tx
.send(OutboundFrame::Bytes(bytes.to_vec()))
.await;
}
if exclude.is_some()
&& let Some(slot) = self.child.as_ref()
&& let Err(e) = slot.handle.write_stdin(bytes).await
{
tracing::debug!(error = %e, "procserv-rs: child stdin write failed");
}
}
fn schedule_restart(&mut self, started_at: Option<tokio::time::Instant>, banner: &'static str) {
let remaining = match started_at {
Some(t) => remaining_holdoff(self.config.holdoff, t.elapsed()),
None => self.config.holdoff,
};
self.pending_restart = Some(PendingRestart {
at: tokio::time::Instant::now() + remaining,
banner,
});
}
async fn banner(&self, text: &str) {
let mut line = text.trim_end_matches('\n').to_string();
line.push_str("\r\n");
self.fanout_to_all(line.as_bytes()).await;
}
}
#[derive(Debug)]
enum ChildLoopOutcome {
Continue,
Shutdown,
}
fn remaining_holdoff(
holdoff: std::time::Duration,
uptime: std::time::Duration,
) -> std::time::Duration {
holdoff.saturating_sub(uptime)
}
fn ascii_caret(c: u8) -> char {
if c < 32 {
(c + b'@') as char
} else {
c as char
}
}
impl Drop for SupervisorState {
fn drop(&mut self) {
if let Some(p) = &self.config.logging.pid_path {
remove_pid_file(p);
}
if let Some(slot) = self.child.as_ref() {
let _ = slot.handle.signal(self.config.child.kill_signal);
}
}
}
#[cfg(test)]
mod tests {
use super::remaining_holdoff;
use std::time::Duration;
#[test]
fn holdoff_zero_uptime_waits_full() {
assert_eq!(
remaining_holdoff(Duration::from_secs(15), Duration::ZERO),
Duration::from_secs(15)
);
}
#[test]
fn holdoff_short_uptime_waits_difference() {
assert_eq!(
remaining_holdoff(Duration::from_secs(15), Duration::from_secs(4)),
Duration::from_secs(11)
);
}
#[test]
fn holdoff_uptime_equals_holdoff_no_wait() {
assert_eq!(
remaining_holdoff(Duration::from_secs(15), Duration::from_secs(15)),
Duration::ZERO
);
}
#[test]
fn holdoff_long_uptime_restarts_immediately() {
assert_eq!(
remaining_holdoff(Duration::from_secs(15), Duration::from_secs(3600)),
Duration::ZERO
);
}
}