extern crate proc_macro;
use pgx_utils::{categorize_return_type, CategorizedType};
use proc_macro2::Ident;
use quote::{quote, quote_spanned};
use std::ops::Deref;
use std::str::FromStr;
use syn::export::{Span, TokenStream2};
use syn::spanned::Spanned;
use syn::{
FnArg, ForeignItem, ForeignItemFn, Generics, ItemFn, ItemForeignMod, Pat, ReturnType,
Signature, Type, Visibility,
};
pub struct PgGuardRewriter();
impl PgGuardRewriter {
pub fn new() -> Self {
PgGuardRewriter()
}
pub fn extern_block(&self, block: ItemForeignMod) -> proc_macro2::TokenStream {
let mut stream = TokenStream2::new();
for item in block.items.into_iter() {
stream.extend(self.foreign_item(item));
}
stream
}
pub fn item_fn(
&self,
func: ItemFn,
rewrite_args: bool,
is_raw: bool,
no_guard: bool,
) -> (proc_macro2::TokenStream, bool) {
if rewrite_args {
self.item_fn_with_rewrite(func, is_raw, no_guard)
} else {
(self.item_fn_without_rewrite(func), true)
}
}
fn item_fn_with_rewrite(
&self,
mut func: ItemFn,
is_raw: bool,
no_guard: bool,
) -> (proc_macro2::TokenStream, bool) {
let vis = func.vis.clone();
func.sig.abi = None;
let arg_list = PgGuardRewriter::build_arg_list(&func.sig, true);
let func_name = &func.sig.ident;
let func_span = func.span();
let rewritten_args = self.rewrite_args(func.clone(), is_raw);
let rewritten_return_type = self.rewrite_return_type(func.clone());
let generics = &func.sig.generics;
let func_name_wrapper = Ident::new(
&format!("{}_wrapper", &func.sig.ident.to_string()),
func_span,
);
let returns_void = rewritten_return_type
.to_string()
.contains("pg_return_void()");
let result_var_name = if returns_void {
Ident::new("_", Span::call_site())
} else {
Ident::new("result", Span::call_site())
};
let func_call = if no_guard {
quote! {
let result = {
#rewritten_args
#func_name(#arg_list)
};
}
} else {
quote! {
let #result_var_name = pg_sys::guard::guard( || {
#rewritten_args
#func_name(#arg_list)
} );
}
};
let prolog = quote! {
#[inline(never)]
#func
#[no_mangle]
#[allow(unused_variables)]
};
match categorize_return_type(&func) {
CategorizedType::Default => (
PgGuardRewriter::impl_standard_udf(
func_span,
prolog,
vis,
func_name_wrapper,
generics,
func_call,
rewritten_return_type,
),
true,
),
CategorizedType::Tuple(_types) => (PgGuardRewriter::impl_tuple_udf(func), false),
CategorizedType::Iterator(types) if types.len() == 1 => (
PgGuardRewriter::impl_setof_srf(
types,
func_span,
prolog,
vis,
func_name_wrapper,
generics,
func_call,
false,
),
true,
),
CategorizedType::OptionalIterator(types) if types.len() == 1 => (
PgGuardRewriter::impl_setof_srf(
types,
func_span,
prolog,
vis,
func_name_wrapper,
generics,
func_call,
true,
),
true,
),
CategorizedType::Iterator(types) => (
PgGuardRewriter::impl_table_srf(
types,
func_span,
prolog,
vis,
func_name_wrapper,
generics,
func_call,
false,
),
true,
),
CategorizedType::OptionalIterator(types) => (
PgGuardRewriter::impl_table_srf(
types,
func_span,
prolog,
vis,
func_name_wrapper,
generics,
func_call,
true,
),
true,
),
}
}
fn impl_standard_udf(
func_span: Span,
prolog: proc_macro2::TokenStream,
vis: Visibility,
func_name_wrapper: Ident,
generics: &Generics,
func_call: proc_macro2::TokenStream,
rewritten_return_type: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
quote_spanned! {func_span=>
#prolog
#[inline(never)]
#[allow(clippy::missing_safety_doc)]
#[allow(clippy::redundant_closure)]
#[pg_guard]
#vis unsafe fn #func_name_wrapper #generics(fcinfo: pg_sys::FunctionCallInfo) -> pg_sys::Datum {
#func_call
#rewritten_return_type
}
}
}
fn impl_tuple_udf(mut func: ItemFn) -> proc_macro2::TokenStream {
let func_span = func.span();
let return_type = func.sig.output;
let return_type = format!("{}", quote! {#return_type});
let return_type = TokenStream2::from_str(return_type.trim_start_matches("->")).unwrap();
let return_type = quote! {impl std::iter::Iterator<Item = #return_type>};
func.sig.output = ReturnType::Default;
let sig = func.sig;
let body = func.block;
quote_spanned! {func_span=>
#[pg_extern]
#sig -> #return_type {
Some(#body).into_iter()
}
}
}
fn impl_setof_srf(
types: Vec<String>,
func_span: Span,
prolog: proc_macro2::TokenStream,
vis: Visibility,
func_name_wrapper: Ident,
generics: &Generics,
func_call: proc_macro2::TokenStream,
optional: bool,
) -> proc_macro2::TokenStream {
let generic_type = proc_macro2::TokenStream::from_str(types.first().unwrap()).unwrap();
let result_handler = if optional {
quote! {
let result = match pgx::PgMemoryContexts::For(funcctx.multi_call_memory_ctx).switch_to(|_| { #func_call result }) {
Some(result) => result,
None => {
pgx::srf_return_done(fcinfo, &mut funcctx);
return pgx::pg_return_null(fcinfo)
}
};
}
} else {
quote! {
let result = pgx::PgMemoryContexts::For(funcctx.multi_call_memory_ctx).switch_to(|_| { #func_call result });
}
};
quote_spanned! {func_span=>
#prolog
#[pg_guard]
#vis fn #func_name_wrapper #generics(fcinfo: pg_sys::FunctionCallInfo) -> pg_sys::Datum {
struct IteratorHolder<T> {
iter: *mut dyn Iterator<Item=T>,
}
let mut funcctx: pgx::PgBox<pg_sys::FuncCallContext>;
let mut iterator_holder: pgx::PgBox<IteratorHolder<#generic_type>>;
if srf_is_first_call(fcinfo) {
funcctx = pgx::srf_first_call_init(fcinfo);
funcctx.user_fctx = pgx::PgMemoryContexts::For(funcctx.multi_call_memory_ctx).palloc_struct::<IteratorHolder<#generic_type>>() as void_mut_ptr;
iterator_holder = pgx::PgBox::from_pg(funcctx.user_fctx as *mut IteratorHolder<#generic_type>);
#result_handler
iterator_holder.iter = pgx::PgMemoryContexts::For(funcctx.multi_call_memory_ctx).leak_and_drop_on_delete(result);
}
funcctx = pgx::srf_per_call_setup(fcinfo);
iterator_holder = pgx::PgBox::from_pg(funcctx.user_fctx as *mut IteratorHolder<#generic_type>);
let mut iter = unsafe { Box::from_raw(iterator_holder.iter) };
match iter.next() {
Some(result) => {
Box::leak(iter);
pgx::srf_return_next(fcinfo, &mut funcctx);
match result.into_datum() {
Some(datum) => datum,
None => pgx::pg_return_null(fcinfo),
}
},
None => {
Box::leak(iter);
pgx::srf_return_done(fcinfo, &mut funcctx);
pgx::pg_return_null(fcinfo)
},
}
}
}
}
fn impl_table_srf(
types: Vec<String>,
func_span: Span,
prolog: proc_macro2::TokenStream,
vis: Visibility,
func_name_wrapper: Ident,
generics: &Generics,
func_call: proc_macro2::TokenStream,
optional: bool,
) -> proc_macro2::TokenStream {
let numtypes = types.len();
let i = (0..numtypes).map(syn::Index::from);
let create_heap_tuple = quote! {
let mut datums: [usize; #numtypes] = [0; #numtypes];
let mut nulls: [bool; #numtypes] = [false; #numtypes];
#(
let datum = result.#i.into_datum();
match datum {
Some(datum) => { datums[#i] = datum as usize; },
None => { nulls[#i] = true; }
}
)*
let heap_tuple = unsafe { pgx::pg_sys::heap_form_tuple(funcctx.tuple_desc, datums.as_mut_ptr(), nulls.as_mut_ptr()) };
};
let composite_type = format!("({})", types.join(","));
let generic_type = proc_macro2::TokenStream::from_str(&composite_type).unwrap();
let result_handler = if optional {
quote! {
let result = match pgx::PgMemoryContexts::For(funcctx.multi_call_memory_ctx).switch_to(|_| { #func_call result }) {
Some(result) => result,
None => {
pgx::srf_return_done(fcinfo, &mut funcctx);
return pgx::pg_return_null(fcinfo)
}
};
}
} else {
quote! {
let result = pgx::PgMemoryContexts::For(funcctx.multi_call_memory_ctx).switch_to(|_| { #func_call result });
}
};
quote_spanned! {func_span=>
#prolog
#[pg_guard]
#vis fn #func_name_wrapper #generics(fcinfo: pg_sys::FunctionCallInfo) -> pg_sys::Datum {
struct IteratorHolder<T> {
iter: *mut dyn Iterator<Item=T>,
}
let mut funcctx: pgx::PgBox<pg_sys::FuncCallContext>;
let mut iterator_holder: pgx::PgBox<IteratorHolder<#generic_type>>;
if srf_is_first_call(fcinfo) {
funcctx = pgx::srf_first_call_init(fcinfo);
funcctx.user_fctx = pgx::PgMemoryContexts::For(funcctx.multi_call_memory_ctx).palloc_struct::<IteratorHolder<#generic_type>>() as void_mut_ptr;
funcctx.tuple_desc = pgx::PgMemoryContexts::For(funcctx.multi_call_memory_ctx).switch_to(|_| {
let mut tupdesc: *mut pgx::pg_sys::TupleDescData = std::ptr::null_mut();
unsafe {
if pgx::pg_sys::get_call_result_type(fcinfo, std::ptr::null_mut(), &mut tupdesc) != pgx::pg_sys::TypeFuncClass_TYPEFUNC_COMPOSITE {
pgx::error!("return type must be a row type");
}
pgx::pg_sys::BlessTupleDesc(tupdesc)
}
});
iterator_holder = pgx::PgBox::from_pg(funcctx.user_fctx as *mut IteratorHolder<#generic_type>);
#result_handler
iterator_holder.iter = pgx::PgMemoryContexts::For(funcctx.multi_call_memory_ctx).leak_and_drop_on_delete(result);
}
funcctx = pgx::srf_per_call_setup(fcinfo);
iterator_holder = pgx::PgBox::from_pg(funcctx.user_fctx as *mut IteratorHolder<#generic_type>);
let mut iter = unsafe { Box::from_raw(iterator_holder.iter) };
match iter.next() {
Some(result) => {
Box::leak(iter);
#create_heap_tuple
let datum = pgx::heap_tuple_get_datum(heap_tuple);
pgx::srf_return_next(fcinfo, &mut funcctx);
datum as pgx::pg_sys::Datum
},
None => {
Box::leak(iter);
pgx::srf_return_done(fcinfo, &mut funcctx);
pgx::pg_return_null(fcinfo)
},
}
}
}
}
fn item_fn_without_rewrite(&self, mut func: ItemFn) -> proc_macro2::TokenStream {
let sig = func.sig.clone();
let vis = func.vis.clone();
func.sig.abi = None;
func.vis = Visibility::Inherited;
func.sig.ident = Ident::new(
&format!("{}_inner", func.sig.ident.to_string()),
func.sig.ident.span(),
);
let arg_list = PgGuardRewriter::build_arg_list(&sig, false);
let func_name = PgGuardRewriter::build_func_name(&func.sig);
quote_spanned! {func.span()=>
#[no_mangle]
#vis #sig {
#func
pg_sys::guard::guard( || #func_name(#arg_list) )
}
}
}
pub fn foreign_item(&self, item: ForeignItem) -> proc_macro2::TokenStream {
match item {
ForeignItem::Fn(func) => {
if func.sig.variadic.is_some() {
return quote! { extern "C" { #func } };
}
self.foreign_item_fn(func)
}
_ => quote! { extern "C" { #item } },
}
}
pub fn foreign_item_fn(&self, func: ForeignItemFn) -> proc_macro2::TokenStream {
let func_name = PgGuardRewriter::build_func_name(&func.sig);
let arg_list = PgGuardRewriter::rename_arg_list(&func.sig);
let arg_list_with_types = PgGuardRewriter::rename_arg_list_with_types(&func.sig);
let return_type = PgGuardRewriter::get_return_type(&func.sig);
quote! {
#[inline(never)]
#[allow(clippy::missing_safety_doc)]
#[allow(clippy::redundant_closure)]
pub unsafe fn #func_name ( #arg_list_with_types ) #return_type {
extern "C" {
pub fn #func_name( #arg_list_with_types ) #return_type ;
}
let prev_exception_stack = pg_sys::PG_exception_stack;
let prev_error_context_stack = pg_sys::error_context_stack;
let mut jmp_buff = std::mem::MaybeUninit::uninit();
let jump_value = pg_sys::sigsetjmp(jmp_buff.as_mut_ptr(), 0);
let result = if jump_value == 0 {
pg_sys::PG_exception_stack = jmp_buff.as_mut_ptr();
#func_name(#arg_list)
} else {
pg_sys::PG_exception_stack = prev_exception_stack;
pg_sys::error_context_stack = prev_error_context_stack;
panic!(pg_sys::JumpContext { });
};
pg_sys::PG_exception_stack = prev_exception_stack;
pg_sys::error_context_stack = prev_error_context_stack;
result
}
}
}
pub fn build_func_name(sig: &Signature) -> Ident {
sig.ident.clone()
}
#[allow(clippy::cmp_owned)]
pub fn build_arg_list(sig: &Signature, suffix_arg_name: bool) -> proc_macro2::TokenStream {
let mut arg_list = proc_macro2::TokenStream::new();
for arg in &sig.inputs {
match arg {
FnArg::Typed(ty) => {
if let Pat::Ident(ident) = ty.pat.deref() {
if suffix_arg_name && ident.ident.to_string() != "fcinfo" {
let ident = Ident::new(&format!("{}_", ident.ident), ident.span());
arg_list.extend(quote! { #ident, });
} else {
arg_list.extend(quote! { #ident, });
}
}
}
FnArg::Receiver(_) => panic!(
"#[pg_guard] doesn't support external functions with 'self' as the argument"
),
}
}
arg_list
}
pub fn rename_arg_list(sig: &Signature) -> proc_macro2::TokenStream {
let mut arg_list = proc_macro2::TokenStream::new();
for arg in &sig.inputs {
match arg {
FnArg::Typed(ty) => {
if let Pat::Ident(ident) = ty.pat.deref() {
let name = Ident::new(&format!("arg_{}", ident.ident), ident.ident.span());
arg_list.extend(quote! { #name, });
}
}
FnArg::Receiver(_) => panic!(
"#[pg_guard] doesn't support external functions with 'self' as the argument"
),
}
}
arg_list
}
pub fn rename_arg_list_with_types(sig: &Signature) -> proc_macro2::TokenStream {
let mut arg_list = proc_macro2::TokenStream::new();
for arg in &sig.inputs {
match arg {
FnArg::Typed(ty) => {
if let Pat::Ident(_) = ty.pat.deref() {
let arg =
proc_macro2::TokenStream::from_str(&format!("arg_{}", quote! {#ty}))
.unwrap();
arg_list.extend(quote! { #arg, });
}
}
FnArg::Receiver(_) => panic!(
"#[pg_guard] doesn't support external functions with 'self' as the argument"
),
}
}
arg_list
}
pub fn get_return_type(sig: &Signature) -> ReturnType {
sig.output.clone()
}
pub fn rewrite_args(&self, func: ItemFn, is_raw: bool) -> proc_macro2::TokenStream {
let fsr = FunctionSignatureRewriter::new(func);
let args = fsr.args(is_raw);
quote! {
#args
}
}
pub fn rewrite_return_type(&self, func: ItemFn) -> proc_macro2::TokenStream {
let fsr = FunctionSignatureRewriter::new(func);
let result = fsr.return_type();
quote! {
#result
}
}
}
struct FunctionSignatureRewriter {
func: ItemFn,
}
impl FunctionSignatureRewriter {
fn new(func: ItemFn) -> Self {
FunctionSignatureRewriter { func }
}
fn return_type(&self) -> proc_macro2::TokenStream {
let mut stream = proc_macro2::TokenStream::new();
match &self.func.sig.output {
ReturnType::Default => {
stream.extend(quote! {
pgx::pg_return_void()
});
}
ReturnType::Type(_, type_) => {
if type_matches(type_, "Option") {
stream.extend(quote! {
match result {
Some(result) => {
result.into_datum().unwrap_or_else(|| panic!("returned Option<T> was NULL"))
},
None => pgx::pg_return_null(fcinfo)
}
});
} else if type_matches(type_, "pg_sys :: Datum") {
stream.extend(quote! {
result
});
} else if type_matches(type_, "()") {
stream.extend(quote! {
pgx::pg_return_void()
});
} else {
stream.extend(quote! {
result.into_datum().unwrap_or_else(|| panic!("returned Datum was NULL"))
});
}
}
}
stream
}
fn args(&self, is_raw: bool) -> proc_macro2::TokenStream {
if self.func.sig.inputs.len() == 1 && self.return_type_is_datum() {
if let FnArg::Typed(ty) = self.func.sig.inputs.first().unwrap() {
if type_matches(&ty.ty, "pg_sys :: FunctionCallInfo") {
return proc_macro2::TokenStream::new();
}
}
}
let mut stream = proc_macro2::TokenStream::new();
let mut i = 0usize;
let mut have_fcinfo = false;
for arg in &self.func.sig.inputs {
match arg {
FnArg::Receiver(_) => panic!("Functions that take self are not supported"),
FnArg::Typed(ty) => match ty.pat.deref() {
Pat::Ident(ident) => {
let name = Ident::new(&format!("{}_", ident.ident), ident.span());
let type_ = &ty.ty;
let is_option = type_matches(type_, "Option");
if have_fcinfo {
panic!("When using `pg_sys::FunctionCallInfo` as an argument it must be the last argument")
}
let ts = if is_option {
let option_type = extract_option_type(type_);
quote_spanned! {ident.span()=>
let #name = pgx::pg_getarg::<#option_type>(fcinfo, #i);
}
} else if type_matches(type_, "pg_sys :: FunctionCallInfo") {
have_fcinfo = true;
quote_spanned! {ident.span()=>
let #name = fcinfo;
}
} else if is_raw {
quote_spanned! {ident.span()=>
let #name = pgx::pg_getarg_datum_raw(fcinfo, #i) as #type_;
}
} else {
quote_spanned! {ident.span()=>
let #name = pgx::pg_getarg::<#type_>(fcinfo, #i).unwrap_or_else(|| panic!("{} is null", stringify!{#ident}));
}
};
stream.extend(ts);
i += 1;
}
_ => panic!("Unrecognized function arg type"),
},
}
}
stream
}
fn return_type_is_datum(&self) -> bool {
match &self.func.sig.output {
ReturnType::Default => false,
ReturnType::Type(_, ty) => type_matches(ty, "pg_sys :: Datum"),
}
}
}
fn type_matches(ty: &Type, pattern: &str) -> bool {
let type_string = format!("{}", quote! {#ty});
type_string.starts_with(pattern)
}
fn extract_option_type(ty: &Type) -> proc_macro2::TokenStream {
match ty {
Type::Path(path) => {
let mut stream = proc_macro2::TokenStream::new();
for segment in &path.path.segments {
let arguments = &segment.arguments;
stream.extend(quote! { #arguments });
}
let string = stream.to_string();
let string = string.trim().trim_start_matches('<');
let string = string.trim().trim_end_matches('>');
proc_macro2::TokenStream::from_str(string.trim()).unwrap()
}
_ => panic!("No type found inside Option"),
}
}