use std::ffi::{OsStr, OsString};
use std::io::{ErrorKind, Read, Write};
use std::path::PathBuf;
use std::process::{Command, Stdio};
use std::sync::OnceLock;
use std::time::Instant;
use wait_timeout::ChildExt;
use crate::git::{GitError, WorktrunkError};
use crate::sync::Semaphore;
static CMD_SEMAPHORE: OnceLock<Semaphore> = OnceLock::new();
static STARTUP_CWD: OnceLock<Option<PathBuf>> = OnceLock::new();
const INHERITED_GIT_PATH_VARS: &[&str] = &[
"GIT_DIR",
"GIT_WORK_TREE",
"GIT_COMMON_DIR",
"GIT_INDEX_FILE",
"GIT_OBJECT_DIRECTORY",
];
pub fn init_startup_cwd() {
STARTUP_CWD.get_or_init(|| std::env::current_dir().ok());
}
fn startup_cwd() -> Option<&'static PathBuf> {
STARTUP_CWD
.get_or_init(|| std::env::current_dir().ok())
.as_ref()
}
fn compute_git_env_overrides<F>(base: &std::path::Path, lookup: F) -> Vec<(&'static str, OsString)>
where
F: Fn(&str) -> Option<OsString>,
{
let mut overrides = Vec::new();
for var in INHERITED_GIT_PATH_VARS {
let Some(value) = lookup(var) else {
continue;
};
let path = std::path::Path::new(&value);
if path.is_absolute() {
continue;
}
overrides.push((*var, base.join(path).into_os_string()));
}
overrides
}
static GIT_ENV_OVERRIDES: OnceLock<Vec<(&'static str, OsString)>> = OnceLock::new();
fn inherited_git_env_overrides() -> &'static [(&'static str, OsString)] {
GIT_ENV_OVERRIDES.get_or_init(|| {
let Some(cwd) = startup_cwd() else {
return Vec::new();
};
compute_git_env_overrides(cwd, |var| std::env::var_os(var))
})
}
const DEFAULT_CONCURRENT_COMMANDS: usize = 32;
fn parse_concurrent_limit(value: &str) -> Option<usize> {
value
.parse::<usize>()
.ok()
.map(|n| if n == 0 { usize::MAX } else { n })
}
fn max_concurrent_commands() -> usize {
std::env::var("WORKTRUNK_MAX_CONCURRENT_COMMANDS")
.ok()
.and_then(|s| parse_concurrent_limit(&s))
.unwrap_or(DEFAULT_CONCURRENT_COMMANDS)
}
fn semaphore() -> &'static Semaphore {
CMD_SEMAPHORE.get_or_init(|| Semaphore::new(max_concurrent_commands()))
}
static SHELL_CONFIG: OnceLock<Result<ShellConfig, String>> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct ShellConfig {
pub executable: PathBuf,
pub args: Vec<String>,
pub is_posix: bool,
pub name: String,
}
impl ShellConfig {
pub fn get() -> anyhow::Result<&'static ShellConfig> {
SHELL_CONFIG
.get_or_init(detect_shell)
.as_ref()
.map_err(|e| anyhow::anyhow!("{e}"))
}
pub fn command(&self, shell_command: &str) -> Command {
let mut cmd = Command::new(&self.executable);
for arg in &self.args {
cmd.arg(arg);
}
cmd.arg(shell_command);
cmd
}
pub fn is_posix(&self) -> bool {
self.is_posix
}
}
fn detect_shell() -> Result<ShellConfig, String> {
#[cfg(unix)]
{
Ok(ShellConfig {
executable: PathBuf::from("sh"),
args: vec!["-c".to_string()],
is_posix: true,
name: "sh".to_string(),
})
}
#[cfg(windows)]
{
detect_windows_shell()
}
}
#[cfg(windows)]
fn detect_windows_shell() -> Result<ShellConfig, String> {
if let Some(bash_path) = find_git_bash() {
return Ok(ShellConfig {
executable: bash_path,
args: vec!["-c".to_string()],
is_posix: true,
name: "Git Bash".to_string(),
});
}
Err("Git for Windows is required but not found.\n\
Install from https://git-scm.com/download/win"
.to_string())
}
#[cfg(windows)]
fn find_git_bash() -> Option<PathBuf> {
if let Ok(git_path) = which::which("git") {
if let Some(git_dir) = git_path.parent().and_then(|p| p.parent()) {
let bash_path = git_dir.join("bin").join("bash.exe");
if bash_path.exists() {
return Some(bash_path);
}
let bash_path = git_dir.join("usr").join("bin").join("bash.exe");
if bash_path.exists() {
return Some(bash_path);
}
}
}
let bash_path = PathBuf::from(r"C:\Program Files\Git\bin\bash.exe");
if bash_path.exists() {
return Some(bash_path);
}
if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") {
let bash_path = PathBuf::from(local_app_data)
.join("Programs")
.join("Git")
.join("bin")
.join("bash.exe");
if bash_path.exists() {
return Some(bash_path);
}
}
None
}
pub const DIRECTIVE_CD_FILE_ENV_VAR: &str = "WORKTRUNK_DIRECTIVE_CD_FILE";
pub const DIRECTIVE_EXEC_FILE_ENV_VAR: &str = "WORKTRUNK_DIRECTIVE_EXEC_FILE";
pub const DIRECTIVE_FILE_ENV_VAR: &str = "WORKTRUNK_DIRECTIVE_FILE";
pub fn scrub_directive_env_vars(cmd: &mut std::process::Command) {
cmd.env_remove(DIRECTIVE_CD_FILE_ENV_VAR);
cmd.env_remove(DIRECTIVE_EXEC_FILE_ENV_VAR);
cmd.env_remove(DIRECTIVE_FILE_ENV_VAR);
}
use std::cell::Cell;
use std::time::Duration;
thread_local! {
static COMMAND_TIMEOUT: Cell<Option<Duration>> = const { Cell::new(None) };
}
pub fn set_command_timeout(timeout: Option<Duration>) {
COMMAND_TIMEOUT.with(|t| t.set(timeout));
}
pub fn trace_instant(event: &str) {
crate::trace::emit::instant(event);
}
const LOG_OUTPUT_MAX_LINES: usize = 200;
const LOG_OUTPUT_MAX_BYTES: usize = 64 * 1024;
pub const SUBPROCESS_FULL_TARGET: &str = "worktrunk::subprocess_full";
pub const SUBPROCESS_TERMINAL_TARGET: &str = "worktrunk::subprocess_terminal";
static OUTPUT_LOG_AVAILABLE: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
pub fn set_output_log_available(yes: bool) {
OUTPUT_LOG_AVAILABLE.store(yes, std::sync::atomic::Ordering::Relaxed);
}
fn log_output(output: &std::process::Output) {
if !log::log_enabled!(log::Level::Debug) {
return;
}
for line in format_stream_full(&output.stdout, " ") {
log::debug!(target: SUBPROCESS_FULL_TARGET, "{}", line);
}
for line in format_stream_full(&output.stderr, " ! ") {
log::debug!(target: SUBPROCESS_FULL_TARGET, "{}", line);
}
for line in format_stream_bounded(&output.stdout, " ") {
log::debug!(target: SUBPROCESS_TERMINAL_TARGET, "{}", line);
}
for line in format_stream_bounded(&output.stderr, " ! ") {
log::debug!(target: SUBPROCESS_TERMINAL_TARGET, "{}", line);
}
}
fn format_stream_full(bytes: &[u8], prefix: &str) -> Vec<String> {
if bytes.is_empty() {
return Vec::new();
}
String::from_utf8_lossy(bytes)
.lines()
.map(|line| format!("{}{}", prefix, line))
.collect()
}
fn format_stream_bounded(bytes: &[u8], prefix: &str) -> Vec<String> {
if bytes.is_empty() {
return Vec::new();
}
let text = String::from_utf8_lossy(bytes);
let total_bytes = bytes.len();
let mut out = Vec::new();
let mut bytes_emitted = 0;
let mut lines = text.lines().enumerate();
for (lines_emitted, line) in &mut lines {
if lines_emitted >= LOG_OUTPUT_MAX_LINES || bytes_emitted >= LOG_OUTPUT_MAX_BYTES {
let remaining_lines = 1 + lines.count();
let remaining_bytes = total_bytes.saturating_sub(bytes_emitted);
let hint = if OUTPUT_LOG_AVAILABLE.load(std::sync::atomic::Ordering::Relaxed) {
"full output in output.log"
} else {
"rerun with -vv for full output"
};
out.push(format!(
"{}… ({} more lines, {} bytes elided — {})",
prefix, remaining_lines, remaining_bytes, hint
));
return out;
}
out.push(format!("{}{}", prefix, line));
bytes_emitted += line.len() + 1;
}
out
}
fn log_command_result(
context: Option<&str>,
cmd_str: &str,
ts: u64,
tid: u64,
dur_us: u64,
result: &std::io::Result<std::process::Output>,
) {
match result {
Ok(output) => {
crate::trace::emit::command_completed(
context,
cmd_str,
ts,
tid,
dur_us,
output.status.success(),
);
log_output(output);
}
Err(e) => {
crate::trace::emit::command_errored(context, cmd_str, ts, tid, dur_us, e);
}
}
}
fn run_with_timeout_impl(
cmd: &mut Command,
timeout: std::time::Duration,
) -> std::io::Result<std::process::Output> {
let mut child = cmd
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let mut child_stdout = child.stdout.take();
let mut child_stderr = child.stderr.take();
std::thread::scope(|s| {
let stdout_thread = s.spawn(|| {
let mut buf = Vec::new();
child_stdout
.as_mut()
.map(|h| h.read_to_end(&mut buf))
.transpose()?;
Ok::<_, std::io::Error>(buf)
});
let stderr_thread = s.spawn(|| {
let mut buf = Vec::new();
child_stderr
.as_mut()
.map(|h| h.read_to_end(&mut buf))
.transpose()?;
Ok::<_, std::io::Error>(buf)
});
match child.wait_timeout(timeout)? {
Some(status) => {
let stdout = stdout_thread.join().unwrap()?;
let stderr = stderr_thread.join().unwrap()?;
Ok(std::process::Output {
status,
stdout,
stderr,
})
}
None => {
let _ = child.kill();
let _ = child.wait();
Err(std::io::Error::new(
ErrorKind::TimedOut,
"command timed out",
))
}
}
})
}
pub struct Cmd {
program: String,
args: Vec<String>,
current_dir: Option<std::path::PathBuf>,
context: Option<String>,
stdin_data: Option<Vec<u8>>,
timeout: Option<std::time::Duration>,
envs: Vec<(OsString, OsString)>,
env_removes: Vec<OsString>,
shell_wrap: bool,
stdout_cfg: Option<std::process::Stdio>,
stdin_cfg: Option<std::process::Stdio>,
forward_signals: bool,
external_label: Option<String>,
directive_cd_file: Option<std::path::PathBuf>,
directive_legacy_file: Option<std::path::PathBuf>,
}
struct ExternalCommandLog {
label: Option<String>,
cmd_str: String,
started_at: Option<Instant>,
}
impl ExternalCommandLog {
fn new(label: Option<String>, cmd_str: String) -> Self {
let started_at = label.as_ref().map(|_| Instant::now());
Self {
label,
cmd_str,
started_at,
}
}
fn record(&self, exit_code: Option<i32>) {
if let Some(label) = &self.label {
let duration = self.started_at.as_ref().map(Instant::elapsed);
crate::command_log::log_command(label, &self.cmd_str, exit_code, duration);
}
}
}
impl Cmd {
fn builder(program: impl Into<String>, shell_wrap: bool) -> Self {
Self {
program: program.into(),
args: Vec::new(),
current_dir: None,
context: None,
stdin_data: None,
timeout: None,
envs: Vec::new(),
env_removes: Vec::new(),
shell_wrap,
stdout_cfg: None,
stdin_cfg: None,
forward_signals: false,
external_label: None,
directive_cd_file: None,
directive_legacy_file: None,
}
}
pub fn new(program: impl Into<String>) -> Self {
Self::builder(program, false)
}
pub fn shell(command: impl Into<String>) -> Self {
Self::builder(command, true)
}
fn command_string(&self) -> String {
if self.shell_wrap || self.args.is_empty() {
self.program.clone()
} else {
format!("{} {}", self.program, self.args.join(" "))
}
}
fn direct_command(&self) -> Command {
let mut cmd = Command::new(&self.program);
cmd.args(&self.args);
cmd
}
fn apply_common_settings(&self, cmd: &mut Command) {
if let Some(dir) = &self.current_dir {
cmd.current_dir(dir);
}
for (key, val) in inherited_git_env_overrides() {
cmd.env(key, val);
}
for (key, val) in &self.envs {
cmd.env(key, val);
}
for key in &self.env_removes {
cmd.env_remove(key);
}
scrub_directive_env_vars(cmd);
}
fn log_run_start(&self, cmd_str: &str) {
match &self.context {
Some(ctx) => log::debug!("$ {} [{}]", cmd_str, ctx),
None => log::debug!("$ {}", cmd_str),
}
}
fn log_stream_start(&self, cmd_str: &str, exec_mode: &str) {
match &self.context {
Some(ctx) => log::debug!("$ {} [{}] (streaming, {})", cmd_str, ctx, exec_mode),
None => log::debug!("$ {} (streaming, {})", cmd_str, exec_mode),
}
}
pub fn arg(mut self, arg: impl Into<String>) -> Self {
self.args.push(arg.into());
self
}
pub fn args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.args.extend(args.into_iter().map(Into::into));
self
}
pub fn current_dir(mut self, dir: impl Into<std::path::PathBuf>) -> Self {
self.current_dir = Some(dir.into());
self
}
pub fn context(mut self, ctx: impl Into<String>) -> Self {
self.context = Some(ctx.into());
self
}
pub fn stdin_bytes(mut self, data: impl Into<Vec<u8>>) -> Self {
self.stdin_data = Some(data.into());
self
}
pub fn timeout(mut self, duration: std::time::Duration) -> Self {
self.timeout = Some(duration);
self
}
pub fn env(mut self, key: impl AsRef<OsStr>, val: impl AsRef<OsStr>) -> Self {
self.envs
.push((key.as_ref().to_os_string(), val.as_ref().to_os_string()));
self
}
pub fn env_remove(mut self, key: impl AsRef<OsStr>) -> Self {
self.env_removes.push(key.as_ref().to_os_string());
self
}
pub fn stdout(mut self, cfg: std::process::Stdio) -> Self {
self.stdout_cfg = Some(cfg);
self
}
pub fn stdin(mut self, cfg: std::process::Stdio) -> Self {
self.stdin_cfg = Some(cfg);
self
}
pub fn forward_signals(mut self) -> Self {
self.forward_signals = true;
self
}
pub fn external(mut self, label: impl Into<String>) -> Self {
self.external_label = Some(label.into());
self
}
pub fn directive_cd_file(mut self, path: impl Into<std::path::PathBuf>) -> Self {
self.directive_cd_file = Some(path.into());
self
}
pub fn directive_legacy_file(mut self, path: impl Into<std::path::PathBuf>) -> Self {
self.directive_legacy_file = Some(path.into());
self
}
pub fn run(self) -> std::io::Result<std::process::Output> {
assert!(
!self.shell_wrap,
"Cmd::shell() commands must use .stream(), not .run()"
);
debug_assert!(
self.directive_cd_file.is_none() && self.directive_legacy_file.is_none(),
"directive_*_file is only applied by .stream(), not .run()"
);
let cmd_str = self.command_string();
let external_log = ExternalCommandLog::new(self.external_label.clone(), cmd_str.clone());
self.log_run_start(&cmd_str);
let _guard = semaphore().acquire();
let t0 = Instant::now();
let ts = t0
.duration_since(crate::trace::emit::trace_epoch())
.as_micros() as u64;
let tid = crate::trace::emit::thread_id();
let mut cmd = self.direct_command();
self.apply_common_settings(&mut cmd);
let effective_timeout = self.timeout.or_else(|| COMMAND_TIMEOUT.with(|t| t.get()));
let result = if let Some(stdin_data) = self.stdin_data {
cmd.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
let mut child = cmd.spawn()?;
if let Some(mut stdin) = child.stdin.take()
&& let Err(e) = stdin.write_all(&stdin_data)
&& e.kind() != std::io::ErrorKind::BrokenPipe
{
return Err(e);
}
child.wait_with_output()
} else if let Some(timeout_duration) = effective_timeout {
run_with_timeout_impl(&mut cmd, timeout_duration)
} else {
cmd.output()
};
let dur_us = t0.elapsed().as_micros() as u64;
log_command_result(self.context.as_deref(), &cmd_str, ts, tid, dur_us, &result);
let exit_code = result.as_ref().ok().and_then(|output| output.status.code());
external_log.record(exit_code);
result
}
pub fn pipe_into(
self,
next: Cmd,
) -> std::io::Result<(std::process::Output, std::process::Output)> {
assert!(
!self.shell_wrap && !next.shell_wrap,
"Cmd::shell() commands cannot be used with pipe_into"
);
assert!(
self.stdin_data.is_none(),
"pipe_into source cannot also use stdin_bytes"
);
assert!(
next.stdin_data.is_none(),
"pipe_into sink cannot use stdin_bytes (stdin comes from source)"
);
assert!(
self.timeout.is_none() && next.timeout.is_none(),
"pipe_into does not support timeouts"
);
assert!(
self.external_label.is_none() && next.external_label.is_none(),
"pipe_into does not support external() logging"
);
debug_assert!(
self.directive_cd_file.is_none()
&& self.directive_legacy_file.is_none()
&& next.directive_cd_file.is_none()
&& next.directive_legacy_file.is_none(),
"directive_*_file is only applied by .stream(), not pipe_into"
);
let first_cmd_str = self.command_string();
let second_cmd_str = next.command_string();
self.log_run_start(&first_cmd_str);
next.log_run_start(&second_cmd_str);
let _guard = semaphore().acquire();
let t0 = Instant::now();
let ts = t0
.duration_since(crate::trace::emit::trace_epoch())
.as_micros() as u64;
let tid = crate::trace::emit::thread_id();
let mut first = self.direct_command();
self.apply_common_settings(&mut first);
first
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
let mut first_child = first.spawn()?;
let first_stdout = first_child
.stdout
.take()
.expect("stdout was configured as piped");
let mut second = next.direct_command();
next.apply_common_settings(&mut second);
second
.stdin(Stdio::from(first_stdout))
.stdout(Stdio::piped())
.stderr(Stdio::piped());
let second_child = match second.spawn() {
Ok(child) => child,
Err(e) => {
let _ = first_child.kill();
let _ = first_child.wait();
return Err(e);
}
};
let mut first_stderr_pipe = first_child
.stderr
.take()
.expect("stderr was configured as piped");
let (first_result, second_result, first_dur_us, second_dur_us) = std::thread::scope(|s| {
let stderr_thread = s.spawn(move || {
let mut buf = Vec::new();
first_stderr_pipe.read_to_end(&mut buf).map(|_| buf)
});
let second_result = second_child.wait_with_output();
let second_dur_us = t0.elapsed().as_micros() as u64;
let first_status = first_child.wait();
let first_stderr = stderr_thread.join().unwrap();
let first_dur_us = t0.elapsed().as_micros() as u64;
let first_result = first_status.and_then(|status| {
first_stderr.map(|stderr| std::process::Output {
status,
stdout: Vec::new(),
stderr,
})
});
(first_result, second_result, first_dur_us, second_dur_us)
});
log_command_result(
self.context.as_deref(),
&first_cmd_str,
ts,
tid,
first_dur_us,
&first_result,
);
log_command_result(
next.context.as_deref(),
&second_cmd_str,
ts,
tid,
second_dur_us,
&second_result,
);
Ok((first_result?, second_result?))
}
pub fn stream(mut self) -> anyhow::Result<()> {
#[cfg(unix)]
use {
signal_hook::consts::{SIGINT, SIGPIPE, SIGTERM},
signal_hook::iterator::Signals,
std::os::unix::process::CommandExt,
};
assert!(
!self.shell_wrap || self.args.is_empty(),
"Cmd::shell() cannot use .arg() - include arguments in the shell command string"
);
let (mut cmd, exec_mode) = if self.shell_wrap {
let shell = ShellConfig::get()?;
let mode = format!("shell: {}", shell.name);
(shell.command(&self.program), mode)
} else {
(self.direct_command(), "direct".to_string())
};
let cmd_str = self.command_string();
let external_log = ExternalCommandLog::new(self.external_label.take(), cmd_str.clone());
self.log_stream_start(&cmd_str, &exec_mode);
self.apply_common_settings(&mut cmd);
if let Some(ref path) = self.directive_cd_file {
cmd.env(DIRECTIVE_CD_FILE_ENV_VAR, path);
}
if let Some(ref path) = self.directive_legacy_file {
cmd.env(DIRECTIVE_FILE_ENV_VAR, path);
}
#[cfg(not(unix))]
let _ = self.forward_signals;
let stdout_mode = self.stdout_cfg.unwrap_or_else(std::process::Stdio::inherit);
let stdin_mode = if self.stdin_data.is_some() {
std::process::Stdio::piped()
} else {
self.stdin_cfg.unwrap_or_else(std::process::Stdio::null)
};
#[cfg(unix)]
let mut signals = if self.forward_signals {
Some(Signals::new([SIGINT, SIGTERM])?)
} else {
None
};
#[cfg(unix)]
if self.forward_signals {
cmd.process_group(0);
}
cmd.stdin(stdin_mode)
.stdout(stdout_mode)
.stderr(std::process::Stdio::inherit()) .env_remove("VERGEN_GIT_DESCRIBE");
let mut child = cmd.spawn().map_err(|e| {
anyhow::Error::from(GitError::Other {
message: format!("Failed to execute command ({}): {}", exec_mode, e),
})
})?;
if let Some(ref content) = self.stdin_data
&& let Some(mut stdin) = child.stdin.take()
&& let Err(e) = stdin.write_all(content)
&& e.kind() != std::io::ErrorKind::BrokenPipe
{
return Err(e.into());
}
#[cfg(unix)]
let (status, seen_signal) = if self.forward_signals {
let child_pgid = child.id() as i32;
let mut seen_signal: Option<i32> = None;
loop {
if let Some(status) = child.try_wait().map_err(|e| {
anyhow::Error::from(GitError::Other {
message: format!("Failed to wait for command: {}", e),
})
})? {
break (status, seen_signal);
}
if let Some(signals) = signals.as_mut() {
for sig in signals.pending() {
if seen_signal.is_none() {
seen_signal = Some(sig);
forward_signal_with_escalation(child_pgid, sig);
}
}
}
std::thread::sleep(std::time::Duration::from_millis(10));
}
} else {
let status = child.wait().map_err(|e| {
anyhow::Error::from(GitError::Other {
message: format!("Failed to wait for command: {}", e),
})
})?;
(status, None)
};
#[cfg(not(unix))]
let status = child.wait().map_err(|e| {
anyhow::Error::from(GitError::Other {
message: format!("Failed to wait for command: {}", e),
})
})?;
#[cfg(unix)]
if let Some(sig) = seen_signal {
external_log.record(Some(128 + sig));
return Err(WorktrunkError::ChildProcessExited {
code: 128 + sig,
message: format!("terminated by signal {}", sig),
signal: Some(sig),
}
.into());
}
#[cfg(unix)]
if let Some(sig) = std::os::unix::process::ExitStatusExt::signal(&status) {
if sig == SIGPIPE {
external_log.record(Some(0));
return Ok(());
}
external_log.record(Some(128 + sig));
return Err(WorktrunkError::ChildProcessExited {
code: 128 + sig,
message: format!("terminated by signal {}", sig),
signal: Some(sig),
}
.into());
}
if !status.success() {
let code = status.code().unwrap_or(1);
external_log.record(status.code());
return Err(WorktrunkError::ChildProcessExited {
code,
message: format!("exit status: {}", code),
signal: None,
}
.into());
}
external_log.record(Some(0));
Ok(())
}
}
#[cfg(unix)]
fn process_group_alive(pgid: i32) -> bool {
match nix::sys::signal::killpg(nix::unistd::Pid::from_raw(pgid), None) {
Ok(_) => true,
Err(nix::errno::Errno::ESRCH) => false,
Err(_) => true,
}
}
#[cfg(unix)]
fn wait_for_exit(pgid: i32, grace: std::time::Duration) -> bool {
std::thread::sleep(grace);
!process_group_alive(pgid)
}
#[cfg(unix)]
pub fn forward_signal_with_escalation(pgid: i32, sig: i32) {
let pgid = nix::unistd::Pid::from_raw(pgid);
let initial_signal = match sig {
signal_hook::consts::SIGINT => nix::sys::signal::Signal::SIGINT,
signal_hook::consts::SIGTERM => nix::sys::signal::Signal::SIGTERM,
_ => return,
};
let _ = nix::sys::signal::killpg(pgid, initial_signal);
let grace = std::time::Duration::from_millis(200);
if sig == signal_hook::consts::SIGINT {
if !wait_for_exit(pgid.as_raw(), grace) {
let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGTERM);
if !wait_for_exit(pgid.as_raw(), grace) {
let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGKILL);
}
}
} else {
if !wait_for_exit(pgid.as_raw(), grace) {
let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGKILL);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_git_env_overrides() {
let base_buf = std::env::temp_dir().join("wt-test-startup-cwd");
let base = base_buf.as_path();
let abs_work = std::env::temp_dir().join("wt-test-abs-work");
let env: std::collections::HashMap<&str, OsString> = [
("GIT_DIR", OsString::from(".git")),
("GIT_WORK_TREE", abs_work.clone().into_os_string()),
("GIT_INDEX_FILE", OsString::from("../index")),
("GIT_AUTHOR_NAME", OsString::from("Test User")),
]
.into_iter()
.collect();
let overrides = compute_git_env_overrides(base, |var| env.get(var).cloned());
assert_eq!(overrides.len(), 2);
let as_map: std::collections::HashMap<_, _> = overrides.into_iter().collect();
assert_eq!(
as_map.get("GIT_DIR"),
Some(&base.join(".git").into_os_string())
);
assert_eq!(
as_map.get("GIT_INDEX_FILE"),
Some(&base.join("../index").into_os_string())
);
}
#[test]
fn test_compute_git_env_overrides_all_absolute() {
let base_buf = std::env::temp_dir().join("wt-test-startup-cwd");
let abs_git = std::env::temp_dir().join("wt-test-abs.git");
let env: std::collections::HashMap<&str, OsString> =
[("GIT_DIR", abs_git.into_os_string())]
.into_iter()
.collect();
let overrides = compute_git_env_overrides(base_buf.as_path(), |var| env.get(var).cloned());
assert!(overrides.is_empty());
}
#[test]
fn test_compute_git_env_overrides_all_unset() {
let base_buf = std::env::temp_dir().join("wt-test-startup-cwd");
let overrides = compute_git_env_overrides(base_buf.as_path(), |_| None);
assert!(overrides.is_empty());
}
#[test]
fn test_max_concurrent_commands_defaults() {
assert!(max_concurrent_commands() >= 1, "Default should be >= 1");
assert_eq!(
max_concurrent_commands(),
DEFAULT_CONCURRENT_COMMANDS,
"Without env var, should use default"
);
}
#[test]
fn test_parse_concurrent_limit() {
assert_eq!(parse_concurrent_limit("1"), Some(1));
assert_eq!(parse_concurrent_limit("32"), Some(32));
assert_eq!(parse_concurrent_limit("100"), Some(100));
assert_eq!(parse_concurrent_limit("0"), Some(usize::MAX));
assert_eq!(parse_concurrent_limit(""), None);
assert_eq!(parse_concurrent_limit("abc"), None);
assert_eq!(parse_concurrent_limit("-1"), None);
assert_eq!(parse_concurrent_limit("1.5"), None);
}
#[test]
fn test_shell_config_is_available() {
let config = ShellConfig::get().unwrap();
assert!(!config.name.is_empty());
assert!(!config.args.is_empty());
}
#[test]
#[cfg(unix)]
fn test_unix_shell_is_posix() {
let config = ShellConfig::get().unwrap();
assert!(config.is_posix);
assert_eq!(config.name, "sh");
}
#[test]
fn test_command_creation() {
let config = ShellConfig::get().unwrap();
let cmd = config.command("echo hello");
let _ = format!("{:?}", cmd);
}
#[test]
fn test_shell_command_execution() {
let config = ShellConfig::get().unwrap();
let output = config
.command("echo hello")
.output()
.expect("Failed to execute shell command");
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
assert!(
output.status.success(),
"echo should succeed. Shell: {} ({:?}), exit: {:?}, stdout: '{}', stderr: '{}'",
config.name,
config.executable,
output.status.code(),
stdout.trim(),
stderr.trim()
);
assert!(
stdout.contains("hello"),
"stdout should contain 'hello', got: '{}'",
stdout.trim()
);
}
#[test]
#[cfg(windows)]
fn test_windows_uses_git_bash() {
let config = ShellConfig::get().unwrap();
assert_eq!(config.name, "Git Bash");
assert!(config.is_posix, "Git Bash should support POSIX syntax");
assert!(
config.args.contains(&"-c".to_string()),
"Git Bash should use -c flag"
);
}
#[test]
#[cfg(windows)]
fn test_windows_echo_command() {
let config = ShellConfig::get().unwrap();
let output = config
.command("echo test_output")
.output()
.expect("Failed to execute echo");
let stdout = String::from_utf8_lossy(&output.stdout);
assert!(output.status.success());
assert!(
stdout.contains("test_output"),
"stdout should contain 'test_output', got: '{}'",
stdout.trim()
);
}
#[test]
#[cfg(windows)]
fn test_windows_posix_redirection() {
let config = ShellConfig::get().unwrap();
let output = config
.command("echo redirected 1>&2")
.output()
.expect("Failed to execute redirection test");
let stderr = String::from_utf8_lossy(&output.stderr);
assert!(output.status.success());
assert!(
stderr.contains("redirected"),
"stderr should contain 'redirected' (stdout redirected to stderr), got: '{}'",
stderr.trim()
);
}
#[test]
fn test_shell_config_clone() {
let config = ShellConfig::get().unwrap();
let cloned = config.clone();
assert_eq!(config.name, cloned.name);
assert_eq!(config.is_posix, cloned.is_posix);
assert_eq!(config.args, cloned.args);
}
#[test]
fn test_shell_is_posix_method() {
let config = ShellConfig::get().unwrap();
assert_eq!(config.is_posix(), config.is_posix);
}
#[test]
fn test_cmd_completes_fast_command() {
let result = Cmd::new("echo")
.arg("hello")
.timeout(Duration::from_secs(5))
.run();
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.status.success());
assert!(String::from_utf8_lossy(&output.stdout).contains("hello"));
}
#[test]
#[cfg(unix)]
fn test_cmd_timeout_kills_slow_command() {
let result = Cmd::new("sleep")
.arg("10")
.timeout(Duration::from_millis(50))
.run();
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::TimedOut);
}
#[test]
fn test_cmd_without_timeout_completes() {
let result = Cmd::new("echo").arg("no timeout").run();
assert!(result.is_ok());
}
#[test]
fn test_cmd_with_context() {
let result = Cmd::new("echo")
.arg("with context")
.context("test-context")
.run();
assert!(result.is_ok());
}
#[test]
fn test_cmd_with_stdin() {
let result = Cmd::new("cat").stdin_bytes("hello from stdin").run();
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.status.success());
assert!(String::from_utf8_lossy(&output.stdout).contains("hello from stdin"));
}
#[test]
fn test_thread_local_timeout_setting() {
let initial = COMMAND_TIMEOUT.with(|t| t.get());
set_command_timeout(Some(Duration::from_millis(100)));
let after_set = COMMAND_TIMEOUT.with(|t| t.get());
assert_eq!(after_set, Some(Duration::from_millis(100)));
set_command_timeout(initial);
let after_clear = COMMAND_TIMEOUT.with(|t| t.get());
assert_eq!(after_clear, initial);
}
#[test]
fn test_cmd_uses_thread_local_timeout() {
set_command_timeout(None);
let result = Cmd::new("echo").arg("thread local test").run();
assert!(result.is_ok());
set_command_timeout(None);
}
#[test]
#[cfg(unix)]
fn test_cmd_thread_local_timeout_kills_slow_command() {
set_command_timeout(Some(Duration::from_millis(50)));
let result = Cmd::new("sleep").arg("10").run();
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::TimedOut);
set_command_timeout(None);
}
#[test]
fn test_cmd_shell_stream_succeeds() {
let result = Cmd::shell("echo hello").stream();
assert!(result.is_ok());
}
#[test]
fn test_cmd_shell_stream_fails_on_nonzero_exit() {
use crate::git::WorktrunkError;
let result = Cmd::shell("exit 42").stream();
assert!(result.is_err());
let err = result.unwrap_err();
let wt_err = err.downcast_ref::<WorktrunkError>().unwrap();
match wt_err {
WorktrunkError::ChildProcessExited { code, .. } => {
assert_eq!(*code, 42);
}
_ => panic!("Expected ChildProcessExited error"),
}
}
#[test]
#[cfg(unix)]
fn test_cmd_stream_sigpipe_is_not_an_error() {
let result = Cmd::new("sh").args(["-c", "kill -PIPE $$"]).stream();
assert!(
result.is_ok(),
"SIGPIPE should not be treated as an error: {result:?}"
);
}
#[test]
#[cfg(unix)]
fn test_cmd_stream_other_signals_are_errors() {
use crate::git::WorktrunkError;
let result = Cmd::new("sh").args(["-c", "kill -TERM $$"]).stream();
assert!(result.is_err());
let err = result.unwrap_err();
let wt_err = err.downcast_ref::<WorktrunkError>().unwrap();
match wt_err {
WorktrunkError::ChildProcessExited { code, .. } => {
assert_eq!(*code, 128 + 15); }
_ => panic!("Expected ChildProcessExited error"),
}
}
#[test]
#[cfg(unix)]
fn test_cmd_shell_stream_with_stdin() {
let result = Cmd::shell("cat").stdin_bytes("test content").stream();
assert!(result.is_ok());
}
#[test]
#[cfg(unix)]
fn test_cmd_new_stream_succeeds() {
let result = Cmd::new("echo").arg("hello").stream();
assert!(result.is_ok());
}
#[test]
#[cfg(unix)]
fn test_cmd_shell_stream_with_stdout_redirect() {
use std::process::Stdio;
let result = Cmd::shell("echo redirected")
.stdout(Stdio::from(std::io::stderr()))
.stream();
assert!(result.is_ok());
}
#[test]
#[cfg(unix)]
fn test_cmd_shell_stream_with_stdin_inherit() {
use std::process::Stdio;
let result = Cmd::shell("true").stdin(Stdio::inherit()).stream();
assert!(result.is_ok());
}
#[test]
#[cfg(unix)]
fn test_cmd_shell_stream_with_env() {
let result = Cmd::shell("printenv TEST_VAR")
.env("TEST_VAR", "test_value")
.env_remove("SOME_NONEXISTENT_VAR")
.stream();
assert!(result.is_ok());
}
#[test]
#[cfg(unix)]
fn test_process_group_alive_with_current_process() {
let pgid = nix::unistd::getpgrp().as_raw();
assert!(super::process_group_alive(pgid));
}
#[test]
#[cfg(unix)]
fn test_process_group_alive_with_nonexistent_pgid() {
assert!(!super::process_group_alive(999_999_999));
}
#[test]
#[cfg(unix)]
fn test_forward_signal_with_escalation_unknown_signal() {
super::forward_signal_with_escalation(1, 999);
}
#[test]
fn test_format_stream_full_empty() {
assert!(format_stream_full(b"", " ").is_empty());
}
#[test]
fn test_format_stream_full_prefixes_each_line() {
let lines = format_stream_full(b"alpha\nbeta\ngamma\n", " ");
assert_eq!(lines, vec![" alpha", " beta", " gamma"]);
}
#[test]
fn test_format_stream_full_stderr_prefix() {
let lines = format_stream_full(b"err1\nerr2\n", " ! ");
assert_eq!(lines, vec![" ! err1", " ! err2"]);
}
#[test]
fn test_format_stream_bounded_empty() {
assert!(format_stream_bounded(b"", " ").is_empty());
}
#[test]
fn test_format_stream_bounded_below_caps_emits_all() {
let lines = format_stream_bounded(b"one\ntwo\nthree\n", " ");
assert_eq!(lines, vec![" one", " two", " three"]);
}
#[test]
fn test_format_stream_bounded_line_cap_triggers_elision() {
let input: String = (0..LOG_OUTPUT_MAX_LINES + 5)
.map(|i| format!("line{i}\n"))
.collect();
let lines = format_stream_bounded(input.as_bytes(), " ");
assert_eq!(lines.len(), LOG_OUTPUT_MAX_LINES + 1, "cap + 1 marker");
let marker = lines.last().unwrap();
assert!(
marker.starts_with(" … (5 more lines, "),
"marker should count the 5 lines past the cap: {marker}"
);
assert!(marker.contains("rerun with -vv"));
}
#[test]
fn test_format_stream_bounded_byte_cap_triggers_elision() {
let long = "x".repeat(LOG_OUTPUT_MAX_BYTES + 100);
let input = format!("{long}\nafter1\nafter2\n");
let lines = format_stream_bounded(input.as_bytes(), " ");
assert_eq!(lines.len(), 2);
assert_eq!(lines[0].len(), 2 + long.len());
let marker = &lines[1];
assert!(
marker.starts_with(" … (2 more lines, "),
"marker should count after1 + after2: {marker}"
);
}
}