use crate::PositioningRef;
use proc_macro2::{Ident, Span, TokenStream, TokenTree};
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
use std::collections::HashSet;
#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
pub enum ExternArgs {
CreateOrReplace,
Immutable,
Strict,
Stable,
Volatile,
Raw,
NoGuard,
SecurityDefiner,
SecurityInvoker,
ParallelSafe,
ParallelUnsafe,
ParallelRestricted,
ShouldPanic(String),
Schema(String),
Support(PositioningRef),
Name(String),
Cost(String),
Requires(Vec<PositioningRef>),
}
impl core::fmt::Display for ExternArgs {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
ExternArgs::Immutable => write!(f, "IMMUTABLE"),
ExternArgs::Strict => write!(f, "STRICT"),
ExternArgs::Stable => write!(f, "STABLE"),
ExternArgs::Volatile => write!(f, "VOLATILE"),
ExternArgs::Raw => Ok(()),
ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
ExternArgs::SecurityDefiner => write!(f, "SECURITY DEFINER"),
ExternArgs::SecurityInvoker => write!(f, "SECURITY INVOKER"),
ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
ExternArgs::Support(item) => write!(f, "{item}"),
ExternArgs::ShouldPanic(_) => Ok(()),
ExternArgs::NoGuard => Ok(()),
ExternArgs::Schema(_) => Ok(()),
ExternArgs::Name(_) => Ok(()),
ExternArgs::Cost(cost) => write!(f, "COST {cost}"),
ExternArgs::Requires(_) => Ok(()),
}
}
}
impl ExternArgs {
pub fn section_len_tokens(&self) -> TokenStream {
match self {
ExternArgs::CreateOrReplace
| ExternArgs::Immutable
| ExternArgs::Strict
| ExternArgs::Stable
| ExternArgs::Volatile
| ExternArgs::Raw
| ExternArgs::NoGuard
| ExternArgs::SecurityDefiner
| ExternArgs::SecurityInvoker
| ExternArgs::ParallelSafe
| ExternArgs::ParallelUnsafe
| ExternArgs::ParallelRestricted => {
quote! { ::pgrx::pgrx_sql_entity_graph::section::u8_len() }
}
ExternArgs::ShouldPanic(value)
| ExternArgs::Schema(value)
| ExternArgs::Name(value)
| ExternArgs::Cost(value) => quote! {
::pgrx::pgrx_sql_entity_graph::section::u8_len()
+ ::pgrx::pgrx_sql_entity_graph::section::str_len(#value)
},
ExternArgs::Support(item) => {
let item_len = item.section_len_tokens();
quote! {
::pgrx::pgrx_sql_entity_graph::section::u8_len() + (#item_len)
}
}
ExternArgs::Requires(items) => {
let item_lens = items.iter().map(PositioningRef::section_len_tokens);
quote! {
::pgrx::pgrx_sql_entity_graph::section::u8_len()
+ ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
#( #item_lens ),*
])
}
}
}
}
pub fn section_writer_tokens(&self, writer: TokenStream) -> TokenStream {
match self {
ExternArgs::CreateOrReplace => {
quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_CREATE_OR_REPLACE) }
}
ExternArgs::Immutable => {
quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_IMMUTABLE) }
}
ExternArgs::Strict => {
quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_STRICT) }
}
ExternArgs::Stable => {
quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_STABLE) }
}
ExternArgs::Volatile => {
quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_VOLATILE) }
}
ExternArgs::Raw => {
quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_RAW) }
}
ExternArgs::NoGuard => {
quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_NO_GUARD) }
}
ExternArgs::SecurityDefiner => quote! {
#writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SECURITY_DEFINER)
},
ExternArgs::SecurityInvoker => quote! {
#writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SECURITY_INVOKER)
},
ExternArgs::ParallelSafe => quote! {
#writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_SAFE)
},
ExternArgs::ParallelUnsafe => quote! {
#writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_UNSAFE)
},
ExternArgs::ParallelRestricted => quote! {
#writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_RESTRICTED)
},
ExternArgs::ShouldPanic(value) => quote! {
#writer
.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SHOULD_PANIC)
.str(#value)
},
ExternArgs::Schema(value) => quote! {
#writer
.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SCHEMA)
.str(#value)
},
ExternArgs::Support(item) => item.section_writer_tokens(quote! {
#writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SUPPORT)
}),
ExternArgs::Name(value) => quote! {
#writer
.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_NAME)
.str(#value)
},
ExternArgs::Cost(value) => quote! {
#writer
.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_COST)
.str(#value)
},
ExternArgs::Requires(items) => {
let writer_ident = Ident::new("__pgrx_schema_writer", Span::mixed_site());
let item_writers =
items.iter().map(|item| item.section_writer_tokens(quote! { #writer_ident }));
let count = items.len();
quote! {
{
let #writer_ident = #writer
.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_REQUIRES)
.u32(#count as u32);
#( let #writer_ident = { #item_writers }; )*
#writer_ident
}
}
}
}
}
}
impl ToTokens for ExternArgs {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
ExternArgs::Strict => tokens.append(format_ident!("Strict")),
ExternArgs::Stable => tokens.append(format_ident!("Stable")),
ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
ExternArgs::Raw => tokens.append(format_ident!("Raw")),
ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
ExternArgs::SecurityDefiner => tokens.append(format_ident!("SecurityDefiner")),
ExternArgs::SecurityInvoker => tokens.append(format_ident!("SecurityInvoker")),
ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
ExternArgs::ShouldPanic(_s) => tokens.append_all(quote! { Error(String::from("#_s")) }),
ExternArgs::Schema(_s) => tokens.append_all(quote! { Schema(String::from("#_s")) }),
ExternArgs::Support(item) => tokens.append_all(quote! { Support(#item) }),
ExternArgs::Name(_s) => tokens.append_all(quote! { Name(String::from("#_s")) }),
ExternArgs::Cost(_s) => tokens.append_all(quote! { Cost(String::from("#_s")) }),
ExternArgs::Requires(items) => {
tokens.append_all(quote! { Requires(vec![#(#items),*]) })
}
}
}
}
#[track_caller]
pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
let mut args = HashSet::<ExternArgs>::new();
let mut itr = attr.into_iter();
while let Some(t) = itr.next() {
match t {
TokenTree::Group(g) => {
for arg in parse_extern_attributes(g.stream()).into_iter() {
args.insert(arg);
}
}
TokenTree::Ident(i) => {
let name = i.to_string();
match name.as_str() {
"create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
"immutable" => args.insert(ExternArgs::Immutable),
"strict" => args.insert(ExternArgs::Strict),
"stable" => args.insert(ExternArgs::Stable),
"volatile" => args.insert(ExternArgs::Volatile),
"raw" => args.insert(ExternArgs::Raw),
"no_guard" => args.insert(ExternArgs::NoGuard),
"security_invoker" => args.insert(ExternArgs::SecurityInvoker),
"security_definer" => args.insert(ExternArgs::SecurityDefiner),
"parallel_safe" => args.insert(ExternArgs::ParallelSafe),
"parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
"parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
"error" | "expected" => {
let _punc = itr.next().unwrap();
let literal = itr.next().unwrap();
let message = literal.to_string();
let message = unescape::unescape(&message).expect("failed to unescape");
let message = message[1..message.len() - 1].to_string();
args.insert(ExternArgs::ShouldPanic(message.to_string()))
}
"schema" => {
let _punc = itr.next().unwrap();
let literal = itr.next().unwrap();
let schema = literal.to_string();
let schema = unescape::unescape(&schema).expect("failed to unescape");
let schema = schema[1..schema.len() - 1].to_string();
args.insert(ExternArgs::Schema(schema.to_string()))
}
"name" => {
let _punc = itr.next().unwrap();
let literal = itr.next().unwrap();
let name = literal.to_string();
let name = unescape::unescape(&name).expect("failed to unescape");
let name = name[1..name.len() - 1].to_string();
args.insert(ExternArgs::Name(name.to_string()))
}
"sql" => {
let _punc = itr.next().unwrap();
let _value = itr.next().unwrap();
false
}
_ => false,
};
}
TokenTree::Punct(_) => {}
TokenTree::Literal(_) => {}
}
}
args
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use crate::{ExternArgs, parse_extern_attributes};
#[test]
fn parse_args() {
let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
let ts = proc_macro2::TokenStream::from_str(s).unwrap();
let args = parse_extern_attributes(ts);
assert!(
args.contains(&ExternArgs::ShouldPanic("syntax error at or near \"THIS\"".to_string()))
);
}
}