queries_derive/
lib.rs

1use quote::ToTokens;
2
3#[proc_macro_attribute]
4pub fn queries(
5    attr: proc_macro::TokenStream,
6    item: proc_macro::TokenStream,
7) -> proc_macro::TokenStream {
8    let args = syn::parse_macro_input!(attr as syn::MetaNameValue);
9    let input = syn::parse_macro_input!(item as syn::ItemTrait);
10
11    expand(args, input)
12        .unwrap_or_else(syn::Error::into_compile_error)
13        .into()
14}
15
16fn expand(
17    args: syn::MetaNameValue,
18    input: syn::ItemTrait,
19) -> syn::Result<proc_macro2::TokenStream> {
20    if !args.path.is_ident("database") {
21        return Err(syn::Error::new_spanned(
22            args,
23            "The only permitted argument is database.",
24        ));
25    }
26    let database = args.value;
27
28    if input.unsafety.is_some()
29        || input.auto_token.is_some()
30        || input.restriction.is_some()
31        || !input.generics.params.is_empty()
32        || input.generics.where_clause.is_some()
33        || !input.supertraits.is_empty()
34    {
35        return Err(syn::Error::new_spanned(
36            input,
37            "Used an unsupported feature in trait definition",
38        ));
39    }
40
41    let mut method_impls = vec![];
42    for item in input.items {
43        let syn::TraitItem::Fn(fn_def) = item else {
44            return Err(syn::Error::new_spanned(
45                item,
46                "Only methods are allowed in the trait definition",
47            ));
48        };
49        method_impls.push(expand_method_impl(&database, fn_def)?);
50    }
51
52    let name = input.ident;
53    let vis = input.vis;
54    let result = quote::quote! {
55        #vis struct #name {
56            pool: sqlx::Pool<#database>,
57        }
58
59        impl #name {
60            pub fn new(pool: sqlx::Pool<#database>) -> Self {
61                Self { pool }
62            }
63        }
64
65        impl #name {
66            #(#method_impls)*
67        }
68    };
69    Ok(result)
70}
71
72fn expand_method_impl(
73    database: &syn::Expr,
74    fn_def: syn::TraitItemFn,
75) -> syn::Result<proc_macro2::TokenStream> {
76    if fn_def.default.is_some() {
77        return Err(syn::Error::new_spanned(
78            fn_def,
79            "Default implementations are not allowed",
80        ));
81    }
82
83    if fn_def.sig.asyncness.is_none() {
84        return Err(syn::Error::new_spanned(fn_def.sig, "Method must be async"));
85    }
86
87    for attr in &fn_def.attrs {
88        if !attr.path().is_ident("query") {
89            return Err(syn::Error::new_spanned(
90                attr,
91                "Only #[query] attributes are allowed",
92            ));
93        }
94    }
95
96    let query = &fn_def.attrs[0].meta.require_name_value()?.value;
97    let name = &fn_def.sig.ident;
98    let args = &fn_def.sig.inputs;
99    let arg_names = args
100        .iter()
101        .map(|p| {
102            let syn::FnArg::Typed(pat) = p else {
103                return Err(syn::Error::new_spanned(p, "weird arg"));
104            };
105            let syn::Pat::Ident(i) = &*pat.pat else {
106                return Err(syn::Error::new_spanned(pat, "weird arg"));
107            };
108            Ok(&i.ident)
109        })
110        .collect::<Result<Vec<_>, _>>()?;
111    let return_type = match &fn_def.sig.output {
112        syn::ReturnType::Default => quote::quote! { () },
113        syn::ReturnType::Type(_, ty) => ty.into_token_stream(),
114    };
115
116    let result = quote::quote! {
117        async fn #name(&self, #args) -> Result<#return_type, sqlx::Error>
118        {
119            use queries::Probe;
120
121            let q = sqlx::query(#query);
122            #(let q = q.bind(#arg_names);)*
123            <
124                #return_type as queries::FromRows<
125                    #database,
126                    { queries::FromRowsCategory::<#return_type>::VALUE }
127                >
128            >::from_rows(q.fetch(&self.pool)).await
129        }
130    };
131    Ok(result)
132}