use super::*;
use async_trait::async_trait;
use std::path::PathBuf;
use std::process::Output;
use std::sync::Mutex;
#[derive(Debug, Clone)]
pub struct Invocation {
pub program: String,
pub args: Vec<String>,
pub cwd: PathBuf,
pub is_shell: bool,
pub is_interactive: bool,
pub is_streaming: bool,
}
#[derive(Debug)]
pub struct RecordingCommandRunner {
invocations: Mutex<Vec<Invocation>>,
exit_code: i32,
stdout: Vec<u8>,
stderr: Vec<u8>,
}
impl RecordingCommandRunner {
pub fn new(exit_code: i32) -> Self {
Self {
invocations: Mutex::new(Vec::new()),
exit_code,
stdout: Vec::new(),
stderr: Vec::new(),
}
}
pub fn with_stdout(mut self, stdout: Vec<u8>) -> Self {
self.stdout = stdout;
self
}
pub fn with_stderr(mut self, stderr: Vec<u8>) -> Self {
self.stderr = stderr;
self
}
pub fn invocations(&self) -> Vec<Invocation> {
self.invocations.lock().expect("mutex poisoned").clone()
}
fn make_output(&self) -> Output {
#[cfg(unix)]
let status = {
use std::os::unix::process::ExitStatusExt;
std::process::ExitStatus::from_raw(self.exit_code << 8)
};
#[cfg(windows)]
let status = {
use std::os::windows::process::ExitStatusExt;
std::process::ExitStatus::from_raw(self.exit_code as u32)
};
Output {
status,
stdout: self.stdout.clone(),
stderr: self.stderr.clone(),
}
}
fn record(
&self,
program: &str,
args: Vec<String>,
cwd: &Path,
is_shell: bool,
is_interactive: bool,
is_streaming: bool,
) {
self.invocations
.lock()
.expect("mutex poisoned")
.push(Invocation {
program: program.to_string(),
args,
cwd: cwd.to_path_buf(),
is_shell,
is_interactive,
is_streaming,
});
}
}
#[async_trait]
impl CommandRunner for RecordingCommandRunner {
async fn run(&self, program: &str, args: &[&str], cwd: &Path) -> anyhow::Result<Output> {
self.record(
program,
args.iter().map(|s| s.to_string()).collect(),
cwd,
false,
false,
false,
);
Ok(self.make_output())
}
async fn run_mut(&self, program: &str, args: &[&str], cwd: &Path) -> anyhow::Result<Output> {
self.run(program, args, cwd).await
}
async fn run_interactive(
&self,
program: &str,
args: &[&str],
cwd: &Path,
) -> anyhow::Result<std::process::ExitStatus> {
self.record(
program,
args.iter().map(|s| s.to_string()).collect(),
cwd,
false,
true,
false,
);
Ok(self.make_output().status)
}
async fn run_shell_interactive(
&self,
command: &str,
cwd: &Path,
) -> anyhow::Result<std::process::ExitStatus> {
self.record(
shell_program(),
vec![shell_flag().to_string(), command.to_string()],
cwd,
true,
true,
false,
);
Ok(self.make_output().status)
}
async fn run_streaming(
&self,
command: &str,
cwd: &Path,
) -> anyhow::Result<std::process::ExitStatus> {
self.record(
shell_program(),
vec![shell_flag().to_string(), command.to_string()],
cwd,
true,
false,
true,
);
Ok(self.make_output().status)
}
}
#[derive(Debug)]
pub struct DispatchRule {
pub program: String,
pub args: Option<Vec<String>>,
pub exit_code: i32,
pub stdout: Vec<u8>,
pub stderr: Vec<u8>,
}
#[derive(Debug)]
pub struct DispatchingCommandRunner {
rules: Vec<DispatchRule>,
default_exit_code: i32,
invocations: Mutex<Vec<Invocation>>,
}
impl DispatchingCommandRunner {
pub fn new(default_exit_code: i32) -> Self {
Self {
rules: Vec::new(),
default_exit_code,
invocations: Mutex::new(Vec::new()),
}
}
pub fn on_rule(mut self, rule: DispatchRule) -> Self {
self.rules.push(rule);
self
}
pub fn on(self, program: impl Into<String>, exit_code: i32) -> Self {
self.on_rule(DispatchRule {
program: program.into(),
args: None,
exit_code,
stdout: Vec::new(),
stderr: Vec::new(),
})
}
pub fn on_with_args(
self,
program: impl Into<String>,
args: Vec<String>,
exit_code: i32,
) -> Self {
self.on_rule(DispatchRule {
program: program.into(),
args: Some(args),
exit_code,
stdout: Vec::new(),
stderr: Vec::new(),
})
}
pub fn on_stdout(self, program: impl Into<String>, exit_code: i32, stdout: Vec<u8>) -> Self {
self.on_rule(DispatchRule {
program: program.into(),
args: None,
exit_code,
stdout,
stderr: Vec::new(),
})
}
pub fn on_with_args_stdout(
self,
program: impl Into<String>,
args: Vec<String>,
exit_code: i32,
stdout: Vec<u8>,
) -> Self {
self.on_rule(DispatchRule {
program: program.into(),
args: Some(args),
exit_code,
stdout,
stderr: Vec::new(),
})
}
pub fn on_stderr(self, program: impl Into<String>, exit_code: i32, stderr: Vec<u8>) -> Self {
self.on_rule(DispatchRule {
program: program.into(),
args: None,
exit_code,
stdout: Vec::new(),
stderr,
})
}
pub fn on_with_args_stderr(
self,
program: impl Into<String>,
args: Vec<String>,
exit_code: i32,
stderr: Vec<u8>,
) -> Self {
self.on_rule(DispatchRule {
program: program.into(),
args: Some(args),
exit_code,
stdout: Vec::new(),
stderr,
})
}
pub fn invocations(&self) -> Vec<Invocation> {
self.invocations.lock().expect("mutex poisoned").clone()
}
fn find_match(&self, program: &str, args: &[&str]) -> (i32, Vec<u8>, Vec<u8>) {
self.rules
.iter()
.find(|rule| {
rule.program == program
&& rule.args.as_ref().is_none_or(|prefix| {
prefix.len() <= args.len()
&& prefix
.iter()
.zip(args.iter())
.all(|(a, b)| a.as_str() == *b)
})
})
.map_or_else(
|| (self.default_exit_code, Vec::new(), Vec::new()),
|rule| (rule.exit_code, rule.stdout.clone(), rule.stderr.clone()),
)
}
fn make_output_for(&self, program: &str, args: &[&str]) -> Output {
let (exit_code, stdout, stderr) = self.find_match(program, args);
#[cfg(unix)]
let status = {
use std::os::unix::process::ExitStatusExt;
std::process::ExitStatus::from_raw(exit_code << 8)
};
#[cfg(windows)]
let status = {
use std::os::windows::process::ExitStatusExt;
std::process::ExitStatus::from_raw(exit_code as u32)
};
Output {
status,
stdout,
stderr,
}
}
fn record(
&self,
program: &str,
args: Vec<String>,
cwd: &Path,
is_shell: bool,
is_interactive: bool,
is_streaming: bool,
) {
self.invocations
.lock()
.expect("mutex poisoned")
.push(Invocation {
program: program.to_string(),
args,
cwd: cwd.to_path_buf(),
is_shell,
is_interactive,
is_streaming,
});
}
}
#[async_trait]
impl CommandRunner for DispatchingCommandRunner {
async fn run(&self, program: &str, args: &[&str], cwd: &Path) -> anyhow::Result<Output> {
self.record(
program,
args.iter().map(|s| s.to_string()).collect(),
cwd,
false,
false,
false,
);
Ok(self.make_output_for(program, args))
}
async fn run_mut(&self, program: &str, args: &[&str], cwd: &Path) -> anyhow::Result<Output> {
self.run(program, args, cwd).await
}
async fn run_interactive(
&self,
program: &str,
args: &[&str],
cwd: &Path,
) -> anyhow::Result<std::process::ExitStatus> {
self.record(
program,
args.iter().map(|s| s.to_string()).collect(),
cwd,
false,
true,
false,
);
Ok(self.make_output_for(program, args).status)
}
async fn run_shell_interactive(
&self,
command: &str,
cwd: &Path,
) -> anyhow::Result<std::process::ExitStatus> {
self.record(
shell_program(),
vec![shell_flag().to_string(), command.to_string()],
cwd,
true,
true,
false,
);
Ok(self
.make_output_for(shell_program(), &[shell_flag(), command])
.status)
}
async fn run_streaming(
&self,
command: &str,
cwd: &Path,
) -> anyhow::Result<std::process::ExitStatus> {
self.record(
shell_program(),
vec![shell_flag().to_string(), command.to_string()],
cwd,
true,
false,
true,
);
Ok(self
.make_output_for(shell_program(), &[shell_flag(), command])
.status)
}
}
#[cfg(test)]
mod dispatching_tests {
use std::path::Path;
use super::*;
#[tokio::test]
async fn dispatching_runner_returns_default_when_no_rule_matches() {
let runner = DispatchingCommandRunner::new(1);
let cwd = Path::new("/tmp");
let output = runner.run("unknown", &[], cwd).await.unwrap();
assert!(!output.status.success());
}
#[tokio::test]
async fn dispatching_runner_matches_program_name() {
let runner = DispatchingCommandRunner::new(1).on("git", 0);
let cwd = Path::new("/tmp");
let output = runner.run("git", &["status"], cwd).await.unwrap();
assert!(output.status.success());
}
#[tokio::test]
async fn dispatching_runner_first_matching_rule_wins() {
let runner = DispatchingCommandRunner::new(1).on("git", 0).on("git", 2); let cwd = Path::new("/tmp");
let output = runner.run("git", &[], cwd).await.unwrap();
assert!(output.status.success());
}
#[tokio::test]
async fn dispatching_runner_matches_args_prefix() {
let runner =
DispatchingCommandRunner::new(0).on_with_args("git", vec!["push".to_string()], 42);
let cwd = Path::new("/tmp");
let output = runner
.run("git", &["push", "origin", "HEAD"], cwd)
.await
.unwrap();
#[cfg(unix)]
{
use std::os::unix::process::ExitStatusExt;
assert_eq!(output.status.into_raw(), 42 << 8);
}
}
#[tokio::test]
async fn dispatching_runner_falls_through_when_args_prefix_does_not_match() {
let runner =
DispatchingCommandRunner::new(0).on_with_args("git", vec!["push".to_string()], 42);
let cwd = Path::new("/tmp");
let output = runner.run("git", &["fetch"], cwd).await.unwrap();
assert!(output.status.success());
}
#[tokio::test]
async fn dispatching_runner_returns_configured_stdout() {
let runner = DispatchingCommandRunner::new(0).on_stdout("npm", 0, b"test-user\n".to_vec());
let cwd = Path::new("/tmp");
let output = runner.run("npm", &["whoami"], cwd).await.unwrap();
assert_eq!(output.stdout, b"test-user\n");
}
#[tokio::test]
async fn dispatching_runner_returns_configured_stderr() {
let runner = DispatchingCommandRunner::new(0).on_stderr(
"cargo",
1,
b"error: not logged in\n".to_vec(),
);
let cwd = Path::new("/tmp");
let output = runner.run("cargo", &[], cwd).await.unwrap();
assert_eq!(output.stderr, b"error: not logged in\n");
}
#[tokio::test]
async fn dispatching_runner_on_rule_accepts_full_dispatch_rule() {
let rule = DispatchRule {
program: "npm".to_string(),
args: Some(vec!["whoami".to_string()]),
exit_code: 0,
stdout: b"alice\n".to_vec(),
stderr: Vec::new(),
};
let runner = DispatchingCommandRunner::new(1).on_rule(rule);
let cwd = Path::new("/tmp");
let output = runner.run("npm", &["whoami"], cwd).await.unwrap();
assert_eq!(output.stdout, b"alice\n");
assert!(output.status.success());
}
#[tokio::test]
async fn dispatching_runner_records_invocations() {
let runner = DispatchingCommandRunner::new(0).on("git", 0);
let cwd = Path::new("/tmp");
let _ = runner.run("git", &["status"], cwd).await.unwrap();
let _ = runner
.run_mut("git", &["commit", "-m", "msg"], cwd)
.await
.unwrap();
let invocations = runner.invocations();
assert_eq!(invocations.len(), 2);
assert_eq!(invocations[0].args, vec!["status"]);
assert_eq!(invocations[1].args, vec!["commit", "-m", "msg"]);
}
#[tokio::test]
async fn dispatching_runner_records_streaming_invocations() {
let runner = DispatchingCommandRunner::new(0);
let cwd = Path::new("/tmp");
let _ = runner.run_streaming("npm install", cwd).await.unwrap();
let invocations = runner.invocations();
assert_eq!(invocations.len(), 1);
assert!(invocations[0].is_shell);
assert!(invocations[0].is_streaming);
assert!(!invocations[0].is_interactive);
assert_eq!(invocations[0].program, shell_program());
}
#[tokio::test]
async fn dispatching_runner_records_interactive_invocations() {
let runner = DispatchingCommandRunner::new(0);
let cwd = Path::new("/tmp");
let _ = runner
.run_interactive("vim", &["file.txt"], cwd)
.await
.unwrap();
let invocations = runner.invocations();
assert_eq!(invocations.len(), 1);
assert!(invocations[0].is_interactive);
assert_eq!(invocations[0].program, "vim");
}
}