pgrx-sql-entity-graph 0.18.0

Sql Entity Graph for `pgrx`
Documentation
//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
//LICENSE
//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
//LICENSE
//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
//LICENSE
//LICENSE All rights reserved.
//LICENSE
//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
use crate::PositioningRef;
use proc_macro2::{Ident, Span, TokenStream, TokenTree};
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
use std::collections::HashSet;

#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
pub enum ExternArgs {
    CreateOrReplace,
    Immutable,
    Strict,
    Stable,
    Volatile,
    Raw,
    NoGuard,
    SecurityDefiner,
    SecurityInvoker,
    ParallelSafe,
    ParallelUnsafe,
    ParallelRestricted,
    ShouldPanic(String),
    Schema(String),
    Support(PositioningRef),
    Name(String),
    Cost(String),
    Requires(Vec<PositioningRef>),
}

impl core::fmt::Display for ExternArgs {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
            ExternArgs::Immutable => write!(f, "IMMUTABLE"),
            ExternArgs::Strict => write!(f, "STRICT"),
            ExternArgs::Stable => write!(f, "STABLE"),
            ExternArgs::Volatile => write!(f, "VOLATILE"),
            ExternArgs::Raw => Ok(()),
            ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
            ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
            ExternArgs::SecurityDefiner => write!(f, "SECURITY DEFINER"),
            ExternArgs::SecurityInvoker => write!(f, "SECURITY INVOKER"),
            ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
            ExternArgs::Support(item) => write!(f, "{item}"),
            ExternArgs::ShouldPanic(_) => Ok(()),
            ExternArgs::NoGuard => Ok(()),
            ExternArgs::Schema(_) => Ok(()),
            ExternArgs::Name(_) => Ok(()),
            ExternArgs::Cost(cost) => write!(f, "COST {cost}"),
            ExternArgs::Requires(_) => Ok(()),
        }
    }
}

impl ExternArgs {
    pub fn section_len_tokens(&self) -> TokenStream {
        match self {
            ExternArgs::CreateOrReplace
            | ExternArgs::Immutable
            | ExternArgs::Strict
            | ExternArgs::Stable
            | ExternArgs::Volatile
            | ExternArgs::Raw
            | ExternArgs::NoGuard
            | ExternArgs::SecurityDefiner
            | ExternArgs::SecurityInvoker
            | ExternArgs::ParallelSafe
            | ExternArgs::ParallelUnsafe
            | ExternArgs::ParallelRestricted => {
                quote! { ::pgrx::pgrx_sql_entity_graph::section::u8_len() }
            }
            ExternArgs::ShouldPanic(value)
            | ExternArgs::Schema(value)
            | ExternArgs::Name(value)
            | ExternArgs::Cost(value) => quote! {
                ::pgrx::pgrx_sql_entity_graph::section::u8_len()
                    + ::pgrx::pgrx_sql_entity_graph::section::str_len(#value)
            },
            ExternArgs::Support(item) => {
                let item_len = item.section_len_tokens();
                quote! {
                    ::pgrx::pgrx_sql_entity_graph::section::u8_len() + (#item_len)
                }
            }
            ExternArgs::Requires(items) => {
                let item_lens = items.iter().map(PositioningRef::section_len_tokens);
                quote! {
                    ::pgrx::pgrx_sql_entity_graph::section::u8_len()
                        + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
                            #( #item_lens ),*
                        ])
                }
            }
        }
    }

    pub fn section_writer_tokens(&self, writer: TokenStream) -> TokenStream {
        match self {
            ExternArgs::CreateOrReplace => {
                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_CREATE_OR_REPLACE) }
            }
            ExternArgs::Immutable => {
                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_IMMUTABLE) }
            }
            ExternArgs::Strict => {
                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_STRICT) }
            }
            ExternArgs::Stable => {
                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_STABLE) }
            }
            ExternArgs::Volatile => {
                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_VOLATILE) }
            }
            ExternArgs::Raw => {
                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_RAW) }
            }
            ExternArgs::NoGuard => {
                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_NO_GUARD) }
            }
            ExternArgs::SecurityDefiner => quote! {
                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SECURITY_DEFINER)
            },
            ExternArgs::SecurityInvoker => quote! {
                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SECURITY_INVOKER)
            },
            ExternArgs::ParallelSafe => quote! {
                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_SAFE)
            },
            ExternArgs::ParallelUnsafe => quote! {
                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_UNSAFE)
            },
            ExternArgs::ParallelRestricted => quote! {
                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_RESTRICTED)
            },
            ExternArgs::ShouldPanic(value) => quote! {
                #writer
                    .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SHOULD_PANIC)
                    .str(#value)
            },
            ExternArgs::Schema(value) => quote! {
                #writer
                    .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SCHEMA)
                    .str(#value)
            },
            ExternArgs::Support(item) => item.section_writer_tokens(quote! {
                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SUPPORT)
            }),
            ExternArgs::Name(value) => quote! {
                #writer
                    .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_NAME)
                    .str(#value)
            },
            ExternArgs::Cost(value) => quote! {
                #writer
                    .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_COST)
                    .str(#value)
            },
            ExternArgs::Requires(items) => {
                let writer_ident = Ident::new("__pgrx_schema_writer", Span::mixed_site());
                let item_writers =
                    items.iter().map(|item| item.section_writer_tokens(quote! { #writer_ident }));
                let count = items.len();
                quote! {
                    {
                        let #writer_ident = #writer
                            .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_REQUIRES)
                            .u32(#count as u32);
                        #( let #writer_ident = { #item_writers }; )*
                        #writer_ident
                    }
                }
            }
        }
    }
}

impl ToTokens for ExternArgs {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        match self {
            ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
            ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
            ExternArgs::Strict => tokens.append(format_ident!("Strict")),
            ExternArgs::Stable => tokens.append(format_ident!("Stable")),
            ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
            ExternArgs::Raw => tokens.append(format_ident!("Raw")),
            ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
            ExternArgs::SecurityDefiner => tokens.append(format_ident!("SecurityDefiner")),
            ExternArgs::SecurityInvoker => tokens.append(format_ident!("SecurityInvoker")),
            ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
            ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
            ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
            ExternArgs::ShouldPanic(_s) => tokens.append_all(quote! { Error(String::from("#_s")) }),
            ExternArgs::Schema(_s) => tokens.append_all(quote! { Schema(String::from("#_s")) }),
            ExternArgs::Support(item) => tokens.append_all(quote! { Support(#item) }),
            ExternArgs::Name(_s) => tokens.append_all(quote! { Name(String::from("#_s")) }),
            ExternArgs::Cost(_s) => tokens.append_all(quote! { Cost(String::from("#_s")) }),
            ExternArgs::Requires(items) => {
                tokens.append_all(quote! { Requires(vec![#(#items),*]) })
            }
        }
    }
}

// This horror-story should be returning result
#[track_caller]
pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
    let mut args = HashSet::<ExternArgs>::new();
    let mut itr = attr.into_iter();
    while let Some(t) = itr.next() {
        match t {
            TokenTree::Group(g) => {
                for arg in parse_extern_attributes(g.stream()).into_iter() {
                    args.insert(arg);
                }
            }
            TokenTree::Ident(i) => {
                let name = i.to_string();
                match name.as_str() {
                    "create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
                    "immutable" => args.insert(ExternArgs::Immutable),
                    "strict" => args.insert(ExternArgs::Strict),
                    "stable" => args.insert(ExternArgs::Stable),
                    "volatile" => args.insert(ExternArgs::Volatile),
                    "raw" => args.insert(ExternArgs::Raw),
                    "no_guard" => args.insert(ExternArgs::NoGuard),
                    "security_invoker" => args.insert(ExternArgs::SecurityInvoker),
                    "security_definer" => args.insert(ExternArgs::SecurityDefiner),
                    "parallel_safe" => args.insert(ExternArgs::ParallelSafe),
                    "parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
                    "parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
                    "error" | "expected" => {
                        let _punc = itr.next().unwrap();
                        let literal = itr.next().unwrap();
                        let message = literal.to_string();
                        let message = unescape::unescape(&message).expect("failed to unescape");

                        // trim leading/trailing quotes around the literal
                        let message = message[1..message.len() - 1].to_string();
                        args.insert(ExternArgs::ShouldPanic(message.to_string()))
                    }
                    "schema" => {
                        let _punc = itr.next().unwrap();
                        let literal = itr.next().unwrap();
                        let schema = literal.to_string();
                        let schema = unescape::unescape(&schema).expect("failed to unescape");

                        // trim leading/trailing quotes around the literal
                        let schema = schema[1..schema.len() - 1].to_string();
                        args.insert(ExternArgs::Schema(schema.to_string()))
                    }
                    "name" => {
                        let _punc = itr.next().unwrap();
                        let literal = itr.next().unwrap();
                        let name = literal.to_string();
                        let name = unescape::unescape(&name).expect("failed to unescape");

                        // trim leading/trailing quotes around the literal
                        let name = name[1..name.len() - 1].to_string();
                        args.insert(ExternArgs::Name(name.to_string()))
                    }
                    // Recognized, but not handled as an extern argument
                    "sql" => {
                        let _punc = itr.next().unwrap();
                        let _value = itr.next().unwrap();
                        false
                    }
                    _ => false,
                };
            }
            TokenTree::Punct(_) => {}
            TokenTree::Literal(_) => {}
        }
    }
    args
}

#[cfg(test)]
mod tests {
    use std::str::FromStr;

    use crate::{ExternArgs, parse_extern_attributes};

    #[test]
    fn parse_args() {
        let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
        let ts = proc_macro2::TokenStream::from_str(s).unwrap();

        let args = parse_extern_attributes(ts);
        assert!(
            args.contains(&ExternArgs::ShouldPanic("syntax error at or near \"THIS\"".to_string()))
        );
    }
}