duvet_macros/
lib.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(clippy::manual_unwrap_or_default)] // `FromMeta` currently generates clippy warnings
5
6use darling::FromMeta;
7use proc_macro::TokenStream;
8use quote::{quote, ToTokens};
9use syn::{parse_macro_input, AttributeArgs, ItemFn};
10
11#[derive(Debug, FromMeta)]
12struct QueryArgs {
13    #[darling(default)]
14    cache: bool,
15    #[darling(default)]
16    delegate: bool,
17    #[darling(default)]
18    spawn: bool,
19}
20
21#[proc_macro_attribute]
22pub fn query(args: TokenStream, input: TokenStream) -> TokenStream {
23    let attr_args = parse_macro_input!(args as AttributeArgs);
24    let mut fun = parse_macro_input!(input as ItemFn);
25
26    let args = match QueryArgs::from_list(&attr_args) {
27        Ok(v) => v,
28        Err(e) => return TokenStream::from(e.write_errors()),
29    };
30
31    if !args.cache {
32        basic_query(&mut fun, args.delegate, args.spawn);
33    } else if fun.sig.inputs.is_empty() {
34        global_query(&mut fun, args.delegate, args.spawn);
35    } else {
36        cache_query(&mut fun, args.delegate, args.spawn);
37    }
38
39    quote!(#fun).into()
40}
41
42fn basic_query(fun: &mut ItemFn, delegate: bool, spawn: bool) {
43    // TODO
44    let _ = spawn;
45
46    let is_async = fun.sig.asyncness.is_some();
47    fun.sig.asyncness = None;
48    replace_output(fun);
49
50    let new = if delegate {
51        quote!(delegate)
52    } else {
53        quote!(new)
54    };
55
56    let block = &fun.block;
57    let block = if is_async {
58        quote!(#new(async move #block))
59    } else {
60        quote!(from(#block))
61    };
62
63    fun.block = Box::new(syn::parse_quote!({
64        ::duvet_core::Query::#block
65    }));
66}
67
68fn global_query(fun: &mut ItemFn, delegate: bool, spawn: bool) {
69    // TODO
70    let _ = spawn;
71
72    let is_async = fun.sig.asyncness.is_some();
73    fun.sig.asyncness = None;
74    replace_output(fun);
75
76    if !fun.sig.inputs.is_empty() {
77        panic!("global query arguments must be empty");
78    }
79
80    let new = if delegate {
81        quote!(delegate)
82    } else {
83        quote!(new)
84    };
85    let block = &fun.block;
86    let block = if is_async {
87        quote!(#new(async #block))
88    } else {
89        quote!(from(#block))
90    };
91    fun.block = Box::new(syn::parse_quote!({
92        #[derive(Copy, Clone, Hash, PartialEq, Eq)]
93        struct Query;
94        ::duvet_core::Cache::current().get_or_init_global(Query, move || {
95            ::duvet_core::Query::#block
96        })
97    }));
98}
99
100fn cache_query(fun: &mut ItemFn, delegate: bool, spawn: bool) {
101    // TODO
102    let _ = spawn;
103
104    let is_async = fun.sig.asyncness.is_some();
105    fun.sig.asyncness = None;
106    replace_output(fun);
107
108    let mut inject_tokens = quote!();
109    let mut join_alias = quote!();
110    let mut join_args = quote!();
111    let mut hash = quote!();
112
113    for input in core::mem::take(&mut fun.sig.inputs).into_pairs() {
114        let (mut input, punc) = input.into_tuple();
115
116        let mut should_push = true;
117
118        if let syn::FnArg::Typed(ref mut input) = input {
119            let mut is_ignored = false;
120            let mut inject = None;
121
122            // TODO add custom hasher attribute
123
124            input.attrs.retain(|attr| {
125                if attr.path.is_ident("skip") {
126                    is_ignored = true;
127                    false
128                } else if attr.path.is_ident("inject") {
129                    inject = Some(attr.tokens.clone());
130                    should_push = false;
131                    false
132                } else {
133                    true
134                }
135            });
136
137            if !is_ignored {
138                let pat = &input.pat;
139
140                if let Some(inject) = inject {
141                    quote!(#[allow(unused_parens)] let #pat = #inject;)
142                        .to_tokens(&mut inject_tokens);
143                }
144
145                if is_query_arg(&input.ty) {
146                    quote!(let #pat = #pat.get();).to_tokens(&mut join_alias);
147                    quote!(#pat,).to_tokens(&mut join_args);
148                }
149
150                quote!(::core::hash::Hash::hash(&#pat, &mut hasher);).to_tokens(&mut hash);
151            }
152        }
153
154        if should_push {
155            fun.sig.inputs.push(input);
156            if let Some(punc) = punc {
157                fun.sig.inputs.push_punct(punc);
158            }
159        }
160    }
161
162    let new = if delegate {
163        quote!(delegate)
164    } else {
165        quote!(new)
166    };
167    let block = &fun.block;
168    let block = if is_async {
169        quote!(#new(async move #block))
170    } else {
171        quote!(from(#block))
172    };
173    fun.block = Box::new(syn::parse_quote!({
174        ::duvet_core::Query::delegate(async move {
175            #inject_tokens
176
177            let key = {
178                use ::duvet_core::macro_support::tokio;
179
180                #join_alias
181                let (#join_args) = tokio::join!(#join_args);
182
183                let mut hasher = ::duvet_core::hash::Hasher::default();
184
185                #hash
186
187                hasher.finish()
188            };
189
190            ::duvet_core::Cache::current().get_or_init(key, move || {
191                ::duvet_core::Query::#block
192            })
193        })
194    }));
195}
196
197fn is_query_arg(ty: &syn::Type) -> bool {
198    if let syn::Type::Path(path) = ty {
199        if path.qself.is_some() {
200            return false;
201        }
202
203        if path.path.leading_colon.is_some() {
204            return false;
205        }
206
207        if path.path.segments.len() != 1 {
208            return false;
209        }
210
211        let seg = &path.path.segments[0];
212
213        if seg.ident != "Query" {
214            return false;
215        }
216
217        if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
218            args.args.len() == 1
219        } else {
220            false
221        }
222    } else {
223        false
224    }
225}
226
227fn replace_output(fun: &mut ItemFn) -> Box<syn::Type> {
228    let output = core::mem::replace(&mut fun.sig.output, syn::ReturnType::Default);
229    match output {
230        syn::ReturnType::Default => {
231            todo!("cannot return an empty query");
232        }
233        syn::ReturnType::Type(arrow, ty) => {
234            fun.sig.output =
235                syn::ReturnType::Type(arrow, Box::new(syn::parse_quote!(::duvet_core::Query<#ty>)));
236            ty
237        }
238    }
239}