pgx_utils/sql_entity_graph/extension_sql/
mod.rs1pub mod entity;
19
20use crate::sql_entity_graph::positioning_ref::PositioningRef;
21
22use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
23use quote::{quote, ToTokens, TokenStreamExt};
24use syn::parse::{Parse, ParseStream};
25use syn::punctuated::Punctuated;
26use syn::{LitStr, Token};
27
28#[derive(Debug, Clone)]
52pub struct ExtensionSqlFile {
53 pub path: LitStr,
54 pub attrs: Punctuated<ExtensionSqlAttribute, Token![,]>,
55}
56
57impl Parse for ExtensionSqlFile {
58 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
59 let path = input.parse()?;
60 let _after_sql_comma: Option<Token![,]> = input.parse()?;
61 let attrs = input.parse_terminated(ExtensionSqlAttribute::parse)?;
62 Ok(Self { path, attrs })
63 }
64}
65
66impl ToTokens for ExtensionSqlFile {
67 fn to_tokens(&self, tokens: &mut TokenStream2) {
68 let path = &self.path;
69 let mut name = None;
70 let mut bootstrap = false;
71 let mut finalize = false;
72 let mut requires = vec![];
73 let mut creates = vec![];
74 for attr in &self.attrs {
75 match attr {
76 ExtensionSqlAttribute::Creates(items) => {
77 creates.append(&mut items.iter().map(|x| x.to_token_stream()).collect());
78 }
79 ExtensionSqlAttribute::Requires(items) => {
80 requires.append(&mut items.iter().map(|x| x.to_token_stream()).collect());
81 }
82 ExtensionSqlAttribute::Bootstrap => {
83 bootstrap = true;
84 }
85 ExtensionSqlAttribute::Finalize => {
86 finalize = true;
87 }
88 ExtensionSqlAttribute::Name(found_name) => {
89 name = Some(found_name.value());
90 }
91 }
92 }
93 let name = name.unwrap_or(
94 std::path::PathBuf::from(path.value())
95 .file_stem()
96 .expect("No file name for extension_sql_file!()")
97 .to_str()
98 .expect("No UTF-8 file name for extension_sql_file!()")
99 .to_string(),
100 );
101 let requires_iter = requires.iter();
102 let creates_iter = creates.iter();
103 let sql_graph_entity_fn_name =
104 syn::Ident::new(&format!("__pgx_internals_sql_{}", name.clone()), Span::call_site());
105 let inv = quote! {
106 #[no_mangle]
107 #[doc(hidden)]
108 pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgx::utils::sql_entity_graph::SqlGraphEntity {
109 extern crate alloc;
110 use alloc::vec::Vec;
111 use alloc::vec;
112 let submission = ::pgx::utils::sql_entity_graph::ExtensionSqlEntity {
113 sql: include_str!(#path),
114 module_path: module_path!(),
115 full_path: concat!(file!(), ':', line!()),
116 file: file!(),
117 line: line!(),
118 name: #name,
119 bootstrap: #bootstrap,
120 finalize: #finalize,
121 requires: vec![#(#requires_iter),*],
122 creates: vec![#(#creates_iter),*],
123 };
124 ::pgx::utils::sql_entity_graph::SqlGraphEntity::CustomSql(submission)
125 }
126 };
127 tokens.append_all(inv);
128 }
129}
130
131#[derive(Debug, Clone)]
155pub struct ExtensionSql {
156 pub sql: LitStr,
157 pub name: LitStr,
158 pub attrs: Punctuated<ExtensionSqlAttribute, Token![,]>,
159}
160
161impl Parse for ExtensionSql {
162 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
163 let sql = input.parse()?;
164 let _after_sql_comma: Option<Token![,]> = input.parse()?;
165 let attrs = input.parse_terminated(ExtensionSqlAttribute::parse)?;
166 let mut name = None;
167 for attr in &attrs {
168 match attr {
169 ExtensionSqlAttribute::Name(found_name) => {
170 name = Some(found_name.clone());
171 }
172 _ => (),
173 }
174 }
175 let name =
176 name.ok_or_else(|| syn::Error::new(input.span(), "expected `name` to be set"))?;
177 Ok(Self { sql, attrs, name })
178 }
179}
180
181impl ToTokens for ExtensionSql {
182 fn to_tokens(&self, tokens: &mut TokenStream2) {
183 let sql = &self.sql;
184 let mut bootstrap = false;
185 let mut finalize = false;
186 let mut creates = vec![];
187 let mut requires = vec![];
188 for attr in &self.attrs {
189 match attr {
190 ExtensionSqlAttribute::Requires(items) => {
191 requires.append(&mut items.iter().map(|x| x.to_token_stream()).collect());
192 }
193 ExtensionSqlAttribute::Creates(items) => {
194 creates.append(&mut items.iter().map(|x| x.to_token_stream()).collect());
195 }
196 ExtensionSqlAttribute::Bootstrap => {
197 bootstrap = true;
198 }
199 ExtensionSqlAttribute::Finalize => {
200 finalize = true;
201 }
202 ExtensionSqlAttribute::Name(_found_name) => (), }
204 }
205 let requires_iter = requires.iter();
206 let creates_iter = creates.iter();
207 let name = &self.name;
208
209 let sql_graph_entity_fn_name =
210 syn::Ident::new(&format!("__pgx_internals_sql_{}", name.value()), Span::call_site());
211 let inv = quote! {
212 #[no_mangle]
213 pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgx::utils::sql_entity_graph::SqlGraphEntity {
214 extern crate alloc;
215 use alloc::vec::Vec;
216 use alloc::vec;
217 let submission = ::pgx::utils::sql_entity_graph::ExtensionSqlEntity {
218 sql: #sql,
219 module_path: module_path!(),
220 full_path: concat!(file!(), ':', line!()),
221 file: file!(),
222 line: line!(),
223 name: #name,
224 bootstrap: #bootstrap,
225 finalize: #finalize,
226 requires: vec![#(#requires_iter),*],
227 creates: vec![#(#creates_iter),*],
228 };
229 ::pgx::utils::sql_entity_graph::SqlGraphEntity::CustomSql(submission)
230 }
231 };
232 tokens.append_all(inv);
233 }
234}
235
236#[derive(Debug, Clone)]
237pub enum ExtensionSqlAttribute {
238 Requires(Punctuated<PositioningRef, Token![,]>),
239 Creates(Punctuated<SqlDeclared, Token![,]>),
240 Bootstrap,
241 Finalize,
242 Name(LitStr),
243}
244
245impl Parse for ExtensionSqlAttribute {
246 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
247 let ident: Ident = input.parse()?;
248 let found = match ident.to_string().as_str() {
249 "creates" => {
250 let _eq: syn::token::Eq = input.parse()?;
251 let content;
252 let _bracket = syn::bracketed!(content in input);
253 Self::Creates(content.parse_terminated(SqlDeclared::parse)?)
254 }
255 "requires" => {
256 let _eq: syn::token::Eq = input.parse()?;
257 let content;
258 let _bracket = syn::bracketed!(content in input);
259 Self::Requires(content.parse_terminated(PositioningRef::parse)?)
260 }
261 "bootstrap" => Self::Bootstrap,
262 "finalize" => Self::Finalize,
263 "name" => {
264 let _eq: syn::token::Eq = input.parse()?;
265 Self::Name(input.parse()?)
266 }
267 other => {
268 return Err(syn::Error::new(
269 ident.span(),
270 &format!("Unknown extension_sql attribute: {}", other),
271 ))
272 }
273 };
274 Ok(found)
275 }
276}
277
278#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
279pub enum SqlDeclared {
280 Type(String),
281 Enum(String),
282 Function(String),
283}
284
285impl Parse for SqlDeclared {
286 fn parse(input: ParseStream) -> syn::Result<Self> {
287 let variant: Ident = input.parse()?;
288 let content;
289 let _bracket: syn::token::Paren = syn::parenthesized!(content in input);
290 let identifier_path: syn::Path = content.parse()?;
291 let identifier_str = {
292 let mut identifier_segments = Vec::new();
293 for segment in identifier_path.segments {
294 identifier_segments.push(segment.ident.to_string())
295 }
296 identifier_segments.join("::")
297 };
298 let this = match variant.to_string().as_str() {
299 "Type" => SqlDeclared::Type(identifier_str),
300 "Enum" => SqlDeclared::Enum(identifier_str),
301 "Function" => SqlDeclared::Function(identifier_str),
302 _ => return Err(syn::Error::new(
303 variant.span(),
304 "SQL declared entities must be `Type(ident)`, `Enum(ident)`, or `Function(ident)`",
305 )),
306 };
307 Ok(this)
308 }
309}
310
311impl ToTokens for SqlDeclared {
312 fn to_tokens(&self, tokens: &mut TokenStream2) {
313 let (variant, identifier) = match &self {
314 SqlDeclared::Type(val) => ("Type", val),
315 SqlDeclared::Enum(val) => ("Enum", val),
316 SqlDeclared::Function(val) => ("Function", val),
317 };
318 let identifier_split = identifier.split("::").collect::<Vec<_>>();
319 let identifier = if identifier_split.len() == 1 {
320 let identifier_infer =
321 Ident::new(identifier_split.last().unwrap(), proc_macro2::Span::call_site());
322 quote! { concat!(module_path!(), "::", stringify!(#identifier_infer)) }
323 } else {
324 quote! { stringify!(#identifier) }
325 };
326 let inv = quote! {
327 ::pgx::utils::sql_entity_graph::SqlDeclaredEntity::build(#variant, #identifier).unwrap()
328 };
329 tokens.append_all(inv);
330 }
331}