use pgrx_pg_sys::PgTryBuilder;
use pgrx_pg_sys::errcodes::PgSqlErrorCode;
use pgrx_pg_sys::ffi::pg_guard_ffi_boundary;
use std::panic::AssertUnwindSafe;
use crate::memcx;
use crate::pg_catalog::pg_proc::{PgProc, ProArgMode, ProKind};
use crate::seal::Sealed;
use crate::{
Array, FromDatum, IntoDatum, direct_function_call, is_a, list::List, pg_sys, pg_sys::AsPgCStr,
};
pub unsafe trait FnCallArg: Sealed {
fn as_datum(&self, pg_proc: &PgProc, argnum: usize) -> Result<Option<pg_sys::Datum>>;
fn type_oid(&self) -> pg_sys::Oid;
}
pub enum Arg<T> {
Null,
Default,
Value(T),
}
impl<T> Sealed for Arg<T> {}
unsafe impl<T: IntoDatum + Clone> FnCallArg for Arg<T> {
fn as_datum(&self, pg_proc: &PgProc, argnum: usize) -> Result<Option<pg_sys::Datum>> {
match self {
Arg::Null => Ok(None),
Arg::Value(v) => Ok(Clone::clone(v).into_datum()),
Arg::Default => create_default_value(pg_proc, argnum),
}
}
#[inline]
fn type_oid(&self) -> pg_sys::Oid {
T::type_oid()
}
}
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
pub enum FnCallError {
#[error("Invalid identifier: `{0}`")]
InvalidIdentifier(String),
#[error("The specified function does not exist")]
UndefinedFunction,
#[error(
"The specified function exists, but has overloaded versions which are ambiguous given the argument types provided"
)]
AmbiguousFunction,
#[error("Can only dynamically call plain functions")]
UnsupportedFunctionType,
#[error("Functions with OUT/IN_OUT/TABLE arguments are not supported")]
UnsupportedArgumentModes,
#[error("Functions with argument or return types of `internal` are not supported")]
InternalTypeNotSupported,
#[error(
"The requested return type OID `{0:?}` is not compatible with the actual return type OID `{1:?}`"
)]
IncompatibleReturnType(pg_sys::Oid, pg_sys::Oid),
#[error("Function call has more arguments than are supported")]
TooManyArguments,
#[error("Did not provide enough non-default arguments")]
NotEnoughArguments,
#[error("Function has no default arguments")]
NoDefaultArguments,
#[error("Argument #{0} does not have a DEFAULT value")]
NotDefaultArgument(usize),
#[error("Argument's default value is not a constant expression")]
DefaultNotConstantExpression,
}
pub type Result<T> = std::result::Result<T, FnCallError>;
pub fn fn_call<R: FromDatum + IntoDatum>(
fname: &str,
args: &[&dyn FnCallArg],
) -> Result<Option<R>> {
fn_call_with_collation(fname, pg_sys::DEFAULT_COLLATION_OID, args)
}
pub fn fn_call_with_collation<R: FromDatum + IntoDatum>(
fname: &str,
collation: pg_sys::Oid,
args: &[&dyn FnCallArg],
) -> Result<Option<R>> {
let func_oid = lookup_fn(fname, args)?;
let pg_proc = PgProc::new(func_oid).ok_or(FnCallError::UndefinedFunction)?;
let retoid = pg_proc.prorettype();
if !matches!(pg_proc.prokind(), ProKind::Function) {
return Err(FnCallError::UnsupportedFunctionType);
} else if pg_proc.proargmodes().iter().any(|mode| *mode != ProArgMode::In) {
return Err(FnCallError::UnsupportedArgumentModes);
} else if retoid == pg_sys::INTERNALOID || pg_proc.proargtypes().contains(&pg_sys::INTERNALOID)
{
return Err(FnCallError::InternalTypeNotSupported);
} else if !R::is_compatible_with(retoid) {
return Err(FnCallError::IncompatibleReturnType(R::type_oid(), retoid));
}
let mut null = false;
let arg_datums = args
.iter()
.enumerate()
.map(|(i, a)| a.as_datum(&pg_proc, i))
.chain((args.len()..pg_proc.pronargs()).map(|i| create_default_value(&pg_proc, i)))
.map(|datum| {
null |= matches!(datum, Ok(None));
datum
})
.collect::<Result<Vec<_>>>()?;
let nargs = arg_datums.len();
let isstrict = pg_proc.proisstrict();
if null && isstrict {
return Ok(None);
}
unsafe {
let mut flinfo = pg_sys::FmgrInfo::default();
pg_sys::fmgr_info(func_oid, &mut flinfo);
assert_eq!(nargs, pg_proc.pronargs());
let fcinfo = pg_sys::palloc0(
std::mem::size_of::<pg_sys::FunctionCallInfoBaseData>()
+ std::mem::size_of::<pg_sys::NullableDatum>() * nargs,
) as *mut pg_sys::FunctionCallInfoBaseData;
let fcinfo_ref = &mut *fcinfo;
fcinfo_ref.flinfo = &mut flinfo;
fcinfo_ref.fncollation = collation;
fcinfo_ref.context = std::ptr::null_mut();
fcinfo_ref.resultinfo = std::ptr::null_mut();
fcinfo_ref.isnull = false;
fcinfo_ref.nargs = nargs as _;
let args_slice = fcinfo_ref.args.as_mut_slice(nargs);
for (i, datum) in arg_datums.into_iter().enumerate() {
assert!(!isstrict || datum.is_some());
let arg = &mut args_slice[i];
(arg.value, arg.isnull) =
datum.map(|d| (d, false)).unwrap_or_else(|| (pg_sys::Datum::from(0), true));
}
let func = *(*fcinfo_ref.flinfo)
.fn_addr
.as_ref()
.expect("function initialization problem: fn_addr not set");
let result_datum = pg_guard_ffi_boundary(|| func(fcinfo));
let result = R::from_datum(result_datum, fcinfo_ref.isnull);
pg_sys::pfree(fcinfo.cast());
Ok(result)
}
}
fn lookup_fn(fname: &str, args: &[&dyn FnCallArg]) -> Result<pg_sys::Oid> {
memcx::current_context(|mcx| {
let mut parts_list = List::<*mut std::ffi::c_void>::default();
let result = PgTryBuilder::new(AssertUnwindSafe(|| unsafe {
let arg_types = args.iter().map(|a| a.type_oid()).collect::<Vec<_>>();
let nargs: i16 =
arg_types.len().try_into().map_err(|_| FnCallError::TooManyArguments)?;
let ident_parts = parse_sql_ident(fname)?;
ident_parts
.iter_deny_null()
.map(|part| {
pg_sys::makeString(part.as_pg_cstr())
})
.for_each(|part| {
parts_list.unstable_push_in_context(part.cast(), mcx);
});
let mut fnoid = pg_sys::LookupFuncName(
parts_list.as_mut_ptr(),
nargs as _,
arg_types.as_ptr(),
true,
);
if fnoid == pg_sys::InvalidOid {
fnoid = pg_sys::LookupFuncName(
parts_list.as_mut_ptr(),
-1,
arg_types.as_ptr(),
false, );
}
Ok(fnoid)
}))
.catch_when(PgSqlErrorCode::ERRCODE_INVALID_PARAMETER_VALUE, |_| {
Err(FnCallError::InvalidIdentifier(fname.to_string()))
})
.catch_when(PgSqlErrorCode::ERRCODE_AMBIGUOUS_FUNCTION, |_| {
Err(FnCallError::AmbiguousFunction)
})
.catch_when(PgSqlErrorCode::ERRCODE_UNDEFINED_FUNCTION, |_| {
Err(FnCallError::UndefinedFunction)
})
.execute();
unsafe {
parts_list.drain(..).for_each(|s| {
#[cfg(any(feature = "pg13", feature = "pg14"))]
{
let s = s.cast::<pg_sys::Value>();
pg_sys::pfree((*s).val.str_.cast());
}
#[cfg(any(feature = "pg15", feature = "pg16", feature = "pg17", feature = "pg18"))]
{
let s = s.cast::<pg_sys::String>();
pg_sys::pfree((*s).sval.cast());
}
});
}
result
})
}
fn parse_sql_ident(ident: &str) -> Result<Array<'_, &str>> {
unsafe {
direct_function_call::<Array<&str>>(
pg_sys::parse_ident,
&[ident.into_datum(), true.into_datum()],
)
.ok_or_else(|| FnCallError::InvalidIdentifier(ident.to_string()))
}
}
fn create_default_value(pg_proc: &PgProc, argnum: usize) -> Result<Option<pg_sys::Datum>> {
let non_default_args_cnt = pg_proc.pronargs() - pg_proc.pronargdefaults();
if argnum < non_default_args_cnt {
return Err(FnCallError::NotDefaultArgument(argnum));
}
let default_argnum = argnum - non_default_args_cnt;
let node = memcx::current_context(|mcx| {
let default_value_tree =
pg_proc.proargdefaults(mcx).ok_or(FnCallError::NoDefaultArguments)?;
default_value_tree
.get(default_argnum)
.ok_or(FnCallError::NotDefaultArgument(argnum))
.copied()
})?;
unsafe {
let evaluated = pg_sys::eval_const_expressions(std::ptr::null_mut(), node.cast());
if is_a(evaluated.cast(), pg_sys::NodeTag::T_Const) {
let con: *mut pg_sys::Const = evaluated.cast();
let con_ref = &*con;
if con_ref.constisnull { Ok(None) } else { Ok(Some(con_ref.constvalue)) }
} else {
Err(FnCallError::DefaultNotConstantExpression)
}
}
}