Skip to main content

pgrx_sql_entity_graph/postgres_enum/
mod.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[derive(PostgresEnum)]` related macro expansion for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18pub mod entity;
19
20use crate::enrich::{ToEntityGraphTokens, ToRustCodeTokens};
21use crate::{CodeEnrichment, ToSqlConfig};
22use proc_macro2::{Span, TokenStream as TokenStream2};
23use quote::{format_ident, quote};
24use syn::parse::{Parse, ParseStream};
25use syn::punctuated::Punctuated;
26use syn::{DeriveInput, Generics, Ident, ItemEnum, Token};
27
28/// A parsed `#[derive(PostgresEnum)]` item.
29///
30/// It should be used with [`syn::parse::Parse`] functions.
31///
32/// Using [`quote::ToTokens`] will output the declaration for a `pgrx::datum::pgrx_sql_entity_graph::PostgresEnumEntity`.
33///
34/// ```rust
35/// use syn::{Macro, parse::Parse, parse_quote, parse};
36/// use quote::{quote, ToTokens};
37/// use pgrx_sql_entity_graph::PostgresEnum;
38///
39/// # fn main() -> eyre::Result<()> {
40/// use pgrx_sql_entity_graph::CodeEnrichment;
41/// let parsed: CodeEnrichment<PostgresEnum> = parse_quote! {
42///     #[derive(PostgresEnum)]
43///     enum Demo {
44///         Example,
45///     }
46/// };
47/// let sql_graph_entity_tokens = parsed.to_token_stream();
48/// # Ok(())
49/// # }
50/// ```
51#[derive(Debug, Clone)]
52pub struct PostgresEnum {
53    name: Ident,
54    generics: Generics,
55    variants: Punctuated<syn::Variant, Token![,]>,
56    to_sql_config: ToSqlConfig,
57}
58
59impl PostgresEnum {
60    pub fn new(
61        name: Ident,
62        generics: Generics,
63        variants: Punctuated<syn::Variant, Token![,]>,
64        to_sql_config: ToSqlConfig,
65    ) -> Result<CodeEnrichment<Self>, syn::Error> {
66        if !to_sql_config.overrides_default() {
67            crate::ident_is_acceptable_to_postgres(&name)?;
68        }
69
70        Ok(CodeEnrichment(Self { name, generics, variants, to_sql_config }))
71    }
72
73    pub fn from_derive_input(
74        derive_input: DeriveInput,
75    ) -> Result<CodeEnrichment<Self>, syn::Error> {
76        let to_sql_config =
77            ToSqlConfig::from_attributes(derive_input.attrs.as_slice())?.unwrap_or_default();
78        let data_enum = match derive_input.data {
79            syn::Data::Enum(data_enum) => data_enum,
80            syn::Data::Union(_) | syn::Data::Struct(_) => {
81                return Err(syn::Error::new(derive_input.ident.span(), "expected enum"));
82            }
83        };
84        Self::new(derive_input.ident, derive_input.generics, data_enum.variants, to_sql_config)
85    }
86}
87
88impl ToEntityGraphTokens for PostgresEnum {
89    fn to_entity_graph_tokens(&self) -> TokenStream2 {
90        // It's important we remap all lifetimes we spot to `'static` so they can be used during inventory submission.
91        let name = self.name.clone();
92        let mut static_generics = self.generics.clone();
93        static_generics.params = static_generics
94            .params
95            .clone()
96            .into_iter()
97            .flat_map(|param| match param {
98                item @ syn::GenericParam::Type(_) | item @ syn::GenericParam::Const(_) => {
99                    Some(item)
100                }
101                syn::GenericParam::Lifetime(mut lifetime) => {
102                    lifetime.lifetime.ident = Ident::new("static", Span::call_site());
103                    Some(syn::GenericParam::Lifetime(lifetime))
104                }
105            })
106            .collect();
107        let mut staticless_generics = self.generics.clone();
108        staticless_generics.params = static_generics
109            .params
110            .clone()
111            .into_iter()
112            .flat_map(|param| match param {
113                item @ syn::GenericParam::Type(_) | item @ syn::GenericParam::Const(_) => {
114                    Some(item)
115                }
116                syn::GenericParam::Lifetime(_) => None,
117            })
118            .collect();
119        let (staticless_impl_generics, _staticless_ty_generics, _staticless_where_clauses) =
120            staticless_generics.split_for_impl();
121        let (_static_impl_generics, static_ty_generics, static_where_clauses) =
122            static_generics.split_for_impl();
123
124        let variants =
125            self.variants.iter().map(|variant| variant.ident.clone()).collect::<Vec<_>>();
126        let sql_graph_entity_fn_name = format_ident!("__pgrx_schema_enum_{}", name);
127
128        let to_sql_config = &self.to_sql_config;
129        let to_sql_config_len = to_sql_config.section_len_tokens();
130        let variants_len = quote! {
131            ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
132                #( ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#variants)) ),*
133            ])
134        };
135        let payload_len = quote! {
136            ::pgrx::pgrx_sql_entity_graph::section::u8_len()
137                + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#name))
138                + ::pgrx::pgrx_sql_entity_graph::section::str_len(file!())
139                + ::pgrx::pgrx_sql_entity_graph::section::u32_len()
140                + ::pgrx::pgrx_sql_entity_graph::section::str_len(module_path!())
141                + ::pgrx::pgrx_sql_entity_graph::section::str_len(stringify!(#name #static_ty_generics))
142                + ::pgrx::pgrx_sql_entity_graph::section::str_len(<#name #static_ty_generics as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::TYPE_IDENT)
143                + (#variants_len)
144                + (#to_sql_config_len)
145        };
146        let total_len = quote! {
147            ::pgrx::pgrx_sql_entity_graph::section::u32_len() + (#payload_len)
148        };
149        let writer = to_sql_config.section_writer_tokens(quote! {
150            ::pgrx::pgrx_sql_entity_graph::section::EntryWriter::<{ #total_len }>::new()
151                .u32((#payload_len) as u32)
152                .u8(::pgrx::pgrx_sql_entity_graph::section::ENTITY_ENUM)
153                .str(stringify!(#name))
154                .str(file!())
155                .u32(line!())
156                .str(module_path!())
157                .str(stringify!(#name #static_ty_generics))
158                .str(<#name #static_ty_generics as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::TYPE_IDENT)
159                .u32([ #( stringify!(#variants) ),* ].len() as u32)
160                #( .str(stringify!(#variants)) )*
161        });
162
163        quote! {
164            unsafe impl #staticless_impl_generics ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable for #name #static_ty_generics #static_where_clauses {
165                const TYPE_IDENT: &'static str = ::pgrx::pgrx_resolved_type!(#name #static_ty_generics);
166                const TYPE_ORIGIN: ::pgrx::pgrx_sql_entity_graph::metadata::TypeOrigin =
167                    ::pgrx::pgrx_sql_entity_graph::metadata::TypeOrigin::ThisExtension;
168                const ARGUMENT_SQL: core::result::Result<
169                    ::pgrx::pgrx_sql_entity_graph::metadata::SqlMappingRef,
170                    ::pgrx::pgrx_sql_entity_graph::metadata::ArgumentError,
171                > = Ok(::pgrx::pgrx_sql_entity_graph::metadata::SqlMappingRef::As(stringify!(#name)));
172                const RETURN_SQL: core::result::Result<
173                    ::pgrx::pgrx_sql_entity_graph::metadata::ReturnsRef,
174                    ::pgrx::pgrx_sql_entity_graph::metadata::ReturnsError,
175                > = Ok(::pgrx::pgrx_sql_entity_graph::metadata::ReturnsRef::One(
176                    ::pgrx::pgrx_sql_entity_graph::metadata::SqlMappingRef::As(stringify!(#name))
177                ));
178            }
179
180            ::pgrx::pgrx_sql_entity_graph::__pgrx_schema_entry!(
181                #sql_graph_entity_fn_name,
182                #total_len,
183                #writer.finish()
184            );
185        }
186    }
187}
188
189impl ToRustCodeTokens for PostgresEnum {}
190
191impl Parse for CodeEnrichment<PostgresEnum> {
192    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
193        let parsed: ItemEnum = input.parse()?;
194        let to_sql_config =
195            ToSqlConfig::from_attributes(parsed.attrs.as_slice())?.unwrap_or_default();
196        PostgresEnum::new(parsed.ident, parsed.generics, parsed.variants, to_sql_config)
197    }
198}