1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro2::TokenStream as TokenStream2;
4use proc_macro_error::{
5 abort, abort_call_site, emit_call_site_warning, emit_error, emit_warning, proc_macro_error,
6};
7use quote::{quote, ToTokens};
8use syn::parse_quote;
9use syn::parse_quote_spanned;
10use syn::{
11 parse::Parser, parse_macro_input, spanned::Spanned, Attribute, Block, Expr, FnArg,
12 Ident, ItemFn, Lifetime, LifetimeParam, MetaNameValue, Pat, PatType, Path, PathArguments,
13 ReturnType, Signature, Token, Type, Visibility,
14};
15
16struct QueryInputs {
17 inputs_dereffed: Vec<PatType>,
18 inputs_without_context: Vec<PatType>,
19 input_dereffed_types_without_context: Vec<Type>,
20 context: Ident,
21 context_typaram: Option<Type>,
22}
23
24struct QuerySignature {
25 constness: Option<Token![const]>,
26 asyncness: Option<Token![async]>,
27 fn_token: Token![fn],
28 ident: Ident,
29 query_lifetime: Lifetime,
30 inputs: QueryInputs,
31 output: Type,
32 output_ref: Type,
33}
34
35struct Query {
36 mode: QueryMode,
37 attrs: Vec<Attribute>,
38 vis: Visibility,
39 sig: QuerySignature,
40 block: Box<Block>,
41}
42
43fn has_lt_attr(lifetime: &&LifetimeParam) -> bool {
44 lifetime
45 .attrs
46 .iter()
47 .filter_map(get_path)
48 .any(|i| i.is_ident("lt"))
49}
50
51fn pat_as_ident(pat: &Pat) -> Ident {
52 match try_pat_as_ident(pat) {
53 Ok(i) => i,
54 Err(e) => abort!(e, "expected identifier"),
55 }
56}
57
58fn try_pat_as_ident(pat: &Pat) -> Result<Ident, &Pat> {
59 match pat {
60 Pat::Ident(i) => Ok(i.ident.clone()),
61 x => Err(x),
62 }
63}
64
65fn is_context(inner: &Type, lifetime: &Lifetime) -> bool {
66 match inner {
67 Type::Paren(i) => is_context(&i.elem, lifetime),
68 Type::Path(p) => {
69 if let Some(last) = p.path.segments.last() {
70 if last.ident == "Context" {
71 p.qself.is_none()
72 } else {
73 false
74 }
75 } else {
76 false
77 }
78 }
79 Type::Ptr(i) => {
80 if is_context(&i.elem, lifetime) {
81 emit_warning!(
82 i,
83 "did you mean this to be a reference to a context (`&Context<{}>`)",
84 quote! {#lifetime}
85 );
86 }
87 false
88 }
89 Type::Reference(i) => {
90 if is_context(&i.elem, lifetime) {
91 emit_warning!(
92 i,
93 "did you mean this to be a reference to a context (`&Context<{}>`)",
94 quote! {#lifetime}
95 );
96 }
97 false
98 }
99 _ => false,
100 }
101}
102
103fn is_context_ref(ty: &Type, lifetime: &Lifetime) -> bool {
104 match ty {
105 Type::Paren(i) => is_context_ref(&i.elem, lifetime),
106 Type::Path(i) => {
107 if i.path
108 .segments
109 .last()
110 .map(|i| i.ident == "Context")
111 .unwrap_or(false)
112 {
113 emit_warning!(
114 i.path.segments.last().unwrap(),
115 "did you mean &{}<{}>",
116 quote! {#i},
117 quote! {#lifetime}
118 );
119 false
120 } else {
121 false
122 }
123 }
124 Type::Ptr(p) => {
125 if is_context(&p.elem, lifetime) {
126 emit_warning!(
127 p,
128 "did you mean this to be a reference to a context (`&Context<{}>`)",
129 quote! {#lifetime}
130 );
131 }
132
133 false
134 }
135 Type::Reference(r) => {
136 if is_context(&r.elem, lifetime) {
137 r.mutability.is_none()
138 } else {
139 false
140 }
141 }
142 _ => false,
143 }
144}
145
146fn get_ty_path(ty: &Type) -> Path {
147 match ty {
148 Type::Reference(x) => get_ty_path(&x.elem),
149 Type::Paren(x) => get_ty_path(&x.elem),
150 Type::Path(p) => p.clone().path,
151 ty => abort!(ty, "not a path"),
152 }
153}
154
155fn deref_type(t: &Type) -> Type {
156 match t {
157 Type::Reference(r) => {
158 r.elem.as_ref().clone()
159 }
160 x => {
161 abort!(
162 x,
163 "expected this input to be a reference `&{}`",
164 quote! {#x}
165 );
166 }
167 }
168}
169
170fn validate_inputs(inputs: impl IntoIterator<Item = FnArg>, lifetime: &Lifetime) -> QueryInputs {
171 let mut new_inputs = Vec::new();
172
173 let mut context = None;
174 let mut first_nonself_arg = None;
175 let mut idx = 0;
176
177 for i in inputs {
178 if let FnArg::Typed(PatType { pat, ty, .. }) = &i {
179 if first_nonself_arg.is_none() {
180 first_nonself_arg = Some(try_pat_as_ident(pat).map_err(|e| e.clone()));
181 }
182 if is_context_ref(ty, lifetime) {
183 context = Some((pat_as_ident(pat), idx))
184 }
185 idx += 1;
186 } else {
187 abort!(i, "queries may not have a receiver type");
188 }
189
190 if let FnArg::Typed(x) = i {
191 new_inputs.push(x);
192 }
193 }
194
195 if new_inputs.is_empty() {
196 abort_call_site!("queries must have at least one parameter which is `cx: &Context<{}>`")
197 }
198 if !new_inputs.is_empty() && first_nonself_arg.is_none() {
199 abort_call_site!("queries must have at least one parameter which is `cx: &Context<{}>`")
200 }
201
202 let (context, idx) = match context {
203 Some((context, idx)) => (context, idx),
204 _ => {
205 emit_call_site_warning!(
206 "queries must start with one parameter `cx: &Context<{}>`",
207 quote! {#lifetime}
208 );
209 match first_nonself_arg.unwrap() {
210 Ok(i) => (i, 0),
211 Err(e) => {
212 abort!(e, "expected identifier for the first parameter of a query (which must have type `Context<{}>`)", quote! {#lifetime});
213 }
214 }
215 }
216 };
217
218 if idx != 0 {
219 emit_warning!(context, "expected context to be the first argument")
220 }
221
222 let inputs_without_context: Vec<_> = new_inputs
223 .iter()
224 .filter(|p| {
225 if let Ok(i) = try_pat_as_ident(&p.pat) {
226 i != context
227 } else {
228 true
229 }
230 })
231 .cloned()
232 .collect();
233
234 let input_types_without_context: Vec<_> = inputs_without_context
235 .iter()
236 .map(|i| i.ty.as_ref().clone())
237 .collect();
238 let mut input_dereffed_types_without_context = Vec::new();
239
240 for i in &input_types_without_context {
242 input_dereffed_types_without_context.push(deref_type(i))
243 }
244
245 let context_ty = get_ty_path(
246 &new_inputs
247 .iter()
248 .find(|i| try_pat_as_ident(&i.pat).as_ref() == Ok(&context))
249 .expect("context")
250 .ty,
251 );
252 let PathArguments::AngleBracketed(arguments) = context_ty
253 .segments
254 .last()
255 .expect("path segment")
256 .arguments
257 .clone()
258 else {
259 abort!(
260 context_ty.segments.last().expect("path segment").arguments,
261 "unexpected path segment"
262 );
263 };
264 let generics = arguments.args;
265
266 let mut had_lifetime = false;
267 let mut had_generic = None;
268
269 for i in generics {
270 match i {
271 syn::GenericArgument::Lifetime(l) => {
272 if &l != lifetime {
273 abort!(
274 l,
275 "expected `Context<{}>` but found `Context<{}>`",
276 quote! {#lifetime},
277 quote! {#l}
278 );
279 }
280
281 if had_lifetime {
282 abort!(
283 l,
284 "expected `Context<{}>` to have only one lifetime argument",
285 quote! {#lifetime}
286 );
287 }
288
289 had_lifetime = true;
290 }
291 syn::GenericArgument::Type(ref t) => {
292 if let Some(old) = had_generic.replace(t.clone()) {
293 abort!(t, "expected at most one type argument `Contex<{}>` but found `Context<{}, ..., {}>`", quote!{#old}, quote!{#old}, quote!{#t});
294 }
295 }
296 syn::GenericArgument::Const(c) => abort!(c, "unexpected const argument on `Context`"),
297 syn::GenericArgument::AssocType(a) => {
298 abort!(a, "unexpected associated type on `Context`")
299 }
300 syn::GenericArgument::AssocConst(c) => {
301 abort!(c, "unexpected associated const on `Context`")
302 }
303 syn::GenericArgument::Constraint(c) => abort!(c, "unexpected constraint on `Context`"),
304 g => abort!(g, "unexpected generic argument on `Context`"),
305 }
306 }
307
308 let inputs_dereffed = new_inputs
309 .iter()
310 .map(|i@PatType { attrs, pat, colon_token, ty }| {
311 if try_pat_as_ident(&i.pat).as_ref() == Ok(&context) {
312 return i.clone();
313 }
314
315 let dereffed_ty = deref_type(ty);
316 parse_quote!(
317 #(#attrs)* #pat #colon_token #dereffed_ty
318 )
319 })
320 .collect();
321
322 QueryInputs {
323 inputs_dereffed,
324 inputs_without_context,
325 input_dereffed_types_without_context,
326 context,
327 context_typaram: had_generic.clone(),
328 }
329}
330
331fn validate_sig(
332 Signature {
333 constness,
334 asyncness,
335 unsafety,
336 abi,
337 fn_token,
338 ident,
339 generics,
340 paren_token: _,
341 inputs,
342 variadic,
343 output,
344 }: Signature,
345) -> QuerySignature {
346 let marked_lifetime = generics.lifetimes().find(has_lt_attr);
347 let cx_lifetime = generics.lifetimes().find(|i| i.lifetime.ident == "cx");
348 let Some(query_lifetime) = marked_lifetime
349 .or(cx_lifetime)
350 .map(|i| &i.lifetime)
351 .cloned()
352 else {
353 abort!(
354 generics,
355 "expected `'cx` lifetime or lifetime marked with #[lt] in the generics list"
356 )
357 };
358
359 if let Some(i) = unsafety {
360 abort!(i, "queries can't be unsafe");
361 }
362 if let Some(i) = abi {
363 abort!(i, "queries can't have an explicit abi");
364 }
365 if let Some(i) = variadic {
366 abort!(i, "queries can't be variadic");
367 }
368
369 QuerySignature {
370 constness,
371 asyncness,
372 fn_token,
373 ident,
374 inputs: validate_inputs(inputs, &query_lifetime),
375 output: match &output {
376 ReturnType::Default => parse_quote!(()),
377 ReturnType::Type(_, ty) => *ty.clone(),
378 },
379 output_ref: match output {
380 ReturnType::Default => parse_quote_spanned! {output.span() => & #query_lifetime ()},
381 ReturnType::Type(_, ty) => parse_quote!{& #query_lifetime #ty},
382 },
383 query_lifetime,
384 }
385}
386
387enum QueryAttr {
388 Mode(QueryMode),
389}
390
391fn get_string(e: &Expr) -> String {
392 match e {
393 Expr::Lit(l) => match &l.lit {
394 syn::Lit::Str(s) => s.value(),
395 l => abort!(l, "expected string literal"),
396 },
397 e => abort!(e, "expected string literal"),
398 }
399}
400
401fn parse_rerun(s: &str, span: Span) -> QueryMode {
402 match s {
403 "always" => QueryMode::Always,
404 "generation" => QueryMode::Generation,
405 _ => abort!(
406 span,
407 "unknown query mode, expected `always` or `generation`"
408 ),
409 }
410}
411
412fn parse_attr(attr: &Attribute) -> Option<QueryAttr> {
413 match &attr.meta {
414 syn::Meta::Path(_) => None,
415 syn::Meta::List(ml) => if ml.path.is_ident("rerun") {
416 match ml.parse_args::<Ident>() {
417 Err(e) => abort!(ml, "{}", e),
418 Ok(i) => Some(QueryAttr::Mode(parse_rerun(&i.to_string(), i.span()))),
419 }
420 } else {
421 None
422 },
423 syn::Meta::NameValue(MetaNameValue { path, value, .. }) => {
424 if path.is_ident("rerun") {
425 Some(QueryAttr::Mode(parse_rerun(
426 &get_string(value),
427 value.span(),
428 )))
429 } else {
430 None
431 }
432 }
433 }
434}
435
436fn validate(
437 ItemFn {
438 attrs,
439 vis,
440 sig,
441 block,
442 }: ItemFn,
443) -> Query {
444 let mut mode = QueryMode::Cache;
445
446 for attr in attrs.iter().filter_map(parse_attr) {
447 match attr {
448 QueryAttr::Mode(m) => mode = m,
449 }
450 }
451
452 Query {
453 mode,
454 attrs,
455 vis,
456 sig: validate_sig(sig),
457 block,
458 }
459}
460
461fn get_path(attr: &Attribute) -> Option<&Path> {
462 match &attr.meta {
463 syn::Meta::Path(p) => Some(p),
465 syn::Meta::List(_) => None,
467 syn::Meta::NameValue(_) => None,
468 }
469}
470
471fn assert_simple_attr(attr: TokenStream, expected: &str) -> Result<(), syn::Error> {
472 let attrs = Parser::parse(Attribute::parse_outer, attr)?;
473 for i in attrs {
474 match i.meta {
475 syn::Meta::Path(_) => {}
477 syn::Meta::List(ml) if ml.path.is_ident(expected) => {
479 emit_error!(ml, "expected an attribute without parameters")
480 }
481 syn::Meta::NameValue(mnv) if mnv.path.is_ident(expected) => {
482 emit_error!(mnv, "expected an attribute without this value")
483 }
484 _ => {}
486 }
487 }
488 Ok(())
489}
490
491fn tuple_from_types<T: ToTokens>(types: &[T]) -> TokenStream2 {
492 match types {
493 [] => quote! {()},
494 [x] => quote! {(#x,)},
495 x => quote! {(#(#x),*)},
496 }
497}
498
499fn or_unit(ty: &Option<Type>) -> TokenStream2 {
500 match ty {
501 Some(ty) => quote! {#ty},
502 None => quote! {()},
503 }
504}
505
506#[derive(Clone, Copy, PartialEq, Eq, Debug)]
507enum QueryMode {
508 Always,
509 Generation,
510 Cache,
511}
512
513#[proc_macro_error]
514#[proc_macro_attribute]
515pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
516 let input = parse_macro_input!(item as ItemFn);
517 if let Err(e) = assert_simple_attr(attr, "query") {
518 return e.to_compile_error().into();
519 }
520
521 let Query {
522 mode,
523 attrs,
524 vis,
525 sig:
526 QuerySignature {
527 constness,
528 asyncness,
529 fn_token,
530 ident,
531 inputs:
532 QueryInputs {
533 inputs_dereffed,
534 context,
535 inputs_without_context,
536 input_dereffed_types_without_context,
537 context_typaram,
538 },
539 output,
540 output_ref,
541 query_lifetime,
542 },
543 block,
544 } = validate(input);
545
546 let query = quote! {incremental_query::Query};
547 let erased_query_run = quote! {incremental_query::ErasedQueryRun};
548 let input_type = tuple_from_types(&input_dereffed_types_without_context);
549 let input_type_dereffed = tuple_from_types(&input_dereffed_types_without_context);
550
551 let string_ident = ident.to_string();
552 let data_ty = or_unit(&context_typaram);
553 let context_ty = quote! {incremental_query::Context};
554 let type_erased_query_param = quote! {incremental_query::TypeErasedQueryParam};
555 let mode_ident = quote! {incremental_query::QueryMode};
556
557 let param_names = tuple_from_types(
558 &inputs_without_context
559 .iter()
560 .map(|i| &i.pat)
561 .collect::<Vec<_>>(),
562 );
563
564 let mode_fn = match mode {
565 QueryMode::Always => quote! {
566 fn mode(&self) -> #mode_ident {
567 #mode_ident::Always
568 }
569 },
570 QueryMode::Generation => quote! {
571 fn mode(&self) -> #mode_ident {
572 #mode_ident::Generation
573 }
574 },
575 QueryMode::Cache => quote! {
576 fn mode(&self) -> #mode_ident {
577 #mode_ident::Cache
578 }
579 },
580 };
581
582 quote! {
583 #(#attrs)*
584 #vis #constness #asyncness #fn_token #ident <#query_lifetime> (#(#inputs_dereffed),*) -> #output_ref {
585 #[derive(Copy, Clone)]
586 struct Q;
587
588 impl<#query_lifetime> #query<#query_lifetime, #data_ty> for Q {
589 type Input = #input_type;
590 type Output = #output;
591
592 const NAME: &'static str = #string_ident;
593
594 fn get_run_fn() -> #erased_query_run<#data_ty> {
595 fn run<'cx>(
596 cx: &#context_ty<'cx, #data_ty>,
597 input: #type_erased_query_param<'cx>,
598 should_alloc: &dyn Fn(u128) -> bool,
599 ) -> (Option<#type_erased_query_param<'cx>>, u128)
600 {
601 let input: &#input_type_dereffed = unsafe{input.get_ref()};
602 let output = <Q as #query<'cx, #data_ty>>::run(cx, input);
603
604 let output_hash = cx.hash(Q, &output);
605 if should_alloc(output_hash) {
606 (Some(#type_erased_query_param::new(cx.storage.alloc(output))), output_hash)
607 } else {
608 (None, output_hash)
609 }
610 }
611
612 run
613 }
614
615 #mode_fn
616
617 fn run(#context: &#context_ty<#query_lifetime, #data_ty>, #param_names: &Self::Input) -> Self::Output #block
618 }
619
620 #context.query(Q, #param_names)
621 }
622
623 }
624 .into()
625}
626
627#[proc_macro_attribute]
633#[proc_macro_error]
634pub fn rerun(_attr: TokenStream, item: TokenStream) -> TokenStream {
635 item
636}