pgx_sql_entity_graph/pg_extern/
mod.rs1mod argument;
19mod attribute;
20pub mod entity;
21mod operator;
22mod returning;
23mod search_path;
24
25pub use argument::PgExternArgument;
26pub use operator::PgOperator;
27pub use returning::NameMacro;
28
29use crate::ToSqlConfig;
30use attribute::Attribute;
31use operator::{PgxOperatorAttributeWithIdent, PgxOperatorOpName};
32use search_path::SearchPathList;
33
34use crate::enrich::CodeEnrichment;
35use crate::enrich::ToEntityGraphTokens;
36use crate::enrich::ToRustCodeTokens;
37use crate::lifetimes::staticize_lifetimes;
38use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
39use quote::{quote, quote_spanned, ToTokens};
40use syn::parse::{Parse, ParseStream, Parser};
41use syn::punctuated::Punctuated;
42use syn::spanned::Spanned;
43use syn::{Meta, Token};
44
45use self::returning::Returning;
46
47use super::UsedType;
48
49#[derive(Debug, Clone)]
72pub struct PgExtern {
73 attrs: Vec<Attribute>,
74 func: syn::ItemFn,
75 to_sql_config: ToSqlConfig,
76 operator: Option<PgOperator>,
77 search_path: Option<SearchPathList>,
78 inputs: Vec<PgExternArgument>,
79 input_types: Vec<syn::Type>,
80 returns: Returning,
81}
82
83impl PgExtern {
84 pub fn new(attr: TokenStream2, item: TokenStream2) -> Result<CodeEnrichment<Self>, syn::Error> {
85 let mut attrs = Vec::new();
86 let mut to_sql_config: Option<ToSqlConfig> = None;
87
88 let parser = Punctuated::<Attribute, Token![,]>::parse_terminated;
89 let punctuated_attrs = parser.parse2(attr)?;
90 for pair in punctuated_attrs.into_pairs() {
91 match pair.into_value() {
92 Attribute::Sql(config) => {
93 to_sql_config.get_or_insert(config);
94 }
95 attr => {
96 attrs.push(attr);
97 }
98 }
99 }
100
101 let mut to_sql_config = to_sql_config.unwrap_or_default();
102
103 let func = syn::parse2::<syn::ItemFn>(item)?;
104
105 if let Some(ref mut content) = to_sql_config.content {
106 let value = content.value();
107 let updated_value = value
108 .replace("@FUNCTION_NAME@", &*(func.sig.ident.to_string() + "_wrapper"))
109 + "\n";
110 *content = syn::LitStr::new(&updated_value, Span::call_site());
111 }
112
113 if !to_sql_config.overrides_default() {
114 crate::ident_is_acceptable_to_postgres(&func.sig.ident)?;
115 }
116 let operator = Self::operator(&func)?;
117 let search_path = Self::search_path(&func)?;
118 let inputs = Self::inputs(&func)?;
119 let input_types = Self::input_types(&func)?;
120 let returns = Returning::try_from(&func.sig.output)?;
121 Ok(CodeEnrichment(Self {
122 attrs,
123 func,
124 to_sql_config,
125 operator,
126 search_path,
127 inputs,
128 input_types,
129 returns,
130 }))
131 }
132
133 fn input_types(func: &syn::ItemFn) -> syn::Result<Vec<syn::Type>> {
134 func.sig
135 .inputs
136 .iter()
137 .filter_map(|v| -> Option<syn::Result<syn::Type>> {
138 match v {
139 syn::FnArg::Receiver(_) => None,
140 syn::FnArg::Typed(pat_ty) => {
141 let static_ty = pat_ty.ty.clone();
142 let mut static_ty = match UsedType::new(*static_ty) {
143 Ok(v) => v.resolved_ty,
144 Err(e) => return Some(Err(e)),
145 };
146 staticize_lifetimes(&mut static_ty);
147 Some(Ok(static_ty))
148 }
149 }
150 })
151 .collect()
152 }
153
154 fn name(&self) -> String {
155 self.attrs
156 .iter()
157 .find_map(|a| match a {
158 Attribute::Name(name) => Some(name.value()),
159 _ => None,
160 })
161 .unwrap_or_else(|| self.func.sig.ident.to_string())
162 }
163
164 fn schema(&self) -> Option<String> {
165 self.attrs.iter().find_map(|a| match a {
166 Attribute::Schema(name) => Some(name.value()),
167 _ => None,
168 })
169 }
170
171 pub fn extern_attrs(&self) -> &[Attribute] {
172 self.attrs.as_slice()
173 }
174
175 fn overridden(&self) -> Option<syn::LitStr> {
176 let mut span = None;
177 let mut retval = None;
178 let mut in_commented_sql_block = false;
179 for attr in &self.func.attrs {
180 let meta = attr.parse_meta().ok();
181 if let Some(meta) = meta {
182 if meta.path().is_ident("doc") {
183 let content = match meta {
184 Meta::Path(_) | Meta::List(_) => continue,
185 Meta::NameValue(mnv) => mnv,
186 };
187 if let syn::Lit::Str(ref inner) = content.lit {
188 span.get_or_insert(content.lit.span());
189 if !in_commented_sql_block && inner.value().trim() == "```pgxsql" {
190 in_commented_sql_block = true;
191 } else if in_commented_sql_block && inner.value().trim() == "```" {
192 in_commented_sql_block = false;
193 } else if in_commented_sql_block {
194 let sql = retval.get_or_insert_with(String::default);
195 let line = inner.value().trim_start().replace(
196 "@FUNCTION_NAME@",
197 &*(self.func.sig.ident.to_string() + "_wrapper"),
198 ) + "\n";
199 sql.push_str(&*line);
200 }
201 }
202 }
203 }
204 }
205 retval.map(|s| syn::LitStr::new(s.as_ref(), span.unwrap()))
206 }
207
208 fn operator(func: &syn::ItemFn) -> syn::Result<Option<PgOperator>> {
209 let mut skel = Option::<PgOperator>::default();
210 for attr in &func.attrs {
211 let last_segment = attr.path.segments.last().unwrap();
212 match last_segment.ident.to_string().as_str() {
213 "opname" => {
214 let attr: PgxOperatorOpName = syn::parse2(attr.tokens.clone())?;
215 skel.get_or_insert_with(Default::default).opname.get_or_insert(attr);
216 }
217 "commutator" => {
218 let attr: PgxOperatorAttributeWithIdent = syn::parse2(attr.tokens.clone())?;
219 skel.get_or_insert_with(Default::default).commutator.get_or_insert(attr);
220 }
221 "negator" => {
222 let attr: PgxOperatorAttributeWithIdent = syn::parse2(attr.tokens.clone())?;
223 skel.get_or_insert_with(Default::default).negator.get_or_insert(attr);
224 }
225 "join" => {
226 let attr: PgxOperatorAttributeWithIdent = syn::parse2(attr.tokens.clone())?;
227 skel.get_or_insert_with(Default::default).join.get_or_insert(attr);
228 }
229 "restrict" => {
230 let attr: PgxOperatorAttributeWithIdent = syn::parse2(attr.tokens.clone())?;
231 skel.get_or_insert_with(Default::default).restrict.get_or_insert(attr);
232 }
233 "hashes" => {
234 skel.get_or_insert_with(Default::default).hashes = true;
235 }
236 "merges" => {
237 skel.get_or_insert_with(Default::default).merges = true;
238 }
239 _ => (),
240 }
241 }
242 Ok(skel)
243 }
244
245 fn search_path(func: &syn::ItemFn) -> syn::Result<Option<SearchPathList>> {
246 func.attrs
247 .iter()
248 .find(|f| {
249 f.path
250 .segments
251 .first()
252 .map(|f| f.ident == Ident::new("search_path", Span::call_site()))
253 .unwrap_or_default()
254 })
255 .map(|attr| attr.parse_args::<SearchPathList>())
256 .transpose()
257 }
258
259 fn inputs(func: &syn::ItemFn) -> syn::Result<Vec<PgExternArgument>> {
260 let mut args = Vec::default();
261 for input in &func.sig.inputs {
262 let arg = PgExternArgument::build(input.clone())?;
263 args.push(arg);
264 }
265 Ok(args)
266 }
267
268 fn entity_tokens(&self) -> TokenStream2 {
269 let ident = &self.func.sig.ident;
270 let name = self.name();
271 let unsafety = &self.func.sig.unsafety;
272 let schema = self.schema();
273 let schema_iter = schema.iter();
274 let extern_attrs = self
275 .attrs
276 .iter()
277 .map(|attr| attr.to_sql_entity_graph_tokens())
278 .collect::<Punctuated<_, Token![,]>>();
279 let search_path = self.search_path.clone().into_iter();
280 let inputs = &self.inputs;
281 let inputs_iter = inputs.iter().map(|v| v.entity_tokens());
282
283 let input_types = self.input_types.iter().cloned();
284
285 let returns = &self.returns;
286
287 let return_type = match &self.func.sig.output {
288 syn::ReturnType::Default => None,
289 syn::ReturnType::Type(arrow, ty) => {
290 let mut static_ty = ty.clone();
291 staticize_lifetimes(&mut static_ty);
292 Some(syn::ReturnType::Type(*arrow, static_ty))
293 }
294 };
295
296 let operator = self.operator.clone().into_iter();
297 let to_sql_config = match self.overridden() {
298 None => self.to_sql_config.clone(),
299 Some(content) => {
300 let mut config = self.to_sql_config.clone();
301 config.content = Some(content);
302 config
303 }
304 };
305
306 let sql_graph_entity_fn_name =
307 syn::Ident::new(&format!("__pgx_internals_fn_{}", ident), Span::call_site());
308 quote_spanned! { self.func.sig.span() =>
309 #[no_mangle]
310 #[doc(hidden)]
311 pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgx::pgx_sql_entity_graph::SqlGraphEntity {
312 extern crate alloc;
313 #[allow(unused_imports)]
314 use alloc::{vec, vec::Vec};
315 type FunctionPointer = #unsafety fn(#( #input_types ),*) #return_type;
316 let metadata: FunctionPointer = #ident;
317 let submission = ::pgx::pgx_sql_entity_graph::PgExternEntity {
318 name: #name,
319 unaliased_name: stringify!(#ident),
320 module_path: core::module_path!(),
321 full_path: concat!(core::module_path!(), "::", stringify!(#ident)),
322 metadata: ::pgx::pgx_sql_entity_graph::metadata::FunctionMetadata::entity(&metadata),
323 fn_args: vec![#(#inputs_iter),*],
324 fn_return: #returns,
325 #[allow(clippy::or_fun_call)]
326 schema: None #( .unwrap_or_else(|| Some(#schema_iter)) )*,
327 file: file!(),
328 line: line!(),
329 extern_attrs: vec![#extern_attrs],
330 #[allow(clippy::or_fun_call)]
331 search_path: None #( .unwrap_or_else(|| Some(vec![#search_path])) )*,
332 #[allow(clippy::or_fun_call)]
333 operator: None #( .unwrap_or_else(|| Some(#operator)) )*,
334 to_sql_config: #to_sql_config,
335 };
336 ::pgx::pgx_sql_entity_graph::SqlGraphEntity::Function(submission)
337 }
338 }
339 }
340
341 fn finfo_tokens(&self) -> TokenStream2 {
342 let finfo_name = syn::Ident::new(
343 &format!("pg_finfo_{}_wrapper", self.func.sig.ident),
344 Span::call_site(),
345 );
346 quote_spanned! { self.func.sig.span() =>
347 #[no_mangle]
348 #[doc(hidden)]
349 pub extern "C" fn #finfo_name() -> &'static ::pgx::pg_sys::Pg_finfo_record {
350 const V1_API: ::pgx::pg_sys::Pg_finfo_record = ::pgx::pg_sys::Pg_finfo_record { api_version: 1 };
351 &V1_API
352 }
353 }
354 }
355
356 pub fn wrapper_func(&self) -> TokenStream2 {
357 let func_name = &self.func.sig.ident;
358 let func_name_wrapper = Ident::new(
359 &format!("{}_wrapper", &self.func.sig.ident.to_string()),
360 self.func.sig.ident.span(),
361 );
362 let func_generics = &self.func.sig.generics;
363 let is_raw = self.extern_attrs().contains(&Attribute::Raw);
364 let fcinfo_ident = syn::Ident::new("_fcinfo", self.func.sig.ident.span());
366
367 let args = &self.inputs;
368 let arg_pats = args
369 .iter()
370 .map(|v| syn::Ident::new(&format!("{}_", &v.pat), self.func.sig.span()))
371 .collect::<Vec<_>>();
372 let arg_fetches = args.iter().enumerate().map(|(idx, arg)| {
373 let pat = &arg_pats[idx];
374 let resolved_ty = &arg.used_ty.resolved_ty;
375 if arg.used_ty.resolved_ty.to_token_stream().to_string() == quote!(pgx::pg_sys::FunctionCallInfo).to_token_stream().to_string()
376 || arg.used_ty.resolved_ty.to_token_stream().to_string() == quote!(pg_sys::FunctionCallInfo).to_token_stream().to_string()
377 || arg.used_ty.resolved_ty.to_token_stream().to_string() == quote!(::pgx::pg_sys::FunctionCallInfo).to_token_stream().to_string()
378 {
379 quote_spanned! {pat.span()=>
380 let #pat = #fcinfo_ident;
381 }
382 } else if arg.used_ty.resolved_ty.to_token_stream().to_string() == quote!(()).to_token_stream().to_string() {
383 quote_spanned! {pat.span()=>
384 debug_assert!(unsafe { ::pgx::fcinfo::pg_getarg::<()>(#fcinfo_ident, #idx).is_none() }, "A `()` argument should always receive `NULL`");
385 let #pat = ();
386 }
387 } else {
388 match (is_raw, &arg.used_ty.optional) {
389 (true, None) | (true, Some(_)) => quote_spanned! { pat.span() =>
390 let #pat = unsafe { ::pgx::fcinfo::pg_getarg_datum_raw(#fcinfo_ident, #idx) as #resolved_ty };
391 },
392 (false, None) => quote_spanned! { pat.span() =>
393 let #pat = unsafe { ::pgx::fcinfo::pg_getarg::<#resolved_ty>(#fcinfo_ident, #idx).unwrap_or_else(|| panic!("{} is null", stringify!{#pat})) };
394 },
395 (false, Some(inner)) => quote_spanned! { pat.span() =>
396 let #pat = unsafe { ::pgx::fcinfo::pg_getarg::<#inner>(#fcinfo_ident, #idx) };
397 },
398 }
399 }
400 });
401
402 match &self.returns {
403 Returning::None => quote_spanned! { self.func.sig.span() =>
404 #[no_mangle]
405 #[doc(hidden)]
406 #[::pgx::pgx_macros::pg_guard]
407 pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgx::pg_sys::FunctionCallInfo) {
408 #(
409 #arg_fetches
410 )*
411
412 #[allow(unused_unsafe)] unsafe { #func_name(#(#arg_pats),*) }
414 }
415 },
416 Returning::Type(retval_ty) => {
417 let result_ident = syn::Ident::new("result", self.func.sig.span());
418 let retval_transform = if retval_ty.resolved_ty == syn::parse_quote!(()) {
419 quote_spanned! { self.func.sig.output.span() =>
420 unsafe { ::pgx::fcinfo::pg_return_void() }
421 }
422 } else if retval_ty.result {
423 if retval_ty.optional.is_some() {
424 quote_spanned! {
426 self.func.sig.output.span() =>
427 match ::pgx::datum::IntoDatum::into_datum(#result_ident) {
428 Some(datum) => datum,
429 None => unsafe { ::pgx::fcinfo::pg_return_null(#fcinfo_ident) },
430 }
431 }
432 } else {
433 quote_spanned! {
435 self.func.sig.output.span() =>
436 ::pgx::datum::IntoDatum::into_datum(#result_ident).unwrap_or_else(|| panic!("returned Datum was NULL"))
437 }
438 }
439 } else if retval_ty.resolved_ty == syn::parse_quote!(pg_sys::Datum)
440 || retval_ty.resolved_ty == syn::parse_quote!(pgx::pg_sys::Datum)
441 || retval_ty.resolved_ty == syn::parse_quote!(::pgx::pg_sys::Datum)
442 {
443 quote_spanned! { self.func.sig.output.span() =>
444 #result_ident
445 }
446 } else if retval_ty.optional.is_some() {
447 quote_spanned! { self.func.sig.output.span() =>
448 match #result_ident {
449 Some(result) => {
450 ::pgx::datum::IntoDatum::into_datum(result).unwrap_or_else(|| panic!("returned Option<T> was NULL"))
451 },
452 None => unsafe { ::pgx::fcinfo::pg_return_null(#fcinfo_ident) }
453 }
454 }
455 } else {
456 quote_spanned! { self.func.sig.output.span() =>
457 ::pgx::datum::IntoDatum::into_datum(#result_ident).unwrap_or_else(|| panic!("returned Datum was NULL"))
458 }
459 };
460
461 quote_spanned! { self.func.sig.span() =>
462 #[no_mangle]
463 #[doc(hidden)]
464 #[::pgx::pgx_macros::pg_guard]
465 pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgx::pg_sys::FunctionCallInfo) -> ::pgx::pg_sys::Datum {
466 #(
467 #arg_fetches
468 )*
469
470 #[allow(unused_unsafe)] let #result_ident = unsafe { #func_name(#(#arg_pats),*) };
472
473 #retval_transform
474 }
475 }
476 }
477 Returning::SetOf { ty: _retval_ty, optional, result } => {
478 let result_handler = if *optional && !*result {
479 quote_spanned! { self.func.sig.span() =>
481 #func_name(#(#arg_pats),*)
482 }
483 } else if *result {
484 if *optional {
485 quote_spanned! { self.func.sig.span() =>
486 use ::pgx::pg_sys::panic::ErrorReportable;
487 #func_name(#(#arg_pats),*).report()
488 }
489 } else {
490 quote_spanned! { self.func.sig.span() =>
491 use ::pgx::pg_sys::panic::ErrorReportable;
492 Some(#func_name(#(#arg_pats),*).report())
493 }
494 }
495 } else {
496 quote_spanned! { self.func.sig.span() =>
497 Some(#func_name(#(#arg_pats),*))
498 }
499 };
500
501 quote_spanned! { self.func.sig.span() =>
502 #[no_mangle]
503 #[doc(hidden)]
504 #[::pgx::pgx_macros::pg_guard]
505 pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgx::pg_sys::FunctionCallInfo) -> ::pgx::pg_sys::Datum {
506 #[allow(unused_unsafe)]
507 unsafe {
508 ::pgx::iter::SetOfIterator::srf_next(#fcinfo_ident, || {
512 #( #arg_fetches )*
513 #result_handler
514 })
515 }
516 }
517 }
518 }
519 Returning::Iterated { tys: _retval_tys, optional, result } => {
520 let result_handler = if *optional {
521 quote_spanned! { self.func.sig.span() =>
523 #func_name(#(#arg_pats),*)
524 }
525 } else if *result {
526 quote_spanned! { self.func.sig.span() =>
527 {
528 use ::pgx::pg_sys::panic::ErrorReportable;
529 Some(#func_name(#(#arg_pats),*).report())
530 }
531 }
532 } else {
533 quote_spanned! { self.func.sig.span() =>
534 Some(#func_name(#(#arg_pats),*))
535 }
536 };
537
538 quote_spanned! { self.func.sig.span() =>
539 #[no_mangle]
540 #[doc(hidden)]
541 #[::pgx::pgx_macros::pg_guard]
542 pub unsafe extern "C" fn #func_name_wrapper #func_generics(#fcinfo_ident: ::pgx::pg_sys::FunctionCallInfo) -> ::pgx::pg_sys::Datum {
543 #[allow(unused_unsafe)]
544 unsafe {
545 ::pgx::iter::TableIterator::srf_next(#fcinfo_ident, || {
549 #( #arg_fetches )*
550 #result_handler
551 })
552 }
553 }
554 }
555 }
556 }
557 }
558}
559
560impl ToEntityGraphTokens for PgExtern {
561 fn to_entity_graph_tokens(&self) -> TokenStream2 {
562 self.entity_tokens()
563 }
564}
565
566impl ToRustCodeTokens for PgExtern {
567 fn to_rust_code_tokens(&self) -> TokenStream2 {
568 let original_func = &self.func;
569 let wrapper_func = self.wrapper_func();
570 let finfo_tokens = self.finfo_tokens();
571
572 quote_spanned! { self.func.sig.span() =>
573 #original_func
574 #wrapper_func
575 #finfo_tokens
576 }
577 }
578}
579
580impl Parse for CodeEnrichment<PgExtern> {
581 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
582 let mut attrs = Vec::new();
583
584 let parser = Punctuated::<Attribute, Token![,]>::parse_terminated;
585 let punctuated_attrs = input.call(parser).ok().unwrap_or_default();
586 for pair in punctuated_attrs.into_pairs() {
587 attrs.push(pair.into_value())
588 }
589 PgExtern::new(quote! {#(#attrs)*}, input.parse()?)
590 }
591}