yeter_macros/
lib.rs

1use proc_macro2::{Ident, Span, TokenStream};
2use proc_macro_error::*;
3use quote::quote;
4use syn::punctuated::Punctuated;
5use syn::{
6    parse_quote, Attribute, Expr, ExprField, ExprPath, ExprTuple, FnArg, ForeignItemFn,
7    GenericArgument, GenericParam, Index, ItemFn, Member, Pat, PatIdent, PatType, Path,
8    PathArguments, PathSegment, ReturnType, Signature, Token, Type, TypePath, TypeReference,
9    TypeTuple, Visibility, WhereClause,
10};
11
12fn fn_arg_to_type(arg: &FnArg) -> &Type {
13    match arg {
14        FnArg::Receiver(_) => unimplemented!(),
15        FnArg::Typed(arg) => arg.ty.as_ref(),
16    }
17}
18
19fn build_type_tuple(types: impl Iterator<Item = Type>) -> Type {
20    let mut elems = types.collect::<Punctuated<_, Token![,]>>();
21    if !elems.is_empty() {
22        elems.push_punct(Default::default());
23    }
24
25    Type::Tuple(TypeTuple {
26        paren_token: Default::default(),
27        elems,
28    })
29}
30
31fn build_unit_tuple() -> Type {
32    build_type_tuple([].into_iter())
33}
34
35fn arg_name(arg: &FnArg) -> Option<Ident> {
36    match arg {
37        FnArg::Receiver(_) => Some(Ident::new("self", Span::call_site())),
38        FnArg::Typed(pat_type) => {
39            if let Pat::Ident(name) = pat_type.pat.as_ref() {
40                Some(name.ident.clone())
41            } else {
42                None
43            }
44        }
45    }
46}
47
48fn arg_names<'a>(args: impl Iterator<Item = &'a FnArg>) -> Vec<Ident> {
49    args.enumerate()
50        .map(|(n, arg)| {
51            arg_name(arg).unwrap_or_else(|| Ident::new(&format!("arg{n}"), Span::mixed_site()))
52        })
53        .collect()
54}
55
56fn calling_tuple_args(idents: impl Iterator<Item = (Ident, Type)>) -> Punctuated<FnArg, Token![,]> {
57    idents
58        .map(|(name, typ)| {
59            FnArg::Typed(PatType {
60                attrs: Default::default(),
61                pat: Box::new(Pat::Ident(PatIdent {
62                    attrs: Default::default(),
63                    by_ref: None,
64                    mutability: None,
65                    subpat: None,
66                    ident: name,
67                })),
68                colon_token: Default::default(),
69                ty: Box::new(typ),
70            })
71        })
72        .collect()
73}
74
75fn build_ident_tuple(idents: impl Iterator<Item = Ident>) -> Expr {
76    let mut elems = idents
77        .map(ident_to_expr)
78        .collect::<Punctuated<_, Token![,]>>();
79    if !elems.is_empty() {
80        elems.push_punct(Default::default());
81    }
82
83    ExprTuple {
84        attrs: Default::default(),
85        paren_token: Default::default(),
86        elems,
87    }
88    .into()
89}
90
91fn ident_to_expr(id: Ident) -> Expr {
92    ExprPath {
93        attrs: Default::default(),
94        qself: Default::default(),
95        path: Path::from(id),
96    }
97    .into()
98}
99
100/// Converts generic arguments to generic params (effectively dismissing all ": ???" bounds)
101fn use_generic_args(
102    generics: &Punctuated<GenericParam, Token![,]>,
103) -> Punctuated<GenericArgument, Token![,]> {
104    generics
105        .iter()
106        .map(|p| match p {
107            GenericParam::Type(t) => GenericArgument::Type(
108                TypePath {
109                    qself: None,
110                    path: t.ident.clone().into(),
111                }
112                .into(),
113            ),
114            GenericParam::Lifetime(l) => GenericArgument::Lifetime(l.lifetime.clone()),
115            GenericParam::Const(c) => GenericArgument::Const(ident_to_expr(c.ident.clone())),
116        })
117        .collect()
118}
119
120fn generic_args_phantom(generics: &Punctuated<GenericArgument, Token![,]>) -> Type {
121    build_type_tuple(generics.iter().filter_map(|a| {
122        match a {
123            GenericArgument::Binding(_) => unreachable!(),
124            GenericArgument::Constraint(_) => unreachable!(),
125            GenericArgument::Type(t) => Some(t.clone()),
126            GenericArgument::Const(_) => None,
127            GenericArgument::Lifetime(lt) => Some(
128                TypeReference {
129                    lifetime: Some(lt.clone()),
130                    mutability: None,
131                    and_token: Default::default(),
132                    elem: Box::new(build_unit_tuple()),
133                }
134                .into(),
135            ),
136        }
137    }))
138}
139
140#[proc_macro_error]
141#[proc_macro_attribute]
142pub fn query(
143    attr: proc_macro::TokenStream,
144    item: proc_macro::TokenStream,
145) -> proc_macro::TokenStream {
146    if !attr.is_empty() {
147        emit_error!(
148            TokenStream::from(attr),
149            "#[yeter::query] doesn't expect any attributes"
150        );
151    }
152
153    let mut function_no_impl;
154    let mut function_impl;
155    let function = {
156        if let Ok(f) = syn::parse::<ForeignItemFn>(item.clone()) {
157            function_no_impl = f;
158            &mut function_no_impl as &mut dyn FunctionItem
159        } else if let Ok(f) = syn::parse::<ItemFn>(item.clone()) {
160            function_impl = f;
161            &mut function_impl as &mut dyn FunctionItem
162        } else {
163            let item = TokenStream::from(item);
164            return (quote! { compile_error!("expected fn item"); #item }).into();
165        }
166    };
167
168    let query_attrs = function.take_attrs();
169    let fn_args = &function.sig().inputs;
170    let query_args = fn_args
171        .iter()
172        .skip(1)
173        .map(fn_arg_to_type)
174        .cloned()
175        .collect::<Vec<_>>();
176
177    let db_ident_fallback = Ident::new("db", Span::call_site());
178    match fn_args.first() {
179        // self, &self, &mut self
180        Some(receiver @ FnArg::Receiver(_)) => {
181            emit_error!(
182                receiver,
183                "#[yeter::query] can't be used on instance methods";
184                hint = "did you mean `db: &yeter::Database`?";
185            );
186
187            &db_ident_fallback
188        }
189        Some(FnArg::Typed(pat_type)) => match pat_type.pat.as_ref() {
190            Pat::Ident(ident) => &ident.ident,
191            _ => {
192                emit_error!(
193                    pat_type.pat,
194                    "simple database argument pattern expected";
195                    help = "use a simple argument declaration such as `db: &yeter::Database`";
196                );
197
198                &db_ident_fallback
199            }
200        },
201        None => {
202            emit_error!(
203                function.sig(), "a query must take a database as its first argument";
204                note = "no arguments were specified";
205            );
206
207            &db_ident_fallback
208        }
209    };
210
211    let fn_arg_count = fn_args.len() as u32;
212    let query_arg_count = if fn_arg_count == 0 {
213        0
214    } else {
215        fn_arg_count - 1
216    };
217
218    let unit_type;
219
220    let query_vis = &function.vis();
221    let query_name = &function.sig().ident;
222    let generics = &function.sig().generics;
223    let generics_params = &generics.params;
224    let generics_where = &generics.where_clause;
225    let generics_args = use_generic_args(generics_params);
226    let generics_phantom = generic_args_phantom(&generics_args);
227
228    let input_type = build_type_tuple(query_args.iter().cloned());
229    let output_type = match &function.sig().output {
230        ReturnType::Default => {
231            unit_type = build_unit_tuple();
232            &unit_type
233        }
234        ReturnType::Type(_, typ) => typ.as_ref(),
235    };
236
237    let calling_arg_names = arg_names(fn_args.iter().skip(1));
238
239    let calling_tuple_args = calling_tuple_args(calling_arg_names.iter().cloned().zip(query_args));
240    let calling_tuple = build_ident_tuple(calling_arg_names.into_iter());
241
242    let call_ident_span = Span::call_site().located_at(query_name.span());
243    // When Span::def_site is stable, we will be able to properly create hygienic idents
244    let call_ident = Ident::new(&format!("__yeter_{query_name}"), call_ident_span);
245
246    let to_function_impl = function.to_function_impl(&call_ident, generics_params, output_type);
247    let to_function_call = function.to_function_call(&call_ident, query_arg_count);
248    let to_additional_impl = function.to_additional_impl(
249        query_name,
250        generics_params,
251        &generics_args,
252        generics_where,
253        output_type,
254    );
255
256    let expanded = quote! {
257        #(#query_attrs)*
258        #query_vis fn #query_name<#generics_params>(db: &::yeter::Database, #calling_tuple_args) -> ::std::rc::Rc<#output_type>
259            #generics_where
260        {
261            #to_function_impl
262            db.run::<_, #query_name::<#generics_args>>(#to_function_call, #calling_tuple)
263        }
264
265        #[allow(non_camel_case_types)]
266        #[doc(hidden)]
267        #query_vis enum #query_name<#generics_params> {
268            Phantom(std::convert::Infallible, std::marker::PhantomData<#generics_phantom>),
269        }
270
271        impl<#generics_params> ::yeter::QueryDef for #query_name<#generics_args> #generics_where {
272            type Input = #input_type;
273            type Output = #output_type;
274        }
275
276        #to_additional_impl
277    };
278
279    set_dummy(expanded.clone()); // Still produce these tokens if an error was emitted
280    expanded.into()
281}
282
283trait FunctionItem {
284    fn take_attrs(&mut self) -> Vec<Attribute>;
285    fn vis(&self) -> &Visibility;
286    fn sig(&self) -> &Signature;
287
288    fn to_function_impl(
289        &self,
290        _call_ident: &Ident,
291        _generics_params: &Punctuated<GenericParam, Token![,]>,
292        _output_type: &Type,
293    ) -> TokenStream {
294        quote! {}
295    }
296
297    fn to_function_call(&self, _call_ident: &Ident, _query_arg_count: u32) -> TokenStream;
298
299    fn to_additional_impl(
300        &self,
301        _query_name: &Ident,
302        _generics_params: &Punctuated<GenericParam, Token![,]>,
303        _generics_args: &Punctuated<GenericArgument, Token![,]>,
304        _generics_where: &Option<WhereClause>,
305        _output_type: &Type,
306    ) -> TokenStream {
307        quote! {}
308    }
309}
310
311fn guess_option_inner_type(option: &Type) -> Type {
312    if let Type::Path(path) = option {
313        match path.path.segments.last() {
314            Some(seg) if seg.ident != "Option" => {
315                emit_error!(seg.ident, "expected `Option` type",);
316            }
317            Some(PathSegment {
318                arguments: PathArguments::AngleBracketed(angle),
319                ..
320            }) if angle.args.len() == 1 => match angle.args.first().unwrap() {
321                GenericArgument::Type(t) => return t.clone(),
322                o => {
323                    emit_error!(o, "unexpected generic argument for Option type",);
324                }
325            },
326            Some(seg) => {
327                emit_error!(seg, "expected Option<T> return type",);
328            }
329            None => {
330                emit_error!(path, "expected Option<T> return type",);
331            }
332        }
333    };
334
335    parse_quote! { Option<()> }
336}
337
338impl FunctionItem for ForeignItemFn {
339    fn take_attrs(&mut self) -> Vec<Attribute> {
340        std::mem::take(&mut self.attrs)
341    }
342
343    fn vis(&self) -> &Visibility {
344        &self.vis
345    }
346
347    fn sig(&self) -> &Signature {
348        &self.sig
349    }
350
351    fn to_function_call(&self, _call_ident: &Ident, _query_arg_count: u32) -> TokenStream {
352        quote! {
353            |_db, _input| None
354        }
355    }
356
357    fn to_additional_impl(
358        &self,
359        query_name: &Ident,
360        generics_params: &Punctuated<GenericParam, Token![,]>,
361        generics_args: &Punctuated<GenericArgument, Token![,]>,
362        generics_where: &Option<WhereClause>,
363        output_type: &Type,
364    ) -> TokenStream {
365        // Output should be an option; we try to guess what could be inside
366        let output_type = guess_option_inner_type(output_type);
367
368        quote! {
369            impl<#generics_params> ::yeter::InputQueryDef for #query_name<#generics_args> #generics_where {
370                type OptionalOutput = #output_type;
371            }
372        }
373    }
374}
375
376impl FunctionItem for ItemFn {
377    fn take_attrs(&mut self) -> Vec<Attribute> {
378        std::mem::take(&mut self.attrs)
379    }
380
381    fn vis(&self) -> &Visibility {
382        &self.vis
383    }
384
385    fn sig(&self) -> &Signature {
386        &self.sig
387    }
388
389    fn to_function_impl(
390        &self,
391        call_ident: &Ident,
392        _generics_params: &Punctuated<GenericParam, Token![,]>,
393        _output_type: &Type,
394    ) -> TokenStream {
395        let mut s = self.clone();
396        s.sig.ident = call_ident.clone();
397
398        quote! {
399            #[allow(clippy::needless_lifetimes)]
400            #s
401        }
402    }
403
404    fn to_function_call(&self, call_ident: &Ident, query_arg_count: u32) -> TokenStream {
405        let db_ident = Ident::new("db", Span::mixed_site());
406        let input_ident = Ident::new("input", Span::mixed_site());
407        let input_ident_expr = Box::new(ident_to_expr(input_ident.clone()));
408        let calling_args = (0..query_arg_count)
409            .map(|n| {
410                Expr::Field(ExprField {
411                    attrs: Default::default(),
412                    base: input_ident_expr.clone(),
413                    dot_token: Default::default(),
414                    member: Member::Unnamed(Index {
415                        index: n,
416                        span: Span::mixed_site(),
417                    }),
418                })
419            })
420            .collect::<Punctuated<_, Token![,]>>();
421
422        quote! {
423            |#db_ident, #input_ident| #call_ident(#db_ident, #calling_args)
424        }
425    }
426}