juniper-eager-loading-code-gen 0.4.2

Eliminate N+1 query bugs when using Juniper
Documentation
mod field_args;

use darling::{FromDeriveInput, FromMeta};
use field_args::{DeriveArgs, FieldArgs, HasMany, HasManyThrough, HasOne, OptionHasOne};
use heck::{CamelCase, SnakeCase};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
    parse_macro_input, DeriveInput, GenericArgument, Ident, NestedMeta, PathArguments, Type,
};

pub fn gen_tokens(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let ast = parse_macro_input!(input as DeriveInput);
    let args = match DeriveArgs::from_derive_input(&ast) {
        Ok(args) => args,
        Err(err) => panic!("{}", err),
    };

    let out = DeriveData::new(ast, args);
    let tokens = out.build_derive_output();

    tokens.into()
}

struct DeriveData {
    input: DeriveInput,
    args: DeriveArgs,
    tokens: TokenStream,
}

impl DeriveData {
    fn new(input: DeriveInput, args: DeriveArgs) -> Self {
        Self {
            input,
            args,
            tokens: quote! {},
        }
    }

    fn build_derive_output(mut self) -> TokenStream {
        self.gen_graphql_node_for_model();
        self.gen_eager_load_all_children();

        if self.args.print() {
            eprintln!("{}", self.tokens);
        }

        self.gen_eager_load_children_of_type();

        self.tokens
    }

    fn gen_graphql_node_for_model(&mut self) {
        let struct_name = self.struct_name();
        let model = self.model();
        let id = self.id();
        let connection = self.connection();
        let error = self.error();

        let field_setters = self.struct_fields().map(|field| {
            let ident = &field.ident;

            if is_association_field(&field.ty) {
                quote! { #ident: std::default::Default::default() }
            } else {
                quote! { #ident: std::clone::Clone::clone(model) }
            }
        });

        let code = quote! {
            impl juniper_eager_loading::GraphqlNodeForModel for #struct_name {
                type Model = #model;
                type Id = #id;
                type Connection = #connection;
                type Error = #error;

                fn new_from_model(model: &Self::Model) -> Self {
                    Self {
                        #(#field_setters),*
                    }
                }
            }
        };
        self.tokens.extend(code);
    }

    fn gen_eager_load_children_of_type(&mut self) {
        let impls = self
            .struct_fields()
            .filter_map(|field| self.gen_eager_load_children_of_type_for_field(field));

        let code = quote! { #(#impls)* };
        self.tokens.extend(code);
    }

    fn gen_eager_load_children_of_type_for_field(&self, field: &syn::Field) -> Option<TokenStream> {
        let (args, data) = self.parse_field_args(field)?;

        let inner_type = &data.inner_type;
        let struct_name = self.struct_name();
        let join_model_impl = self.join_model_impl(&data);
        let load_children_impl = self.load_children_impl(&data);
        let association_impl = self.association_impl(&data);
        let is_child_of_impl = self.is_child_of_impl(&data);

        let context = self.field_context_name(&field);

        let full_output = quote! {
            #[allow(missing_docs, dead_code)]
            struct #context;

            impl<'a> juniper_eager_loading::EagerLoadChildrenOfType<
                'a,
                #inner_type,
                #context,
                #join_model_impl,
            > for #struct_name {
                type FieldArguments = ();

                #load_children_impl
                #is_child_of_impl
                #association_impl
            }
        };

        if args.print {
            eprintln!("{}", full_output);
        }

        if args.skip {
            Some(quote! {})
        } else {
            Some(full_output)
        }
    }

    fn parse_field_args(&self, field: &syn::Field) -> Option<(FieldArgs, FieldDeriveData)> {
        let inner_type = get_type_from_association(&field.ty)?;
        let association_type = association_type(&field.ty)?;

        let args = match association_type {
            AssociationType::HasOne => {
                let args = parse_field_args::<HasOne>(&field)
                    .unwrap_or_else(|e| panic!("{}", e))
                    .has_one;
                FieldArgs::from(args)
            }
            AssociationType::OptionHasOne => {
                let args = parse_field_args::<OptionHasOne>(&field)
                    .unwrap_or_else(|e| panic!("{}", e))
                    .option_has_one;
                FieldArgs::from(args)
            }
            AssociationType::HasMany => {
                let args = parse_field_args::<HasMany>(&field)
                    .unwrap_or_else(|e| panic!("{}", e))
                    .has_many;
                FieldArgs::from(args)
            }
            AssociationType::HasManyThrough => {
                let args = parse_field_args::<HasManyThrough>(&field)
                    .unwrap_or_else(|e| panic!("{}", e))
                    .has_many_through;
                FieldArgs::from(args)
            }
        };

        let field_name = field.ident.as_ref().unwrap_or_else(|| {
            panic!("Found `juniper_eager_loading::HasOne` field without a name")
        });

        let foreign_key_field_default = match association_type {
            AssociationType::HasMany | AssociationType::HasManyThrough => self.struct_name(),
            AssociationType::HasOne | AssociationType::OptionHasOne => &field_name,
        };

        let data = FieldDeriveData {
            field_name: field_name.clone(),
            inner_type: inner_type.clone(),
            root_model_field: self.root_model_field().clone(),
            join_model: args.join_model(),
            model_field: args.model_field(&inner_type),
            join_model_field: args.join_model_field(),
            foreign_key_field: args.foreign_key_field(foreign_key_field_default),
            foreign_key_optional: args.foreign_key_optional,
            field_root_model_field: args.root_model_field(&field_name),
            association_type,
            predicate_method: args.predicate_method(),
        };

        Some((args, data))
    }

    fn join_model_impl(&self, data: &FieldDeriveData) -> TokenStream {
        match data.association_type {
            AssociationType::HasMany | AssociationType::HasOne | AssociationType::OptionHasOne => {
                quote! { () }
            }
            AssociationType::HasManyThrough => {
                let join_model = &data.join_model;
                quote! { #join_model }
            }
        }
    }

    fn load_children_impl(&self, data: &FieldDeriveData) -> TokenStream {
        use AssociationType::*;

        let foreign_key_field = &data.foreign_key_field;
        let join_model = &data.join_model;
        let model_id_field = data.model_id_field();
        let inner_type = &data.inner_type;

        let load_children_impl = match data.association_type {
            HasOne => {
                quote! {
                    let ids = models
                        .iter()
                        .map(|model| model.#foreign_key_field.clone())
                        .collect::<Vec<_>>();
                    let ids = juniper_eager_loading::unique(ids);

                    let child_models: Vec<<#inner_type as juniper_eager_loading::GraphqlNodeForModel>::Model> =
                        juniper_eager_loading::LoadFrom::load(&ids, field_args, db)?;

                    Ok(juniper_eager_loading::LoadChildrenOutput::ChildModels(child_models))
                }
            }
            OptionHasOne => {
                quote! {
                    let ids = models
                        .iter()
                        .filter_map(|model| model.#foreign_key_field)
                        .map(|id| id.clone())
                        .collect::<Vec<_>>();
                    let ids = juniper_eager_loading::unique(ids);

                    let child_models: Vec<<#inner_type as juniper_eager_loading::GraphqlNodeForModel>::Model> =
                        juniper_eager_loading::LoadFrom::load(&ids, field_args, db)?;

                    Ok(juniper_eager_loading::LoadChildrenOutput::ChildModels(child_models))
                }
            }
            HasMany => {
                let filter = if let Some(predicate_method) = &data.predicate_method {
                    quote! {
                        let child_models = child_models
                            .into_iter()
                            .filter(|child_model| child_model.#predicate_method(db))
                            .collect::<Vec<_>>();
                    }
                } else {
                    quote! {}
                };

                quote! {
                    let child_models: Vec<<#inner_type as juniper_eager_loading::GraphqlNodeForModel>::Model> =
                        juniper_eager_loading::LoadFrom::load(&models, field_args, db)?;

                    #filter

                    Ok(juniper_eager_loading::LoadChildrenOutput::ChildModels(child_models))
                }
            }
            HasManyThrough => {
                let filter = if let Some(predicate_method) = &data.predicate_method {
                    quote! {
                        let join_models = join_models
                            .into_iter()
                            .filter(|child_model| child_model.#predicate_method(db))
                            .collect::<Vec<_>>();
                    }
                } else {
                    quote! {}
                };

                quote! {
                    let join_models: Vec<#join_model> =
                        juniper_eager_loading::LoadFrom::load(&models, field_args, db)?;

                    #filter

                    let child_models: Vec<<#inner_type as juniper_eager_loading::GraphqlNodeForModel>::Model> =
                        juniper_eager_loading::LoadFrom::load(&join_models, field_args, db)?;

                    let mut child_and_join_model_pairs = Vec::new();
                    for join_model in join_models {
                        for child_model in &child_models {
                            if join_model.#model_id_field == child_model.id {
                                let pair = (
                                    std::clone::Clone::clone(child_model),
                                    std::clone::Clone::clone(&join_model),
                                );
                                child_and_join_model_pairs.push(pair);
                            }
                        }
                    }

                    Ok(juniper_eager_loading::LoadChildrenOutput::ChildAndJoinModels(child_and_join_model_pairs))
                }
            }
        };

        quote! {
            #[allow(unused_variables)]
            fn load_children(
                models: &[Self::Model],
                field_args: &Self::FieldArguments,
                db: &Self::Connection,
            ) -> Result<
                juniper_eager_loading::LoadChildrenOutput<
                    <#inner_type as juniper_eager_loading::GraphqlNodeForModel>::Model,
                    #join_model
                >,
                Self::Error,
            > {
                #load_children_impl
            }
        }
    }

    fn is_child_of_impl(&self, data: &FieldDeriveData) -> TokenStream {
        let root_model_field = &data.root_model_field;
        let foreign_key_field = &data.foreign_key_field;
        let field_root_model_field = &data.field_root_model_field;
        let inner_type = &data.inner_type;
        let join_model = &data.join_model;
        let model_field = &data.model_field;
        let model_id_field = &data.model_id_field();

        let is_child_of_impl = match data.association_type {
            AssociationType::HasOne => {
                quote! {
                    node.#root_model_field.#foreign_key_field == child.#field_root_model_field.id
                }
            }
            AssociationType::OptionHasOne => {
                quote! {
                    node.#root_model_field.#foreign_key_field == Some(child.#field_root_model_field.id)
                }
            }
            AssociationType::HasMany => {
                if data.foreign_key_optional {
                    quote! {
                        Some(node.#root_model_field.id) ==
                            child.#field_root_model_field.#foreign_key_field
                    }
                } else {
                    quote! {
                        node.#root_model_field.id ==
                            child.#field_root_model_field.#foreign_key_field
                    }
                }
            }
            AssociationType::HasManyThrough => {
                quote! {
                    node.#root_model_field.id == join_model.#foreign_key_field &&
                        join_model.#model_id_field == child.#model_field.id
                }
            }
        };

        quote! {
            fn is_child_of(
                node: &Self,
                child: &#inner_type,
                join_model: &#join_model,
                _field_args: &Self::FieldArguments,
            ) -> bool {
                #is_child_of_impl
            }
        }
    }

    fn association_impl(&self, data: &FieldDeriveData) -> TokenStream {
        let field_name = &data.field_name;
        let inner_type = &data.inner_type;

        quote! {
            fn association(node: &mut Self) ->
                &mut dyn juniper_eager_loading::Association<#inner_type>
            {
                &mut node.#field_name
            }
        }
    }

    fn gen_eager_load_all_children(&mut self) {
        let struct_name = self.struct_name();

        let eager_load_children_calls = self
            .struct_fields()
            .filter_map(|field| self.gen_eager_load_all_children_for_field(field));

        let code = quote! {
            impl juniper_eager_loading::EagerLoadAllChildren for #struct_name {
                fn eager_load_all_children_for_each(
                    nodes: &mut [Self],
                    models: &[Self::Model],
                    db: &Self::Connection,
                    trail: &juniper_from_schema::QueryTrail<'_, Self, juniper_from_schema::Walked>,
                ) -> Result<(), Self::Error> {
                    #(#eager_load_children_calls)*

                    Ok(())
                }
            }
        };
        self.tokens.extend(code);
    }

    fn gen_eager_load_all_children_for_field(&self, field: &syn::Field) -> Option<TokenStream> {
        let inner_type = get_type_from_association(&field.ty)?;

        let (args, _data) = self.parse_field_args(field)?;

        let field_name = args
            .graphql_field()
            .clone()
            .map(|ident| {
                let ident = ident.to_string().to_snake_case();
                Ident::new(&ident, Span::call_site())
            })
            .unwrap_or_else(|| {
                field.ident.clone().unwrap_or_else(|| {
                    panic!("Found `juniper_eager_loading::HasOne` field without a name")
                })
            });
        let field_args_name = quote::format_ident!("{}_args", field_name);

        let context = self.field_context_name(&field);

        Some(quote! {
            if let Some(child_trail) = trail.#field_name().walk() {
                let field_args = trail.#field_args_name();

                EagerLoadChildrenOfType::<#inner_type, #context, _>::eager_load_children(
                    nodes,
                    models,
                    db,
                    &child_trail,
                    &field_args,
                )?;
            }
        })
    }

    fn struct_name(&self) -> &syn::Ident {
        &self.input.ident
    }

    fn model(&self) -> TokenStream {
        self.args.model(&self.struct_name())
    }

    fn id(&self) -> TokenStream {
        self.args.id()
    }

    fn connection(&self) -> TokenStream {
        self.args.connection()
    }

    fn error(&self) -> TokenStream {
        self.args.error()
    }

    fn root_model_field(&self) -> TokenStream {
        self.args.root_model_field(&self.struct_name())
    }

    fn struct_fields(&self) -> syn::punctuated::Iter<syn::Field> {
        use syn::{Data, Fields};

        match &self.input.data {
            Data::Union(_) => panic!("Factory can only be derived on structs"),
            Data::Enum(_) => panic!("Factory can only be derived on structs"),
            Data::Struct(data) => match &data.fields {
                Fields::Named(named) => named.named.iter(),
                Fields::Unit => panic!("Factory can only be derived on structs with named fields"),
                Fields::Unnamed(_) => {
                    panic!("Factory can only be derived on structs with named fields")
                }
            },
        }
    }

    fn field_context_name(&self, field: &syn::Field) -> Ident {
        let camel_name = field
            .ident
            .as_ref()
            .expect("field without name")
            .to_string()
            .to_camel_case();
        let full_name = format!("EagerLoadingContext{}For{}", self.struct_name(), camel_name);
        Ident::new(&full_name, Span::call_site())
    }
}

macro_rules! if_let_or_none {
    ( $path:path , $($tokens:tt)* ) => {
        if let $path(inner) = $($tokens)* {
            inner
        } else {
            return None
        }
    };
}

fn get_type_from_association(ty: &syn::Type) -> Option<&syn::Type> {
    if !is_association_field(ty) {
        return None;
    }

    let type_path = if_let_or_none!(Type::Path, ty);
    let path = &type_path.path;
    let segments = &path.segments;
    let segment = if_let_or_none!(Some, segments.last());
    let args = if_let_or_none!(PathArguments::AngleBracketed, &segment.arguments);
    let generic_argument: &syn::GenericArgument = if_let_or_none!(Some, args.args.last());
    let ty = if_let_or_none!(GenericArgument::Type, generic_argument);
    Some(remove_possible_box_wrapper(ty))
}

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum AssociationType {
    HasOne,
    OptionHasOne,
    HasMany,
    HasManyThrough,
}

fn association_type(ty: &syn::Type) -> Option<AssociationType> {
    if *last_ident_in_type_segment(ty)? == "OptionHasOne" {
        return Some(AssociationType::OptionHasOne);
    }

    if *last_ident_in_type_segment(ty)? == "HasManyThrough" {
        return Some(AssociationType::HasManyThrough);
    }

    if *last_ident_in_type_segment(ty)? == "HasMany" {
        return Some(AssociationType::HasMany);
    }

    if *last_ident_in_type_segment(ty)? == "HasOne" {
        return Some(AssociationType::HasOne);
    }

    None
}

fn is_association_field(ty: &syn::Type) -> bool {
    association_type(ty).is_some()
}

fn last_ident_in_type_segment(ty: &syn::Type) -> Option<&syn::Ident> {
    let type_path = if_let_or_none!(Type::Path, ty);
    let path = &type_path.path;
    let segments = &path.segments;
    let segment = if_let_or_none!(Some, segments.last());
    Some(&segment.ident)
}

fn parse_field_args<T: FromMeta>(field: &syn::Field) -> Result<T, darling::Error> {
    let attrs = field
        .attrs
        .iter()
        .map(|attr| {
            let meta = attr.parse_meta().unwrap_or_else(|e| panic!("{}", e));
            NestedMeta::from(meta)
        })
        .collect::<Vec<_>>();
    FromMeta::from_list(attrs.as_slice())
}

#[derive(Debug)]
struct FieldDeriveData {
    foreign_key_field: TokenStream,
    foreign_key_optional: bool,
    field_root_model_field: TokenStream,
    root_model_field: TokenStream,
    join_model: TokenStream,
    inner_type: syn::Type,
    field_name: Ident,
    association_type: AssociationType,
    model_field: TokenStream,
    join_model_field: TokenStream,
    predicate_method: Option<Ident>,
}

impl FieldDeriveData {
    fn model_id_field(&self) -> Ident {
        Ident::new(&format!("{}_id", self.model_field), Span::call_site())
    }
}

fn remove_possible_box_wrapper(ty: &Type) -> &syn::Type {
    if let Type::Path(type_path) = ty {
        let last_segment = if let Some(x) = type_path.path.segments.last() {
            x
        } else {
            return ty;
        };

        if last_segment.ident == "Box" {
            let args = if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
                args
            } else {
                return ty;
            };

            let generic_argument = if let Some(x) = args.args.last() {
                x
            } else {
                return ty;
            };

            if let syn::GenericArgument::Type(inner_ty) = generic_argument {
                inner_ty
            } else {
                ty
            }
        } else {
            ty
        }
    } else {
        ty
    }
}