1use core::fmt;
4use std::vec;
5
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use queries::{
9 GeneratedInputStruct, InputQuery, InputSetter, InputSetterWithDurability, Intern, Lookup,
10 Queries, SetterKind, TrackedQuery, Transparent,
11};
12use quote::{ToTokens, format_ident, quote};
13use syn::spanned::Spanned;
14use syn::visit_mut::VisitMut;
15use syn::{
16 Attribute, FnArg, ItemTrait, Path, TraitItem, TraitItemFn, parse_quote, parse_quote_spanned,
17};
18
19mod queries;
20
21#[proc_macro_attribute]
22pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
23 match query_group_impl(args, input.clone()) {
24 Ok(tokens) => tokens,
25 Err(e) => token_stream_with_error(input, e),
26 }
27}
28
29#[derive(Debug)]
30struct InputStructField {
31 name: proc_macro2::TokenStream,
32 ty: proc_macro2::TokenStream,
33}
34
35impl fmt::Display for InputStructField {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 write!(f, "{}", self.name)
38 }
39}
40
41struct SalsaAttr {
42 name: String,
43 tts: TokenStream,
44 span: Span,
45}
46
47impl std::fmt::Debug for SalsaAttr {
48 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(fmt, "{:?}", self.name)
50 }
51}
52
53impl TryFrom<syn::Attribute> for SalsaAttr {
54 type Error = syn::Attribute;
55
56 fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
57 if is_not_salsa_attr_path(attr.path()) {
58 return Err(attr);
59 }
60
61 let span = attr.span();
62
63 let name = attr.path().segments[1].ident.to_string();
64 let tts = match attr.meta {
65 syn::Meta::Path(path) => path.into_token_stream(),
66 syn::Meta::List(ref list) => {
67 let tts = list
68 .into_token_stream()
69 .into_iter()
70 .skip(attr.path().to_token_stream().into_iter().count());
71 proc_macro2::TokenStream::from_iter(tts)
72 }
73 syn::Meta::NameValue(nv) => nv.into_token_stream(),
74 }
75 .into();
76
77 Ok(SalsaAttr { name, tts, span })
78 }
79}
80
81fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
82 path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2
83}
84
85fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
86 let mut other = vec![];
87 let mut salsa = vec![];
88 for attr in attrs {
92 match SalsaAttr::try_from(attr) {
93 Ok(it) => salsa.push(it),
94 Err(it) => other.push(it),
95 }
96 }
97 (other, salsa)
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
101enum QueryKind {
102 Input,
103 Tracked,
104 TrackedWithSalsaStruct,
105 Transparent,
106 Interned,
107}
108
109pub(crate) fn query_group_impl(
110 _args: proc_macro::TokenStream,
111 input: proc_macro::TokenStream,
112) -> Result<proc_macro::TokenStream, syn::Error> {
113 let mut item_trait = syn::parse::<ItemTrait>(input)?;
114
115 let supertraits = &item_trait.supertraits;
116
117 let db_attr: Attribute = parse_quote! {
118 #[salsa::db]
119 };
120 item_trait.attrs.push(db_attr);
121
122 let trait_name_ident = &item_trait.ident.clone();
123 let input_struct_name = format_ident!("{}Data", trait_name_ident);
124 let create_data_ident = format_ident!("create_data_{}", trait_name_ident);
125
126 let mut input_struct_fields: Vec<InputStructField> = vec![];
127 let mut trait_methods = vec![];
128 let mut setter_trait_methods = vec![];
129 let mut lookup_signatures = vec![];
130 let mut lookup_methods = vec![];
131
132 for item in &mut item_trait.items {
133 if let syn::TraitItem::Fn(method) = item {
134 let method_name = &method.sig.ident;
135 let signature = &method.sig;
136
137 let (_attrs, salsa_attrs) = filter_attrs(method.attrs.clone());
138
139 let mut query_kind = QueryKind::TrackedWithSalsaStruct;
140 let mut invoke = None;
141 let mut cycle = None;
142 let mut interned_struct_path = None;
143 let mut lru = None;
144
145 let params: Vec<FnArg> = signature.inputs.clone().into_iter().collect();
146 let pat_and_tys = params
147 .into_iter()
148 .filter(|fn_arg| matches!(fn_arg, FnArg::Typed(_)))
149 .map(|fn_arg| match fn_arg {
150 FnArg::Typed(pat_type) => pat_type.clone(),
151 FnArg::Receiver(_) => unreachable!("this should have been filtered out"),
152 })
153 .collect::<Vec<syn::PatType>>();
154
155 for SalsaAttr { name, tts, span } in salsa_attrs {
156 match name.as_str() {
157 "cycle" => {
158 let path = syn::parse::<Parenthesized<Path>>(tts)?;
159 cycle = Some(path.0.clone())
160 }
161 "input" => {
162 if !pat_and_tys.is_empty() {
163 return Err(syn::Error::new(
164 span,
165 "input methods cannot have a parameter",
166 ));
167 }
168 query_kind = QueryKind::Input;
169 }
170 "interned" => {
171 let syn::ReturnType::Type(_, ty) = &signature.output else {
172 return Err(syn::Error::new(
173 span,
174 "interned queries must have return type",
175 ));
176 };
177 let syn::Type::Path(path) = &**ty else {
178 return Err(syn::Error::new(
179 span,
180 "interned queries must have return type",
181 ));
182 };
183 interned_struct_path = Some(path.path.clone());
184 query_kind = QueryKind::Interned;
185 }
186 "invoke_interned" => {
187 let path = syn::parse::<Parenthesized<Path>>(tts)?;
188 invoke = Some(path.0.clone());
189 query_kind = QueryKind::Tracked;
190 }
191 "invoke" => {
192 let path = syn::parse::<Parenthesized<Path>>(tts)?;
193 invoke = Some(path.0.clone());
194 if query_kind != QueryKind::Transparent {
195 query_kind = QueryKind::TrackedWithSalsaStruct;
196 }
197 }
198 "tracked" if method.default.is_some() => {
199 query_kind = QueryKind::TrackedWithSalsaStruct;
200 }
201 "lru" => {
202 let lru_count = syn::parse::<Parenthesized<syn::LitInt>>(tts)?;
203 let lru_count = lru_count.0.base10_parse::<u32>()?;
204
205 lru = Some(lru_count);
206 }
207 "transparent" => {
208 query_kind = QueryKind::Transparent;
209 }
210 _ => return Err(syn::Error::new(span, format!("unknown attribute `{name}`"))),
211 }
212 }
213
214 let syn::ReturnType::Type(_, return_ty) = signature.output.clone() else {
215 return Err(syn::Error::new(signature.span(), "Queries must have a return type"));
216 };
217
218 if let syn::Type::Path(ref ty_path) = *return_ty {
219 if matches!(query_kind, QueryKind::Input) {
220 let field = InputStructField {
221 name: method_name.to_token_stream(),
222 ty: ty_path.path.to_token_stream(),
223 };
224
225 input_struct_fields.push(field);
226 }
227 }
228
229 if let Some(block) = &mut method.default {
230 SelfToDbRewriter.visit_block_mut(block);
231 }
232
233 match (query_kind, invoke) {
234 (QueryKind::Input, None) => {
236 let query = InputQuery {
237 signature: method.sig.clone(),
238 create_data_ident: create_data_ident.clone(),
239 };
240 let value = Queries::InputQuery(query);
241 trait_methods.push(value);
242
243 let setter = InputSetter {
244 signature: method.sig.clone(),
245 return_type: *return_ty.clone(),
246 create_data_ident: create_data_ident.clone(),
247 };
248 setter_trait_methods.push(SetterKind::Plain(setter));
249
250 let setter = InputSetterWithDurability {
251 signature: method.sig.clone(),
252 return_type: *return_ty.clone(),
253 create_data_ident: create_data_ident.clone(),
254 };
255 setter_trait_methods.push(SetterKind::WithDurability(setter));
256 }
257 (QueryKind::Interned, None) => {
258 let interned_struct_path = interned_struct_path.unwrap();
259 let method = Intern {
260 signature: signature.clone(),
261 pat_and_tys: pat_and_tys.clone(),
262 interned_struct_path: interned_struct_path.clone(),
263 };
264
265 trait_methods.push(Queries::Intern(method));
266
267 let mut method = Lookup {
268 signature: signature.clone(),
269 pat_and_tys: pat_and_tys.clone(),
270 return_ty: *return_ty,
271 interned_struct_path,
272 };
273 method.prepare_signature();
274
275 lookup_signatures
276 .push(TraitItem::Fn(make_trait_method(method.signature.clone())));
277 lookup_methods.push(method);
278 }
279 (QueryKind::Tracked, invoke) => {
281 let method = TrackedQuery {
282 trait_name: trait_name_ident.clone(),
283 generated_struct: Some(GeneratedInputStruct {
284 input_struct_name: input_struct_name.clone(),
285 create_data_ident: create_data_ident.clone(),
286 }),
287 signature: signature.clone(),
288 pat_and_tys: pat_and_tys.clone(),
289 invoke,
290 cycle,
291 lru,
292 default: method.default.take(),
293 };
294
295 trait_methods.push(Queries::TrackedQuery(method));
296 }
297 (QueryKind::TrackedWithSalsaStruct, invoke) => {
298 let method = TrackedQuery {
299 trait_name: trait_name_ident.clone(),
300 generated_struct: None,
301 signature: signature.clone(),
302 pat_and_tys: pat_and_tys.clone(),
303 invoke,
304 cycle,
305 lru,
306 default: method.default.take(),
307 };
308
309 trait_methods.push(Queries::TrackedQuery(method))
310 }
311 (QueryKind::Transparent, invoke) => {
312 let method = Transparent {
313 signature: method.sig.clone(),
314 pat_and_tys: pat_and_tys.clone(),
315 invoke,
316 default: method.default.take(),
317 };
318 trait_methods.push(Queries::Transparent(method));
319 }
320 (QueryKind::Interned, Some(path)) => {
322 return Err(syn::Error::new(
323 path.span(),
324 "Interned queries cannot be used with an `#[invoke]`".to_string(),
325 ));
326 }
327 (QueryKind::Input, Some(path)) => {
328 return Err(syn::Error::new(
329 path.span(),
330 "Inputs cannot be used with an `#[invoke]`".to_string(),
331 ));
332 }
333 }
334 }
335 }
336
337 let fields = input_struct_fields
338 .into_iter()
339 .map(|input| {
340 let name = input.name;
341 let ret = input.ty;
342 quote! { #name: Option<#ret> }
343 })
344 .collect::<Vec<proc_macro2::TokenStream>>();
345
346 let input_struct = quote! {
347 #[salsa::input]
348 pub(crate) struct #input_struct_name {
349 #(#fields),*
350 }
351 };
352
353 let field_params = std::iter::repeat_n(quote! { None }, fields.len())
354 .collect::<Vec<proc_macro2::TokenStream>>();
355
356 let create_data_method = quote! {
357 #[allow(non_snake_case)]
358 #[salsa::tracked]
359 fn #create_data_ident(db: &dyn #trait_name_ident) -> #input_struct_name {
360 #input_struct_name::new(db, #(#field_params),*)
361 }
362 };
363
364 let mut setter_signatures = vec![];
365 let mut setter_methods = vec![];
366 for trait_item in setter_trait_methods
367 .iter()
368 .map(|method| method.to_token_stream())
369 .map(|tokens| syn::parse2::<syn::TraitItemFn>(tokens).unwrap())
370 {
371 let mut methods_sans_body = trait_item.clone();
372 methods_sans_body.default = None;
373 methods_sans_body.semi_token = Some(syn::Token));
374
375 setter_signatures.push(TraitItem::Fn(methods_sans_body));
376 setter_methods.push(TraitItem::Fn(trait_item));
377 }
378
379 item_trait.items.append(&mut setter_signatures);
380 item_trait.items.append(&mut lookup_signatures);
381
382 let trait_impl = quote! {
383 #[salsa::db]
384 impl<DB> #trait_name_ident for DB
385 where
386 DB: #supertraits,
387 {
388 #(#trait_methods)*
389
390 #(#setter_methods)*
391
392 #(#lookup_methods)*
393 }
394 };
395 RemoveAttrsFromTraitMethods.visit_item_trait_mut(&mut item_trait);
396
397 let out = quote! {
398 #item_trait
399
400 #trait_impl
401
402 #input_struct
403
404 #create_data_method
405 }
406 .into();
407
408 Ok(out)
409}
410
411pub(crate) struct Parenthesized<T>(pub(crate) T);
413
414impl<T> syn::parse::Parse for Parenthesized<T>
415where
416 T: syn::parse::Parse,
417{
418 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
419 let content;
420 syn::parenthesized!(content in input);
421 content.parse::<T>().map(Parenthesized)
422 }
423}
424
425fn make_trait_method(sig: syn::Signature) -> TraitItemFn {
426 TraitItemFn {
427 attrs: vec![],
428 sig: sig.clone(),
429 semi_token: Some(syn::Token)),
430 default: None,
431 }
432}
433
434struct RemoveAttrsFromTraitMethods;
435
436impl VisitMut for RemoveAttrsFromTraitMethods {
437 fn visit_item_trait_mut(&mut self, i: &mut syn::ItemTrait) {
438 for item in &mut i.items {
439 if let TraitItem::Fn(trait_item_fn) = item {
440 trait_item_fn.attrs = vec![];
441 }
442 }
443 }
444}
445
446pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
447 tokens.extend(TokenStream::from(error.into_compile_error()));
448 tokens
449}
450
451struct SelfToDbRewriter;
452
453impl VisitMut for SelfToDbRewriter {
454 fn visit_expr_path_mut(&mut self, i: &mut syn::ExprPath) {
455 if i.path.is_ident("self") {
456 i.path = parse_quote_spanned!(i.path.span() => db);
457 }
458 }
459}