extern crate proc_macro;
use proc_macro::TokenStream;
use std::collections::HashSet;
use std::ffi::CString;
use proc_macro2::Ident;
use quote::{ToTokens, format_ident, quote};
use syn::spanned::Spanned;
use syn::{Attribute, Data, DeriveInput, Item, ItemImpl, parse_macro_input};
use operators::{deriving_postgres_eq, deriving_postgres_hash, deriving_postgres_ord};
use pgrx_sql_entity_graph as sql_gen;
use sql_gen::{
CodeEnrichment, ExtensionSql, ExtensionSqlFile, ExternArgs, PgAggregate, PgCast, PgExtern,
PostgresEnum, Schema, parse_extern_attributes,
};
mod operators;
mod pg_bench;
mod rewriter;
#[proc_macro_attribute]
pub fn pg_guard(attr: TokenStream, item: TokenStream) -> TokenStream {
let ast = parse_macro_input!(item as syn::Item);
let res = match ast {
Item::ForeignMod(block) => Ok(rewriter::extern_block(block)),
Item::Fn(func) => rewriter::item_fn_without_rewrite(func, attr),
unknown => Err(syn::Error::new(
unknown.span(),
"#[pg_guard] can only be applied to extern \"C-unwind\" blocks and top-level functions",
)),
};
res.unwrap_or_else(|e| e.into_compile_error()).into()
}
#[proc_macro_attribute]
pub fn pg_test(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut stream = proc_macro2::TokenStream::new();
let args = parse_extern_attributes(proc_macro2::TokenStream::from(attr.clone()));
let mut expected_error = None;
args.into_iter().for_each(|v| {
if let ExternArgs::ShouldPanic(message) = v {
expected_error = Some(message)
}
});
let ast = parse_macro_input!(item as syn::Item);
match ast {
Item::Fn(mut func) => {
let (test_attributes, non_test_attributes) =
func.attrs.into_iter().partition::<Vec<Attribute>, _>(|attr| {
attr.path()
.get_ident()
.is_some_and(|ident| ident == "ignore" || ident == "should_panic")
});
func.attrs = non_test_attributes;
let original_ident = func.sig.ident.clone();
maybe_shorten_pg_test_ident(&mut func.sig.ident);
stream.extend(proc_macro2::TokenStream::from(pg_extern(
attr,
Item::Fn(func.clone()).to_token_stream().into(),
)));
let expected_error = match expected_error {
Some(msg) => quote! {Some(#msg)},
None => quote! {None},
};
let sql_funcname = func.sig.ident.to_string();
let test_func_name = format_ident!("pg_{}", original_ident);
let attributes = func.attrs;
let mut att_stream = proc_macro2::TokenStream::new();
for a in attributes.iter() {
let as_str = a.to_token_stream().to_string();
att_stream.extend(quote! {
options.push(#as_str);
});
}
stream.extend(quote! {
#[test]
#(#test_attributes)*
fn #test_func_name() {
let mut options = Vec::new();
#att_stream
crate::pg_test::setup(options);
let res = pgrx_tests::run_test(#sql_funcname, #expected_error, crate::pg_test::postgresql_conf_options());
match res {
Ok(()) => (),
Err(e) => panic!("{e:?}")
}
}
});
}
thing => {
return syn::Error::new(
thing.span(),
"#[pg_test] can only be applied to top-level functions",
)
.into_compile_error()
.into();
}
}
stream.into()
}
fn maybe_shorten_pg_test_ident(ident: &mut syn::Ident) {
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
const POSTGRES_IDENTIFIER_MAX_LEN: usize = 64;
let original = ident.to_string();
if original.len() < POSTGRES_IDENTIFIER_MAX_LEN {
return;
}
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let prefix = format!("t{n}_");
let name_budget = (POSTGRES_IDENTIFIER_MAX_LEN - 1) - prefix.len();
let mut byte_end = name_budget.min(original.len());
while !original.is_char_boundary(byte_end) {
byte_end -= 1;
}
let shortened = format!("{prefix}{}", &original[..byte_end]);
*ident = syn::Ident::new(&shortened, ident.span());
}
#[proc_macro_attribute]
pub fn pg_bench(attr: TokenStream, item: TokenStream) -> TokenStream {
pg_bench::pg_bench(attr, item)
}
#[proc_macro_attribute]
pub fn initialize(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn pg_cast(attr: TokenStream, item: TokenStream) -> TokenStream {
fn wrapped(attr: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
use syn::parse::Parser;
use syn::punctuated::Punctuated;
let mut cast = None;
let mut pg_extern_attrs = proc_macro2::TokenStream::new();
match Punctuated::<syn::Path, syn::Token![,]>::parse_terminated.parse(attr) {
Ok(paths) => {
let mut new_paths = Punctuated::<syn::Path, syn::Token![,]>::new();
for path in paths {
match (PgCast::try_from(path), &cast) {
(Ok(style), None) => cast = Some(style),
(Ok(_), Some(cast)) => {
panic!("The cast type has already been set to `{cast:?}`")
}
(Err(unknown), _) => {
new_paths.push(unknown);
}
}
}
pg_extern_attrs.extend(new_paths.into_token_stream());
}
Err(err) => {
panic!("Failed to parse attribute to pg_cast: {err}")
}
}
let pg_extern = PgExtern::new(pg_extern_attrs, item.clone().into())?.0;
Ok(CodeEnrichment(pg_extern.as_cast(cast.unwrap_or_default())).to_token_stream().into())
}
wrapped(attr, item).unwrap_or_else(|e: syn::Error| e.into_compile_error().into())
}
#[proc_macro_attribute]
pub fn pg_operator(attr: TokenStream, item: TokenStream) -> TokenStream {
pg_extern(attr, item)
}
#[proc_macro_attribute]
pub fn opname(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn commutator(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn negator(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn restrict(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn join(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn hashes(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn merges(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn pg_schema(_attr: TokenStream, input: TokenStream) -> TokenStream {
fn wrapped(input: TokenStream) -> Result<TokenStream, syn::Error> {
let pgrx_schema: Schema = syn::parse(input)?;
Ok(pgrx_schema.to_token_stream().into())
}
wrapped(input).unwrap_or_else(|e: syn::Error| e.into_compile_error().into())
}
#[proc_macro]
pub fn extension_sql(input: TokenStream) -> TokenStream {
fn wrapped(input: TokenStream) -> Result<TokenStream, syn::Error> {
let ext_sql: CodeEnrichment<ExtensionSql> = syn::parse(input)?;
Ok(ext_sql.to_token_stream().into())
}
wrapped(input).unwrap_or_else(|e: syn::Error| e.into_compile_error().into())
}
#[proc_macro]
pub fn extension_sql_file(input: TokenStream) -> TokenStream {
fn wrapped(input: TokenStream) -> Result<TokenStream, syn::Error> {
let ext_sql: CodeEnrichment<ExtensionSqlFile> = syn::parse(input)?;
Ok(ext_sql.to_token_stream().into())
}
wrapped(input).unwrap_or_else(|e: syn::Error| e.into_compile_error().into())
}
#[proc_macro_attribute]
pub fn search_path(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
#[track_caller]
pub fn pg_extern(attr: TokenStream, item: TokenStream) -> TokenStream {
fn wrapped(attr: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
let pg_extern_item = PgExtern::new(attr.into(), item.into())?;
Ok(pg_extern_item.to_token_stream().into())
}
wrapped(attr, item).unwrap_or_else(|e: syn::Error| e.into_compile_error().into())
}
#[proc_macro_derive(PostgresEnum, attributes(requires, pgrx))]
pub fn postgres_enum(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_enum(ast).unwrap_or_else(|e| e.into_compile_error()).into()
}
fn impl_postgres_enum(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();
let sql_graph_entity_ast = ast.clone();
let generics = &ast.generics.clone();
let enum_ident = &ast.ident;
let enum_name = enum_ident.to_string();
let Data::Enum(enum_data) = ast.data else {
return Err(syn::Error::new(
ast.span(),
"#[derive(PostgresEnum)] can only be applied to enums",
));
};
let mut from_datum = proc_macro2::TokenStream::new();
let mut into_datum = proc_macro2::TokenStream::new();
for d in enum_data.variants.clone() {
let label_ident = &d.ident;
let label_string = label_ident.to_string();
from_datum.extend(quote! { #label_string => Some(#enum_ident::#label_ident), });
into_datum.extend(quote! { #enum_ident::#label_ident => Some(::pgrx::enum_helper::lookup_enum_by_label(#enum_name, #label_string)), });
}
let fcx_lt = syn::Lifetime::new("'fcx", proc_macro2::Span::mixed_site());
let mut generics_with_fcx = generics.clone();
generics_with_fcx.make_where_clause().predicates.push(syn::WherePredicate::Type(
syn::PredicateType {
lifetimes: None,
bounded_ty: syn::parse_quote! { Self },
colon_token: syn::Token),
bounds: syn::parse_quote! { #fcx_lt },
},
));
let (impl_gens, ty_gens, where_clause) = generics_with_fcx.split_for_impl();
let mut impl_gens: syn::Generics = syn::parse_quote! { #impl_gens };
impl_gens
.params
.insert(0, syn::GenericParam::Lifetime(syn::LifetimeParam::new(fcx_lt.clone())));
stream.extend(quote! {
impl ::pgrx::datum::FromDatum for #enum_ident {
#[inline]
unsafe fn from_polymorphic_datum(datum: ::pgrx::pg_sys::Datum, is_null: bool, _typeoid: ::pgrx::pg_sys::Oid) -> Option<#enum_ident> {
if is_null {
None
} else {
let (name, _, _) = ::pgrx::enum_helper::lookup_enum_by_oid(unsafe { ::pgrx::pg_sys::Oid::from_datum(datum, is_null)? } );
match name.as_str() {
#from_datum
_ => panic!("invalid enum value: {name}")
}
}
}
}
unsafe impl #impl_gens ::pgrx::callconv::ArgAbi<#fcx_lt> for #enum_ident #ty_gens #where_clause {
unsafe fn unbox_arg_unchecked(arg: ::pgrx::callconv::Arg<'_, #fcx_lt>) -> Self {
let index = arg.index();
unsafe { arg.unbox_arg_using_from_datum().unwrap_or_else(|| panic!("argument {index} must not be null")) }
}
}
unsafe impl #generics ::pgrx::datum::UnboxDatum for #enum_ident #generics {
type As<'dat> = #enum_ident #generics where Self: 'dat;
#[inline]
unsafe fn unbox<'dat>(d: ::pgrx::datum::Datum<'dat>) -> Self::As<'dat> where Self: 'dat {
<Self as ::pgrx::datum::FromDatum>::from_datum(::core::mem::transmute(d), false).unwrap()
}
}
impl ::pgrx::datum::IntoDatum for #enum_ident {
#[inline]
fn into_datum(self) -> Option<::pgrx::pg_sys::Datum> {
match self {
#into_datum
}
}
fn type_oid() -> ::pgrx::pg_sys::Oid {
::pgrx::wrappers::regtypein(#enum_name)
}
}
unsafe impl ::pgrx::callconv::BoxRet for #enum_ident {
unsafe fn box_into<'fcx>(self, fcinfo: &mut ::pgrx::callconv::FcInfo<'fcx>) -> ::pgrx::datum::Datum<'fcx> {
match ::pgrx::datum::IntoDatum::into_datum(self) {
None => fcinfo.return_null(),
Some(datum) => unsafe { fcinfo.return_raw_datum(datum) },
}
}
}
});
let sql_graph_entity_item = PostgresEnum::from_derive_input(sql_graph_entity_ast)?;
sql_graph_entity_item.to_tokens(&mut stream);
Ok(stream)
}
#[proc_macro_derive(
PostgresType,
attributes(
inoutfuncs,
pgvarlena_inoutfuncs,
pg_binary_protocol,
bikeshed_postgres_type_manually_impl_from_into_datum,
requires,
pgrx
)
)]
pub fn postgres_type(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_type(ast).unwrap_or_else(|e| e.into_compile_error()).into()
}
fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let name = &ast.ident;
let generics = &ast.generics.clone();
let has_lifetimes = generics.lifetimes().next();
let funcname_in = Ident::new(&format!("{name}_in").to_lowercase(), name.span());
let funcname_out = Ident::new(&format!("{name}_out").to_lowercase(), name.span());
let funcname_recv = Ident::new(&format!("{name}_recv").to_lowercase(), name.span());
let funcname_send = Ident::new(&format!("{name}_send").to_lowercase(), name.span());
let mut args = parse_postgres_type_args(&ast.attrs);
let mut stream = proc_macro2::TokenStream::new();
match ast.data {
Data::Struct(_) => { }
Data::Enum(_) => {
}
_ => {
return Err(syn::Error::new(
ast.span(),
"#[derive(PostgresType)] can only be applied to structs or enums",
));
}
}
if !args.contains(&PostgresTypeAttribute::InOutFuncs)
&& !args.contains(&PostgresTypeAttribute::PgVarlenaInOutFuncs)
{
args.insert(PostgresTypeAttribute::Default);
}
let lifetime = match has_lifetimes {
Some(lifetime) => quote! {#lifetime},
None => quote! {'_},
};
let fcx_lt = syn::Lifetime::new("'fcx", proc_macro2::Span::mixed_site());
let mut generics_with_fcx = generics.clone();
generics_with_fcx.make_where_clause().predicates.push(syn::WherePredicate::Type(
syn::PredicateType {
lifetimes: None,
bounded_ty: syn::parse_quote! { Self },
colon_token: syn::Token),
bounds: syn::parse_quote! { #fcx_lt },
},
));
let (impl_gens, ty_gens, where_clause) = generics_with_fcx.split_for_impl();
let mut impl_gens: syn::Generics = syn::parse_quote! { #impl_gens };
impl_gens
.params
.insert(0, syn::GenericParam::Lifetime(syn::LifetimeParam::new(fcx_lt.clone())));
stream.extend(quote! {
impl #generics ::pgrx::datum::PostgresType for #name #generics { }
});
if !args.contains(&PostgresTypeAttribute::ManualFromIntoDatum) {
stream.extend(
quote! {
impl #generics ::pgrx::datum::IntoDatum for #name #generics {
fn into_datum(self) -> Option<::pgrx::pg_sys::Datum> {
#[allow(deprecated)]
Some(unsafe { ::pgrx::datum::cbor_encode(&self) }.into())
}
fn type_oid() -> ::pgrx::pg_sys::Oid {
::pgrx::wrappers::rust_regtypein::<Self>()
}
}
unsafe impl #generics ::pgrx::callconv::BoxRet for #name #generics {
unsafe fn box_into<'fcx>(self, fcinfo: &mut ::pgrx::callconv::FcInfo<'fcx>) -> ::pgrx::datum::Datum<'fcx> {
match ::pgrx::datum::IntoDatum::into_datum(self) {
None => fcinfo.return_null(),
Some(datum) => unsafe { fcinfo.return_raw_datum(datum) },
}
}
}
impl #generics ::pgrx::datum::FromDatum for #name #generics {
unsafe fn from_polymorphic_datum(
datum: ::pgrx::pg_sys::Datum,
is_null: bool,
_typoid: ::pgrx::pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
#[allow(deprecated)]
::pgrx::datum::cbor_decode(datum.cast_mut_ptr())
}
}
unsafe fn from_datum_in_memory_context(
mut memory_context: ::pgrx::memcxt::PgMemoryContexts,
datum: ::pgrx::pg_sys::Datum,
is_null: bool,
_typoid: ::pgrx::pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
memory_context.switch_to(|_| {
let varlena = ::pgrx::pg_sys::pg_detoast_datum_copy(datum.cast_mut_ptr());
<Self as ::pgrx::datum::FromDatum>::from_datum(varlena.into(), is_null)
})
}
}
}
unsafe impl #generics ::pgrx::datum::UnboxDatum for #name #generics {
type As<'dat> = Self where Self: 'dat;
unsafe fn unbox<'dat>(datum: ::pgrx::datum::Datum<'dat>) -> Self::As<'dat> where Self: 'dat {
<Self as ::pgrx::datum::FromDatum>::from_datum(::core::mem::transmute(datum), false).unwrap()
}
}
unsafe impl #impl_gens ::pgrx::callconv::ArgAbi<#fcx_lt> for #name #ty_gens #where_clause
{
unsafe fn unbox_arg_unchecked(arg: ::pgrx::callconv::Arg<'_, #fcx_lt>) -> Self {
let index = arg.index();
unsafe { arg.unbox_arg_using_from_datum().unwrap_or_else(|| panic!("argument {index} must not be null")) }
}
}
}
)
}
if args.contains(&PostgresTypeAttribute::Default) {
stream.extend(quote! {
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_extern(immutable, parallel_safe)]
pub fn #funcname_in #generics(input: Option<&#lifetime ::core::ffi::CStr>) -> Option<#name #generics> {
use ::pgrx::inoutfuncs::json_from_slice;
input.map(|cstr| json_from_slice(cstr.to_bytes()).ok()).flatten()
}
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_extern (immutable, parallel_safe)]
pub fn #funcname_out #generics(input: #name #generics) -> ::pgrx::ffi::CString {
use ::pgrx::inoutfuncs::json_to_vec;
let mut bytes = json_to_vec(&input).unwrap();
bytes.push(0); ::pgrx::ffi::CString::from_vec_with_nul(bytes).unwrap()
}
});
} else if args.contains(&PostgresTypeAttribute::InOutFuncs) {
stream.extend(quote! {
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_extern(immutable,parallel_safe)]
pub fn #funcname_in #generics(input: Option<&::core::ffi::CStr>) -> Option<#name #generics> {
input.map_or_else(|| {
if let Some(m) = <#name as ::pgrx::inoutfuncs::InOutFuncs>::NULL_ERROR_MESSAGE {
::pgrx::pg_sys::error!("{m}");
}
None
}, |i| Some(<#name as ::pgrx::inoutfuncs::InOutFuncs>::input(i)))
}
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_extern(immutable,parallel_safe)]
pub fn #funcname_out #generics(input: #name #generics) -> ::pgrx::ffi::CString {
let mut buffer = ::pgrx::stringinfo::StringInfo::new();
::pgrx::inoutfuncs::InOutFuncs::output(&input, &mut buffer);
unsafe { buffer.leak_cstr().to_owned() }
}
});
} else if args.contains(&PostgresTypeAttribute::PgVarlenaInOutFuncs) {
stream.extend(quote! {
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_extern(immutable,parallel_safe)]
pub fn #funcname_in #generics(input: Option<&::core::ffi::CStr>) -> Option<::pgrx::datum::PgVarlena<#name #generics>> {
input.map_or_else(|| {
if let Some(m) = <#name as ::pgrx::inoutfuncs::PgVarlenaInOutFuncs>::NULL_ERROR_MESSAGE {
::pgrx::pg_sys::error!("{m}");
}
None
}, |i| Some(<#name as ::pgrx::inoutfuncs::PgVarlenaInOutFuncs>::input(i)))
}
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_extern(immutable,parallel_safe)]
pub fn #funcname_out #generics(input: ::pgrx::datum::PgVarlena<#name #generics>) -> ::pgrx::ffi::CString {
let mut buffer = ::pgrx::stringinfo::StringInfo::new();
::pgrx::inoutfuncs::PgVarlenaInOutFuncs::output(&*input, &mut buffer);
unsafe { buffer.leak_cstr().to_owned() }
}
});
}
if args.contains(&PostgresTypeAttribute::PgBinaryProtocol) {
stream.extend(quote! {
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_extern(immutable, strict, parallel_safe)]
pub fn #funcname_recv #generics(
mut internal: ::pgrx::datum::Internal,
) -> #name #generics {
let buf = unsafe { internal.get_mut::<::pgrx::pg_sys::StringInfoData>().unwrap() };
let mut serialized = ::pgrx::StringInfo::new();
serialized.push_bytes(&[0u8; ::pgrx::pg_sys::VARHDRSZ]); serialized.push_bytes(unsafe {
core::slice::from_raw_parts(
buf.data as *const u8,
buf.len as usize
)
});
let size = serialized.len();
let varlena = serialized.into_char_ptr();
unsafe{
::pgrx::set_varsize_4b(varlena as *mut ::pgrx::pg_sys::varlena, size as i32);
buf.cursor = buf.len;
::pgrx::datum::cbor_decode(varlena as *mut ::pgrx::pg_sys::varlena)
}
}
#[doc(hidden)]
#[::pgrx::pgrx_macros::pg_extern(immutable, strict, parallel_safe)]
pub fn #funcname_send #generics(input: #name #generics) -> Vec<u8> {
use ::pgrx::datum::{FromDatum, IntoDatum};
let Some(datum): Option<::pgrx::pg_sys::Datum> = input.into_datum() else {
::pgrx::error!("Datum of type `{}` is unexpectedly NULL.", stringify!(#name));
};
unsafe {
let Some(serialized): Option<Vec<u8>> = FromDatum::from_datum(datum, false) else {
::pgrx::error!("Failed to CBOR-serialize Datum to type `{}`.", stringify!(#name));
};
serialized
}
}
});
}
let sql_graph_entity_item = sql_gen::PostgresTypeDerive::from_derive_input(
ast,
args.contains(&PostgresTypeAttribute::PgBinaryProtocol),
)?;
sql_graph_entity_item.to_tokens(&mut stream);
Ok(stream)
}
#[proc_macro_derive(PostgresGucEnum, attributes(name, hidden))]
pub fn postgres_guc_enum(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_guc_enum(ast).unwrap_or_else(|e| e.into_compile_error()).into()
}
fn impl_guc_enum(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
use std::str::FromStr;
use syn::parse::Parse;
enum GucEnumAttribute {
Name(CString),
Hidden(bool),
}
impl GucEnumAttribute {
fn is_guc_enum_attribute(attribute: &str) -> bool {
matches!(attribute, "name" | "hidden")
}
}
impl Parse for GucEnumAttribute {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let ident: Ident = input.parse()?;
let _: syn::token::Eq = input.parse()?;
match ident.to_string().as_str() {
"name" => input.parse::<syn::LitCStr>().map(|val| Self::Name(val.value())),
"hidden" => input.parse::<syn::LitBool>().map(|val| Self::Hidden(val.value())),
x => Err(syn::Error::new(input.span(), format!("unknown attribute {x}"))),
}
}
}
let Data::Enum(data) = ast.data.clone() else {
return Err(syn::Error::new(
ast.span(),
"#[derive(PostgresGucEnum)] can only be applied to enums",
));
};
let ident = ast.ident.clone();
let mut config = Vec::new();
for (index, variant) in data.variants.iter().enumerate() {
let default_name = CString::from_str(&variant.ident.to_string())
.expect("the identifier contains a null character.");
let default_val = index as i32;
let default_hidden = false;
let mut name = None;
let mut hidden = None;
for attr in variant.attrs.iter() {
if let Some(ident) = attr.path().get_ident()
&& GucEnumAttribute::is_guc_enum_attribute(&ident.to_string())
{
let pair: GucEnumAttribute = syn::parse2(attr.meta.to_token_stream())?;
match pair {
GucEnumAttribute::Name(value) => {
if name.replace(value).is_some() {
return Err(syn::Error::new(ast.span(), "too many #[name] attributes"));
}
}
GucEnumAttribute::Hidden(value) => {
if hidden.replace(value).is_some() {
return Err(syn::Error::new(
ast.span(),
"too many #[hidden] attributes",
));
}
}
}
}
}
let ident = variant.ident.clone();
let name = name.unwrap_or(default_name);
let val = default_val;
let hidden = hidden.unwrap_or(default_hidden);
config.push((ident, name, val, hidden));
}
let config_idents = config.iter().map(|x| &x.0).collect::<Vec<_>>();
let config_names = config.iter().map(|x| &x.1).collect::<Vec<_>>();
let config_vals = config.iter().map(|x| &x.2).collect::<Vec<_>>();
let config_hiddens = config.iter().map(|x| &x.3).collect::<Vec<_>>();
Ok(quote! {
unsafe impl ::pgrx::guc::GucEnum for #ident {
fn from_ordinal(ordinal: i32) -> Self {
match ordinal {
#(#config_vals => Self::#config_idents,)*
_ => panic!("Unrecognized ordinal"),
}
}
fn to_ordinal(&self) -> i32 {
match self {
#(Self::#config_idents => #config_vals,)*
}
}
const CONFIG_ENUM_ENTRY: *const ::pgrx::pg_sys::config_enum_entry = [
#(
::pgrx::pg_sys::config_enum_entry {
name: #config_names.as_ptr(),
val: #config_vals,
hidden: #config_hiddens,
},
)*
::pgrx::pg_sys::config_enum_entry {
name: core::ptr::null(),
val: 0,
hidden: false,
},
].as_ptr();
}
})
}
#[derive(Debug, Hash, Ord, PartialOrd, Eq, PartialEq)]
enum PostgresTypeAttribute {
InOutFuncs,
PgBinaryProtocol,
PgVarlenaInOutFuncs,
Default,
ManualFromIntoDatum,
}
fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet<PostgresTypeAttribute> {
let mut categorized_attributes = HashSet::new();
for a in attributes {
let path = &a.path();
let path = quote! {#path}.to_string();
match path.as_str() {
"inoutfuncs" => {
categorized_attributes.insert(PostgresTypeAttribute::InOutFuncs);
}
"pg_binary_protocol" => {
categorized_attributes.insert(PostgresTypeAttribute::PgBinaryProtocol);
}
"pgvarlena_inoutfuncs" => {
categorized_attributes.insert(PostgresTypeAttribute::PgVarlenaInOutFuncs);
}
"bikeshed_postgres_type_manually_impl_from_into_datum" => {
categorized_attributes.insert(PostgresTypeAttribute::ManualFromIntoDatum);
}
_ => {
}
};
}
categorized_attributes
}
#[proc_macro_derive(PostgresEq, attributes(pgrx))]
pub fn derive_postgres_eq(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
deriving_postgres_eq(ast).unwrap_or_else(syn::Error::into_compile_error).into()
}
#[proc_macro_derive(PostgresOrd, attributes(pgrx))]
pub fn derive_postgres_ord(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
deriving_postgres_ord(ast).unwrap_or_else(syn::Error::into_compile_error).into()
}
#[proc_macro_derive(PostgresHash, attributes(pgrx))]
pub fn derive_postgres_hash(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
deriving_postgres_hash(ast).unwrap_or_else(syn::Error::into_compile_error).into()
}
#[proc_macro_derive(AggregateName, attributes(aggregate_name))]
pub fn derive_aggregate_name(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_aggregate_name(ast).unwrap_or_else(|e| e.into_compile_error()).into()
}
fn impl_aggregate_name(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let name = &ast.ident;
let mut custom_name_value: Option<String> = None;
for attr in &ast.attrs {
if attr.path().is_ident("aggregate_name") {
let meta = &attr.meta;
match meta {
syn::Meta::NameValue(syn::MetaNameValue {
value: syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(s), .. }),
..
}) => {
custom_name_value = Some(s.value());
break;
}
_ => {
return Err(syn::Error::new_spanned(
attr,
"#[aggregate_name] must be in the form `#[aggregate_name = \"string_literal\"]`",
));
}
}
}
}
let name_str = custom_name_value.unwrap_or(name.to_string());
let expanded = quote! {
impl ::pgrx::aggregate::ToAggregateName for #name {
const NAME: &'static str = #name_str;
}
};
Ok(expanded)
}
#[proc_macro_attribute]
pub fn pg_aggregate(_attr: TokenStream, item: TokenStream) -> TokenStream {
fn wrapped(item_impl: ItemImpl) -> Result<TokenStream, syn::Error> {
let sql_graph_entity_item = PgAggregate::new(item_impl)?;
Ok(sql_graph_entity_item.to_token_stream().into())
}
let parsed_base = parse_macro_input!(item as syn::ItemImpl);
wrapped(parsed_base).unwrap_or_else(|e| e.into_compile_error().into())
}
#[proc_macro_attribute]
pub fn pgrx(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn pg_trigger(attrs: TokenStream, input: TokenStream) -> TokenStream {
fn wrapped(attrs: TokenStream, input: TokenStream) -> Result<TokenStream, syn::Error> {
use pgrx_sql_entity_graph::{PgTrigger, PgTriggerAttribute};
use syn::Token;
use syn::parse::Parser;
use syn::punctuated::Punctuated;
let attributes =
Punctuated::<PgTriggerAttribute, Token![,]>::parse_terminated.parse(attrs)?;
let item_fn: syn::ItemFn = syn::parse(input)?;
let trigger_item = PgTrigger::new(item_fn, attributes)?;
let trigger_tokens = trigger_item.to_token_stream();
Ok(trigger_tokens.into())
}
wrapped(attrs, input).unwrap_or_else(|e| e.into_compile_error().into())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn short_name_unchanged() {
let mut ident = syn::Ident::new("test_foo", proc_macro2::Span::call_site());
let original = ident.to_string();
maybe_shorten_pg_test_ident(&mut ident);
assert_eq!(ident.to_string(), original);
}
#[test]
fn exactly_63_chars_unchanged() {
let name = "a".repeat(63);
let mut ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
maybe_shorten_pg_test_ident(&mut ident);
assert_eq!(ident.to_string(), name);
}
#[test]
fn exactly_64_chars_is_shortened() {
let name = "a".repeat(64);
let mut ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
maybe_shorten_pg_test_ident(&mut ident);
let result = ident.to_string();
assert!(result.len() <= 63, "shortened name is {len} chars: {result}", len = result.len());
assert!(result.starts_with('t'), "shortened name should start with 't': {result}");
}
#[test]
fn very_long_name_fits_in_63() {
let name = "test_that_something_really_important_works_correctly_when_given_a_very_long_input_name";
assert!(name.len() > 63);
let mut ident = syn::Ident::new(name, proc_macro2::Span::call_site());
maybe_shorten_pg_test_ident(&mut ident);
let result = ident.to_string();
assert_eq!(result.len(), 63, "shortened name should be exactly 63 chars: {result}");
}
#[test]
fn different_long_names_get_different_shortened_names() {
let name_a = format!("{}{}", "a".repeat(60), "xxxx");
let name_b = format!("{}{}", "a".repeat(60), "yyyy");
let mut id_a = syn::Ident::new(&name_a, proc_macro2::Span::call_site());
let mut id_b = syn::Ident::new(&name_b, proc_macro2::Span::call_site());
maybe_shorten_pg_test_ident(&mut id_a);
maybe_shorten_pg_test_ident(&mut id_b);
assert_ne!(
id_a.to_string(),
id_b.to_string(),
"names differing only in the tail should still get different shortened forms"
);
}
}