1extern crate proc_macro;
11
12use proc_macro2::Ident;
13use quote::{quote, quote_spanned};
14use std::ops::Deref;
15use std::str::FromStr;
16use syn::punctuated::Punctuated;
17use syn::spanned::Spanned;
18use syn::{
19 FnArg, ForeignItem, ForeignItemFn, GenericParam, ItemFn, ItemForeignMod, Pat, ReturnType,
20 Signature, Token, Visibility,
21};
22
23pub struct PgGuardRewriter();
24
25impl PgGuardRewriter {
26 pub fn new() -> Self {
27 PgGuardRewriter()
28 }
29
30 pub fn extern_block(&self, block: ItemForeignMod) -> proc_macro2::TokenStream {
31 let mut stream = proc_macro2::TokenStream::new();
32
33 for item in block.items.into_iter() {
34 stream.extend(self.foreign_item(item, &block.abi));
35 }
36
37 stream
38 }
39
40 pub fn item_fn_without_rewrite(
41 &self,
42 mut func: ItemFn,
43 ) -> syn::Result<proc_macro2::TokenStream> {
44 let input_func_name = func.sig.ident.to_string();
47 let sig = func.sig.clone();
48 let vis = func.vis.clone();
49 let attrs = func.attrs.clone();
50
51 let generics = func.sig.generics.clone();
52
53 if attrs.iter().any(|attr| attr.path.is_ident("no_mangle"))
54 && generics.params.iter().any(|p| match p {
55 GenericParam::Type(_) => true,
56 GenericParam::Lifetime(_) => false,
57 GenericParam::Const(_) => true,
58 })
59 {
60 panic!("#[pg_guard] for function with generic parameters must not be combined with #[no_mangle]");
61 }
62
63 func.sig.abi = None;
66 func.attrs.clear();
67
68 func.vis = Visibility::Inherited;
70
71 func.sig.ident =
72 Ident::new(&format!("{}_inner", func.sig.ident.to_string()), func.sig.ident.span());
73
74 let arg_list = PgGuardRewriter::build_arg_list(&sig, false)?;
75 let func_name = PgGuardRewriter::build_func_name(&func.sig);
76
77 let prolog = if input_func_name == "__pgx_private_shmem_hook"
78 || input_func_name == "__pgx_private_shmem_request_hook"
79 {
80 quote! {}
82 } else if input_func_name == "_PG_init" || input_func_name == "_PG_fini" {
83 quote! {
84 #[allow(non_snake_case)]
85 #[no_mangle]
86 }
87 } else {
88 quote! {}
89 };
90
91 let body = if generics.params.is_empty() {
92 quote! { #func_name(#arg_list) }
93 } else {
94 let ty = generics
95 .params
96 .into_iter()
97 .filter_map(|p| match p {
98 GenericParam::Type(ty) => Some(ty.ident),
99 GenericParam::Const(c) => Some(c.ident),
100 GenericParam::Lifetime(_) => None,
101 })
102 .collect::<Punctuated<_, Token![,]>>();
103 quote! { #func_name::<#ty>(#arg_list) }
104 };
105
106 Ok(quote_spanned! {func.span()=>
107 #prolog
108 #(#attrs)*
109 #vis #sig {
110 #[allow(non_snake_case)]
111 #func
112
113 #[allow(unused_unsafe)]
114 unsafe {
115 pgx::pg_sys::submodules::panic::pgx_extern_c_guard( || #body )
117 }
118 }
119 })
120 }
121
122 pub fn foreign_item(
123 &self,
124 item: ForeignItem,
125 abi: &syn::Abi,
126 ) -> syn::Result<proc_macro2::TokenStream> {
127 match item {
128 ForeignItem::Fn(func) => {
129 if func.sig.variadic.is_some() {
130 return Ok(quote! { #abi { #func } });
131 }
132
133 self.foreign_item_fn(&func, abi)
134 }
135 _ => Ok(quote! { #abi { #item } }),
136 }
137 }
138
139 pub fn foreign_item_fn(
140 &self,
141 func: &ForeignItemFn,
142 abi: &syn::Abi,
143 ) -> syn::Result<proc_macro2::TokenStream> {
144 let func_name = PgGuardRewriter::build_func_name(&func.sig);
145 let arg_list = PgGuardRewriter::rename_arg_list(&func.sig)?;
146 let arg_list_with_types = PgGuardRewriter::rename_arg_list_with_types(&func.sig)?;
147 let return_type = PgGuardRewriter::get_return_type(&func.sig);
148
149 Ok(quote! {
150 #[track_caller]
151 pub unsafe fn #func_name ( #arg_list_with_types ) #return_type {
152 crate::ffi::pg_guard_ffi_boundary(move || {
153 #abi { #func }
154 #func_name(#arg_list)
155 })
156 }
157 })
158 }
159
160 pub fn build_func_name(sig: &Signature) -> Ident {
161 sig.ident.clone()
162 }
163
164 #[allow(clippy::cmp_owned)]
165 pub fn build_arg_list(
166 sig: &Signature,
167 suffix_arg_name: bool,
168 ) -> syn::Result<proc_macro2::TokenStream> {
169 let mut arg_list = proc_macro2::TokenStream::new();
170
171 for arg in &sig.inputs {
172 match arg {
173 FnArg::Typed(ty) => {
174 if let Pat::Ident(ident) = ty.pat.deref() {
175 if suffix_arg_name && ident.ident.to_string() != "fcinfo" {
176 let ident = Ident::new(&format!("{}_", ident.ident), ident.span());
177 arg_list.extend(quote! { #ident, });
178 } else {
179 arg_list.extend(quote! { #ident, });
180 }
181 } else {
182 return Err(syn::Error::new(
183 ty.pat.span(),
184 "Unknown argument pattern in `#[pg_guard]` function",
185 ));
186 }
187 }
188 a @ FnArg::Receiver(_) => return Err(syn::Error::new(
189 a.span(),
190 "#[pg_guard] doesn't support external functions with 'self' as the argument",
191 )),
192 }
193 }
194
195 Ok(arg_list)
196 }
197
198 pub fn rename_arg_list(sig: &Signature) -> syn::Result<proc_macro2::TokenStream> {
199 let mut arg_list = proc_macro2::TokenStream::new();
200
201 for arg in &sig.inputs {
202 match arg {
203 FnArg::Typed(ty) => {
204 if let Pat::Ident(ident) = ty.pat.deref() {
205 let name = Ident::new(&format!("arg_{}", ident.ident), ident.ident.span());
207 arg_list.extend(quote! { #name, });
208 } else {
209 return Err(syn::Error::new(
210 ty.pat.span(),
211 "Unknown argument pattern in `#[pg_guard]` function",
212 ));
213 }
214 }
215 a @ FnArg::Receiver(_) => return Err(syn::Error::new(
216 a.span(),
217 "#[pg_guard] doesn't support external functions with 'self' as the argument",
218 )),
219 }
220 }
221
222 Ok(arg_list)
223 }
224
225 pub fn rename_arg_list_with_types(sig: &Signature) -> syn::Result<proc_macro2::TokenStream> {
226 let mut arg_list = proc_macro2::TokenStream::new();
227
228 for arg in &sig.inputs {
229 match arg {
230 FnArg::Typed(ty) => {
231 if let Pat::Ident(_) = ty.pat.deref() {
232 let arg =
234 proc_macro2::TokenStream::from_str(&format!("arg_{}", quote! {#ty}))
235 .unwrap();
236 arg_list.extend(quote! { #arg, });
237 } else {
238 return Err(syn::Error::new(
239 ty.pat.span(),
240 "Unknown argument pattern in `#[pg_guard]` function",
241 ));
242 }
243 }
244 a @ FnArg::Receiver(_) => return Err(syn::Error::new(
245 a.span(),
246 "#[pg_guard] doesn't support external functions with 'self' as the argument",
247 )),
248 }
249 }
250
251 Ok(arg_list)
252 }
253
254 pub fn get_return_type(sig: &Signature) -> ReturnType {
255 sig.output.clone()
256 }
257}