commandeer_test/
lib.rs

1use anyhow::{Result, anyhow};
2use escargot::CargoBuild;
3use serde::{Deserialize, Serialize};
4use std::{
5    collections::HashMap,
6    env, fmt, fs,
7    path::{Path, PathBuf},
8};
9use tempfile::TempDir;
10use tokio::{
11    fs::{DirBuilder, try_exists},
12    io::AsyncReadExt as _,
13    process::Command,
14};
15
16pub use commandeer_macros::commandeer;
17
18#[derive(Serialize, Deserialize, Debug, Clone)]
19pub struct CommandInvocation {
20    pub binary_name: String,
21    pub args: Vec<String>,
22    pub stdout: String,
23    pub stderr: String,
24    pub exit_code: i32,
25}
26
27#[derive(Serialize, Deserialize, Debug, Default)]
28pub struct RecordedCommands {
29    commands: HashMap<String, Vec<CommandInvocation>>,
30}
31
32impl RecordedCommands {
33    fn generate_key(binary_name: &str, args: &[String]) -> String {
34        format!("{binary_name}:{}", args.join(" "))
35    }
36
37    pub fn add_invocation(&mut self, invocation: CommandInvocation) {
38        let key = Self::generate_key(&invocation.binary_name, &invocation.args);
39
40        self.commands.entry(key).or_default().push(invocation);
41    }
42
43    pub fn find_invocation(
44        &self,
45        binary_name: &str,
46        args: &[String],
47    ) -> Option<&CommandInvocation> {
48        let key = Self::generate_key(binary_name, args);
49
50        self.commands.get(&key)?.first()
51    }
52}
53
54pub async fn load_recordings(file_path: &PathBuf) -> Result<RecordedCommands> {
55    let mut f = tokio::fs::File::options();
56
57    let mut contents = String::new();
58    f.create(true)
59        .write(true)
60        .read(true)
61        .open(file_path)
62        .await?
63        .read_to_string(&mut contents)
64        .await?;
65
66    if contents.trim().is_empty() {
67        return Ok(RecordedCommands::default());
68    }
69
70    let recordings: RecordedCommands = serde_json::from_str(&contents)?;
71
72    Ok(recordings)
73}
74
75pub async fn save_recordings(file_path: &PathBuf, recordings: &RecordedCommands) -> Result<()> {
76    let json = serde_json::to_string_pretty(recordings)?;
77
78    tokio::fs::write(file_path, json.as_bytes()).await?;
79
80    Ok(())
81}
82
83pub async fn record_command(
84    truncate: bool,
85    file_path: PathBuf,
86    command: String,
87    args: Vec<String>,
88) -> Result<CommandInvocation> {
89    let recording_dir = file_path
90        .parent()
91        .ok_or_else(|| anyhow!("Couldn't get parent of recording {}", file_path.display()))?;
92
93    DirBuilder::new()
94        .recursive(true)
95        .create(recording_dir)
96        .await?;
97
98    let mut recordings = if truncate {
99        if try_exists(&file_path).await? {
100            tokio::fs::remove_file(&file_path).await?;
101        }
102
103        RecordedCommands::default()
104    } else {
105        load_recordings(&file_path).await?
106    };
107
108    let output = Command::new(&command).args(&args).output().await?;
109
110    let invocation = CommandInvocation {
111        binary_name: command,
112        args,
113        stdout: String::from_utf8_lossy(&output.stdout).to_string(),
114        stderr: String::from_utf8_lossy(&output.stderr).to_string(),
115        exit_code: output.status.code().unwrap_or(-1),
116    };
117
118    recordings.add_invocation(invocation.clone());
119    save_recordings(&file_path, &recordings).await?;
120
121    Ok(invocation)
122}
123
124pub async fn replay_command(
125    file_path: PathBuf,
126    command: String,
127    args: Vec<String>,
128) -> Result<Option<CommandInvocation>> {
129    let recordings = load_recordings(&file_path).await?;
130
131    Ok(recordings.find_invocation(&command, &args).cloned())
132}
133
134pub fn output_invocation(invocation: &CommandInvocation) {
135    print!("{}", invocation.stdout);
136    eprint!("{}", invocation.stderr);
137}
138
139pub fn exit_with_code(code: i32) -> ! {
140    std::process::exit(code);
141}
142
143#[derive(Debug, Clone, Copy, Eq, PartialEq)]
144pub enum Mode {
145    Record,
146    Replay,
147}
148
149pub struct Commandeer {
150    mock_runner: escargot::CargoRun,
151    temp_dir: TempDir,
152    fixture: PathBuf,
153    mode: Mode,
154    original_path: String,
155}
156
157impl fmt::Display for Mode {
158    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159        match self {
160            Mode::Record => write!(f, "record"),
161            Mode::Replay => write!(f, "replay"),
162        }
163    }
164}
165
166impl Commandeer {
167    pub fn new(test_name: impl AsRef<Path>, mode: Mode) -> Self {
168        let dir = PathBuf::from(
169            std::env::var("CARGO_MANIFEST_DIR").expect("Failed to get crate directory."),
170        );
171
172        std::fs::DirBuilder::new()
173            .recursive(true)
174            .create(&dir)
175            .expect("Failed to create testcmds dir");
176
177        let fixture = dir.join("testcmds").join(test_name);
178
179        if fixture.exists() && mode == Mode::Record {
180            std::fs::remove_file(&fixture).expect("Failed to remove existing fixture file");
181        }
182
183        let mock_runner = CargoBuild::new()
184            .manifest_path(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("Cargo.toml"))
185            .package("commandeer-test")
186            .bin("commandeer")
187            .run()
188            .expect("Failed to build mock binary");
189
190        let temp_dir = TempDir::new().expect("Failed to create temp dir");
191
192        let original_path = std::env::var("PATH").unwrap_or_default();
193        let new_path = format!("{}:{original_path}", temp_dir.path().display());
194
195        unsafe {
196            std::env::set_var("PATH", new_path);
197        }
198
199        Self {
200            mock_runner,
201            temp_dir,
202            fixture,
203            mode,
204            original_path,
205        }
206    }
207    pub fn mock_command(&self, command_name: &str) -> PathBuf {
208        let mock_path = self.temp_dir.path().join(command_name);
209
210        let wrapper = format!(
211            r#"#!/usr/bin/env bash
212exec env PATH="{}" {} {} --file {} --command {command_name} "$@"
213"#,
214            self.original_path,
215            self.mock_runner.path().display(),
216            self.mode,
217            self.fixture.display(),
218        );
219
220        fs::write(&mock_path, wrapper).expect("Failed to write mock wrapper script");
221
222        #[cfg(unix)]
223        {
224            use std::os::unix::fs::PermissionsExt as _;
225
226            let mut perms = fs::metadata(&mock_path)
227                .expect("Could not get permissions")
228                .permissions();
229
230            perms.set_mode(0o755);
231
232            fs::set_permissions(&mock_path, perms).expect("Could not set permissions");
233        }
234
235        mock_path
236    }
237}
238
239impl Drop for Commandeer {
240    fn drop(&mut self) {
241        unsafe {
242            std::env::set_var("PATH", &self.original_path);
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use crate as commandeer_test;
250    use crate::{Commandeer, Mode, commandeer};
251
252    #[serial_test::serial]
253    fn test_mock_cmd() {
254        let commandeer = Commandeer::new("test_recordings.json", Mode::Replay);
255        let mock_path = commandeer.mock_command("echo");
256
257        let status = std::process::Command::new("echo")
258            .arg("foo")
259            .status()
260            .unwrap();
261
262        assert!(status.success());
263
264        assert!(mock_path.exists());
265    }
266
267    #[commandeer(Replay, "echo")]
268    #[serial_test::serial]
269    fn my_test() {
270        let output = std::process::Command::new("echo")
271            .arg("hello")
272            .output()
273            .unwrap();
274
275        assert!(output.status.success());
276    }
277
278    #[commandeer(Replay, "date")]
279    #[tokio::test]
280    #[serial_test::serial]
281    async fn async_replay() {
282        let output = std::process::Command::new("date").output().unwrap();
283
284        insta::assert_debug_snapshot!(output, @r#"
285        Output {
286            status: ExitStatus(
287                unix_wait_status(
288                    0,
289                ),
290            ),
291            stdout: "Wed Aug 20 12:46:19 EDT 2025\n",
292            stderr: "",
293        }
294        "#);
295    }
296
297    #[commandeer(Replay, "git", "date")]
298    #[test]
299    #[serial_test::serial]
300    fn test_flag_args() {
301        let output = std::process::Command::new("git")
302            .arg("--version")
303            .output()
304            .unwrap();
305
306        insta::assert_debug_snapshot!(output, @r#"
307        Output {
308            status: ExitStatus(
309                unix_wait_status(
310                    0,
311                ),
312            ),
313            stdout: "git version 2.51.0\n",
314            stderr: "",
315        }
316        "#);
317
318        let output = std::process::Command::new("date").output().unwrap();
319
320        insta::assert_debug_snapshot!(output, @r#"
321        Output {
322            status: ExitStatus(
323                unix_wait_status(
324                    0,
325                ),
326            ),
327            stdout: "Thu Aug 21 14:54:45 EDT 2025\n",
328            stderr: "",
329        }
330        "#);
331    }
332}