ra_ap_query_group_macro/
lib.rs

1//! A macro that mimics the old Salsa-style `#[query_group]` macro.
2
3use core::fmt;
4use std::vec;
5
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use queries::{
9    GeneratedInputStruct, InputQuery, InputSetter, InputSetterWithDurability, Intern, Lookup,
10    Queries, SetterKind, TrackedQuery, Transparent,
11};
12use quote::{ToTokens, format_ident, quote};
13use syn::spanned::Spanned;
14use syn::visit_mut::VisitMut;
15use syn::{
16    Attribute, FnArg, ItemTrait, Path, TraitItem, TraitItemFn, parse_quote, parse_quote_spanned,
17};
18
19mod queries;
20
21#[proc_macro_attribute]
22pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
23    match query_group_impl(args, input.clone()) {
24        Ok(tokens) => tokens,
25        Err(e) => token_stream_with_error(input, e),
26    }
27}
28
29#[derive(Debug)]
30struct InputStructField {
31    name: proc_macro2::TokenStream,
32    ty: proc_macro2::TokenStream,
33}
34
35impl fmt::Display for InputStructField {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        write!(f, "{}", self.name)
38    }
39}
40
41struct SalsaAttr {
42    name: String,
43    tts: TokenStream,
44    span: Span,
45}
46
47impl std::fmt::Debug for SalsaAttr {
48    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(fmt, "{:?}", self.name)
50    }
51}
52
53impl TryFrom<syn::Attribute> for SalsaAttr {
54    type Error = syn::Attribute;
55
56    fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
57        if is_not_salsa_attr_path(attr.path()) {
58            return Err(attr);
59        }
60
61        let span = attr.span();
62
63        let name = attr.path().segments[1].ident.to_string();
64        let tts = match attr.meta {
65            syn::Meta::Path(path) => path.into_token_stream(),
66            syn::Meta::List(ref list) => {
67                let tts = list
68                    .into_token_stream()
69                    .into_iter()
70                    .skip(attr.path().to_token_stream().into_iter().count());
71                proc_macro2::TokenStream::from_iter(tts)
72            }
73            syn::Meta::NameValue(nv) => nv.into_token_stream(),
74        }
75        .into();
76
77        Ok(SalsaAttr { name, tts, span })
78    }
79}
80
81fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
82    path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2
83}
84
85fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
86    let mut other = vec![];
87    let mut salsa = vec![];
88    // Leave non-salsa attributes untouched. These are
89    // attributes that don't start with `salsa::` or don't have
90    // exactly two segments in their path.
91    for attr in attrs {
92        match SalsaAttr::try_from(attr) {
93            Ok(it) => salsa.push(it),
94            Err(it) => other.push(it),
95        }
96    }
97    (other, salsa)
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
101enum QueryKind {
102    Input,
103    Tracked,
104    TrackedWithSalsaStruct,
105    Transparent,
106    Interned,
107}
108
109pub(crate) fn query_group_impl(
110    _args: proc_macro::TokenStream,
111    input: proc_macro::TokenStream,
112) -> Result<proc_macro::TokenStream, syn::Error> {
113    let mut item_trait = syn::parse::<ItemTrait>(input)?;
114
115    let supertraits = &item_trait.supertraits;
116
117    let db_attr: Attribute = parse_quote! {
118        #[salsa::db]
119    };
120    item_trait.attrs.push(db_attr);
121
122    let trait_name_ident = &item_trait.ident.clone();
123    let input_struct_name = format_ident!("{}Data", trait_name_ident);
124    let create_data_ident = format_ident!("create_data_{}", trait_name_ident);
125
126    let mut input_struct_fields: Vec<InputStructField> = vec![];
127    let mut trait_methods = vec![];
128    let mut setter_trait_methods = vec![];
129    let mut lookup_signatures = vec![];
130    let mut lookup_methods = vec![];
131
132    for item in &mut item_trait.items {
133        if let syn::TraitItem::Fn(method) = item {
134            let method_name = &method.sig.ident;
135            let signature = &method.sig;
136
137            let (_attrs, salsa_attrs) = filter_attrs(method.attrs.clone());
138
139            let mut query_kind = QueryKind::TrackedWithSalsaStruct;
140            let mut invoke = None;
141            let mut cycle = None;
142            let mut interned_struct_path = None;
143            let mut lru = None;
144
145            let params: Vec<FnArg> = signature.inputs.clone().into_iter().collect();
146            let pat_and_tys = params
147                .into_iter()
148                .filter(|fn_arg| matches!(fn_arg, FnArg::Typed(_)))
149                .map(|fn_arg| match fn_arg {
150                    FnArg::Typed(pat_type) => pat_type.clone(),
151                    FnArg::Receiver(_) => unreachable!("this should have been filtered out"),
152                })
153                .collect::<Vec<syn::PatType>>();
154
155            for SalsaAttr { name, tts, span } in salsa_attrs {
156                match name.as_str() {
157                    "cycle" => {
158                        let path = syn::parse::<Parenthesized<Path>>(tts)?;
159                        cycle = Some(path.0.clone())
160                    }
161                    "input" => {
162                        if !pat_and_tys.is_empty() {
163                            return Err(syn::Error::new(
164                                span,
165                                "input methods cannot have a parameter",
166                            ));
167                        }
168                        query_kind = QueryKind::Input;
169                    }
170                    "interned" => {
171                        let syn::ReturnType::Type(_, ty) = &signature.output else {
172                            return Err(syn::Error::new(
173                                span,
174                                "interned queries must have return type",
175                            ));
176                        };
177                        let syn::Type::Path(path) = &**ty else {
178                            return Err(syn::Error::new(
179                                span,
180                                "interned queries must have return type",
181                            ));
182                        };
183                        interned_struct_path = Some(path.path.clone());
184                        query_kind = QueryKind::Interned;
185                    }
186                    "invoke_interned" => {
187                        let path = syn::parse::<Parenthesized<Path>>(tts)?;
188                        invoke = Some(path.0.clone());
189                        query_kind = QueryKind::Tracked;
190                    }
191                    "invoke" => {
192                        let path = syn::parse::<Parenthesized<Path>>(tts)?;
193                        invoke = Some(path.0.clone());
194                        if query_kind != QueryKind::Transparent {
195                            query_kind = QueryKind::TrackedWithSalsaStruct;
196                        }
197                    }
198                    "tracked" if method.default.is_some() => {
199                        query_kind = QueryKind::TrackedWithSalsaStruct;
200                    }
201                    "lru" => {
202                        let lru_count = syn::parse::<Parenthesized<syn::LitInt>>(tts)?;
203                        let lru_count = lru_count.0.base10_parse::<u32>()?;
204
205                        lru = Some(lru_count);
206                    }
207                    "transparent" => {
208                        query_kind = QueryKind::Transparent;
209                    }
210                    _ => return Err(syn::Error::new(span, format!("unknown attribute `{name}`"))),
211                }
212            }
213
214            let syn::ReturnType::Type(_, return_ty) = signature.output.clone() else {
215                return Err(syn::Error::new(signature.span(), "Queries must have a return type"));
216            };
217
218            if let syn::Type::Path(ref ty_path) = *return_ty {
219                if matches!(query_kind, QueryKind::Input) {
220                    let field = InputStructField {
221                        name: method_name.to_token_stream(),
222                        ty: ty_path.path.to_token_stream(),
223                    };
224
225                    input_struct_fields.push(field);
226                }
227            }
228
229            if let Some(block) = &mut method.default {
230                SelfToDbRewriter.visit_block_mut(block);
231            }
232
233            match (query_kind, invoke) {
234                // input
235                (QueryKind::Input, None) => {
236                    let query = InputQuery {
237                        signature: method.sig.clone(),
238                        create_data_ident: create_data_ident.clone(),
239                    };
240                    let value = Queries::InputQuery(query);
241                    trait_methods.push(value);
242
243                    let setter = InputSetter {
244                        signature: method.sig.clone(),
245                        return_type: *return_ty.clone(),
246                        create_data_ident: create_data_ident.clone(),
247                    };
248                    setter_trait_methods.push(SetterKind::Plain(setter));
249
250                    let setter = InputSetterWithDurability {
251                        signature: method.sig.clone(),
252                        return_type: *return_ty.clone(),
253                        create_data_ident: create_data_ident.clone(),
254                    };
255                    setter_trait_methods.push(SetterKind::WithDurability(setter));
256                }
257                (QueryKind::Interned, None) => {
258                    let interned_struct_path = interned_struct_path.unwrap();
259                    let method = Intern {
260                        signature: signature.clone(),
261                        pat_and_tys: pat_and_tys.clone(),
262                        interned_struct_path: interned_struct_path.clone(),
263                    };
264
265                    trait_methods.push(Queries::Intern(method));
266
267                    let mut method = Lookup {
268                        signature: signature.clone(),
269                        pat_and_tys: pat_and_tys.clone(),
270                        return_ty: *return_ty,
271                        interned_struct_path,
272                    };
273                    method.prepare_signature();
274
275                    lookup_signatures
276                        .push(TraitItem::Fn(make_trait_method(method.signature.clone())));
277                    lookup_methods.push(method);
278                }
279                // tracked function. it might have an invoke, or might not.
280                (QueryKind::Tracked, invoke) => {
281                    let method = TrackedQuery {
282                        trait_name: trait_name_ident.clone(),
283                        generated_struct: Some(GeneratedInputStruct {
284                            input_struct_name: input_struct_name.clone(),
285                            create_data_ident: create_data_ident.clone(),
286                        }),
287                        signature: signature.clone(),
288                        pat_and_tys: pat_and_tys.clone(),
289                        invoke,
290                        cycle,
291                        lru,
292                        default: method.default.take(),
293                    };
294
295                    trait_methods.push(Queries::TrackedQuery(method));
296                }
297                (QueryKind::TrackedWithSalsaStruct, invoke) => {
298                    let method = TrackedQuery {
299                        trait_name: trait_name_ident.clone(),
300                        generated_struct: None,
301                        signature: signature.clone(),
302                        pat_and_tys: pat_and_tys.clone(),
303                        invoke,
304                        cycle,
305                        lru,
306                        default: method.default.take(),
307                    };
308
309                    trait_methods.push(Queries::TrackedQuery(method))
310                }
311                (QueryKind::Transparent, invoke) => {
312                    let method = Transparent {
313                        signature: method.sig.clone(),
314                        pat_and_tys: pat_and_tys.clone(),
315                        invoke,
316                        default: method.default.take(),
317                    };
318                    trait_methods.push(Queries::Transparent(method));
319                }
320                // error/invalid constructions
321                (QueryKind::Interned, Some(path)) => {
322                    return Err(syn::Error::new(
323                        path.span(),
324                        "Interned queries cannot be used with an `#[invoke]`".to_string(),
325                    ));
326                }
327                (QueryKind::Input, Some(path)) => {
328                    return Err(syn::Error::new(
329                        path.span(),
330                        "Inputs cannot be used with an `#[invoke]`".to_string(),
331                    ));
332                }
333            }
334        }
335    }
336
337    let fields = input_struct_fields
338        .into_iter()
339        .map(|input| {
340            let name = input.name;
341            let ret = input.ty;
342            quote! { #name: Option<#ret> }
343        })
344        .collect::<Vec<proc_macro2::TokenStream>>();
345
346    let input_struct = quote! {
347        #[salsa::input]
348        pub(crate) struct #input_struct_name {
349            #(#fields),*
350        }
351    };
352
353    let field_params = std::iter::repeat_n(quote! { None }, fields.len())
354        .collect::<Vec<proc_macro2::TokenStream>>();
355
356    let create_data_method = quote! {
357        #[allow(non_snake_case)]
358        #[salsa::tracked]
359        fn #create_data_ident(db: &dyn #trait_name_ident) -> #input_struct_name {
360            #input_struct_name::new(db, #(#field_params),*)
361        }
362    };
363
364    let mut setter_signatures = vec![];
365    let mut setter_methods = vec![];
366    for trait_item in setter_trait_methods
367        .iter()
368        .map(|method| method.to_token_stream())
369        .map(|tokens| syn::parse2::<syn::TraitItemFn>(tokens).unwrap())
370    {
371        let mut methods_sans_body = trait_item.clone();
372        methods_sans_body.default = None;
373        methods_sans_body.semi_token = Some(syn::Token![;](trait_item.span()));
374
375        setter_signatures.push(TraitItem::Fn(methods_sans_body));
376        setter_methods.push(TraitItem::Fn(trait_item));
377    }
378
379    item_trait.items.append(&mut setter_signatures);
380    item_trait.items.append(&mut lookup_signatures);
381
382    let trait_impl = quote! {
383        #[salsa::db]
384        impl<DB> #trait_name_ident for DB
385        where
386            DB: #supertraits,
387        {
388            #(#trait_methods)*
389
390            #(#setter_methods)*
391
392            #(#lookup_methods)*
393        }
394    };
395    RemoveAttrsFromTraitMethods.visit_item_trait_mut(&mut item_trait);
396
397    let out = quote! {
398        #item_trait
399
400        #trait_impl
401
402        #input_struct
403
404        #create_data_method
405    }
406    .into();
407
408    Ok(out)
409}
410
411/// Parenthesis helper
412pub(crate) struct Parenthesized<T>(pub(crate) T);
413
414impl<T> syn::parse::Parse for Parenthesized<T>
415where
416    T: syn::parse::Parse,
417{
418    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
419        let content;
420        syn::parenthesized!(content in input);
421        content.parse::<T>().map(Parenthesized)
422    }
423}
424
425fn make_trait_method(sig: syn::Signature) -> TraitItemFn {
426    TraitItemFn {
427        attrs: vec![],
428        sig: sig.clone(),
429        semi_token: Some(syn::Token![;](sig.span())),
430        default: None,
431    }
432}
433
434struct RemoveAttrsFromTraitMethods;
435
436impl VisitMut for RemoveAttrsFromTraitMethods {
437    fn visit_item_trait_mut(&mut self, i: &mut syn::ItemTrait) {
438        for item in &mut i.items {
439            if let TraitItem::Fn(trait_item_fn) = item {
440                trait_item_fn.attrs = vec![];
441            }
442        }
443    }
444}
445
446pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
447    tokens.extend(TokenStream::from(error.into_compile_error()));
448    tokens
449}
450
451struct SelfToDbRewriter;
452
453impl VisitMut for SelfToDbRewriter {
454    fn visit_expr_path_mut(&mut self, i: &mut syn::ExprPath) {
455        if i.path.is_ident("self") {
456            i.path = parse_quote_spanned!(i.path.span() => db);
457        }
458    }
459}