Skip to main content

pgrx_sql_entity_graph/extension_sql/
mod.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`pgrx::extension_sql!()` related macro expansion for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19pub mod entity;
20
21use crate::positioning_ref::PositioningRef;
22
23use crate::enrich::{CodeEnrichment, ToEntityGraphTokens, ToRustCodeTokens};
24use proc_macro2::{Ident, TokenStream as TokenStream2};
25use quote::{ToTokens, TokenStreamExt, format_ident, quote};
26use syn::parse::{Parse, ParseStream};
27use syn::punctuated::Punctuated;
28use syn::{LitStr, Token};
29
30/// A parsed `extension_sql_file!()` item.
31///
32/// It should be used with [`syn::parse::Parse`] functions.
33///
34/// Using [`quote::ToTokens`] will output the declaration for a [`ExtensionSqlEntity`][crate::ExtensionSqlEntity].
35///
36/// ```rust
37/// use syn::{Macro, parse::Parse, parse_quote, parse};
38/// use quote::{quote, ToTokens};
39/// use pgrx_sql_entity_graph::ExtensionSqlFile;
40///
41/// # fn main() -> eyre::Result<()> {
42/// use pgrx_sql_entity_graph::CodeEnrichment;
43/// let parsed: Macro = parse_quote! {
44///     extension_sql_file!("sql/example.sql", name = "example", bootstrap)
45/// };
46/// let inner_tokens = parsed.tokens;
47/// let inner: CodeEnrichment<ExtensionSqlFile> = parse_quote! {
48///     #inner_tokens
49/// };
50/// let sql_graph_entity_tokens = inner.to_token_stream();
51/// # Ok(())
52/// # }
53/// ```
54#[derive(Debug, Clone)]
55pub struct ExtensionSqlFile {
56    pub path: LitStr,
57    pub attrs: Punctuated<ExtensionSqlAttribute, Token![,]>,
58}
59
60impl ToEntityGraphTokens for ExtensionSqlFile {
61    fn to_entity_graph_tokens(&self) -> TokenStream2 {
62        let path = &self.path;
63        let mut name = None;
64        let mut bootstrap = false;
65        let mut finalize = false;
66        let mut requires: Vec<PositioningRef> = vec![];
67        let mut creates: Vec<SqlDeclared> = vec![];
68        for attr in &self.attrs {
69            match attr {
70                ExtensionSqlAttribute::Creates(items) => {
71                    creates.extend(items.iter().cloned());
72                }
73                ExtensionSqlAttribute::Requires(items) => {
74                    requires.extend(items.iter().cloned());
75                }
76                ExtensionSqlAttribute::Bootstrap => {
77                    bootstrap = true;
78                }
79                ExtensionSqlAttribute::Finalize => {
80                    finalize = true;
81                }
82                ExtensionSqlAttribute::Name(found_name) => {
83                    name = Some(found_name.value());
84                }
85            }
86        }
87        let name = name.unwrap_or(
88            std::path::PathBuf::from(path.value())
89                .file_stem()
90                .expect("No file name for extension_sql_file!()")
91                .to_str()
92                .expect("No UTF-8 file name for extension_sql_file!()")
93                .to_string(),
94        );
95        let require_lens = requires.iter().map(PositioningRef::section_len_tokens);
96        let create_lens = creates.iter().map(SqlDeclared::section_len_tokens);
97        let require_writers =
98            requires.iter().map(|item| item.section_writer_tokens(quote! { writer }));
99        let create_writers =
100            creates.iter().map(|item| item.section_writer_tokens(quote! { writer }));
101        let require_count = requires.len();
102        let create_count = creates.len();
103        let sql_graph_entity_fn_name = format_ident!("__pgrx_schema_sql_{}", name.clone());
104        let payload_len = quote! {
105            ::pgrx::pgrx_sql_entity_graph::section::u8_len()
106                + ::pgrx::pgrx_sql_entity_graph::section::str_len(include_str!(#path))
107                + ::pgrx::pgrx_sql_entity_graph::section::str_len(module_path!())
108                + ::pgrx::pgrx_sql_entity_graph::section::str_len(concat!(file!(), ':', line!()))
109                + ::pgrx::pgrx_sql_entity_graph::section::str_len(file!())
110                + ::pgrx::pgrx_sql_entity_graph::section::u32_len()
111                + ::pgrx::pgrx_sql_entity_graph::section::str_len(#name)
112                + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
113                + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
114                + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
115                    #( #require_lens ),*
116                ])
117                + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
118                    #( #create_lens ),*
119                ])
120        };
121        let total_len = quote! {
122            ::pgrx::pgrx_sql_entity_graph::section::u32_len() + (#payload_len)
123        };
124        quote! {
125            ::pgrx::pgrx_sql_entity_graph::__pgrx_schema_entry!(
126                #sql_graph_entity_fn_name,
127                #total_len,
128                {
129                    let writer = ::pgrx::pgrx_sql_entity_graph::section::EntryWriter::<{ #total_len }>::new()
130                        .u32((#payload_len) as u32)
131                        .u8(::pgrx::pgrx_sql_entity_graph::section::ENTITY_CUSTOM_SQL)
132                        .str(include_str!(#path))
133                        .str(module_path!())
134                        .str(concat!(file!(), ':', line!()))
135                        .str(file!())
136                        .u32(line!())
137                        .str(#name)
138                        .bool(#bootstrap)
139                        .bool(#finalize)
140                        .u32(#require_count as u32);
141                    #( let writer = { #require_writers }; )*
142                    let writer = writer.u32(#create_count as u32);
143                    #( let writer = { #create_writers }; )*
144                    writer.finish()
145                }
146            );
147        }
148    }
149}
150
151impl ToRustCodeTokens for ExtensionSqlFile {}
152
153impl Parse for CodeEnrichment<ExtensionSqlFile> {
154    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
155        let path = input.parse()?;
156        let _after_sql_comma: Option<Token![,]> = input.parse()?;
157        let attrs = input.parse_terminated(ExtensionSqlAttribute::parse, Token![,])?;
158        Ok(CodeEnrichment(ExtensionSqlFile { path, attrs }))
159    }
160}
161
162/// A parsed `extension_sql!()` item.
163///
164/// It should be used with [`syn::parse::Parse`] functions.
165///
166/// Using [`quote::ToTokens`] will output the declaration for a `pgrx::pgrx_sql_entity_graph::ExtensionSqlEntity`.
167///
168/// ```rust
169/// use syn::{Macro, parse::Parse, parse_quote, parse};
170/// use quote::{quote, ToTokens};
171/// use pgrx_sql_entity_graph::ExtensionSql;
172///
173/// # fn main() -> eyre::Result<()> {
174/// use pgrx_sql_entity_graph::CodeEnrichment;
175/// let parsed: Macro = parse_quote! {
176///     extension_sql!("-- Example content", name = "example", bootstrap)
177/// };
178/// let inner_tokens = parsed.tokens;
179/// let inner: CodeEnrichment<ExtensionSql> = parse_quote! {
180///     #inner_tokens
181/// };
182/// let sql_graph_entity_tokens = inner.to_token_stream();
183/// # Ok(())
184/// # }
185/// ```
186#[derive(Debug, Clone)]
187pub struct ExtensionSql {
188    pub sql: LitStr,
189    pub name: LitStr,
190    pub attrs: Punctuated<ExtensionSqlAttribute, Token![,]>,
191}
192
193impl ToEntityGraphTokens for ExtensionSql {
194    fn to_entity_graph_tokens(&self) -> TokenStream2 {
195        let sql = &self.sql;
196        let mut bootstrap = false;
197        let mut finalize = false;
198        let mut creates: Vec<SqlDeclared> = vec![];
199        let mut requires: Vec<PositioningRef> = vec![];
200        for attr in &self.attrs {
201            match attr {
202                ExtensionSqlAttribute::Requires(items) => {
203                    requires.extend(items.iter().cloned());
204                }
205                ExtensionSqlAttribute::Creates(items) => {
206                    creates.extend(items.iter().cloned());
207                }
208                ExtensionSqlAttribute::Bootstrap => {
209                    bootstrap = true;
210                }
211                ExtensionSqlAttribute::Finalize => {
212                    finalize = true;
213                }
214                ExtensionSqlAttribute::Name(_found_name) => (), // Already done
215            }
216        }
217        let name = &self.name;
218        let require_lens = requires.iter().map(PositioningRef::section_len_tokens);
219        let create_lens = creates.iter().map(SqlDeclared::section_len_tokens);
220        let require_writers =
221            requires.iter().map(|item| item.section_writer_tokens(quote! { writer }));
222        let create_writers =
223            creates.iter().map(|item| item.section_writer_tokens(quote! { writer }));
224        let require_count = requires.len();
225        let create_count = creates.len();
226        let sql_graph_entity_fn_name = format_ident!("__pgrx_schema_sql_{}", name.value());
227        let payload_len = quote! {
228            ::pgrx::pgrx_sql_entity_graph::section::u8_len()
229                + ::pgrx::pgrx_sql_entity_graph::section::str_len(#sql)
230                + ::pgrx::pgrx_sql_entity_graph::section::str_len(module_path!())
231                + ::pgrx::pgrx_sql_entity_graph::section::str_len(concat!(file!(), ':', line!()))
232                + ::pgrx::pgrx_sql_entity_graph::section::str_len(file!())
233                + ::pgrx::pgrx_sql_entity_graph::section::u32_len()
234                + ::pgrx::pgrx_sql_entity_graph::section::str_len(#name)
235                + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
236                + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
237                + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
238                    #( #require_lens ),*
239                ])
240                + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
241                    #( #create_lens ),*
242                ])
243        };
244        let total_len = quote! {
245            ::pgrx::pgrx_sql_entity_graph::section::u32_len() + (#payload_len)
246        };
247        quote! {
248            ::pgrx::pgrx_sql_entity_graph::__pgrx_schema_entry!(
249                #sql_graph_entity_fn_name,
250                #total_len,
251                {
252                    let writer = ::pgrx::pgrx_sql_entity_graph::section::EntryWriter::<{ #total_len }>::new()
253                        .u32((#payload_len) as u32)
254                        .u8(::pgrx::pgrx_sql_entity_graph::section::ENTITY_CUSTOM_SQL)
255                        .str(#sql)
256                        .str(module_path!())
257                        .str(concat!(file!(), ':', line!()))
258                        .str(file!())
259                        .u32(line!())
260                        .str(#name)
261                        .bool(#bootstrap)
262                        .bool(#finalize)
263                        .u32(#require_count as u32);
264                    #( let writer = { #require_writers }; )*
265                    let writer = writer.u32(#create_count as u32);
266                    #( let writer = { #create_writers }; )*
267                    writer.finish()
268                }
269            );
270        }
271    }
272}
273
274impl ToRustCodeTokens for ExtensionSql {}
275
276impl Parse for CodeEnrichment<ExtensionSql> {
277    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
278        let sql = input.parse()?;
279        let _after_sql_comma: Option<Token![,]> = input.parse()?;
280        let attrs = input.parse_terminated(ExtensionSqlAttribute::parse, Token![,])?;
281        let name = attrs.iter().rev().find_map(|attr| match attr {
282            ExtensionSqlAttribute::Name(found_name) => Some(found_name.clone()),
283            _ => None,
284        });
285        let name =
286            name.ok_or_else(|| syn::Error::new(input.span(), "expected `name` to be set"))?;
287        Ok(CodeEnrichment(ExtensionSql { sql, attrs, name }))
288    }
289}
290
291impl ToTokens for ExtensionSql {
292    fn to_tokens(&self, tokens: &mut TokenStream2) {
293        tokens.append_all(self.to_entity_graph_tokens())
294    }
295}
296
297#[derive(Debug, Clone)]
298pub enum ExtensionSqlAttribute {
299    Requires(Punctuated<PositioningRef, Token![,]>),
300    Creates(Punctuated<SqlDeclared, Token![,]>),
301    Bootstrap,
302    Finalize,
303    Name(LitStr),
304}
305
306impl Parse for ExtensionSqlAttribute {
307    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
308        let ident: Ident = input.parse()?;
309        let found = match ident.to_string().as_str() {
310            "creates" => {
311                let _eq: syn::token::Eq = input.parse()?;
312                let content;
313                let _bracket = syn::bracketed!(content in input);
314                Self::Creates(content.parse_terminated(SqlDeclared::parse, Token![,])?)
315            }
316            "requires" => {
317                let _eq: syn::token::Eq = input.parse()?;
318                let content;
319                let _bracket = syn::bracketed!(content in input);
320                Self::Requires(content.parse_terminated(PositioningRef::parse, Token![,])?)
321            }
322            "bootstrap" => Self::Bootstrap,
323            "finalize" => Self::Finalize,
324            "name" => {
325                let _eq: syn::token::Eq = input.parse()?;
326                Self::Name(input.parse()?)
327            }
328            other => {
329                return Err(syn::Error::new(
330                    ident.span(),
331                    format!("Unknown extension_sql attribute: {other}"),
332                ));
333            }
334        };
335        Ok(found)
336    }
337}
338
339#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
340pub enum SqlDeclared {
341    Type(String),
342    Enum(String),
343    Function(String),
344}
345
346impl ToEntityGraphTokens for SqlDeclared {
347    fn to_entity_graph_tokens(&self) -> TokenStream2 {
348        let (variant, identifier) = match &self {
349            SqlDeclared::Type(val) => ("Type", val),
350            SqlDeclared::Enum(val) => ("Enum", val),
351            SqlDeclared::Function(val) => ("Function", val),
352        };
353        let identifier_expr = self.section_identifier_tokens();
354        match self {
355            SqlDeclared::Type(_) | SqlDeclared::Enum(_) => {
356                let identifier_path: syn::Path =
357                    syn::parse_str(identifier).expect("type declaration path should parse");
358                quote! {
359                    ::pgrx::pgrx_sql_entity_graph::SqlDeclaredEntity::build_type::<#identifier_path>(#variant, #identifier_expr).unwrap()
360                }
361            }
362            SqlDeclared::Function(_) => quote! {
363                ::pgrx::pgrx_sql_entity_graph::SqlDeclaredEntity::build(#variant, #identifier_expr).unwrap()
364            },
365        }
366    }
367}
368
369impl ToRustCodeTokens for SqlDeclared {}
370
371impl SqlDeclared {
372    fn section_identifier_tokens(&self) -> TokenStream2 {
373        let identifier = match self {
374            SqlDeclared::Type(value) | SqlDeclared::Enum(value) | SqlDeclared::Function(value) => {
375                value
376            }
377        };
378        let identifier_split = identifier.split("::").collect::<Vec<_>>();
379        if identifier_split.len() == 1 {
380            let identifier_infer =
381                Ident::new(identifier_split.last().unwrap(), proc_macro2::Span::call_site());
382            quote! { concat!(module_path!(), "::", stringify!(#identifier_infer)) }
383        } else {
384            quote! { #identifier }
385        }
386    }
387
388    pub fn section_len_tokens(&self) -> TokenStream2 {
389        let identifier_expr = self.section_identifier_tokens();
390        match self {
391            SqlDeclared::Type(identifier) | SqlDeclared::Enum(identifier) => {
392                let identifier_path: syn::Path =
393                    syn::parse_str(identifier).expect("type declaration path should parse");
394                quote! {
395                    ::pgrx::pgrx_sql_entity_graph::section::u8_len()
396                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(#identifier_expr)
397                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(
398                            <#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::TYPE_IDENT
399                        )
400                        + ::pgrx::pgrx_sql_entity_graph::section::argument_sql_len(
401                            <#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::ARGUMENT_SQL
402                        )
403                }
404            }
405            SqlDeclared::Function(_) => quote! {
406                ::pgrx::pgrx_sql_entity_graph::section::u8_len()
407                    + ::pgrx::pgrx_sql_entity_graph::section::str_len(#identifier_expr)
408            },
409        }
410    }
411
412    pub fn section_writer_tokens(&self, writer: TokenStream2) -> TokenStream2 {
413        let identifier_expr = self.section_identifier_tokens();
414        match self {
415            SqlDeclared::Type(identifier) => {
416                let identifier_path: syn::Path =
417                    syn::parse_str(identifier).expect("type declaration path should parse");
418                quote! {
419                    #writer
420                        .u8(::pgrx::pgrx_sql_entity_graph::section::SQL_DECLARED_TYPE)
421                        .str(#identifier_expr)
422                        .str(<#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::TYPE_IDENT)
423                        .argument_sql(<#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::ARGUMENT_SQL)
424                }
425            }
426            SqlDeclared::Enum(identifier) => {
427                let identifier_path: syn::Path =
428                    syn::parse_str(identifier).expect("type declaration path should parse");
429                quote! {
430                    #writer
431                        .u8(::pgrx::pgrx_sql_entity_graph::section::SQL_DECLARED_ENUM)
432                        .str(#identifier_expr)
433                        .str(<#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::TYPE_IDENT)
434                        .argument_sql(<#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::ARGUMENT_SQL)
435                }
436            }
437            SqlDeclared::Function(_) => quote! {
438                #writer
439                    .u8(::pgrx::pgrx_sql_entity_graph::section::SQL_DECLARED_FUNCTION)
440                    .str(#identifier_expr)
441            },
442        }
443    }
444}
445
446impl Parse for SqlDeclared {
447    fn parse(input: ParseStream) -> syn::Result<Self> {
448        let variant: Ident = input.parse()?;
449        let content;
450        let _bracket: syn::token::Paren = syn::parenthesized!(content in input);
451        let identifier_path: syn::Path = content.parse()?;
452        let identifier_str = {
453            let mut identifier_segments = Vec::new();
454            for segment in identifier_path.segments {
455                identifier_segments.push(segment.ident.to_string())
456            }
457            identifier_segments.join("::")
458        };
459        let this = match variant.to_string().as_str() {
460            "Type" => SqlDeclared::Type(identifier_str),
461            "Enum" => SqlDeclared::Enum(identifier_str),
462            "Function" => SqlDeclared::Function(identifier_str),
463            _ => {
464                return Err(syn::Error::new(
465                    variant.span(),
466                    "SQL declared entities must be `Type(ident)`, `Enum(ident)`, or `Function(ident)`",
467                ));
468            }
469        };
470        Ok(this)
471    }
472}
473
474impl ToTokens for SqlDeclared {
475    fn to_tokens(&self, tokens: &mut TokenStream2) {
476        tokens.append_all(self.to_entity_graph_tokens())
477    }
478}