conservator_macro 0.2.0

conservator macro
Documentation
use std::collections::HashSet;
use std::str::FromStr;

use itertools::Itertools;
use proc_macro2::Span;
use quote::{format_ident, quote};
use regex::Regex;
use strum::EnumString;
use syn::spanned::Spanned;
use syn::{
    parse2, AngleBracketedGenericArguments, Expr, ItemFn, Lit, PathArguments, ReturnType, Stmt,
    Type,
};

fn extract_inner_type<'a>(ty: &'a Type, wrapper: &'a str) -> Option<&'a Type> {
    if let Type::Path(syn::TypePath { qself: None, path }) = ty {
        if let Some(segment) = path.segments.last() {
            if segment.ident == wrapper {
                if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
                    args, ..
                }) = &segment.arguments
                {
                    if let Some(syn::GenericArgument::Type(inner_type)) = args.first() {
                        return Some(inner_type);
                    }
                }
            }
        }
    }
    None
}

#[derive(Debug, EnumString)]
#[strum(serialize_all = "snake_case")]
enum Action {
    Fetch,
    Exists,
    Find,
    FetchAll,
    Execute,
}

impl Action {
    fn build_sqlx_query(
        &self,
        fields: &[String],
        fetch_model: &proc_macro2::TokenStream,
        sql: String,
    ) -> proc_macro2::TokenStream {
        let fields = fields
            .iter()
            .filter(|&field| !field.eq("executor"))
            .map(|field| format_ident!("{}", field))
            .collect_vec();
        match self {
            Action::Fetch => {
                if cfg!(debug_assertions) {
                    quote! {
                        ::sqlx::query_as!(#fetch_model, #sql, #(#fields,)*)
                            .fetch_one(executor)
                            .await
                    }
                } else {
                    quote! {
                        ::sqlx::query_as(#sql)
                        #(.bind(#fields))*
                        .fetch_one(executor)
                        .await
                    }
                }
            }
            Action::Exists => {
                let exist_wrapper_sql = format!("select exists({})", sql);
                if cfg!(debug_assertions) {
                    quote! {
                        Ok(::sqlx::query_as!(#fetch_model, #exist_wrapper_sql, #(#fields,)*)
                            .fetch_one(executor)
                            .await?.exists.unwrap_or(false))
                    }
                } else {
                    quote! {
                        Ok(::sqlx::query_as::<_, #fetch_model>(#exist_wrapper_sql)
                        #(.bind(#fields))*
                        .fetch_one(executor)
                        .await?.exists.unwrap_or(false))
                    }
                }
            }
            Action::Find => {
                if cfg!(debug_assertions) {
                    quote! {
                        ::sqlx::query_as!(#fetch_model, #sql, #(#fields,)*)
                            .fetch_optional(executor)
                            .await
                    }
                } else {
                    quote! {
                        ::sqlx::query_as(#sql)
                        #(.bind(#fields))*
                        .fetch_optional(executor)
                        .await
                    }
                }
            }
            Action::FetchAll => {
                if cfg!(debug_assertions) {
                    quote! {
                        ::sqlx::query_as!(#fetch_model, #sql, #(#fields,)*)
                            .fetch_all(executor)
                            .await
                    }
                } else {
                    quote! {
                        ::sqlx::query_as(#sql)
                        #(.bind(#fields))*
                        .fetch_all(executor)
                        .await
                    }
                }
            }
            Action::Execute => {
                if cfg!(debug_assertions) {
                    quote! {
                        ::sqlx::query_as!(#fetch_model, #sql, #(#fields,)*)
                            .execute(executor)
                            .await?;
                        Ok(())
                    }
                } else {
                    quote! {
                        ::sqlx::query(#sql)
                        #(.bind(#fields))*
                        .execute(executor)
                        .await?;
                        Ok(())
                    }
                }
            }
        }
    }

    fn extract_and_build_ret_type(
        &self,
        ident: &ReturnType,
    ) -> Result<(proc_macro2::TokenStream, proc_macro2::TokenStream), (Span, &'static str)> {
        let span = ident.span();
        match ident {
            ReturnType::Default => Err((span, "default return type does not support")),
            ReturnType::Type(_, inner) => match self {
                Action::Fetch => Ok((quote! {#inner}, quote! { #inner })),
                Action::Exists => Ok((quote! {::conservator::ExistsRow}, quote! { bool })),
                Action::Find => {
                    let Some(inner_type) = extract_inner_type(inner, "Option") else {
                        return Err((span, "find method need a option type"));
                    };
                    Ok((quote! {#inner_type}, quote! { #inner }))
                }
                Action::FetchAll => {
                    let Some(inner_type) = extract_inner_type(inner, "Vec") else {
                        return Err((span, "fetchall method need a vec type"));
                    };
                    Ok((quote! {#inner_type}, quote! { #inner }))
                }
                Action::Execute => Ok((quote! { ::conservator::SingleNumberRow }, quote! { () })),
            },
        }
    }
}

pub(crate) fn handler(
    args: proc_macro2::TokenStream,
    input: proc_macro2::TokenStream,
) -> Result<proc_macro2::TokenStream, (Span, &'static str)> {
    let arg = args.to_string();
    let action = match Action::from_str(&arg) {
        Ok(action) => action,
        Err(_) => return Err((args.span(), "unknown action type")),
    };

    let input_span = input.span();
    let method = match parse2::<ItemFn>(input) {
        Ok(func) => func,
        Err(_) => return Err((input_span, "unknown action type")),
    };

    let vis = &method.vis;
    let ident = &method.sig.ident;
    let inputs = &method.sig.inputs;

    let output = &method.sig.output;

    let (fetch_model, return_type) = action.extract_and_build_ret_type(output)?;
    let body = &method.block;
    let body: Vec<proc_macro2::TokenStream> = body
        .stmts
        .iter()
        .cloned()
        .map(|stmt| match &stmt {
            Stmt::Expr(Expr::Lit(expr_lit)) => match &expr_lit.lit {
                Lit::Str(lit_str) => {
                    let mut sql = lit_str.value();
                    let re = Regex::new(r"[^:]:(\w+)").unwrap();
                    let matched: HashSet<String> = re
                        .captures_iter(&sql)
                        .map(|mat| mat[1].to_string())
                        .collect();
                    let matched_fields = matched.into_iter().collect_vec();

                    matched_fields.iter().enumerate().for_each(|(idx, field)| {
                        sql = sql.replace(&format!(":{}", field), &format!("${}", idx + 1));
                    });
                    let query_stmt =
                        action.build_sqlx_query(&matched_fields[..], &fetch_model, sql);
                    quote!( #query_stmt)
                }
                _ => {
                    quote!( #stmt )
                }
            },
            _ => quote!( #stmt ),
        })
        .collect();

    let inputs = if inputs.is_empty() {
        quote! {}
    } else if inputs.trailing_punct() {
        quote! { #inputs}
    } else {
        quote! { #inputs,}
    };
    let ret = quote! {
        #vis async fn #ident<'e, 'c: 'e, E: 'e + ::sqlx::Executor<'c, Database=::sqlx::Postgres>>(#inputs executor: E) -> Result<#return_type, ::sqlx::Error> {
            #(#body )*
        }
    };
    Ok(ret)
}

#[cfg(test)]
mod test {
    use crate::sql::handler;

    #[test]
    fn should_generate_fetch_sql_function() {
        use quote::quote;
        let args = quote! { find };
        let input = quote! {
            pub async fn find_user(email: &str) -> Option<UserEntity> {
                "select * from users where email = :email"
            }
        };

        let expected = quote! {
            pub async fn find_user<'e, 'c: 'e, E: 'e + ::sqlx::Executor<'c, Database = ::sqlx::Postgres>>(
                email: &str,
                executor: E
            ) -> Result<Option<UserEntity>, ::sqlx::Error> {
                ::sqlx::query_as!(UserEntity, "select * from users where email = $1", email,)
                    .fetch_optional(executor)
                    .await
            }
        };
        assert_eq!(
            expected.to_string(),
            handler(args, input).unwrap().to_string()
        );
    }

    #[test]
    fn should_generate_for_linked_domain() {
        use quote::quote;
        let args = quote! { find };
        let input = quote! {
            pub async fn find_user(&self) -> Option<UserEntity> {
                let id = self.id;
                "select * from users where email = :id"
            }
        };

        let expected = quote! {
            pub async fn find_user<'e, 'c: 'e, E: 'e + ::sqlx::Executor<'c, Database = ::sqlx::Postgres>>(
                &self,
                executor: E
            ) -> Result<Option<UserEntity>, ::sqlx::Error> {
                 let id = self.id;
                ::sqlx::query_as!(UserEntity, "select * from users where email = $1", id,)
                    .fetch_optional(executor)
                    .await
            }
        };
        assert_eq!(
            expected.to_string(),
            handler(args, input).unwrap().to_string()
        );
    }

    #[test]
    fn args_with_tailing_comma() {
        use quote::quote;
        let args = quote! { find };
        let input = quote! {
            pub async fn find_user(id: i32, ) -> Option<UserEntity> {
                "select * from users where email = :id"
            }
        };

        let expected = quote! {
            pub async fn find_user<'e, 'c: 'e, E: 'e + ::sqlx::Executor<'c, Database = ::sqlx::Postgres>>(
                id: i32,
                executor: E
            ) -> Result<Option<UserEntity>, ::sqlx::Error> {
                ::sqlx::query_as!(UserEntity, "select * from users where email = $1", id,)
                    .fetch_optional(executor)
                    .await
            }
        };
        assert_eq!(
            expected.to_string(),
            handler(args, input).unwrap().to_string()
        );
    }

    #[test]
    fn args_without_tailing_comma() {
        use quote::quote;
        let args = quote! { find };
        let input = quote! {
            pub async fn find_user(id: i32 ) -> Option<UserEntity> {
                "select * from users where email = :id"
            }
        };

        let expected = quote! {
            pub async fn find_user<'e, 'c: 'e, E: 'e + ::sqlx::Executor<'c, Database = ::sqlx::Postgres>>(
                id: i32,
                executor: E
            ) -> Result<Option<UserEntity>, ::sqlx::Error> {
                ::sqlx::query_as!(UserEntity, "select * from users where email = $1", id,)
                    .fetch_optional(executor)
                    .await
            }
        };
        assert_eq!(
            expected.to_string(),
            handler(args, input).unwrap().to_string()
        );
    }

    #[test]
    fn should_work_with_pg_double_mark() {
        use quote::quote;
        let args = quote! { find };
        let input = quote! {
            pub async fn find_user() -> Option<UserEntity> {
                "select * from users where datetime + '14 days'::interval > now()"
            }
        };

        let expected = quote! {
            pub async fn find_user<'e, 'c: 'e, E: 'e + ::sqlx::Executor<'c, Database = ::sqlx::Postgres>>(

                executor: E
            ) -> Result<Option<UserEntity>, ::sqlx::Error> {
                ::sqlx::query_as!(UserEntity, "select * from users where datetime + '14 days'::interval > now()",)
                    .fetch_optional(executor)
                    .await
            }
        };
        assert_eq!(
            expected.to_string(),
            handler(args, input).unwrap().to_string()
        );
    }
}