Skip to main content

duvet_macros/
lib.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(clippy::manual_unwrap_or_default)] // `FromMeta` currently generates clippy warnings
5
6use darling::FromMeta;
7use proc_macro::TokenStream;
8use quote::{quote, ToTokens};
9use syn::{parse_macro_input, ItemFn};
10
11#[derive(Debug, FromMeta)]
12struct QueryArgs {
13    #[darling(default)]
14    cache: bool,
15    #[darling(default)]
16    delegate: bool,
17    #[darling(default)]
18    spawn: bool,
19}
20
21#[proc_macro_attribute]
22pub fn query(args: TokenStream, input: TokenStream) -> TokenStream {
23    let attr_args = match darling::ast::NestedMeta::parse_meta_list(args.into()) {
24        Ok(v) => v,
25        Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
26    };
27    let mut fun = parse_macro_input!(input as ItemFn);
28
29    let args = match QueryArgs::from_list(&attr_args) {
30        Ok(v) => v,
31        Err(e) => return TokenStream::from(e.write_errors()),
32    };
33
34    if !args.cache {
35        basic_query(&mut fun, args.delegate, args.spawn);
36    } else if fun.sig.inputs.is_empty() {
37        global_query(&mut fun, args.delegate, args.spawn);
38    } else {
39        cache_query(&mut fun, args.delegate, args.spawn);
40    }
41
42    quote!(#fun).into()
43}
44
45fn basic_query(fun: &mut ItemFn, delegate: bool, spawn: bool) {
46    // TODO
47    let _ = spawn;
48
49    let is_async = fun.sig.asyncness.is_some();
50    fun.sig.asyncness = None;
51    replace_output(fun);
52
53    let new = if delegate {
54        quote!(delegate)
55    } else {
56        quote!(new)
57    };
58
59    let block = &fun.block;
60    let block = if is_async {
61        quote!(#new(async move #block))
62    } else {
63        quote!(from(#block))
64    };
65
66    *fun.block = syn::parse_quote!({
67        ::duvet_core::Query::#block
68    });
69}
70
71fn global_query(fun: &mut ItemFn, delegate: bool, spawn: bool) {
72    // TODO
73    let _ = spawn;
74
75    let is_async = fun.sig.asyncness.is_some();
76    fun.sig.asyncness = None;
77    replace_output(fun);
78
79    if !fun.sig.inputs.is_empty() {
80        panic!("global query arguments must be empty");
81    }
82
83    let new = if delegate {
84        quote!(delegate)
85    } else {
86        quote!(new)
87    };
88    let block = &fun.block;
89    let block = if is_async {
90        quote!(#new(async #block))
91    } else {
92        quote!(from(#block))
93    };
94    *fun.block = syn::parse_quote!({
95        #[derive(Copy, Clone, Hash, PartialEq, Eq)]
96        struct Query;
97        ::duvet_core::Cache::current().get_or_init_global(Query, move || {
98            ::duvet_core::Query::#block
99        })
100    });
101}
102
103fn cache_query(fun: &mut ItemFn, delegate: bool, spawn: bool) {
104    // TODO
105    let _ = spawn;
106
107    let is_async = fun.sig.asyncness.is_some();
108    fun.sig.asyncness = None;
109    replace_output(fun);
110
111    let mut inject_tokens = quote!();
112    let mut join_alias = quote!();
113    let mut join_args = quote!();
114    let mut hash = quote!();
115
116    for input in core::mem::take(&mut fun.sig.inputs).into_pairs() {
117        let (mut input, punc) = input.into_tuple();
118
119        let mut should_push = true;
120
121        if let syn::FnArg::Typed(ref mut input) = input {
122            let mut is_ignored = false;
123            let mut inject = None;
124
125            // TODO add custom hasher attribute
126
127            input.attrs.retain(|attr| {
128                if attr.path().is_ident("skip") {
129                    is_ignored = true;
130                    false
131                } else if attr.path().is_ident("inject") {
132                    if let syn::Meta::List(meta_list) = &attr.meta {
133                        inject = Some(meta_list.tokens.clone());
134                    }
135                    should_push = false;
136                    false
137                } else {
138                    true
139                }
140            });
141
142            if !is_ignored {
143                let pat = &input.pat;
144
145                if let Some(inject) = inject {
146                    quote!(#[allow(unused_parens)] let #pat = #inject;)
147                        .to_tokens(&mut inject_tokens);
148                }
149
150                if is_query_arg(&input.ty) {
151                    quote!(let #pat = #pat.get();).to_tokens(&mut join_alias);
152                    quote!(#pat,).to_tokens(&mut join_args);
153                }
154
155                quote!(::core::hash::Hash::hash(&#pat, &mut hasher);).to_tokens(&mut hash);
156            }
157        }
158
159        if should_push {
160            fun.sig.inputs.push(input);
161            if let Some(punc) = punc {
162                fun.sig.inputs.push_punct(punc);
163            }
164        }
165    }
166
167    let new = if delegate {
168        quote!(delegate)
169    } else {
170        quote!(new)
171    };
172    let block = &fun.block;
173    let block = if is_async {
174        quote!(#new(async move #block))
175    } else {
176        quote!(from(#block))
177    };
178    *fun.block = syn::parse_quote!({
179        ::duvet_core::Query::delegate(async move {
180            #inject_tokens
181
182            let key = {
183                use ::duvet_core::macro_support::tokio;
184
185                #join_alias
186                let (#join_args) = tokio::join!(#join_args);
187
188                let mut hasher = ::duvet_core::hash::Hasher::default();
189
190                #hash
191
192                hasher.finish()
193            };
194
195            ::duvet_core::Cache::current().get_or_init(key, move || {
196                ::duvet_core::Query::#block
197            })
198        })
199    });
200}
201
202fn is_query_arg(ty: &syn::Type) -> bool {
203    if let syn::Type::Path(path) = ty {
204        if path.qself.is_some() {
205            return false;
206        }
207
208        if path.path.leading_colon.is_some() {
209            return false;
210        }
211
212        if path.path.segments.len() != 1 {
213            return false;
214        }
215
216        let seg = &path.path.segments[0];
217
218        if seg.ident != "Query" {
219            return false;
220        }
221
222        if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
223            args.args.len() == 1
224        } else {
225            false
226        }
227    } else {
228        false
229    }
230}
231
232fn replace_output(fun: &mut ItemFn) -> Box<syn::Type> {
233    let output = core::mem::replace(&mut fun.sig.output, syn::ReturnType::Default);
234    match output {
235        syn::ReturnType::Default => {
236            todo!("cannot return an empty query");
237        }
238        syn::ReturnType::Type(arrow, ty) => {
239            fun.sig.output =
240                syn::ReturnType::Type(arrow, Box::new(syn::parse_quote!(::duvet_core::Query<#ty>)));
241            ty
242        }
243    }
244}