extern crate proc_macro;
use {
proc_macro2::{Span, TokenStream},
quote::{quote, quote_spanned},
std::{
env,
fs::File,
io::prelude::*,
path::{Path, PathBuf},
},
syn::parse::Parse,
};
fn compile_error(s: &str, span: Span) -> TokenStream {
quote_spanned!(span=> compile_error! { #s })
}
struct AttrArgs {
ser: syn::ExprPath,
de: syn::ExprPath,
value: Option<syn::Type>,
file: syn::LitStr,
}
impl Parse for AttrArgs {
fn parse(input: &syn::parse::ParseBuffer<'_>) -> syn::parse::Result<Self> {
mod kw {
syn::custom_keyword!(exact);
syn::custom_keyword!(file);
syn::custom_keyword!(ser);
syn::custom_keyword!(de);
syn::custom_keyword!(value);
syn::custom_keyword!(serde);
}
let _: kw::exact = input.parse()?;
let _: syn::Token![,] = input.parse()?;
let la = input.lookahead1();
let (ser, de, value) = if la.peek(kw::serde) {
let _: kw::serde = input.parse()?;
let _: syn::Token![=] = input.parse()?;
let format: syn::ExprPath = input.parse()?;
let _: syn::Token![,] = input.parse()?;
let la = input.lookahead1();
if !(la.peek(kw::ser) || la.peek(kw::de) || la.peek(kw::value) | la.peek(kw::file)) {
return Err(la.error())?;
}
let ser: syn::ExprPath = if input.peek(kw::ser) {
let _: kw::ser = input.parse()?;
let _: syn::Token![=] = input.parse()?;
let ser: syn::ExprPath = input.parse()?;
let _: syn::Token![,] = input.parse()?;
let la = input.lookahead1();
if !(la.peek(kw::de) || la.peek(kw::value) | la.peek(kw::file)) {
return Err(la.error())?;
}
ser
} else {
syn::parse_quote!(#format::to_string)
};
let de: syn::ExprPath = if input.peek(kw::de) {
let _: kw::de = input.parse()?;
let _: syn::Token![=] = input.parse()?;
let de: syn::ExprPath = input.parse()?;
let _: syn::Token![,] = input.parse()?;
let la = input.lookahead1();
if !(la.peek(kw::value) | la.peek(kw::file)) {
return Err(la.error())?;
}
de
} else {
syn::parse_quote!(#format::from_str)
};
let value: syn::Type = if input.peek(kw::value) {
let _: kw::value = input.parse()?;
let _: syn::Token![=] = input.parse()?;
let value: syn::Type = input.parse()?;
let _: syn::Token![,] = input.parse()?;
value
} else {
syn::parse_quote!(#format::Value)
};
(ser, de, Some(value))
} else if la.peek(kw::ser) {
let _: kw::ser = input.parse()?;
let _: syn::Token![=] = input.parse()?;
let ser: syn::ExprPath = input.parse()?;
let _: syn::Token![,] = input.parse()?;
let _: kw::de = input.parse()?;
let _: syn::Token![=] = input.parse()?;
let de: syn::ExprPath = input.parse()?;
let _: syn::Token![,] = input.parse()?;
let value = if input.peek(kw::value) {
let _: kw::value = input.parse()?;
let _: syn::Token![=] = input.parse()?;
let value: syn::Type = input.parse()?;
let _: syn::Token![,] = input.parse()?;
Some(value)
} else {
None
};
(ser, de, value)
} else {
return Err(la.error())?;
};
let _: kw::file = input.parse()?;
let _: syn::Token![=] = input.parse()?;
let file: syn::LitStr = input.parse()?;
Ok(AttrArgs {
ser,
de,
value,
file,
})
}
}
struct Test {
name: syn::Ident,
input: String,
output: String,
}
fn read_tests(file_path: &Path, span: Span) -> Result<Vec<Test>, TokenStream> {
let source = {
let mut f = File::open(file_path)
.map_err(|e| compile_error(&format!("failed to open file: {}", e), span))?;
let mut s = String::with_capacity(f.metadata().map(|m| m.len() as usize + 1).unwrap_or(0));
f.read_to_string(&mut s)
.map_err(|e| compile_error(&format!("failed to read file: {}", e), span))?;
s
};
if !source.ends_with('\n') {
return Err(compile_error("file needs to have trailing newline", span));
}
let (s, trailing) = source.split_at(source.rfind("\n...\n").map_or(0, |i| i + 5));
if !trailing.trim().is_empty() {
return Err(compile_error(
"file has disallowed content after final `...`",
span,
));
}
let mut tests = Vec::new();
let mut errs = TokenStream::new();
for (i, test) in s.split_terminator("\n...\n").enumerate() {
let i: usize = i;
let test: &str = test;
let (name, rest) = match test.find("\n===\n") {
Some(ix) => (&test[0..ix], &test[ix + 5..]),
None => {
errs.extend(compile_error(
&format!("test {} does not have `===` after name", i),
span,
));
continue;
}
};
let name = name.trim().replace(' ', "_");
let (input, output) = match rest.rfind("\n---\n") {
Some(ix) => (&rest[0..ix], &rest[ix + 5..]),
None => {
errs.extend(compile_error(
&format!("test `{}` does not have `---` after input", name),
span,
));
continue;
}
};
let input = input.trim().to_string();
let output = output.trim().to_string();
let name = match syn::parse_str::<syn::Ident>(&format!("_{}", name)) {
Ok(name) => name,
Err(_) => {
errs.extend(compile_error(
&format!("`{}` is not a valid test name identifier", name),
span,
));
continue;
}
};
tests.push(Test {
name,
input,
output,
})
}
if errs.is_empty() {
Ok(tests)
} else {
Err(errs)
}
}
#[proc_macro_attribute]
pub fn tests(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let mut tts: TokenStream = item.clone().into();
let manifest_dir = env::var("CARGO_MANIFEST_DIR")
.map(PathBuf::from)
.map_err(|e| {
let e = format!("expected $CARGO_MANIFEST_DIR; {}", e);
compile_error(&e, Span::call_site())
});
let args = syn::parse::<AttrArgs>(attr).map_err(|e| e.to_compile_error());
let fun = syn::parse::<syn::ItemFn>(item).map_err(|e| e.to_compile_error());
match (args, fun, manifest_dir) {
(Ok(args), Ok(fun), Ok(manifest_dir)) => tts.extend(build_tests(args, fun, manifest_dir)),
(Err(a), Err(b), Err(c)) => tts.extend(vec![a, b, c]),
(Err(a), Err(b), _) | (Err(a), _, Err(b)) | (_, Err(a), Err(b)) => tts.extend(vec![a, b]),
(Err(a), _, _) | (_, Err(a), _) | (_, _, Err(a)) => tts.extend(vec![a]),
}
tts.into()
}
fn build_tests(args: AttrArgs, fun: syn::ItemFn, manifest_dir: PathBuf) -> TokenStream {
let AttrArgs {
ser,
de,
value,
file,
} = args;
let fn_name = &fun.sig.ident;
let tested_type = match &fun.sig.output {
syn::ReturnType::Type(_, r#type) => (**r#type).clone(),
syn::ReturnType::Default => syn::parse_str("()").unwrap(),
};
let de_type = value.unwrap_or(tested_type);
let tests_path = manifest_dir.join(file.value());
let tests = match read_tests(&tests_path, file.span()) {
Ok(it) => it,
Err(e) => return e,
};
let filepath = tests_path.to_string_lossy().to_string();
let filename = tests_path
.file_stem()
.unwrap()
.to_string_lossy()
.replace('.', "_");
let testing_fn = syn::Ident::new(&filename, Span::call_site());
let mut tts = quote! {
fn #testing_fn(expected: &str, actual: &str) -> Result<(), Box<dyn ::std::error::Error>> {
const _: &str = include_str!(#filepath);
let actual = #ser(&#fn_name(actual))?;
let expected = #ser(&#de::<#de_type>(expected)?)?;
assert_eq!(actual, expected);
Ok(())
}
};
for test in tests {
let Test {
name,
input,
output,
} = test;
let test_name = quote::format_ident!("{}{}", filename, name);
tts.extend(quote! {
#[test]
fn #test_name() -> Result<(), Box<dyn ::std::error::Error>> {
#testing_fn(#output, #input)
}
})
}
tts.into()
}