1use quote::ToTokens;
2
3#[proc_macro_attribute]
4pub fn queries(
5 attr: proc_macro::TokenStream,
6 item: proc_macro::TokenStream,
7) -> proc_macro::TokenStream {
8 let args = syn::parse_macro_input!(attr as syn::MetaNameValue);
9 let input = syn::parse_macro_input!(item as syn::ItemTrait);
10
11 expand(args, input)
12 .unwrap_or_else(syn::Error::into_compile_error)
13 .into()
14}
15
16fn expand(
17 args: syn::MetaNameValue,
18 input: syn::ItemTrait,
19) -> syn::Result<proc_macro2::TokenStream> {
20 if !args.path.is_ident("database") {
21 return Err(syn::Error::new_spanned(
22 args,
23 "The only permitted argument is database.",
24 ));
25 }
26 let database = args.value;
27
28 if input.unsafety.is_some()
29 || input.auto_token.is_some()
30 || input.restriction.is_some()
31 || !input.generics.params.is_empty()
32 || input.generics.where_clause.is_some()
33 || !input.supertraits.is_empty()
34 {
35 return Err(syn::Error::new_spanned(
36 input,
37 "Used an unsupported feature in trait definition",
38 ));
39 }
40
41 let mut pool_method_impls = vec![];
42 let mut tx_method_impls = vec![];
43
44 for item in input.items {
45 let syn::TraitItem::Fn(fn_def) = item else {
46 return Err(syn::Error::new_spanned(
47 item,
48 "Only methods are allowed in the trait definition",
49 ));
50 };
51 pool_method_impls.push(expand_pool_method_impl(&database, fn_def.clone())?);
52 tx_method_impls.push(expand_tx_method_impl(&database, fn_def)?);
53 }
54
55 let name = input.ident;
56 let vis = input.vis;
57
58 let result = quote::quote! {
59 #vis struct #name<E> {
60 executor: E,
61 }
62
63 impl #name<sqlx::Pool<#database>> {
64 pub fn from_pool(pool: sqlx::Pool<#database>) -> Self {
65 Self { executor: pool }
66 }
67
68 #(#pool_method_impls)*
69 }
70
71 impl<'a> #name<sqlx::Transaction<'a, #database>> {
72 pub fn from_tx(tx: sqlx::Transaction<'a, #database>) -> Self {
73 Self { executor: tx }
74 }
75
76 pub async fn commit(self) -> sqlx::Result<()> {
77 self.executor.commit().await
78 }
79
80 pub async fn rollback(self) -> sqlx::Result<()> {
81 self.executor.rollback().await
82 }
83
84 #(#tx_method_impls)*
85 }
86 };
87 Ok(result)
88}
89
90fn expand_method_impl_with_self(
91 database: &syn::Expr,
92 fn_def: syn::TraitItemFn,
93 self_param: proc_macro2::TokenStream,
94 executor_ref: proc_macro2::TokenStream,
95) -> syn::Result<proc_macro2::TokenStream> {
96 if fn_def.default.is_some() {
97 return Err(syn::Error::new_spanned(
98 fn_def,
99 "Default implementations are not allowed",
100 ));
101 }
102
103 if fn_def.sig.asyncness.is_none() {
104 return Err(syn::Error::new_spanned(fn_def.sig, "Method must be async"));
105 }
106
107 for attr in &fn_def.attrs {
108 if !attr.path().is_ident("query") {
109 return Err(syn::Error::new_spanned(
110 attr,
111 "Only #[query] attributes are allowed",
112 ));
113 }
114 }
115
116 let query = &fn_def.attrs[0].meta.require_name_value()?.value;
117 let name = &fn_def.sig.ident;
118 let args = &fn_def.sig.inputs;
119 let arg_names = args
120 .iter()
121 .map(|p| {
122 let syn::FnArg::Typed(pat) = p else {
123 return Err(syn::Error::new_spanned(p, "weird arg"));
124 };
125 let syn::Pat::Ident(i) = &*pat.pat else {
126 return Err(syn::Error::new_spanned(pat, "weird arg"));
127 };
128 Ok(&i.ident)
129 })
130 .collect::<Result<Vec<_>, _>>()?;
131 let return_type = match &fn_def.sig.output {
132 syn::ReturnType::Default => quote::quote! { () },
133 syn::ReturnType::Type(_, ty) => ty.into_token_stream(),
134 };
135
136 let result = quote::quote! {
137 async fn #name(#self_param, #args) -> Result<#return_type, sqlx::Error>
138 {
139 use queries::Probe;
140
141 let q = sqlx::query(#query);
142 #(let q = q.bind(#arg_names);)*
143 <
144 #return_type as queries::FromRows<
145 #database,
146 { queries::FromRowsCategory::<#return_type>::VALUE }
147 >
148 >::from_rows(q.fetch(#executor_ref)).await
149 }
150 };
151 Ok(result)
152}
153
154fn expand_pool_method_impl(
155 database: &syn::Expr,
156 fn_def: syn::TraitItemFn,
157) -> syn::Result<proc_macro2::TokenStream> {
158 expand_method_impl_with_self(
159 database,
160 fn_def,
161 quote::quote! { &self },
162 quote::quote! { &self.executor },
163 )
164}
165
166fn expand_tx_method_impl(
167 database: &syn::Expr,
168 fn_def: syn::TraitItemFn,
169) -> syn::Result<proc_macro2::TokenStream> {
170 expand_method_impl_with_self(
171 database,
172 fn_def,
173 quote::quote! { &mut self },
174 quote::quote! { &mut *self.executor },
175 )
176}