1extern crate proc_macro;
11
12use proc_macro2::Ident;
13use quote::{quote, quote_spanned};
14use std::ops::Deref;
15use std::str::FromStr;
16use syn::punctuated::Punctuated;
17use syn::spanned::Spanned;
18use syn::{
19 FnArg, ForeignItem, ForeignItemFn, GenericParam, ItemFn, ItemForeignMod, Pat, ReturnType,
20 Signature, Token, Visibility,
21};
22
23pub struct PgGuardRewriter();
24
25impl PgGuardRewriter {
26 pub fn new() -> Self {
27 PgGuardRewriter()
28 }
29
30 pub fn extern_block(&self, block: ItemForeignMod) -> proc_macro2::TokenStream {
31 let mut stream = proc_macro2::TokenStream::new();
32
33 for item in block.items.into_iter() {
34 stream.extend(self.foreign_item(item, &block.abi));
35 }
36
37 stream
38 }
39
40 pub fn item_fn_without_rewrite(
41 &self,
42 mut func: ItemFn,
43 ) -> eyre::Result<proc_macro2::TokenStream> {
44 let input_func_name = func.sig.ident.to_string();
47 let sig = func.sig.clone();
48 let vis = func.vis.clone();
49 let attrs = func.attrs.clone();
50
51 let generics = func.sig.generics.clone();
52
53 if attrs.iter().any(|attr| attr.path.is_ident("no_mangle"))
54 && generics.params.iter().any(|p| match p {
55 GenericParam::Type(_) => true,
56 GenericParam::Lifetime(_) => false,
57 GenericParam::Const(_) => true,
58 })
59 {
60 panic!("#[pg_guard] for function with generic parameters must not be combined with #[no_mangle]");
61 }
62
63 func.sig.abi = None;
66 func.attrs.clear();
67
68 func.vis = Visibility::Inherited;
70
71 func.sig.ident =
72 Ident::new(&format!("{}_inner", func.sig.ident.to_string()), func.sig.ident.span());
73
74 let arg_list = PgGuardRewriter::build_arg_list(&sig, false)?;
75 let func_name = PgGuardRewriter::build_func_name(&func.sig);
76
77 let prolog = if input_func_name == "__pgx_private_shmem_hook"
78 || input_func_name == "__pgx_private_shmem_request_hook"
79 {
80 quote! {}
82 } else if input_func_name == "_PG_init" || input_func_name == "_PG_fini" {
83 quote! {
84 #[allow(non_snake_case)]
85 #[no_mangle]
86 }
87 } else {
88 quote! {}
89 };
90
91 let body = if generics.params.is_empty() {
92 quote! { #func_name(#arg_list) }
93 } else {
94 let ty = generics
95 .params
96 .into_iter()
97 .filter_map(|p| match p {
98 GenericParam::Type(ty) => Some(ty.ident),
99 GenericParam::Const(c) => Some(c.ident),
100 GenericParam::Lifetime(_) => None,
101 })
102 .collect::<Punctuated<_, Token![,]>>();
103 quote! { #func_name::<#ty>(#arg_list) }
104 };
105
106 Ok(quote_spanned! {func.span()=>
107 #prolog
108 #(#attrs)*
109 #vis #sig {
110 #[allow(non_snake_case)]
111 #func
112
113 #[allow(unused_unsafe)]
114 unsafe {
115 pg_sys::panic::pgx_extern_c_guard( || #body )
116 }
117 }
118 })
119 }
120
121 pub fn foreign_item(
122 &self,
123 item: ForeignItem,
124 abi: &syn::Abi,
125 ) -> eyre::Result<proc_macro2::TokenStream> {
126 match item {
127 ForeignItem::Fn(func) => {
128 if func.sig.variadic.is_some() {
129 return Ok(quote! { #abi { #func } });
130 }
131
132 self.foreign_item_fn(&func, abi)
133 }
134 _ => Ok(quote! { #abi { #item } }),
135 }
136 }
137
138 pub fn foreign_item_fn(
139 &self,
140 func: &ForeignItemFn,
141 abi: &syn::Abi,
142 ) -> eyre::Result<proc_macro2::TokenStream> {
143 let func_name = PgGuardRewriter::build_func_name(&func.sig);
144 let arg_list = PgGuardRewriter::rename_arg_list(&func.sig)?;
145 let arg_list_with_types = PgGuardRewriter::rename_arg_list_with_types(&func.sig)?;
146 let return_type = PgGuardRewriter::get_return_type(&func.sig);
147
148 Ok(quote! {
149 #[track_caller]
150 pub unsafe fn #func_name ( #arg_list_with_types ) #return_type {
151 crate::ffi::pg_guard_ffi_boundary(move || {
152 #abi { #func }
153 #func_name(#arg_list)
154 })
155 }
156 })
157 }
158
159 pub fn build_func_name(sig: &Signature) -> Ident {
160 sig.ident.clone()
161 }
162
163 #[allow(clippy::cmp_owned)]
164 pub fn build_arg_list(
165 sig: &Signature,
166 suffix_arg_name: bool,
167 ) -> eyre::Result<proc_macro2::TokenStream> {
168 let mut arg_list = proc_macro2::TokenStream::new();
169
170 for arg in &sig.inputs {
171 match arg {
172 FnArg::Typed(ty) => {
173 if let Pat::Ident(ident) = ty.pat.deref() {
174 if suffix_arg_name && ident.ident.to_string() != "fcinfo" {
175 let ident = Ident::new(&format!("{}_", ident.ident), ident.span());
176 arg_list.extend(quote! { #ident, });
177 } else {
178 arg_list.extend(quote! { #ident, });
179 }
180 } else {
181 eyre::bail!(
182 "Unknown argument pattern in `#[pg_guard]` function: `{:?}`",
183 ty.pat,
184 );
185 }
186 }
187 FnArg::Receiver(_) => panic!(
188 "#[pg_guard] doesn't support external functions with 'self' as the argument"
189 ),
190 }
191 }
192
193 Ok(arg_list)
194 }
195
196 pub fn rename_arg_list(sig: &Signature) -> eyre::Result<proc_macro2::TokenStream> {
197 let mut arg_list = proc_macro2::TokenStream::new();
198
199 for arg in &sig.inputs {
200 match arg {
201 FnArg::Typed(ty) => {
202 if let Pat::Ident(ident) = ty.pat.deref() {
203 let name = Ident::new(&format!("arg_{}", ident.ident), ident.ident.span());
205 arg_list.extend(quote! { #name, });
206 } else {
207 eyre::bail!(
208 "Unknown argument pattern in `#[pg_guard]` function: `{:?}`",
209 ty.pat,
210 );
211 }
212 }
213 FnArg::Receiver(_) => panic!(
214 "#[pg_guard] doesn't support external functions with 'self' as the argument"
215 ),
216 }
217 }
218
219 Ok(arg_list)
220 }
221
222 pub fn rename_arg_list_with_types(sig: &Signature) -> eyre::Result<proc_macro2::TokenStream> {
223 let mut arg_list = proc_macro2::TokenStream::new();
224
225 for arg in &sig.inputs {
226 match arg {
227 FnArg::Typed(ty) => {
228 if let Pat::Ident(_) = ty.pat.deref() {
229 let arg =
231 proc_macro2::TokenStream::from_str(&format!("arg_{}", quote! {#ty}))
232 .unwrap();
233 arg_list.extend(quote! { #arg, });
234 } else {
235 eyre::bail!(
236 "Unknown argument pattern in `#[pg_guard]` function: `{:?}`",
237 ty.pat,
238 );
239 }
240 }
241 FnArg::Receiver(_) => panic!(
242 "#[pg_guard] doesn't support external functions with 'self' as the argument"
243 ),
244 }
245 }
246
247 Ok(arg_list)
248 }
249
250 pub fn get_return_type(sig: &Signature) -> ReturnType {
251 sig.output.clone()
252 }
253}