Skip to main content

pgrx_sql_entity_graph/
extern_args.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10use crate::PositioningRef;
11use proc_macro2::{Ident, Span, TokenStream, TokenTree};
12use quote::{ToTokens, TokenStreamExt, format_ident, quote};
13use std::collections::HashSet;
14
15#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
16pub enum ExternArgs {
17    CreateOrReplace,
18    Immutable,
19    Strict,
20    Stable,
21    Volatile,
22    Raw,
23    NoGuard,
24    SecurityDefiner,
25    SecurityInvoker,
26    ParallelSafe,
27    ParallelUnsafe,
28    ParallelRestricted,
29    ShouldPanic(String),
30    Schema(String),
31    Support(PositioningRef),
32    Name(String),
33    Cost(String),
34    Requires(Vec<PositioningRef>),
35}
36
37impl core::fmt::Display for ExternArgs {
38    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39        match self {
40            ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
41            ExternArgs::Immutable => write!(f, "IMMUTABLE"),
42            ExternArgs::Strict => write!(f, "STRICT"),
43            ExternArgs::Stable => write!(f, "STABLE"),
44            ExternArgs::Volatile => write!(f, "VOLATILE"),
45            ExternArgs::Raw => Ok(()),
46            ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
47            ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
48            ExternArgs::SecurityDefiner => write!(f, "SECURITY DEFINER"),
49            ExternArgs::SecurityInvoker => write!(f, "SECURITY INVOKER"),
50            ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
51            ExternArgs::Support(item) => write!(f, "{item}"),
52            ExternArgs::ShouldPanic(_) => Ok(()),
53            ExternArgs::NoGuard => Ok(()),
54            ExternArgs::Schema(_) => Ok(()),
55            ExternArgs::Name(_) => Ok(()),
56            ExternArgs::Cost(cost) => write!(f, "COST {cost}"),
57            ExternArgs::Requires(_) => Ok(()),
58        }
59    }
60}
61
62impl ExternArgs {
63    pub fn section_len_tokens(&self) -> TokenStream {
64        match self {
65            ExternArgs::CreateOrReplace
66            | ExternArgs::Immutable
67            | ExternArgs::Strict
68            | ExternArgs::Stable
69            | ExternArgs::Volatile
70            | ExternArgs::Raw
71            | ExternArgs::NoGuard
72            | ExternArgs::SecurityDefiner
73            | ExternArgs::SecurityInvoker
74            | ExternArgs::ParallelSafe
75            | ExternArgs::ParallelUnsafe
76            | ExternArgs::ParallelRestricted => {
77                quote! { ::pgrx::pgrx_sql_entity_graph::section::u8_len() }
78            }
79            ExternArgs::ShouldPanic(value)
80            | ExternArgs::Schema(value)
81            | ExternArgs::Name(value)
82            | ExternArgs::Cost(value) => quote! {
83                ::pgrx::pgrx_sql_entity_graph::section::u8_len()
84                    + ::pgrx::pgrx_sql_entity_graph::section::str_len(#value)
85            },
86            ExternArgs::Support(item) => {
87                let item_len = item.section_len_tokens();
88                quote! {
89                    ::pgrx::pgrx_sql_entity_graph::section::u8_len() + (#item_len)
90                }
91            }
92            ExternArgs::Requires(items) => {
93                let item_lens = items.iter().map(PositioningRef::section_len_tokens);
94                quote! {
95                    ::pgrx::pgrx_sql_entity_graph::section::u8_len()
96                        + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
97                            #( #item_lens ),*
98                        ])
99                }
100            }
101        }
102    }
103
104    pub fn section_writer_tokens(&self, writer: TokenStream) -> TokenStream {
105        match self {
106            ExternArgs::CreateOrReplace => {
107                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_CREATE_OR_REPLACE) }
108            }
109            ExternArgs::Immutable => {
110                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_IMMUTABLE) }
111            }
112            ExternArgs::Strict => {
113                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_STRICT) }
114            }
115            ExternArgs::Stable => {
116                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_STABLE) }
117            }
118            ExternArgs::Volatile => {
119                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_VOLATILE) }
120            }
121            ExternArgs::Raw => {
122                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_RAW) }
123            }
124            ExternArgs::NoGuard => {
125                quote! { #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_NO_GUARD) }
126            }
127            ExternArgs::SecurityDefiner => quote! {
128                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SECURITY_DEFINER)
129            },
130            ExternArgs::SecurityInvoker => quote! {
131                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SECURITY_INVOKER)
132            },
133            ExternArgs::ParallelSafe => quote! {
134                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_SAFE)
135            },
136            ExternArgs::ParallelUnsafe => quote! {
137                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_UNSAFE)
138            },
139            ExternArgs::ParallelRestricted => quote! {
140                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_PARALLEL_RESTRICTED)
141            },
142            ExternArgs::ShouldPanic(value) => quote! {
143                #writer
144                    .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SHOULD_PANIC)
145                    .str(#value)
146            },
147            ExternArgs::Schema(value) => quote! {
148                #writer
149                    .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SCHEMA)
150                    .str(#value)
151            },
152            ExternArgs::Support(item) => item.section_writer_tokens(quote! {
153                #writer.u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_SUPPORT)
154            }),
155            ExternArgs::Name(value) => quote! {
156                #writer
157                    .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_NAME)
158                    .str(#value)
159            },
160            ExternArgs::Cost(value) => quote! {
161                #writer
162                    .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_COST)
163                    .str(#value)
164            },
165            ExternArgs::Requires(items) => {
166                let writer_ident = Ident::new("__pgrx_schema_writer", Span::mixed_site());
167                let item_writers =
168                    items.iter().map(|item| item.section_writer_tokens(quote! { #writer_ident }));
169                let count = items.len();
170                quote! {
171                    {
172                        let #writer_ident = #writer
173                            .u8(::pgrx::pgrx_sql_entity_graph::section::EXTERN_ARG_REQUIRES)
174                            .u32(#count as u32);
175                        #( let #writer_ident = { #item_writers }; )*
176                        #writer_ident
177                    }
178                }
179            }
180        }
181    }
182}
183
184impl ToTokens for ExternArgs {
185    fn to_tokens(&self, tokens: &mut TokenStream) {
186        match self {
187            ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
188            ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
189            ExternArgs::Strict => tokens.append(format_ident!("Strict")),
190            ExternArgs::Stable => tokens.append(format_ident!("Stable")),
191            ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
192            ExternArgs::Raw => tokens.append(format_ident!("Raw")),
193            ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
194            ExternArgs::SecurityDefiner => tokens.append(format_ident!("SecurityDefiner")),
195            ExternArgs::SecurityInvoker => tokens.append(format_ident!("SecurityInvoker")),
196            ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
197            ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
198            ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
199            ExternArgs::ShouldPanic(_s) => tokens.append_all(quote! { Error(String::from("#_s")) }),
200            ExternArgs::Schema(_s) => tokens.append_all(quote! { Schema(String::from("#_s")) }),
201            ExternArgs::Support(item) => tokens.append_all(quote! { Support(#item) }),
202            ExternArgs::Name(_s) => tokens.append_all(quote! { Name(String::from("#_s")) }),
203            ExternArgs::Cost(_s) => tokens.append_all(quote! { Cost(String::from("#_s")) }),
204            ExternArgs::Requires(items) => {
205                tokens.append_all(quote! { Requires(vec![#(#items),*]) })
206            }
207        }
208    }
209}
210
211// This horror-story should be returning result
212#[track_caller]
213pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
214    let mut args = HashSet::<ExternArgs>::new();
215    let mut itr = attr.into_iter();
216    while let Some(t) = itr.next() {
217        match t {
218            TokenTree::Group(g) => {
219                for arg in parse_extern_attributes(g.stream()).into_iter() {
220                    args.insert(arg);
221                }
222            }
223            TokenTree::Ident(i) => {
224                let name = i.to_string();
225                match name.as_str() {
226                    "create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
227                    "immutable" => args.insert(ExternArgs::Immutable),
228                    "strict" => args.insert(ExternArgs::Strict),
229                    "stable" => args.insert(ExternArgs::Stable),
230                    "volatile" => args.insert(ExternArgs::Volatile),
231                    "raw" => args.insert(ExternArgs::Raw),
232                    "no_guard" => args.insert(ExternArgs::NoGuard),
233                    "security_invoker" => args.insert(ExternArgs::SecurityInvoker),
234                    "security_definer" => args.insert(ExternArgs::SecurityDefiner),
235                    "parallel_safe" => args.insert(ExternArgs::ParallelSafe),
236                    "parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
237                    "parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
238                    "error" | "expected" => {
239                        let _punc = itr.next().unwrap();
240                        let literal = itr.next().unwrap();
241                        let message = literal.to_string();
242                        let message = unescape::unescape(&message).expect("failed to unescape");
243
244                        // trim leading/trailing quotes around the literal
245                        let message = message[1..message.len() - 1].to_string();
246                        args.insert(ExternArgs::ShouldPanic(message.to_string()))
247                    }
248                    "schema" => {
249                        let _punc = itr.next().unwrap();
250                        let literal = itr.next().unwrap();
251                        let schema = literal.to_string();
252                        let schema = unescape::unescape(&schema).expect("failed to unescape");
253
254                        // trim leading/trailing quotes around the literal
255                        let schema = schema[1..schema.len() - 1].to_string();
256                        args.insert(ExternArgs::Schema(schema.to_string()))
257                    }
258                    "name" => {
259                        let _punc = itr.next().unwrap();
260                        let literal = itr.next().unwrap();
261                        let name = literal.to_string();
262                        let name = unescape::unescape(&name).expect("failed to unescape");
263
264                        // trim leading/trailing quotes around the literal
265                        let name = name[1..name.len() - 1].to_string();
266                        args.insert(ExternArgs::Name(name.to_string()))
267                    }
268                    // Recognized, but not handled as an extern argument
269                    "sql" => {
270                        let _punc = itr.next().unwrap();
271                        let _value = itr.next().unwrap();
272                        false
273                    }
274                    _ => false,
275                };
276            }
277            TokenTree::Punct(_) => {}
278            TokenTree::Literal(_) => {}
279        }
280    }
281    args
282}
283
284#[cfg(test)]
285mod tests {
286    use std::str::FromStr;
287
288    use crate::{ExternArgs, parse_extern_attributes};
289
290    #[test]
291    fn parse_args() {
292        let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
293        let ts = proc_macro2::TokenStream::from_str(s).unwrap();
294
295        let args = parse_extern_attributes(ts);
296        assert!(
297            args.contains(&ExternArgs::ShouldPanic("syntax error at or near \"THIS\"".to_string()))
298        );
299    }
300}