pgx_sql_entity_graph/
extern_args.rs

1use crate::PositioningRef;
2use proc_macro2::{TokenStream, TokenTree};
3use quote::{format_ident, quote, ToTokens, TokenStreamExt};
4use std::collections::HashSet;
5
6#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
7pub enum ExternArgs {
8    CreateOrReplace,
9    Immutable,
10    Strict,
11    Stable,
12    Volatile,
13    Raw,
14    NoGuard,
15    ParallelSafe,
16    ParallelUnsafe,
17    ParallelRestricted,
18    Error(String),
19    Schema(String),
20    Name(String),
21    Cost(String),
22    Requires(Vec<PositioningRef>),
23}
24
25impl core::fmt::Display for ExternArgs {
26    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27        match self {
28            ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
29            ExternArgs::Immutable => write!(f, "IMMUTABLE"),
30            ExternArgs::Strict => write!(f, "STRICT"),
31            ExternArgs::Stable => write!(f, "STABLE"),
32            ExternArgs::Volatile => write!(f, "VOLATILE"),
33            ExternArgs::Raw => Ok(()),
34            ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
35            ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
36            ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
37            ExternArgs::Error(_) => Ok(()),
38            ExternArgs::NoGuard => Ok(()),
39            ExternArgs::Schema(_) => Ok(()),
40            ExternArgs::Name(_) => Ok(()),
41            ExternArgs::Cost(cost) => write!(f, "COST {}", cost),
42            ExternArgs::Requires(_) => Ok(()),
43        }
44    }
45}
46
47impl ToTokens for ExternArgs {
48    fn to_tokens(&self, tokens: &mut TokenStream) {
49        match self {
50            ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
51            ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
52            ExternArgs::Strict => tokens.append(format_ident!("Strict")),
53            ExternArgs::Stable => tokens.append(format_ident!("Stable")),
54            ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
55            ExternArgs::Raw => tokens.append(format_ident!("Raw")),
56            ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
57            ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
58            ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
59            ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
60            ExternArgs::Error(_s) => {
61                tokens.append_all(
62                    quote! {
63                        Error(String::from("#_s"))
64                    }
65                    .to_token_stream(),
66                );
67            }
68            ExternArgs::Schema(_s) => {
69                tokens.append_all(
70                    quote! {
71                        Schema(String::from("#_s"))
72                    }
73                    .to_token_stream(),
74                );
75            }
76            ExternArgs::Name(_s) => {
77                tokens.append_all(
78                    quote! {
79                        Name(String::from("#_s"))
80                    }
81                    .to_token_stream(),
82                );
83            }
84            ExternArgs::Cost(_s) => {
85                tokens.append_all(
86                    quote! {
87                        Cost(String::from("#_s"))
88                    }
89                    .to_token_stream(),
90                );
91            }
92            ExternArgs::Requires(items) => {
93                tokens.append_all(
94                    quote! {
95                        Requires(vec![#(#items),*])
96                    }
97                    .to_token_stream(),
98                );
99            }
100        }
101    }
102}
103
104pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
105    let mut args = HashSet::<ExternArgs>::new();
106    let mut itr = attr.into_iter();
107    while let Some(t) = itr.next() {
108        match t {
109            TokenTree::Group(g) => {
110                for arg in parse_extern_attributes(g.stream()).into_iter() {
111                    args.insert(arg);
112                }
113            }
114            TokenTree::Ident(i) => {
115                let name = i.to_string();
116                match name.as_str() {
117                    "create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
118                    "immutable" => args.insert(ExternArgs::Immutable),
119                    "strict" => args.insert(ExternArgs::Strict),
120                    "stable" => args.insert(ExternArgs::Stable),
121                    "volatile" => args.insert(ExternArgs::Volatile),
122                    "raw" => args.insert(ExternArgs::Raw),
123                    "no_guard" => args.insert(ExternArgs::NoGuard),
124                    "parallel_safe" => args.insert(ExternArgs::ParallelSafe),
125                    "parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
126                    "parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
127                    "error" => {
128                        let _punc = itr.next().unwrap();
129                        let literal = itr.next().unwrap();
130                        let message = literal.to_string();
131                        let message = unescape::unescape(&message).expect("failed to unescape");
132
133                        // trim leading/trailing quotes around the literal
134                        let message = message[1..message.len() - 1].to_string();
135                        args.insert(ExternArgs::Error(message.to_string()))
136                    }
137                    "schema" => {
138                        let _punc = itr.next().unwrap();
139                        let literal = itr.next().unwrap();
140                        let schema = literal.to_string();
141                        let schema = unescape::unescape(&schema).expect("failed to unescape");
142
143                        // trim leading/trailing quotes around the literal
144                        let schema = schema[1..schema.len() - 1].to_string();
145                        args.insert(ExternArgs::Schema(schema.to_string()))
146                    }
147                    "name" => {
148                        let _punc = itr.next().unwrap();
149                        let literal = itr.next().unwrap();
150                        let name = literal.to_string();
151                        let name = unescape::unescape(&name).expect("failed to unescape");
152
153                        // trim leading/trailing quotes around the literal
154                        let name = name[1..name.len() - 1].to_string();
155                        args.insert(ExternArgs::Name(name.to_string()))
156                    }
157                    // Recognized, but not handled as an extern argument
158                    "sql" => {
159                        let _punc = itr.next().unwrap();
160                        let _value = itr.next().unwrap();
161                        false
162                    }
163                    _ => false,
164                };
165            }
166            TokenTree::Punct(_) => {}
167            TokenTree::Literal(_) => {}
168        }
169    }
170    args
171}
172
173#[cfg(test)]
174mod tests {
175    use std::str::FromStr;
176
177    use crate::{parse_extern_attributes, ExternArgs};
178
179    #[test]
180    fn parse_args() {
181        let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
182        let ts = proc_macro2::TokenStream::from_str(s).unwrap();
183
184        let args = parse_extern_attributes(ts);
185        assert!(args.contains(&ExternArgs::Error("syntax error at or near \"THIS\"".to_string())));
186    }
187}