use std::borrow::Cow;
use std::collections::HashMap;
use std::ffi::OsStr;
use std::ffi::OsString;
use std::fmt::Debug;
use std::io::Read;
use std::io::Write;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicI32;
use std::sync::atomic::Ordering;
use parking_lot::Mutex;
use soft_canonicalize::soft_canonicalize;
use thiserror::Error;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use crate::shell::child_process_tracker::ChildProcessTracker;
use super::commands::ShellCommand;
use super::commands::builtin_commands;
pub type Result<T, E = Error> = core::result::Result<T, E>;
#[derive(Debug, Error)]
#[error(transparent)]
pub struct Error(#[from] anyhow::Error);
impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Error(anyhow::Error::from(value))
}
}
macro_rules! bail {
($msg:literal $(,)?) => {
return Err($crate::Error::from(anyhow::anyhow!($msg)))
};
($err:expr $(,)?) => {
return Err($crate::Error::from(anyhow::anyhow!($err)))
};
($fmt:expr, $($arg:tt)*) => {
return Err($crate::Error::from(anyhow::anyhow!($fmt, $($arg)*)))
};
}
pub(crate) use bail;
#[derive(Debug, Default, Clone)]
pub(crate) struct TreeExitCodeCell(Arc<AtomicI32>);
impl TreeExitCodeCell {
pub fn try_set(&self, exit_code: i32) {
let current = self.0.load(Ordering::SeqCst);
if current == 0 {
let _ = self.0.compare_exchange(
0,
exit_code,
Ordering::SeqCst,
Ordering::SeqCst,
);
}
}
pub fn get(&self) -> Option<i32> {
match self.0.load(Ordering::SeqCst) {
0 => None,
code => Some(code),
}
}
}
#[derive(Clone)]
pub struct ShellState {
env_vars: HashMap<OsString, OsString>,
shell_vars: HashMap<OsString, OsString>,
positional_param_len: usize,
cwd: PathBuf,
commands: Arc<HashMap<String, Arc<dyn ShellCommand>>>,
kill_signal: KillSignal,
process_tracker: ChildProcessTracker,
tree_exit_code_cell: TreeExitCodeCell,
}
impl ShellState {
pub fn new(
positional_params: Vec<OsString>,
env_vars: HashMap<OsString, OsString>,
cwd: PathBuf,
custom_commands: HashMap<String, Arc<dyn ShellCommand>>,
kill_signal: KillSignal,
) -> Self {
assert!(cwd.is_absolute());
let mut commands = builtin_commands();
commands.extend(custom_commands);
let mut result = Self {
env_vars: Default::default(),
shell_vars: Default::default(),
positional_param_len: positional_params.len(),
cwd: PathBuf::new(),
commands: Arc::new(commands),
kill_signal,
process_tracker: ChildProcessTracker::new(),
tree_exit_code_cell: Default::default(),
};
for (name, value) in env_vars {
result.apply_env_var(&name, &value);
}
result.set_cwd(cwd);
result.apply_changes(
positional_params
.iter()
.enumerate()
.map(|(position, param)| {
EnvChange::SetShellVar(
OsString::from((position + 1).to_string()),
param.to_os_string(),
)
})
.collect::<Vec<EnvChange>>()
.as_slice(),
);
result
}
pub fn positional_param_len(&self) -> usize {
self.positional_param_len
}
pub fn cwd(&self) -> &PathBuf {
&self.cwd
}
pub fn env_vars(&self) -> &HashMap<OsString, OsString> {
&self.env_vars
}
pub fn get_var(&self, name: &OsStr) -> Option<&OsString> {
let name = if cfg!(windows) {
Cow::Owned(name.to_ascii_uppercase())
} else {
Cow::Borrowed(name)
};
let name: &OsStr = &name;
self
.env_vars
.get(name)
.or_else(|| self.shell_vars.get(name))
}
pub fn set_cwd(&mut self, cwd: PathBuf) {
self.cwd = cwd.clone();
self.env_vars.insert("PWD".into(), cwd.into_os_string());
}
pub fn apply_changes(&mut self, changes: &[EnvChange]) {
for change in changes {
self.apply_change(change);
}
}
pub fn apply_change(&mut self, change: &EnvChange) {
match change {
EnvChange::SetEnvVar(name, value) => self.apply_env_var(name, value),
EnvChange::SetShellVar(name, value) => {
if self.env_vars.contains_key(name) {
self.apply_env_var(name, value);
} else {
self
.shell_vars
.insert(name.to_os_string(), value.to_os_string());
}
}
EnvChange::UnsetVar(name) => {
self.shell_vars.remove(name);
self.env_vars.remove(name);
}
EnvChange::Cd(new_dir) => {
self.set_cwd(new_dir.clone());
}
}
}
pub fn apply_env_var(&mut self, name: &OsStr, value: &OsStr) {
let name = if cfg!(windows) {
name.to_ascii_uppercase()
} else {
name.to_os_string()
};
if name == "PWD" {
let cwd = Path::new(value);
if cwd.is_absolute()
&& let Ok(cwd) = soft_canonicalize(cwd)
{
self.set_cwd(cwd);
}
} else {
self.shell_vars.remove(&name);
self.env_vars.insert(name, value.to_os_string());
}
}
pub fn kill_signal(&self) -> &KillSignal {
&self.kill_signal
}
pub fn track_child_process(&self, child: &tokio::process::Child) {
self.process_tracker.track(child);
}
pub(crate) fn tree_exit_code_cell(&self) -> &TreeExitCodeCell {
&self.tree_exit_code_cell
}
pub fn resolve_custom_command(
&self,
name: &OsStr,
) -> Option<Arc<dyn ShellCommand>> {
name
.to_str()
.and_then(|name| self.commands.get(name).cloned())
}
pub fn resolve_command_path(
&self,
command_name: &OsStr,
) -> Result<PathBuf, super::which::CommandPathResolutionError> {
super::which::resolve_command_path(command_name, self.cwd(), self)
}
pub fn with_child_signal(&self) -> ShellState {
let mut state = self.clone();
state.kill_signal = self.kill_signal.child_signal();
state.tree_exit_code_cell = TreeExitCodeCell::default();
state
}
}
impl sys_traits::BaseEnvVar for ShellState {
fn base_env_var_os(&self, key: &OsStr) -> Option<OsString> {
self.env_vars.get(key).cloned()
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum EnvChange {
SetEnvVar(OsString, OsString),
SetShellVar(OsString, OsString),
UnsetVar(OsString),
Cd(PathBuf),
}
#[derive(Debug)]
pub enum ExecuteResult {
Exit(i32, Vec<JoinHandle<i32>>),
Continue(i32, Vec<EnvChange>, Vec<JoinHandle<i32>>),
}
impl ExecuteResult {
pub fn from_exit_code(exit_code: i32) -> ExecuteResult {
ExecuteResult::Continue(exit_code, Vec::new(), Vec::new())
}
pub fn into_exit_code_and_handles(self) -> (i32, Vec<JoinHandle<i32>>) {
match self {
ExecuteResult::Exit(code, handles) => (code, handles),
ExecuteResult::Continue(code, _, handles) => (code, handles),
}
}
pub fn into_handles(self) -> Vec<JoinHandle<i32>> {
self.into_exit_code_and_handles().1
}
}
#[derive(Debug)]
pub enum ShellPipeReader {
OsPipe(std::io::PipeReader),
StdFile(std::fs::File),
}
impl Clone for ShellPipeReader {
fn clone(&self) -> Self {
match self {
Self::OsPipe(pipe) => Self::OsPipe(pipe.try_clone().unwrap()),
Self::StdFile(file) => Self::StdFile(file.try_clone().unwrap()),
}
}
}
impl ShellPipeReader {
pub fn stdin() -> ShellPipeReader {
#[cfg(unix)]
pub fn dup_stdin_as_pipe_reader() -> std::io::PipeReader {
use std::os::fd::AsFd;
use std::os::fd::FromRawFd;
use std::os::fd::IntoRawFd;
let owned = std::io::stdin().as_fd().try_clone_to_owned().unwrap();
let raw = owned.into_raw_fd();
unsafe { std::io::PipeReader::from_raw_fd(raw) }
}
#[cfg(windows)]
pub fn dup_stdin_as_pipe_reader() -> std::io::PipeReader {
use std::os::windows::io::AsHandle;
use std::os::windows::io::FromRawHandle;
use std::os::windows::io::IntoRawHandle;
let owned = std::io::stdin().as_handle().try_clone_to_owned().unwrap();
let raw = owned.into_raw_handle();
unsafe { std::io::PipeReader::from_raw_handle(raw) }
}
ShellPipeReader::OsPipe(dup_stdin_as_pipe_reader())
}
pub fn from_raw(reader: std::io::PipeReader) -> Self {
Self::OsPipe(reader)
}
pub fn from_std(std_file: std::fs::File) -> Self {
Self::StdFile(std_file)
}
#[cfg(test)]
#[allow(clippy::should_implement_trait)]
pub fn from_str(data: &str) -> Self {
use std::io::Write;
let (read, mut write) = std::io::pipe().unwrap();
write.write_all(data.as_bytes()).unwrap();
Self::OsPipe(read)
}
pub fn into_stdio(self) -> std::process::Stdio {
match self {
Self::OsPipe(pipe) => pipe.into(),
Self::StdFile(file) => file.into(),
}
}
pub fn pipe_to(self, writer: &mut dyn Write) -> Result<()> {
self.pipe_to_inner(writer, false)
}
fn pipe_to_with_flushing(self, writer: &mut dyn Write) -> Result<()> {
self.pipe_to_inner(writer, true)
}
fn pipe_to_inner(
mut self,
writer: &mut dyn Write,
flush: bool,
) -> Result<()> {
loop {
let mut buffer = [0; 512]; let size = match &mut self {
ShellPipeReader::OsPipe(pipe) => pipe.read(&mut buffer)?,
ShellPipeReader::StdFile(file) => file.read(&mut buffer)?,
};
if size == 0 {
break;
}
writer.write_all(&buffer[0..size])?;
if flush {
writer.flush()?;
}
}
Ok(())
}
pub fn pipe_to_sender(self, mut sender: ShellPipeWriter) -> Result<()> {
match &mut sender {
ShellPipeWriter::OsPipe(pipe) => self.pipe_to(pipe),
ShellPipeWriter::StdFile(file) => self.pipe_to(file),
ShellPipeWriter::Stdout => {
self.pipe_to_with_flushing(&mut std::io::stdout())
}
ShellPipeWriter::Stderr => {
self.pipe_to_with_flushing(&mut std::io::stderr())
}
ShellPipeWriter::Null => Ok(()),
}
}
pub fn pipe_to_string_handle(self) -> JoinHandle<String> {
tokio::task::spawn_blocking(|| {
let mut buf = Vec::new();
self.pipe_to(&mut buf).unwrap();
String::from_utf8_lossy(&buf).to_string()
})
}
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
match self {
ShellPipeReader::OsPipe(pipe) => pipe.read(buf).map_err(|e| e.into()),
ShellPipeReader::StdFile(file) => file.read(buf).map_err(|e| e.into()),
}
}
}
#[derive(Debug)]
pub enum ShellPipeWriter {
OsPipe(std::io::PipeWriter),
StdFile(std::fs::File),
Stdout,
Stderr,
Null,
}
impl Clone for ShellPipeWriter {
fn clone(&self) -> Self {
match self {
Self::OsPipe(pipe) => Self::OsPipe(pipe.try_clone().unwrap()),
Self::StdFile(file) => Self::StdFile(file.try_clone().unwrap()),
Self::Stdout => Self::Stdout,
Self::Stderr => Self::Stderr,
Self::Null => Self::Null,
}
}
}
impl ShellPipeWriter {
pub fn stdout() -> Self {
Self::Stdout
}
pub fn stderr() -> Self {
Self::Stderr
}
pub fn null() -> Self {
Self::Null
}
pub fn from_std(std_file: std::fs::File) -> Self {
Self::StdFile(std_file)
}
pub fn into_stdio(self) -> std::process::Stdio {
match self {
Self::OsPipe(pipe) => pipe.into(),
Self::StdFile(file) => file.into(),
Self::Stdout => std::process::Stdio::inherit(),
Self::Stderr => std::process::Stdio::inherit(),
Self::Null => std::process::Stdio::null(),
}
}
pub fn write_all(&mut self, bytes: &[u8]) -> Result<()> {
self.write_all_iter(std::iter::once(bytes))
}
pub fn write_all_iter<'a>(
&mut self,
iter: impl Iterator<Item = &'a [u8]> + 'a,
) -> Result<()> {
match self {
Self::OsPipe(pipe) => {
for bytes in iter {
pipe.write_all(bytes)?;
}
}
Self::StdFile(file) => {
for bytes in iter {
file.write_all(bytes)?
}
}
Self::Stdout => {
let mut stdout = std::io::stdout().lock();
for bytes in iter {
stdout.write_all(bytes)?;
}
stdout.flush()?;
}
Self::Stderr => {
let mut stderr = std::io::stderr().lock();
for bytes in iter {
stderr.write_all(bytes)?;
}
stderr.flush()?;
}
Self::Null => {}
}
Ok(())
}
pub fn write_line(&mut self, line: &str) -> Result<()> {
let bytes = format!("{line}\n");
self.write_all(bytes.as_bytes())
}
}
pub fn pipe() -> (ShellPipeReader, ShellPipeWriter) {
let (reader, writer) = std::io::pipe().unwrap();
(
ShellPipeReader::OsPipe(reader),
ShellPipeWriter::OsPipe(writer),
)
}
#[derive(Debug)]
struct KillSignalInner {
aborted_code: Mutex<Option<i32>>,
sender: broadcast::Sender<SignalKind>,
children: Mutex<Vec<std::sync::Weak<KillSignalInner>>>,
}
impl KillSignalInner {
pub fn send(&self, signal_kind: SignalKind) {
if signal_kind.causes_abort() {
let mut stored_aborted_code = self.aborted_code.lock();
if stored_aborted_code.is_none() {
*stored_aborted_code = Some(signal_kind.aborted_code());
}
}
_ = self.sender.send(signal_kind);
let mut children = self.children.lock();
children.retain(|weak_child| {
if let Some(child) = weak_child.upgrade() {
child.send(signal_kind);
true
} else {
false }
});
}
}
#[derive(Debug, Clone)]
pub struct KillSignal(Arc<KillSignalInner>);
impl Default for KillSignal {
fn default() -> Self {
let (sender, _) = broadcast::channel(100);
Self(Arc::new(KillSignalInner {
aborted_code: Mutex::new(None),
sender,
children: Mutex::new(Vec::new()),
}))
}
}
impl KillSignal {
pub fn aborted_code(&self) -> Option<i32> {
*self.0.aborted_code.lock()
}
pub fn child_signal(&self) -> Self {
let (sender, _) = broadcast::channel(100);
let child = Arc::new(KillSignalInner {
aborted_code: Mutex::new(self.aborted_code()),
sender,
children: Mutex::new(Vec::new()),
});
self.0.children.lock().push(Arc::downgrade(&child));
Self(child)
}
pub fn drop_guard(self) -> KillSignalDropGuard {
self.drop_guard_with_kind(SignalKind::SIGTERM)
}
pub fn drop_guard_with_kind(self, kind: SignalKind) -> KillSignalDropGuard {
KillSignalDropGuard {
disarmed: std::sync::atomic::AtomicBool::new(false),
kill_signal_kind: kind,
signal: self,
}
}
pub fn send(&self, signal: SignalKind) {
self.0.send(signal)
}
pub async fn wait_aborted(&self) -> SignalKind {
let mut receiver = self.0.sender.subscribe();
loop {
let signal = receiver.recv().await.unwrap();
if signal.causes_abort() {
return signal;
}
}
}
pub async fn wait_any(&self) -> SignalKind {
let mut receiver = self.0.sender.subscribe();
receiver.recv().await.unwrap()
}
}
#[derive(Debug)]
pub struct KillSignalDropGuard {
disarmed: std::sync::atomic::AtomicBool,
kill_signal_kind: SignalKind,
signal: KillSignal,
}
impl Drop for KillSignalDropGuard {
fn drop(&mut self) {
if !self.disarmed.load(std::sync::atomic::Ordering::SeqCst) {
self.signal.send(self.kill_signal_kind);
}
}
}
impl KillSignalDropGuard {
pub fn disarm(&self) {
self
.disarmed
.store(true, std::sync::atomic::Ordering::SeqCst);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SignalKind {
SIGTERM,
SIGKILL,
SIGABRT,
SIGQUIT,
SIGINT,
SIGSTOP,
Other(i32),
}
impl SignalKind {
pub fn causes_abort(&self) -> bool {
match self {
SignalKind::SIGTERM
| SignalKind::SIGKILL
| SignalKind::SIGQUIT
| SignalKind::SIGINT
| SignalKind::SIGSTOP
| SignalKind::SIGABRT => true,
SignalKind::Other(_) => false,
}
}
pub fn aborted_code(&self) -> i32 {
let value: i32 = (*self).into();
128 + value
}
}
impl From<i32> for SignalKind {
fn from(value: i32) -> Self {
#[cfg(unix)]
match value {
nix::libc::SIGINT => SignalKind::SIGINT,
nix::libc::SIGQUIT => SignalKind::SIGQUIT,
nix::libc::SIGABRT => SignalKind::SIGABRT,
nix::libc::SIGKILL => SignalKind::SIGKILL,
nix::libc::SIGTERM => SignalKind::SIGTERM,
nix::libc::SIGSTOP => SignalKind::SIGSTOP,
_ => SignalKind::Other(value),
}
#[cfg(not(unix))]
match value {
2 => SignalKind::SIGINT,
3 => SignalKind::SIGQUIT,
6 => SignalKind::SIGABRT,
9 => SignalKind::SIGKILL,
15 => SignalKind::SIGTERM,
19 => SignalKind::SIGSTOP,
_ => SignalKind::Other(value),
}
}
}
impl From<SignalKind> for i32 {
fn from(kind: SignalKind) -> i32 {
#[cfg(unix)]
match kind {
SignalKind::SIGINT => nix::libc::SIGINT,
SignalKind::SIGQUIT => nix::libc::SIGQUIT,
SignalKind::SIGABRT => nix::libc::SIGABRT,
SignalKind::SIGKILL => nix::libc::SIGKILL,
SignalKind::SIGTERM => nix::libc::SIGTERM,
SignalKind::SIGSTOP => nix::libc::SIGSTOP,
SignalKind::Other(value) => value,
}
#[cfg(not(unix))]
match kind {
SignalKind::SIGINT => 2,
SignalKind::SIGQUIT => 3,
SignalKind::SIGABRT => 6,
SignalKind::SIGKILL => 9,
SignalKind::SIGTERM => 15,
SignalKind::SIGSTOP => 19,
SignalKind::Other(value) => value,
}
}
}
#[cfg(test)]
mod test {
use crate::KillSignal;
use crate::SignalKind;
#[tokio::test]
async fn test_send_and_wait_any() {
let kill_signal = KillSignal::default();
tokio::task::spawn({
let kill_signal = kill_signal.clone();
async move {
kill_signal.send(SignalKind::SIGTERM);
}
});
let signal = kill_signal.wait_any().await;
assert_eq!(signal, SignalKind::SIGTERM);
}
#[tokio::test]
async fn test_signal_propagation_to_child_and_grandchild() {
let parent_signal = KillSignal::default();
let child_signal = parent_signal.child_signal();
let sibling_signal = parent_signal.child_signal();
let grandchild_signal = child_signal.child_signal();
let parent = parent_signal.clone();
tokio::task::spawn(async move {
parent.send(SignalKind::SIGKILL);
});
let signals = futures::join!(
child_signal.wait_any(),
sibling_signal.wait_any(),
grandchild_signal.wait_any()
);
for signal in [signals.0, signals.1, signals.2].into_iter() {
assert_eq!(signal, SignalKind::SIGKILL);
}
assert_eq!(child_signal.aborted_code(), Some(128 + 9));
assert_eq!(sibling_signal.aborted_code(), Some(128 + 9));
assert_eq!(grandchild_signal.aborted_code(), Some(128 + 9));
}
#[tokio::test]
async fn test_signal_propagation_on_sub_tree() {
let parent_signal = KillSignal::default();
let child_signal = parent_signal.child_signal();
let sibling_signal = parent_signal.child_signal();
let grandchild_signal = child_signal.child_signal();
let grandchild2_signal = child_signal.child_signal();
child_signal.send(SignalKind::SIGABRT);
assert!(parent_signal.aborted_code().is_none());
assert!(sibling_signal.aborted_code().is_none());
assert!(child_signal.aborted_code().is_some());
assert!(grandchild_signal.aborted_code().is_some());
assert!(grandchild2_signal.aborted_code().is_some());
}
#[tokio::test]
async fn test_wait_aborted() {
let kill_signal = KillSignal::default();
tokio::task::spawn({
let kill_signal = kill_signal.clone();
async move {
kill_signal.send(SignalKind::SIGABRT);
}
});
let signal = kill_signal.wait_aborted().await;
assert_eq!(signal, SignalKind::SIGABRT);
assert!(kill_signal.aborted_code().is_some());
}
#[tokio::test]
async fn test_propagation_and_is_aborted_flag() {
let parent_signal = KillSignal::default();
let child_signal = parent_signal.child_signal();
assert!(parent_signal.aborted_code().is_none());
assert!(child_signal.aborted_code().is_none());
tokio::task::spawn({
let parent_signal = parent_signal.clone();
async move {
parent_signal.send(SignalKind::SIGQUIT);
}
});
let signal = child_signal.wait_aborted().await;
assert_eq!(signal, SignalKind::SIGQUIT);
assert_eq!(parent_signal.aborted_code(), Some(128 + 3));
assert_eq!(child_signal.aborted_code(), Some(128 + 3));
}
#[tokio::test]
async fn test_dropped_child_signal_cleanup() {
let parent_signal = KillSignal::default();
{
let child_signal = parent_signal.child_signal();
assert!(child_signal.aborted_code().is_none());
}
tokio::task::spawn({
let parent_signal = parent_signal.clone();
async move {
parent_signal.send(SignalKind::SIGTERM);
}
});
let signal = parent_signal.wait_any().await;
assert_eq!(signal, SignalKind::SIGTERM);
}
#[tokio::test]
async fn test_drop_guard() {
let parent_signal = KillSignal::default();
{
let drop_guard = parent_signal.clone().drop_guard();
drop_guard.disarm();
}
assert_eq!(parent_signal.aborted_code(), None);
{
let drop_guard = parent_signal.clone().drop_guard();
drop(drop_guard);
}
assert_eq!(
parent_signal.aborted_code(),
Some(SignalKind::SIGTERM.aborted_code())
);
}
}