incremental_query_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro2::TokenStream as TokenStream2;
4use proc_macro_error::{
5    abort, abort_call_site, emit_call_site_warning, emit_error, emit_warning, proc_macro_error,
6};
7use quote::{quote, ToTokens};
8use syn::parse_quote;
9use syn::parse_quote_spanned;
10use syn::{
11    parse::Parser, parse_macro_input, spanned::Spanned, Attribute, Block, Expr, FnArg,
12    Ident, ItemFn, Lifetime, LifetimeParam, MetaNameValue, Pat, PatType, Path, PathArguments,
13    ReturnType, Signature, Token, Type, Visibility,
14};
15
16struct QueryInputs {
17    inputs_dereffed: Vec<PatType>,
18    inputs_without_context: Vec<PatType>,
19    input_dereffed_types_without_context: Vec<Type>,
20    context: Ident,
21    context_typaram: Option<Type>,
22}
23
24struct QuerySignature {
25    constness: Option<Token![const]>,
26    asyncness: Option<Token![async]>,
27    fn_token: Token![fn],
28    ident: Ident,
29    query_lifetime: Lifetime,
30    inputs: QueryInputs,
31    output: Type,
32    output_ref: Type,
33}
34
35struct Query {
36    mode: QueryMode,
37    attrs: Vec<Attribute>,
38    vis: Visibility,
39    sig: QuerySignature,
40    block: Box<Block>,
41}
42
43fn has_lt_attr(lifetime: &&LifetimeParam) -> bool {
44    lifetime
45        .attrs
46        .iter()
47        .filter_map(get_path)
48        .any(|i| i.is_ident("lt"))
49}
50
51fn pat_as_ident(pat: &Pat) -> Ident {
52    match try_pat_as_ident(pat) {
53        Ok(i) => i,
54        Err(e) => abort!(e, "expected identifier"),
55    }
56}
57
58fn try_pat_as_ident(pat: &Pat) -> Result<Ident, &Pat> {
59    match pat {
60        Pat::Ident(i) => Ok(i.ident.clone()),
61        x => Err(x),
62    }
63}
64
65fn is_context(inner: &Type, lifetime: &Lifetime) -> bool {
66    match inner {
67        Type::Paren(i) => is_context(&i.elem, lifetime),
68        Type::Path(p) => {
69            if let Some(last) = p.path.segments.last() {
70                if last.ident == "Context" {
71                    p.qself.is_none()
72                } else {
73                    false
74                }
75            } else {
76                false
77            }
78        }
79        Type::Ptr(i) => {
80            if is_context(&i.elem, lifetime) {
81                emit_warning!(
82                    i,
83                    "did you mean this to be a reference to a context (`&Context<{}>`)",
84                    quote! {#lifetime}
85                );
86            }
87            false
88        }
89        Type::Reference(i) => {
90            if is_context(&i.elem, lifetime) {
91                emit_warning!(
92                    i,
93                    "did you mean this to be a reference to a context (`&Context<{}>`)",
94                    quote! {#lifetime}
95                );
96            }
97            false
98        }
99        _ => false,
100    }
101}
102
103fn is_context_ref(ty: &Type, lifetime: &Lifetime) -> bool {
104    match ty {
105        Type::Paren(i) => is_context_ref(&i.elem, lifetime),
106        Type::Path(i) => {
107            if i.path
108                .segments
109                .last()
110                .map(|i| i.ident == "Context")
111                .unwrap_or(false)
112            {
113                emit_warning!(
114                    i.path.segments.last().unwrap(),
115                    "did you mean &{}<{}>",
116                    quote! {#i},
117                    quote! {#lifetime}
118                );
119                false
120            } else {
121                false
122            }
123        }
124        Type::Ptr(p) => {
125            if is_context(&p.elem, lifetime) {
126                emit_warning!(
127                    p,
128                    "did you mean this to be a reference to a context (`&Context<{}>`)",
129                    quote! {#lifetime}
130                );
131            }
132
133            false
134        }
135        Type::Reference(r) => {
136            if is_context(&r.elem, lifetime) {
137                r.mutability.is_none()
138            } else {
139                false
140            }
141        }
142        _ => false,
143    }
144}
145
146fn get_ty_path(ty: &Type) -> Path {
147    match ty {
148        Type::Reference(x) => get_ty_path(&x.elem),
149        Type::Paren(x) => get_ty_path(&x.elem),
150        Type::Path(p) => p.clone().path,
151        ty => abort!(ty, "not a path"),
152    }
153}
154
155fn deref_type(t: &Type) -> Type {
156    match t {
157        Type::Reference(r) => {
158            r.elem.as_ref().clone()
159        }
160        x => {
161            abort!(
162                x,
163                "expected this input to be a reference `&{}`",
164                quote! {#x}
165            );
166        }
167    }
168}
169
170fn validate_inputs(inputs: impl IntoIterator<Item = FnArg>, lifetime: &Lifetime) -> QueryInputs {
171    let mut new_inputs = Vec::new();
172
173    let mut context = None;
174    let mut first_nonself_arg = None;
175    let mut idx = 0;
176
177    for i in inputs {
178        if let FnArg::Typed(PatType { pat, ty, .. }) = &i {
179            if first_nonself_arg.is_none() {
180                first_nonself_arg = Some(try_pat_as_ident(pat).map_err(|e| e.clone()));
181            }
182            if is_context_ref(ty, lifetime) {
183                context = Some((pat_as_ident(pat), idx))
184            }
185            idx += 1;
186        } else {
187            abort!(i, "queries may not have a receiver type");
188        }
189
190        if let FnArg::Typed(x) = i {
191            new_inputs.push(x);
192        }
193    }
194
195    if new_inputs.is_empty() {
196        abort_call_site!("queries must have at least one parameter which is `cx: &Context<{}>`")
197    }
198    if !new_inputs.is_empty() && first_nonself_arg.is_none() {
199        abort_call_site!("queries must have at least one parameter which is `cx: &Context<{}>`")
200    }
201
202    let (context, idx) = match context {
203        Some((context, idx)) => (context, idx),
204        _ => {
205            emit_call_site_warning!(
206                "queries must start with one parameter `cx: &Context<{}>`",
207                quote! {#lifetime}
208            );
209            match first_nonself_arg.unwrap() {
210                Ok(i) => (i, 0),
211                Err(e) => {
212                    abort!(e, "expected identifier for the first parameter of a query (which must have type `Context<{}>`)", quote! {#lifetime});
213                }
214            }
215        }
216    };
217
218    if idx != 0 {
219        emit_warning!(context, "expected context to be the first argument")
220    }
221
222    let inputs_without_context: Vec<_> = new_inputs
223        .iter()
224        .filter(|p| {
225            if let Ok(i) = try_pat_as_ident(&p.pat) {
226                i != context
227            } else {
228                true
229            }
230        })
231        .cloned()
232        .collect();
233
234    let input_types_without_context: Vec<_> = inputs_without_context
235        .iter()
236        .map(|i| i.ty.as_ref().clone())
237        .collect();
238    let mut input_dereffed_types_without_context = Vec::new();
239
240    // validate that they're all references
241    for i in &input_types_without_context {
242        input_dereffed_types_without_context.push(deref_type(i))
243    }
244
245    let context_ty = get_ty_path(
246        &new_inputs
247            .iter()
248            .find(|i| try_pat_as_ident(&i.pat).as_ref() == Ok(&context))
249            .expect("context")
250            .ty,
251    );
252    let PathArguments::AngleBracketed(arguments) = context_ty
253        .segments
254        .last()
255        .expect("path segment")
256        .arguments
257        .clone()
258    else {
259        abort!(
260            context_ty.segments.last().expect("path segment").arguments,
261            "unexpected path segment"
262        );
263    };
264    let generics = arguments.args;
265
266    let mut had_lifetime = false;
267    let mut had_generic = None;
268
269    for i in generics {
270        match i {
271            syn::GenericArgument::Lifetime(l) => {
272                if &l != lifetime {
273                    abort!(
274                        l,
275                        "expected `Context<{}>` but found `Context<{}>`",
276                        quote! {#lifetime},
277                        quote! {#l}
278                    );
279                }
280
281                if had_lifetime {
282                    abort!(
283                        l,
284                        "expected `Context<{}>` to have only one lifetime argument",
285                        quote! {#lifetime}
286                    );
287                }
288
289                had_lifetime = true;
290            }
291            syn::GenericArgument::Type(ref t) => {
292                if let Some(old) = had_generic.replace(t.clone()) {
293                    abort!(t, "expected at most one type argument `Contex<{}>` but found `Context<{}, ..., {}>`", quote!{#old}, quote!{#old}, quote!{#t});
294                }
295            }
296            syn::GenericArgument::Const(c) => abort!(c, "unexpected const argument on `Context`"),
297            syn::GenericArgument::AssocType(a) => {
298                abort!(a, "unexpected associated type on `Context`")
299            }
300            syn::GenericArgument::AssocConst(c) => {
301                abort!(c, "unexpected associated const on `Context`")
302            }
303            syn::GenericArgument::Constraint(c) => abort!(c, "unexpected constraint on `Context`"),
304            g => abort!(g, "unexpected generic argument on `Context`"),
305        }
306    }
307
308    let inputs_dereffed = new_inputs
309        .iter()
310        .map(|i@PatType { attrs, pat, colon_token, ty }| {
311            if try_pat_as_ident(&i.pat).as_ref() == Ok(&context) {
312                return i.clone();
313            }
314
315            let dereffed_ty = deref_type(ty);
316            parse_quote!(
317                #(#attrs)* #pat #colon_token #dereffed_ty
318            )
319        })
320        .collect();
321
322    QueryInputs {
323        inputs_dereffed,
324        inputs_without_context,
325        input_dereffed_types_without_context,
326        context,
327        context_typaram: had_generic.clone(),
328    }
329}
330
331fn validate_sig(
332    Signature {
333        constness,
334        asyncness,
335        unsafety,
336        abi,
337        fn_token,
338        ident,
339        generics,
340        paren_token: _,
341        inputs,
342        variadic,
343        output,
344    }: Signature,
345) -> QuerySignature {
346    let marked_lifetime = generics.lifetimes().find(has_lt_attr);
347    let cx_lifetime = generics.lifetimes().find(|i| i.lifetime.ident == "cx");
348    let Some(query_lifetime) = marked_lifetime
349        .or(cx_lifetime)
350        .map(|i| &i.lifetime)
351        .cloned()
352    else {
353        abort!(
354            generics,
355            "expected `'cx` lifetime or lifetime marked with #[lt] in the generics list"
356        )
357    };
358
359    if let Some(i) = unsafety {
360        abort!(i, "queries can't be unsafe");
361    }
362    if let Some(i) = abi {
363        abort!(i, "queries can't have an explicit abi");
364    }
365    if let Some(i) = variadic {
366        abort!(i, "queries can't be variadic");
367    }
368
369    QuerySignature {
370        constness,
371        asyncness,
372        fn_token,
373        ident,
374        inputs: validate_inputs(inputs, &query_lifetime),
375        output: match &output {
376            ReturnType::Default => parse_quote!(()),
377            ReturnType::Type(_, ty) => *ty.clone(),
378        },
379        output_ref: match output {
380            ReturnType::Default => parse_quote_spanned! {output.span() => & #query_lifetime ()},
381            ReturnType::Type(_, ty) => parse_quote!{& #query_lifetime #ty},
382        },
383        query_lifetime,
384    }
385}
386
387enum QueryAttr {
388    Mode(QueryMode),
389}
390
391fn get_string(e: &Expr) -> String {
392    match e {
393        Expr::Lit(l) => match &l.lit {
394            syn::Lit::Str(s) => s.value(),
395            l => abort!(l, "expected string literal"),
396        },
397        e => abort!(e, "expected string literal"),
398    }
399}
400
401fn parse_rerun(s: &str, span: Span) -> QueryMode {
402    match s {
403        "always" => QueryMode::Always,
404        "generation" => QueryMode::Generation,
405        _ => abort!(
406            span,
407            "unknown query mode, expected `always` or `generation`"
408        ),
409    }
410}
411
412fn parse_attr(attr: &Attribute) -> Option<QueryAttr> {
413    match &attr.meta {
414        syn::Meta::Path(_) => None,
415        syn::Meta::List(ml) => if ml.path.is_ident("rerun") {
416            match ml.parse_args::<Ident>() {
417                Err(e) => abort!(ml, "{}", e),
418                Ok(i) => Some(QueryAttr::Mode(parse_rerun(&i.to_string(), i.span()))),
419            }
420        } else {
421            None
422        },
423        syn::Meta::NameValue(MetaNameValue { path, value, .. }) => {
424            if path.is_ident("rerun") {
425                Some(QueryAttr::Mode(parse_rerun(
426                    &get_string(value),
427                    value.span(),
428                )))
429            } else {
430                None
431            }
432        }
433    }
434}
435
436fn validate(
437    ItemFn {
438        attrs,
439        vis,
440        sig,
441        block,
442    }: ItemFn,
443) -> Query {
444    let mut mode = QueryMode::Cache;
445
446    for attr in attrs.iter().filter_map(parse_attr) {
447        match attr {
448            QueryAttr::Mode(m) => mode = m,
449        }
450    }
451
452    Query {
453        mode,
454        attrs,
455        vis,
456        sig: validate_sig(sig),
457        block,
458    }
459}
460
461fn get_path(attr: &Attribute) -> Option<&Path> {
462    match &attr.meta {
463        // ok :)
464        syn::Meta::Path(p) => Some(p),
465        // uhhh
466        syn::Meta::List(_) => None,
467        syn::Meta::NameValue(_) => None,
468    }
469}
470
471fn assert_simple_attr(attr: TokenStream, expected: &str) -> Result<(), syn::Error> {
472    let attrs = Parser::parse(Attribute::parse_outer, attr)?;
473    for i in attrs {
474        match i.meta {
475            // ok :)
476            syn::Meta::Path(_) => {}
477            // uhhh
478            syn::Meta::List(ml) if ml.path.is_ident(expected) => {
479                emit_error!(ml, "expected an attribute without parameters")
480            }
481            syn::Meta::NameValue(mnv) if mnv.path.is_ident(expected) => {
482                emit_error!(mnv, "expected an attribute without this value")
483            }
484            // we don't care
485            _ => {}
486        }
487    }
488    Ok(())
489}
490
491fn tuple_from_types<T: ToTokens>(types: &[T]) -> TokenStream2 {
492    match types {
493        [] => quote! {()},
494        [x] => quote! {(#x,)},
495        x => quote! {(#(#x),*)},
496    }
497}
498
499fn or_unit(ty: &Option<Type>) -> TokenStream2 {
500    match ty {
501        Some(ty) => quote! {#ty},
502        None => quote! {()},
503    }
504}
505
506#[derive(Clone, Copy, PartialEq, Eq, Debug)]
507enum QueryMode {
508    Always,
509    Generation,
510    Cache,
511}
512
513#[proc_macro_error]
514#[proc_macro_attribute]
515pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
516    let input = parse_macro_input!(item as ItemFn);
517    if let Err(e) = assert_simple_attr(attr, "query") {
518        return e.to_compile_error().into();
519    }
520
521    let Query {
522        mode,
523        attrs,
524        vis,
525        sig:
526            QuerySignature {
527                constness,
528                asyncness,
529                fn_token,
530                ident,
531                inputs:
532                    QueryInputs {
533                        inputs_dereffed,
534                        context,
535                        inputs_without_context,
536                        input_dereffed_types_without_context,
537                        context_typaram,
538                    },
539                output,
540                output_ref,
541                query_lifetime,
542            },
543        block,
544    } = validate(input);
545
546    let query = quote! {incremental_query::Query};
547    let erased_query_run = quote! {incremental_query::ErasedQueryRun};
548    let input_type = tuple_from_types(&input_dereffed_types_without_context);
549    let input_type_dereffed = tuple_from_types(&input_dereffed_types_without_context);
550
551    let string_ident = ident.to_string();
552    let data_ty = or_unit(&context_typaram);
553    let context_ty = quote! {incremental_query::Context};
554    let type_erased_query_param = quote! {incremental_query::TypeErasedQueryParam};
555    let mode_ident = quote! {incremental_query::QueryMode};
556
557    let param_names = tuple_from_types(
558        &inputs_without_context
559            .iter()
560            .map(|i| &i.pat)
561            .collect::<Vec<_>>(),
562    );
563
564    let mode_fn = match mode {
565        QueryMode::Always => quote! {
566            fn mode(&self) -> #mode_ident {
567                #mode_ident::Always
568            }
569        },
570        QueryMode::Generation => quote! {
571            fn mode(&self) -> #mode_ident {
572                #mode_ident::Generation
573            }
574        },
575        QueryMode::Cache => quote! {
576            fn mode(&self) -> #mode_ident {
577                #mode_ident::Cache
578            }
579        },
580    };
581
582    quote! {
583        #(#attrs)*
584        #vis #constness #asyncness #fn_token #ident <#query_lifetime> (#(#inputs_dereffed),*) -> #output_ref {
585            #[derive(Copy, Clone)]
586            struct Q;
587
588            impl<#query_lifetime> #query<#query_lifetime, #data_ty> for Q {
589                type Input = #input_type;
590                type Output = #output;
591
592                const NAME: &'static str = #string_ident;
593
594                fn get_run_fn() -> #erased_query_run<#data_ty> {
595                    fn run<'cx>(
596                        cx: &#context_ty<'cx, #data_ty>,
597                        input: #type_erased_query_param<'cx>,
598                        should_alloc: &dyn Fn(u128) -> bool,
599                    ) -> (Option<#type_erased_query_param<'cx>>, u128) 
600                    {
601                        let input: &#input_type_dereffed = unsafe{input.get_ref()};
602                        let output = <Q as #query<'cx, #data_ty>>::run(cx, input);
603
604                        let output_hash = cx.hash(Q, &output);
605                        if should_alloc(output_hash) {
606                            (Some(#type_erased_query_param::new(cx.storage.alloc(output))), output_hash)
607                        } else {
608                            (None, output_hash)
609                        }
610                    }
611
612                    run
613                }
614
615                #mode_fn
616
617                fn run(#context: &#context_ty<#query_lifetime, #data_ty>, #param_names: &Self::Input) -> Self::Output #block
618            }
619
620            #context.query(Q, #param_names)
621        }
622
623    }
624    .into()
625}
626
627/// Marker for the query mode. Either `#[rerun(always)]` or `#[rerun(generation)]`.
628///
629/// Alternatives: `#[rerun = "always"]` and `#[rerun = "generation"]`.
630///
631/// Defaults to a cached query mode, which you mark by not giving a mode.
632#[proc_macro_attribute]
633#[proc_macro_error]
634pub fn rerun(_attr: TokenStream, item: TokenStream) -> TokenStream {
635    item
636}