1use crate::PositioningRef;
11use proc_macro2::{Ident, Span, TokenStream, TokenTree};
12use quote::{ToTokens, TokenStreamExt, format_ident, quote};
13use std::collections::HashSet;
14
15#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
16pub enum ExternArgs {
17 CreateOrReplace,
18 Immutable,
19 Strict,
20 Stable,
21 Volatile,
22 Raw,
23 NoGuard,
24 SecurityDefiner,
25 SecurityInvoker,
26 ParallelSafe,
27 ParallelUnsafe,
28 ParallelRestricted,
29 ShouldPanic(String),
30 Schema(String),
31 Support(PositioningRef),
32 Name(String),
33 Cost(String),
34 Requires(Vec<PositioningRef>),
35}
36
37impl core::fmt::Display for ExternArgs {
38 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39 match self {
40 ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
41 ExternArgs::Immutable => write!(f, "IMMUTABLE"),
42 ExternArgs::Strict => write!(f, "STRICT"),
43 ExternArgs::Stable => write!(f, "STABLE"),
44 ExternArgs::Volatile => write!(f, "VOLATILE"),
45 ExternArgs::Raw => Ok(()),
46 ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
47 ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
48 ExternArgs::SecurityDefiner => write!(f, "SECURITY DEFINER"),
49 ExternArgs::SecurityInvoker => write!(f, "SECURITY INVOKER"),
50 ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
51 ExternArgs::Support(item) => write!(f, "{item}"),
52 ExternArgs::ShouldPanic(_) => Ok(()),
53 ExternArgs::NoGuard => Ok(()),
54 ExternArgs::Schema(_) => Ok(()),
55 ExternArgs::Name(_) => Ok(()),
56 ExternArgs::Cost(cost) => write!(f, "COST {cost}"),
57 ExternArgs::Requires(_) => Ok(()),
58 }
59 }
60}
61
62impl ExternArgs {
63 pub fn section_len_tokens(&self) -> TokenStream {
64 match self {
65 ExternArgs::CreateOrReplace
66 | ExternArgs::Immutable
67 | ExternArgs::Strict
68 | ExternArgs::Stable
69 | ExternArgs::Volatile
70 | ExternArgs::Raw
71 | ExternArgs::NoGuard
72 | ExternArgs::SecurityDefiner
73 | ExternArgs::SecurityInvoker
74 | ExternArgs::ParallelSafe
75 | ExternArgs::ParallelUnsafe
76 | ExternArgs::ParallelRestricted => {
77 quote! { ::pgrx::pgrx_sql_entity_graph::section::u8_len() }
78 }
79 ExternArgs::ShouldPanic(value)
80 | ExternArgs::Schema(value)
81 | ExternArgs::Name(value)
82 | ExternArgs::Cost(value) => quote! {
83 ::pgrx::pgrx_sql_entity_graph::section::u8_len()
84 + ::pgrx::pgrx_sql_entity_graph::section::str_len(#value)
85 },
86 ExternArgs::Support(item) => {
87 let item_len = item.section_len_tokens();
88 quote! {
89 ::pgrx::pgrx_sql_entity_graph::section::u8_len() + (#item_len)
90 }
91 }
92 ExternArgs::Requires(items) => {
93 let item_lens = items.iter().map(PositioningRef::section_len_tokens);
94 quote! {
95 ::pgrx::pgrx_sql_entity_graph::section::u8_len()
96 + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
97 #( #item_lens ),*
98 ])
99 }
100 }
101 }
102 }
103
104 pub fn section_writer_tokens(&self, writer: TokenStream) -> TokenStream {
105 match self {
106 ExternArgs::CreateOrReplace => {
107 quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_CREATE_OR_REPLACE) }
108 }
109 ExternArgs::Immutable => {
110 quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_IMMUTABLE) }
111 }
112 ExternArgs::Strict => {
113 quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_STRICT) }
114 }
115 ExternArgs::Stable => {
116 quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_STABLE) }
117 }
118 ExternArgs::Volatile => {
119 quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_VOLATILE) }
120 }
121 ExternArgs::Raw => {
122 quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_RAW) }
123 }
124 ExternArgs::NoGuard => {
125 quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_NO_GUARD) }
126 }
127 ExternArgs::SecurityDefiner => quote! {
128 #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SECURITY_DEFINER)
129 },
130 ExternArgs::SecurityInvoker => quote! {
131 #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SECURITY_INVOKER)
132 },
133 ExternArgs::ParallelSafe => quote! {
134 #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_SAFE)
135 },
136 ExternArgs::ParallelUnsafe => quote! {
137 #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_UNSAFE)
138 },
139 ExternArgs::ParallelRestricted => quote! {
140 #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_RESTRICTED)
141 },
142 ExternArgs::ShouldPanic(value) => quote! {
143 #writer
144 .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SHOULD_PANIC)
145 .str(#value)
146 },
147 ExternArgs::Schema(value) => quote! {
148 #writer
149 .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SCHEMA)
150 .str(#value)
151 },
152 ExternArgs::Support(item) => item.section_writer_tokens(quote! {
153 #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SUPPORT)
154 }),
155 ExternArgs::Name(value) => quote! {
156 #writer
157 .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_NAME)
158 .str(#value)
159 },
160 ExternArgs::Cost(value) => quote! {
161 #writer
162 .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_COST)
163 .str(#value)
164 },
165 ExternArgs::Requires(items) => {
166 let writer_ident = Ident::new("__pgrx_schema_writer", Span::mixed_site());
167 let item_writers =
168 items.iter().map(|item| item.section_writer_tokens(quote! { #writer_ident }));
169 let count = items.len();
170 quote! {
171 {
172 let #writer_ident = #writer
173 .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_REQUIRES)
174 .u32(#count as u32);
175 #( let #writer_ident = { #item_writers }; )*
176 #writer_ident
177 }
178 }
179 }
180 }
181 }
182}
183
184impl ToTokens for ExternArgs {
185 fn to_tokens(&self, tokens: &mut TokenStream) {
186 match self {
187 ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
188 ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
189 ExternArgs::Strict => tokens.append(format_ident!("Strict")),
190 ExternArgs::Stable => tokens.append(format_ident!("Stable")),
191 ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
192 ExternArgs::Raw => tokens.append(format_ident!("Raw")),
193 ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
194 ExternArgs::SecurityDefiner => tokens.append(format_ident!("SecurityDefiner")),
195 ExternArgs::SecurityInvoker => tokens.append(format_ident!("SecurityInvoker")),
196 ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
197 ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
198 ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
199 ExternArgs::ShouldPanic(_s) => tokens.append_all(quote! { Error(String::from("#_s")) }),
200 ExternArgs::Schema(_s) => tokens.append_all(quote! { Schema(String::from("#_s")) }),
201 ExternArgs::Support(item) => tokens.append_all(quote! { Support(#item) }),
202 ExternArgs::Name(_s) => tokens.append_all(quote! { Name(String::from("#_s")) }),
203 ExternArgs::Cost(_s) => tokens.append_all(quote! { Cost(String::from("#_s")) }),
204 ExternArgs::Requires(items) => {
205 tokens.append_all(quote! { Requires(vec![#(#items),*]) })
206 }
207 }
208 }
209}
210
211#[track_caller]
213pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
214 let mut args = HashSet::<ExternArgs>::new();
215 let mut itr = attr.into_iter();
216 while let Some(t) = itr.next() {
217 match t {
218 TokenTree::Group(g) => {
219 for arg in parse_extern_attributes(g.stream()).into_iter() {
220 args.insert(arg);
221 }
222 }
223 TokenTree::Ident(i) => {
224 let name = i.to_string();
225 match name.as_str() {
226 "create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
227 "immutable" => args.insert(ExternArgs::Immutable),
228 "strict" => args.insert(ExternArgs::Strict),
229 "stable" => args.insert(ExternArgs::Stable),
230 "volatile" => args.insert(ExternArgs::Volatile),
231 "raw" => args.insert(ExternArgs::Raw),
232 "no_guard" => args.insert(ExternArgs::NoGuard),
233 "security_invoker" => args.insert(ExternArgs::SecurityInvoker),
234 "security_definer" => args.insert(ExternArgs::SecurityDefiner),
235 "parallel_safe" => args.insert(ExternArgs::ParallelSafe),
236 "parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
237 "parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
238 "error" | "expected" => {
239 let _punc = itr.next().unwrap();
240 let literal = itr.next().unwrap();
241 let message = literal.to_string();
242 let message = unescape::unescape(&message).expect("failed to unescape");
243
244 let message = message[1..message.len() - 1].to_string();
246 args.insert(ExternArgs::ShouldPanic(message.to_string()))
247 }
248 "schema" => {
249 let _punc = itr.next().unwrap();
250 let literal = itr.next().unwrap();
251 let schema = literal.to_string();
252 let schema = unescape::unescape(&schema).expect("failed to unescape");
253
254 let schema = schema[1..schema.len() - 1].to_string();
256 args.insert(ExternArgs::Schema(schema.to_string()))
257 }
258 "name" => {
259 let _punc = itr.next().unwrap();
260 let literal = itr.next().unwrap();
261 let name = literal.to_string();
262 let name = unescape::unescape(&name).expect("failed to unescape");
263
264 let name = name[1..name.len() - 1].to_string();
266 args.insert(ExternArgs::Name(name.to_string()))
267 }
268 "sql" => {
270 let _punc = itr.next().unwrap();
271 let _value = itr.next().unwrap();
272 false
273 }
274 _ => false,
275 };
276 }
277 TokenTree::Punct(_) => {}
278 TokenTree::Literal(_) => {}
279 }
280 }
281 args
282}
283
284#[cfg(test)]
285mod tests {
286 use std::str::FromStr;
287
288 use crate::{ExternArgs, parse_extern_attributes};
289
290 #[test]
291 fn parse_args() {
292 let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
293 let ts = proc_macro2::TokenStream::from_str(s).unwrap();
294
295 let args = parse_extern_attributes(ts);
296 assert!(
297 args.contains(&ExternArgs::ShouldPanic("syntax error at or near \"THIS\"".to_string()))
298 );
299 }
300}