Skip to main content

pgrx_sql_entity_graph/extension_sql/
mod.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`pgrx::extension_sql!()` related macro expansion for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19pub mod entity;
20
21use crate::positioning_ref::PositioningRef;
22
23use crate::enrich::{CodeEnrichment, ToEntityGraphTokens, ToRustCodeTokens};
24use proc_macro2::{Ident, TokenStream as TokenStream2};
25use quote::{ToTokens, TokenStreamExt, format_ident, quote};
26use syn::parse::{Parse, ParseStream};
27use syn::punctuated::Punctuated;
28use syn::{LitStr, Token};
29
30/// A parsed `extension_sql_file!()` item.
31///
32/// It should be used with [`syn::parse::Parse`] functions.
33///
34/// Using [`quote::ToTokens`] will output the declaration for a [`ExtensionSqlEntity`][crate::ExtensionSqlEntity].
35///
36/// ```rust
37/// use syn::{Macro, parse::Parse, parse_quote, parse};
38/// use quote::{quote, ToTokens};
39/// use pgrx_sql_entity_graph::ExtensionSqlFile;
40///
41/// # fn main() -> eyre::Result<()> {
42/// use pgrx_sql_entity_graph::CodeEnrichment;
43/// let parsed: Macro = parse_quote! {
44///     extension_sql_file!("sql/example.sql", name = "example", bootstrap)
45/// };
46/// let inner_tokens = parsed.tokens;
47/// let inner: CodeEnrichment<ExtensionSqlFile> = parse_quote! {
48///     #inner_tokens
49/// };
50/// let sql_graph_entity_tokens = inner.to_token_stream();
51/// # Ok(())
52/// # }
53/// ```
54#[derive(Debug, Clone)]
55pub struct ExtensionSqlFile {
56    pub path: LitStr,
57    pub attrs: Punctuated<ExtensionSqlAttribute, Token![,]>,
58}
59
60impl ToEntityGraphTokens for ExtensionSqlFile {
61    fn to_entity_graph_tokens(&self) -> TokenStream2 {
62        let path = &self.path;
63        let mut name = None;
64        let mut bootstrap = false;
65        let mut finalize = false;
66        let mut requires: Vec<PositioningRef> = vec![];
67        let mut creates: Vec<SqlDeclared> = vec![];
68        for attr in &self.attrs {
69            match attr {
70                ExtensionSqlAttribute::Creates(items) => {
71                    creates.extend(items.iter().cloned());
72                }
73                ExtensionSqlAttribute::Requires(items) => {
74                    requires.extend(items.iter().cloned());
75                }
76                ExtensionSqlAttribute::Bootstrap => {
77                    bootstrap = true;
78                }
79                ExtensionSqlAttribute::Finalize => {
80                    finalize = true;
81                }
82                ExtensionSqlAttribute::Name(found_name) => {
83                    name = Some(found_name.value());
84                }
85            }
86        }
87        let name = name.unwrap_or(
88            std::path::PathBuf::from(path.value())
89                .file_stem()
90                .expect("No file name for extension_sql_file!()")
91                .to_str()
92                .expect("No UTF-8 file name for extension_sql_file!()")
93                .to_string(),
94        );
95        let require_lens = requires.iter().map(PositioningRef::section_len_tokens);
96        let create_lens = creates.iter().map(SqlDeclared::section_len_tokens);
97        let require_writers =
98            requires.iter().map(|item| item.section_writer_tokens(quote! { writer }));
99        let create_writers =
100            creates.iter().map(|item| item.section_writer_tokens(quote! { writer }));
101        let require_count = requires.len();
102        let create_count = creates.len();
103        let sql_graph_entity_fn_name = format_ident!("__pgrx_schema_sql_{}", name.clone());
104        let payload_len = quote! {
105            ::pgrx::pgrx_sql_entity_graph::section::u8_len()
106                + ::pgrx::pgrx_sql_entity_graph::section::str_len(include_str!(#path))
107                + ::pgrx::pgrx_sql_entity_graph::section::str_len(module_path!())
108                + ::pgrx::pgrx_sql_entity_graph::section::str_len(concat!(file!(), ':', line!()))
109                + ::pgrx::pgrx_sql_entity_graph::section::str_len(file!())
110                + ::pgrx::pgrx_sql_entity_graph::section::u32_len()
111                + ::pgrx::pgrx_sql_entity_graph::section::str_len(#name)
112                + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
113                + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
114                + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
115                    #( #require_lens ),*
116                ])
117                + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
118                    #( #create_lens ),*
119                ])
120        };
121        let total_len = quote! {
122            ::pgrx::pgrx_sql_entity_graph::section::u32_len() + (#payload_len)
123        };
124        quote! {
125            ::pgrx::pgrx_sql_entity_graph::__pgrx_schema_entry!(
126                #sql_graph_entity_fn_name,
127                #total_len,
128                {
129                    let writer = ::pgrx::pgrx_sql_entity_graph::section::EntryWriter::<{ #total_len }>::new()
130                        .u32((#payload_len) as u32)
131                        .u8(::pgrx::pgrx_sql_entity_graph::section::ENTITY_CUSTOM_SQL)
132                        .str(include_str!(#path))
133                        .str(module_path!())
134                        .str(concat!(file!(), ':', line!()))
135                        .str(file!())
136                        .u32(line!())
137                        .str(#name)
138                        .bool(#bootstrap)
139                        .bool(#finalize)
140                        .u32(#require_count as u32);
141                    #( let writer = { #require_writers }; )*
142                    let writer = writer.u32(#create_count as u32);
143                    #( let writer = { #create_writers }; )*
144                    writer.finish()
145                }
146            );
147        }
148    }
149}
150
151impl ToRustCodeTokens for ExtensionSqlFile {}
152
153impl Parse for CodeEnrichment<ExtensionSqlFile> {
154    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
155        let path = input.parse()?;
156        let _after_sql_comma: Option<Token![,]> = input.parse()?;
157        let attrs = input.parse_terminated(ExtensionSqlAttribute::parse, Token![,])?;
158        Ok(Self(ExtensionSqlFile { path, attrs }))
159    }
160}
161
162/// A parsed `extension_sql!()` item.
163///
164/// It should be used with [`syn::parse::Parse`] functions.
165///
166/// Using [`quote::ToTokens`] will output the declaration for a `pgrx::pgrx_sql_entity_graph::ExtensionSqlEntity`.
167///
168/// ```rust
169/// use syn::{Macro, parse::Parse, parse_quote, parse};
170/// use quote::{quote, ToTokens};
171/// use pgrx_sql_entity_graph::ExtensionSql;
172///
173/// # fn main() -> eyre::Result<()> {
174/// use pgrx_sql_entity_graph::CodeEnrichment;
175/// let parsed: Macro = parse_quote! {
176///     extension_sql!("-- Example content", name = "example", bootstrap)
177/// };
178/// let inner_tokens = parsed.tokens;
179/// let inner: CodeEnrichment<ExtensionSql> = parse_quote! {
180///     #inner_tokens
181/// };
182/// let sql_graph_entity_tokens = inner.to_token_stream();
183/// # Ok(())
184/// # }
185/// ```
186#[derive(Debug, Clone)]
187pub struct ExtensionSql {
188    pub sql: LitStr,
189    pub name: LitStr,
190    pub attrs: Punctuated<ExtensionSqlAttribute, Token![,]>,
191}
192
193impl ToEntityGraphTokens for ExtensionSql {
194    fn to_entity_graph_tokens(&self) -> TokenStream2 {
195        let sql = &self.sql;
196        let mut bootstrap = false;
197        let mut finalize = false;
198        let mut creates: Vec<SqlDeclared> = vec![];
199        let mut requires: Vec<PositioningRef> = vec![];
200        for attr in &self.attrs {
201            match attr {
202                ExtensionSqlAttribute::Requires(items) => {
203                    requires.extend(items.iter().cloned());
204                }
205                ExtensionSqlAttribute::Creates(items) => {
206                    creates.extend(items.iter().cloned());
207                }
208                ExtensionSqlAttribute::Bootstrap => {
209                    bootstrap = true;
210                }
211                ExtensionSqlAttribute::Finalize => {
212                    finalize = true;
213                }
214                ExtensionSqlAttribute::Name(_found_name) => (), // Already done
215            }
216        }
217        let name = &self.name;
218        let require_lens = requires.iter().map(PositioningRef::section_len_tokens);
219        let create_lens = creates.iter().map(SqlDeclared::section_len_tokens);
220        let require_writers =
221            requires.iter().map(|item| item.section_writer_tokens(quote! { writer }));
222        let create_writers =
223            creates.iter().map(|item| item.section_writer_tokens(quote! { writer }));
224        let require_count = requires.len();
225        let create_count = creates.len();
226        let sql_graph_entity_fn_name = format_ident!("__pgrx_schema_sql_{}", name.value());
227        let payload_len = quote! {
228            ::pgrx::pgrx_sql_entity_graph::section::u8_len()
229                + ::pgrx::pgrx_sql_entity_graph::section::str_len(#sql)
230                + ::pgrx::pgrx_sql_entity_graph::section::str_len(module_path!())
231                + ::pgrx::pgrx_sql_entity_graph::section::str_len(concat!(file!(), ':', line!()))
232                + ::pgrx::pgrx_sql_entity_graph::section::str_len(file!())
233                + ::pgrx::pgrx_sql_entity_graph::section::u32_len()
234                + ::pgrx::pgrx_sql_entity_graph::section::str_len(#name)
235                + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
236                + ::pgrx::pgrx_sql_entity_graph::section::bool_len()
237                + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
238                    #( #require_lens ),*
239                ])
240                + ::pgrx::pgrx_sql_entity_graph::section::list_len(&[
241                    #( #create_lens ),*
242                ])
243        };
244        let total_len = quote! {
245            ::pgrx::pgrx_sql_entity_graph::section::u32_len() + (#payload_len)
246        };
247        quote! {
248            ::pgrx::pgrx_sql_entity_graph::__pgrx_schema_entry!(
249                #sql_graph_entity_fn_name,
250                #total_len,
251                {
252                    let writer = ::pgrx::pgrx_sql_entity_graph::section::EntryWriter::<{ #total_len }>::new()
253                        .u32((#payload_len) as u32)
254                        .u8(::pgrx::pgrx_sql_entity_graph::section::ENTITY_CUSTOM_SQL)
255                        .str(#sql)
256                        .str(module_path!())
257                        .str(concat!(file!(), ':', line!()))
258                        .str(file!())
259                        .u32(line!())
260                        .str(#name)
261                        .bool(#bootstrap)
262                        .bool(#finalize)
263                        .u32(#require_count as u32);
264                    #( let writer = { #require_writers }; )*
265                    let writer = writer.u32(#create_count as u32);
266                    #( let writer = { #create_writers }; )*
267                    writer.finish()
268                }
269            );
270        }
271    }
272}
273
274impl ToRustCodeTokens for ExtensionSql {}
275
276impl Parse for CodeEnrichment<ExtensionSql> {
277    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
278        let sql = input.parse()?;
279        let _after_sql_comma: Option<Token![,]> = input.parse()?;
280        let attrs = input.parse_terminated(ExtensionSqlAttribute::parse, Token![,])?;
281        let name = attrs.iter().rev().find_map(|attr| match attr {
282            ExtensionSqlAttribute::Name(found_name) => Some(found_name.clone()),
283            _ => None,
284        });
285        let name =
286            name.ok_or_else(|| syn::Error::new(input.span(), "expected `name` to be set"))?;
287        Ok(Self(ExtensionSql { sql, attrs, name }))
288    }
289}
290
291impl ToTokens for ExtensionSql {
292    fn to_tokens(&self, tokens: &mut TokenStream2) {
293        tokens.append_all(self.to_entity_graph_tokens())
294    }
295}
296
297#[derive(Debug, Clone)]
298pub enum ExtensionSqlAttribute {
299    Requires(Punctuated<PositioningRef, Token![,]>),
300    Creates(Punctuated<SqlDeclared, Token![,]>),
301    Bootstrap,
302    Finalize,
303    Name(LitStr),
304}
305
306impl Parse for ExtensionSqlAttribute {
307    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
308        let ident: Ident = input.parse()?;
309        let found = match ident.to_string().as_str() {
310            "creates" => {
311                let _eq: syn::token::Eq = input.parse()?;
312                let content;
313                let _bracket = syn::bracketed!(content in input);
314                Self::Creates(content.parse_terminated(SqlDeclared::parse, Token![,])?)
315            }
316            "requires" => {
317                let _eq: syn::token::Eq = input.parse()?;
318                let content;
319                let _bracket = syn::bracketed!(content in input);
320                Self::Requires(content.parse_terminated(PositioningRef::parse, Token![,])?)
321            }
322            "bootstrap" => Self::Bootstrap,
323            "finalize" => Self::Finalize,
324            "name" => {
325                let _eq: syn::token::Eq = input.parse()?;
326                Self::Name(input.parse()?)
327            }
328            other => {
329                return Err(syn::Error::new(
330                    ident.span(),
331                    format!("Unknown extension_sql attribute: {other}"),
332                ));
333            }
334        };
335        Ok(found)
336    }
337}
338
339#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
340pub enum SqlDeclared {
341    Type(String),
342    Enum(String),
343    Function(String),
344}
345
346impl ToEntityGraphTokens for SqlDeclared {
347    fn to_entity_graph_tokens(&self) -> TokenStream2 {
348        let (variant, identifier) = match &self {
349            Self::Type(val) => ("Type", val),
350            Self::Enum(val) => ("Enum", val),
351            Self::Function(val) => ("Function", val),
352        };
353        let identifier_expr = self.section_identifier_tokens();
354        match self {
355            Self::Type(_) | Self::Enum(_) => {
356                let identifier_path: syn::Path =
357                    syn::parse_str(identifier).expect("type declaration path should parse");
358                quote! {
359                    ::pgrx::pgrx_sql_entity_graph::SqlDeclaredEntity::build_type::<#identifier_path>(#variant, #identifier_expr).unwrap()
360                }
361            }
362            Self::Function(_) => quote! {
363                ::pgrx::pgrx_sql_entity_graph::SqlDeclaredEntity::build(#variant, #identifier_expr).unwrap()
364            },
365        }
366    }
367}
368
369impl ToRustCodeTokens for SqlDeclared {}
370
371impl SqlDeclared {
372    fn section_identifier_tokens(&self) -> TokenStream2 {
373        let identifier = match self {
374            Self::Type(value) | Self::Enum(value) | Self::Function(value) => value,
375        };
376        let identifier_split = identifier.split("::").collect::<Vec<_>>();
377        if identifier_split.len() == 1 {
378            let identifier_infer =
379                Ident::new(identifier_split.last().unwrap(), proc_macro2::Span::call_site());
380            quote! { concat!(module_path!(), "::", stringify!(#identifier_infer)) }
381        } else {
382            quote! { #identifier }
383        }
384    }
385
386    pub fn section_len_tokens(&self) -> TokenStream2 {
387        let identifier_expr = self.section_identifier_tokens();
388        match self {
389            Self::Type(identifier) | Self::Enum(identifier) => {
390                let identifier_path: syn::Path =
391                    syn::parse_str(identifier).expect("type declaration path should parse");
392                quote! {
393                    ::pgrx::pgrx_sql_entity_graph::section::u8_len()
394                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(#identifier_expr)
395                        + ::pgrx::pgrx_sql_entity_graph::section::str_len(
396                            <#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::TYPE_IDENT
397                        )
398                        + ::pgrx::pgrx_sql_entity_graph::section::argument_sql_len(
399                            <#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::ARGUMENT_SQL
400                        )
401                }
402            }
403            Self::Function(_) => quote! {
404                ::pgrx::pgrx_sql_entity_graph::section::u8_len()
405                    + ::pgrx::pgrx_sql_entity_graph::section::str_len(#identifier_expr)
406            },
407        }
408    }
409
410    pub fn section_writer_tokens(&self, writer: TokenStream2) -> TokenStream2 {
411        let identifier_expr = self.section_identifier_tokens();
412        match self {
413            Self::Type(identifier) => {
414                let identifier_path: syn::Path =
415                    syn::parse_str(identifier).expect("type declaration path should parse");
416                quote! {
417                    #writer
418                        .u8(::pgrx::pgrx_sql_entity_graph::section::SQL_DECLARED_TYPE)
419                        .str(#identifier_expr)
420                        .str(<#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::TYPE_IDENT)
421                        .argument_sql(<#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::ARGUMENT_SQL)
422                }
423            }
424            Self::Enum(identifier) => {
425                let identifier_path: syn::Path =
426                    syn::parse_str(identifier).expect("type declaration path should parse");
427                quote! {
428                    #writer
429                        .u8(::pgrx::pgrx_sql_entity_graph::section::SQL_DECLARED_ENUM)
430                        .str(#identifier_expr)
431                        .str(<#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::TYPE_IDENT)
432                        .argument_sql(<#identifier_path as ::pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable>::ARGUMENT_SQL)
433                }
434            }
435            Self::Function(_) => quote! {
436                #writer
437                    .u8(::pgrx::pgrx_sql_entity_graph::section::SQL_DECLARED_FUNCTION)
438                    .str(#identifier_expr)
439            },
440        }
441    }
442}
443
444impl Parse for SqlDeclared {
445    fn parse(input: ParseStream) -> syn::Result<Self> {
446        let variant: Ident = input.parse()?;
447        let content;
448        let _bracket: syn::token::Paren = syn::parenthesized!(content in input);
449        let identifier_path: syn::Path = content.parse()?;
450        let identifier_str = {
451            let mut identifier_segments = Vec::new();
452            for segment in identifier_path.segments {
453                identifier_segments.push(segment.ident.to_string())
454            }
455            identifier_segments.join("::")
456        };
457        let this = match variant.to_string().as_str() {
458            "Type" => Self::Type(identifier_str),
459            "Enum" => Self::Enum(identifier_str),
460            "Function" => Self::Function(identifier_str),
461            _ => {
462                return Err(syn::Error::new(
463                    variant.span(),
464                    "SQL declared entities must be `Type(ident)`, `Enum(ident)`, or `Function(ident)`",
465                ));
466            }
467        };
468        Ok(this)
469    }
470}
471
472impl ToTokens for SqlDeclared {
473    fn to_tokens(&self, tokens: &mut TokenStream2) {
474        tokens.append_all(self.to_entity_graph_tokens())
475    }
476}