pgx_utils/
lib.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9
10use crate::sql_entity_graph::{NameMacro, PositioningRef};
11use proc_macro2::{TokenStream, TokenTree};
12use quote::{format_ident, quote, ToTokens, TokenStreamExt};
13use std::collections::HashSet;
14use syn::{GenericArgument, PathArguments, Type, TypeParamBound};
15
16pub mod rewriter;
17pub mod sql_entity_graph;
18
19#[doc(hidden)]
20pub mod __reexports {
21    pub use eyre;
22    // For `#[no_std]` based `pgx` extensions we use `HashSet` for type mappings.
23    pub mod std {
24        pub mod collections {
25            pub use std::collections::HashSet;
26        }
27    }
28}
29
30#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
31pub enum ExternArgs {
32    CreateOrReplace,
33    Immutable,
34    Strict,
35    Stable,
36    Volatile,
37    Raw,
38    NoGuard,
39    ParallelSafe,
40    ParallelUnsafe,
41    ParallelRestricted,
42    Error(String),
43    Schema(String),
44    Name(String),
45    Cost(String),
46    Requires(Vec<PositioningRef>),
47}
48
49impl core::fmt::Display for ExternArgs {
50    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51        match self {
52            ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
53            ExternArgs::Immutable => write!(f, "IMMUTABLE"),
54            ExternArgs::Strict => write!(f, "STRICT"),
55            ExternArgs::Stable => write!(f, "STABLE"),
56            ExternArgs::Volatile => write!(f, "VOLATILE"),
57            ExternArgs::Raw => Ok(()),
58            ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
59            ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
60            ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
61            ExternArgs::Error(_) => Ok(()),
62            ExternArgs::NoGuard => Ok(()),
63            ExternArgs::Schema(_) => Ok(()),
64            ExternArgs::Name(_) => Ok(()),
65            ExternArgs::Cost(cost) => write!(f, "COST {}", cost),
66            ExternArgs::Requires(_) => Ok(()),
67        }
68    }
69}
70
71impl ToTokens for ExternArgs {
72    fn to_tokens(&self, tokens: &mut TokenStream) {
73        match self {
74            ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
75            ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
76            ExternArgs::Strict => tokens.append(format_ident!("Strict")),
77            ExternArgs::Stable => tokens.append(format_ident!("Stable")),
78            ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
79            ExternArgs::Raw => tokens.append(format_ident!("Raw")),
80            ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
81            ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
82            ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
83            ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
84            ExternArgs::Error(_s) => {
85                tokens.append_all(
86                    quote! {
87                        Error(String::from("#_s"))
88                    }
89                    .to_token_stream(),
90                );
91            }
92            ExternArgs::Schema(_s) => {
93                tokens.append_all(
94                    quote! {
95                        Schema(String::from("#_s"))
96                    }
97                    .to_token_stream(),
98                );
99            }
100            ExternArgs::Name(_s) => {
101                tokens.append_all(
102                    quote! {
103                        Name(String::from("#_s"))
104                    }
105                    .to_token_stream(),
106                );
107            }
108            ExternArgs::Cost(_s) => {
109                tokens.append_all(
110                    quote! {
111                        Cost(String::from("#_s"))
112                    }
113                    .to_token_stream(),
114                );
115            }
116            ExternArgs::Requires(items) => {
117                tokens.append_all(
118                    quote! {
119                        Requires(vec![#(#items),*])
120                    }
121                    .to_token_stream(),
122                );
123            }
124        }
125    }
126}
127
128#[derive(Debug, Hash, Ord, PartialOrd, Eq, PartialEq)]
129pub enum FunctionArgs {
130    SearchPath(String),
131}
132
133#[derive(Debug)]
134pub enum CategorizedType {
135    Iterator(Vec<String>),
136    OptionalIterator(Vec<String>),
137    Tuple(Vec<String>),
138    Default,
139}
140
141pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
142    let mut args = HashSet::<ExternArgs>::new();
143    let mut itr = attr.into_iter();
144    while let Some(t) = itr.next() {
145        match t {
146            TokenTree::Group(g) => {
147                for arg in parse_extern_attributes(g.stream()).into_iter() {
148                    args.insert(arg);
149                }
150            }
151            TokenTree::Ident(i) => {
152                let name = i.to_string();
153                match name.as_str() {
154                    "create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
155                    "immutable" => args.insert(ExternArgs::Immutable),
156                    "strict" => args.insert(ExternArgs::Strict),
157                    "stable" => args.insert(ExternArgs::Stable),
158                    "volatile" => args.insert(ExternArgs::Volatile),
159                    "raw" => args.insert(ExternArgs::Raw),
160                    "no_guard" => args.insert(ExternArgs::NoGuard),
161                    "parallel_safe" => args.insert(ExternArgs::ParallelSafe),
162                    "parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
163                    "parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
164                    "error" => {
165                        let _punc = itr.next().unwrap();
166                        let literal = itr.next().unwrap();
167                        let message = literal.to_string();
168                        let message = unescape::unescape(&message).expect("failed to unescape");
169
170                        // trim leading/trailing quotes around the literal
171                        let message = message[1..message.len() - 1].to_string();
172                        args.insert(ExternArgs::Error(message.to_string()))
173                    }
174                    "schema" => {
175                        let _punc = itr.next().unwrap();
176                        let literal = itr.next().unwrap();
177                        let schema = literal.to_string();
178                        let schema = unescape::unescape(&schema).expect("failed to unescape");
179
180                        // trim leading/trailing quotes around the literal
181                        let schema = schema[1..schema.len() - 1].to_string();
182                        args.insert(ExternArgs::Schema(schema.to_string()))
183                    }
184                    "name" => {
185                        let _punc = itr.next().unwrap();
186                        let literal = itr.next().unwrap();
187                        let name = literal.to_string();
188                        let name = unescape::unescape(&name).expect("failed to unescape");
189
190                        // trim leading/trailing quotes around the literal
191                        let name = name[1..name.len() - 1].to_string();
192                        args.insert(ExternArgs::Name(name.to_string()))
193                    }
194                    // Recognized, but not handled as an extern argument
195                    "sql" => {
196                        let _punc = itr.next().unwrap();
197                        let _value = itr.next().unwrap();
198                        false
199                    }
200                    _ => false,
201                };
202            }
203            TokenTree::Punct(_) => {}
204            TokenTree::Literal(_) => {}
205        }
206    }
207    args
208}
209
210pub fn categorize_type(ty: &Type) -> CategorizedType {
211    match ty {
212        Type::Path(ty) => {
213            let segments = &ty.path.segments;
214            for segment in segments {
215                let segment_ident = segment.ident.to_string();
216                if segment_ident == "Option" {
217                    match &segment.arguments {
218                        PathArguments::AngleBracketed(a) => match a.args.first().unwrap() {
219                            GenericArgument::Type(ty) => {
220                                let result = categorize_type(ty);
221
222                                return match result {
223                                    CategorizedType::Iterator(i) => {
224                                        CategorizedType::OptionalIterator(i)
225                                    }
226
227                                    _ => result,
228                                };
229                            }
230                            _ => {
231                                break;
232                            }
233                        },
234                        _ => {
235                            break;
236                        }
237                    }
238                }
239                if segment_ident == "Box" {
240                    match &segment.arguments {
241                        PathArguments::AngleBracketed(a) => match a.args.first().unwrap() {
242                            GenericArgument::Type(ty) => return categorize_type(ty),
243                            _ => {
244                                break;
245                            }
246                        },
247                        _ => {
248                            break;
249                        }
250                    }
251                }
252            }
253            CategorizedType::Default
254        }
255        Type::TraitObject(trait_object) => {
256            for bound in &trait_object.bounds {
257                return categorize_trait_bound(bound);
258            }
259
260            panic!("Unsupported trait return type");
261        }
262        Type::ImplTrait(ty) => {
263            for bound in &ty.bounds {
264                return categorize_trait_bound(bound);
265            }
266
267            panic!("Unsupported trait return type");
268        }
269        Type::Tuple(tuple) => {
270            if tuple.elems.len() == 0 {
271                CategorizedType::Default
272            } else {
273                let mut types = Vec::new();
274                for ty in &tuple.elems {
275                    types.push(quote! {#ty}.to_string())
276                }
277                CategorizedType::Tuple(types)
278            }
279        }
280        _ => CategorizedType::Default,
281    }
282}
283
284pub fn categorize_trait_bound(bound: &TypeParamBound) -> CategorizedType {
285    match bound {
286        TypeParamBound::Trait(trait_bound) => {
287            let segments = &trait_bound.path.segments;
288
289            let mut ident = String::new();
290            for segment in segments {
291                if !ident.is_empty() {
292                    ident.push_str("::")
293                }
294                ident.push_str(segment.ident.to_string().as_str());
295            }
296
297            match ident.as_str() {
298                "Iterator" | "std::iter::Iterator" => {
299                    let segment = segments.last().unwrap();
300                    match &segment.arguments {
301                        PathArguments::None => {
302                            panic!("Iterator must have at least one generic type")
303                        }
304                        PathArguments::Parenthesized(_) => {
305                            panic!("Unsupported arguments to Iterator")
306                        }
307                        PathArguments::AngleBracketed(a) => {
308                            let args = &a.args;
309                            if args.len() > 1 {
310                                panic!(
311                                    "Only one generic type is supported when returning an Iterator"
312                                )
313                            }
314
315                            match args.first().unwrap() {
316                                GenericArgument::Binding(b) => {
317                                    let mut types = Vec::new();
318                                    let ty = &b.ty;
319                                    match ty {
320                                        Type::Tuple(tuple) => {
321                                            for e in &tuple.elems {
322                                                types.push(quote! {#e}.to_string());
323                                            }
324                                        },
325                                        _ => {
326                                            types.push(quote! {#ty}.to_string())
327                                        }
328                                    }
329
330                                    return CategorizedType::Iterator(types);
331                                }
332                                _ => panic!("Only binding type arguments are supported when returning an Iterator")
333                            }
334                        }
335                    }
336                }
337                _ => panic!("Unsupported trait return type"),
338            }
339        }
340        TypeParamBound::Lifetime(_) => {
341            panic!("Functions can't return traits with lifetime bounds")
342        }
343    }
344}
345
346pub fn staticize_lifetimes_in_type_path(value: syn::TypePath) -> syn::TypePath {
347    let mut ty = syn::Type::Path(value);
348    staticize_lifetimes(&mut ty);
349    match ty {
350        syn::Type::Path(type_path) => type_path,
351
352        // shouldn't happen
353        _ => panic!("not a TypePath"),
354    }
355}
356
357pub fn staticize_lifetimes(value: &mut syn::Type) {
358    match value {
359        syn::Type::Path(type_path) => {
360            for segment in &mut type_path.path.segments {
361                match &mut segment.arguments {
362                    syn::PathArguments::AngleBracketed(bracketed) => {
363                        for arg in &mut bracketed.args {
364                            match arg {
365                                // rename lifetimes to the static lifetime so the TypeIds match.
366                                syn::GenericArgument::Lifetime(lifetime) => {
367                                    lifetime.ident =
368                                        syn::Ident::new("static", lifetime.ident.span());
369                                }
370
371                                // recurse
372                                syn::GenericArgument::Type(ty) => staticize_lifetimes(ty),
373                                syn::GenericArgument::Binding(binding) => {
374                                    staticize_lifetimes(&mut binding.ty)
375                                }
376                                syn::GenericArgument::Constraint(constraint) => {
377                                    for bound in constraint.bounds.iter_mut() {
378                                        match bound {
379                                            syn::TypeParamBound::Lifetime(lifetime) => {
380                                                lifetime.ident =
381                                                    syn::Ident::new("static", lifetime.ident.span())
382                                            }
383                                            _ => {}
384                                        }
385                                    }
386                                }
387
388                                // nothing to do otherwise
389                                _ => {}
390                            }
391                        }
392                    }
393                    _ => {}
394                }
395            }
396        }
397
398        syn::Type::Reference(type_ref) => match &mut type_ref.lifetime {
399            Some(ref mut lifetime) => {
400                lifetime.ident = syn::Ident::new("static", lifetime.ident.span());
401            }
402            this @ None => *this = Some(syn::parse_quote!('static)),
403        },
404
405        syn::Type::Tuple(type_tuple) => {
406            for elem in &mut type_tuple.elems {
407                staticize_lifetimes(elem);
408            }
409        }
410
411        syn::Type::Macro(type_macro) => {
412            let mac = &type_macro.mac;
413            if let Some(archetype) = mac.path.segments.last() {
414                match archetype.ident.to_string().as_str() {
415                    "name" => {
416                        if let Ok(out) = mac.parse_body::<NameMacro>() {
417                            // We don't particularly care what the identifier is, so we parse a
418                            // raw TokenStream.  Specifically, it's okay for the identifier String,
419                            // which we end up using as a Postgres column name, to be nearly any
420                            // string, which can include Rust reserved words such as "type" or "match"
421                            if let Ok(ident) = syn::parse_str::<TokenStream>(&out.ident) {
422                                let mut ty = out.used_ty.resolved_ty;
423
424                                // rewrite the name!() macro's type so that it has a static lifetime, if any
425                                staticize_lifetimes(&mut ty);
426                                type_macro.mac = syn::parse_quote! {name!(#ident, #ty)};
427                            }
428                        }
429                    }
430                    _ => {}
431                }
432            }
433        }
434        _ => {}
435    }
436}
437
438pub fn anonymize_lifetimes_in_type_path(value: syn::TypePath) -> syn::TypePath {
439    let mut ty = syn::Type::Path(value);
440    anonymize_lifetimes(&mut ty);
441    match ty {
442        syn::Type::Path(type_path) => type_path,
443
444        // shouldn't happen
445        _ => panic!("not a TypePath"),
446    }
447}
448
449pub fn anonymize_lifetimes(value: &mut syn::Type) {
450    match value {
451        syn::Type::Path(type_path) => {
452            for segment in &mut type_path.path.segments {
453                match &mut segment.arguments {
454                    syn::PathArguments::AngleBracketed(bracketed) => {
455                        for arg in &mut bracketed.args {
456                            match arg {
457                                // rename lifetimes to the anonymous lifetime
458                                syn::GenericArgument::Lifetime(lifetime) => {
459                                    lifetime.ident = syn::Ident::new("_", lifetime.ident.span());
460                                }
461
462                                // recurse
463                                syn::GenericArgument::Type(ty) => anonymize_lifetimes(ty),
464                                syn::GenericArgument::Binding(binding) => {
465                                    anonymize_lifetimes(&mut binding.ty)
466                                }
467                                syn::GenericArgument::Constraint(constraint) => {
468                                    for bound in constraint.bounds.iter_mut() {
469                                        match bound {
470                                            syn::TypeParamBound::Lifetime(lifetime) => {
471                                                lifetime.ident =
472                                                    syn::Ident::new("_", lifetime.ident.span())
473                                            }
474                                            _ => {}
475                                        }
476                                    }
477                                }
478
479                                // nothing to do otherwise
480                                _ => {}
481                            }
482                        }
483                    }
484                    _ => {}
485                }
486            }
487        }
488
489        syn::Type::Reference(type_ref) => {
490            if let Some(lifetime) = type_ref.lifetime.as_mut() {
491                lifetime.ident = syn::Ident::new("_", lifetime.ident.span());
492            }
493        }
494
495        syn::Type::Tuple(type_tuple) => {
496            for elem in &mut type_tuple.elems {
497                anonymize_lifetimes(elem);
498            }
499        }
500
501        _ => {}
502    }
503}
504
505/// Roughly `pgx::pg_sys::NAMEDATALEN`
506///
507/// Technically it **should** be that exactly, however this is `pgx-utils` and a this data is used at macro time.
508const POSTGRES_IDENTIFIER_MAX_LEN: usize = 64;
509
510/// Validate that a given ident is acceptable to PostgreSQL
511///
512/// PostgreSQL places some restrictions on identifiers for things like functions.
513///
514/// Namely:
515///
516/// * It must be less than 64 characters
517///
518// This list is incomplete, you could expand it!
519pub fn ident_is_acceptable_to_postgres(ident: &syn::Ident) -> Result<(), syn::Error> {
520    let ident_string = ident.to_string();
521    if ident_string.len() >= POSTGRES_IDENTIFIER_MAX_LEN {
522        return Err(syn::Error::new(
523            ident.span(),
524            &format!(
525                "Identifier `{}` was {} characters long, PostgreSQL will truncate identifiers with less than {POSTGRES_IDENTIFIER_MAX_LEN} characters, opt for an identifier which Postgres won't truncate",
526                ident,
527                ident_string.len(),
528            )
529        ));
530    }
531
532    Ok(())
533}
534
535#[cfg(test)]
536mod tests {
537    use crate::{parse_extern_attributes, ExternArgs};
538    use std::str::FromStr;
539
540    #[test]
541    fn parse_args() {
542        let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
543        let ts = proc_macro2::TokenStream::from_str(s).unwrap();
544
545        let args = parse_extern_attributes(ts);
546        assert!(args.contains(&ExternArgs::Error("syntax error at or near \"THIS\"".to_string())));
547    }
548}