oxi_test/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span};
3use quote::quote;
4use syn::{parse_macro_input, Error};
5
6/// Tests a piece of code inside a Neovim session.
7///
8/// # Examples
9///
10/// ```ignore
11/// use nvim_oxi::{self as nvim, api};
12///
13/// #[nvim::test]
14/// fn set_get_del_var() {
15///     api::set_var("foo", 42).unwrap();
16///     assert_eq!(Ok(42), api::get_var("foo"));
17///     assert_eq!(Ok(()), api::del_var("foo"));
18/// }
19/// ```
20#[proc_macro_attribute]
21pub fn oxi_test(attr: TokenStream, item: TokenStream) -> TokenStream {
22    let args = parse_macro_input!(attr as syn::AttributeArgs);
23
24    if !args.is_empty() {
25        return Error::new(Span::call_site(), "no attributes are supported")
26            .to_compile_error()
27            .into();
28    }
29
30    let item = parse_macro_input!(item as syn::ItemFn);
31
32    let syn::ItemFn { sig, block, .. } = item;
33
34    // TODO: here we'd need to append something like the module path of the
35    // call site to `test_name` to avoid collisions between equally named tests
36    // across different modules. Unfortunately that doesn't seem to be possible
37    // yet?
38    // See https://www.reddit.com/r/rust/comments/a3fgp6/procmacro_determining_the_callers_module_path/
39    let test_name = sig.ident;
40    let test_body = block;
41
42    let module_name = Ident::new(&format!("__{test_name}"), Span::call_site());
43
44    quote! {
45        #[test]
46        fn #test_name() {
47            let mut library_filename = String::new();
48            library_filename.push_str(::std::env::consts::DLL_PREFIX);
49            library_filename.push_str(env!("CARGO_CRATE_NAME"));
50            library_filename.push_str(::std::env::consts::DLL_SUFFIX);
51
52            let mut target_filename = String::from("__");
53            target_filename.push_str(stringify!(#test_name));
54
55            #[cfg(not(target_os = "macos"))]
56            target_filename.push_str(::std::env::consts::DLL_SUFFIX);
57
58            #[cfg(target_os = "macos")]
59            target_filename.push_str(".so");
60
61            let target_dir =
62                ::std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
63                    .join("target")
64                    .join("debug");
65
66            let library_filepath = target_dir.join(library_filename);
67
68            if !library_filepath.exists() {
69                panic!(
70                    "Compiled library not found in '{}'. Please run `cargo \
71                     build` before running the tests.",
72                    library_filepath.display()
73                )
74            }
75
76            let target_filepath =
77                target_dir.join("oxi-test").join("lua").join(target_filename);
78
79            if !target_filepath.parent().unwrap().exists() {
80                if let Err(err) = ::std::fs::create_dir_all(
81                    target_filepath.parent().unwrap(),
82                ) {
83                    // It might happen that another test created the `lua`
84                    // directory between the first if and the `create_dir_all`.
85                    if !matches!(
86                        err.kind(),
87                        ::std::io::ErrorKind::AlreadyExists
88                    ) {
89                        panic!("{}", err)
90                    }
91                }
92            }
93
94            #[cfg(unix)]
95            let res = ::std::os::unix::fs::symlink(
96                &library_filepath,
97                &target_filepath,
98            );
99
100            #[cfg(windows)]
101            let res = ::std::os::windows::fs::symlink_file(
102                &library_filepath,
103                &target_filepath,
104            );
105
106            if let Err(err) = res {
107                if !matches!(err.kind(), ::std::io::ErrorKind::AlreadyExists) {
108                    panic!("{}", err)
109                }
110            }
111
112            let out = ::std::process::Command::new("nvim")
113                .args(["-u", "NONE", "--headless"])
114                .args(["-c", "set noswapfile"])
115                .args([
116                    "-c",
117                    &format!(
118                        "set rtp+={}",
119                        target_dir.join("oxi-test").display()
120                    ),
121                ])
122                .args([
123                    "-c",
124                    &format!("lua require('__{}')", stringify!(#test_name)),
125                ])
126                .args(["+quit"])
127                .output()
128                .expect("Couldn't find `nvim` binary in $PATH!");
129
130            let stderr = String::from_utf8_lossy(&out.stderr);
131
132            if !stderr.is_empty() {
133                // Remove the last 2 lines from stderr for a cleaner error msg.
134                let stderr = {
135                    let lines = stderr.lines().collect::<Vec<_>>();
136                    let len = lines.len();
137                    lines[..lines.len() - 2].join("\n")
138                };
139
140                // The first 31 bytes are `thread '<unnamed>' panicked at `.
141                let (_, stderr) = stderr.split_at(31);
142
143                panic!("{}", stderr)
144            }
145        }
146
147        #[::nvim_oxi::module]
148        fn #module_name() -> ::nvim_oxi::Result<()> {
149            let result = ::std::panic::catch_unwind(|| {
150                #test_body
151            });
152
153            ::std::process::exit(match result {
154                Ok(_) => 0,
155
156                Err(err) => {
157                    eprintln!("{:?}", err);
158                    1
159                },
160            })
161        }
162    }
163    .into()
164}