pgx_utils/
rewriter.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9
10extern crate proc_macro;
11
12use proc_macro2::Ident;
13use quote::{quote, quote_spanned};
14use std::ops::Deref;
15use std::str::FromStr;
16use syn::punctuated::Punctuated;
17use syn::spanned::Spanned;
18use syn::{
19    FnArg, ForeignItem, ForeignItemFn, GenericParam, ItemFn, ItemForeignMod, Pat, ReturnType,
20    Signature, Token, Visibility,
21};
22
23pub struct PgGuardRewriter();
24
25impl PgGuardRewriter {
26    pub fn new() -> Self {
27        PgGuardRewriter()
28    }
29
30    pub fn extern_block(&self, block: ItemForeignMod) -> proc_macro2::TokenStream {
31        let mut stream = proc_macro2::TokenStream::new();
32
33        for item in block.items.into_iter() {
34            stream.extend(self.foreign_item(item, &block.abi));
35        }
36
37        stream
38    }
39
40    pub fn item_fn_without_rewrite(
41        &self,
42        mut func: ItemFn,
43    ) -> eyre::Result<proc_macro2::TokenStream> {
44        // remember the original visibility and signature classifications as we want
45        // to use those for the outer function
46        let input_func_name = func.sig.ident.to_string();
47        let sig = func.sig.clone();
48        let vis = func.vis.clone();
49        let attrs = func.attrs.clone();
50
51        let generics = func.sig.generics.clone();
52
53        if attrs.iter().any(|attr| attr.path.is_ident("no_mangle"))
54            && generics.params.iter().any(|p| match p {
55                GenericParam::Type(_) => true,
56                GenericParam::Lifetime(_) => false,
57                GenericParam::Const(_) => true,
58            })
59        {
60            panic!("#[pg_guard] for function with generic parameters must not be combined with #[no_mangle]");
61        }
62
63        // but for the inner function (the one we're wrapping) we don't need any kind of
64        // abi classification
65        func.sig.abi = None;
66        func.attrs.clear();
67
68        // nor do we need a visibility beyond "private"
69        func.vis = Visibility::Inherited;
70
71        func.sig.ident =
72            Ident::new(&format!("{}_inner", func.sig.ident.to_string()), func.sig.ident.span());
73
74        let arg_list = PgGuardRewriter::build_arg_list(&sig, false)?;
75        let func_name = PgGuardRewriter::build_func_name(&func.sig);
76
77        let prolog = if input_func_name == "__pgx_private_shmem_hook"
78            || input_func_name == "__pgx_private_shmem_request_hook"
79        {
80            // we do not want "no_mangle" on these functions
81            quote! {}
82        } else if input_func_name == "_PG_init" || input_func_name == "_PG_fini" {
83            quote! {
84                #[allow(non_snake_case)]
85                #[no_mangle]
86            }
87        } else {
88            quote! {}
89        };
90
91        let body = if generics.params.is_empty() {
92            quote! { #func_name(#arg_list) }
93        } else {
94            let ty = generics
95                .params
96                .into_iter()
97                .filter_map(|p| match p {
98                    GenericParam::Type(ty) => Some(ty.ident),
99                    GenericParam::Const(c) => Some(c.ident),
100                    GenericParam::Lifetime(_) => None,
101                })
102                .collect::<Punctuated<_, Token![,]>>();
103            quote! { #func_name::<#ty>(#arg_list) }
104        };
105
106        Ok(quote_spanned! {func.span()=>
107            #prolog
108            #(#attrs)*
109            #vis #sig {
110                #[allow(non_snake_case)]
111                #func
112
113                #[allow(unused_unsafe)]
114                unsafe {
115                    pg_sys::panic::pgx_extern_c_guard( || #body )
116                }
117            }
118        })
119    }
120
121    pub fn foreign_item(
122        &self,
123        item: ForeignItem,
124        abi: &syn::Abi,
125    ) -> eyre::Result<proc_macro2::TokenStream> {
126        match item {
127            ForeignItem::Fn(func) => {
128                if func.sig.variadic.is_some() {
129                    return Ok(quote! { #abi { #func } });
130                }
131
132                self.foreign_item_fn(&func, abi)
133            }
134            _ => Ok(quote! { #abi { #item } }),
135        }
136    }
137
138    pub fn foreign_item_fn(
139        &self,
140        func: &ForeignItemFn,
141        abi: &syn::Abi,
142    ) -> eyre::Result<proc_macro2::TokenStream> {
143        let func_name = PgGuardRewriter::build_func_name(&func.sig);
144        let arg_list = PgGuardRewriter::rename_arg_list(&func.sig)?;
145        let arg_list_with_types = PgGuardRewriter::rename_arg_list_with_types(&func.sig)?;
146        let return_type = PgGuardRewriter::get_return_type(&func.sig);
147
148        Ok(quote! {
149            #[track_caller]
150            pub unsafe fn #func_name ( #arg_list_with_types ) #return_type {
151                crate::ffi::pg_guard_ffi_boundary(move || {
152                    #abi { #func }
153                    #func_name(#arg_list)
154                })
155            }
156        })
157    }
158
159    pub fn build_func_name(sig: &Signature) -> Ident {
160        sig.ident.clone()
161    }
162
163    #[allow(clippy::cmp_owned)]
164    pub fn build_arg_list(
165        sig: &Signature,
166        suffix_arg_name: bool,
167    ) -> eyre::Result<proc_macro2::TokenStream> {
168        let mut arg_list = proc_macro2::TokenStream::new();
169
170        for arg in &sig.inputs {
171            match arg {
172                FnArg::Typed(ty) => {
173                    if let Pat::Ident(ident) = ty.pat.deref() {
174                        if suffix_arg_name && ident.ident.to_string() != "fcinfo" {
175                            let ident = Ident::new(&format!("{}_", ident.ident), ident.span());
176                            arg_list.extend(quote! { #ident, });
177                        } else {
178                            arg_list.extend(quote! { #ident, });
179                        }
180                    } else {
181                        eyre::bail!(
182                            "Unknown argument pattern in `#[pg_guard]` function: `{:?}`",
183                            ty.pat,
184                        );
185                    }
186                }
187                FnArg::Receiver(_) => panic!(
188                    "#[pg_guard] doesn't support external functions with 'self' as the argument"
189                ),
190            }
191        }
192
193        Ok(arg_list)
194    }
195
196    pub fn rename_arg_list(sig: &Signature) -> eyre::Result<proc_macro2::TokenStream> {
197        let mut arg_list = proc_macro2::TokenStream::new();
198
199        for arg in &sig.inputs {
200            match arg {
201                FnArg::Typed(ty) => {
202                    if let Pat::Ident(ident) = ty.pat.deref() {
203                        // prefix argument name with "arg_""
204                        let name = Ident::new(&format!("arg_{}", ident.ident), ident.ident.span());
205                        arg_list.extend(quote! { #name, });
206                    } else {
207                        eyre::bail!(
208                            "Unknown argument pattern in `#[pg_guard]` function: `{:?}`",
209                            ty.pat,
210                        );
211                    }
212                }
213                FnArg::Receiver(_) => panic!(
214                    "#[pg_guard] doesn't support external functions with 'self' as the argument"
215                ),
216            }
217        }
218
219        Ok(arg_list)
220    }
221
222    pub fn rename_arg_list_with_types(sig: &Signature) -> eyre::Result<proc_macro2::TokenStream> {
223        let mut arg_list = proc_macro2::TokenStream::new();
224
225        for arg in &sig.inputs {
226            match arg {
227                FnArg::Typed(ty) => {
228                    if let Pat::Ident(_) = ty.pat.deref() {
229                        // prefix argument name with a "arg_"
230                        let arg =
231                            proc_macro2::TokenStream::from_str(&format!("arg_{}", quote! {#ty}))
232                                .unwrap();
233                        arg_list.extend(quote! { #arg, });
234                    } else {
235                        eyre::bail!(
236                            "Unknown argument pattern in `#[pg_guard]` function: `{:?}`",
237                            ty.pat,
238                        );
239                    }
240                }
241                FnArg::Receiver(_) => panic!(
242                    "#[pg_guard] doesn't support external functions with 'self' as the argument"
243                ),
244            }
245        }
246
247        Ok(arg_list)
248    }
249
250    pub fn get_return_type(sig: &Signature) -> ReturnType {
251        sig.output.clone()
252    }
253}