1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{parse_macro_input, punctuated::Punctuated, ItemFn, Signature, Type};
5
6#[proc_macro_attribute]
12pub fn test(_attr: TokenStream, item: TokenStream) -> TokenStream {
13 let func = parse_macro_input!(item as ItemFn);
14
15 let conn_type = get_conn_type(&func);
16 let db_name = test_db_name(&func);
17 let (sig, sig_inner) = get_signatures(&func);
18 let inner_ident = &sig_inner.ident;
19 let inner_block = &func.block;
20
21 let run_test = if func.sig.asyncness.is_some() {
22 quote! {
23 use glow_plug::FutureExt;
24 let result = async move {
25 std::panic::AssertUnwindSafe(#inner_ident(conn)).catch_unwind().await
26 }.await;
27 }
28 } else {
29 quote! {
30 let result = std::panic::catch_unwind(|| {
31 #inner_ident(conn)
32 });
33 }
34 };
35
36 let test_macro = if func.sig.asyncness.is_some() {
37 quote! { #[glow_plug::tokio::test] }
38 } else {
39 quote! { #[test] }
40 };
41
42 quote! {
43 #test_macro
44 #sig {
45 #sig_inner #inner_block
46
47 use glow_plug::{Connection, RunQueryDsl};
48 use glow_plug::MigrationHarness;
49
50 glow_plug::dotenvy::dotenv().ok();
51
52 let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
53
54 let mut main_conn = <#conn_type>::establish(db_url.as_str())
55 .expect("Failed to establish connection.");
56
57 diesel::sql_query(format!("CREATE DATABASE {};", #db_name))
58 .execute(&mut main_conn)
59 .expect("Failed to create database.");
60
61 let mut conn = <#conn_type>::establish(format!("{}/{}", db_url, #db_name).as_str())
64 .expect("Failed to establish connection.");
65
66 conn.run_pending_migrations(crate::MIGRATIONS)
70 .expect("Failed to run migrations");
71
72 #run_test
74
75 diesel::sql_query(format!("DROP DATABASE {};", #db_name))
76 .execute(&mut main_conn)
77 .expect("Failed to drop database.");
78
79 match result {
80 Ok(val) => val,
81 Err(err) => std::panic::resume_unwind(err),
82 }
83 }
84 }
85 .into()
86}
87
88fn get_signatures(func: &ItemFn) -> (Signature, Signature) {
89 let mut sig = func.sig.clone();
90 let mut sig_inner = func.sig.clone();
91
92 sig.inputs = Punctuated::new();
94
95 sig_inner.ident = syn::Ident::new(&format!("{}_inner", sig.ident), sig.ident.span());
97
98 (sig, sig_inner)
99}
100
101fn get_conn_type(func: &ItemFn) -> &Box<Type> {
103 let args = &func.sig.inputs;
104
105 match args.first() {
106 Some(syn::FnArg::Typed(syn::PatType { ty, .. })) => ty,
107 _ => panic!("Expected at least one input argument"),
108 }
109}
110
111fn test_db_name(func: &ItemFn) -> String {
112 let name = func.sig.ident.to_string();
113 format!("{}_{}", name, rand::random::<u32>())
114}