1use anyhow::{Result, anyhow};
2use escargot::CargoBuild;
3use serde::{Deserialize, Serialize};
4use std::{
5 collections::HashMap,
6 env, fmt, fs,
7 path::{Path, PathBuf},
8};
9use tempfile::TempDir;
10use tokio::{
11 fs::{DirBuilder, try_exists},
12 io::AsyncReadExt as _,
13 process::Command,
14};
15
16pub use commandeer_macros::commandeer;
17
18#[derive(Serialize, Deserialize, Debug, Clone)]
19pub struct CommandInvocation {
20 pub binary_name: String,
21 pub args: Vec<String>,
22 pub stdout: String,
23 pub stderr: String,
24 pub exit_code: i32,
25}
26
27#[derive(Serialize, Deserialize, Debug, Default)]
28pub struct RecordedCommands {
29 commands: HashMap<String, Vec<CommandInvocation>>,
30}
31
32impl RecordedCommands {
33 fn generate_key(binary_name: &str, args: &[String]) -> String {
34 format!("{binary_name}:{}", args.join(" "))
35 }
36
37 pub fn add_invocation(&mut self, invocation: CommandInvocation) {
38 let key = Self::generate_key(&invocation.binary_name, &invocation.args);
39
40 self.commands.entry(key).or_default().push(invocation);
41 }
42
43 pub fn find_invocation(
44 &self,
45 binary_name: &str,
46 args: &[String],
47 ) -> Option<&CommandInvocation> {
48 let key = Self::generate_key(binary_name, args);
49
50 self.commands.get(&key)?.first()
51 }
52}
53
54pub async fn load_recordings(file_path: &PathBuf) -> Result<RecordedCommands> {
55 let mut f = tokio::fs::File::options();
56
57 let mut contents = String::new();
58 f.create(true)
59 .write(true)
60 .read(true)
61 .open(file_path)
62 .await?
63 .read_to_string(&mut contents)
64 .await?;
65
66 if contents.trim().is_empty() {
67 return Ok(RecordedCommands::default());
68 }
69
70 let recordings: RecordedCommands = serde_json::from_str(&contents)?;
71
72 Ok(recordings)
73}
74
75pub async fn save_recordings(file_path: &PathBuf, recordings: &RecordedCommands) -> Result<()> {
76 let json = serde_json::to_string_pretty(recordings)?;
77
78 tokio::fs::write(file_path, json.as_bytes()).await?;
79
80 Ok(())
81}
82
83pub async fn record_command(
84 truncate: bool,
85 file_path: PathBuf,
86 command: String,
87 args: Vec<String>,
88) -> Result<CommandInvocation> {
89 let recording_dir = file_path
90 .parent()
91 .ok_or_else(|| anyhow!("Couldn't get parent of recording {}", file_path.display()))?;
92
93 DirBuilder::new()
94 .recursive(true)
95 .create(recording_dir)
96 .await?;
97
98 let mut recordings = if truncate {
99 if try_exists(&file_path).await? {
100 tokio::fs::remove_file(&file_path).await?;
101 }
102
103 RecordedCommands::default()
104 } else {
105 load_recordings(&file_path).await?
106 };
107
108 let output = Command::new(&command).args(&args).output().await?;
109
110 let invocation = CommandInvocation {
111 binary_name: command,
112 args,
113 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
114 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
115 exit_code: output.status.code().unwrap_or(-1),
116 };
117
118 recordings.add_invocation(invocation.clone());
119 save_recordings(&file_path, &recordings).await?;
120
121 Ok(invocation)
122}
123
124pub async fn replay_command(
125 file_path: PathBuf,
126 command: String,
127 args: Vec<String>,
128) -> Result<Option<CommandInvocation>> {
129 let recordings = load_recordings(&file_path).await?;
130
131 Ok(recordings.find_invocation(&command, &args).cloned())
132}
133
134pub fn output_invocation(invocation: &CommandInvocation) {
135 print!("{}", invocation.stdout);
136 eprint!("{}", invocation.stderr);
137}
138
139pub fn exit_with_code(code: i32) -> ! {
140 std::process::exit(code);
141}
142
143#[derive(Debug, Clone, Copy, Eq, PartialEq)]
144pub enum Mode {
145 Record,
146 Replay,
147}
148
149pub struct Commandeer {
150 mock_runner: escargot::CargoRun,
151 temp_dir: TempDir,
152 fixture: PathBuf,
153 mode: Mode,
154 original_path: String,
155}
156
157impl fmt::Display for Mode {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 match self {
160 Mode::Record => write!(f, "record"),
161 Mode::Replay => write!(f, "replay"),
162 }
163 }
164}
165
166impl Commandeer {
167 pub fn new(test_name: impl AsRef<Path>, mode: Mode) -> Self {
168 let dir = PathBuf::from(
169 std::env::var("CARGO_MANIFEST_DIR").expect("Failed to get crate directory."),
170 );
171
172 std::fs::DirBuilder::new()
173 .recursive(true)
174 .create(&dir)
175 .expect("Failed to create testcmds dir");
176
177 let fixture = dir.join("testcmds").join(test_name);
178
179 if fixture.exists() && mode == Mode::Record {
180 std::fs::remove_file(&fixture).expect("Failed to remove existing fixture file");
181 }
182
183 let mock_runner = CargoBuild::new()
184 .manifest_path(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("Cargo.toml"))
185 .package("commandeer-test")
186 .bin("commandeer")
187 .run()
188 .expect("Failed to build mock binary");
189
190 let temp_dir = TempDir::new().expect("Failed to create temp dir");
191
192 let original_path = std::env::var("PATH").unwrap_or_default();
193 let new_path = format!("{}:{original_path}", temp_dir.path().display());
194
195 unsafe {
196 std::env::set_var("PATH", new_path);
197 }
198
199 Self {
200 mock_runner,
201 temp_dir,
202 fixture,
203 mode,
204 original_path,
205 }
206 }
207 pub fn mock_command(&self, command_name: &str) -> PathBuf {
208 let mock_path = self.temp_dir.path().join(command_name);
209
210 let wrapper = format!(
211 r#"#!/usr/bin/env bash
212exec env PATH="{}" {} {} --file {} --command {command_name} "$@"
213"#,
214 self.original_path,
215 self.mock_runner.path().display(),
216 self.mode,
217 self.fixture.display(),
218 );
219
220 fs::write(&mock_path, wrapper).expect("Failed to write mock wrapper script");
221
222 #[cfg(unix)]
223 {
224 use std::os::unix::fs::PermissionsExt as _;
225
226 let mut perms = fs::metadata(&mock_path)
227 .expect("Could not get permissions")
228 .permissions();
229
230 perms.set_mode(0o755);
231
232 fs::set_permissions(&mock_path, perms).expect("Could not set permissions");
233 }
234
235 mock_path
236 }
237}
238
239impl Drop for Commandeer {
240 fn drop(&mut self) {
241 unsafe {
242 std::env::set_var("PATH", &self.original_path);
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use crate as commandeer_test;
250 use crate::{Commandeer, Mode, commandeer};
251
252 #[serial_test::serial]
253 fn test_mock_cmd() {
254 let commandeer = Commandeer::new("test_recordings.json", Mode::Replay);
255 let mock_path = commandeer.mock_command("echo");
256
257 let status = std::process::Command::new("echo")
258 .arg("foo")
259 .status()
260 .unwrap();
261
262 assert!(status.success());
263
264 assert!(mock_path.exists());
265 }
266
267 #[commandeer(Replay, "echo")]
268 #[serial_test::serial]
269 fn my_test() {
270 let output = std::process::Command::new("echo")
271 .arg("hello")
272 .output()
273 .unwrap();
274
275 assert!(output.status.success());
276 }
277
278 #[commandeer(Replay, "date")]
279 #[tokio::test]
280 #[serial_test::serial]
281 async fn async_replay() {
282 let output = std::process::Command::new("date").output().unwrap();
283
284 insta::assert_debug_snapshot!(output, @r#"
285 Output {
286 status: ExitStatus(
287 unix_wait_status(
288 0,
289 ),
290 ),
291 stdout: "Wed Aug 20 12:46:19 EDT 2025\n",
292 stderr: "",
293 }
294 "#);
295 }
296
297 #[commandeer(Replay, "git", "date")]
298 #[test]
299 #[serial_test::serial]
300 fn test_flag_args() {
301 let output = std::process::Command::new("git")
302 .arg("--version")
303 .output()
304 .unwrap();
305
306 insta::assert_debug_snapshot!(output, @r#"
307 Output {
308 status: ExitStatus(
309 unix_wait_status(
310 0,
311 ),
312 ),
313 stdout: "git version 2.51.0\n",
314 stderr: "",
315 }
316 "#);
317
318 let output = std::process::Command::new("date").output().unwrap();
319
320 insta::assert_debug_snapshot!(output, @r#"
321 Output {
322 status: ExitStatus(
323 unix_wait_status(
324 0,
325 ),
326 ),
327 stdout: "Thu Aug 21 14:54:45 EDT 2025\n",
328 stderr: "",
329 }
330 "#);
331 }
332}