use std::ffi::{OsStr, OsString};
use std::sync::Mutex;
use crate::command::Command;
use crate::error::Result;
use crate::result::ProcessResult;
use crate::runner::ProcessRunner;
#[derive(Debug, Clone)]
pub struct Reply {
stdout: String,
stderr: String,
code: i32,
timed_out: bool,
pending: bool,
line_delay: Option<std::time::Duration>,
}
impl Reply {
pub fn ok(stdout: impl Into<String>) -> Self {
Self {
stdout: stdout.into(),
stderr: String::new(),
code: 0,
timed_out: false,
pending: false,
line_delay: None,
}
}
pub fn fail(code: i32, stderr: impl Into<String>) -> Self {
Self {
stdout: String::new(),
stderr: stderr.into(),
code,
timed_out: false,
pending: false,
line_delay: None,
}
}
pub fn timeout() -> Self {
Self {
stdout: String::new(),
stderr: String::new(),
code: 0,
timed_out: true,
pending: false,
line_delay: None,
}
}
#[cfg(feature = "cancellation")]
pub fn pending() -> Self {
Self {
stdout: String::new(),
stderr: String::new(),
code: 0,
timed_out: false,
pending: true,
line_delay: None,
}
}
pub fn lines<I, S>(lines: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let mut text = lines
.into_iter()
.map(Into::into)
.collect::<Vec<_>>()
.join("\n");
if !text.is_empty() {
text.push('\n');
}
Self::ok(text)
}
pub fn with_line_delay(mut self, delay: std::time::Duration) -> Self {
self.line_delay = Some(delay);
self
}
pub fn with_stdout(mut self, stdout: impl Into<String>) -> Self {
self.stdout = stdout.into();
self
}
fn into_running(self, command: &Command) -> crate::RunningProcess {
let lifetime = if self.pending {
None
} else {
let per_line = self.line_delay.unwrap_or_default();
let lines = self.stdout.split_inclusive('\n').count() as u32;
Some(per_line * lines)
};
let scripted = crate::running::ScriptedProc::new(
self.stdout,
self.stderr,
(!self.timed_out).then_some(self.code),
self.timed_out,
lifetime,
self.line_delay,
);
crate::RunningProcess::from_scripted(command, scripted)
}
fn into_result(
self,
program: String,
timeout: Option<std::time::Duration>,
) -> ProcessResult<String> {
let code = (!self.timed_out).then_some(self.code);
ProcessResult::new(
program,
self.stdout,
self.stderr,
code,
self.timed_out,
timeout,
)
}
}
type Predicate = Box<dyn Fn(&Command) -> bool + Send + Sync>;
enum Rule {
Prefix(Vec<OsString>),
Predicate(Predicate),
}
impl Rule {
fn matches(&self, command: &Command) -> bool {
match self {
Rule::Prefix(prefix) => command.arguments().starts_with(prefix),
Rule::Predicate(pred) => pred(command),
}
}
}
#[derive(Default)]
pub struct ScriptedRunner {
rules: Vec<(Rule, Reply)>,
fallback: Option<Reply>,
}
impl std::fmt::Debug for ScriptedRunner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScriptedRunner")
.field("rules", &self.rules.len())
.field("has_fallback", &self.fallback.is_some())
.finish_non_exhaustive()
}
}
impl ScriptedRunner {
pub fn new() -> Self {
Self::default()
}
pub fn on<I, S>(mut self, prefix: I, reply: Reply) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
let prefix = prefix
.into_iter()
.map(|s| s.as_ref().to_os_string())
.collect();
self.rules.push((Rule::Prefix(prefix), reply));
self
}
pub fn when<F>(mut self, predicate: F, reply: Reply) -> Self
where
F: Fn(&Command) -> bool + Send + Sync + 'static,
{
self.rules
.push((Rule::Predicate(Box::new(predicate)), reply));
self
}
pub fn fallback(mut self, reply: Reply) -> Self {
self.fallback = Some(reply);
self
}
fn matched_reply(&self, command: &Command, program: &str) -> Result<&Reply> {
for (rule, reply) in &self.rules {
if rule.matches(command) {
return Ok(reply);
}
}
self.fallback
.as_ref()
.ok_or_else(|| crate::error::Error::Spawn {
program: program.to_owned(),
source: std::io::Error::new(
std::io::ErrorKind::NotFound,
"ScriptedRunner: no rule matched and no fallback set",
),
})
}
}
fn replay_line_handlers(command: &Command, reply: &Reply) {
if let Some(handler) = command.stdout_handler() {
for line in reply.stdout.lines() {
handler(line);
}
}
if let Some(handler) = command.stderr_handler() {
for line in reply.stderr.lines() {
handler(line);
}
}
}
#[async_trait::async_trait]
impl ProcessRunner for ScriptedRunner {
async fn output(&self, command: &Command) -> Result<ProcessResult<String>> {
let program = command.program().to_string_lossy().into_owned();
let timeout = command.configured_timeout();
let reply = self.matched_reply(command, &program)?;
if reply.pending {
return park_until_cancelled(command, program).await;
}
replay_line_handlers(command, reply);
Ok(reply.clone().into_result(program, timeout))
}
async fn start(&self, command: &Command) -> Result<crate::RunningProcess> {
let program = command.program().to_string_lossy().into_owned();
let reply = self.matched_reply(command, &program)?;
Ok(reply.clone().into_running(command))
}
}
async fn park_until_cancelled(command: &Command, program: String) -> Result<ProcessResult<String>> {
#[cfg(feature = "cancellation")]
if let Some(token) = command.cancel_token() {
token.cancelled().await;
return Err(crate::error::Error::Cancelled { program });
}
#[cfg(not(feature = "cancellation"))]
let _ = (command, program);
std::future::pending().await
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Invocation {
pub program: OsString,
pub args: Vec<OsString>,
pub cwd: Option<OsString>,
pub envs: Vec<(OsString, Option<OsString>)>,
pub has_stdin: bool,
}
impl Invocation {
pub(crate) fn from_command(command: &Command) -> Self {
Self {
program: command.program().to_os_string(),
args: command.arguments().to_vec(),
cwd: command.working_dir().map(|p| p.as_os_str().to_os_string()),
envs: command.env_overrides().to_vec(),
has_stdin: command
.stdin_source()
.is_some_and(|stdin| !stdin.is_empty()),
}
}
pub fn has_flag(&self, flag: impl AsRef<OsStr>) -> bool {
let flag = flag.as_ref();
self.args.iter().any(|a| a == flag)
}
pub fn args_str(&self) -> Vec<String> {
self.args
.iter()
.map(|a| a.to_string_lossy().into_owned())
.collect()
}
}
pub struct RecordingRunner<R: ProcessRunner = ScriptedRunner> {
inner: R,
calls: Mutex<Vec<Invocation>>,
}
impl<R: ProcessRunner> std::fmt::Debug for RecordingRunner<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let calls = self.calls.lock().map(|c| c.len()).unwrap_or(0);
f.debug_struct("RecordingRunner")
.field("calls", &calls)
.finish_non_exhaustive()
}
}
impl RecordingRunner<ScriptedRunner> {
pub fn replying(reply: Reply) -> Self {
Self::new(ScriptedRunner::new().fallback(reply))
}
}
impl<R: ProcessRunner> RecordingRunner<R> {
pub fn new(inner: R) -> Self {
Self {
inner,
calls: Mutex::new(Vec::new()),
}
}
pub fn calls(&self) -> Vec<Invocation> {
self.calls.lock().expect("recorder lock poisoned").clone()
}
pub fn only_call(&self) -> Invocation {
let calls = self.calls();
assert_eq!(
calls.len(),
1,
"expected exactly one call, got {}",
calls.len()
);
calls.into_iter().next().expect("length checked above")
}
}
#[async_trait::async_trait]
impl<R: ProcessRunner> ProcessRunner for RecordingRunner<R> {
async fn output(&self, command: &Command) -> Result<ProcessResult<String>> {
self.calls
.lock()
.expect("recorder lock poisoned")
.push(Invocation::from_command(command));
self.inner.output(command).await
}
async fn start(&self, command: &Command) -> Result<crate::RunningProcess> {
self.calls
.lock()
.expect("recorder lock poisoned")
.push(Invocation::from_command(command));
self.inner.start(command).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runner::ProcessRunnerExt;
#[tokio::test]
async fn scripted_start_streams_canned_lines_through_real_pumps() {
use tokio_stream::StreamExt;
let runner = ScriptedRunner::new().on(["log"], Reply::lines(["first", "second", "third"]));
let cmd = Command::new("git").arg("log");
let mut run = runner.start(&cmd).await.expect("scripted start");
assert_eq!(run.pid(), None, "a scripted child has no OS identity");
let mut lines = run.stdout_lines();
let mut seen = Vec::new();
while let Some(line) = lines.next().await {
seen.push(line);
}
assert_eq!(seen, ["first", "second", "third"]);
let (code, stderr) = run.finish_streamed().await.expect("finish");
assert_eq!(code, Some(0));
assert_eq!(stderr, "");
}
#[tokio::test]
async fn scripted_start_supports_probes_and_failing_finish() {
let runner = ScriptedRunner::new().fallback(
Reply::fail(7, "boom: detail\n").with_stdout("starting up\nready to serve\n"),
);
let cmd = Command::new("server");
let mut run = runner.start(&cmd).await.expect("scripted start");
run.wait_for_line(|l| l.contains("ready"), std::time::Duration::from_secs(5))
.await
.expect("the canned banner satisfies the probe");
let (code, stderr) = run.finish_streamed().await.expect("finish");
assert_eq!(code, Some(7));
assert_eq!(stderr, "boom: detail");
}
#[tokio::test]
async fn scripted_start_consumed_by_output_string() {
let runner = ScriptedRunner::new().fallback(Reply::lines(["a", "b"]));
let run = runner.start(&Command::new("x")).await.expect("start");
let result = run.output_string().await.expect("consume");
assert!(result.is_success());
assert_eq!(result.stdout(), "a\nb");
}
#[tokio::test(start_paused = true)]
async fn scripted_line_delay_delivers_incrementally() {
use tokio_stream::StreamExt;
let runner = ScriptedRunner::new().fallback(
Reply::lines(["tick", "tock"]).with_line_delay(std::time::Duration::from_secs(10)),
);
let mut run = runner
.start(&Command::new("clock"))
.await
.expect("scripted start");
let mut lines = run.stdout_lines();
assert!(
tokio::time::timeout(std::time::Duration::from_secs(5), lines.next())
.await
.is_err(),
"no line may arrive before its scripted delay"
);
assert_eq!(lines.next().await.as_deref(), Some("tick"));
assert_eq!(lines.next().await.as_deref(), Some("tock"));
assert_eq!(lines.next().await, None);
}
#[tokio::test]
async fn scripted_timeout_reply_surfaces_through_start() {
let runner = ScriptedRunner::new().fallback(Reply::timeout());
let cmd = Command::new("slow").timeout(std::time::Duration::from_secs(9));
let run = runner.start(&cmd).await.expect("start");
let result = run.output_string().await.expect("a timeout is captured");
assert!(result.timed_out());
assert!(!result.is_success());
}
#[tokio::test]
async fn output_replays_canned_lines_through_handlers() {
use std::sync::{Arc, Mutex};
let seen = Arc::new(Mutex::new(Vec::new()));
let errs = Arc::new(Mutex::new(Vec::new()));
let runner = ScriptedRunner::new().on(["fetch"], Reply::ok("a\nb\n").with_stdout("a\nb\n"));
let cmd = Command::new("git")
.arg("fetch")
.on_stdout_line({
let seen = seen.clone();
move |l| seen.lock().unwrap().push(l.to_owned())
})
.on_stderr_line({
let errs = errs.clone();
move |l| errs.lock().unwrap().push(l.to_owned())
});
let result = runner.output(&cmd).await.expect("scripted run");
assert!(result.is_success());
assert_eq!(*seen.lock().unwrap(), ["a", "b"]);
assert!(errs.lock().unwrap().is_empty());
}
#[tokio::test]
async fn handler_calls_happen_before_the_consuming_verb_resolves() {
use std::sync::{Arc, Mutex};
let seen = Arc::new(Mutex::new(0usize));
let lines: Vec<String> = (1..=100).map(|n| format!("line {n}")).collect();
let runner = ScriptedRunner::new().fallback(Reply::lines(lines));
let cmd = Command::new("x").on_stdout_line({
let seen = seen.clone();
move |_| *seen.lock().unwrap() += 1
});
let run = runner.start(&cmd).await.expect("scripted start");
let result = run.output_string().await.expect("consume");
assert!(result.is_success());
assert_eq!(
*seen.lock().unwrap(),
100,
"all handler calls happen-before the verb resolves"
);
}
#[tokio::test]
async fn recording_runner_records_start_invocations() {
let rec = RecordingRunner::new(ScriptedRunner::new().fallback(Reply::lines(["x"])));
let run = rec
.start(&Command::new("gh").args(["run", "watch"]))
.await
.expect("recorded start");
drop(run); assert_eq!(rec.only_call().args_str(), ["run", "watch"]);
}
#[cfg(feature = "cancellation")]
#[tokio::test(start_paused = true)]
async fn scripted_pending_start_is_cancellable() {
let token = crate::CancellationToken::new();
let runner = ScriptedRunner::new().fallback(Reply::pending());
let cmd = Command::new("watch").cancel_on(token.clone());
let run = runner.start(&cmd).await.expect("start");
let consume = run.output_string();
tokio::pin!(consume);
assert!(
tokio::time::timeout(std::time::Duration::from_secs(3600), &mut consume)
.await
.is_err(),
"a pending scripted run must not resolve before cancellation"
);
token.cancel();
let err = tokio::time::timeout(std::time::Duration::from_secs(3600), consume)
.await
.expect("the token resolves the run")
.expect_err("cancellation is always an error");
assert!(
matches!(err, crate::error::Error::Cancelled { .. }),
"got {err:?}"
);
}
#[tokio::test]
async fn prefix_rule_matches_and_replies() {
let runner = ScriptedRunner::new().on(["status"], Reply::ok("clean"));
let out = runner
.output(&Command::new("git").arg("status"))
.await
.unwrap();
assert_eq!(out.stdout(), "clean");
assert!(out.is_success());
}
#[tokio::test]
async fn predicate_rule_and_fallback() {
let runner = ScriptedRunner::new()
.when(
|c| c.arguments().iter().any(|a| a == "--version"),
Reply::ok("v1"),
)
.fallback(Reply::fail(1, "unknown"));
assert_eq!(
runner
.output(&Command::new("tool").arg("--version"))
.await
.unwrap()
.stdout(),
"v1"
);
let miss = runner.output(&Command::new("tool").arg("x")).await.unwrap();
assert_eq!(miss.code(), Some(1));
assert!(!miss.is_success());
}
#[tokio::test]
async fn no_match_without_fallback_is_a_not_found_spawn_error() {
let runner = ScriptedRunner::new().on(["status"], Reply::ok("clean"));
let err = runner
.output(&Command::new("git").arg("log"))
.await
.expect_err("an unmatched command with no fallback must error");
match err {
crate::error::Error::Spawn { program, source } => {
assert_eq!(program, "git");
assert_eq!(source.kind(), std::io::ErrorKind::NotFound);
}
other => panic!("expected Error::Spawn, got {other:?}"),
}
}
#[tokio::test]
async fn prefix_matches_whole_elements_not_substrings() {
let runner = ScriptedRunner::new().on(["foo"], Reply::ok("hit"));
assert!(
runner
.output(&Command::new("tool").args(["foo", "bar"]))
.await
.is_ok()
);
assert!(
runner
.output(&Command::new("tool").arg("foobar"))
.await
.is_err(),
"substring of an element is not a prefix match"
);
}
#[tokio::test]
async fn timeout_reply_surfaces_as_timeout_error() {
use crate::error::Error;
let runner = ScriptedRunner::new().fallback(Reply::timeout());
let out = runner.output(&Command::new("git")).await.unwrap();
assert!(out.timed_out());
assert!(matches!(
runner.run(&Command::new("git")).await.unwrap_err(),
Error::Timeout { .. }
));
assert!(matches!(
runner.exit_code(&Command::new("git")).await.unwrap_err(),
Error::Timeout { .. }
));
let cmd = Command::new("git").timeout(std::time::Duration::from_secs(7));
match runner.run(&cmd).await.unwrap_err() {
Error::Timeout { timeout, .. } => {
assert_eq!(timeout, std::time::Duration::from_secs(7))
}
other => panic!("expected Timeout, got {other:?}"),
}
}
#[cfg(feature = "cancellation")]
#[tokio::test(start_paused = true)]
async fn pending_parks_until_the_token_fires_then_cancels() {
use crate::error::Error;
let token = crate::CancellationToken::new();
let runner = ScriptedRunner::new().on(["run", "watch"], Reply::pending());
let cmd = Command::new("gh")
.args(["run", "watch"])
.cancel_on(token.clone());
let call = runner.output(&cmd);
tokio::pin!(call);
assert!(
tokio::time::timeout(std::time::Duration::from_secs(3600), &mut call)
.await
.is_err(),
"a pending reply must not resolve before cancellation"
);
token.cancel();
match call.await {
Err(Error::Cancelled { program }) => assert_eq!(program, "gh"),
other => panic!("expected Error::Cancelled, got {other:?}"),
}
}
#[cfg(feature = "cancellation")]
#[tokio::test(start_paused = true)]
async fn pending_without_a_token_parks_forever() {
let runner = ScriptedRunner::new().fallback(Reply::pending());
let cmd = Command::new("gh");
let call = runner.output(&cmd);
tokio::pin!(call);
assert!(
tokio::time::timeout(std::time::Duration::from_secs(3600), &mut call)
.await
.is_err()
);
}
#[tokio::test]
async fn probe_reads_exit_code_as_bool() {
use crate::error::Error;
let runner = ScriptedRunner::new()
.on(["yes"], Reply::ok(""))
.on(["no"], Reply::fail(1, ""))
.on(["boom"], Reply::fail(2, "bad"))
.fallback(Reply::timeout());
assert!(runner.probe(&Command::new("t").arg("yes")).await.unwrap());
assert!(!runner.probe(&Command::new("t").arg("no")).await.unwrap());
assert!(matches!(
runner
.probe(&Command::new("t").arg("boom"))
.await
.unwrap_err(),
Error::Exit { code: 2, .. }
));
assert!(matches!(
runner
.probe(&Command::new("t").arg("other"))
.await
.unwrap_err(),
Error::Timeout { .. }
));
}
#[tokio::test]
async fn run_ext_trims_and_checks_success() {
let runner = ScriptedRunner::new().fallback(Reply::ok(" hello \n"));
let trimmed = runner.run(&Command::new("echo")).await.unwrap();
assert_eq!(trimmed, " hello");
}
#[tokio::test]
async fn recording_captures_args_cwd_and_absence() {
let recorder = RecordingRunner::replying(Reply::ok("ok"));
recorder
.output(
&Command::new("gh")
.current_dir("/repo")
.args(["pr", "create", "--title", "T"]),
)
.await
.unwrap();
let call = recorder.only_call();
assert_eq!(call.program, OsString::from("gh"));
assert_eq!(call.cwd, Some(OsString::from("/repo")));
assert!(call.has_flag("--title"));
assert!(!call.has_flag("--base"), "no --base flag was passed");
}
}