pgx_macros/
rewriter.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9
10extern crate proc_macro;
11
12use proc_macro2::Ident;
13use quote::{quote, quote_spanned};
14use std::ops::Deref;
15use std::str::FromStr;
16use syn::punctuated::Punctuated;
17use syn::spanned::Spanned;
18use syn::{
19    FnArg, ForeignItem, ForeignItemFn, GenericParam, ItemFn, ItemForeignMod, Pat, ReturnType,
20    Signature, Token, Visibility,
21};
22
23pub struct PgGuardRewriter();
24
25impl PgGuardRewriter {
26    pub fn new() -> Self {
27        PgGuardRewriter()
28    }
29
30    pub fn extern_block(&self, block: ItemForeignMod) -> proc_macro2::TokenStream {
31        let mut stream = proc_macro2::TokenStream::new();
32
33        for item in block.items.into_iter() {
34            stream.extend(self.foreign_item(item, &block.abi));
35        }
36
37        stream
38    }
39
40    pub fn item_fn_without_rewrite(
41        &self,
42        mut func: ItemFn,
43    ) -> syn::Result<proc_macro2::TokenStream> {
44        // remember the original visibility and signature classifications as we want
45        // to use those for the outer function
46        let input_func_name = func.sig.ident.to_string();
47        let sig = func.sig.clone();
48        let vis = func.vis.clone();
49        let attrs = func.attrs.clone();
50
51        let generics = func.sig.generics.clone();
52
53        if attrs.iter().any(|attr| attr.path.is_ident("no_mangle"))
54            && generics.params.iter().any(|p| match p {
55                GenericParam::Type(_) => true,
56                GenericParam::Lifetime(_) => false,
57                GenericParam::Const(_) => true,
58            })
59        {
60            panic!("#[pg_guard] for function with generic parameters must not be combined with #[no_mangle]");
61        }
62
63        // but for the inner function (the one we're wrapping) we don't need any kind of
64        // abi classification
65        func.sig.abi = None;
66        func.attrs.clear();
67
68        // nor do we need a visibility beyond "private"
69        func.vis = Visibility::Inherited;
70
71        func.sig.ident =
72            Ident::new(&format!("{}_inner", func.sig.ident.to_string()), func.sig.ident.span());
73
74        let arg_list = PgGuardRewriter::build_arg_list(&sig, false)?;
75        let func_name = PgGuardRewriter::build_func_name(&func.sig);
76
77        let prolog = if input_func_name == "__pgx_private_shmem_hook"
78            || input_func_name == "__pgx_private_shmem_request_hook"
79        {
80            // we do not want "no_mangle" on these functions
81            quote! {}
82        } else if input_func_name == "_PG_init" || input_func_name == "_PG_fini" {
83            quote! {
84                #[allow(non_snake_case)]
85                #[no_mangle]
86            }
87        } else {
88            quote! {}
89        };
90
91        let body = if generics.params.is_empty() {
92            quote! { #func_name(#arg_list) }
93        } else {
94            let ty = generics
95                .params
96                .into_iter()
97                .filter_map(|p| match p {
98                    GenericParam::Type(ty) => Some(ty.ident),
99                    GenericParam::Const(c) => Some(c.ident),
100                    GenericParam::Lifetime(_) => None,
101                })
102                .collect::<Punctuated<_, Token![,]>>();
103            quote! { #func_name::<#ty>(#arg_list) }
104        };
105
106        Ok(quote_spanned! {func.span()=>
107            #prolog
108            #(#attrs)*
109            #vis #sig {
110                #[allow(non_snake_case)]
111                #func
112
113                #[allow(unused_unsafe)]
114                unsafe {
115                    // NB: this is purposely not spelled `::pgx` as pgx itself uses #[pg_guard]
116                    pgx::pg_sys::submodules::panic::pgx_extern_c_guard( || #body )
117                }
118            }
119        })
120    }
121
122    pub fn foreign_item(
123        &self,
124        item: ForeignItem,
125        abi: &syn::Abi,
126    ) -> syn::Result<proc_macro2::TokenStream> {
127        match item {
128            ForeignItem::Fn(func) => {
129                if func.sig.variadic.is_some() {
130                    return Ok(quote! { #abi { #func } });
131                }
132
133                self.foreign_item_fn(&func, abi)
134            }
135            _ => Ok(quote! { #abi { #item } }),
136        }
137    }
138
139    pub fn foreign_item_fn(
140        &self,
141        func: &ForeignItemFn,
142        abi: &syn::Abi,
143    ) -> syn::Result<proc_macro2::TokenStream> {
144        let func_name = PgGuardRewriter::build_func_name(&func.sig);
145        let arg_list = PgGuardRewriter::rename_arg_list(&func.sig)?;
146        let arg_list_with_types = PgGuardRewriter::rename_arg_list_with_types(&func.sig)?;
147        let return_type = PgGuardRewriter::get_return_type(&func.sig);
148
149        Ok(quote! {
150            #[track_caller]
151            pub unsafe fn #func_name ( #arg_list_with_types ) #return_type {
152                crate::ffi::pg_guard_ffi_boundary(move || {
153                    #abi { #func }
154                    #func_name(#arg_list)
155                })
156            }
157        })
158    }
159
160    pub fn build_func_name(sig: &Signature) -> Ident {
161        sig.ident.clone()
162    }
163
164    #[allow(clippy::cmp_owned)]
165    pub fn build_arg_list(
166        sig: &Signature,
167        suffix_arg_name: bool,
168    ) -> syn::Result<proc_macro2::TokenStream> {
169        let mut arg_list = proc_macro2::TokenStream::new();
170
171        for arg in &sig.inputs {
172            match arg {
173                FnArg::Typed(ty) => {
174                    if let Pat::Ident(ident) = ty.pat.deref() {
175                        if suffix_arg_name && ident.ident.to_string() != "fcinfo" {
176                            let ident = Ident::new(&format!("{}_", ident.ident), ident.span());
177                            arg_list.extend(quote! { #ident, });
178                        } else {
179                            arg_list.extend(quote! { #ident, });
180                        }
181                    } else {
182                        return Err(syn::Error::new(
183                            ty.pat.span(),
184                            "Unknown argument pattern in `#[pg_guard]` function",
185                        ));
186                    }
187                }
188                a @ FnArg::Receiver(_) => return Err(syn::Error::new(
189                    a.span(),
190                    "#[pg_guard] doesn't support external functions with 'self' as the argument",
191                )),
192            }
193        }
194
195        Ok(arg_list)
196    }
197
198    pub fn rename_arg_list(sig: &Signature) -> syn::Result<proc_macro2::TokenStream> {
199        let mut arg_list = proc_macro2::TokenStream::new();
200
201        for arg in &sig.inputs {
202            match arg {
203                FnArg::Typed(ty) => {
204                    if let Pat::Ident(ident) = ty.pat.deref() {
205                        // prefix argument name with "arg_""
206                        let name = Ident::new(&format!("arg_{}", ident.ident), ident.ident.span());
207                        arg_list.extend(quote! { #name, });
208                    } else {
209                        return Err(syn::Error::new(
210                            ty.pat.span(),
211                            "Unknown argument pattern in `#[pg_guard]` function",
212                        ));
213                    }
214                }
215                a @ FnArg::Receiver(_) => return Err(syn::Error::new(
216                    a.span(),
217                    "#[pg_guard] doesn't support external functions with 'self' as the argument",
218                )),
219            }
220        }
221
222        Ok(arg_list)
223    }
224
225    pub fn rename_arg_list_with_types(sig: &Signature) -> syn::Result<proc_macro2::TokenStream> {
226        let mut arg_list = proc_macro2::TokenStream::new();
227
228        for arg in &sig.inputs {
229            match arg {
230                FnArg::Typed(ty) => {
231                    if let Pat::Ident(_) = ty.pat.deref() {
232                        // prefix argument name with a "arg_"
233                        let arg =
234                            proc_macro2::TokenStream::from_str(&format!("arg_{}", quote! {#ty}))
235                                .unwrap();
236                        arg_list.extend(quote! { #arg, });
237                    } else {
238                        return Err(syn::Error::new(
239                            ty.pat.span(),
240                            "Unknown argument pattern in `#[pg_guard]` function",
241                        ));
242                    }
243                }
244                a @ FnArg::Receiver(_) => return Err(syn::Error::new(
245                    a.span(),
246                    "#[pg_guard] doesn't support external functions with 'self' as the argument",
247                )),
248            }
249        }
250
251        Ok(arg_list)
252    }
253
254    pub fn get_return_type(sig: &Signature) -> ReturnType {
255        sig.output.clone()
256    }
257}