1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
extern crate proc_macro;

mod attrs;

use crate::attrs::*;
use quote::*;
use std::collections::HashSet;
use syn::punctuated::*;
use syn::*;

#[proc_macro_derive(Query, attributes(query))]
pub fn define_query(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = syn::parse_macro_input!(input as DeriveInput);

    let attrs = match Attrs::from_input(&input) {
        Ok(attrs) => attrs,
        Err(err) => return err.to_compile_error().into(),
    };

    let (sql, params) = match attrs.parse_sql_literal() {
        Ok(sql) => sql,
        Err(err) => return err.to_compile_error().into(),
    };

    let mut sorted_params = params.iter().collect::<Vec<_>>();
    sorted_params.sort_by_key(|(_, index)| *index);

    let sorted_params = sorted_params
        .into_iter()
        .map(|(ident, _)| ident)
        .collect::<Vec<_>>();

    match input.data {
        Data::Struct(DataStruct {
            fields,
            ..
        }) => {
            let fields = match fields {
                Fields::Named(fields) => fields.named,
                Fields::Unit => Default::default(),
                Fields::Unnamed(fields) => return Error::new_spanned(fields, "Cannot derive `Query` for tuple struct")
                    .to_compile_error().into()
            };

            let references_trait_object = fields
                .iter()
                .filter(|field| match &field.ty {
                    Type::Reference(reference) => match *reference.elem {
                        Type::TraitObject(_) => true,
                        _ => false,
                    },
                    _ => false,
                })
                .filter_map(|field| field.ident.as_ref())
                .collect::<HashSet<_>>();

            let param_borrow = sorted_params.iter().map(|ident| {
                if references_trait_object.contains(ident) {
                    None
                } else {
                    Some(quote!(&))
                }
            });

            let ident = input.ident;
            let param_count = params.len();

            let where_clause = input.generics.where_clause;
            let generic_params = input.generics.params.into_iter().collect::<Vec<_>>();
            let generic_tokens = generic_params
                .iter()
                .map(|param| match param {
                    GenericParam::Type(TypeParam { ident, .. }) => quote!(#ident),
                    GenericParam::Lifetime(LifetimeDef { lifetime, .. }) => quote!(#lifetime),
                    GenericParam::Const(ConstParam { ident, .. }) => quote!(#ident),
                })
                .collect::<Punctuated<_, Token![,]>>();

            let output = quote! {
                impl<'__query_params #(, #generic_params)*> ::postgres_query::Query<'__query_params> for #ident <#generic_tokens> #where_clause {
                    type Sql = &'static str;
                    type Params = [&'__query_params dyn ::postgres::types::ToSql; #param_count];

                    fn sql(&self) -> Self::Sql {
                        #sql
                    }

                    fn params(&'__query_params self) -> Self::Params {
                        [
                            #( #param_borrow self.#sorted_params, )*
                        ]
                    }
                }
            };

            output
        }

        Data::Enum(data) => Error::new_spanned(
            data.enum_token,
            "Cannot derive `Query` for enum",
        )
        .to_compile_error(),

        Data::Union(data) => Error::new_spanned(
            data.union_token,
            "Cannot derive `Query` for union",
        )
        .to_compile_error(),
    }.into()
}