pgx_sql_entity_graph/
extern_args.rs1use crate::PositioningRef;
2use proc_macro2::{TokenStream, TokenTree};
3use quote::{format_ident, quote, ToTokens, TokenStreamExt};
4use std::collections::HashSet;
5
6#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
7pub enum ExternArgs {
8 CreateOrReplace,
9 Immutable,
10 Strict,
11 Stable,
12 Volatile,
13 Raw,
14 NoGuard,
15 ParallelSafe,
16 ParallelUnsafe,
17 ParallelRestricted,
18 Error(String),
19 Schema(String),
20 Name(String),
21 Cost(String),
22 Requires(Vec<PositioningRef>),
23}
24
25impl core::fmt::Display for ExternArgs {
26 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27 match self {
28 ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
29 ExternArgs::Immutable => write!(f, "IMMUTABLE"),
30 ExternArgs::Strict => write!(f, "STRICT"),
31 ExternArgs::Stable => write!(f, "STABLE"),
32 ExternArgs::Volatile => write!(f, "VOLATILE"),
33 ExternArgs::Raw => Ok(()),
34 ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
35 ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
36 ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
37 ExternArgs::Error(_) => Ok(()),
38 ExternArgs::NoGuard => Ok(()),
39 ExternArgs::Schema(_) => Ok(()),
40 ExternArgs::Name(_) => Ok(()),
41 ExternArgs::Cost(cost) => write!(f, "COST {}", cost),
42 ExternArgs::Requires(_) => Ok(()),
43 }
44 }
45}
46
47impl ToTokens for ExternArgs {
48 fn to_tokens(&self, tokens: &mut TokenStream) {
49 match self {
50 ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
51 ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
52 ExternArgs::Strict => tokens.append(format_ident!("Strict")),
53 ExternArgs::Stable => tokens.append(format_ident!("Stable")),
54 ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
55 ExternArgs::Raw => tokens.append(format_ident!("Raw")),
56 ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
57 ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
58 ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
59 ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
60 ExternArgs::Error(_s) => {
61 tokens.append_all(
62 quote! {
63 Error(String::from("#_s"))
64 }
65 .to_token_stream(),
66 );
67 }
68 ExternArgs::Schema(_s) => {
69 tokens.append_all(
70 quote! {
71 Schema(String::from("#_s"))
72 }
73 .to_token_stream(),
74 );
75 }
76 ExternArgs::Name(_s) => {
77 tokens.append_all(
78 quote! {
79 Name(String::from("#_s"))
80 }
81 .to_token_stream(),
82 );
83 }
84 ExternArgs::Cost(_s) => {
85 tokens.append_all(
86 quote! {
87 Cost(String::from("#_s"))
88 }
89 .to_token_stream(),
90 );
91 }
92 ExternArgs::Requires(items) => {
93 tokens.append_all(
94 quote! {
95 Requires(vec![#(#items),*])
96 }
97 .to_token_stream(),
98 );
99 }
100 }
101 }
102}
103
104pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
105 let mut args = HashSet::<ExternArgs>::new();
106 let mut itr = attr.into_iter();
107 while let Some(t) = itr.next() {
108 match t {
109 TokenTree::Group(g) => {
110 for arg in parse_extern_attributes(g.stream()).into_iter() {
111 args.insert(arg);
112 }
113 }
114 TokenTree::Ident(i) => {
115 let name = i.to_string();
116 match name.as_str() {
117 "create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
118 "immutable" => args.insert(ExternArgs::Immutable),
119 "strict" => args.insert(ExternArgs::Strict),
120 "stable" => args.insert(ExternArgs::Stable),
121 "volatile" => args.insert(ExternArgs::Volatile),
122 "raw" => args.insert(ExternArgs::Raw),
123 "no_guard" => args.insert(ExternArgs::NoGuard),
124 "parallel_safe" => args.insert(ExternArgs::ParallelSafe),
125 "parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
126 "parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
127 "error" => {
128 let _punc = itr.next().unwrap();
129 let literal = itr.next().unwrap();
130 let message = literal.to_string();
131 let message = unescape::unescape(&message).expect("failed to unescape");
132
133 let message = message[1..message.len() - 1].to_string();
135 args.insert(ExternArgs::Error(message.to_string()))
136 }
137 "schema" => {
138 let _punc = itr.next().unwrap();
139 let literal = itr.next().unwrap();
140 let schema = literal.to_string();
141 let schema = unescape::unescape(&schema).expect("failed to unescape");
142
143 let schema = schema[1..schema.len() - 1].to_string();
145 args.insert(ExternArgs::Schema(schema.to_string()))
146 }
147 "name" => {
148 let _punc = itr.next().unwrap();
149 let literal = itr.next().unwrap();
150 let name = literal.to_string();
151 let name = unescape::unescape(&name).expect("failed to unescape");
152
153 let name = name[1..name.len() - 1].to_string();
155 args.insert(ExternArgs::Name(name.to_string()))
156 }
157 "sql" => {
159 let _punc = itr.next().unwrap();
160 let _value = itr.next().unwrap();
161 false
162 }
163 _ => false,
164 };
165 }
166 TokenTree::Punct(_) => {}
167 TokenTree::Literal(_) => {}
168 }
169 }
170 args
171}
172
173#[cfg(test)]
174mod tests {
175 use std::str::FromStr;
176
177 use crate::{parse_extern_attributes, ExternArgs};
178
179 #[test]
180 fn parse_args() {
181 let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
182 let ts = proc_macro2::TokenStream::from_str(s).unwrap();
183
184 let args = parse_extern_attributes(ts);
185 assert!(args.contains(&ExternArgs::Error("syntax error at or near \"THIS\"".to_string())));
186 }
187}