use std::{
io::{Read, Write},
time,
};
use anyhow::{anyhow, Context};
use nix::{poll, poll::PollFlags};
use tracing::{debug, error, info, instrument, warn};
use crate::{
consts::{SENTINEL_FLAG_VAR, STARTUP_SENTINEL},
daemon::trie::{Trie, TrieCursor},
exe, test_hooks,
};
const SENTINEL_POLL_MS: u16 = 500;
const SENTINEL_POLL_TIMEOUT: time::Duration = time::Duration::from_secs(90);
#[derive(Debug, Clone)]
enum KnownShell {
Bash,
Zsh,
Fish,
}
#[instrument(skip_all)]
pub fn maybe_inject_prefix(
pty_master: &mut shpool_pty::fork::Fork,
prompt_prefix: &str,
session_name: &str,
) -> anyhow::Result<()> {
let shell_pid = pty_master.child_pid().ok_or(anyhow!("no child pid"))?;
let mut pty_master = pty_master.is_parent().context("expected parent")?;
wait_for_startup(&mut pty_master)?;
let shell_type = sniff_shell(shell_pid);
debug!("sniffed shell type: {:?}", shell_type);
let prompt_prefix = prompt_prefix.replace("$SHPOOL_SESSION_NAME", session_name);
let mut script = match (prompt_prefix.as_str(), shell_type) {
(_, Ok(KnownShell::Bash)) => format!(
r#"
if [[ -z "${{PROMPT_COMMAND+x}}" ]]; then
PS1="{prompt_prefix}${{PS1}}"
else
SHPOOL__OLD_PROMPT_COMMAND=("${{PROMPT_COMMAND[@]}}")
SHPOOL__OLD_PS1="${{PS1}}"
function __shpool__prompt_command() {{
PS1="${{SHPOOL__OLD_PS1}}"
for prompt_hook in "${{SHPOOL__OLD_PROMPT_COMMAND[@]}}"
do
eval "${{prompt_hook}}"
done
PS1="{prompt_prefix}${{PS1}}"
}}
PROMPT_COMMAND=__shpool__prompt_command
fi
"#
),
(_, Ok(KnownShell::Zsh)) => format!(
r#"
typeset -a precmd_functions
SHPOOL__OLD_PROMPT="${{PROMPT}}"
function __shpool__reset_rprompt() {{
PROMPT="${{SHPOOL__OLD_PROMPT}}"
}}
precmd_functions[1,0]=(__shpool__reset_rprompt)
function __shpool__prompt_command() {{
PROMPT="{prompt_prefix}${{PROMPT}}"
}}
precmd_functions+=(__shpool__prompt_command)
"#
),
(_, Ok(KnownShell::Fish)) => format!(
r#"
functions --copy fish_prompt shpool__old_prompt
function fish_prompt; echo -n "{prompt_prefix}"; shpool__old_prompt; end
"#
),
(_, Err(e)) => {
warn!("could not sniff shell: {}", e);
String::new()
}
};
let exe_path =
exe::current().context("getting current exe path")?.to_string_lossy().into_owned();
let sentinel_cmd = format!("\n {}=prompt {} daemon\n", SENTINEL_FLAG_VAR, exe_path);
script.push_str(sentinel_cmd.as_str());
debug!("injecting prefix script '{}'", script);
pty_master.write_all(script.as_bytes()).context("running prefix script")?;
Ok(())
}
#[instrument(skip_all)]
fn wait_for_startup(pty_master: &mut shpool_pty::fork::Master) -> anyhow::Result<()> {
test_hooks::emit("wait-for-startup-enter");
let mut startup_sentinel_scanner = SentinelScanner::new(STARTUP_SENTINEL);
let exe_path =
exe::current().context("getting current exe path")?.to_string_lossy().into_owned();
let startup_sentinel_cmd = format!("\n {}=startup {} daemon\n", SENTINEL_FLAG_VAR, exe_path);
pty_master
.write_all(startup_sentinel_cmd.as_bytes())
.context("running startup sentinel script")?;
let watchable_master = pty_master.clone();
let mut poll_fds = [poll::PollFd::new(
watchable_master.borrow_fd(),
PollFlags::POLLIN | PollFlags::POLLHUP | PollFlags::POLLERR,
)];
let deadline = time::Instant::now() + SENTINEL_POLL_TIMEOUT;
let mut buf: [u8; 2048] = [0; 2048];
loop {
if time::Instant::now() > deadline {
return Err(anyhow!("timed out waiting for shell startup"));
}
let nready = match poll::poll(&mut poll_fds, SENTINEL_POLL_MS) {
Ok(n) => n,
Err(e) => {
error!("polling pty master: {:?}", e);
return Err(e)?;
}
};
if nready == 0 {
continue;
}
if nready != 1 {
return Err(anyhow!("sentinal scan: expected exactly 1 ready fd"));
}
let len = pty_master.read(&mut buf).context("reading chunk to scan for startup")?;
if len == 0 {
return Err(anyhow!("EOF during shell startup"));
}
let buf = &buf[..len];
debug!("buf='{}'", String::from_utf8_lossy(buf));
for byte in buf.iter() {
if startup_sentinel_scanner.transition(*byte) {
return Ok(());
}
}
}
}
#[instrument(skip_all)]
fn sniff_shell(pid: libc::pid_t) -> anyhow::Result<KnownShell> {
let shell_proc_name =
libproc::proc_pid::name(pid).map_err(|e| anyhow!("determining subproc name: {:?}", e))?;
info!("shell_proc_name: {}", shell_proc_name);
if shell_proc_name.ends_with("bash") {
Ok(KnownShell::Bash)
} else if shell_proc_name.ends_with("zsh") {
Ok(KnownShell::Zsh)
} else if shell_proc_name.ends_with("fish") {
Ok(KnownShell::Fish)
} else {
Err(anyhow!("unknown shell: {:?}", shell_proc_name))
}
}
pub struct SentinelScanner {
scanner: Trie<u8, (), Vec<Option<usize>>>,
cursor: TrieCursor,
num_matches: usize,
}
impl SentinelScanner {
pub fn new(sentinel: &str) -> Self {
let mut scanner = Trie::new();
scanner.insert(sentinel.bytes(), ());
SentinelScanner { scanner, cursor: TrieCursor::Start, num_matches: 0 }
}
pub fn transition(&mut self, byte: u8) -> bool {
self.cursor = self.scanner.advance(self.cursor, byte);
match self.cursor {
TrieCursor::NoMatch => {
self.cursor = TrieCursor::Start;
false
}
TrieCursor::Match { is_partial, .. } if !is_partial => {
self.cursor = TrieCursor::Start;
self.num_matches += 1;
debug!("got prompt sentinel match #{}", self.num_matches);
self.num_matches == 1
}
_ => false,
}
}
}