Skip to main content

pgrx_sql_entity_graph/pg_extern/
attribute.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_extern]` related attributes for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18use crate::extern_args::ExternArgs;
19use crate::positioning_ref::PositioningRef;
20use crate::to_sql::ToSqlConfig;
21use proc_macro2::TokenStream as TokenStream2;
22use quote::{ToTokens, TokenStreamExt, quote};
23use syn::Token;
24use syn::parse::{Parse, ParseStream};
25use syn::punctuated::Punctuated;
26use syn::spanned::Spanned;
27
28#[derive(Debug, Clone, Hash, Eq, PartialEq)]
29pub enum Attribute {
30    Immutable,
31    Strict,
32    Stable,
33    Volatile,
34    Raw,
35    NoGuard,
36    CreateOrReplace,
37    SecurityDefiner,
38    SecurityInvoker,
39    ParallelSafe,
40    ParallelUnsafe,
41    ParallelRestricted,
42    ShouldPanic(syn::LitStr),
43    Schema(syn::LitStr),
44    Support(PositioningRef),
45    Name(syn::LitStr),
46    Cost(Box<syn::Expr>),
47    Requires(Punctuated<PositioningRef, Token![,]>),
48    Sql(ToSqlConfig),
49}
50
51impl ToTokens for Attribute {
52    fn to_tokens(&self, tokens: &mut TokenStream2) {
53        let quoted = match self {
54            Self::Immutable => quote! { immutable },
55            Self::Strict => quote! { strict },
56            Self::Stable => quote! { stable },
57            Self::Volatile => quote! { volatile },
58            Self::Raw => quote! { raw },
59            Self::NoGuard => quote! { no_guard },
60            Self::CreateOrReplace => quote! { create_or_replace },
61            Self::SecurityDefiner => {
62                quote! {security_definer}
63            }
64            Self::SecurityInvoker => {
65                quote! {security_invoker}
66            }
67            Self::ParallelSafe => {
68                quote! { parallel_safe }
69            }
70            Self::ParallelUnsafe => {
71                quote! { parallel_unsafe }
72            }
73            Self::ParallelRestricted => {
74                quote! { parallel_restricted }
75            }
76            Self::ShouldPanic(s) => {
77                quote! { expected = #s }
78            }
79            Self::Schema(s) => {
80                quote! { schema = #s }
81            }
82            Self::Support(item) => {
83                quote! { support = #item }
84            }
85            Self::Name(s) => {
86                quote! { name = #s }
87            }
88            Self::Cost(s) => {
89                quote! { cost = #s }
90            }
91            Self::Requires(items) => {
92                let items_iter = items.iter().map(|x| x.to_token_stream()).collect::<Vec<_>>();
93                quote! { requires = [#(#items_iter),*] }
94            }
95            // This attribute is handled separately
96            Self::Sql(to_sql_config) => {
97                quote! { sql = #to_sql_config }
98            }
99        };
100        tokens.append_all(quoted);
101    }
102}
103
104impl Attribute {
105    /// Convert this attribute into an [`ExternArgs`] for SQL emission.
106    ///
107    /// Returns `None` for attributes (currently only [`Attribute::Sql`]) that are handled outside the extern-args pipeline.
108    pub fn as_extern_arg(&self) -> Option<ExternArgs> {
109        Some(match self {
110            Self::CreateOrReplace => ExternArgs::CreateOrReplace,
111            Self::Immutable => ExternArgs::Immutable,
112            Self::Strict => ExternArgs::Strict,
113            Self::Stable => ExternArgs::Stable,
114            Self::Volatile => ExternArgs::Volatile,
115            Self::Raw => ExternArgs::Raw,
116            Self::NoGuard => ExternArgs::NoGuard,
117            Self::SecurityDefiner => ExternArgs::SecurityDefiner,
118            Self::SecurityInvoker => ExternArgs::SecurityInvoker,
119            Self::ParallelSafe => ExternArgs::ParallelSafe,
120            Self::ParallelUnsafe => ExternArgs::ParallelUnsafe,
121            Self::ParallelRestricted => ExternArgs::ParallelRestricted,
122            Self::ShouldPanic(v) => ExternArgs::ShouldPanic(v.value()),
123            Self::Schema(v) => ExternArgs::Schema(v.value()),
124            Self::Support(v) => ExternArgs::Support(v.clone()),
125            Self::Name(v) => ExternArgs::Name(v.value()),
126            Self::Cost(v) => ExternArgs::Cost(v.to_token_stream().to_string()),
127            Self::Requires(items) => ExternArgs::Requires(items.iter().cloned().collect()),
128            Self::Sql(_) => return None,
129        })
130    }
131}
132
133impl Parse for Attribute {
134    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
135        let ident: syn::Ident = input.parse()?;
136        let found = match ident.to_string().as_str() {
137            "immutable" => Self::Immutable,
138            "strict" => Self::Strict,
139            "stable" => Self::Stable,
140            "volatile" => Self::Volatile,
141            "raw" => Self::Raw,
142            "no_guard" => Self::NoGuard,
143            "create_or_replace" => Self::CreateOrReplace,
144            "security_definer" => Self::SecurityDefiner,
145            "security_invoker" => Self::SecurityInvoker,
146            "parallel_safe" => Self::ParallelSafe,
147            "parallel_unsafe" => Self::ParallelUnsafe,
148            "parallel_restricted" => Self::ParallelRestricted,
149            "error" | "expected" => {
150                let _eq: Token![=] = input.parse()?;
151                let literal: syn::LitStr = input.parse()?;
152                Self::ShouldPanic(literal)
153            }
154            "schema" => {
155                let _eq: Token![=] = input.parse()?;
156                let literal: syn::LitStr = input.parse()?;
157                Self::Schema(literal)
158            }
159            "support" => {
160                let _eq: Token![=] = input.parse()?;
161                let item: PositioningRef = input.parse()?;
162                Self::Support(item)
163            }
164            "name" => {
165                let _eq: Token![=] = input.parse()?;
166                let literal: syn::LitStr = input.parse()?;
167                Self::Name(literal)
168            }
169            "cost" => {
170                let _eq: Token![=] = input.parse()?;
171                let literal: syn::Expr = input.parse()?;
172                Self::Cost(Box::new(literal))
173            }
174            "requires" => {
175                let _eq: syn::token::Eq = input.parse()?;
176                let content;
177                let _bracket = syn::bracketed!(content in input);
178                Self::Requires(content.parse_terminated(PositioningRef::parse, Token![,])?)
179            }
180            "sql" => {
181                use crate::pgrx_attribute::ArgValue;
182                use syn::Lit;
183
184                let _eq: Token![=] = input.parse()?;
185                match input.parse::<ArgValue>()? {
186                    ArgValue::Path(path) => {
187                        return Err(syn::Error::new(
188                            path.span(),
189                            "expected boolean or string literal",
190                        ));
191                    }
192                    ArgValue::Lit(Lit::Bool(b)) => Self::Sql(ToSqlConfig::from(b.value)),
193                    ArgValue::Lit(Lit::Str(s)) => Self::Sql(ToSqlConfig::from(s)),
194                    ArgValue::Lit(other) => {
195                        // FIXME: add a ui test for this
196                        return Err(syn::Error::new(
197                            other.span(),
198                            "expected boolean or string literal",
199                        ));
200                    }
201                }
202            }
203            e => {
204                // FIXME: add a UI test for this
205                return Err(syn::Error::new(
206                    ident.span(),
207                    format!("Invalid option `{e}` inside `{ident} {input}`"),
208                ));
209            }
210        };
211        Ok(found)
212    }
213}
214
215#[cfg(test)]
216mod tests {
217
218    use super::Attribute;
219    use std::str::FromStr;
220    use syn::parse::Parser;
221    use syn::punctuated::Punctuated;
222
223    fn parse(src: &str) -> Punctuated<Attribute, syn::Token![,]> {
224        let ts = proc_macro2::TokenStream::from_str(src).expect("tokenize");
225        Punctuated::<Attribute, syn::Token![,]>::parse_terminated.parse2(ts).expect("parse")
226    }
227
228    fn expected_value(attrs: &Punctuated<Attribute, syn::Token![,]>) -> Option<String> {
229        attrs.iter().find_map(|a| match a {
230            Attribute::ShouldPanic(lit) => Some(lit.value()),
231            _ => None,
232        })
233    }
234
235    #[test]
236    fn plain_string_expected() {
237        let attrs = parse(r#"expected = "syntax error""#);
238        assert_eq!(expected_value(&attrs).as_deref(), Some("syntax error"));
239    }
240
241    #[test]
242    fn escaped_quotes_in_plain_string() {
243        let attrs = parse(r#"expected = "syntax error at or near \"THIS\"""#);
244        assert_eq!(expected_value(&attrs).as_deref(), Some(r#"syntax error at or near "THIS""#),);
245    }
246
247    #[test]
248    fn raw_string_with_embedded_quotes() {
249        // The bug we are pinning: the old walker would have produced `#"foo "bar""#` (raw-string delimiters leaking into the value).
250        let attrs = parse(r###"expected = r#"foo "bar""#"###);
251        assert_eq!(expected_value(&attrs).as_deref(), Some(r#"foo "bar""#));
252    }
253
254    #[test]
255    fn raw_string_with_nested_hashes() {
256        let attrs = parse(r####"expected = r##"weird"#text"##"####);
257        assert_eq!(expected_value(&attrs).as_deref(), Some(r##"weird"#text"##));
258    }
259
260    #[test]
261    fn error_alias_works_like_expected() {
262        let attrs = parse(r#"error = "boom""#);
263        assert_eq!(expected_value(&attrs).as_deref(), Some("boom"));
264    }
265
266    #[test]
267    fn other_attrs_alongside_expected_do_not_interfere() {
268        let attrs = parse(r#"immutable, expected = "ok", strict"#);
269        assert_eq!(expected_value(&attrs).as_deref(), Some("ok"));
270        assert!(attrs.iter().any(|a| matches!(a, Attribute::Immutable)));
271        assert!(attrs.iter().any(|a| matches!(a, Attribute::Strict)));
272    }
273
274    #[test]
275    fn malformed_input_is_a_syn_error_not_a_panic() {
276        let ts = proc_macro2::TokenStream::from_str("expected").expect("tokenize");
277        let result = Punctuated::<Attribute, syn::Token![,]>::parse_terminated.parse2(ts);
278        assert!(result.is_err(), "expected = is required");
279    }
280}