use std::{
ffi::OsString,
io,
io::{Read, Write},
os::{
fd::AsFd,
unix::{net::UnixStream, process::CommandExt as _},
},
process,
sync::atomic::{AtomicBool, Ordering},
sync::Arc,
thread,
time::{Duration, Instant},
};
use anyhow::{anyhow, Context};
use nix::{poll, sys::signal, unistd};
use parking_lot::Mutex;
use shpool_protocol::{Chunk, ChunkKind, TtySize};
use tracing::{error, info, instrument, span, trace, warn, Level};
use crate::{consts, protocol::ChunkExt as _, tty::TtySizeExt as _};
const POLL_MS: u16 = 100;
#[derive(Debug)]
pub struct PagerCtl {
pub tty_size_change: crossbeam_channel::Sender<TtySize>,
pub tty_size_change_ack: crossbeam_channel::Receiver<()>,
}
pub struct Pager {
pager_bin: String,
}
impl Pager {
pub fn new(pager_bin: String) -> Self {
Pager { pager_bin }
}
#[instrument(skip_all)]
pub fn display(
&self,
client_stream: &mut UnixStream,
ctl_slot: Arc<Mutex<Option<PagerCtl>>>,
init_tty_size: TtySize,
msg: &str,
shell_env: &[(OsString, OsString)],
) -> anyhow::Result<TtySize> {
let (tty_size_change_tx, tty_size_change_rx) = crossbeam_channel::bounded(0);
let (tty_size_change_ack_tx, tty_size_change_ack_rx) = crossbeam_channel::bounded(0);
{
let mut ctl_handle = ctl_slot.lock();
if ctl_handle.is_some() {
return Err(anyhow!("only one pager per session at a time allowed"));
}
trace!("registering PagerCtl");
*ctl_handle = Some(PagerCtl {
tty_size_change: tty_size_change_tx,
tty_size_change_ack: tty_size_change_ack_rx,
});
}
let _ctl_guard = PagerCltGuard { ctl_slot };
let mut msg_file = tempfile::NamedTempFile::with_prefix("shpool_pager")
.context("creating tmp file to display msg via pager")?;
let cleaned_msg = strip_ansi_escapes::strip(msg);
msg_file.write_all(cleaned_msg.as_slice()).context("writing msg to tmp pager file")?;
let mut cmd = process::Command::new(&self.pager_bin);
cmd.env_clear().envs(shell_env.to_vec());
cmd.arg(msg_file.path().as_os_str());
info!("forking pager pty proc");
let fork = shpool_pty::fork::Fork::from_ptmx().context("forking pty")?;
if fork.is_child().is_ok() {
for fd in consts::STDERR_FD + 1..(nix::unistd::SysconfVar::OPEN_MAX as i32) {
let _ = nix::unistd::close(fd);
}
let err = cmd.exec();
eprintln!("pager exec err: {err:?}");
std::process::exit(1);
}
let pager_exited = Arc::new(AtomicBool::new(false));
let _proc_guard =
PagerProcGuard { pager_proc: &fork, pager_exited: Arc::clone(&pager_exited) };
let pager_exited_ref = Arc::clone(&pager_exited);
let waitable_child = fork.clone();
thread::spawn(move || {
let _s = span!(Level::INFO, "pager_exit_monitor").entered();
match waitable_child.wait_for_exit() {
Ok((_, Some(exit_status))) => {
info!("child pager exited with status {}", exit_status);
pager_exited_ref.store(true, Ordering::Relaxed);
}
Ok((_, None)) => {
info!("child pager exited without status");
pager_exited_ref.store(true, Ordering::Relaxed);
}
Err(e) => {
info!("error waiting on pager child: {:?}", e);
pager_exited_ref.store(true, Ordering::Relaxed);
}
}
info!("reaped child pager: {:?}", waitable_child);
});
let mut pty_master = fork.is_parent().context("getting pty_master handle")?;
let pty_master_fd = pty_master.raw_fd();
init_tty_size.set_fd(pty_master_fd).context("setting init tty size")?;
let tty_size = Arc::new(Mutex::new(init_tty_size.clone()));
let tty_size_ref = Arc::clone(&tty_size);
info!("spawning pager size change listener");
thread::spawn(move || {
let _s = span!(Level::INFO, "pager_size_change").entered();
while let Ok(size) = tty_size_change_rx.recv() {
info!("recvd new size: {:?}", size);
if let Err(e) = size.set_fd(pty_master_fd) {
warn!("setting pager size: {:?}", e);
}
{
let mut tty_size = tty_size_ref.lock();
*tty_size = size;
}
if let Err(e) = tty_size_change_ack_tx.send(()) {
error!("could not send size change ack: {:?}", e);
break;
}
}
info!("pager size change loop done");
});
let mut last_heartbeat_at = Instant::now();
let mut buf = vec![0; consts::BUF_SIZE];
let watchable_master = pty_master.clone();
let watchable_client_stream =
client_stream.try_clone().context("could not clone client stream")?;
loop {
let mut poll_fds = [
poll::PollFd::new(watchable_master.borrow_fd(), poll::PollFlags::POLLIN),
poll::PollFd::new(watchable_client_stream.as_fd(), poll::PollFlags::POLLIN),
];
let nready = poll::poll(&mut poll_fds, POLL_MS).context("polling both streams")?;
if pager_exited.load(Ordering::Relaxed) {
let tty_size = tty_size.lock();
return Ok(tty_size.clone());
}
if nready == 0 {
let now = Instant::now();
if now
.checked_duration_since(
last_heartbeat_at
.checked_add(consts::HEARTBEAT_DURATION)
.ok_or(anyhow!("could not add to dur"))?,
)
.is_some()
{
last_heartbeat_at = now;
let chunk = Chunk { kind: ChunkKind::Heartbeat, buf: &[] };
match chunk.write_to(client_stream).and_then(|_| client_stream.flush()) {
Ok(_) => {
trace!("wrote heartbeat");
}
Err(e) if e.kind() == io::ErrorKind::BrokenPipe => {
trace!("client hangup writing heartbeat: {:?}", e);
return Err(PagerError::ClientHangup)?;
}
Err(e) => {
return Err(e).context("writing heartbeat")?;
}
}
}
} else {
assert!(nready > 0);
let pty_master_poll_fd = &poll_fds[0];
let client_stream_poll_fd = &poll_fds[1];
if pty_master_poll_fd.any().unwrap_or(false) {
let len = pty_master.read(&mut buf).context("reading chunk from pty master")?;
if len == 0 {
return Err(anyhow!("EOF from pty while displaying pager"));
}
let chunk = Chunk { kind: ChunkKind::Data, buf: &buf[..len] };
match chunk.write_to(client_stream).and_then(|_| client_stream.flush()) {
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::BrokenPipe => {
trace!("client hangup writing data chunk: {:?}", e);
return Err(PagerError::ClientHangup)?;
}
Err(e) => {
return Err(e).context("writing data chunk")?;
}
}
}
if client_stream_poll_fd.any().unwrap_or(false) {
let len = client_stream.read(&mut buf).context("reading client chunk")?;
if len == 0 {
info!("EOF");
return Err(anyhow!("EOF from client while displaying pager"));
}
trace!("user input: {}", String::from_utf8_lossy(&buf[..len]));
if let Err(e) = pty_master.write_all(&buf[0..len]) {
info!("Error writing to pager pty, nbd though: {:?}", e);
let tty_size = tty_size.lock();
return Ok(tty_size.clone());
}
if let Err(e) = pty_master.flush() {
info!("Error flushing pager pty, nbd though: {:?}", e);
let tty_size = tty_size.lock();
return Ok(tty_size.clone());
}
}
}
}
}
}
struct PagerCltGuard {
ctl_slot: Arc<Mutex<Option<PagerCtl>>>,
}
impl std::ops::Drop for PagerCltGuard {
fn drop(&mut self) {
let mut pager_ctl = self.ctl_slot.lock();
*pager_ctl = None;
trace!("deregistered PagerCtl");
}
}
struct PagerProcGuard<'pager> {
pager_proc: &'pager shpool_pty::fork::Fork,
pager_exited: Arc<AtomicBool>,
}
impl std::ops::Drop for PagerProcGuard<'_> {
fn drop(&mut self) {
if self.pager_exited.load(Ordering::Relaxed) {
return;
}
if let Err(e) = self.kill() {
error!("Error cleaning up pager proc: {:?}", e);
}
}
}
impl PagerProcGuard<'_> {
fn kill(&self) -> anyhow::Result<()> {
let pid = if let shpool_pty::fork::Fork::Parent(pid, _) = self.pager_proc {
*pid
} else {
return Err(anyhow!("somehow have a child pty handle in the main proc"));
};
signal::kill(unistd::Pid::from_raw(pid), Some(signal::Signal::SIGTERM))
.context("sending SIGTERM to pager proc")?;
let mut sleep_ms = 10;
for _ in 0..7 {
if self.pager_exited.load(Ordering::Relaxed) {
return Ok(());
}
thread::sleep(Duration::from_millis(sleep_ms));
sleep_ms *= 2;
}
signal::kill(unistd::Pid::from_raw(pid), Some(signal::Signal::SIGKILL))
.context("sending SIGKILL to pager proc")?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum PagerError {
ClientHangup,
}
impl std::fmt::Display for PagerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "{self:?}")?;
Ok(())
}
}
impl std::error::Error for PagerError {}