1use crate::PositioningRef;
11use proc_macro2::{TokenStream, TokenTree};
12use quote::{format_ident, quote, ToTokens, TokenStreamExt};
13use std::collections::HashSet;
14
15#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
16pub enum ExternArgs {
17 CreateOrReplace,
18 Immutable,
19 Strict,
20 Stable,
21 Volatile,
22 Raw,
23 NoGuard,
24 SecurityDefiner,
25 SecurityInvoker,
26 ParallelSafe,
27 ParallelUnsafe,
28 ParallelRestricted,
29 ShouldPanic(String),
30 Schema(String),
31 Support(PositioningRef),
32 Name(String),
33 Cost(String),
34 Requires(Vec<PositioningRef>),
35}
36
37impl core::fmt::Display for ExternArgs {
38 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39 match self {
40 ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
41 ExternArgs::Immutable => write!(f, "IMMUTABLE"),
42 ExternArgs::Strict => write!(f, "STRICT"),
43 ExternArgs::Stable => write!(f, "STABLE"),
44 ExternArgs::Volatile => write!(f, "VOLATILE"),
45 ExternArgs::Raw => Ok(()),
46 ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
47 ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
48 ExternArgs::SecurityDefiner => write!(f, "SECURITY DEFINER"),
49 ExternArgs::SecurityInvoker => write!(f, "SECURITY INVOKER"),
50 ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
51 ExternArgs::Support(item) => write!(f, "{item}"),
52 ExternArgs::ShouldPanic(_) => Ok(()),
53 ExternArgs::NoGuard => Ok(()),
54 ExternArgs::Schema(_) => Ok(()),
55 ExternArgs::Name(_) => Ok(()),
56 ExternArgs::Cost(cost) => write!(f, "COST {cost}"),
57 ExternArgs::Requires(_) => Ok(()),
58 }
59 }
60}
61
62impl ToTokens for ExternArgs {
63 fn to_tokens(&self, tokens: &mut TokenStream) {
64 match self {
65 ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
66 ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
67 ExternArgs::Strict => tokens.append(format_ident!("Strict")),
68 ExternArgs::Stable => tokens.append(format_ident!("Stable")),
69 ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
70 ExternArgs::Raw => tokens.append(format_ident!("Raw")),
71 ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
72 ExternArgs::SecurityDefiner => tokens.append(format_ident!("SecurityDefiner")),
73 ExternArgs::SecurityInvoker => tokens.append(format_ident!("SecurityInvoker")),
74 ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
75 ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
76 ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
77 ExternArgs::ShouldPanic(_s) => tokens.append_all(quote! { Error(String::from("#_s")) }),
78 ExternArgs::Schema(_s) => tokens.append_all(quote! { Schema(String::from("#_s")) }),
79 ExternArgs::Support(item) => tokens.append_all(quote! { Support(#item) }),
80 ExternArgs::Name(_s) => tokens.append_all(quote! { Name(String::from("#_s")) }),
81 ExternArgs::Cost(_s) => tokens.append_all(quote! { Cost(String::from("#_s")) }),
82 ExternArgs::Requires(items) => {
83 tokens.append_all(quote! { Requires(vec![#(#items),*]) })
84 }
85 }
86 }
87}
88
89#[track_caller]
91pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
92 let mut args = HashSet::<ExternArgs>::new();
93 let mut itr = attr.into_iter();
94 while let Some(t) = itr.next() {
95 match t {
96 TokenTree::Group(g) => {
97 for arg in parse_extern_attributes(g.stream()).into_iter() {
98 args.insert(arg);
99 }
100 }
101 TokenTree::Ident(i) => {
102 let name = i.to_string();
103 match name.as_str() {
104 "create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
105 "immutable" => args.insert(ExternArgs::Immutable),
106 "strict" => args.insert(ExternArgs::Strict),
107 "stable" => args.insert(ExternArgs::Stable),
108 "volatile" => args.insert(ExternArgs::Volatile),
109 "raw" => args.insert(ExternArgs::Raw),
110 "no_guard" => args.insert(ExternArgs::NoGuard),
111 "security_invoker" => args.insert(ExternArgs::SecurityInvoker),
112 "security_definer" => args.insert(ExternArgs::SecurityDefiner),
113 "parallel_safe" => args.insert(ExternArgs::ParallelSafe),
114 "parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
115 "parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
116 "error" | "expected" => {
117 let _punc = itr.next().unwrap();
118 let literal = itr.next().unwrap();
119 let message = literal.to_string();
120 let message = unescape::unescape(&message).expect("failed to unescape");
121
122 let message = message[1..message.len() - 1].to_string();
124 args.insert(ExternArgs::ShouldPanic(message.to_string()))
125 }
126 "schema" => {
127 let _punc = itr.next().unwrap();
128 let literal = itr.next().unwrap();
129 let schema = literal.to_string();
130 let schema = unescape::unescape(&schema).expect("failed to unescape");
131
132 let schema = schema[1..schema.len() - 1].to_string();
134 args.insert(ExternArgs::Schema(schema.to_string()))
135 }
136 "name" => {
137 let _punc = itr.next().unwrap();
138 let literal = itr.next().unwrap();
139 let name = literal.to_string();
140 let name = unescape::unescape(&name).expect("failed to unescape");
141
142 let name = name[1..name.len() - 1].to_string();
144 args.insert(ExternArgs::Name(name.to_string()))
145 }
146 "sql" => {
148 let _punc = itr.next().unwrap();
149 let _value = itr.next().unwrap();
150 false
151 }
152 _ => false,
153 };
154 }
155 TokenTree::Punct(_) => {}
156 TokenTree::Literal(_) => {}
157 }
158 }
159 args
160}
161
162#[cfg(test)]
163mod tests {
164 use std::str::FromStr;
165
166 use crate::{parse_extern_attributes, ExternArgs};
167
168 #[test]
169 fn parse_args() {
170 let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
171 let ts = proc_macro2::TokenStream::from_str(s).unwrap();
172
173 let args = parse_extern_attributes(ts);
174 assert!(
175 args.contains(&ExternArgs::ShouldPanic("syntax error at or near \"THIS\"".to_string()))
176 );
177 }
178}