Skip to main content

pgrx_macros/
lib.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10extern crate proc_macro;
11
12use proc_macro::TokenStream;
13use std::collections::HashSet;
14use std::ffi::CString;
15
16use proc_macro2::Ident;
17use quote::{ToTokens, format_ident, quote};
18use syn::parse::Parser;
19use syn::spanned::Spanned;
20use syn::{Attribute, Data, DeriveInput, Item, ItemImpl, parse_macro_input};
21
22use operators::{deriving_postgres_eq, deriving_postgres_hash, deriving_postgres_ord};
23use pgrx_sql_entity_graph as sql_gen;
24use sql_gen::{
25    Attribute as SqlGenAttribute, CodeEnrichment, ExtensionSql, ExtensionSqlFile, PgAggregate,
26    PgCast, PgExtern, PostgresEnum, Schema,
27};
28
29mod operators;
30mod pg_bench;
31mod rewriter;
32
33/// Declare a function as `#[pg_guard]` to indicate that it is called from a Postgres `extern "C-unwind"`
34/// function so that Rust `panic!()`s (and Postgres `elog(ERROR)`s) will be properly handled by `pgrx`
35#[proc_macro_attribute]
36pub fn pg_guard(attr: TokenStream, item: TokenStream) -> TokenStream {
37    // get a usable token stream
38    let ast = parse_macro_input!(item as syn::Item);
39
40    let res = match ast {
41        // this is for processing the members of extern "C-unwind" { } blocks
42        // functions inside the block get wrapped as public, top-level unsafe functions that are not "extern"
43        Item::ForeignMod(block) => Ok(rewriter::extern_block(block)),
44
45        // process top-level functions
46        Item::Fn(func) => rewriter::item_fn_without_rewrite(func, attr),
47        unknown => Err(syn::Error::new(
48            unknown.span(),
49            "#[pg_guard] can only be applied to extern \"C-unwind\" blocks and top-level functions",
50        )),
51    };
52    res.unwrap_or_else(|e| e.into_compile_error()).into()
53}
54
55/// `#[pg_test]` functions are test functions (akin to `#[test]`), but they run in-process inside
56/// Postgres during `cargo pgrx test`.
57///
58/// This can be combined with test attributes like [`#[should_panic(expected = "..")]`][expected].
59///
60/// [expected]: https://doc.rust-lang.org/reference/attributes/testing.html#the-should_panic-attribute
61#[proc_macro_attribute]
62pub fn pg_test(attr: TokenStream, item: TokenStream) -> TokenStream {
63    let mut stream = proc_macro2::TokenStream::new();
64
65    let parsed_attrs =
66        match syn::punctuated::Punctuated::<SqlGenAttribute, syn::Token![,]>::parse_terminated
67            .parse(attr.clone())
68        {
69            Ok(p) => p,
70            Err(e) => return e.into_compile_error().into(),
71        };
72
73    let expected_error = parsed_attrs.iter().find_map(|a| match a {
74        SqlGenAttribute::ShouldPanic(lit) => Some(lit.value()),
75        _ => None,
76    });
77
78    let ast = parse_macro_input!(item as syn::Item);
79
80    match ast {
81        Item::Fn(mut func) => {
82            // Here we need to break out attributes into test and non-test attributes,
83            // so the generated #[test] attributes are in the appropriate place.
84            let (test_attributes, non_test_attributes) =
85                func.attrs.into_iter().partition::<Vec<Attribute>, _>(|attr| {
86                    attr.path()
87                        .get_ident()
88                        .is_some_and(|ident| ident == "ignore" || ident == "should_panic")
89                });
90
91            func.attrs = non_test_attributes;
92
93            // Save the original ident -- the Rust #[test] function keeps the full
94            // name (Rust has no identifier length limit) so test output is readable.
95            let original_ident = func.sig.ident.clone();
96            maybe_shorten_pg_test_ident(&mut func.sig.ident);
97
98            stream.extend(proc_macro2::TokenStream::from(pg_extern(
99                attr,
100                Item::Fn(func.clone()).to_token_stream().into(),
101            )));
102
103            let expected_error = match expected_error {
104                Some(msg) => quote! {Some(#msg)},
105                None => quote! {None},
106            };
107
108            let sql_funcname = func.sig.ident.to_string();
109            let test_func_name = format_ident!("pg_{}", original_ident);
110
111            let attributes = func.attrs;
112            let mut att_stream = proc_macro2::TokenStream::new();
113
114            for a in attributes.iter() {
115                let as_str = a.to_token_stream().to_string();
116                att_stream.extend(quote! {
117                    options.push(#as_str);
118                });
119            }
120
121            stream.extend(quote! {
122                #[test]
123                #(#test_attributes)*
124                fn #test_func_name() {
125                    let mut options = Vec::new();
126                    #att_stream
127
128                    crate::pg_test::setup(options);
129                    let res = pgrx_tests::run_test(#sql_funcname, #expected_error, crate::pg_test::postgresql_conf_options());
130                    match res {
131                        Ok(()) => (),
132                        Err(e) => panic!("{e:?}")
133                    }
134                }
135            });
136        }
137
138        thing => {
139            return syn::Error::new(
140                thing.span(),
141                "#[pg_test] can only be applied to top-level functions",
142            )
143            .into_compile_error()
144            .into();
145        }
146    }
147
148    stream.into()
149}
150
151/// If `ident` is too long for PostgreSQL's NAMEDATALEN (64 bytes, so 63 usable characters),
152/// replace it with a shortened form: `t{N}_{truncated_original}` where N is a monotonic
153/// counter that guarantees uniqueness within a compilation.
154fn maybe_shorten_pg_test_ident(ident: &mut syn::Ident) {
155    use std::sync::atomic::{AtomicUsize, Ordering};
156    static COUNTER: AtomicUsize = AtomicUsize::new(0);
157
158    // Matches pgrx_sql_entity_graph::ident_is_acceptable_to_postgres
159    const POSTGRES_IDENTIFIER_MAX_LEN: usize = 64;
160
161    let original = ident.to_string();
162    if original.len() < POSTGRES_IDENTIFIER_MAX_LEN {
163        return;
164    }
165
166    let n = COUNTER.fetch_add(1, Ordering::Relaxed);
167    let prefix = format!("t{n}_");
168    let name_budget = (POSTGRES_IDENTIFIER_MAX_LEN - 1) - prefix.len();
169
170    // Truncate at a UTF-8 char boundary (idents are ASCII in practice, but be correct).
171    let mut byte_end = name_budget.min(original.len());
172    while !original.is_char_boundary(byte_end) {
173        byte_end -= 1;
174    }
175    let shortened = format!("{prefix}{}", &original[..byte_end]);
176    *ident = syn::Ident::new(&shortened, ident.span());
177}
178
179/// `#[pg_bench]` functions are in-process Criterion-driven benchmarks that run inside Postgres
180/// during `cargo pgrx bench`.
181#[proc_macro_attribute]
182pub fn pg_bench(attr: TokenStream, item: TokenStream) -> TokenStream {
183    pg_bench::pg_bench(attr, item)
184}
185
186/// Associated macro for `#[pg_test]` to provide context back to your test framework to indicate
187/// that the test system is being initialized
188#[proc_macro_attribute]
189pub fn initialize(_attr: TokenStream, item: TokenStream) -> TokenStream {
190    item
191}
192
193/**
194Declare a function as `#[pg_cast]` to indicate that it represents a Postgres [cast](https://www.postgresql.org/docs/current/sql-createcast.html).
195
196* `assignment`: Corresponds to [`AS ASSIGNMENT`](https://www.postgresql.org/docs/current/sql-createcast.html).
197* `implicit`: Corresponds to [`AS IMPLICIT`](https://www.postgresql.org/docs/current/sql-createcast.html).
198
199By default if no attribute is specified, the cast function can only be used in an explicit cast.
200
201Functions MUST accept and return exactly one value whose type MUST be a `pgrx` supported type. `pgrx` supports many PostgreSQL types by default.
202New types can be defined via [`macro@PostgresType`] or [`macro@PostgresEnum`].
203
204`#[pg_cast]` also supports all the attributes supported by the [`macro@pg_extern]` macro, which are
205passed down to the underlying function.
206
207Example usage:
208```rust,ignore
209use pgrx::*;
210#[pg_cast(implicit)]
211fn cast_json_to_int(input: Json) -> i32 { todo!() }
212*/
213#[proc_macro_attribute]
214pub fn pg_cast(attr: TokenStream, item: TokenStream) -> TokenStream {
215    fn wrapped(attr: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
216        use syn::parse::Parser;
217        use syn::punctuated::Punctuated;
218
219        let mut cast = None;
220        let mut pg_extern_attrs = proc_macro2::TokenStream::new();
221
222        // look for the attributes `#[pg_cast]` directly understands
223        match Punctuated::<syn::Path, syn::Token![,]>::parse_terminated.parse(attr) {
224            Ok(paths) => {
225                let mut new_paths = Punctuated::<syn::Path, syn::Token![,]>::new();
226                for path in paths {
227                    match (PgCast::try_from(path), &cast) {
228                        (Ok(style), None) => cast = Some(style),
229                        (Ok(_), Some(cast)) => {
230                            panic!("The cast type has already been set to `{cast:?}`")
231                        }
232
233                        // ... and anything it doesn't understand is blindly passed through to the
234                        // underlying `#[pg_extern]` function that gets created, which will ultimately
235                        // decide what's naughty and what's nice
236                        (Err(unknown), _) => {
237                            new_paths.push(unknown);
238                        }
239                    }
240                }
241
242                pg_extern_attrs.extend(new_paths.into_token_stream());
243            }
244            Err(err) => {
245                panic!("Failed to parse attribute to pg_cast: {err}")
246            }
247        }
248
249        let pg_extern = PgExtern::new(pg_extern_attrs, item.clone().into())?.0;
250        Ok(CodeEnrichment(pg_extern.as_cast(cast.unwrap_or_default())).to_token_stream().into())
251    }
252
253    wrapped(attr, item).unwrap_or_else(|e: syn::Error| e.into_compile_error().into())
254}
255
256/// Declare a function as `#[pg_operator]` to indicate that it represents a Postgres operator
257/// `cargo pgrx schema` will automatically generate the underlying SQL
258#[proc_macro_attribute]
259pub fn pg_operator(attr: TokenStream, item: TokenStream) -> TokenStream {
260    pg_extern(attr, item)
261}
262
263/// Used with `#[pg_operator]`.  1 value which is the operator name itself
264#[proc_macro_attribute]
265pub fn opname(_attr: TokenStream, item: TokenStream) -> TokenStream {
266    item
267}
268
269/// Used with `#[pg_operator]`.  1 value which is the function name
270#[proc_macro_attribute]
271pub fn commutator(_attr: TokenStream, item: TokenStream) -> TokenStream {
272    item
273}
274
275/// Used with `#[pg_operator]`.  1 value which is the function name
276#[proc_macro_attribute]
277pub fn negator(_attr: TokenStream, item: TokenStream) -> TokenStream {
278    item
279}
280
281/// Used with `#[pg_operator]`.  1 value which is the function name
282#[proc_macro_attribute]
283pub fn restrict(_attr: TokenStream, item: TokenStream) -> TokenStream {
284    item
285}
286
287/// Used with `#[pg_operator]`.  1 value which is the function name
288#[proc_macro_attribute]
289pub fn join(_attr: TokenStream, item: TokenStream) -> TokenStream {
290    item
291}
292
293/// Used with `#[pg_operator]`.  no values
294#[proc_macro_attribute]
295pub fn hashes(_attr: TokenStream, item: TokenStream) -> TokenStream {
296    item
297}
298
299/// Used with `#[pg_operator]`.  no values
300#[proc_macro_attribute]
301pub fn merges(_attr: TokenStream, item: TokenStream) -> TokenStream {
302    item
303}
304
305/**
306Declare a Rust module and its contents to be in a schema.
307
308The schema name will always be the `mod`'s identifier. So `mod flop` will create a `flop` schema.
309
310If there is a schema inside a schema, the most specific schema is chosen.
311
312In this example, the created `example` function is in the `dsl_filters` schema.
313
314```rust,ignore
315use pgrx::*;
316
317#[pg_schema]
318mod dsl {
319    use pgrx::*;
320    #[pg_schema]
321    mod dsl_filters {
322        use pgrx::*;
323        #[pg_extern]
324        fn example() { todo!() }
325    }
326}
327```
328
329File modules (like `mod name;`) aren't able to be supported due to [`rust/#54725`](https://github.com/rust-lang/rust/issues/54725).
330
331*/
332#[proc_macro_attribute]
333pub fn pg_schema(_attr: TokenStream, input: TokenStream) -> TokenStream {
334    fn wrapped(input: TokenStream) -> Result<TokenStream, syn::Error> {
335        let pgrx_schema: Schema = syn::parse(input)?;
336        Ok(pgrx_schema.to_token_stream().into())
337    }
338
339    wrapped(input).unwrap_or_else(|e: syn::Error| e.into_compile_error().into())
340}
341
342/**
343Declare SQL to be included in generated extension script.
344
345Accepts a String literal, a `name` attribute, and optionally others:
346
347* `name = "item"`: Set the unique identifier to `"item"` for use in `requires` declarations.
348* `requires = [item, item_two]`: References to other `name`s or Rust items which this SQL should be present after.
349* `creates = [ Type(submod::Cust), Enum(Pre), Function(defined)]`: Communicates that this SQL block creates certain entities.
350  Please note it **does not** create matching Rust types.
351* `bootstrap` (**Unique**): Communicates that this is SQL intended to go before all other generated SQL.
352* `finalize` (**Unique**): Communicates that this is SQL intended to go after all other generated SQL.
353
354You can declare some SQL without any positioning information, meaning it can end up anywhere in the generated SQL:
355
356```rust,ignore
357use pgrx_macros::extension_sql;
358
359extension_sql!(
360    r#"
361    -- SQL statements
362    "#,
363    name = "demo",
364);
365```
366
367To cause the SQL to be output at the start of the generated SQL:
368
369```rust,ignore
370use pgrx_macros::extension_sql;
371
372extension_sql!(
373    r#"
374    -- SQL statements
375    "#,
376    name = "demo",
377    bootstrap,
378);
379```
380
381To cause the SQL to be output at the end of the generated SQL:
382
383```rust,ignore
384use pgrx_macros::extension_sql;
385
386extension_sql!(
387    r#"
388    -- SQL statements
389    "#,
390    name = "demo",
391    finalize,
392);
393```
394
395To declare the SQL dependent, or a dependency of, other items:
396
397```rust,ignore
398use pgrx_macros::extension_sql;
399
400struct Treat;
401
402mod dog_characteristics {
403    enum DogAlignment {
404        Good
405    }
406}
407
408extension_sql!(r#"
409    -- SQL statements
410    "#,
411    name = "named_one",
412);
413
414extension_sql!(r#"
415    -- SQL statements
416    "#,
417    name = "demo",
418    requires = [ "named_one", dog_characteristics::DogAlignment ],
419);
420```
421
422To declare the SQL defines some entity (**Caution:** This is not recommended usage):
423
424```rust,ignore
425use pgrx::stringinfo::StringInfo;
426use pgrx::*;
427use pgrx_utils::get_named_capture;
428
429#[derive(Debug)]
430#[repr(C)]
431struct Complex {
432    x: f64,
433    y: f64,
434}
435
436extension_sql!(r#"\
437        CREATE TYPE complex;\
438    "#,
439    name = "create_complex_type",
440    creates = [Type(Complex)],
441);
442
443#[pg_extern(immutable)]
444fn complex_in(input: &core::ffi::CStr) -> PgBox<Complex> {
445    todo!()
446}
447
448#[pg_extern(immutable)]
449fn complex_out(complex: PgBox<Complex>) -> &'static ::core::ffi::CStr {
450    todo!()
451}
452
453extension_sql!(r#"\
454        CREATE TYPE complex (
455            internallength = 16,
456            input = complex_in,
457            output = complex_out,
458            alignment = double
459        );\
460    "#,
461    name = "demo",
462    requires = ["create_complex_type", complex_in, complex_out],
463);
464
465```
466*/
467#[proc_macro]
468pub fn extension_sql(input: TokenStream) -> TokenStream {
469    fn wrapped(input: TokenStream) -> Result<TokenStream, syn::Error> {
470        let ext_sql: CodeEnrichment<ExtensionSql> = syn::parse(input)?;
471        Ok(ext_sql.to_token_stream().into())
472    }
473
474    wrapped(input).unwrap_or_else(|e: syn::Error| e.into_compile_error().into())
475}
476
477/**
478Declare SQL (from a file) to be included in generated extension script.
479
480Accepts the same options as [`macro@extension_sql`]. `name` is automatically set to the file name (not the full path).
481
482You can declare some SQL without any positioning information, meaning it can end up anywhere in the generated SQL:
483
484```rust,ignore
485use pgrx_macros::extension_sql_file;
486extension_sql_file!(
487    "../static/demo.sql",
488);
489```
490
491To override the default name:
492
493```rust,ignore
494use pgrx_macros::extension_sql_file;
495
496extension_sql_file!(
497    "../static/demo.sql",
498    name = "singular",
499);
500```
501
502For all other options, and examples of them, see [`macro@extension_sql`].
503*/
504#[proc_macro]
505pub fn extension_sql_file(input: TokenStream) -> TokenStream {
506    fn wrapped(input: TokenStream) -> Result<TokenStream, syn::Error> {
507        let ext_sql: CodeEnrichment<ExtensionSqlFile> = syn::parse(input)?;
508        Ok(ext_sql.to_token_stream().into())
509    }
510
511    wrapped(input).unwrap_or_else(|e: syn::Error| e.into_compile_error().into())
512}
513
514/// Associated macro for `#[pg_extern]` or `#[macro@pg_operator]`.  Used to set the `SEARCH_PATH` option
515/// on the `CREATE FUNCTION` statement.
516#[proc_macro_attribute]
517pub fn search_path(_attr: TokenStream, item: TokenStream) -> TokenStream {
518    item
519}
520
521/**
522Declare a function as `#[pg_extern]` to indicate that it can be used by Postgres as a UDF.
523
524Optionally accepts the following attributes:
525
526* `immutable`: Corresponds to [`IMMUTABLE`](https://www.postgresql.org/docs/current/sql-createfunction.html).
527* `strict`: Corresponds to [`STRICT`](https://www.postgresql.org/docs/current/sql-createfunction.html).
528  + In most cases, `#[pg_extern]` can detect when no `Option<T>`s are used, and automatically set this.
529* `stable`: Corresponds to [`STABLE`](https://www.postgresql.org/docs/current/sql-createfunction.html).
530* `volatile`: Corresponds to [`VOLATILE`](https://www.postgresql.org/docs/current/sql-createfunction.html).
531* `raw`: Corresponds to [`RAW`](https://www.postgresql.org/docs/current/sql-createfunction.html).
532* `support`: Corresponds to [`SUPPORT`](https://www.postgresql.org/docs/current/sql-createfunction.html) and is the Rust path to a function to act as the SUPPORT function
533* `security_definer`: Corresponds to [`SECURITY DEFINER`](https://www.postgresql.org/docs/current/sql-createfunction.html)
534* `security_invoker`: Corresponds to [`SECURITY INVOKER`](https://www.postgresql.org/docs/current/sql-createfunction.html)
535* `parallel_safe`: Corresponds to [`PARALLEL SAFE`](https://www.postgresql.org/docs/current/sql-createfunction.html).
536* `parallel_unsafe`: Corresponds to [`PARALLEL UNSAFE`](https://www.postgresql.org/docs/current/sql-createfunction.html).
537* `parallel_restricted`: Corresponds to [`PARALLEL RESTRICTED`](https://www.postgresql.org/docs/current/sql-createfunction.html).
538* `no_guard`: Do not use `#[pg_guard]` with the function.
539* `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx).
540* `name`: Specifies target function name. Defaults to Rust function name.
541
542Functions can accept and return any type which `pgrx` supports. `pgrx` supports many PostgreSQL types by default.
543New types can be defined via [`macro@PostgresType`] or [`macro@PostgresEnum`].
544
545
546Without any arguments or returns:
547```rust,ignore
548use pgrx::*;
549#[pg_extern]
550fn foo() { todo!() }
551```
552
553# Arguments
554It's possible to pass even complex arguments:
555
556```rust,ignore
557use pgrx::*;
558#[pg_extern]
559fn boop(
560    a: i32,
561    b: Option<i32>,
562    c: Vec<i32>,
563    d: Option<Vec<Option<i32>>>
564) { todo!() }
565```
566
567It's possible to set argument defaults, set by PostgreSQL when the function is invoked:
568
569```rust,ignore
570use pgrx::*;
571#[pg_extern]
572fn boop(a: default!(i32, 11111)) { todo!() }
573#[pg_extern]
574fn doop(
575    a: default!(Vec<Option<&str>>, "ARRAY[]::text[]"),
576    b: default!(String, "'note the inner quotes!'")
577) { todo!() }
578```
579
580The `default!()` macro may only be used in argument position.
581
582It accepts 2 arguments:
583
584* A type
585* A `bool`, numeric, or SQL string to represent the default. `"NULL"` is a possible value, as is `"'string'"`
586
587**If the default SQL entity created by the extension:** ensure it is added to `requires` as a dependency:
588
589```rust,ignore
590use pgrx::*;
591#[pg_extern]
592fn default_value() -> i32 { todo!() }
593
594#[pg_extern(
595    requires = [ default_value, ],
596)]
597fn do_it(
598    a: default!(i32, "default_value()"),
599) { todo!() }
600```
601
602# Returns
603
604It's possible to return even complex values, as well:
605
606```rust,ignore
607use pgrx::*;
608#[pg_extern]
609fn boop() -> i32 { todo!() }
610#[pg_extern]
611fn doop() -> Option<i32> { todo!() }
612#[pg_extern]
613fn swoop() -> Option<Vec<Option<i32>>> { todo!() }
614#[pg_extern]
615fn floop() -> (i32, i32) { todo!() }
616```
617
618Like in PostgreSQL, it's possible to return tables using iterators and the `name!()` macro:
619
620```rust,ignore
621use pgrx::*;
622#[pg_extern]
623fn floop<'a>() -> TableIterator<'a, (name!(a, i32), name!(b, i32))> {
624    TableIterator::new(None.into_iter())
625}
626
627#[pg_extern]
628fn singular_floop() -> (name!(a, i32), name!(b, i32)) {
629    todo!()
630}
631```
632
633The `name!()` macro may only be used in return position inside the `T` of a `TableIterator<'a, T>`.
634
635It accepts 2 arguments:
636
637* A name, such as `example`
638* A type
639
640# Special Cases
641
642`pg_sys::Oid` is a special cased type alias, in order to use it as an argument or return it must be
643passed with it's full module path (`pg_sys::Oid`) in order to be resolved.
644
645```rust,ignore
646use pgrx::*;
647
648#[pg_extern]
649fn example_arg(animals: pg_sys::Oid) {
650    todo!()
651}
652
653#[pg_extern]
654fn example_return() -> pg_sys::Oid {
655    todo!()
656}
657```
658
659*/
660#[proc_macro_attribute]
661#[track_caller]
662pub fn pg_extern(attr: TokenStream, item: TokenStream) -> TokenStream {
663    fn wrapped(attr: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
664        let pg_extern_item = PgExtern::new(attr.into(), item.into())?;
665        Ok(pg_extern_item.to_token_stream().into())
666    }
667
668    wrapped(attr, item).unwrap_or_else(|e: syn::Error| e.into_compile_error().into())
669}
670
671/**
672Generate necessary bindings for using the enum with PostgreSQL.
673
674```rust,ignore
675# use pgrx_pg_sys as pg_sys;
676use pgrx::*;
677use serde::{Deserialize, Serialize};
678#[derive(Debug, Serialize, Deserialize, PostgresEnum)]
679enum DogNames {
680    Nami,
681    Brandy,
682}
683```
684
685*/
686#[proc_macro_derive(PostgresEnum, attributes(requires, pgrx))]
687pub fn postgres_enum(input: TokenStream) -> TokenStream {
688    let ast = parse_macro_input!(input as syn::DeriveInput);
689
690    impl_postgres_enum(ast).unwrap_or_else(|e| e.into_compile_error()).into()
691}
692
693fn impl_postgres_enum(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
694    let mut stream = proc_macro2::TokenStream::new();
695    let sql_graph_entity_ast = ast.clone();
696    let generics = &ast.generics.clone();
697    let enum_ident = &ast.ident;
698    let enum_name = enum_ident.to_string();
699
700    // validate that we're only operating on an enum
701    let Data::Enum(enum_data) = ast.data else {
702        return Err(syn::Error::new(
703            ast.span(),
704            "#[derive(PostgresEnum)] can only be applied to enums",
705        ));
706    };
707
708    let mut from_datum = proc_macro2::TokenStream::new();
709    let mut into_datum = proc_macro2::TokenStream::new();
710
711    for d in enum_data.variants.clone() {
712        let label_ident = &d.ident;
713        let label_string = label_ident.to_string();
714
715        from_datum.extend(quote! { #label_string => Some(#enum_ident::#label_ident), });
716        into_datum.extend(quote! { #enum_ident::#label_ident => Some(::pgrx::enum_helper::lookup_enum_by_label(#enum_name, #label_string)), });
717    }
718
719    // We need another variant of the params for the ArgAbi impl
720    let fcx_lt = syn::Lifetime::new("'fcx", proc_macro2::Span::mixed_site());
721    let mut generics_with_fcx = generics.clone();
722    // so that we can bound on Self: 'fcx
723    generics_with_fcx.make_where_clause().predicates.push(syn::WherePredicate::Type(
724        syn::PredicateType {
725            lifetimes: None,
726            bounded_ty: syn::parse_quote! { Self },
727            colon_token: syn::Token![:](proc_macro2::Span::mixed_site()),
728            bounds: syn::parse_quote! { #fcx_lt },
729        },
730    ));
731    let (impl_gens, ty_gens, where_clause) = generics_with_fcx.split_for_impl();
732    let mut impl_gens: syn::Generics = syn::parse_quote! { #impl_gens };
733    impl_gens
734        .params
735        .insert(0, syn::GenericParam::Lifetime(syn::LifetimeParam::new(fcx_lt.clone())));
736
737    stream.extend(quote! {
738        impl ::pgrx::datum::FromDatum for #enum_ident {
739            #[inline]
740            unsafe fn from_polymorphic_datum(datum: ::pgrx::pg_sys::Datum, is_null: bool, _typeoid: ::pgrx::pg_sys::Oid) -> Option<#enum_ident> {
741                if is_null {
742                    None
743                } else {
744                    // GREPME: non-primitive cast u64 as Oid
745                    let (name, _, _) = ::pgrx::enum_helper::lookup_enum_by_oid(unsafe { ::pgrx::pg_sys::Oid::from_datum(datum, is_null)? } );
746                    match name.as_str() {
747                        #from_datum
748                        _ => panic!("invalid enum value: {name}")
749                    }
750                }
751            }
752        }
753
754        unsafe impl #impl_gens ::pgrx::callconv::ArgAbi<#fcx_lt> for #enum_ident #ty_gens #where_clause {
755            unsafe fn unbox_arg_unchecked(arg: ::pgrx::callconv::Arg<'_, #fcx_lt>) -> Self {
756                let index = arg.index();
757                unsafe { arg.unbox_arg_using_from_datum().unwrap_or_else(|| panic!("argument {index} must not be null")) }
758            }
759
760        }
761
762        unsafe impl #generics ::pgrx::datum::UnboxDatum for #enum_ident #generics {
763            type As<'dat> = #enum_ident #generics where Self: 'dat;
764            #[inline]
765            unsafe fn unbox<'dat>(d: ::pgrx::datum::Datum<'dat>) -> Self::As<'dat> where Self: 'dat {
766                <Self as ::pgrx::datum::FromDatum>::from_datum(::core::mem::transmute(d), false).unwrap()
767            }
768        }
769
770        impl ::pgrx::datum::IntoDatum for #enum_ident {
771            #[inline]
772            fn into_datum(self) -> Option<::pgrx::pg_sys::Datum> {
773                match self {
774                    #into_datum
775                }
776            }
777
778            fn type_oid() -> ::pgrx::pg_sys::Oid {
779                ::pgrx::wrappers::regtypein(#enum_name)
780            }
781
782        }
783
784        unsafe impl ::pgrx::callconv::BoxRet for #enum_ident {
785            unsafe fn box_into<'fcx>(self, fcinfo: &mut ::pgrx::callconv::FcInfo<'fcx>) -> ::pgrx::datum::Datum<'fcx> {
786                match ::pgrx::datum::IntoDatum::into_datum(self) {
787                    None => fcinfo.return_null(),
788                    Some(datum) => unsafe { fcinfo.return_raw_datum(datum) },
789                }
790            }
791        }
792    });
793
794    let sql_graph_entity_item = PostgresEnum::from_derive_input(sql_graph_entity_ast)?;
795    sql_graph_entity_item.to_tokens(&mut stream);
796
797    Ok(stream)
798}
799
800/**
801Generate necessary bindings for using the type with PostgreSQL.
802
803```rust,ignore
804# use pgrx_pg_sys as pg_sys;
805use pgrx::*;
806use serde::{Deserialize, Serialize};
807#[derive(Debug, Serialize, Deserialize, PostgresType)]
808struct Dog {
809    treats_received: i64,
810    pets_gotten: i64,
811}
812
813#[derive(Debug, Serialize, Deserialize, PostgresType)]
814enum Animal {
815    Dog(Dog),
816}
817```
818
819Optionally accepts the following attributes:
820
821* `inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the type.
822* `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type.
823* `pg_binary_protocol`: Use the binary protocol for this type.
824* `pgrx(alignment = "<align>")`: Derive Postgres alignment from Rust type. One of `"on"`, or `"off"`.
825* `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx).
826*/
827#[proc_macro_derive(
828    PostgresType,
829    attributes(
830        inoutfuncs,
831        pgvarlena_inoutfuncs,
832        pg_binary_protocol,
833        bikeshed_postgres_type_manually_impl_from_into_datum,
834        requires,
835        pgrx
836    )
837)]
838pub fn postgres_type(input: TokenStream) -> TokenStream {
839    let ast = parse_macro_input!(input as syn::DeriveInput);
840
841    impl_postgres_type(ast).unwrap_or_else(|e| e.into_compile_error()).into()
842}
843
844fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
845    let name = &ast.ident;
846    let generics = &ast.generics.clone();
847    let has_lifetimes = generics.lifetimes().next();
848    let funcname_in = Ident::new(&format!("{name}_in").to_lowercase(), name.span());
849    let funcname_out = Ident::new(&format!("{name}_out").to_lowercase(), name.span());
850    let funcname_recv = Ident::new(&format!("{name}_recv").to_lowercase(), name.span());
851    let funcname_send = Ident::new(&format!("{name}_send").to_lowercase(), name.span());
852
853    let mut args = parse_postgres_type_args(&ast.attrs);
854    let mut stream = proc_macro2::TokenStream::new();
855
856    // validate that we're only operating on a struct
857    match ast.data {
858        Data::Struct(_) => { /* this is okay */ }
859        Data::Enum(_) => {
860            // this is okay and if there's an attempt to implement PostgresEnum,
861            // it will result in compile-time error of conflicting implementation
862            // of traits (IntoDatum, inout, etc.)
863        }
864        _ => {
865            return Err(syn::Error::new(
866                ast.span(),
867                "#[derive(PostgresType)] can only be applied to structs or enums",
868            ));
869        }
870    }
871
872    if !args.contains(&PostgresTypeAttribute::InOutFuncs)
873        && !args.contains(&PostgresTypeAttribute::PgVarlenaInOutFuncs)
874    {
875        // assume the user wants us to implement the InOutFuncs
876        args.insert(PostgresTypeAttribute::Default);
877    }
878
879    let lifetime = match has_lifetimes {
880        Some(lifetime) => quote! {#lifetime},
881        None => quote! {'_},
882    };
883
884    // We need another variant of the params for the ArgAbi impl
885    let fcx_lt = syn::Lifetime::new("'fcx", proc_macro2::Span::mixed_site());
886    let mut generics_with_fcx = generics.clone();
887    // so that we can bound on Self: 'fcx
888    generics_with_fcx.make_where_clause().predicates.push(syn::WherePredicate::Type(
889        syn::PredicateType {
890            lifetimes: None,
891            bounded_ty: syn::parse_quote! { Self },
892            colon_token: syn::Token![:](proc_macro2::Span::mixed_site()),
893            bounds: syn::parse_quote! { #fcx_lt },
894        },
895    ));
896    let (impl_gens, ty_gens, where_clause) = generics_with_fcx.split_for_impl();
897    let mut impl_gens: syn::Generics = syn::parse_quote! { #impl_gens };
898    impl_gens
899        .params
900        .insert(0, syn::GenericParam::Lifetime(syn::LifetimeParam::new(fcx_lt.clone())));
901
902    // all #[derive(PostgresType)] need to implement that trait
903    // and also the FromDatum and IntoDatum
904    stream.extend(quote! {
905        impl #generics ::pgrx::datum::PostgresType for #name #generics { }
906    });
907
908    if !args.contains(&PostgresTypeAttribute::ManualFromIntoDatum) {
909        stream.extend(
910            quote! {
911                impl #generics ::pgrx::datum::IntoDatum for #name #generics {
912                    fn into_datum(self) -> Option<::pgrx::pg_sys::Datum> {
913                        #[allow(deprecated)]
914                        Some(unsafe { ::pgrx::datum::cbor_encode(&self) }.into())
915                    }
916
917                    fn type_oid() -> ::pgrx::pg_sys::Oid {
918                        ::pgrx::wrappers::rust_regtypein::<Self>()
919                    }
920                }
921
922                unsafe impl #generics ::pgrx::callconv::BoxRet for #name #generics {
923                    unsafe fn box_into<'fcx>(self, fcinfo: &mut ::pgrx::callconv::FcInfo<'fcx>) -> ::pgrx::datum::Datum<'fcx> {
924                        match ::pgrx::datum::IntoDatum::into_datum(self) {
925                            None => fcinfo.return_null(),
926                            Some(datum) => unsafe { fcinfo.return_raw_datum(datum) },
927                        }
928                    }
929                }
930
931                impl #generics ::pgrx::datum::FromDatum for #name #generics {
932                    unsafe fn from_polymorphic_datum(
933                        datum: ::pgrx::pg_sys::Datum,
934                        is_null: bool,
935                        _typoid: ::pgrx::pg_sys::Oid,
936                    ) -> Option<Self> {
937                        if is_null {
938                            None
939                        } else {
940                            #[allow(deprecated)]
941                            ::pgrx::datum::cbor_decode(datum.cast_mut_ptr())
942                        }
943                    }
944
945                    unsafe fn from_datum_in_memory_context(
946                        mut memory_context: ::pgrx::memcxt::PgMemoryContexts,
947                        datum: ::pgrx::pg_sys::Datum,
948                        is_null: bool,
949                        _typoid: ::pgrx::pg_sys::Oid,
950                    ) -> Option<Self> {
951                        if is_null {
952                            None
953                        } else {
954                            memory_context.switch_to(|_| {
955                                // this gets the varlena Datum copied into this memory context
956                                let varlena = ::pgrx::pg_sys::pg_detoast_datum_copy(datum.cast_mut_ptr());
957                                <Self as ::pgrx::datum::FromDatum>::from_datum(varlena.into(), is_null)
958                            })
959                        }
960                    }
961                }
962
963                unsafe impl #generics ::pgrx::datum::UnboxDatum for #name #generics {
964                    type As<'dat> = Self where Self: 'dat;
965                    unsafe fn unbox<'dat>(datum: ::pgrx::datum::Datum<'dat>) -> Self::As<'dat> where Self: 'dat {
966                        <Self as ::pgrx::datum::FromDatum>::from_datum(::core::mem::transmute(datum), false).unwrap()
967                    }
968                }
969
970                unsafe impl #impl_gens ::pgrx::callconv::ArgAbi<#fcx_lt> for #name #ty_gens #where_clause
971                {
972                        unsafe fn unbox_arg_unchecked(arg: ::pgrx::callconv::Arg<'_, #fcx_lt>) -> Self {
973                        let index = arg.index();
974                        unsafe { arg.unbox_arg_using_from_datum().unwrap_or_else(|| panic!("argument {index} must not be null")) }
975                    }
976                }
977            }
978        )
979    }
980
981    // and if we don't have custom inout/funcs, we use the JsonInOutFuncs trait
982    // which implements _in and _out #[pg_extern] functions that just return the type itself
983    if args.contains(&PostgresTypeAttribute::Default) {
984        stream.extend(quote! {
985            #[doc(hidden)]
986            #[::pgrx::pgrx_macros::pg_extern(immutable, parallel_safe)]
987            pub fn #funcname_in #generics(input: Option<&#lifetime ::core::ffi::CStr>) -> Option<#name #generics> {
988                use ::pgrx::inoutfuncs::json_from_slice;
989                input.map(|cstr| json_from_slice(cstr.to_bytes()).ok()).flatten()
990            }
991
992            #[doc(hidden)]
993            #[::pgrx::pgrx_macros::pg_extern (immutable, parallel_safe)]
994            pub fn #funcname_out #generics(input: #name #generics) -> ::pgrx::ffi::CString {
995                use ::pgrx::inoutfuncs::json_to_vec;
996                let mut bytes = json_to_vec(&input).unwrap();
997                bytes.push(0); // terminate
998                ::pgrx::ffi::CString::from_vec_with_nul(bytes).unwrap()
999            }
1000        });
1001    } else if args.contains(&PostgresTypeAttribute::InOutFuncs) {
1002        // otherwise if it's InOutFuncs our _in/_out functions use an owned type instance
1003        stream.extend(quote! {
1004            #[doc(hidden)]
1005            #[::pgrx::pgrx_macros::pg_extern(immutable,parallel_safe)]
1006            pub fn #funcname_in #generics(input: Option<&::core::ffi::CStr>) -> Option<#name #generics> {
1007                input.map_or_else(|| {
1008                    if let Some(m) = <#name as ::pgrx::inoutfuncs::InOutFuncs>::NULL_ERROR_MESSAGE {
1009                        ::pgrx::pg_sys::error!("{m}");
1010                    }
1011                    None
1012                }, |i| Some(<#name as ::pgrx::inoutfuncs::InOutFuncs>::input(i)))
1013            }
1014
1015            #[doc(hidden)]
1016            #[::pgrx::pgrx_macros::pg_extern(immutable,parallel_safe)]
1017            pub fn #funcname_out #generics(input: #name #generics) -> ::pgrx::ffi::CString {
1018                let mut buffer = ::pgrx::stringinfo::StringInfo::new();
1019                ::pgrx::inoutfuncs::InOutFuncs::output(&input, &mut buffer);
1020                // SAFETY: We just constructed this StringInfo ourselves
1021                unsafe { buffer.leak_cstr().to_owned() }
1022            }
1023        });
1024    } else if args.contains(&PostgresTypeAttribute::PgVarlenaInOutFuncs) {
1025        // otherwise if it's PgVarlenaInOutFuncs our _in/_out functions use a PgVarlena
1026        stream.extend(quote! {
1027            #[doc(hidden)]
1028            #[::pgrx::pgrx_macros::pg_extern(immutable,parallel_safe)]
1029            pub fn #funcname_in #generics(input: Option<&::core::ffi::CStr>) -> Option<::pgrx::datum::PgVarlena<#name #generics>> {
1030                input.map_or_else(|| {
1031                    if let Some(m) = <#name as ::pgrx::inoutfuncs::PgVarlenaInOutFuncs>::NULL_ERROR_MESSAGE {
1032                        ::pgrx::pg_sys::error!("{m}");
1033                    }
1034                    None
1035                }, |i| Some(<#name as ::pgrx::inoutfuncs::PgVarlenaInOutFuncs>::input(i)))
1036            }
1037
1038            #[doc(hidden)]
1039            #[::pgrx::pgrx_macros::pg_extern(immutable,parallel_safe)]
1040            pub fn #funcname_out #generics(input: ::pgrx::datum::PgVarlena<#name #generics>) -> ::pgrx::ffi::CString {
1041                let mut buffer = ::pgrx::stringinfo::StringInfo::new();
1042                ::pgrx::inoutfuncs::PgVarlenaInOutFuncs::output(&*input, &mut buffer);
1043                // SAFETY: We just constructed this StringInfo ourselves
1044                unsafe { buffer.leak_cstr().to_owned() }
1045            }
1046        });
1047    }
1048
1049    if args.contains(&PostgresTypeAttribute::PgBinaryProtocol) {
1050        // At this time, the `PostgresTypeAttribute` does not impact the way we generate
1051        // the `recv` and `send` functions.
1052        stream.extend(quote! {
1053            #[doc(hidden)]
1054            #[::pgrx::pgrx_macros::pg_extern(immutable, strict, parallel_safe)]
1055            pub fn #funcname_recv #generics(
1056                mut internal: ::pgrx::datum::Internal,
1057            ) -> #name #generics {
1058                let buf = unsafe { internal.get_mut::<::pgrx::pg_sys::StringInfoData>().unwrap() };
1059
1060                let mut serialized = ::pgrx::StringInfo::new();
1061
1062                serialized.push_bytes(&[0u8; ::pgrx::pg_sys::VARHDRSZ]); // reserve space for the header
1063                serialized.push_bytes(unsafe {
1064                    core::slice::from_raw_parts(
1065                        buf.data as *const u8,
1066                        buf.len as usize
1067                    )
1068                });
1069
1070                let size = serialized.len();
1071                let varlena = serialized.into_char_ptr();
1072
1073                unsafe{
1074                    ::pgrx::set_varsize_4b(varlena as *mut ::pgrx::pg_sys::varlena, size as i32);
1075                    buf.cursor = buf.len;
1076                    ::pgrx::datum::cbor_decode(varlena as *mut ::pgrx::pg_sys::varlena)
1077                }
1078            }
1079            #[doc(hidden)]
1080            #[::pgrx::pgrx_macros::pg_extern(immutable, strict, parallel_safe)]
1081            pub fn #funcname_send #generics(input: #name #generics) -> Vec<u8> {
1082                use ::pgrx::datum::{FromDatum, IntoDatum};
1083                let Some(datum): Option<::pgrx::pg_sys::Datum> = input.into_datum() else {
1084                    ::pgrx::error!("Datum of type `{}` is unexpectedly NULL.", stringify!(#name));
1085                };
1086                unsafe {
1087                    let Some(serialized): Option<Vec<u8>> = FromDatum::from_datum(datum, false) else {
1088                        ::pgrx::error!("Failed to CBOR-serialize Datum to type `{}`.", stringify!(#name));
1089                    };
1090                    serialized
1091                }
1092            }
1093        });
1094    }
1095
1096    let sql_graph_entity_item = sql_gen::PostgresTypeDerive::from_derive_input(
1097        ast,
1098        args.contains(&PostgresTypeAttribute::PgBinaryProtocol),
1099    )?;
1100    sql_graph_entity_item.to_tokens(&mut stream);
1101
1102    Ok(stream)
1103}
1104
1105/// Derives the `GucEnum` trait, so that normal Rust enums can be used as a GUC.
1106#[proc_macro_derive(PostgresGucEnum, attributes(name, hidden))]
1107pub fn postgres_guc_enum(input: TokenStream) -> TokenStream {
1108    let ast = parse_macro_input!(input as syn::DeriveInput);
1109
1110    impl_guc_enum(ast).unwrap_or_else(|e| e.into_compile_error()).into()
1111}
1112
1113fn impl_guc_enum(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
1114    use std::str::FromStr;
1115    use syn::parse::Parse;
1116
1117    enum GucEnumAttribute {
1118        Name(CString),
1119        Hidden(bool),
1120    }
1121
1122    impl GucEnumAttribute {
1123        fn is_guc_enum_attribute(attribute: &str) -> bool {
1124            matches!(attribute, "name" | "hidden")
1125        }
1126    }
1127
1128    impl Parse for GucEnumAttribute {
1129        fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1130            let ident: Ident = input.parse()?;
1131            let _: syn::token::Eq = input.parse()?;
1132            match ident.to_string().as_str() {
1133                "name" => input.parse::<syn::LitCStr>().map(|val| Self::Name(val.value())),
1134                "hidden" => input.parse::<syn::LitBool>().map(|val| Self::Hidden(val.value())),
1135                x => Err(syn::Error::new(input.span(), format!("unknown attribute {x}"))),
1136            }
1137        }
1138    }
1139
1140    // validate that we're only operating on an enum
1141    let Data::Enum(data) = ast.data.clone() else {
1142        return Err(syn::Error::new(
1143            ast.span(),
1144            "#[derive(PostgresGucEnum)] can only be applied to enums",
1145        ));
1146    };
1147    let ident = ast.ident.clone();
1148    let mut config = Vec::new();
1149    for (index, variant) in data.variants.iter().enumerate() {
1150        let default_name = CString::from_str(&variant.ident.to_string())
1151            .expect("the identifier contains a null character.");
1152        let default_val = index as i32;
1153        let default_hidden = false;
1154        let mut name = None;
1155        let mut hidden = None;
1156
1157        for attr in variant.attrs.iter() {
1158            if let Some(ident) = attr.path().get_ident()
1159                && GucEnumAttribute::is_guc_enum_attribute(&ident.to_string())
1160            {
1161                let pair: GucEnumAttribute = syn::parse2(attr.meta.to_token_stream())?;
1162                match pair {
1163                    GucEnumAttribute::Name(value) => {
1164                        if name.replace(value).is_some() {
1165                            return Err(syn::Error::new(ast.span(), "too many #[name] attributes"));
1166                        }
1167                    }
1168                    GucEnumAttribute::Hidden(value) => {
1169                        if hidden.replace(value).is_some() {
1170                            return Err(syn::Error::new(
1171                                ast.span(),
1172                                "too many #[hidden] attributes",
1173                            ));
1174                        }
1175                    }
1176                }
1177            }
1178        }
1179        let ident = variant.ident.clone();
1180        let name = name.unwrap_or(default_name);
1181        let val = default_val;
1182        let hidden = hidden.unwrap_or(default_hidden);
1183        config.push((ident, name, val, hidden));
1184    }
1185    let config_idents = config.iter().map(|x| &x.0).collect::<Vec<_>>();
1186    let config_names = config.iter().map(|x| &x.1).collect::<Vec<_>>();
1187    let config_vals = config.iter().map(|x| &x.2).collect::<Vec<_>>();
1188    let config_hiddens = config.iter().map(|x| &x.3).collect::<Vec<_>>();
1189
1190    Ok(quote! {
1191        unsafe impl ::pgrx::guc::GucEnum for #ident {
1192            fn from_ordinal(ordinal: i32) -> Self {
1193                match ordinal {
1194                    #(#config_vals => Self::#config_idents,)*
1195                    _ => panic!("Unrecognized ordinal"),
1196                }
1197            }
1198
1199            fn to_ordinal(&self) -> i32 {
1200                match self {
1201                    #(Self::#config_idents => #config_vals,)*
1202                }
1203            }
1204
1205            const CONFIG_ENUM_ENTRY: *const ::pgrx::pg_sys::config_enum_entry = [
1206                #(
1207                    ::pgrx::pg_sys::config_enum_entry {
1208                        name: #config_names.as_ptr(),
1209                        val: #config_vals,
1210                        hidden: #config_hiddens,
1211                    },
1212                )*
1213                ::pgrx::pg_sys::config_enum_entry {
1214                    name: core::ptr::null(),
1215                    val: 0,
1216                    hidden: false,
1217                },
1218            ].as_ptr();
1219        }
1220    })
1221}
1222
1223#[derive(Debug, Hash, Ord, PartialOrd, Eq, PartialEq)]
1224enum PostgresTypeAttribute {
1225    InOutFuncs,
1226    PgBinaryProtocol,
1227    PgVarlenaInOutFuncs,
1228    Default,
1229    ManualFromIntoDatum,
1230}
1231
1232fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet<PostgresTypeAttribute> {
1233    let mut categorized_attributes = HashSet::new();
1234
1235    for a in attributes {
1236        let path = &a.path();
1237        let path = quote! {#path}.to_string();
1238        match path.as_str() {
1239            "inoutfuncs" => {
1240                categorized_attributes.insert(PostgresTypeAttribute::InOutFuncs);
1241            }
1242            "pg_binary_protocol" => {
1243                categorized_attributes.insert(PostgresTypeAttribute::PgBinaryProtocol);
1244            }
1245            "pgvarlena_inoutfuncs" => {
1246                categorized_attributes.insert(PostgresTypeAttribute::PgVarlenaInOutFuncs);
1247            }
1248            "bikeshed_postgres_type_manually_impl_from_into_datum" => {
1249                categorized_attributes.insert(PostgresTypeAttribute::ManualFromIntoDatum);
1250            }
1251            _ => {
1252                // we can just ignore attributes we don't understand
1253            }
1254        };
1255    }
1256
1257    categorized_attributes
1258}
1259
1260/**
1261Generate necessary code using the type in operators like `==` and `!=`.
1262
1263```rust,ignore
1264# use pgrx_pg_sys as pg_sys;
1265use pgrx::*;
1266use serde::{Deserialize, Serialize};
1267#[derive(Debug, Serialize, Deserialize, PostgresEnum, PartialEq, Eq, PostgresEq)]
1268enum DogNames {
1269    Nami,
1270    Brandy,
1271}
1272```
1273Optionally accepts the following attributes:
1274
1275* `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx).
1276
1277# No bounds?
1278Unlike some derives, this does not implement a "real" Rust trait, thus
1279PostgresEq cannot be used in trait bounds, nor can it be manually implemented.
1280*/
1281#[proc_macro_derive(PostgresEq, attributes(pgrx))]
1282pub fn derive_postgres_eq(input: TokenStream) -> TokenStream {
1283    let ast = parse_macro_input!(input as syn::DeriveInput);
1284    deriving_postgres_eq(ast).unwrap_or_else(syn::Error::into_compile_error).into()
1285}
1286
1287/**
1288Generate necessary code using the type in operators like `>`, `<`, `<=`, and `>=`.
1289
1290```rust,ignore
1291# use pgrx_pg_sys as pg_sys;
1292use pgrx::*;
1293use serde::{Deserialize, Serialize};
1294#[derive(
1295    Debug, Serialize, Deserialize, PartialEq, Eq,
1296     PartialOrd, Ord, PostgresEnum, PostgresOrd
1297)]
1298enum DogNames {
1299    Nami,
1300    Brandy,
1301}
1302```
1303Optionally accepts the following attributes:
1304
1305* `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx).
1306
1307# No bounds?
1308Unlike some derives, this does not implement a "real" Rust trait, thus
1309PostgresOrd cannot be used in trait bounds, nor can it be manually implemented.
1310*/
1311#[proc_macro_derive(PostgresOrd, attributes(pgrx))]
1312pub fn derive_postgres_ord(input: TokenStream) -> TokenStream {
1313    let ast = parse_macro_input!(input as syn::DeriveInput);
1314    deriving_postgres_ord(ast).unwrap_or_else(syn::Error::into_compile_error).into()
1315}
1316
1317/**
1318Generate necessary code for stable hashing the type so it can be used with `USING hash` indexes.
1319
1320```rust,ignore
1321# use pgrx_pg_sys as pg_sys;
1322use pgrx::*;
1323use serde::{Deserialize, Serialize};
1324#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, PostgresEnum, PostgresHash)]
1325enum DogNames {
1326    Nami,
1327    Brandy,
1328}
1329```
1330Optionally accepts the following attributes:
1331
1332* `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx).
1333
1334# No bounds?
1335Unlike some derives, this does not implement a "real" Rust trait, thus
1336PostgresHash cannot be used in trait bounds, nor can it be manually implemented.
1337*/
1338#[proc_macro_derive(PostgresHash, attributes(pgrx))]
1339pub fn derive_postgres_hash(input: TokenStream) -> TokenStream {
1340    let ast = parse_macro_input!(input as syn::DeriveInput);
1341    deriving_postgres_hash(ast).unwrap_or_else(syn::Error::into_compile_error).into()
1342}
1343
1344/// Derives the `ToAggregateName` trait.
1345#[proc_macro_derive(AggregateName, attributes(aggregate_name))]
1346pub fn derive_aggregate_name(input: TokenStream) -> TokenStream {
1347    let ast = parse_macro_input!(input as syn::DeriveInput);
1348
1349    impl_aggregate_name(ast).unwrap_or_else(|e| e.into_compile_error()).into()
1350}
1351
1352fn impl_aggregate_name(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
1353    let name = &ast.ident;
1354
1355    let mut custom_name_value: Option<String> = None;
1356
1357    for attr in &ast.attrs {
1358        if attr.path().is_ident("aggregate_name") {
1359            let meta = &attr.meta;
1360            match meta {
1361                syn::Meta::NameValue(syn::MetaNameValue {
1362                    value: syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(s), .. }),
1363                    ..
1364                }) => {
1365                    custom_name_value = Some(s.value());
1366                    break;
1367                }
1368                _ => {
1369                    return Err(syn::Error::new_spanned(
1370                        attr,
1371                        "#[aggregate_name] must be in the form `#[aggregate_name = \"string_literal\"]`",
1372                    ));
1373                }
1374            }
1375        }
1376    }
1377
1378    let name_str = custom_name_value.unwrap_or(name.to_string());
1379
1380    let expanded = quote! {
1381        impl ::pgrx::aggregate::ToAggregateName for #name {
1382            const NAME: &'static str = #name_str;
1383        }
1384    };
1385
1386    Ok(expanded)
1387}
1388
1389/**
1390Declare a `pgrx::Aggregate` implementation on a type as able to used by Postgres as an aggregate.
1391
1392Functions inside the `impl` may use the [`#[pgrx]`](macro@pgrx) attribute.
1393*/
1394#[proc_macro_attribute]
1395pub fn pg_aggregate(_attr: TokenStream, item: TokenStream) -> TokenStream {
1396    // We don't care about `_attr` as we can find it in the `ItemMod`.
1397    fn wrapped(item_impl: ItemImpl) -> Result<TokenStream, syn::Error> {
1398        let sql_graph_entity_item = PgAggregate::new(item_impl)?;
1399
1400        Ok(sql_graph_entity_item.to_token_stream().into())
1401    }
1402
1403    let parsed_base = parse_macro_input!(item as syn::ItemImpl);
1404    wrapped(parsed_base).unwrap_or_else(|e| e.into_compile_error().into())
1405}
1406
1407/**
1408A helper attribute for various contexts.
1409
1410## Usage with [`#[pg_aggregate]`](macro@pg_aggregate).
1411
1412It can be decorated on functions inside a [`#[pg_aggregate]`](macro@pg_aggregate) implementation.
1413In this position, it takes the same args as [`#[pg_extern]`](macro@pg_extern), and those args have the same effect.
1414
1415## Usage for configuring SQL generation
1416
1417This attribute can be used to control the behavior of the SQL generator on a decorated item,
1418e.g. `#[pgrx(sql = false)]`
1419
1420Currently `sql` can be provided one of the following:
1421
1422* Disable SQL generation with `#[pgrx(sql = false)]`
1423* Call custom SQL generator function with `#[pgrx(sql = path::to_function)]`
1424* Render a specific fragment of SQL with a string `#[pgrx(sql = "CREATE FUNCTION ...")]`
1425
1426*/
1427#[proc_macro_attribute]
1428pub fn pgrx(_attr: TokenStream, item: TokenStream) -> TokenStream {
1429    item
1430}
1431
1432/**
1433Declare a function as a GUC hook. This takes one argument: `show`, `check`, or `assign`.
1434
1435The first parameter of `check` and `assign` hooks must implement `pgrx::guc::GucValue`.
1436
1437Examples:
1438```rust,ignore
1439#[pg_guc_hook(show)]
1440fn my_show_hook() -> String {
1441    "CUSTOM_VALUE".to_string()
1442}
1443
1444#[pg_guc_hook(check)]
1445fn my_check_hook(newval: i32) -> Result<(), GucCheckError> {
1446    // accept or reject newval
1447}
1448
1449#[pg_guc_hook(assign)]
1450fn my_assign_hook(newval: i32) {
1451    // do more now that every change is accepted
1452}
1453```
1454
1455# Check Hooks
1456
1457This macro adapts check hooks of multiple forms.
1458
1459The simplest of these takes a single argument and returns a `bool`.
1460- `true` means the value is valid, and Postgres will store that value in the GUC variable.
1461- `false` means the value is invalid, and Postgres will produce an error message automatically.
1462
1463```rust,ignore
1464#[pg_guc_hook(check)]
1465fn my_check_hook(newval: i32) -> bool { newval > 0 }
1466```
1467
1468To apply your own error message for an invalid value, return a `Result<(), GucCheckError>`.
1469- `Ok` means the value is valid.
1470- `Err` means the value is invalid, and `pgrx` will pass the contents of `GucCheckError` to Postgres.
1471
1472```rust,ignore
1473#[pg_guc_hook(check)]
1474fn my_check_hook(newval: i32) -> Result<(), GucCheckError> {
1475    if newval > 0 {
1476        Ok(())
1477    } else {
1478        Err(
1479            GucCheckError::new("value cannot be negative")
1480            .with_hint("to configure the opposite behavior, set other_param"),
1481        )
1482    }
1483}
1484```
1485
1486A check hook can also accept a second argument that indicates when the parameter is changing.
1487Like the examples above, this can return a bool or a Result.
1488
1489```rust,ignore
1490#[pg_guc_hook(check)]
1491fn my_check_hook(newval: i32, source: pg_sys::GucSource::Type) -> bool { true }
1492```
1493*/
1494#[proc_macro_attribute]
1495pub fn pg_guc_hook(attr: TokenStream, item: TokenStream) -> TokenStream {
1496    let hook_type = parse_macro_input!(attr as syn::Ident);
1497    let mut func = parse_macro_input!(item as syn::ItemFn);
1498
1499    // Rename the original function and use its name.
1500    let original_ident = func.sig.ident.clone();
1501    let inner_ident = format_ident!("{}_INNER", original_ident);
1502    func.sig.ident = inner_ident.clone();
1503
1504    // Remove the original visibility and use it.
1505    let vis = &func.vis;
1506    let mut inner_func = func.clone();
1507    inner_func.vis = syn::Visibility::Inherited;
1508
1509    let wrapper = match hook_type.to_string().as_str() {
1510        "show" => {
1511            quote! {
1512                #[::pgrx::pg_guard]
1513                #vis unsafe extern "C-unwind" fn #original_ident() -> *const ::core::ffi::c_char {
1514                    #[inline(always)]
1515                    #inner_func
1516
1517                    let show = #inner_ident();
1518                    ::pgrx::pg_sys::AsPgCStr::as_pg_cstr(&show)
1519                }
1520            }
1521        }
1522        "check" => {
1523            let invoke_inner = match func.sig.inputs.len() {
1524                1 => quote! { #inner_ident(value) },
1525                2 => quote! { #inner_ident(value, source) },
1526                _ => {
1527                    return syn::Error::new(
1528                        func.sig.span(),
1529                        "check hook must have one or two arguments",
1530                    )
1531                    .into_compile_error()
1532                    .into();
1533                }
1534            };
1535
1536            let arg_type = match func.sig.inputs.first() {
1537                Some(syn::FnArg::Typed(pat_type)) => &pat_type.ty,
1538                _ => {
1539                    return syn::Error::new(
1540                        func.sig.span(),
1541                        "check hook first argument must implement GucValue",
1542                    )
1543                    .into_compile_error()
1544                    .into();
1545                }
1546            };
1547
1548            // The original function may return bool or Result. Pass bool through directly and apply
1549            // any details present in GucCheckError.
1550            let invoke_to_bool = if let syn::ReturnType::Type(_, return_type) = &func.sig.output
1551                && let syn::Type::Path(type_path) = &**return_type
1552                && type_path.path.segments.last().map_or(false, |seg| seg.ident == "bool")
1553            {
1554                quote! { #invoke_inner }
1555            } else {
1556                quote! {
1557                    match #invoke_inner {
1558                        Ok(()) => true,
1559                        Err(err) => {
1560                            unsafe { ::pgrx::guc::GucCheckError::apply(err) };
1561                            false
1562                        }
1563                    }
1564                }
1565            };
1566
1567            quote! {
1568                #[::pgrx::pg_guard]
1569                #vis unsafe extern "C-unwind" fn #original_ident(
1570                    newval: *mut <#arg_type as ::pgrx::guc::GucValue>::Raw,
1571                    extra: *mut *mut ::core::ffi::c_void,
1572                    source: ::pgrx::pg_sys::GucSource::Type,
1573                ) -> bool {
1574                    #[inline(always)]
1575                    #inner_func
1576
1577                    debug_assert!(!newval.is_null());
1578                    let value = unsafe { <#arg_type as ::pgrx::guc::GucValue>::from_raw(*newval) };
1579                    #invoke_to_bool
1580                }
1581            }
1582        }
1583        "assign" => {
1584            let arg_type = match func.sig.inputs.first() {
1585                Some(syn::FnArg::Typed(pat_type)) => &pat_type.ty,
1586                _ => {
1587                    return syn::Error::new(
1588                        func.sig.span(),
1589                        "assign hook first argument must implement GucValue",
1590                    )
1591                    .into_compile_error()
1592                    .into();
1593                }
1594            };
1595
1596            quote! {
1597                #[::pgrx::pg_guard]
1598                #vis unsafe extern "C-unwind" fn #original_ident(
1599                    newval: <#arg_type as ::pgrx::guc::GucValue>::Raw,
1600                    extra: *mut ::core::ffi::c_void,
1601                ) {
1602                    #[inline(always)]
1603                    #inner_func
1604
1605                    let value = unsafe { <#arg_type as ::pgrx::guc::GucValue>::from_raw(newval) };
1606                    #inner_ident(value);
1607                }
1608            }
1609        }
1610        _ => {
1611            return syn::Error::new(
1612                hook_type.span(),
1613                "Unknown GUC hook type. Expected 'show', 'check', or 'assign'",
1614            )
1615            .into_compile_error()
1616            .into();
1617        }
1618    };
1619
1620    wrapper.into()
1621}
1622
1623/**
1624Create a [PostgreSQL trigger function](https://www.postgresql.org/docs/current/plpgsql-trigger.html)
1625
1626Review the `pgrx::trigger_support::PgTrigger` documentation for use.
1627
1628 */
1629#[proc_macro_attribute]
1630pub fn pg_trigger(attrs: TokenStream, input: TokenStream) -> TokenStream {
1631    fn wrapped(attrs: TokenStream, input: TokenStream) -> Result<TokenStream, syn::Error> {
1632        use pgrx_sql_entity_graph::{PgTrigger, PgTriggerAttribute};
1633        use syn::Token;
1634        use syn::parse::Parser;
1635        use syn::punctuated::Punctuated;
1636
1637        let attributes =
1638            Punctuated::<PgTriggerAttribute, Token![,]>::parse_terminated.parse(attrs)?;
1639        let item_fn: syn::ItemFn = syn::parse(input)?;
1640        let trigger_item = PgTrigger::new(item_fn, attributes)?;
1641        let trigger_tokens = trigger_item.to_token_stream();
1642
1643        Ok(trigger_tokens.into())
1644    }
1645
1646    wrapped(attrs, input).unwrap_or_else(|e| e.into_compile_error().into())
1647}
1648
1649#[cfg(test)]
1650mod tests {
1651    use super::*;
1652
1653    #[test]
1654    fn short_name_unchanged() {
1655        let mut ident = syn::Ident::new("test_foo", proc_macro2::Span::call_site());
1656        let original = ident.to_string();
1657        maybe_shorten_pg_test_ident(&mut ident);
1658        assert_eq!(ident.to_string(), original);
1659    }
1660
1661    #[test]
1662    fn exactly_63_chars_unchanged() {
1663        let name = "a".repeat(63);
1664        let mut ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1665        maybe_shorten_pg_test_ident(&mut ident);
1666        assert_eq!(ident.to_string(), name);
1667    }
1668
1669    #[test]
1670    fn exactly_64_chars_is_shortened() {
1671        let name = "a".repeat(64);
1672        let mut ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
1673        maybe_shorten_pg_test_ident(&mut ident);
1674        let result = ident.to_string();
1675        assert!(result.len() <= 63, "shortened name is {len} chars: {result}", len = result.len());
1676        assert!(result.starts_with('t'), "shortened name should start with 't': {result}");
1677    }
1678
1679    #[test]
1680    fn very_long_name_fits_in_63() {
1681        let name = "test_that_something_really_important_works_correctly_when_given_a_very_long_input_name";
1682        assert!(name.len() > 63);
1683        let mut ident = syn::Ident::new(name, proc_macro2::Span::call_site());
1684        maybe_shorten_pg_test_ident(&mut ident);
1685        let result = ident.to_string();
1686        assert_eq!(result.len(), 63, "shortened name should be exactly 63 chars: {result}");
1687    }
1688
1689    #[test]
1690    fn different_long_names_get_different_shortened_names() {
1691        let name_a = format!("{}{}", "a".repeat(60), "xxxx");
1692        let name_b = format!("{}{}", "a".repeat(60), "yyyy");
1693        let mut id_a = syn::Ident::new(&name_a, proc_macro2::Span::call_site());
1694        let mut id_b = syn::Ident::new(&name_b, proc_macro2::Span::call_site());
1695        maybe_shorten_pg_test_ident(&mut id_a);
1696        maybe_shorten_pg_test_ident(&mut id_b);
1697        assert_ne!(
1698            id_a.to_string(),
1699            id_b.to_string(),
1700            "names differing only in the tail should still get different shortened forms"
1701        );
1702    }
1703}