conservator_macro 0.1.6

conservator macro
Documentation
use std::str::FromStr;

use itertools::Itertools;
use proc_macro2::Span;
use quote::{format_ident, quote};
use strum::EnumString;
use syn::spanned::Spanned;
use syn::{parse2, Expr, FnArg, ItemFn, Lit, Pat, Stmt};

#[derive(Debug, EnumString)]
#[strum(serialize_all = "snake_case")]
enum Action {
    Fetch,
    Exists,
    Find,
    FetchAll,
    Execute,
}

impl Action {
    fn build_sqlx_query(&self, fields: &[String], sql: String) -> proc_macro2::TokenStream {
        let fields = fields
            .iter()
            .filter(|&field| !field.eq("executor"))
            .map(|field| format_ident!("{}", field))
            .collect_vec();
        match self {
            Action::Fetch => {
                quote! {
                    ::sqlx::query_as(#sql)
                    #(.bind(#fields))*
                    .fetch_one(executor)
                    .await
                }
            }
            Action::Exists => {
                quote! {
                    Ok(::sqlx::query_as::<_, (bool, )>(#sql)
                    #(.bind(#fields))*
                    .fetch_one(executor)
                    .await?.0)
                }
            }
            Action::Find => {
                quote! {
                    ::sqlx::query_as(#sql)
                    #(.bind(#fields))*
                    .fetch_optional(executor)
                    .await
                }
            }
            Action::FetchAll => {
                quote! {
                    ::sqlx::query_as(#sql)
                    #(.bind(#fields))*
                    .fetch_all(executor)
                    .await
                }
            }
            Action::Execute => quote! {
                ::sqlx::query(#sql)
                #(.bind(#fields))*
                .execute(executor)
                .await?;
                Ok(())
            },
        }
    }
}

pub(crate) fn handler(args: proc_macro2::TokenStream, input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, (Span, &'static str)> {
    let arg = args.to_string();
    let action = match Action::from_str(&arg) {
        Ok(action) => action,
        Err(_) => return Err((args.span(), "unknown action type")),
    };

    let input_span = input.span().clone();
    let method = match parse2::<ItemFn>(input) {
        Ok(func) => func,
        Err(_) => return Err((input_span, "unknown action type")),
    };

    let vis = &method.vis;
    let ident = &method.sig.ident;
    let inputs = &method.sig.inputs;
    let fields: Vec<String> = inputs
        .iter()
        .filter_map(|it| match it {
            FnArg::Receiver(_) => None,
            FnArg::Typed(typed) => match &*typed.pat {
                Pat::Ident(ident) => Some(ident.ident.to_string()),
                _ => None,
            },
        })
        .collect();
    let output = &method.sig.output;
    let body = &method.block;
    let body: Vec<proc_macro2::TokenStream> = body
        .stmts
        .iter()
        .cloned()
        .map(|stmt| match &stmt {
            Stmt::Expr(Expr::Lit(expr_lit)) => match &expr_lit.lit {
                Lit::Str(lit_str) => {
                    let mut sql = lit_str.value();
                    fields.iter().enumerate().for_each(|(idx, field)| {
                        sql = sql.replace(&format!(":{}", field), &format!("${}", idx + 1));
                    });
                    let query_stmt = action.build_sqlx_query(&fields[..], sql);
                    quote!( #query_stmt)
                }
                _ => {
                    quote!( #stmt )
                }
            },
            _ => quote!( #stmt ),
        })
        .collect();

    Ok(quote! {
        #vis async fn #ident<'e, 'c: 'e, E: 'e + ::sqlx::Executor<'c, Database=::sqlx::Postgres>>(#inputs) #output {
            #(#body )*
        }
    })
}

#[cfg(test)]
mod test {
    use crate::sql::handler;

    #[test]
    fn should_generate_fetch_sql_function() {
        use quote::quote;
        let args = quote! { fetch };
        let input = quote! {
            pub async fn find_user<E>(email: &str, executor: E) -> Result<Option<UserEntity>, ::sqlx::Error> {
                "select * from users where email = :email"
            }
        };

        let expected = quote! {
            pub async fn find_user<'e, 'c: 'e, E: 'e + ::sqlx::Executor<'c, Database = ::sqlx::Postgres>>(
                email: &str,
                executor: E
            ) -> Result<Option<UserEntity>, ::sqlx::Error> {
                ::sqlx::query_as("select * from users where email = $1")
                    .bind(email)
                    .fetch_one(executor)
                    .await
            }
        };
        assert_eq!(expected.to_string(), handler(args, input).unwrap().to_string());
    }
}