1use core::fmt;
4use std::vec;
5
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use queries::{
9 GeneratedInputStruct, InputQuery, InputSetter, InputSetterWithDurability, Intern, Lookup,
10 Queries, SetterKind, TrackedQuery, Transparent,
11};
12use quote::{ToTokens, format_ident, quote};
13use syn::spanned::Spanned;
14use syn::visit_mut::VisitMut;
15use syn::{Attribute, FnArg, ItemTrait, Path, TraitItem, TraitItemFn, parse_quote};
16
17mod queries;
18
19#[proc_macro_attribute]
20pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
21 match query_group_impl(args, input.clone()) {
22 Ok(tokens) => tokens,
23 Err(e) => token_stream_with_error(input, e),
24 }
25}
26
27#[derive(Debug)]
28struct InputStructField {
29 name: proc_macro2::TokenStream,
30 ty: proc_macro2::TokenStream,
31}
32
33impl fmt::Display for InputStructField {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "{}", self.name)
36 }
37}
38
39struct SalsaAttr {
40 name: String,
41 tts: TokenStream,
42 span: Span,
43}
44
45impl std::fmt::Debug for SalsaAttr {
46 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 write!(fmt, "{:?}", self.name)
48 }
49}
50
51impl TryFrom<syn::Attribute> for SalsaAttr {
52 type Error = syn::Attribute;
53
54 fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
55 if is_not_salsa_attr_path(attr.path()) {
56 return Err(attr);
57 }
58
59 let span = attr.span();
60
61 let name = attr.path().segments[1].ident.to_string();
62 let tts = match attr.meta {
63 syn::Meta::Path(path) => path.into_token_stream(),
64 syn::Meta::List(ref list) => {
65 let tts = list
66 .into_token_stream()
67 .into_iter()
68 .skip(attr.path().to_token_stream().into_iter().count());
69 proc_macro2::TokenStream::from_iter(tts)
70 }
71 syn::Meta::NameValue(nv) => nv.into_token_stream(),
72 }
73 .into();
74
75 Ok(SalsaAttr { name, tts, span })
76 }
77}
78
79fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
80 path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2
81}
82
83fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
84 let mut other = vec![];
85 let mut salsa = vec![];
86 for attr in attrs {
90 match SalsaAttr::try_from(attr) {
91 Ok(it) => salsa.push(it),
92 Err(it) => other.push(it),
93 }
94 }
95 (other, salsa)
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99enum QueryKind {
100 Input,
101 Tracked,
102 TrackedWithSalsaStruct,
103 Transparent,
104 Interned,
105}
106
107pub(crate) fn query_group_impl(
108 _args: proc_macro::TokenStream,
109 input: proc_macro::TokenStream,
110) -> Result<proc_macro::TokenStream, syn::Error> {
111 let mut item_trait = syn::parse::<ItemTrait>(input)?;
112
113 let supertraits = &item_trait.supertraits;
114
115 let db_attr: Attribute = parse_quote! {
116 #[salsa::db]
117 };
118 item_trait.attrs.push(db_attr);
119
120 let trait_name_ident = &item_trait.ident.clone();
121 let input_struct_name = format_ident!("{}Data", trait_name_ident);
122 let create_data_ident = format_ident!("create_data_{}", trait_name_ident);
123
124 let mut input_struct_fields: Vec<InputStructField> = vec![];
125 let mut trait_methods = vec![];
126 let mut setter_trait_methods = vec![];
127 let mut lookup_signatures = vec![];
128 let mut lookup_methods = vec![];
129
130 for item in item_trait.clone().items {
131 if let syn::TraitItem::Fn(method) = item {
132 let method_name = &method.sig.ident;
133 let signature = &method.sig.clone();
134
135 let (_attrs, salsa_attrs) = filter_attrs(method.attrs);
136
137 let mut query_kind = QueryKind::Tracked;
138 let mut invoke = None;
139 let mut cycle = None;
140 let mut interned_struct_path = None;
141 let mut lru = None;
142
143 let params: Vec<FnArg> = signature.inputs.clone().into_iter().collect();
144 let pat_and_tys = params
145 .into_iter()
146 .filter(|fn_arg| matches!(fn_arg, FnArg::Typed(_)))
147 .map(|fn_arg| match fn_arg {
148 FnArg::Typed(pat_type) => pat_type.clone(),
149 FnArg::Receiver(_) => unreachable!("this should have been filtered out"),
150 })
151 .collect::<Vec<syn::PatType>>();
152
153 for SalsaAttr { name, tts, span } in salsa_attrs {
154 match name.as_str() {
155 "cycle" => {
156 let path = syn::parse::<Parenthesized<Path>>(tts)?;
157 cycle = Some(path.0.clone())
158 }
159 "input" => {
160 if !pat_and_tys.is_empty() {
161 return Err(syn::Error::new(
162 span,
163 "input methods cannot have a parameter",
164 ));
165 }
166 query_kind = QueryKind::Input;
167 }
168 "interned" => {
169 let syn::ReturnType::Type(_, ty) = &signature.output else {
170 return Err(syn::Error::new(
171 span,
172 "interned queries must have return type",
173 ));
174 };
175 let syn::Type::Path(path) = &**ty else {
176 return Err(syn::Error::new(
177 span,
178 "interned queries must have return type",
179 ));
180 };
181 interned_struct_path = Some(path.path.clone());
182 query_kind = QueryKind::Interned;
183 }
184 "invoke" => {
185 let path = syn::parse::<Parenthesized<Path>>(tts)?;
186 invoke = Some(path.0.clone());
187 }
188 "invoke_actual" => {
189 let path = syn::parse::<Parenthesized<Path>>(tts)?;
190 invoke = Some(path.0.clone());
191 query_kind = QueryKind::TrackedWithSalsaStruct;
192 }
193 "lru" => {
194 let lru_count = syn::parse::<Parenthesized<syn::LitInt>>(tts)?;
195 let lru_count = lru_count.0.base10_parse::<u32>()?;
196
197 lru = Some(lru_count);
198 }
199 "transparent" => {
200 query_kind = QueryKind::Transparent;
201 }
202 _ => return Err(syn::Error::new(span, format!("unknown attribute `{name}`"))),
203 }
204 }
205
206 let syn::ReturnType::Type(_, return_ty) = signature.output.clone() else {
207 return Err(syn::Error::new(signature.span(), "Queries must have a return type"));
208 };
209
210 if let syn::Type::Path(ref ty_path) = *return_ty {
211 if matches!(query_kind, QueryKind::Input) {
212 let field = InputStructField {
213 name: method_name.to_token_stream(),
214 ty: ty_path.path.to_token_stream(),
215 };
216
217 input_struct_fields.push(field);
218 }
219 }
220
221 match (query_kind, invoke) {
222 (QueryKind::Input, None) => {
224 let query = InputQuery {
225 signature: method.sig.clone(),
226 create_data_ident: create_data_ident.clone(),
227 };
228 let value = Queries::InputQuery(query);
229 trait_methods.push(value);
230
231 let setter = InputSetter {
232 signature: method.sig.clone(),
233 return_type: *return_ty.clone(),
234 create_data_ident: create_data_ident.clone(),
235 };
236 setter_trait_methods.push(SetterKind::Plain(setter));
237
238 let setter = InputSetterWithDurability {
239 signature: method.sig.clone(),
240 return_type: *return_ty.clone(),
241 create_data_ident: create_data_ident.clone(),
242 };
243 setter_trait_methods.push(SetterKind::WithDurability(setter));
244 }
245 (QueryKind::Interned, None) => {
246 let interned_struct_path = interned_struct_path.unwrap();
247 let method = Intern {
248 signature: signature.clone(),
249 pat_and_tys: pat_and_tys.clone(),
250 interned_struct_path: interned_struct_path.clone(),
251 };
252
253 trait_methods.push(Queries::Intern(method));
254
255 let mut method = Lookup {
256 signature: signature.clone(),
257 pat_and_tys: pat_and_tys.clone(),
258 return_ty: *return_ty,
259 interned_struct_path,
260 };
261 method.prepare_signature();
262
263 lookup_signatures
264 .push(TraitItem::Fn(make_trait_method(method.signature.clone())));
265 lookup_methods.push(method);
266 }
267 (QueryKind::Tracked, invoke) => {
269 let method = TrackedQuery {
270 trait_name: trait_name_ident.clone(),
271 generated_struct: Some(GeneratedInputStruct {
272 input_struct_name: input_struct_name.clone(),
273 create_data_ident: create_data_ident.clone(),
274 }),
275 signature: signature.clone(),
276 pat_and_tys: pat_and_tys.clone(),
277 invoke,
278 cycle,
279 lru,
280 };
281
282 trait_methods.push(Queries::TrackedQuery(method));
283 }
284 (QueryKind::TrackedWithSalsaStruct, Some(invoke)) => {
285 let method = TrackedQuery {
286 trait_name: trait_name_ident.clone(),
287 generated_struct: None,
288 signature: signature.clone(),
289 pat_and_tys: pat_and_tys.clone(),
290 invoke: Some(invoke),
291 cycle,
292 lru,
293 };
294
295 trait_methods.push(Queries::TrackedQuery(method))
296 }
297 (QueryKind::TrackedWithSalsaStruct, None) => unreachable!(),
300 (QueryKind::Transparent, invoke) => {
301 let method = Transparent {
302 signature: method.sig.clone(),
303 pat_and_tys: pat_and_tys.clone(),
304 invoke,
305 };
306 trait_methods.push(Queries::Transparent(method));
307 }
308 (QueryKind::Interned, Some(path)) => {
310 return Err(syn::Error::new(
311 path.span(),
312 "Interned queries cannot be used with an `#[invoke]`".to_string(),
313 ));
314 }
315 (QueryKind::Input, Some(path)) => {
316 return Err(syn::Error::new(
317 path.span(),
318 "Inputs cannot be used with an `#[invoke]`".to_string(),
319 ));
320 }
321 }
322 }
323 }
324
325 let fields = input_struct_fields
326 .into_iter()
327 .map(|input| {
328 let name = input.name;
329 let ret = input.ty;
330 quote! { #name: Option<#ret> }
331 })
332 .collect::<Vec<proc_macro2::TokenStream>>();
333
334 let input_struct = quote! {
335 #[salsa::input]
336 pub(crate) struct #input_struct_name {
337 #(#fields),*
338 }
339 };
340
341 let field_params = std::iter::repeat_n(quote! { None }, fields.len())
342 .collect::<Vec<proc_macro2::TokenStream>>();
343
344 let create_data_method = quote! {
345 #[allow(non_snake_case)]
346 #[salsa::tracked]
347 fn #create_data_ident(db: &dyn #trait_name_ident) -> #input_struct_name {
348 #input_struct_name::new(db, #(#field_params),*)
349 }
350 };
351
352 let mut setter_signatures = vec![];
353 let mut setter_methods = vec![];
354 for trait_item in setter_trait_methods
355 .iter()
356 .map(|method| method.to_token_stream())
357 .map(|tokens| syn::parse2::<syn::TraitItemFn>(tokens).unwrap())
358 {
359 let mut methods_sans_body = trait_item.clone();
360 methods_sans_body.default = None;
361 methods_sans_body.semi_token = Some(syn::Token));
362
363 setter_signatures.push(TraitItem::Fn(methods_sans_body));
364 setter_methods.push(TraitItem::Fn(trait_item));
365 }
366
367 item_trait.items.append(&mut setter_signatures);
368 item_trait.items.append(&mut lookup_signatures);
369
370 let trait_impl = quote! {
371 #[salsa::db]
372 impl<DB> #trait_name_ident for DB
373 where
374 DB: #supertraits,
375 {
376 #(#trait_methods)*
377
378 #(#setter_methods)*
379
380 #(#lookup_methods)*
381 }
382 };
383 RemoveAttrsFromTraitMethods.visit_item_trait_mut(&mut item_trait);
384
385 let out = quote! {
386 #item_trait
387
388 #trait_impl
389
390 #input_struct
391
392 #create_data_method
393 }
394 .into();
395
396 Ok(out)
397}
398
399pub(crate) struct Parenthesized<T>(pub(crate) T);
401
402impl<T> syn::parse::Parse for Parenthesized<T>
403where
404 T: syn::parse::Parse,
405{
406 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
407 let content;
408 syn::parenthesized!(content in input);
409 content.parse::<T>().map(Parenthesized)
410 }
411}
412
413fn make_trait_method(sig: syn::Signature) -> TraitItemFn {
414 TraitItemFn {
415 attrs: vec![],
416 sig: sig.clone(),
417 semi_token: Some(syn::Token)),
418 default: None,
419 }
420}
421
422struct RemoveAttrsFromTraitMethods;
423
424impl VisitMut for RemoveAttrsFromTraitMethods {
425 fn visit_item_trait_mut(&mut self, i: &mut syn::ItemTrait) {
426 for item in &mut i.items {
427 if let TraitItem::Fn(trait_item_fn) = item {
428 trait_item_fn.attrs = vec![];
429 }
430 }
431 }
432}
433
434pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
435 tokens.extend(TokenStream::from(error.into_compile_error()));
436 tokens
437}