lsor-proc-macro 0.1.0

Proc macros for lsor
Documentation
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::{Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident};

use crate::util;

pub fn expand_derive_filter(input: TokenStream) -> TokenStream {
    let ast: DeriveInput = syn::parse(input).unwrap();
    match &ast.data {
        // Struct
        Data::Struct(data) => expand_derive_filter_for_struct(&ast, &ast.attrs, data),
        Data::Enum(data) => expand_derive_filter_for_enum(&ast, &ast.attrs, data),
        _ => panic!("filter can only be implemented for structs and enums"),
    }
}

fn expand_derive_filter_for_struct(
    ast: &DeriveInput,
    attrs: &[Attribute],
    data: &DataStruct,
) -> TokenStream {
    if util::has_json_attr(attrs) {
        return expand_derive_json_filter_for_struct(ast, attrs, data);
    }

    let ident = &ast.ident;
    let filter_ident = util::concat_idents(ident, &Ident::new("Filter", Span::call_site()));
    let table = util::collect_table_attr(attrs);

    let fields = match &data.fields {
        Fields::Named(fields) => fields,
        _ => panic!("filter can only be implemented for structs with named fields"),
    };

    let field_variants_decl = fields.named.iter().filter_map(|field| {
        let skip = util::has_skip_filter_attr(&field.attrs);
        if skip {
            return None;
        }

        let field_ident = field.ident.as_ref().unwrap();
        let field_ident_camel_case = Ident::new(
            &util::snake_case_to_camel_case(field_ident.to_string().as_str()),
            Span::call_site(),
        );
        let field_ty = &field.ty;

        Some(quote! { #field_ident_camel_case(<#field_ty as ::lsor::filter::Filterable>::Filter), })
    });

    let field_variants_impl = fields.named.iter().filter_map(|field| {
        let skip = util::has_skip_filter_attr(&field.attrs);
        if skip {
            return None;
        }

        let field_ident = field.ident.as_ref().unwrap();
        let field_ident_camel_case = Ident::new(
            &util::snake_case_to_camel_case(field_ident.to_string().as_str()),
            Span::call_site(),
        );

        let flat = util::has_flatten_attr(&field.attrs);
        if flat {
            Some(quote! { #filter_ident::#field_ident_camel_case(filter) => filter.push_to_driver_with_table_name(tn, driver), })
        } else {
            Some(quote! { #filter_ident::#field_ident_camel_case(filter) => {
                filter.push_to_driver(&::lsor::table::dot(tn, ::lsor::column::col(stringify!(#field_ident))), driver);
            }})
        }
    });

    let push_to_drive_impl = table.map(|table| {
        quote! {
            impl ::lsor::driver::PushPrql for #filter_ident {
                fn push_to_driver(&self, driver: &mut ::lsor::driver::Driver) {
                    self.push_to_driver_with_table_name(&::lsor::table::table(#table), driver);
                }
            }
        }
    });

    let expanded = quote! {
        impl ::lsor::filter::Filterable for #ident {
            type Filter = #filter_ident;
        }

        #[derive(::std::clone::Clone, ::std::fmt::Debug, ::async_graphql::OneofObject)]
        #[graphql(rename_fields = "snake_case")]
        pub enum #filter_ident {
            All(Vec<#filter_ident>),
            Any(Vec<#filter_ident>),
            #(#field_variants_decl)*
        }

        #push_to_drive_impl

        impl #filter_ident {
            pub fn push_to_driver_with_table_name(&self, tn: &dyn ::lsor::driver::PushPrql, driver: &mut ::lsor::driver::Driver) {
                match &self {
                    #filter_ident::All(all) => {
                        let n = all.len();
                        for (i, x) in all.iter().enumerate() {
                            driver.push('(');
                            x.push_to_driver_with_table_name(tn, driver);
                            if i < n - 1 {
                                driver.push(") && ");
                            } else {
                                driver.push(')');
                            }
                        }
                    },
                    #filter_ident::Any(any) => {
                        let n = any.len();
                        for (i, x) in any.iter().enumerate() {
                            driver.push('(');
                            x.push_to_driver_with_table_name(tn, driver);
                            if i < n - 1 {
                                driver.push(") || ");
                            } else {
                                driver.push(')');
                            }
                        }
                    },
                    #(#field_variants_impl)*
                }
            }
        }
    };

    TokenStream::from(expanded)
}

fn expand_derive_json_filter_for_struct(
    ast: &DeriveInput,
    _attrs: &[Attribute],
    data: &DataStruct,
) -> TokenStream {
    let ident = &ast.ident;
    let filter_ident = util::concat_idents(ident, &Ident::new("Filter", Span::call_site()));

    let fields = match &data.fields {
        Fields::Named(fields) => fields,
        _ => panic!("filter can only be implemented for structs with named fields"),
    };

    let field_variants_decl = fields.named.iter().filter_map(|field| {
        let skip = util::has_skip_filter_attr(&field.attrs);
        if skip {
            return None;
        }

        let field_ident = field.ident.as_ref().unwrap();
        let field_ident_camel_case = Ident::new(
            &util::snake_case_to_camel_case(field_ident.to_string().as_str()),
            Span::call_site(),
        );
        let field_ty = &field.ty;

        Some(quote! { #field_ident_camel_case(<#field_ty as ::lsor::filter::Filterable>::Filter), })
    });

    let field_variants_impl = fields.named.iter().filter_map(|field| {
        let skip = util::has_skip_filter_attr(&field.attrs);
        if skip {
            return None;
        }

        let field_ident = field.ident.as_ref().unwrap();
        let field_ident_camel_case = Ident::new(
            &util::snake_case_to_camel_case(field_ident.to_string().as_str()),
            Span::call_site(),
        );

        let flat = util::has_flatten_attr(&field.attrs);
        if flat {
            Some(quote! { #filter_ident::#field_ident_camel_case(filter) => filter.push_to_driver(driver), })
        } else {
            Some(quote! { #filter_ident::#field_ident_camel_case(filter) => {
                filter.push_to_driver(&::lsor::column::json(lhs).get_text(stringify!(#field_ident)), driver);
            }})
        }
    });

    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();

    let expanded = quote! {
        impl #impl_generics ::lsor::filter::Filterable for #ident #ty_generics #where_clause {
            type Filter = #filter_ident;
        }

        #[derive(::std::clone::Clone, ::std::fmt::Debug, ::async_graphql::OneofObject)]
        #[graphql(rename_fields = "snake_case")]
        pub enum #filter_ident {
            #(#field_variants_decl)*
        }

        impl #filter_ident {
            pub fn push_to_driver(&self, lhs: &dyn ::lsor::driver::PushPrql, driver: &mut ::lsor::driver::Driver) {
                match &self {
                    #(#field_variants_impl)*
                }
            }
        }
    };

    TokenStream::from(expanded)
}

fn expand_derive_filter_for_enum(
    ast: &DeriveInput,
    attrs: &[Attribute],
    _data: &DataEnum,
) -> TokenStream {
    if util::has_json_attr(attrs) {
        panic!("filter does not support the #[lsor(json)] attribute for enums")
    }

    let ident = &ast.ident;
    let filter_ident = util::concat_idents(ident, &Ident::new("Filter", Span::call_site()));
    let filter_attrs = util::collect_filter_attrs(attrs);

    if filter_attrs.is_empty() {
        panic!("expected at least one of {}", filter_attrs_str());
    }

    let variants = filter_attrs.iter().map(|attr| match attr.as_str() {
        "==" => quote! { Eq(#ident) },
        "!=" => quote! { Ne(#ident) },
        "<" => quote! { Lt(#ident) },
        "<=" => quote! { Le(#ident) },
        ">" => quote! { Gt(#ident) },
        ">=" => quote! { Ge(#ident) },
        _ => panic!(
            "invalid filter attribute, must be one of {}",
            filter_attrs_str()
        ),
    });

    let match_arms = filter_attrs
        .iter()
        .map(|attr| match attr.as_str() {
            "==" => quote! {
                #filter_ident::Eq(x) => {
                    lhs.push_to_driver(driver);
                    driver.push(" == ");
                    driver.push_bind(x);
                }
            },
            "!=" => quote! {
                #filter_ident::Ne(x) => {
                    lhs.push_to_driver(driver);
                    driver.push(" != ");
                    driver.push_bind(x);
                }
            },
            "<" => quote! {
                #filter_ident::Lt(x) => {
                    lhs.push_to_driver(driver);
                    driver.push(" < ");
                    driver.push_bind(x);
                }
            },
            "<=" => quote! {
                #filter_ident::Le(x) => {
                    lhs.push_to_driver(driver);
                    driver.push(" <= ");
                    driver.push_bind(x);
                }
            },
            ">" => quote! {
                #filter_ident::Gt(x) => {
                    lhs.push_to_driver(driver);
                    driver.push(" > ");
                    driver.push_bind(x);
                }
            },
            ">=" => quote! {
                #filter_ident::Ge(x) => {
                    lhs.push_to_driver(driver);
                    driver.push(" >= ");
                    driver.push_bind(x);
                }
            },
            _ => panic!(
                "invalid filter attribute, must be one of {}",
                filter_attrs_str()
            ),
        })
        .collect::<Vec<_>>();

    let expanded = quote! {
        impl ::lsor::filter::Filterable for #ident {
            type Filter = #filter_ident;
        }

        #[derive(::std::clone::Clone, ::std::fmt::Debug, ::async_graphql::OneofObject)]
        #[graphql(rename_fields = "snake_case")]
        pub enum #filter_ident {
            #(#variants,)*
        }

        impl #filter_ident {
            pub fn push_to_driver(&self, lhs: &dyn ::lsor::driver::PushPrql, driver: &mut ::lsor::driver::Driver) {
                match self {
                    #(#match_arms)*
                }
            }
        }
    };

    TokenStream::from(expanded)
}

const fn filter_attrs_str() -> &'static str {
    "'==', '!=', '<', '<=', '>', or '>='"
}