juniper-eager-loading-code-gen 0.4.2

Eliminate N+1 query bugs when using Juniper
Documentation
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
    braced, parenthesized,
    parse::{Parse, ParseStream},
    punctuated::Punctuated,
    Ident, Token, Type,
};

pub fn go(input: proc_macro::TokenStream, backend: Backend) -> proc_macro::TokenStream {
    let input = match syn::parse::<Input>(input) {
        Ok(x) => x,
        Err(err) => return err.to_compile_error().into(),
    };

    let mut tokens = TokenStream::new();

    for impl_ in &input.impls {
        impl_.gen_tokens(&input, &backend, &mut tokens);
    }

    tokens.into()
}

#[derive(Debug)]
pub enum Backend {
    Pg,
    Mysql,
    Sqlite,
}

mod kw {
    syn::custom_keyword!(error);
    syn::custom_keyword!(connection);
}

#[derive(Debug)]
struct Input {
    error_ty: Type,
    connection_ty: Type,
    impls: Punctuated<InputImpl, Token![,]>,
}

impl Parse for Input {
    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
        let prelude;
        parenthesized!(prelude in input);

        prelude.parse::<kw::error>()?;
        prelude.parse::<Token![=]>()?;
        let error_ty = prelude.parse::<Type>()?;

        prelude.parse::<Token![,]>()?;

        prelude.parse::<kw::connection>()?;
        prelude.parse::<Token![=]>()?;
        let connection_ty = prelude.parse::<Type>()?;

        if prelude.peek(Token![,]) {
            prelude.parse::<Token![,]>()?;
        }

        input.parse::<Token![=>]>()?;

        let content;
        braced!(content in input);
        let impls = Punctuated::parse_terminated(&content)?;

        Ok(Self {
            error_ty,
            connection_ty,
            impls,
        })
    }
}

#[derive(Debug)]
enum InputImpl {
    HasOne(HasOne),
    HasMany(HasMany),
}

#[derive(Debug)]
struct HasOne {
    id_ty: Type,
    table: Ident,
    self_ty: Type,
}

#[derive(Debug)]
struct HasMany {
    join_ty: Type,
    join_from: Ident,
    table: Ident,
    join_to: Ident,
    self_ty: Type,
}

impl Parse for InputImpl {
    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
        let id_ty = input.parse::<Type>()?;

        if input.peek(Token![.]) {
            let join_ty = id_ty;
            input.parse::<Token![.]>()?;
            let join_from = input.parse::<Ident>()?;

            input.parse::<Token![->]>()?;

            let inside;
            parenthesized!(inside in input);
            let table = inside.parse::<Ident>()?;
            inside.parse::<Token![.]>()?;
            let join_to = inside.parse::<Ident>()?;
            inside.parse::<Token![,]>()?;
            let self_ty = inside.parse::<Type>()?;

            Ok(InputImpl::HasMany(HasMany {
                join_ty,
                join_from,
                table,
                join_to,
                self_ty,
            }))
        } else {
            input.parse::<Token![->]>()?;

            let inside;
            parenthesized!(inside in input);

            let table = inside.parse::<Ident>()?;
            inside.parse::<Token![,]>()?;
            let self_ty = inside.parse::<Type>()?;

            Ok(InputImpl::HasOne(HasOne {
                id_ty,
                table,
                self_ty,
            }))
        }
    }
}

impl InputImpl {
    fn gen_tokens(&self, input: &Input, backend: &Backend, out: &mut TokenStream) {
        match self {
            InputImpl::HasOne(has_one) => has_one.gen_tokens(input, backend, out),
            InputImpl::HasMany(has_many) => has_many.gen_tokens(input, backend, out),
        }
    }
}

impl HasOne {
    fn gen_tokens(&self, input: &Input, backend: &Backend, out: &mut TokenStream) {
        let error_ty = &input.error_ty;
        let connection_ty = &input.connection_ty;

        let id_ty = &self.id_ty;
        let self_ty = &self.self_ty;
        let table = &self.table;

        let filter = match backend {
            Backend::Pg => {
                quote! {
                    #table::id.eq(diesel::pg::expression::dsl::any(ids))
                }
            }
            Backend::Mysql | Backend::Sqlite => {
                quote! {
                    #table::id.eq_any(ids)
                }
            }
        };

        out.extend(quote! {
            impl juniper_eager_loading::LoadFrom<#id_ty> for #self_ty {
                type Error = #error_ty;
                type Connection = #connection_ty;

                fn load(
                    ids: &[#id_ty],
                    _field_args: &(),
                    db: &Self::Connection,
                ) -> Result<Vec<Self>, Self::Error> {
                    #table::table
                    .filter(#filter)
                        .load::<#self_ty>(db)
                        .map_err(From::from)
                }
            }
        });
    }
}

impl HasMany {
    fn gen_tokens(&self, input: &Input, backend: &Backend, out: &mut TokenStream) {
        let error_ty = &input.error_ty;
        let connection_ty = &input.connection_ty;

        let join_ty = &self.join_ty;
        let join_from = &self.join_from;
        let table = &self.table;
        let join_to = &self.join_to;
        let self_ty = &self.self_ty;

        let filter = match backend {
            Backend::Pg => {
                quote! {
                    #table::#join_to.eq(diesel::pg::expression::dsl::any(from_ids))
                }
            }
            Backend::Mysql | Backend::Sqlite => {
                quote! {
                    #table::#join_to.eq_any(from_ids)
                }
            }
        };

        out.extend(quote! {
            impl juniper_eager_loading::LoadFrom<#join_ty> for #self_ty {
                type Error = #error_ty;
                type Connection = #connection_ty;

                fn load(
                    froms: &[#join_ty],
                    _field_args: &(),
                    db: &Self::Connection,
                ) -> Result<Vec<Self>, Self::Error> {
                    let from_ids = froms
                        .iter()
                        .map(|other| other.#join_from)
                        .collect::<Vec<_>>();

                    #table::table
                        .filter(#filter)
                        .load(db)
                        .map_err(From::from)
                }
            }
        })
    }
}