ra_ap_query_group_macro/
lib.rs

1//! A macro that mimics the old Salsa-style `#[query_group]` macro.
2
3use core::fmt;
4use std::vec;
5
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use queries::{
9    GeneratedInputStruct, InputQuery, InputSetter, InputSetterWithDurability, Intern, Lookup,
10    Queries, SetterKind, TrackedQuery, Transparent,
11};
12use quote::{ToTokens, format_ident, quote};
13use syn::spanned::Spanned;
14use syn::visit_mut::VisitMut;
15use syn::{Attribute, FnArg, ItemTrait, Path, TraitItem, TraitItemFn, parse_quote};
16
17mod queries;
18
19#[proc_macro_attribute]
20pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
21    match query_group_impl(args, input.clone()) {
22        Ok(tokens) => tokens,
23        Err(e) => token_stream_with_error(input, e),
24    }
25}
26
27#[derive(Debug)]
28struct InputStructField {
29    name: proc_macro2::TokenStream,
30    ty: proc_macro2::TokenStream,
31}
32
33impl fmt::Display for InputStructField {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "{}", self.name)
36    }
37}
38
39struct SalsaAttr {
40    name: String,
41    tts: TokenStream,
42    span: Span,
43}
44
45impl std::fmt::Debug for SalsaAttr {
46    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(fmt, "{:?}", self.name)
48    }
49}
50
51impl TryFrom<syn::Attribute> for SalsaAttr {
52    type Error = syn::Attribute;
53
54    fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
55        if is_not_salsa_attr_path(attr.path()) {
56            return Err(attr);
57        }
58
59        let span = attr.span();
60
61        let name = attr.path().segments[1].ident.to_string();
62        let tts = match attr.meta {
63            syn::Meta::Path(path) => path.into_token_stream(),
64            syn::Meta::List(ref list) => {
65                let tts = list
66                    .into_token_stream()
67                    .into_iter()
68                    .skip(attr.path().to_token_stream().into_iter().count());
69                proc_macro2::TokenStream::from_iter(tts)
70            }
71            syn::Meta::NameValue(nv) => nv.into_token_stream(),
72        }
73        .into();
74
75        Ok(SalsaAttr { name, tts, span })
76    }
77}
78
79fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
80    path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2
81}
82
83fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
84    let mut other = vec![];
85    let mut salsa = vec![];
86    // Leave non-salsa attributes untouched. These are
87    // attributes that don't start with `salsa::` or don't have
88    // exactly two segments in their path.
89    for attr in attrs {
90        match SalsaAttr::try_from(attr) {
91            Ok(it) => salsa.push(it),
92            Err(it) => other.push(it),
93        }
94    }
95    (other, salsa)
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99enum QueryKind {
100    Input,
101    Tracked,
102    TrackedWithSalsaStruct,
103    Transparent,
104    Interned,
105}
106
107pub(crate) fn query_group_impl(
108    _args: proc_macro::TokenStream,
109    input: proc_macro::TokenStream,
110) -> Result<proc_macro::TokenStream, syn::Error> {
111    let mut item_trait = syn::parse::<ItemTrait>(input)?;
112
113    let supertraits = &item_trait.supertraits;
114
115    let db_attr: Attribute = parse_quote! {
116        #[salsa::db]
117    };
118    item_trait.attrs.push(db_attr);
119
120    let trait_name_ident = &item_trait.ident.clone();
121    let input_struct_name = format_ident!("{}Data", trait_name_ident);
122    let create_data_ident = format_ident!("create_data_{}", trait_name_ident);
123
124    let mut input_struct_fields: Vec<InputStructField> = vec![];
125    let mut trait_methods = vec![];
126    let mut setter_trait_methods = vec![];
127    let mut lookup_signatures = vec![];
128    let mut lookup_methods = vec![];
129
130    for item in item_trait.clone().items {
131        if let syn::TraitItem::Fn(method) = item {
132            let method_name = &method.sig.ident;
133            let signature = &method.sig.clone();
134
135            let (_attrs, salsa_attrs) = filter_attrs(method.attrs);
136
137            let mut query_kind = QueryKind::Tracked;
138            let mut invoke = None;
139            let mut cycle = None;
140            let mut interned_struct_path = None;
141            let mut lru = None;
142
143            let params: Vec<FnArg> = signature.inputs.clone().into_iter().collect();
144            let pat_and_tys = params
145                .into_iter()
146                .filter(|fn_arg| matches!(fn_arg, FnArg::Typed(_)))
147                .map(|fn_arg| match fn_arg {
148                    FnArg::Typed(pat_type) => pat_type.clone(),
149                    FnArg::Receiver(_) => unreachable!("this should have been filtered out"),
150                })
151                .collect::<Vec<syn::PatType>>();
152
153            for SalsaAttr { name, tts, span } in salsa_attrs {
154                match name.as_str() {
155                    "cycle" => {
156                        let path = syn::parse::<Parenthesized<Path>>(tts)?;
157                        cycle = Some(path.0.clone())
158                    }
159                    "input" => {
160                        if !pat_and_tys.is_empty() {
161                            return Err(syn::Error::new(
162                                span,
163                                "input methods cannot have a parameter",
164                            ));
165                        }
166                        query_kind = QueryKind::Input;
167                    }
168                    "interned" => {
169                        let syn::ReturnType::Type(_, ty) = &signature.output else {
170                            return Err(syn::Error::new(
171                                span,
172                                "interned queries must have return type",
173                            ));
174                        };
175                        let syn::Type::Path(path) = &**ty else {
176                            return Err(syn::Error::new(
177                                span,
178                                "interned queries must have return type",
179                            ));
180                        };
181                        interned_struct_path = Some(path.path.clone());
182                        query_kind = QueryKind::Interned;
183                    }
184                    "invoke" => {
185                        let path = syn::parse::<Parenthesized<Path>>(tts)?;
186                        invoke = Some(path.0.clone());
187                    }
188                    "invoke_actual" => {
189                        let path = syn::parse::<Parenthesized<Path>>(tts)?;
190                        invoke = Some(path.0.clone());
191                        query_kind = QueryKind::TrackedWithSalsaStruct;
192                    }
193                    "lru" => {
194                        let lru_count = syn::parse::<Parenthesized<syn::LitInt>>(tts)?;
195                        let lru_count = lru_count.0.base10_parse::<u32>()?;
196
197                        lru = Some(lru_count);
198                    }
199                    "transparent" => {
200                        query_kind = QueryKind::Transparent;
201                    }
202                    _ => return Err(syn::Error::new(span, format!("unknown attribute `{name}`"))),
203                }
204            }
205
206            let syn::ReturnType::Type(_, return_ty) = signature.output.clone() else {
207                return Err(syn::Error::new(signature.span(), "Queries must have a return type"));
208            };
209
210            if let syn::Type::Path(ref ty_path) = *return_ty {
211                if matches!(query_kind, QueryKind::Input) {
212                    let field = InputStructField {
213                        name: method_name.to_token_stream(),
214                        ty: ty_path.path.to_token_stream(),
215                    };
216
217                    input_struct_fields.push(field);
218                }
219            }
220
221            match (query_kind, invoke) {
222                // input
223                (QueryKind::Input, None) => {
224                    let query = InputQuery {
225                        signature: method.sig.clone(),
226                        create_data_ident: create_data_ident.clone(),
227                    };
228                    let value = Queries::InputQuery(query);
229                    trait_methods.push(value);
230
231                    let setter = InputSetter {
232                        signature: method.sig.clone(),
233                        return_type: *return_ty.clone(),
234                        create_data_ident: create_data_ident.clone(),
235                    };
236                    setter_trait_methods.push(SetterKind::Plain(setter));
237
238                    let setter = InputSetterWithDurability {
239                        signature: method.sig.clone(),
240                        return_type: *return_ty.clone(),
241                        create_data_ident: create_data_ident.clone(),
242                    };
243                    setter_trait_methods.push(SetterKind::WithDurability(setter));
244                }
245                (QueryKind::Interned, None) => {
246                    let interned_struct_path = interned_struct_path.unwrap();
247                    let method = Intern {
248                        signature: signature.clone(),
249                        pat_and_tys: pat_and_tys.clone(),
250                        interned_struct_path: interned_struct_path.clone(),
251                    };
252
253                    trait_methods.push(Queries::Intern(method));
254
255                    let mut method = Lookup {
256                        signature: signature.clone(),
257                        pat_and_tys: pat_and_tys.clone(),
258                        return_ty: *return_ty,
259                        interned_struct_path,
260                    };
261                    method.prepare_signature();
262
263                    lookup_signatures
264                        .push(TraitItem::Fn(make_trait_method(method.signature.clone())));
265                    lookup_methods.push(method);
266                }
267                // tracked function. it might have an invoke, or might not.
268                (QueryKind::Tracked, invoke) => {
269                    let method = TrackedQuery {
270                        trait_name: trait_name_ident.clone(),
271                        generated_struct: Some(GeneratedInputStruct {
272                            input_struct_name: input_struct_name.clone(),
273                            create_data_ident: create_data_ident.clone(),
274                        }),
275                        signature: signature.clone(),
276                        pat_and_tys: pat_and_tys.clone(),
277                        invoke,
278                        cycle,
279                        lru,
280                    };
281
282                    trait_methods.push(Queries::TrackedQuery(method));
283                }
284                (QueryKind::TrackedWithSalsaStruct, Some(invoke)) => {
285                    let method = TrackedQuery {
286                        trait_name: trait_name_ident.clone(),
287                        generated_struct: None,
288                        signature: signature.clone(),
289                        pat_and_tys: pat_and_tys.clone(),
290                        invoke: Some(invoke),
291                        cycle,
292                        lru,
293                    };
294
295                    trait_methods.push(Queries::TrackedQuery(method))
296                }
297                // while it is possible to make this reachable, it's not really worthwhile for a migration aid.
298                // doing this would require attaching an attribute to the salsa struct parameter in the query.
299                (QueryKind::TrackedWithSalsaStruct, None) => unreachable!(),
300                (QueryKind::Transparent, invoke) => {
301                    let method = Transparent {
302                        signature: method.sig.clone(),
303                        pat_and_tys: pat_and_tys.clone(),
304                        invoke,
305                    };
306                    trait_methods.push(Queries::Transparent(method));
307                }
308                // error/invalid constructions
309                (QueryKind::Interned, Some(path)) => {
310                    return Err(syn::Error::new(
311                        path.span(),
312                        "Interned queries cannot be used with an `#[invoke]`".to_string(),
313                    ));
314                }
315                (QueryKind::Input, Some(path)) => {
316                    return Err(syn::Error::new(
317                        path.span(),
318                        "Inputs cannot be used with an `#[invoke]`".to_string(),
319                    ));
320                }
321            }
322        }
323    }
324
325    let fields = input_struct_fields
326        .into_iter()
327        .map(|input| {
328            let name = input.name;
329            let ret = input.ty;
330            quote! { #name: Option<#ret> }
331        })
332        .collect::<Vec<proc_macro2::TokenStream>>();
333
334    let input_struct = quote! {
335        #[salsa::input]
336        pub(crate) struct #input_struct_name {
337            #(#fields),*
338        }
339    };
340
341    let field_params = std::iter::repeat_n(quote! { None }, fields.len())
342        .collect::<Vec<proc_macro2::TokenStream>>();
343
344    let create_data_method = quote! {
345        #[allow(non_snake_case)]
346        #[salsa::tracked]
347        fn #create_data_ident(db: &dyn #trait_name_ident) -> #input_struct_name {
348            #input_struct_name::new(db, #(#field_params),*)
349        }
350    };
351
352    let mut setter_signatures = vec![];
353    let mut setter_methods = vec![];
354    for trait_item in setter_trait_methods
355        .iter()
356        .map(|method| method.to_token_stream())
357        .map(|tokens| syn::parse2::<syn::TraitItemFn>(tokens).unwrap())
358    {
359        let mut methods_sans_body = trait_item.clone();
360        methods_sans_body.default = None;
361        methods_sans_body.semi_token = Some(syn::Token![;](trait_item.span()));
362
363        setter_signatures.push(TraitItem::Fn(methods_sans_body));
364        setter_methods.push(TraitItem::Fn(trait_item));
365    }
366
367    item_trait.items.append(&mut setter_signatures);
368    item_trait.items.append(&mut lookup_signatures);
369
370    let trait_impl = quote! {
371        #[salsa::db]
372        impl<DB> #trait_name_ident for DB
373        where
374            DB: #supertraits,
375        {
376            #(#trait_methods)*
377
378            #(#setter_methods)*
379
380            #(#lookup_methods)*
381        }
382    };
383    RemoveAttrsFromTraitMethods.visit_item_trait_mut(&mut item_trait);
384
385    let out = quote! {
386        #item_trait
387
388        #trait_impl
389
390        #input_struct
391
392        #create_data_method
393    }
394    .into();
395
396    Ok(out)
397}
398
399/// Parenthesis helper
400pub(crate) struct Parenthesized<T>(pub(crate) T);
401
402impl<T> syn::parse::Parse for Parenthesized<T>
403where
404    T: syn::parse::Parse,
405{
406    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
407        let content;
408        syn::parenthesized!(content in input);
409        content.parse::<T>().map(Parenthesized)
410    }
411}
412
413fn make_trait_method(sig: syn::Signature) -> TraitItemFn {
414    TraitItemFn {
415        attrs: vec![],
416        sig: sig.clone(),
417        semi_token: Some(syn::Token![;](sig.span())),
418        default: None,
419    }
420}
421
422struct RemoveAttrsFromTraitMethods;
423
424impl VisitMut for RemoveAttrsFromTraitMethods {
425    fn visit_item_trait_mut(&mut self, i: &mut syn::ItemTrait) {
426        for item in &mut i.items {
427            if let TraitItem::Fn(trait_item_fn) = item {
428                trait_item_fn.attrs = vec![];
429            }
430        }
431    }
432}
433
434pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
435    tokens.extend(TokenStream::from(error.into_compile_error()));
436    tokens
437}