debugger_test/
lib.rs

1mod debugger;
2mod debugger_script;
3
4use std::str::FromStr;
5
6use debugger::DebuggerType;
7use proc_macro::TokenStream;
8use quote::{format_ident, quote, ToTokens};
9use syn::{parse::Parse, Token};
10
11use crate::debugger_script::create_debugger_script;
12
13struct DebuggerTest {
14    debugger: String,
15    commands: String,
16    expected_statements: String,
17}
18
19impl Parse for DebuggerTest {
20    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
21        let debugger_meta = input.parse::<syn::MetaNameValue>()?;
22        let debugger = if debugger_meta.path.is_ident("debugger") {
23            match debugger_meta.lit {
24                syn::Lit::Str(lit_str) => lit_str.value(),
25                _ => {
26                    return Err(input.error("Expected a literal string for the value of `debugger`"))
27                }
28            }
29        } else {
30            return Err(input.error("Expected value `debugger`"));
31        };
32
33        input.parse::<Token![,]>()?;
34
35        let commands_meta = input.parse::<syn::MetaNameValue>()?;
36        let commands = if commands_meta.path.is_ident("commands") {
37            match commands_meta.lit {
38                syn::Lit::Str(lit_str) => lit_str.value(),
39                _ => {
40                    return Err(input.error("Expected a literal string for the value of `commands`"))
41                }
42            }
43        } else {
44            return Err(input.error("Expected value `commands`"));
45        };
46
47        input.parse::<Token![,]>()?;
48
49        let expected_statements_meta = input.parse::<syn::MetaNameValue>()?;
50        let expected_statements = if expected_statements_meta
51            .path
52            .is_ident("expected_statements")
53        {
54            match expected_statements_meta.lit {
55                syn::Lit::Str(lit_str) => lit_str.value(),
56                _ => {
57                    return Err(input
58                        .error("Expected a literal string for the value of `expected_statements`"))
59                }
60            }
61        } else {
62            return Err(input.error("Expected value `expected_statements`"));
63        };
64
65        Ok(DebuggerTest {
66            debugger,
67            commands,
68            expected_statements,
69        })
70    }
71}
72
73#[proc_macro_attribute]
74pub fn debugger_test(attr: TokenStream, item: TokenStream) -> TokenStream {
75    let invoc = match syn::parse::<DebuggerTest>(attr) {
76        Ok(s) => s,
77        Err(e) => return e.to_compile_error().into(),
78    };
79
80    let item = match syn::parse::<syn::Item>(item) {
81        Ok(s) => s,
82        Err(e) => return e.to_compile_error().into(),
83    };
84
85    let func = match item {
86        syn::Item::Fn(ref f) => f,
87        _ => panic!("must be attached to a function"),
88    };
89
90    let debugger_commands = &invoc
91        .commands
92        .trim()
93        .lines()
94        .into_iter()
95        .map(|line| line.trim())
96        .collect::<Vec<&str>>();
97
98    let debugger_type = DebuggerType::from_str(invoc.debugger.as_str()).expect(
99        format!(
100            "debugger `{}` must be a valid debugger option.",
101            invoc.debugger.as_str()
102        )
103        .as_str(),
104    );
105    let debugger_executable_path = debugger::get_debugger(&debugger_type);
106
107    let fn_name = func.sig.ident.to_string();
108    let fn_ident = format_ident!("{}", fn_name);
109    let test_fn_name = format!("{}__{}", fn_name, debugger_type.to_string());
110    let test_fn_ident = format_ident!("{}", test_fn_name);
111
112    let debugger_script_contents = create_debugger_script(&fn_name, debugger_commands);
113
114    // Trim all whitespace and remove any empty lines.
115    let expected_statements = &invoc
116        .expected_statements
117        .trim()
118        .lines()
119        .collect::<Vec<&str>>();
120
121    // Create the cli for the given debugger.
122    let (debugger_command_line, cfg_attr) = match debugger_type {
123        DebuggerType::Cdb => {
124            let debugger_path = debugger_executable_path.to_string_lossy().to_string();
125            let command_line = quote!(
126                match std::process::Command::new(#debugger_path)
127                    .stdout(std::process::Stdio::from(debugger_stdout_file))
128                    .stderr(std::process::Stdio::from(debugger_stderr_file))
129                    .arg("-pd")
130                    .arg("-p")
131                    .arg(pid.to_string())
132                    .arg("-cf")
133                    .arg(&debugger_script_path)
134                    .spawn() {
135                        Ok(child) => child,
136                        Err(error) => {
137                            return Err(std::boxed::Box::from(format!("Failed to launch CDB: {}\n", error.to_string())));
138                        }
139                }
140            );
141
142            // cdb is only supported on Windows.
143            let cfg_attr = quote!(
144                #[cfg_attr(not(target_os = "windows"), ignore = "test only runs on windows platforms.")]
145            );
146
147            (command_line, cfg_attr)
148        }
149    };
150
151    // Create the test function that will launch the debugger and run debugger commands.
152    let mut debugger_test_fn = proc_macro::TokenStream::from(quote!(
153        #[test]
154        #cfg_attr
155        fn #test_fn_ident() -> std::result::Result<(), Box<dyn std::error::Error>> {
156            use std::io::Read;
157            use std::io::Write;
158
159            let pid = std::process::id();
160            let current_exe_filename = std::env::current_exe()?.file_stem().expect("must have a valid file name").to_string_lossy().to_string();
161
162            // Create a temporary file to store the debugger script to run.
163            let debugger_script_filename = format!("{}_{}.debugger_script", current_exe_filename, #test_fn_name);
164            let debugger_script_path = std::env::temp_dir().join(debugger_script_filename);
165
166            // Write the contents of the debugger script to a new file.
167            let mut debugger_script = std::fs::File::create(&debugger_script_path)?;
168            writeln!(debugger_script, #debugger_script_contents)?;
169
170            // Create a temporary file to store the stdout and stderr from the debugger output.
171            let debugger_stdout_path = debugger_script_path.with_extension("debugger_out");
172            let debugger_stderr_path = debugger_script_path.with_extension("debugger_err");
173
174            let debugger_stdout_file = std::fs::File::create(&debugger_stdout_path)?;
175            let debugger_stderr_file = std::fs::File::create(&debugger_stderr_path)?;
176
177            // Start the debugger and run the debugger commands.
178            let mut child = #debugger_command_line;
179
180            // Wait for the debugger to launch
181            // On Windows, use the IsDebuggerPresent API to check if a debugger is present
182            // for the current process. https://docs.microsoft.com/en-us/windows/win32/api/debugapi/nf-debugapi-isdebuggerpresent
183            #[cfg(windows)]
184            extern "stdcall" {
185                fn IsDebuggerPresent() -> i32;
186            };
187            #[cfg(windows)]
188            unsafe {
189                while IsDebuggerPresent() == 0 {
190                    std::thread::sleep(std::time::Duration::from_secs(1));
191                }
192            }
193
194            // Wait 3 seconds to ensure the debugger is in control of the process.
195            std::thread::sleep(std::time::Duration::from_secs(3));
196
197            // Call the test function.
198            #fn_ident();
199
200            // Wait for the debugger to exit.
201            std::thread::sleep(std::time::Duration::from_secs(3));
202
203            // If debugger has not already quit, force quit the debugger.
204            let mut debugger_stdout = String::new();
205            match child.try_wait()? {
206                Some(status) => {
207                    // Bail early if the debugger process didn't execute successfully.
208                    let mut debugger_stdout_file = std::fs::File::open(&debugger_stdout_path)?;
209                    debugger_stdout_file.read_to_string(&mut debugger_stdout)?;
210
211                    if !status.success() {
212                        let mut debugger_stderr = String::new();
213                        let mut debugger_stderr_file = std::fs::File::open(&debugger_stderr_path)?;
214                        debugger_stderr_file.read_to_string(&mut debugger_stderr)?;
215                        return Err(std::boxed::Box::from(format!("Debugger failed with {}.\n{}\n{}\n", status, debugger_stderr, debugger_stdout)));
216                    }
217
218                    println!("Debugger stdout:\n{}\n", &debugger_stdout);
219                },
220                None => {
221                    // Force kill the debugger process if it has not exited yet.
222                    println!("killing debugger process.");
223                    child.kill().expect("debugger has been running for too long");
224
225                    let mut debugger_stdout_file = std::fs::File::open(&debugger_stdout_path)?;
226                    debugger_stdout_file.read_to_string(&mut debugger_stdout)?;
227                    println!("Debugger stdout:\n{}\n", &debugger_stdout);
228                }
229            }
230
231            // Verify the expected contents of the debugger output.
232            let expected_statements = vec![#(#expected_statements),*];
233            debugger_test_parser::parse(debugger_stdout, expected_statements)?;
234
235            #[cfg(windows)]
236            unsafe {
237                while IsDebuggerPresent() == 1 {
238                    std::thread::sleep(std::time::Duration::from_secs(1));
239                }
240            }
241
242            #[cfg(not(windows))]
243            std::thread::sleep(std::time::Duration::from_secs(3));
244
245            Ok(())
246        }
247    ));
248
249    debugger_test_fn.extend(proc_macro::TokenStream::from(item.to_token_stream()).into_iter());
250    debugger_test_fn
251}