queries_derive/
lib.rs

1use quote::ToTokens;
2
3#[proc_macro_attribute]
4pub fn queries(
5    attr: proc_macro::TokenStream,
6    item: proc_macro::TokenStream,
7) -> proc_macro::TokenStream {
8    let args = syn::parse_macro_input!(attr as syn::MetaNameValue);
9    let input = syn::parse_macro_input!(item as syn::ItemTrait);
10
11    expand(args, input)
12        .unwrap_or_else(syn::Error::into_compile_error)
13        .into()
14}
15
16fn expand(
17    args: syn::MetaNameValue,
18    input: syn::ItemTrait,
19) -> syn::Result<proc_macro2::TokenStream> {
20    if !args.path.is_ident("database") {
21        return Err(syn::Error::new_spanned(
22            args,
23            "The only permitted argument is database.",
24        ));
25    }
26    let database = args.value;
27
28    if input.unsafety.is_some()
29        || input.auto_token.is_some()
30        || input.restriction.is_some()
31        || !input.generics.params.is_empty()
32        || input.generics.where_clause.is_some()
33        || !input.supertraits.is_empty()
34    {
35        return Err(syn::Error::new_spanned(
36            input,
37            "Used an unsupported feature in trait definition",
38        ));
39    }
40
41    let mut pool_method_impls = vec![];
42    let mut tx_method_impls = vec![];
43
44    for item in input.items {
45        let syn::TraitItem::Fn(fn_def) = item else {
46            return Err(syn::Error::new_spanned(
47                item,
48                "Only methods are allowed in the trait definition",
49            ));
50        };
51        pool_method_impls.push(expand_pool_method_impl(&database, fn_def.clone())?);
52        tx_method_impls.push(expand_tx_method_impl(&database, fn_def)?);
53    }
54
55    let name = input.ident;
56    let vis = input.vis;
57
58    let result = quote::quote! {
59        #vis struct #name<E> {
60            executor: E,
61        }
62
63        impl #name<sqlx::Pool<#database>> {
64            pub fn from_pool(pool: sqlx::Pool<#database>) -> Self {
65                Self { executor: pool }
66            }
67
68            #(#pool_method_impls)*
69        }
70
71        impl<'a> #name<sqlx::Transaction<'a, #database>> {
72            pub fn from_tx(tx: sqlx::Transaction<'a, #database>) -> Self {
73                Self { executor: tx }
74            }
75
76            pub async fn commit(self) -> sqlx::Result<()> {
77                self.executor.commit().await
78            }
79
80            pub async fn rollback(self) -> sqlx::Result<()> {
81                self.executor.rollback().await
82            }
83
84            #(#tx_method_impls)*
85        }
86    };
87    Ok(result)
88}
89
90fn expand_method_impl_with_self(
91    database: &syn::Expr,
92    fn_def: syn::TraitItemFn,
93    self_param: proc_macro2::TokenStream,
94    executor_ref: proc_macro2::TokenStream,
95) -> syn::Result<proc_macro2::TokenStream> {
96    if fn_def.default.is_some() {
97        return Err(syn::Error::new_spanned(
98            fn_def,
99            "Default implementations are not allowed",
100        ));
101    }
102
103    if fn_def.sig.asyncness.is_none() {
104        return Err(syn::Error::new_spanned(fn_def.sig, "Method must be async"));
105    }
106
107    for attr in &fn_def.attrs {
108        if !attr.path().is_ident("query") {
109            return Err(syn::Error::new_spanned(
110                attr,
111                "Only #[query] attributes are allowed",
112            ));
113        }
114    }
115
116    let query = &fn_def.attrs[0].meta.require_name_value()?.value;
117    let name = &fn_def.sig.ident;
118    let args = &fn_def.sig.inputs;
119    let arg_names = args
120        .iter()
121        .map(|p| {
122            let syn::FnArg::Typed(pat) = p else {
123                return Err(syn::Error::new_spanned(p, "weird arg"));
124            };
125            let syn::Pat::Ident(i) = &*pat.pat else {
126                return Err(syn::Error::new_spanned(pat, "weird arg"));
127            };
128            Ok(&i.ident)
129        })
130        .collect::<Result<Vec<_>, _>>()?;
131    let return_type = match &fn_def.sig.output {
132        syn::ReturnType::Default => quote::quote! { () },
133        syn::ReturnType::Type(_, ty) => ty.into_token_stream(),
134    };
135
136    let result = quote::quote! {
137        async fn #name(#self_param, #args) -> Result<#return_type, sqlx::Error>
138        {
139            use queries::Probe;
140
141            let q = sqlx::query(#query);
142            #(let q = q.bind(#arg_names);)*
143            <
144                #return_type as queries::FromRows<
145                    #database,
146                    { queries::FromRowsCategory::<#return_type>::VALUE }
147                >
148            >::from_rows(q.fetch(#executor_ref)).await
149        }
150    };
151    Ok(result)
152}
153
154fn expand_pool_method_impl(
155    database: &syn::Expr,
156    fn_def: syn::TraitItemFn,
157) -> syn::Result<proc_macro2::TokenStream> {
158    expand_method_impl_with_self(
159        database,
160        fn_def,
161        quote::quote! { &self },
162        quote::quote! { &self.executor },
163    )
164}
165
166fn expand_tx_method_impl(
167    database: &syn::Expr,
168    fn_def: syn::TraitItemFn,
169) -> syn::Result<proc_macro2::TokenStream> {
170    expand_method_impl_with_self(
171        database,
172        fn_def,
173        quote::quote! { &mut self },
174        quote::quote! { &mut *self.executor },
175    )
176}