1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 Expr, Ident, ItemFn, Result,
5 parse::{Parse, ParseStream},
6 parse_macro_input, parse_quote,
7};
8
9struct CommandeerArgs {
10 mode: Ident,
11 commands: Vec<String>,
12}
13
14const RECORD: &str = "Record";
15const REPLAY: &str = "Replay";
16
17impl Parse for CommandeerArgs {
18 fn parse(input: ParseStream) -> Result<Self> {
19 let mut commands = vec![];
20
21 let ident: Ident = input.parse()?;
22
23 let mode = match ident.to_string().as_str() {
24 x if [RECORD, REPLAY].contains(&x) => Ident::new(x, proc_macro2::Span::call_site()),
25 _ => {
26 return Err(syn::Error::new(
27 ident.span(),
28 format!("Expected '{RECORD}' or '{REPLAY}'"),
29 ));
30 }
31 };
32
33 input.parse::<syn::Token![,]>()?;
34
35 while !input.is_empty() {
36 if input.peek(syn::LitStr) {
37 let lit: syn::LitStr = input.parse()?;
38
39 commands.push(lit.value());
40 } else {
41 return Err(input.error("Expected a command string"));
42 }
43
44 if input.peek(syn::Token![,]) {
45 input.parse::<syn::Token![,]>()?;
46 }
47 }
48
49 if commands.is_empty() {
50 return Err(syn::Error::new(
51 input.span(),
52 "Expected at least one command string",
53 ));
54 }
55
56 Ok(CommandeerArgs { mode, commands })
57 }
58}
59
60#[proc_macro_attribute]
66pub fn commandeer(args: TokenStream, input: TokenStream) -> TokenStream {
67 let args = parse_macro_input!(args as CommandeerArgs);
68 let mut input_fn = parse_macro_input!(input as ItemFn);
69
70 let fn_name = &input_fn.sig.ident;
71
72 let test_file_name = format!("cmds_{fn_name}.json");
73
74 let mock_commands: Vec<Expr> = args
77 .commands
78 .iter()
79 .map(|cmd| {
80 parse_quote! {
81 commandeer.mock_command(#cmd)
82 }
83 })
84 .collect();
85
86 let mode = args.mode;
87
88 let setup_stmts: Vec<syn::Stmt> = vec![parse_quote! {
90 let commandeer = commandeer_test::Commandeer::new(#test_file_name, commandeer_test::Mode::#mode);
91 }];
92
93 let mock_stmts: Vec<syn::Stmt> = mock_commands
94 .iter()
95 .map(|expr| {
96 parse_quote! {
97 #expr;
98 }
99 })
100 .collect();
101
102 let mut new_stmts = setup_stmts;
104 new_stmts.extend(mock_stmts);
105 new_stmts.extend(input_fn.block.stmts);
106
107 input_fn.block.stmts = new_stmts;
108
109 let body_str = quote!(#input_fn).to_string();
110
111 if body_str.contains("local_serial_core") {
112 return syn::Error::new_spanned(
113 input_fn.sig.fn_token,
114 "Out of order error. `commandeer` macro must be above the `serial_test` macro.",
115 )
116 .to_compile_error()
117 .into();
118 }
119
120 TokenStream::from(quote! { #input_fn })
121}