use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::procserv::child::{ChildEvent, ChildHandle, ChildSpec};
use crate::procserv::client::{
ClientId, ClientMeta, InboundEvent, IncomingClient, OutboundFrame, spawn_client,
};
use crate::procserv::config::ProcServConfig;
use crate::procserv::error::{ProcServError, ProcServResult};
use crate::procserv::menu::{Action, scan as menu_scan};
use crate::procserv::restart::{RestartMode, RestartTracker};
use crate::procserv::sidecar::{
InfoSnapshot, LogFile, remove_pid_file, render_procserv_info, write_info_file, write_pid_file,
};
pub struct ProcServ {
config: Arc<ProcServConfig>,
}
impl ProcServ {
pub fn new(config: ProcServConfig) -> ProcServResult<Self> {
config.validate().map_err(ProcServError::Config)?;
Ok(Self {
config: Arc::new(config),
})
}
pub async fn run(self) -> ProcServResult<()> {
let mut state = SupervisorState::bootstrap(self.config).await?;
state.event_loop().await
}
}
struct SupervisorState {
config: Arc<ProcServConfig>,
inbound_tx: mpsc::Sender<(ClientId, InboundEvent)>,
inbound_rx: mpsc::Receiver<(ClientId, InboundEvent)>,
incoming_rx: mpsc::Receiver<IncomingClient>,
clients: HashMap<ClientId, ClientEntry>,
child: Option<ChildSlot>,
restart_mode: RestartMode,
restart_tracker: RestartTracker,
log: Option<LogFile>,
has_run_once: bool,
}
struct ClientEntry {
out_tx: mpsc::Sender<OutboundFrame>,
#[allow(dead_code)]
meta: ClientMeta,
}
struct ChildSlot {
handle: ChildHandle,
rx: mpsc::Receiver<ChildEvent>,
}
impl SupervisorState {
async fn bootstrap(config: Arc<ProcServConfig>) -> ProcServResult<Self> {
let (inbound_tx, inbound_rx) = mpsc::channel::<(ClientId, InboundEvent)>(256);
let (incoming_tx, incoming_rx) = mpsc::channel::<IncomingClient>(8);
if let Some(p) = &config.logging.pid_path {
write_pid_file(p, std::process::id() as i32)?;
}
let log = if let Some(p) = &config.logging.log_path {
Some(LogFile::open(p, config.logging.time_format.clone()).await?)
} else {
None
};
if let Some(addr) = config.listen.tcp_bind {
let tx = incoming_tx.clone();
tokio::spawn(async move {
if let Err(e) = super::listener::run_tcp(addr, false, tx).await {
tracing::error!(error = %e, "procserv-rs: TCP listener exited");
}
});
}
if let Some(path) = config.listen.unix_path.clone() {
let tx = incoming_tx.clone();
tokio::spawn(async move {
if let Err(e) = super::listener::run_unix(path, false, tx).await {
tracing::error!(error = %e, "procserv-rs: UNIX listener exited");
}
});
}
drop(incoming_tx);
let mut state = Self {
restart_mode: config.restart_mode,
config,
inbound_tx,
inbound_rx,
incoming_rx,
clients: HashMap::new(),
child: None,
restart_tracker: RestartTracker::new(),
log,
has_run_once: false,
};
if !state.config.wait_for_manual_start {
state.respawn_child().await?;
}
Ok(state)
}
async fn event_loop(&mut self) -> ProcServResult<()> {
loop {
let child_event = async {
match self.child.as_mut() {
Some(slot) => slot.rx.recv().await,
None => std::future::pending().await,
}
};
tokio::select! {
biased;
Some((peer_id, event)) = self.inbound_rx.recv() => {
if self.handle_inbound(peer_id, event).await? {
return Ok(()); }
}
ev = child_event => {
if let Some(ev) = ev {
match self.handle_child_event(ev).await? {
ChildLoopOutcome::Continue => {}
ChildLoopOutcome::Shutdown => return Ok(()),
}
} else {
self.child = None;
}
}
Some(incoming) = self.incoming_rx.recv() => {
self.handle_new_client(incoming).await;
}
}
}
}
async fn handle_inbound(
&mut self,
client_id: ClientId,
event: InboundEvent,
) -> ProcServResult<bool> {
match event {
InboundEvent::TelnetReply { bytes } => {
if let Some(entry) = self.clients.get(&client_id) {
let _ = entry.out_tx.send(OutboundFrame::RawIac(bytes)).await;
}
}
InboundEvent::Disconnected => {
self.clients.remove(&client_id);
}
InboundEvent::Data { bytes } => {
let child_alive = self.child.is_some();
let actions = menu_scan(&bytes, &self.config.keys, child_alive);
let mut quit = false;
for action in &actions {
match action {
Action::None => {}
Action::KillChild => {
if let Some(slot) = self.child.as_ref() {
let _ = slot.handle.signal(self.config.child.kill_signal);
}
}
Action::RestartChild => {
if self.child.is_none() {
self.banner("@@@ Manual restart").await;
if let Err(e) = self.respawn_child().await {
tracing::error!(error = %e, "procserv-rs: manual respawn failed");
}
}
}
Action::ToggleRestartMode => {
self.restart_mode = self.restart_mode.next();
let msg = format!(
"\r\n@@@ Toggled auto restart mode to {}\r\n",
self.restart_mode.label()
);
self.fanout_to_all(msg.as_bytes()).await;
}
Action::LogoutClient => {
if let Some(entry) = self.clients.remove(&client_id) {
let _ = entry.out_tx.send(OutboundFrame::Disconnect).await;
}
}
Action::QuitServer => {
quit = true;
}
}
}
self.fanout_excluding(&bytes, Some(client_id)).await;
if quit {
return Ok(true);
}
}
}
Ok(false)
}
async fn handle_child_event(&mut self, event: ChildEvent) -> ProcServResult<ChildLoopOutcome> {
match event {
ChildEvent::Output(bytes) => {
self.fanout_excluding(&bytes, None).await;
if let Some(log) = &self.log
&& let Err(e) = log.write_chunk(&bytes).await
{
tracing::warn!(error = %e, "procserv-rs: log write failed");
}
Ok(ChildLoopOutcome::Continue)
}
ChildEvent::Exited { status } => {
self.child = None;
let msg = format!(
"\r\n@@@ Child exited (status: {:?})\r\n",
status
.map(|s| s.to_string())
.unwrap_or_else(|| "unknown".into())
);
self.fanout_to_all(msg.as_bytes()).await;
match self.restart_mode {
RestartMode::OnExit => {
match self.restart_tracker.try_record(&self.config.restart) {
Ok(()) => {
tokio::time::sleep(self.config.holdoff).await;
self.banner("@@@ Auto restart").await;
self.respawn_child().await?;
Ok(ChildLoopOutcome::Continue)
}
Err((max, win)) => Err(ProcServError::RestartLimitExceeded {
attempts: max,
window_secs: win,
}),
}
}
RestartMode::OneShot => {
if !self.has_run_once {
self.has_run_once = true;
tokio::time::sleep(self.config.holdoff).await;
self.banner("@@@ One-shot relaunch").await;
self.respawn_child().await?;
Ok(ChildLoopOutcome::Continue)
} else {
self.banner("@@@ One-shot mode: exiting").await;
Ok(ChildLoopOutcome::Shutdown)
}
}
RestartMode::Disabled => {
self.banner("@@@ Auto restart disabled — exiting").await;
Ok(ChildLoopOutcome::Shutdown)
}
}
}
}
}
async fn respawn_child(&mut self) -> ProcServResult<()> {
let pre_spawn_info = InfoSnapshot {
procserv_pid: std::process::id() as i32,
child_pid: None,
child_exe: self.config.child.program.clone(),
child_args: self.config.child.args.clone(),
};
unsafe { std::env::set_var("PROCSERV_INFO", render_procserv_info(&pre_spawn_info)) };
let spec = ChildSpec {
program: self.config.child.program.clone(),
args: self.config.child.args.clone(),
cwd: self.config.child.cwd.clone(),
ignore_chars: self.config.child.ignore_chars.clone(),
};
let (handle, rx) = ChildHandle::spawn(&spec)?;
let post_spawn_info = InfoSnapshot {
procserv_pid: pre_spawn_info.procserv_pid,
child_pid: Some(handle.pid()),
child_exe: pre_spawn_info.child_exe.clone(),
child_args: pre_spawn_info.child_args.clone(),
};
if let Some(p) = &self.config.logging.info_path {
let _ = write_info_file(p, &post_spawn_info);
}
self.has_run_once = true;
self.banner(&format!("@@@ Child started (pid {})", handle.pid()))
.await;
self.child = Some(ChildSlot { handle, rx });
Ok(())
}
async fn handle_new_client(&mut self, incoming: IncomingClient) {
let (meta, out_tx) = spawn_client(incoming, self.inbound_tx.clone());
let banner = self.welcome_banner();
let _ = out_tx.send(OutboundFrame::Bytes(banner.into_bytes())).await;
self.clients.insert(
meta.id,
ClientEntry {
out_tx,
meta: meta.clone(),
},
);
tracing::debug!(client = meta.id.raw(), peer = ?meta.peer, readonly = meta.readonly, "procserv-rs: client connected");
}
fn welcome_banner(&self) -> String {
let mut s = String::new();
s.push_str("@@@ Welcome to procserv-rs\r\n");
s.push_str(&format!(
"@@@ Wrapping: {} (mode: {})\r\n",
self.config.child.name,
self.restart_mode.label()
));
if let Some(c) = self.config.keys.kill {
s.push_str(&format!(
"@@@ Use ^{} to kill the child\r\n",
ascii_caret(c)
));
}
if let Some(c) = self.config.keys.toggle_restart {
s.push_str(&format!(
"@@@ Use ^{} to toggle auto restart\r\n",
ascii_caret(c)
));
}
if let Some(c) = self.config.keys.logout {
s.push_str(&format!("@@@ Use ^{} to logout\r\n", ascii_caret(c)));
}
s
}
async fn fanout_to_all(&self, bytes: &[u8]) {
for entry in self.clients.values() {
let _ = entry
.out_tx
.send(OutboundFrame::Bytes(bytes.to_vec()))
.await;
}
}
async fn fanout_excluding(&self, bytes: &[u8], exclude: Option<ClientId>) {
for (id, entry) in &self.clients {
if Some(*id) == exclude {
continue;
}
let _ = entry
.out_tx
.send(OutboundFrame::Bytes(bytes.to_vec()))
.await;
}
if exclude.is_some()
&& let Some(slot) = self.child.as_ref()
&& let Err(e) = slot.handle.write_stdin(bytes).await
{
tracing::debug!(error = %e, "procserv-rs: child stdin write failed");
}
}
async fn banner(&self, text: &str) {
let mut line = text.trim_end_matches('\n').to_string();
line.push_str("\r\n");
self.fanout_to_all(line.as_bytes()).await;
}
}
#[derive(Debug)]
enum ChildLoopOutcome {
Continue,
Shutdown,
}
fn ascii_caret(c: u8) -> char {
if c < 32 {
(c + b'@') as char
} else {
c as char
}
}
impl Drop for SupervisorState {
fn drop(&mut self) {
if let Some(p) = &self.config.logging.pid_path {
remove_pid_file(p);
}
if let Some(slot) = self.child.as_ref() {
let _ = slot.handle.signal(self.config.child.kill_signal);
}
}
}