pgx_sql_entity_graph/pg_extern/
mod.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9
10/*!
11
12`#[pg_extern]` related macro expansion for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate::pgx_sql_entity_graph] APIs, this is considered **internal**
15to the `pgx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18mod argument;
19mod attribute;
20pub mod entity;
21mod operator;
22mod returning;
23mod search_path;
24
25pub use argument::PgExternArgument;
26pub use operator::PgOperator;
27pub use returning::NameMacro;
28
29use crate::ToSqlConfig;
30use attribute::Attribute;
31use operator::{PgxOperatorAttributeWithIdent, PgxOperatorOpName};
32use search_path::SearchPathList;
33
34use crate::enrich::CodeEnrichment;
35use crate::enrich::ToEntityGraphTokens;
36use crate::enrich::ToRustCodeTokens;
37use crate::lifetimes::staticize_lifetimes;
38use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
39use quote::{quote, quote_spanned, ToTokens};
40use syn::parse::{Parse, ParseStream, Parser};
41use syn::punctuated::Punctuated;
42use syn::spanned::Spanned;
43use syn::{Meta, Token};
44
45use self::returning::Returning;
46
47use super::UsedType;
48
49/// A parsed `#[pg_extern]` item.
50///
51/// It should be used with [`syn::parse::Parse`] functions.
52///
53/// Using [`quote::ToTokens`] will output the declaration for a [`PgExternEntity`][crate::PgExternEntity].
54///
55/// ```rust
56/// use syn::{Macro, parse::Parse, parse_quote, parse};
57/// use quote::{quote, ToTokens};
58/// use pgx_sql_entity_graph::PgExtern;
59///
60/// # fn main() -> eyre::Result<()> {
61/// use pgx_sql_entity_graph::CodeEnrichment;
62/// let parsed: CodeEnrichment<PgExtern> = parse_quote! {
63///     fn example(x: Option<str>) -> Option<&'a str> {
64///         unimplemented!()
65///     }
66/// };
67/// let sql_graph_entity_tokens = parsed.to_token_stream();
68/// # Ok(())
69/// # }
70/// ```
71#[derive(Debug, Clone)]
72pub struct PgExtern {
73    attrs: Vec<Attribute>,
74    func: syn::ItemFn,
75    to_sql_config: ToSqlConfig,
76    operator: Option<PgOperator>,
77    search_path: Option<SearchPathList>,
78    inputs: Vec<PgExternArgument>,
79    input_types: Vec<syn::Type>,
80    returns: Returning,
81}
82
83impl PgExtern {
84    pub fn new(attr: TokenStream2, item: TokenStream2) -> Result<CodeEnrichment<Self>, syn::Error> {
85        let mut attrs = Vec::new();
86        let mut to_sql_config: Option<ToSqlConfig> = None;
87
88        let parser = Punctuated::<Attribute, Token![,]>::parse_terminated;
89        let punctuated_attrs = parser.parse2(attr)?;
90        for pair in punctuated_attrs.into_pairs() {
91            match pair.into_value() {
92                Attribute::Sql(config) => {
93                    to_sql_config.get_or_insert(config);
94                }
95                attr => {
96                    attrs.push(attr);
97                }
98            }
99        }
100
101        let mut to_sql_config = to_sql_config.unwrap_or_default();
102
103        let func = syn::parse2::<syn::ItemFn>(item)?;
104
105        if let Some(ref mut content) = to_sql_config.content {
106            let value = content.value();
107            let updated_value = value
108                .replace("@FUNCTION_NAME@", &*(func.sig.ident.to_string() + "_wrapper"))
109                + "\n";
110            *content = syn::LitStr::new(&updated_value, Span::call_site());
111        }
112
113        if !to_sql_config.overrides_default() {
114            crate::ident_is_acceptable_to_postgres(&func.sig.ident)?;
115        }
116        let operator = Self::operator(&func)?;
117        let search_path = Self::search_path(&func)?;
118        let inputs = Self::inputs(&func)?;
119        let input_types = Self::input_types(&func)?;
120        let returns = Returning::try_from(&func.sig.output)?;
121        Ok(CodeEnrichment(Self {
122            attrs,
123            func,
124            to_sql_config,
125            operator,
126            search_path,
127            inputs,
128            input_types,
129            returns,
130        }))
131    }
132
133    fn input_types(func: &syn::ItemFn) -> syn::Result<Vec<syn::Type>> {
134        func.sig
135            .inputs
136            .iter()
137            .filter_map(|v| -> Option<syn::Result<syn::Type>> {
138                match v {
139                    syn::FnArg::Receiver(_) => None,
140                    syn::FnArg::Typed(pat_ty) => {
141                        let static_ty = pat_ty.ty.clone();
142                        let mut static_ty = match UsedType::new(*static_ty) {
143                            Ok(v) => v.resolved_ty,
144                            Err(e) => return Some(Err(e)),
145                        };
146                        staticize_lifetimes(&mut static_ty);
147                        Some(Ok(static_ty))
148                    }
149                }
150            })
151            .collect()
152    }
153
154    fn name(&self) -> String {
155        self.attrs
156            .iter()
157            .find_map(|a| match a {
158                Attribute::Name(name) => Some(name.value()),
159                _ => None,
160            })
161            .unwrap_or_else(|| self.func.sig.ident.to_string())
162    }
163
164    fn schema(&self) -> Option<String> {
165        self.attrs.iter().find_map(|a| match a {
166            Attribute::Schema(name) => Some(name.value()),
167            _ => None,
168        })
169    }
170
171    pub fn extern_attrs(&self) -> &[Attribute] {
172        self.attrs.as_slice()
173    }
174
175    fn overridden(&self) -> Option<syn::LitStr> {
176        let mut span = None;
177        let mut retval = None;
178        let mut in_commented_sql_block = false;
179        for attr in &self.func.attrs {
180            let meta = attr.parse_meta().ok();
181            if let Some(meta) = meta {
182                if meta.path().is_ident("doc") {
183                    let content = match meta {
184                        Meta::Path(_) | Meta::List(_) => continue,
185                        Meta::NameValue(mnv) => mnv,
186                    };
187                    if let syn::Lit::Str(ref inner) = content.lit {
188                        span.get_or_insert(content.lit.span());
189                        if !in_commented_sql_block && inner.value().trim() == "```pgxsql" {
190                            in_commented_sql_block = true;
191                        } else if in_commented_sql_block && inner.value().trim() == "```" {
192                            in_commented_sql_block = false;
193                        } else if in_commented_sql_block {
194                            let sql = retval.get_or_insert_with(String::default);
195                            let line = inner.value().trim_start().replace(
196                                "@FUNCTION_NAME@",
197                                &*(self.func.sig.ident.to_string() + "_wrapper"),
198                            ) + "\n";
199                            sql.push_str(&*line);
200                        }
201                    }
202                }
203            }
204        }
205        retval.map(|s| syn::LitStr::new(s.as_ref(), span.unwrap()))
206    }
207
208    fn operator(func: &syn::ItemFn) -> syn::Result<Option<PgOperator>> {
209        let mut skel = Option::<PgOperator>::default();
210        for attr in &func.attrs {
211            let last_segment = attr.path.segments.last().unwrap();
212            match last_segment.ident.to_string().as_str() {
213                "opname" => {
214                    let attr: PgxOperatorOpName = syn::parse2(attr.tokens.clone())?;
215                    skel.get_or_insert_with(Default::default).opname.get_or_insert(attr);
216                }
217                "commutator" => {
218                    let attr: PgxOperatorAttributeWithIdent = syn::parse2(attr.tokens.clone())?;
219                    skel.get_or_insert_with(Default::default).commutator.get_or_insert(attr);
220                }
221                "negator" => {
222                    let attr: PgxOperatorAttributeWithIdent = syn::parse2(attr.tokens.clone())?;
223                    skel.get_or_insert_with(Default::default).negator.get_or_insert(attr);
224                }
225                "join" => {
226                    let attr: PgxOperatorAttributeWithIdent = syn::parse2(attr.tokens.clone())?;
227                    skel.get_or_insert_with(Default::default).join.get_or_insert(attr);
228                }
229                "restrict" => {
230                    let attr: PgxOperatorAttributeWithIdent = syn::parse2(attr.tokens.clone())?;
231                    skel.get_or_insert_with(Default::default).restrict.get_or_insert(attr);
232                }
233                "hashes" => {
234                    skel.get_or_insert_with(Default::default).hashes = true;
235                }
236                "merges" => {
237                    skel.get_or_insert_with(Default::default).merges = true;
238                }
239                _ => (),
240            }
241        }
242        Ok(skel)
243    }
244
245    fn search_path(func: &syn::ItemFn) -> syn::Result<Option<SearchPathList>> {
246        func.attrs
247            .iter()
248            .find(|f| {
249                f.path
250                    .segments
251                    .first()
252                    .map(|f| f.ident == Ident::new("search_path", Span::call_site()))
253                    .unwrap_or_default()
254            })
255            .map(|attr| attr.parse_args::<SearchPathList>())
256            .transpose()
257    }
258
259    fn inputs(func: &syn::ItemFn) -> syn::Result<Vec<PgExternArgument>> {
260        let mut args = Vec::default();
261        for input in &func.sig.inputs {
262            let arg = PgExternArgument::build(input.clone())?;
263            args.push(arg);
264        }
265        Ok(args)
266    }
267
268    fn entity_tokens(&self) -> TokenStream2 {
269        let ident = &self.func.sig.ident;
270        let name = self.name();
271        let unsafety = &self.func.sig.unsafety;
272        let schema = self.schema();
273        let schema_iter = schema.iter();
274        let extern_attrs = self
275            .attrs
276            .iter()
277            .map(|attr| attr.to_sql_entity_graph_tokens())
278            .collect::<Punctuated<_, Token![,]>>();
279        let search_path = self.search_path.clone().into_iter();
280        let inputs = &self.inputs;
281        let inputs_iter = inputs.iter().map(|v| v.entity_tokens());
282
283        let input_types = self.input_types.iter().cloned();
284
285        let returns = &self.returns;
286
287        let return_type = match &self.func.sig.output {
288            syn::ReturnType::Default => None,
289            syn::ReturnType::Type(arrow, ty) => {
290                let mut static_ty = ty.clone();
291                staticize_lifetimes(&mut static_ty);
292                Some(syn::ReturnType::Type(*arrow, static_ty))
293            }
294        };
295
296        let operator = self.operator.clone().into_iter();
297        let to_sql_config = match self.overridden() {
298            None => self.to_sql_config.clone(),
299            Some(content) => {
300                let mut config = self.to_sql_config.clone();
301                config.content = Some(content);
302                config
303            }
304        };
305
306        let sql_graph_entity_fn_name =
307            syn::Ident::new(&format!("__pgx_internals_fn_{}", ident), Span::call_site());
308        quote_spanned! { self.func.sig.span() =>
309            #[no_mangle]
310            #[doc(hidden)]
311            pub extern "Rust" fn  #sql_graph_entity_fn_name() -> ::pgx::pgx_sql_entity_graph::SqlGraphEntity {
312                extern crate alloc;
313                #[allow(unused_imports)]
314                use alloc::{vec, vec::Vec};
315                type FunctionPointer = #unsafety fn(#( #input_types ),*) #return_type;
316                let metadata: FunctionPointer = #ident;
317                let submission = ::pgx::pgx_sql_entity_graph::PgExternEntity {
318                    name: #name,
319                    unaliased_name: stringify!(#ident),
320                    module_path: core::module_path!(),
321                    full_path: concat!(core::module_path!(), "::", stringify!(#ident)),
322                    metadata: ::pgx::pgx_sql_entity_graph::metadata::FunctionMetadata::entity(&metadata),
323                    fn_args: vec![#(#inputs_iter),*],
324                    fn_return: #returns,
325                    #[allow(clippy::or_fun_call)]
326                    schema: None #( .unwrap_or_else(|| Some(#schema_iter)) )*,
327                    file: file!(),
328                    line: line!(),
329                    extern_attrs: vec![#extern_attrs],
330                    #[allow(clippy::or_fun_call)]
331                    search_path: None #( .unwrap_or_else(|| Some(vec![#search_path])) )*,
332                    #[allow(clippy::or_fun_call)]
333                    operator: None #( .unwrap_or_else(|| Some(#operator)) )*,
334                    to_sql_config: #to_sql_config,
335                };
336                ::pgx::pgx_sql_entity_graph::SqlGraphEntity::Function(submission)
337            }
338        }
339    }
340
341    fn finfo_tokens(&self) -> TokenStream2 {
342        let finfo_name = syn::Ident::new(
343            &format!("pg_finfo_{}_wrapper", self.func.sig.ident),
344            Span::call_site(),
345        );
346        quote_spanned! { self.func.sig.span() =>
347            #[no_mangle]
348            #[doc(hidden)]
349            pub extern "C" fn #finfo_name() -> &'static ::pgx::pg_sys::Pg_finfo_record {
350                const V1_API: ::pgx::pg_sys::Pg_finfo_record = ::pgx::pg_sys::Pg_finfo_record { api_version: 1 };
351                &V1_API
352            }
353        }
354    }
355
356    pub fn wrapper_func(&self) -> TokenStream2 {
357        let func_name = &self.func.sig.ident;
358        let func_name_wrapper = Ident::new(
359            &format!("{}_wrapper", &self.func.sig.ident.to_string()),
360            self.func.sig.ident.span(),
361        );
362        let func_generics = &self.func.sig.generics;
363        let is_raw = self.extern_attrs().contains(&Attribute::Raw);
364        // We use a `_` prefix to make functions with no args more satisfied during linting.
365        let fcinfo_ident = syn::Ident::new("_fcinfo", self.func.sig.ident.span());
366
367        let args = &self.inputs;
368        let arg_pats = args
369            .iter()
370            .map(|v| syn::Ident::new(&format!("{}_", &v.pat), self.func.sig.span()))
371            .collect::<Vec<_>>();
372        let arg_fetches = args.iter().enumerate().map(|(idx, arg)| {
373            let pat = &arg_pats[idx];
374            let resolved_ty = &arg.used_ty.resolved_ty;
375            if arg.used_ty.resolved_ty.to_token_stream().to_string() == quote!(pgx::pg_sys::FunctionCallInfo).to_token_stream().to_string()
376                || arg.used_ty.resolved_ty.to_token_stream().to_string() == quote!(pg_sys::FunctionCallInfo).to_token_stream().to_string()
377                || arg.used_ty.resolved_ty.to_token_stream().to_string() == quote!(::pgx::pg_sys::FunctionCallInfo).to_token_stream().to_string()
378            {
379                quote_spanned! {pat.span()=>
380                    let #pat = #fcinfo_ident;
381                }
382            } else if arg.used_ty.resolved_ty.to_token_stream().to_string() == quote!(()).to_token_stream().to_string() {
383                quote_spanned! {pat.span()=>
384                    debug_assert!(unsafe { ::pgx::fcinfo::pg_getarg::<()>(#fcinfo_ident, #idx).is_none() }, "A `()` argument should always receive `NULL`");
385                    let #pat = ();
386                }
387            } else {
388                match (is_raw, &arg.used_ty.optional) {
389                    (true, None) | (true, Some(_)) => quote_spanned! { pat.span() =>
390                        let #pat = unsafe { ::pgx::fcinfo::pg_getarg_datum_raw(#fcinfo_ident, #idx) as #resolved_ty };
391                    },
392                    (false, None) => quote_spanned! { pat.span() =>
393                        let #pat = unsafe { ::pgx::fcinfo::pg_getarg::<#resolved_ty>(#fcinfo_ident, #idx).unwrap_or_else(|| panic!("{} is null", stringify!{#pat})) };
394                    },
395                    (false, Some(inner)) => quote_spanned! { pat.span() =>
396                        let #pat = unsafe { ::pgx::fcinfo::pg_getarg::<#inner>(#fcinfo_ident, #idx) };
397                    },
398                }
399            }
400        });
401
402        match &self.returns {
403            Returning::None => quote_spanned! { self.func.sig.span() =>
404                  #[no_mangle]
405                  #[doc(hidden)]
406                  #[::pgx::pgx_macros::pg_guard]
407                  pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgx::pg_sys::FunctionCallInfo) {
408                      #(
409                          #arg_fetches
410                      )*
411
412                    #[allow(unused_unsafe)] // unwrapped fn might be unsafe
413                    unsafe { #func_name(#(#arg_pats),*) }
414                }
415            },
416            Returning::Type(retval_ty) => {
417                let result_ident = syn::Ident::new("result", self.func.sig.span());
418                let retval_transform = if retval_ty.resolved_ty == syn::parse_quote!(()) {
419                    quote_spanned! { self.func.sig.output.span() =>
420                       unsafe { ::pgx::fcinfo::pg_return_void() }
421                    }
422                } else if retval_ty.result {
423                    if retval_ty.optional.is_some() {
424                        // returning `Result<Option<T>>`
425                        quote_spanned! {
426                            self.func.sig.output.span() =>
427                                match ::pgx::datum::IntoDatum::into_datum(#result_ident) {
428                                    Some(datum) => datum,
429                                    None => unsafe { ::pgx::fcinfo::pg_return_null(#fcinfo_ident) },
430                                }
431                        }
432                    } else {
433                        // returning Result<T>
434                        quote_spanned! {
435                            self.func.sig.output.span() =>
436                                ::pgx::datum::IntoDatum::into_datum(#result_ident).unwrap_or_else(|| panic!("returned Datum was NULL"))
437                        }
438                    }
439                } else if retval_ty.resolved_ty == syn::parse_quote!(pg_sys::Datum)
440                    || retval_ty.resolved_ty == syn::parse_quote!(pgx::pg_sys::Datum)
441                    || retval_ty.resolved_ty == syn::parse_quote!(::pgx::pg_sys::Datum)
442                {
443                    quote_spanned! { self.func.sig.output.span() =>
444                       #result_ident
445                    }
446                } else if retval_ty.optional.is_some() {
447                    quote_spanned! { self.func.sig.output.span() =>
448                        match #result_ident {
449                            Some(result) => {
450                                ::pgx::datum::IntoDatum::into_datum(result).unwrap_or_else(|| panic!("returned Option<T> was NULL"))
451                            },
452                            None => unsafe { ::pgx::fcinfo::pg_return_null(#fcinfo_ident) }
453                        }
454                    }
455                } else {
456                    quote_spanned! { self.func.sig.output.span() =>
457                        ::pgx::datum::IntoDatum::into_datum(#result_ident).unwrap_or_else(|| panic!("returned Datum was NULL"))
458                    }
459                };
460
461                quote_spanned! { self.func.sig.span() =>
462                    #[no_mangle]
463                    #[doc(hidden)]
464                    #[::pgx::pgx_macros::pg_guard]
465                    pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgx::pg_sys::FunctionCallInfo) -> ::pgx::pg_sys::Datum {
466                        #(
467                            #arg_fetches
468                        )*
469
470                        #[allow(unused_unsafe)] // unwrapped fn might be unsafe
471                        let #result_ident = unsafe { #func_name(#(#arg_pats),*) };
472
473                        #retval_transform
474                    }
475                }
476            }
477            Returning::SetOf { ty: _retval_ty, optional, result } => {
478                let result_handler = if *optional && !*result {
479                    // don't need unsafe annotations because of the larger unsafe block coming up
480                    quote_spanned! { self.func.sig.span() =>
481                        #func_name(#(#arg_pats),*)
482                    }
483                } else if *result {
484                    if *optional {
485                        quote_spanned! { self.func.sig.span() =>
486                            use ::pgx::pg_sys::panic::ErrorReportable;
487                            #func_name(#(#arg_pats),*).report()
488                        }
489                    } else {
490                        quote_spanned! { self.func.sig.span() =>
491                            use ::pgx::pg_sys::panic::ErrorReportable;
492                            Some(#func_name(#(#arg_pats),*).report())
493                        }
494                    }
495                } else {
496                    quote_spanned! { self.func.sig.span() =>
497                        Some(#func_name(#(#arg_pats),*))
498                    }
499                };
500
501                quote_spanned! { self.func.sig.span() =>
502                    #[no_mangle]
503                    #[doc(hidden)]
504                    #[::pgx::pgx_macros::pg_guard]
505                    pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgx::pg_sys::FunctionCallInfo) -> ::pgx::pg_sys::Datum {
506                        #[allow(unused_unsafe)]
507                        unsafe {
508                            // SAFETY: the caller has asserted that `fcinfo` is a valid FunctionCallInfo pointer, allocated by Postgres
509                            // with all its fields properly setup.  Unless the user is calling this wrapper function directly, this
510                            // will always be the case
511                            ::pgx::iter::SetOfIterator::srf_next(#fcinfo_ident, || {
512                                #( #arg_fetches )*
513                                #result_handler
514                            })
515                        }
516                    }
517                }
518            }
519            Returning::Iterated { tys: _retval_tys, optional, result } => {
520                let result_handler = if *optional {
521                    // don't need unsafe annotations because of the larger unsafe block coming up
522                    quote_spanned! { self.func.sig.span() =>
523                        #func_name(#(#arg_pats),*)
524                    }
525                } else if *result {
526                    quote_spanned! { self.func.sig.span() =>
527                        {
528                            use ::pgx::pg_sys::panic::ErrorReportable;
529                            Some(#func_name(#(#arg_pats),*).report())
530                        }
531                    }
532                } else {
533                    quote_spanned! { self.func.sig.span() =>
534                        Some(#func_name(#(#arg_pats),*))
535                    }
536                };
537
538                quote_spanned! { self.func.sig.span() =>
539                    #[no_mangle]
540                    #[doc(hidden)]
541                    #[::pgx::pgx_macros::pg_guard]
542                    pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgx::pg_sys::FunctionCallInfo) -> ::pgx::pg_sys::Datum {
543                        #[allow(unused_unsafe)]
544                        unsafe {
545                            // SAFETY: the caller has asserted that `fcinfo` is a valid FunctionCallInfo pointer, allocated by Postgres
546                            // with all its fields properly setup.  Unless the user is calling this wrapper function directly, this
547                            // will always be the case
548                            ::pgx::iter::TableIterator::srf_next(#fcinfo_ident, || {
549                                #( #arg_fetches )*
550                                #result_handler
551                            })
552                        }
553                    }
554                }
555            }
556        }
557    }
558}
559
560impl ToEntityGraphTokens for PgExtern {
561    fn to_entity_graph_tokens(&self) -> TokenStream2 {
562        self.entity_tokens()
563    }
564}
565
566impl ToRustCodeTokens for PgExtern {
567    fn to_rust_code_tokens(&self) -> TokenStream2 {
568        let original_func = &self.func;
569        let wrapper_func = self.wrapper_func();
570        let finfo_tokens = self.finfo_tokens();
571
572        quote_spanned! { self.func.sig.span() =>
573            #original_func
574            #wrapper_func
575            #finfo_tokens
576        }
577    }
578}
579
580impl Parse for CodeEnrichment<PgExtern> {
581    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
582        let mut attrs = Vec::new();
583
584        let parser = Punctuated::<Attribute, Token![,]>::parse_terminated;
585        let punctuated_attrs = input.call(parser).ok().unwrap_or_default();
586        for pair in punctuated_attrs.into_pairs() {
587            attrs.push(pair.into_value())
588        }
589        PgExtern::new(quote! {#(#attrs)*}, input.parse()?)
590    }
591}