1use proc_macro2::{Ident, Span, TokenStream};
2use proc_macro_error::*;
3use quote::quote;
4use syn::punctuated::Punctuated;
5use syn::{
6 parse_quote, Attribute, Expr, ExprField, ExprPath, ExprTuple, FnArg, ForeignItemFn,
7 GenericArgument, GenericParam, Index, ItemFn, Member, Pat, PatIdent, PatType, Path,
8 PathArguments, PathSegment, ReturnType, Signature, Token, Type, TypePath, TypeReference,
9 TypeTuple, Visibility, WhereClause,
10};
11
12fn fn_arg_to_type(arg: &FnArg) -> &Type {
13 match arg {
14 FnArg::Receiver(_) => unimplemented!(),
15 FnArg::Typed(arg) => arg.ty.as_ref(),
16 }
17}
18
19fn build_type_tuple(types: impl Iterator<Item = Type>) -> Type {
20 let mut elems = types.collect::<Punctuated<_, Token![,]>>();
21 if !elems.is_empty() {
22 elems.push_punct(Default::default());
23 }
24
25 Type::Tuple(TypeTuple {
26 paren_token: Default::default(),
27 elems,
28 })
29}
30
31fn build_unit_tuple() -> Type {
32 build_type_tuple([].into_iter())
33}
34
35fn arg_name(arg: &FnArg) -> Option<Ident> {
36 match arg {
37 FnArg::Receiver(_) => Some(Ident::new("self", Span::call_site())),
38 FnArg::Typed(pat_type) => {
39 if let Pat::Ident(name) = pat_type.pat.as_ref() {
40 Some(name.ident.clone())
41 } else {
42 None
43 }
44 }
45 }
46}
47
48fn arg_names<'a>(args: impl Iterator<Item = &'a FnArg>) -> Vec<Ident> {
49 args.enumerate()
50 .map(|(n, arg)| {
51 arg_name(arg).unwrap_or_else(|| Ident::new(&format!("arg{n}"), Span::mixed_site()))
52 })
53 .collect()
54}
55
56fn calling_tuple_args(idents: impl Iterator<Item = (Ident, Type)>) -> Punctuated<FnArg, Token![,]> {
57 idents
58 .map(|(name, typ)| {
59 FnArg::Typed(PatType {
60 attrs: Default::default(),
61 pat: Box::new(Pat::Ident(PatIdent {
62 attrs: Default::default(),
63 by_ref: None,
64 mutability: None,
65 subpat: None,
66 ident: name,
67 })),
68 colon_token: Default::default(),
69 ty: Box::new(typ),
70 })
71 })
72 .collect()
73}
74
75fn build_ident_tuple(idents: impl Iterator<Item = Ident>) -> Expr {
76 let mut elems = idents
77 .map(ident_to_expr)
78 .collect::<Punctuated<_, Token![,]>>();
79 if !elems.is_empty() {
80 elems.push_punct(Default::default());
81 }
82
83 ExprTuple {
84 attrs: Default::default(),
85 paren_token: Default::default(),
86 elems,
87 }
88 .into()
89}
90
91fn ident_to_expr(id: Ident) -> Expr {
92 ExprPath {
93 attrs: Default::default(),
94 qself: Default::default(),
95 path: Path::from(id),
96 }
97 .into()
98}
99
100fn use_generic_args(
102 generics: &Punctuated<GenericParam, Token![,]>,
103) -> Punctuated<GenericArgument, Token![,]> {
104 generics
105 .iter()
106 .map(|p| match p {
107 GenericParam::Type(t) => GenericArgument::Type(
108 TypePath {
109 qself: None,
110 path: t.ident.clone().into(),
111 }
112 .into(),
113 ),
114 GenericParam::Lifetime(l) => GenericArgument::Lifetime(l.lifetime.clone()),
115 GenericParam::Const(c) => GenericArgument::Const(ident_to_expr(c.ident.clone())),
116 })
117 .collect()
118}
119
120fn generic_args_phantom(generics: &Punctuated<GenericArgument, Token![,]>) -> Type {
121 build_type_tuple(generics.iter().filter_map(|a| {
122 match a {
123 GenericArgument::Binding(_) => unreachable!(),
124 GenericArgument::Constraint(_) => unreachable!(),
125 GenericArgument::Type(t) => Some(t.clone()),
126 GenericArgument::Const(_) => None,
127 GenericArgument::Lifetime(lt) => Some(
128 TypeReference {
129 lifetime: Some(lt.clone()),
130 mutability: None,
131 and_token: Default::default(),
132 elem: Box::new(build_unit_tuple()),
133 }
134 .into(),
135 ),
136 }
137 }))
138}
139
140#[proc_macro_error]
141#[proc_macro_attribute]
142pub fn query(
143 attr: proc_macro::TokenStream,
144 item: proc_macro::TokenStream,
145) -> proc_macro::TokenStream {
146 if !attr.is_empty() {
147 emit_error!(
148 TokenStream::from(attr),
149 "#[yeter::query] doesn't expect any attributes"
150 );
151 }
152
153 let mut function_no_impl;
154 let mut function_impl;
155 let function = {
156 if let Ok(f) = syn::parse::<ForeignItemFn>(item.clone()) {
157 function_no_impl = f;
158 &mut function_no_impl as &mut dyn FunctionItem
159 } else if let Ok(f) = syn::parse::<ItemFn>(item.clone()) {
160 function_impl = f;
161 &mut function_impl as &mut dyn FunctionItem
162 } else {
163 let item = TokenStream::from(item);
164 return (quote! { compile_error!("expected fn item"); #item }).into();
165 }
166 };
167
168 let query_attrs = function.take_attrs();
169 let fn_args = &function.sig().inputs;
170 let query_args = fn_args
171 .iter()
172 .skip(1)
173 .map(fn_arg_to_type)
174 .cloned()
175 .collect::<Vec<_>>();
176
177 let db_ident_fallback = Ident::new("db", Span::call_site());
178 match fn_args.first() {
179 Some(receiver @ FnArg::Receiver(_)) => {
181 emit_error!(
182 receiver,
183 "#[yeter::query] can't be used on instance methods";
184 hint = "did you mean `db: &yeter::Database`?";
185 );
186
187 &db_ident_fallback
188 }
189 Some(FnArg::Typed(pat_type)) => match pat_type.pat.as_ref() {
190 Pat::Ident(ident) => &ident.ident,
191 _ => {
192 emit_error!(
193 pat_type.pat,
194 "simple database argument pattern expected";
195 help = "use a simple argument declaration such as `db: &yeter::Database`";
196 );
197
198 &db_ident_fallback
199 }
200 },
201 None => {
202 emit_error!(
203 function.sig(), "a query must take a database as its first argument";
204 note = "no arguments were specified";
205 );
206
207 &db_ident_fallback
208 }
209 };
210
211 let fn_arg_count = fn_args.len() as u32;
212 let query_arg_count = if fn_arg_count == 0 {
213 0
214 } else {
215 fn_arg_count - 1
216 };
217
218 let unit_type;
219
220 let query_vis = &function.vis();
221 let query_name = &function.sig().ident;
222 let generics = &function.sig().generics;
223 let generics_params = &generics.params;
224 let generics_where = &generics.where_clause;
225 let generics_args = use_generic_args(generics_params);
226 let generics_phantom = generic_args_phantom(&generics_args);
227
228 let input_type = build_type_tuple(query_args.iter().cloned());
229 let output_type = match &function.sig().output {
230 ReturnType::Default => {
231 unit_type = build_unit_tuple();
232 &unit_type
233 }
234 ReturnType::Type(_, typ) => typ.as_ref(),
235 };
236
237 let calling_arg_names = arg_names(fn_args.iter().skip(1));
238
239 let calling_tuple_args = calling_tuple_args(calling_arg_names.iter().cloned().zip(query_args));
240 let calling_tuple = build_ident_tuple(calling_arg_names.into_iter());
241
242 let call_ident_span = Span::call_site().located_at(query_name.span());
243 let call_ident = Ident::new(&format!("__yeter_{query_name}"), call_ident_span);
245
246 let to_function_impl = function.to_function_impl(&call_ident, generics_params, output_type);
247 let to_function_call = function.to_function_call(&call_ident, query_arg_count);
248 let to_additional_impl = function.to_additional_impl(
249 query_name,
250 generics_params,
251 &generics_args,
252 generics_where,
253 output_type,
254 );
255
256 let expanded = quote! {
257 #(#query_attrs)*
258 #query_vis fn #query_name<#generics_params>(db: &::yeter::Database, #calling_tuple_args) -> ::std::rc::Rc<#output_type>
259 #generics_where
260 {
261 #to_function_impl
262 db.run::<_, #query_name::<#generics_args>>(#to_function_call, #calling_tuple)
263 }
264
265 #[allow(non_camel_case_types)]
266 #[doc(hidden)]
267 #query_vis enum #query_name<#generics_params> {
268 Phantom(std::convert::Infallible, std::marker::PhantomData<#generics_phantom>),
269 }
270
271 impl<#generics_params> ::yeter::QueryDef for #query_name<#generics_args> #generics_where {
272 type Input = #input_type;
273 type Output = #output_type;
274 }
275
276 #to_additional_impl
277 };
278
279 set_dummy(expanded.clone()); expanded.into()
281}
282
283trait FunctionItem {
284 fn take_attrs(&mut self) -> Vec<Attribute>;
285 fn vis(&self) -> &Visibility;
286 fn sig(&self) -> &Signature;
287
288 fn to_function_impl(
289 &self,
290 _call_ident: &Ident,
291 _generics_params: &Punctuated<GenericParam, Token![,]>,
292 _output_type: &Type,
293 ) -> TokenStream {
294 quote! {}
295 }
296
297 fn to_function_call(&self, _call_ident: &Ident, _query_arg_count: u32) -> TokenStream;
298
299 fn to_additional_impl(
300 &self,
301 _query_name: &Ident,
302 _generics_params: &Punctuated<GenericParam, Token![,]>,
303 _generics_args: &Punctuated<GenericArgument, Token![,]>,
304 _generics_where: &Option<WhereClause>,
305 _output_type: &Type,
306 ) -> TokenStream {
307 quote! {}
308 }
309}
310
311fn guess_option_inner_type(option: &Type) -> Type {
312 if let Type::Path(path) = option {
313 match path.path.segments.last() {
314 Some(seg) if seg.ident != "Option" => {
315 emit_error!(seg.ident, "expected `Option` type",);
316 }
317 Some(PathSegment {
318 arguments: PathArguments::AngleBracketed(angle),
319 ..
320 }) if angle.args.len() == 1 => match angle.args.first().unwrap() {
321 GenericArgument::Type(t) => return t.clone(),
322 o => {
323 emit_error!(o, "unexpected generic argument for Option type",);
324 }
325 },
326 Some(seg) => {
327 emit_error!(seg, "expected Option<T> return type",);
328 }
329 None => {
330 emit_error!(path, "expected Option<T> return type",);
331 }
332 }
333 };
334
335 parse_quote! { Option<()> }
336}
337
338impl FunctionItem for ForeignItemFn {
339 fn take_attrs(&mut self) -> Vec<Attribute> {
340 std::mem::take(&mut self.attrs)
341 }
342
343 fn vis(&self) -> &Visibility {
344 &self.vis
345 }
346
347 fn sig(&self) -> &Signature {
348 &self.sig
349 }
350
351 fn to_function_call(&self, _call_ident: &Ident, _query_arg_count: u32) -> TokenStream {
352 quote! {
353 |_db, _input| None
354 }
355 }
356
357 fn to_additional_impl(
358 &self,
359 query_name: &Ident,
360 generics_params: &Punctuated<GenericParam, Token![,]>,
361 generics_args: &Punctuated<GenericArgument, Token![,]>,
362 generics_where: &Option<WhereClause>,
363 output_type: &Type,
364 ) -> TokenStream {
365 let output_type = guess_option_inner_type(output_type);
367
368 quote! {
369 impl<#generics_params> ::yeter::InputQueryDef for #query_name<#generics_args> #generics_where {
370 type OptionalOutput = #output_type;
371 }
372 }
373 }
374}
375
376impl FunctionItem for ItemFn {
377 fn take_attrs(&mut self) -> Vec<Attribute> {
378 std::mem::take(&mut self.attrs)
379 }
380
381 fn vis(&self) -> &Visibility {
382 &self.vis
383 }
384
385 fn sig(&self) -> &Signature {
386 &self.sig
387 }
388
389 fn to_function_impl(
390 &self,
391 call_ident: &Ident,
392 _generics_params: &Punctuated<GenericParam, Token![,]>,
393 _output_type: &Type,
394 ) -> TokenStream {
395 let mut s = self.clone();
396 s.sig.ident = call_ident.clone();
397
398 quote! {
399 #[allow(clippy::needless_lifetimes)]
400 #s
401 }
402 }
403
404 fn to_function_call(&self, call_ident: &Ident, query_arg_count: u32) -> TokenStream {
405 let db_ident = Ident::new("db", Span::mixed_site());
406 let input_ident = Ident::new("input", Span::mixed_site());
407 let input_ident_expr = Box::new(ident_to_expr(input_ident.clone()));
408 let calling_args = (0..query_arg_count)
409 .map(|n| {
410 Expr::Field(ExprField {
411 attrs: Default::default(),
412 base: input_ident_expr.clone(),
413 dot_token: Default::default(),
414 member: Member::Unnamed(Index {
415 index: n,
416 span: Span::mixed_site(),
417 }),
418 })
419 })
420 .collect::<Punctuated<_, Token![,]>>();
421
422 quote! {
423 |#db_ident, #input_ident| #call_ident(#db_ident, #calling_args)
424 }
425 }
426}