1use quote::ToTokens;
2
3#[proc_macro_attribute]
4pub fn queries(
5 attr: proc_macro::TokenStream,
6 item: proc_macro::TokenStream,
7) -> proc_macro::TokenStream {
8 let args = syn::parse_macro_input!(attr as syn::MetaNameValue);
9 let input = syn::parse_macro_input!(item as syn::ItemTrait);
10
11 expand(args, input)
12 .unwrap_or_else(syn::Error::into_compile_error)
13 .into()
14}
15
16fn expand(
17 args: syn::MetaNameValue,
18 input: syn::ItemTrait,
19) -> syn::Result<proc_macro2::TokenStream> {
20 if !args.path.is_ident("database") {
21 return Err(syn::Error::new_spanned(
22 args,
23 "The only permitted argument is database.",
24 ));
25 }
26 let database = args.value;
27
28 if input.unsafety.is_some()
29 || input.auto_token.is_some()
30 || input.restriction.is_some()
31 || !input.generics.params.is_empty()
32 || input.generics.where_clause.is_some()
33 || !input.supertraits.is_empty()
34 {
35 return Err(syn::Error::new_spanned(
36 input,
37 "Used an unsupported feature in trait definition",
38 ));
39 }
40
41 let mut method_impls = vec![];
42 for item in input.items {
43 let syn::TraitItem::Fn(fn_def) = item else {
44 return Err(syn::Error::new_spanned(
45 item,
46 "Only methods are allowed in the trait definition",
47 ));
48 };
49 method_impls.push(expand_method_impl(&database, fn_def)?);
50 }
51
52 let name = input.ident;
53 let vis = input.vis;
54 let result = quote::quote! {
55 #vis struct #name {
56 pool: sqlx::Pool<#database>,
57 }
58
59 impl #name {
60 pub fn new(pool: sqlx::Pool<#database>) -> Self {
61 Self { pool }
62 }
63 }
64
65 impl #name {
66 #(#method_impls)*
67 }
68 };
69 Ok(result)
70}
71
72fn expand_method_impl(
73 database: &syn::Expr,
74 fn_def: syn::TraitItemFn,
75) -> syn::Result<proc_macro2::TokenStream> {
76 if fn_def.default.is_some() {
77 return Err(syn::Error::new_spanned(
78 fn_def,
79 "Default implementations are not allowed",
80 ));
81 }
82
83 if fn_def.sig.asyncness.is_none() {
84 return Err(syn::Error::new_spanned(fn_def.sig, "Method must be async"));
85 }
86
87 for attr in &fn_def.attrs {
88 if !attr.path().is_ident("query") {
89 return Err(syn::Error::new_spanned(
90 attr,
91 "Only #[query] attributes are allowed",
92 ));
93 }
94 }
95
96 let query = &fn_def.attrs[0].meta.require_name_value()?.value;
97 let name = &fn_def.sig.ident;
98 let args = &fn_def.sig.inputs;
99 let arg_names = args
100 .iter()
101 .map(|p| {
102 let syn::FnArg::Typed(pat) = p else {
103 return Err(syn::Error::new_spanned(p, "weird arg"));
104 };
105 let syn::Pat::Ident(i) = &*pat.pat else {
106 return Err(syn::Error::new_spanned(pat, "weird arg"));
107 };
108 Ok(&i.ident)
109 })
110 .collect::<Result<Vec<_>, _>>()?;
111 let return_type = match &fn_def.sig.output {
112 syn::ReturnType::Default => quote::quote! { () },
113 syn::ReturnType::Type(_, ty) => ty.into_token_stream(),
114 };
115
116 let result = quote::quote! {
117 async fn #name(&self, #args) -> Result<#return_type, sqlx::Error>
118 {
119 use queries::Probe;
120
121 let q = sqlx::query(#query);
122 #(let q = q.bind(#arg_names);)*
123 <
124 #return_type as queries::FromRows<
125 #database,
126 { queries::FromRowsCategory::<#return_type>::VALUE }
127 >
128 >::from_rows(q.fetch(&self.pool)).await
129 }
130 };
131 Ok(result)
132}