use std::sync::Mutex;
use crate::cmd::{Cmd, RunOutput};
use crate::cmd_display::CmdDisplay;
use crate::error::RunError;
use crate::runner_trait::Runner;
pub struct MockRunner {
expectations: Mutex<Vec<Expectation>>,
panic_on_no_match: bool,
}
struct Expectation {
matcher: Matcher,
responder: Responder,
calls: usize,
matcher_desc: String,
}
enum Matcher {
Display(String),
Predicate(Box<dyn Fn(&Cmd) -> bool + Send + Sync>),
}
impl Matcher {
fn matches(&self, cmd: &Cmd) -> bool {
match self {
Self::Display(expected) => &cmd.display().to_string() == expected,
Self::Predicate(f) => f(cmd),
}
}
}
enum Responder {
Once { result: Option<MockResult> },
Bounded {
factory: Box<dyn FnMut() -> MockResult + Send>,
remaining: usize,
},
Unlimited {
factory: Box<dyn FnMut() -> MockResult + Send>,
},
}
impl Responder {
fn take(&mut self) -> Option<MockResult> {
match self {
Self::Once { result } => result.take(),
Self::Bounded { factory, remaining } => {
if *remaining == 0 {
None
} else {
*remaining -= 1;
Some(factory())
}
}
Self::Unlimited { factory } => Some(factory()),
}
}
fn exhausted(&self) -> bool {
match self {
Self::Once { result } => result.is_none(),
Self::Bounded { remaining, .. } => *remaining == 0,
Self::Unlimited { .. } => false,
}
}
}
impl MockRunner {
pub fn new() -> Self {
Self {
expectations: Mutex::new(Vec::new()),
panic_on_no_match: true,
}
}
pub fn error_on_no_match(mut self) -> Self {
self.panic_on_no_match = false;
self
}
pub fn expect(self, display: impl Into<String>, result: MockResult) -> Self {
let display = display.into();
let desc = format!("display = {display:?}");
self.push(Matcher::Display(display), Responder::Once { result: Some(result) }, desc)
}
pub fn expect_when<F>(self, matcher: F, result: MockResult) -> Self
where
F: Fn(&Cmd) -> bool + Send + Sync + 'static,
{
self.push(
Matcher::Predicate(Box::new(matcher)),
Responder::Once { result: Some(result) },
"<predicate>".to_string(),
)
}
pub fn expect_repeated<F>(self, display: impl Into<String>, times: usize, factory: F) -> Self
where
F: FnMut() -> MockResult + Send + 'static,
{
let display = display.into();
let desc = format!("display = {display:?} (×{times})");
self.push(
Matcher::Display(display),
Responder::Bounded {
factory: Box::new(factory),
remaining: times,
},
desc,
)
}
pub fn expect_when_repeated<M, F>(self, matcher: M, times: usize, factory: F) -> Self
where
M: Fn(&Cmd) -> bool + Send + Sync + 'static,
F: FnMut() -> MockResult + Send + 'static,
{
self.push(
Matcher::Predicate(Box::new(matcher)),
Responder::Bounded {
factory: Box::new(factory),
remaining: times,
},
format!("<predicate> (×{times})"),
)
}
pub fn expect_always<F>(self, display: impl Into<String>, factory: F) -> Self
where
F: FnMut() -> MockResult + Send + 'static,
{
let display = display.into();
let desc = format!("display = {display:?} (∞)");
self.push(
Matcher::Display(display),
Responder::Unlimited { factory: Box::new(factory) },
desc,
)
}
pub fn expect_when_always<M, F>(self, matcher: M, factory: F) -> Self
where
M: Fn(&Cmd) -> bool + Send + Sync + 'static,
F: FnMut() -> MockResult + Send + 'static,
{
self.push(
Matcher::Predicate(Box::new(matcher)),
Responder::Unlimited { factory: Box::new(factory) },
"<predicate> (∞)".to_string(),
)
}
fn push(self, matcher: Matcher, responder: Responder, matcher_desc: String) -> Self {
self.expectations
.lock()
.expect("MockRunner mutex poisoned")
.push(Expectation {
matcher,
responder,
calls: 0,
matcher_desc,
});
self
}
pub fn verify(&self) -> Result<(), String> {
let exps = self
.expectations
.lock()
.expect("MockRunner mutex poisoned");
let unmet: Vec<_> = exps.iter().filter(|e| e.calls == 0).collect();
if unmet.is_empty() {
Ok(())
} else {
let descriptions: Vec<_> = unmet.iter().map(|e| e.matcher_desc.clone()).collect();
Err(format!(
"{} unmet MockRunner expectation(s):\n - {}",
unmet.len(),
descriptions.join("\n - "),
))
}
}
}
impl Default for MockRunner {
fn default() -> Self {
Self::new()
}
}
impl Runner for MockRunner {
fn run(&self, cmd: Cmd) -> Result<RunOutput, RunError> {
let display = cmd.display();
let no_match_msg = {
let mut exps = self
.expectations
.lock()
.expect("MockRunner mutex poisoned");
for exp in exps.iter_mut() {
if exp.responder.exhausted() {
continue;
}
if exp.matcher.matches(&cmd)
&& let Some(mock_result) = exp.responder.take()
{
exp.calls += 1;
return mock_result.resolve(&display);
}
}
let registered: Vec<_> = exps
.iter()
.map(|e| {
format!(
"{} (calls={}, exhausted={})",
e.matcher_desc,
e.calls,
e.responder.exhausted()
)
})
.collect();
format!(
"MockRunner: no matching expectation for command:\n {}\nregistered:\n - {}",
display,
registered.join("\n - "),
)
};
if self.panic_on_no_match {
panic!("{no_match_msg}");
}
Err(RunError::Spawn {
command: display,
source: std::io::Error::other(no_match_msg),
})
}
}
pub enum MockResult {
Ok {
stdout: Vec<u8>,
stderr: String,
},
NonZeroExit {
code: i32,
stdout: Vec<u8>,
stderr: String,
},
Spawn {
source: std::io::Error,
},
Timeout {
elapsed: std::time::Duration,
stdout: Vec<u8>,
stderr: String,
},
}
impl MockResult {
pub fn resolve(self, command: &CmdDisplay) -> Result<RunOutput, RunError> {
match self {
Self::Ok { stdout, stderr } => Ok(RunOutput { stdout, stderr }),
Self::NonZeroExit {
code,
stdout,
stderr,
} => Err(RunError::NonZeroExit {
command: command.clone(),
status: build_exit_status(code),
stdout,
stderr,
}),
Self::Spawn { source } => Err(RunError::Spawn {
command: command.clone(),
source,
}),
Self::Timeout {
elapsed,
stdout,
stderr,
} => Err(RunError::Timeout {
command: command.clone(),
elapsed,
stdout,
stderr,
}),
}
}
}
pub fn ok(stdout: impl Into<Vec<u8>>) -> MockResult {
MockResult::Ok {
stdout: stdout.into(),
stderr: String::new(),
}
}
pub fn ok_str(stdout: impl Into<String>) -> MockResult {
MockResult::Ok {
stdout: stdout.into().into_bytes(),
stderr: String::new(),
}
}
pub fn nonzero(code: i32, stderr: impl Into<String>) -> MockResult {
MockResult::NonZeroExit {
code,
stdout: vec![],
stderr: stderr.into(),
}
}
pub fn spawn_error(message: impl Into<String>) -> MockResult {
MockResult::Spawn {
source: std::io::Error::other(message.into()),
}
}
pub fn timeout(elapsed: std::time::Duration, stderr: impl Into<String>) -> MockResult {
MockResult::Timeout {
elapsed,
stdout: vec![],
stderr: stderr.into(),
}
}
#[cfg(unix)]
fn build_exit_status(code: i32) -> std::process::ExitStatus {
use std::os::unix::process::ExitStatusExt;
std::process::ExitStatus::from_raw(code << 8)
}
#[cfg(windows)]
fn build_exit_status(code: i32) -> std::process::ExitStatus {
use std::os::windows::process::ExitStatusExt;
std::process::ExitStatus::from_raw(code as u32)
}