use std::io::{BufRead, BufReader, Read, Write};
use std::path::Path;
use std::process::{Child, Stdio};
use std::sync::mpsc::{self, Sender};
use std::thread;
use std::time::Instant;
use anyhow::Context;
use worktrunk::command_log::log_command;
use worktrunk::git::WorktrunkError;
use worktrunk::shell_exec::{
DIRECTIVE_CD_FILE_ENV_VAR, DIRECTIVE_FILE_ENV_VAR, ShellConfig, scrub_directive_env_vars,
};
use worktrunk::styling::stderr;
use super::handlers::DirectivePassthrough;
pub struct ConcurrentCommand<'a> {
pub label: &'a str,
pub expanded: &'a str,
pub working_dir: &'a Path,
pub context_json: &'a str,
pub log_label: Option<&'a str>,
pub directives: &'a DirectivePassthrough,
}
pub fn run_concurrent_commands(
cmds: &[ConcurrentCommand<'_>],
) -> anyhow::Result<Vec<anyhow::Result<()>>> {
let prefix_width = cmds.iter().map(|c| c.label.len()).max().unwrap_or(0);
let shell = ShellConfig::get()?;
if std::env::var_os("WORKTRUNK_TEST_SERIAL_CONCURRENT").is_some() {
return Ok(run_serial_with_prefix(shell, cmds, prefix_width));
}
#[cfg(unix)]
let signals = {
use signal_hook::consts::{SIGINT, SIGTERM};
signal_hook::iterator::Signals::new([SIGINT, SIGTERM])?
};
let mut children: Vec<SpawnedChild> = Vec::with_capacity(cmds.len());
for (i, cmd) in cmds.iter().enumerate() {
match spawn_child(shell, i, cmd) {
Ok(spawned) => children.push(spawned),
Err(e) => {
for mut prior in children {
let _ = prior.child.kill();
let _ = prior.child.wait();
}
return Err(e);
}
}
}
let (line_tx, line_rx) = mpsc::channel::<LabeledLine>();
let mut readers: Vec<thread::JoinHandle<()>> = Vec::new();
for (i, spawned) in children.iter_mut().enumerate() {
let label = cmds[i].label.to_string();
if let Some(stdout) = spawned.child.stdout.take() {
readers.push(spawn_reader(i, label.clone(), stdout, line_tx.clone()));
}
if let Some(stderr) = spawned.child.stderr.take() {
readers.push(spawn_reader(i, label, stderr, line_tx.clone()));
}
}
drop(line_tx);
#[cfg(unix)]
let signal_thread = spawn_signal_forwarder(
signals,
children
.iter()
.map(|c| c.child.id() as i32)
.collect::<Vec<_>>(),
);
{
let mut stderr = stderr().lock();
for labeled in line_rx {
let prefix = render_prefix(labeled.index, &labeled.label, prefix_width);
writeln!(stderr, "{}{}", prefix, labeled.line).ok();
}
}
for r in readers {
let _ = r.join();
}
let mut outcomes = Vec::with_capacity(children.len());
for (spawned, cmd) in children.into_iter().zip(cmds) {
outcomes.push(collect_outcome(spawned, cmd));
}
#[cfg(unix)]
{
signal_thread.stop();
}
Ok(outcomes)
}
fn run_serial_with_prefix(
shell: &ShellConfig,
cmds: &[ConcurrentCommand<'_>],
prefix_width: usize,
) -> Vec<anyhow::Result<()>> {
let mut results = Vec::with_capacity(cmds.len());
for (i, cmd) in cmds.iter().enumerate() {
let spawned = match spawn_child(shell, i, cmd) {
Ok(s) => s,
Err(e) => {
results.push(Err(e));
continue;
}
};
let result = drain_and_wait_single(spawned, cmd, i, prefix_width);
results.push(result);
}
results
}
fn drain_and_wait_single(
mut spawned: SpawnedChild,
cmd: &ConcurrentCommand<'_>,
index: usize,
prefix_width: usize,
) -> anyhow::Result<()> {
let (tx, rx) = mpsc::channel::<LabeledLine>();
let mut readers: Vec<thread::JoinHandle<()>> = Vec::new();
let label = cmd.label.to_string();
if let Some(stdout) = spawned.child.stdout.take() {
readers.push(spawn_reader(index, label.clone(), stdout, tx.clone()));
}
if let Some(stderr) = spawned.child.stderr.take() {
readers.push(spawn_reader(index, label, stderr, tx.clone()));
}
drop(tx);
{
let mut out = stderr().lock();
for labeled in rx {
let prefix = render_prefix(labeled.index, &labeled.label, prefix_width);
writeln!(out, "{prefix}{}", labeled.line).ok();
}
}
for r in readers {
let _ = r.join();
}
collect_outcome(spawned, cmd)
}
struct SpawnedChild {
child: Child,
cmd_str: String,
log_label: Option<String>,
started_at: Instant,
}
fn spawn_child(
shell: &ShellConfig,
index: usize,
cmd: &ConcurrentCommand<'_>,
) -> anyhow::Result<SpawnedChild> {
let mut command = shell.command(cmd.expanded);
command
.current_dir(cmd.working_dir)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
scrub_directive_env_vars(&mut command);
if let Some(path) = &cmd.directives.cd_file {
command.env(DIRECTIVE_CD_FILE_ENV_VAR, path);
}
if let Some(path) = &cmd.directives.legacy_file {
command.env(DIRECTIVE_FILE_ENV_VAR, path);
}
#[cfg(unix)]
{
use std::os::unix::process::CommandExt;
command.process_group(0);
}
log::debug!(
"$ {} (concurrent #{index}, shell: {})",
cmd.expanded,
shell.name
);
let mut child = command
.spawn()
.with_context(|| format!("failed to spawn concurrent command '{}'", cmd.label))?;
if let Some(mut stdin) = child.stdin.take() {
let _ = stdin.write_all(cmd.context_json.as_bytes());
}
Ok(SpawnedChild {
child,
cmd_str: cmd.expanded.to_string(),
log_label: cmd.log_label.map(str::to_string),
started_at: Instant::now(),
})
}
fn spawn_reader<R: Read + Send + 'static>(
index: usize,
label: String,
stream: R,
tx: Sender<LabeledLine>,
) -> thread::JoinHandle<()> {
thread::spawn(move || {
let mut reader = BufReader::new(stream);
let mut buf = Vec::with_capacity(256);
loop {
buf.clear();
match reader.read_until(b'\n', &mut buf) {
Ok(0) => return,
Ok(_) => {
if buf.last() == Some(&b'\n') {
buf.pop();
if buf.last() == Some(&b'\r') {
buf.pop();
}
}
let line = String::from_utf8_lossy(&buf).into_owned();
if tx
.send(LabeledLine {
index,
label: label.clone(),
line,
})
.is_err()
{
return; }
}
Err(_) => return, }
}
})
}
fn collect_outcome(spawned: SpawnedChild, cmd: &ConcurrentCommand<'_>) -> anyhow::Result<()> {
let SpawnedChild {
mut child,
cmd_str,
log_label,
started_at,
} = spawned;
let status = child
.wait()
.with_context(|| format!("failed to wait for concurrent command '{}'", cmd.label))?;
let duration = started_at.elapsed();
let exit_code = status.code();
#[cfg(unix)]
let signal = std::os::unix::process::ExitStatusExt::signal(&status);
#[cfg(not(unix))]
let signal: Option<i32> = None;
let normalized_code = exit_code.or_else(|| signal.map(|s| 128 + s));
if let Some(label) = log_label {
log_command(&label, &cmd_str, normalized_code, Some(duration));
}
if status.success() {
Ok(())
} else if let Some(sig) = signal {
Err(WorktrunkError::ChildProcessExited {
code: 128 + sig,
message: format!("terminated by signal {sig}"),
signal: Some(sig),
}
.into())
} else {
let code = exit_code.unwrap_or(1);
Err(WorktrunkError::ChildProcessExited {
code,
message: format!("exit status: {code}"),
signal: None,
}
.into())
}
}
struct LabeledLine {
index: usize,
label: String,
line: String,
}
fn render_prefix(index: usize, label: &str, width: usize) -> String {
use anstyle::{AnsiColor, Color, Style};
let palette = [
AnsiColor::Cyan,
AnsiColor::Magenta,
AnsiColor::Yellow,
AnsiColor::Green,
AnsiColor::Blue,
AnsiColor::BrightCyan,
AnsiColor::BrightMagenta,
AnsiColor::BrightYellow,
];
let style = Style::new()
.fg_color(Some(Color::Ansi(palette[index % palette.len()])))
.bold();
format!("{style}{label:<width$}{style:#} │ ")
}
#[cfg(unix)]
struct SignalForwarder {
stop: std::sync::Arc<std::sync::atomic::AtomicBool>,
handle: thread::JoinHandle<()>,
}
#[cfg(unix)]
impl SignalForwarder {
fn stop(self) {
self.stop.store(true, std::sync::atomic::Ordering::Relaxed);
let _ = self.handle.join();
}
}
#[cfg(unix)]
fn spawn_signal_forwarder(
mut signals: signal_hook::iterator::Signals,
pgids: Vec<i32>,
) -> SignalForwarder {
use std::sync::atomic::{AtomicBool, Ordering};
let stop = std::sync::Arc::new(AtomicBool::new(false));
let stop_clone = stop.clone();
let handle = thread::spawn(move || {
let mut seen_once = false;
while !stop_clone.load(Ordering::Relaxed) {
for sig in signals.pending() {
if !seen_once {
seen_once = true;
for &pgid in &pgids {
worktrunk::shell_exec::forward_signal_with_escalation(pgid, sig);
}
} else {
for &pgid in &pgids {
let _ = nix::sys::signal::killpg(
nix::unistd::Pid::from_raw(pgid),
nix::sys::signal::Signal::SIGKILL,
);
}
}
}
thread::sleep(std::time::Duration::from_millis(25));
}
});
SignalForwarder { stop, handle }
}
#[cfg(test)]
mod tests {
use super::*;
fn run_one_with_directives(
label: &str,
script: &str,
log_label: Option<&str>,
directives: &DirectivePassthrough,
) -> Vec<anyhow::Result<()>> {
let wd = std::env::temp_dir();
let specs = vec![ConcurrentCommand {
label,
expanded: script,
working_dir: &wd,
context_json: "{}",
log_label,
directives,
}];
run_concurrent_commands(&specs).expect("spawn failed")
}
#[test]
fn test_log_label_is_recorded() {
let outcomes = run_one_with_directives(
"job",
"true",
Some("test-label"),
&DirectivePassthrough::none(),
);
assert_eq!(outcomes.len(), 1);
assert!(outcomes[0].is_ok(), "`true` should exit 0");
}
#[test]
#[cfg(unix)]
fn test_directive_env_vars_passed_through() {
use tempfile::NamedTempFile;
let cd = NamedTempFile::new().unwrap();
let legacy = NamedTempFile::new().unwrap();
let directives = DirectivePassthrough {
cd_file: Some(cd.path().to_path_buf()),
legacy_file: Some(legacy.path().to_path_buf()),
};
let script = format!(
"printf CD > {} && printf LEGACY > {}",
cd.path().display(),
legacy.path().display(),
);
let outcomes = run_one_with_directives("job", &script, None, &directives);
assert!(outcomes[0].is_ok(), "child should exit 0");
let cd_contents = std::fs::read_to_string(cd.path()).unwrap();
let legacy_contents = std::fs::read_to_string(legacy.path()).unwrap();
assert_eq!(cd_contents, "CD");
assert_eq!(legacy_contents, "LEGACY");
}
#[test]
fn test_empty_cmds_returns_empty() {
let outcomes = run_concurrent_commands(&[]).expect("no spawn should happen");
assert!(outcomes.is_empty());
}
}