glow_plug_macros/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{parse_macro_input, punctuated::Punctuated, ItemFn, Signature, Type};
5
6/// A testing macro similar to the standard `[#test]` but allows for
7/// function to take a connection to a database as a parameter.
8///
9/// A clean new database will be automatically created for the test
10/// and destroyed after the test has run.
11#[proc_macro_attribute]
12pub fn test(_attr: TokenStream, item: TokenStream) -> TokenStream {
13    let func = parse_macro_input!(item as ItemFn);
14
15    let conn_type = get_conn_type(&func);
16    let db_name = test_db_name(&func);
17    let (sig, sig_inner) = get_signatures(&func);
18    let inner_ident = &sig_inner.ident;
19    let inner_block = &func.block;
20
21    let run_test = if func.sig.asyncness.is_some() {
22        quote! {
23            use glow_plug::FutureExt;
24            let result = async move {
25                std::panic::AssertUnwindSafe(#inner_ident(conn)).catch_unwind().await
26            }.await;
27        }
28    } else {
29        quote! {
30            let result = std::panic::catch_unwind(|| {
31                #inner_ident(conn)
32            });
33        }
34    };
35
36    let test_macro = if func.sig.asyncness.is_some() {
37        quote! { #[glow_plug::tokio::test] }
38    } else {
39        quote! { #[test] }
40    };
41
42    quote! {
43        #test_macro
44        #sig {
45            #sig_inner #inner_block
46
47            use glow_plug::{Connection, RunQueryDsl};
48            use glow_plug::MigrationHarness;
49
50            glow_plug::dotenvy::dotenv().ok();
51
52            let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
53
54            let mut main_conn = <#conn_type>::establish(db_url.as_str())
55                .expect("Failed to establish connection.");
56
57            diesel::sql_query(format!("CREATE DATABASE {};", #db_name))
58                .execute(&mut main_conn)
59                .expect("Failed to create database.");
60
61            // We cannot just change the `conn` to point to the new database since not
62            // all databases don't support a `USE database` command. So just connect again
63            let mut conn = <#conn_type>::establish(format!("{}/{}", db_url, #db_name).as_str())
64                .expect("Failed to establish connection.");
65
66            // Run the migrations, there may be a better way to do this. Right now it currently
67            // requires the migrations to be a const variable in the crate root also the user
68            // needs to manually add the `diesel_migrations` dependency with the correct features.
69            conn.run_pending_migrations(crate::MIGRATIONS)
70                .expect("Failed to run migrations");
71
72            // Make sure to catch the panic, so we can drop the database even on failure.
73            #run_test
74
75            diesel::sql_query(format!("DROP DATABASE {};", #db_name))
76                .execute(&mut main_conn)
77                .expect("Failed to drop database.");
78
79            match result {
80                Ok(val) => val,
81                Err(err) => std::panic::resume_unwind(err),
82            }
83        }
84    }
85    .into()
86}
87
88fn get_signatures(func: &ItemFn) -> (Signature, Signature) {
89    let mut sig = func.sig.clone();
90    let mut sig_inner = func.sig.clone();
91
92    // Remove the inputs of the outer function
93    sig.inputs = Punctuated::new();
94
95    // Rename the inner function
96    sig_inner.ident = syn::Ident::new(&format!("{}_inner", sig.ident), sig.ident.span());
97
98    (sig, sig_inner)
99}
100
101// Gets the type of connection, e.g. `PgConnection` or `MysqlConnection`
102fn get_conn_type(func: &ItemFn) -> &Box<Type> {
103    let args = &func.sig.inputs;
104
105    match args.first() {
106        Some(syn::FnArg::Typed(syn::PatType { ty, .. })) => ty,
107        _ => panic!("Expected at least one input argument"),
108    }
109}
110
111fn test_db_name(func: &ItemFn) -> String {
112    let name = func.sig.ident.to_string();
113    format!("{}_{}", name, rand::random::<u32>())
114}