use std::ffi::{OsStr, OsString};
use std::path::Path;
use std::time::Duration;
use crate::command::Command;
use crate::error::Result;
use crate::result::ProcessResult;
use crate::runner::{JobRunner, ProcessRunner, ProcessRunnerExt};
pub struct CliClient<R: ProcessRunner = JobRunner> {
program: OsString,
runner: R,
timeout: Option<Duration>,
envs: Vec<(OsString, Option<OsString>)>,
#[cfg(feature = "cancellation")]
cancel: Option<tokio_util::sync::CancellationToken>,
}
impl<R: ProcessRunner> std::fmt::Debug for CliClient<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut d = f.debug_struct("CliClient");
d.field("program", &self.program)
.field("timeout", &self.timeout)
.field("envs", &self.envs);
#[cfg(feature = "cancellation")]
d.field("has_default_cancel", &self.cancel.is_some());
d.finish_non_exhaustive()
}
}
impl CliClient<JobRunner> {
pub fn new(program: impl AsRef<OsStr>) -> Self {
Self {
program: program.as_ref().to_os_string(),
runner: JobRunner,
timeout: None,
envs: Vec::new(),
#[cfg(feature = "cancellation")]
cancel: None,
}
}
}
impl<R: ProcessRunner> CliClient<R> {
pub fn with_runner(program: impl AsRef<OsStr>, runner: R) -> Self {
Self {
program: program.as_ref().to_os_string(),
runner,
timeout: None,
envs: Vec::new(),
#[cfg(feature = "cancellation")]
cancel: None,
}
}
#[must_use]
pub fn default_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn default_env(mut self, key: impl AsRef<OsStr>, value: impl AsRef<OsStr>) -> Self {
self.envs.push((
key.as_ref().to_os_string(),
Some(value.as_ref().to_os_string()),
));
self
}
#[must_use]
pub fn default_env_remove(mut self, key: impl AsRef<OsStr>) -> Self {
self.envs.push((key.as_ref().to_os_string(), None));
self
}
#[cfg(feature = "cancellation")]
#[must_use]
pub fn default_cancel_on(mut self, token: tokio_util::sync::CancellationToken) -> Self {
self.cancel = Some(token);
self
}
pub fn runner(&self) -> &R {
&self.runner
}
pub fn timeout(&self) -> Option<Duration> {
self.timeout
}
pub fn command<I, S>(&self, args: I) -> Command
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
self.apply_defaults(Command::new(&self.program).args(args))
}
pub fn command_in<I, S>(&self, dir: &Path, args: I) -> Command
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
self.apply_defaults(Command::new(&self.program).current_dir(dir).args(args))
}
fn apply_defaults(&self, command: Command) -> Command {
let mut command = match self.timeout {
Some(timeout) => command.timeout(timeout),
None => command,
};
for (key, value) in &self.envs {
command = match value {
Some(value) => command.env(key, value),
None => command.env_remove(key),
};
}
#[cfg(feature = "cancellation")]
if let Some(token) = &self.cancel {
command = command.cancel_on(token.clone());
}
command
}
pub async fn run(&self, command: Command) -> Result<String> {
Ok(self
.runner
.checked(&command)
.await?
.into_stdout()
.trim_end()
.to_owned())
}
pub async fn output(&self, command: Command) -> Result<ProcessResult<String>> {
self.runner.output(&command).await
}
pub async fn run_unit(&self, command: Command) -> Result<()> {
self.runner.run_unit(&command).await
}
pub async fn exit_code(&self, command: Command) -> Result<i32> {
self.runner.exit_code(&command).await
}
pub async fn probe(&self, command: Command) -> Result<bool> {
self.runner.probe(&command).await
}
pub async fn parse<T>(&self, command: Command, parse: impl FnOnce(&str) -> T) -> Result<T> {
let out = self.runner.checked(&command).await?;
Ok(parse(out.stdout()))
}
pub async fn try_parse<T>(
&self,
command: Command,
parse: impl FnOnce(&str) -> Result<T>,
) -> Result<T> {
let out = self.runner.checked(&command).await?;
parse(out.stdout())
}
}
#[macro_export]
macro_rules! cli_client {
($(#[$meta:meta])* $vis:vis struct $name:ident => $binary:expr) => {
$(#[$meta])*
$vis struct $name<R: $crate::ProcessRunner = $crate::JobRunner> {
core: $crate::CliClient<R>,
}
impl $name<$crate::JobRunner> {
pub fn new() -> Self {
Self { core: $crate::CliClient::new($binary) }
}
}
impl ::core::default::Default for $name<$crate::JobRunner> {
fn default() -> Self {
Self::new()
}
}
impl<R: $crate::ProcessRunner> $name<R> {
pub fn with_runner(runner: R) -> Self {
Self { core: $crate::CliClient::with_runner($binary, runner) }
}
pub fn default_timeout(mut self, timeout: ::core::time::Duration) -> Self {
self.core = self.core.default_timeout(timeout);
self
}
pub fn default_env(
mut self,
key: impl ::core::convert::AsRef<::std::ffi::OsStr>,
value: impl ::core::convert::AsRef<::std::ffi::OsStr>,
) -> Self {
self.core = self.core.default_env(key, value);
self
}
pub fn default_env_remove(
mut self,
key: impl ::core::convert::AsRef<::std::ffi::OsStr>,
) -> Self {
self.core = self.core.default_env_remove(key);
self
}
}
$crate::__cli_client_cancellation!($name);
};
}
#[cfg(feature = "cancellation")]
#[doc(hidden)]
#[macro_export]
macro_rules! __cli_client_cancellation {
($name:ident) => {
impl<R: $crate::ProcessRunner> $name<R> {
pub fn default_cancel_on(mut self, token: $crate::CancellationToken) -> Self {
self.core = self.core.default_cancel_on(token);
self
}
}
};
}
#[cfg(not(feature = "cancellation"))]
#[doc(hidden)]
#[macro_export]
macro_rules! __cli_client_cancellation {
($name:ident) => {};
}
#[cfg(test)]
mod tests {
use std::path::Path;
use std::time::Duration;
use super::*;
use crate::{Error, RecordingRunner, Reply, ScriptedRunner};
crate::cli_client!(struct Demo => "git");
impl<R: ProcessRunner> Demo<R> {
async fn head(&self, dir: &Path) -> Result<String> {
self.core
.run(self.core.command_in(dir, ["rev-parse", "HEAD"]))
.await
}
async fn is_clean(&self, dir: &Path) -> Result<bool> {
Ok(self
.core
.exit_code(self.core.command_in(dir, ["diff", "--quiet"]))
.await?
== 0)
}
async fn branches(&self, dir: &Path) -> Result<Vec<String>> {
self.core
.parse(self.core.command_in(dir, ["branch"]), |s| {
s.lines().map(|l| l.trim().to_owned()).collect()
})
.await
}
}
#[tokio::test]
async fn run_trims_trailing_whitespace_only() {
let demo =
Demo::with_runner(ScriptedRunner::new().on(["rev-parse"], Reply::ok(" abc123 \n")));
assert_eq!(demo.head(Path::new(".")).await.unwrap(), " abc123");
}
#[tokio::test]
async fn exit_code_maps_exit_status() {
let demo = Demo::with_runner(ScriptedRunner::new().on(["diff"], Reply::fail(1, "")));
assert!(!demo.is_clean(Path::new(".")).await.unwrap());
}
#[tokio::test]
async fn parse_builds_a_typed_value() {
let demo =
Demo::with_runner(ScriptedRunner::new().on(["branch"], Reply::ok("main\nfeature\n")));
assert_eq!(
demo.branches(Path::new(".")).await.unwrap(),
vec!["main", "feature"]
);
}
#[tokio::test]
async fn try_parse_maps_failure_to_parse_error() {
let client = CliClient::with_runner(
"gh",
ScriptedRunner::new().fallback(Reply::ok("not a number")),
);
let err = client
.try_parse::<u32>(client.command(["x"]), |s| {
s.trim().parse::<u32>().map_err(|e| Error::Parse {
program: "gh".into(),
message: e.to_string(),
})
})
.await
.unwrap_err();
assert!(matches!(err, Error::Parse { .. }), "got {err:?}");
}
#[tokio::test]
async fn when_predicate_reads_public_command_accessors() {
let runner = ScriptedRunner::new()
.when(
|c| c.working_dir() == Some(Path::new("/repo")),
Reply::ok("in-repo"),
)
.fallback(Reply::ok("elsewhere"));
let client = CliClient::with_runner("git", runner);
assert_eq!(
client
.run(client.command_in(Path::new("/repo"), ["status"]))
.await
.unwrap(),
"in-repo"
);
assert_eq!(
client.run(client.command(["status"])).await.unwrap(),
"elsewhere"
);
}
#[tokio::test]
async fn recording_runner_captures_args_cwd_and_absence() {
let rec = RecordingRunner::replying(Reply::ok("https://gh/pr/2\n"));
let client = CliClient::with_runner("gh", &rec);
let _ = client
.run(client.command_in(Path::new("/repo"), ["pr", "create", "--title", "T"]))
.await
.unwrap();
let call = rec.only_call();
assert_eq!(call.cwd.as_deref(), Some(std::ffi::OsStr::new("/repo")));
assert_eq!(call.args_str(), ["pr", "create", "--title", "T"]);
assert!(!call.has_flag("--base"), "no --base flag was passed");
}
#[tokio::test]
async fn exit_code_errors_on_timeout() {
let client = CliClient::with_runner("gh", ScriptedRunner::new().fallback(Reply::timeout()));
assert!(matches!(
client
.exit_code(client.command(["auth", "status"]))
.await
.unwrap_err(),
Error::Timeout { .. }
));
}
#[tokio::test]
async fn default_timeout_is_applied() {
let client = CliClient::new("git").default_timeout(Duration::from_secs(7));
assert_eq!(
client.command(["status"]).configured_timeout(),
Some(Duration::from_secs(7))
);
}
#[tokio::test]
async fn probe_maps_exit_code_to_bool() {
let client = CliClient::with_runner(
"git",
ScriptedRunner::new()
.on(["diff"], Reply::fail(1, ""))
.fallback(Reply::ok("")),
);
assert!(
!client
.probe(client.command(["diff", "--quiet"]))
.await
.unwrap()
);
assert!(client.probe(client.command(["status"])).await.unwrap());
}
#[tokio::test]
async fn default_env_is_applied_to_every_command() {
use std::ffi::OsString;
let client = CliClient::new("git").default_env("GIT_TERMINAL_PROMPT", "0");
for cmd in [
client.command(["status"]),
client.command_in(Path::new("."), ["fetch"]),
] {
assert!(
cmd.env_overrides()
.iter()
.any(|(k, v)| k == "GIT_TERMINAL_PROMPT"
&& v.as_deref() == Some(OsString::from("0").as_os_str())),
"default env missing on built command",
);
}
}
#[tokio::test]
async fn default_env_reaches_the_invocation() {
let rec = RecordingRunner::replying(Reply::ok("ok\n"));
let client = CliClient::with_runner("git", &rec).default_env("GIT_TERMINAL_PROMPT", "0");
let _ = client.run(client.command(["status"])).await.unwrap();
let call = rec.only_call();
assert!(
call.envs
.iter()
.any(|(k, v)| k == "GIT_TERMINAL_PROMPT" && v.is_some()),
"env override did not reach the runner: {:?}",
call.envs
);
}
#[cfg(feature = "cancellation")]
#[tokio::test]
async fn default_cancel_on_is_applied_to_every_command() {
let token = crate::CancellationToken::new();
let client = CliClient::new("git").default_cancel_on(token);
for cmd in [
client.command(["status"]),
client.command_in(Path::new("."), ["fetch"]),
] {
assert!(
cmd.cancel_token().is_some(),
"default token missing on built command"
);
}
assert!(format!("{client:?}").contains("has_default_cancel: true"));
}
#[cfg(feature = "cancellation")]
#[tokio::test(start_paused = true)]
async fn per_command_cancel_on_overrides_the_default() {
use crate::CancellationToken;
let default_token = CancellationToken::new();
let explicit = CancellationToken::new();
let client = CliClient::with_runner("gh", ScriptedRunner::new().fallback(Reply::pending()))
.default_cancel_on(default_token.clone());
let cmd = client.command(["run", "watch"]).cancel_on(explicit.clone());
let call = client.output(cmd);
tokio::pin!(call);
default_token.cancel();
assert!(
tokio::time::timeout(Duration::from_secs(3600), &mut call)
.await
.is_err(),
"the replaced default token must not cancel the call"
);
explicit.cancel();
let err = tokio::time::timeout(Duration::from_secs(3600), call)
.await
.expect("the explicit token must resolve the call")
.expect_err("explicit token cancels");
assert!(matches!(err, Error::Cancelled { .. }), "got {err:?}");
}
#[cfg(feature = "cancellation")]
#[tokio::test(start_paused = true)]
async fn acceptance_pending_reply_with_client_default_cancel() {
use crate::CancellationToken;
let token = CancellationToken::new();
let rec =
RecordingRunner::new(ScriptedRunner::new().on(["run", "watch"], Reply::pending()));
let client = CliClient::with_runner("gh", &rec).default_cancel_on(token.clone());
let call = client.output(client.command(["run", "watch", "123"]));
tokio::pin!(call);
assert!(
tokio::time::timeout(Duration::from_secs(3600), &mut call)
.await
.is_err(),
"must not resolve before the token fires"
);
token.cancel();
match tokio::time::timeout(Duration::from_secs(3600), call)
.await
.expect("the cancelled token must resolve the call")
{
Err(Error::Cancelled { program }) => assert_eq!(program, "gh"),
other => panic!("expected Error::Cancelled, got {other:?}"),
}
assert_eq!(rec.only_call().args_str(), ["run", "watch", "123"]);
}
#[cfg(feature = "cancellation")]
#[test]
fn macro_emits_default_cancel_on() {
let _client = Demo::with_runner(ScriptedRunner::new())
.default_cancel_on(crate::CancellationToken::new());
}
#[test]
fn macro_generates_all_constructors() {
let _real = Demo::new();
let _default = Demo::default();
let _fake = Demo::with_runner(ScriptedRunner::new())
.default_timeout(Duration::from_secs(1))
.default_env("GIT_TERMINAL_PROMPT", "0")
.default_env_remove("GIT_PAGER");
}
}