Skip to main content

pgrx_sql_entity_graph/pg_extern/
returning.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_extern]` return value related macro expansion for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18use super::LastIdent;
19use crate::UsedType;
20
21use proc_macro2::TokenStream as TokenStream2;
22use quote::{ToTokens, TokenStreamExt, quote};
23use syn::parse::{Parse, ParseStream};
24
25use syn::spanned::Spanned;
26use syn::{Error, GenericArgument, PathArguments, Token, Type};
27
28#[derive(Debug, Clone)]
29pub struct ReturningIteratedItem {
30    pub used_ty: UsedType,
31    pub name: Option<String>,
32}
33
34#[derive(Debug, Clone)]
35pub enum Returning {
36    None,
37    Type(UsedType),
38    SetOf { ty: UsedType },
39    Iterated { tys: Vec<ReturningIteratedItem> },
40    // /// Technically we don't ever create this, single triggers have their own macro.
41    // Trigger,
42}
43
44impl Returning {
45    fn parse_type_macro(type_macro: &mut syn::TypeMacro) -> Result<Returning, syn::Error> {
46        let mac = &type_macro.mac;
47        let opt_archetype = mac.path.segments.last().map(|archetype| archetype.ident.to_string());
48        match opt_archetype.as_deref() {
49            Some("composite_type") => {
50                Ok(Returning::Type(UsedType::new(syn::Type::Macro(type_macro.clone()))?))
51            }
52            _ => Err(syn::Error::new(
53                type_macro.span(),
54                "type macros other than `composite_type!` are not yet implemented",
55            )),
56        }
57    }
58
59    fn match_type(ty: &Type) -> Result<Returning, Error> {
60        let mut ty = Box::new(ty.clone());
61
62        match &mut *ty {
63            syn::Type::Path(typepath) => {
64                let is_option = typepath.last_ident_is("Option");
65                let is_result = typepath.last_ident_is("Result");
66                let mut is_setof_iter = typepath.last_ident_is("SetOfIterator");
67                let mut is_table_iter = typepath.last_ident_is("TableIterator");
68                let path = &mut typepath.path;
69
70                if is_option || is_result || is_setof_iter || is_table_iter {
71                    let option_inner_path = if is_option || is_result {
72                        match path.segments.last_mut().map(|s| &mut s.arguments) {
73                            Some(syn::PathArguments::AngleBracketed(args)) => {
74                                let args_span = args.span();
75                                match args.args.first_mut() {
76                                    Some(syn::GenericArgument::Type(syn::Type::Path(
77                                        syn::TypePath { qself: _, path },
78                                    ))) => path.clone(),
79                                    Some(syn::GenericArgument::Type(_)) => {
80                                        let used_ty =
81                                            UsedType::new(syn::Type::Path(typepath.clone()))?;
82                                        return Ok(Returning::Type(used_ty));
83                                    }
84                                    other => {
85                                        return Err(syn::Error::new(
86                                            other.as_ref().map(|s| s.span()).unwrap_or(args_span),
87                                            format!(
88                                                "Got unexpected generic argument for Option inner: {other:?}"
89                                            ),
90                                        ));
91                                    }
92                                }
93                            }
94                            other => {
95                                return Err(syn::Error::new(
96                                    other.span(),
97                                    format!(
98                                        "Got unexpected path argument for Option inner: {other:?}"
99                                    ),
100                                ));
101                            }
102                        }
103                    } else {
104                        path.clone()
105                    };
106
107                    let mut segments = option_inner_path.segments.clone();
108
109                    loop {
110                        if let Some(segment) = segments.filter_last_ident("Option") {
111                            let PathArguments::AngleBracketed(generics) = &segment.arguments else {
112                                unreachable!()
113                            };
114                            let Some(GenericArgument::Type(Type::Path(this_path))) =
115                                generics.args.last()
116                            else {
117                                return Err(syn::Error::new_spanned(
118                                    generics,
119                                    "where's the generic args?",
120                                ));
121                            };
122                            segments = this_path.path.segments.clone(); // recurse deeper
123                        } else {
124                            if segments.last_ident_is("SetOfIterator") {
125                                is_setof_iter = true;
126                            } else if segments.last_ident_is("TableIterator") {
127                                is_table_iter = true;
128                            }
129                            break;
130                        }
131                    }
132
133                    if is_setof_iter {
134                        let last_path_segment = option_inner_path.segments.last();
135                        let used_ty = match &last_path_segment.map(|ps| &ps.arguments) {
136                            Some(syn::PathArguments::AngleBracketed(args)) => {
137                                match args.args.last().expect("should have one arg?") {
138                                    syn::GenericArgument::Type(ty) => match ty {
139                                        Type::Path(_) | Type::Macro(_) | Type::Reference(_) => {
140                                            UsedType::new(ty.clone())?
141                                        }
142                                        ty => {
143                                            return Err(syn::Error::new(
144                                                ty.span(),
145                                                "SetOf Iterator must have an item",
146                                            ));
147                                        }
148                                    },
149                                    other => {
150                                        return Err(syn::Error::new(
151                                            other.span(),
152                                            format!(
153                                                "Got unexpected generic argument for SetOfIterator: {other:?}"
154                                            ),
155                                        ));
156                                    }
157                                }
158                            }
159                            other => {
160                                return Err(syn::Error::new(
161                                    other
162                                        .map(|s| s.span())
163                                        .unwrap_or_else(proc_macro2::Span::call_site),
164                                    format!(
165                                        "Got unexpected path argument for SetOfIterator: {other:?}"
166                                    ),
167                                ));
168                            }
169                        };
170                        Ok(Returning::SetOf { ty: used_ty })
171                    } else if is_table_iter {
172                        let last_path_segment = segments.last_mut().unwrap();
173                        let mut iterated_items = vec![];
174
175                        match &mut last_path_segment.arguments {
176                            syn::PathArguments::AngleBracketed(args) => {
177                                match args.args.last_mut().unwrap() {
178                                    syn::GenericArgument::Type(syn::Type::Tuple(type_tuple)) => {
179                                        for elem in &type_tuple.elems {
180                                            match &elem {
181                                                syn::Type::Path(path) => {
182                                                    let iterated_item = ReturningIteratedItem {
183                                                        name: None,
184                                                        used_ty: UsedType::new(syn::Type::Path(
185                                                            path.clone(),
186                                                        ))?,
187                                                    };
188                                                    iterated_items.push(iterated_item);
189                                                }
190                                                syn::Type::Macro(type_macro) => {
191                                                    let mac = &type_macro.mac;
192                                                    let archetype =
193                                                        mac.path.segments.last().unwrap();
194                                                    match archetype.ident.to_string().as_str() {
195                                                        "name" => {
196                                                            let out: NameMacro =
197                                                                mac.parse_body()?;
198                                                            let iterated_item =
199                                                                ReturningIteratedItem {
200                                                                    name: Some(out.ident),
201                                                                    used_ty: out.used_ty,
202                                                                };
203                                                            iterated_items.push(iterated_item)
204                                                        }
205                                                        _ => {
206                                                            let iterated_item =
207                                                                ReturningIteratedItem {
208                                                                    name: None,
209                                                                    used_ty: UsedType::new(
210                                                                        syn::Type::Macro(
211                                                                            type_macro.clone(),
212                                                                        ),
213                                                                    )?,
214                                                                };
215                                                            iterated_items.push(iterated_item);
216                                                        }
217                                                    }
218                                                }
219                                                reference @ syn::Type::Reference(_) => {
220                                                    let iterated_item = ReturningIteratedItem {
221                                                        name: None,
222                                                        used_ty: UsedType::new(
223                                                            (*reference).clone(),
224                                                        )?,
225                                                    };
226                                                    iterated_items.push(iterated_item);
227                                                }
228                                                ty => {
229                                                    return Err(syn::Error::new(
230                                                        ty.span(),
231                                                        "Table Iterator must have an item",
232                                                    ));
233                                                }
234                                            };
235                                        }
236                                    }
237                                    syn::GenericArgument::Lifetime(_) => (),
238                                    other => {
239                                        return Err(syn::Error::new(
240                                            other.span(),
241                                            format!("Got unexpected generic argument: {other:?}"),
242                                        ));
243                                    }
244                                };
245                            }
246                            other => {
247                                return Err(syn::Error::new(
248                                    other.span(),
249                                    format!("Got unexpected path argument: {other:?}"),
250                                ));
251                            }
252                        };
253                        Ok(Returning::Iterated { tys: iterated_items })
254                    } else {
255                        let used_ty = UsedType::new(syn::Type::Path(typepath.clone()))?;
256                        Ok(Returning::Type(used_ty))
257                    }
258                } else {
259                    let used_ty = UsedType::new(syn::Type::Path(typepath.clone()))?;
260                    Ok(Returning::Type(used_ty))
261                }
262            }
263            syn::Type::Reference(ty_ref) => {
264                let used_ty = UsedType::new(syn::Type::Reference(ty_ref.clone()))?;
265                Ok(Returning::Type(used_ty))
266            }
267            syn::Type::Macro(type_macro) => Self::parse_type_macro(type_macro),
268            syn::Type::Paren(type_paren) => match &mut *type_paren.elem {
269                syn::Type::Macro(type_macro) => Self::parse_type_macro(type_macro),
270                other => Err(syn::Error::new(
271                    other.span(),
272                    format!("Got unknown return type (type_paren): {type_paren:?}"),
273                )),
274            },
275            syn::Type::Group(tg) => Self::match_type(&tg.elem),
276            other => Err(syn::Error::new(
277                other.span(),
278                format!("Got unknown return type (other): {other:?}"),
279            )),
280        }
281    }
282}
283
284impl TryFrom<&syn::ReturnType> for Returning {
285    type Error = syn::Error;
286
287    fn try_from(value: &syn::ReturnType) -> Result<Self, Self::Error> {
288        match &value {
289            syn::ReturnType::Default => Ok(Returning::None),
290            syn::ReturnType::Type(_, ty) => Self::match_type(ty),
291        }
292    }
293}
294
295impl ToTokens for Returning {
296    fn to_tokens(&self, tokens: &mut TokenStream2) {
297        let quoted = match self {
298            Returning::None => quote! {
299                ::pgrx::pgrx_sql_entity_graph::PgExternReturnEntity::None
300            },
301            Returning::Type(used_ty) => {
302                let used_ty_entity_tokens = used_ty.entity_tokens();
303                quote! {
304                    ::pgrx::pgrx_sql_entity_graph::PgExternReturnEntity::Type {
305                        ty: #used_ty_entity_tokens,
306                    }
307                }
308            }
309            Returning::SetOf { ty: used_ty } => {
310                let used_ty_entity_tokens = used_ty.entity_tokens();
311                quote! {
312                    ::pgrx::pgrx_sql_entity_graph::PgExternReturnEntity::SetOf {
313                        ty: #used_ty_entity_tokens,
314                                                                  }
315                }
316            }
317            Returning::Iterated { tys: items } => {
318                let quoted_items = items
319                    .iter()
320                    .map(|ReturningIteratedItem { used_ty, name }| {
321                        let name_iter = name.iter();
322                        let used_ty_entity_tokens = used_ty.entity_tokens();
323                        quote! {
324                            ::pgrx::pgrx_sql_entity_graph::PgExternReturnEntityIteratedItem {
325                                ty: #used_ty_entity_tokens,
326                                name: None #( .unwrap_or(Some(stringify!(#name_iter))) )*,
327                            }
328                        }
329                    })
330                    .collect::<Vec<_>>();
331                quote! {
332                    ::pgrx::pgrx_sql_entity_graph::PgExternReturnEntity::Iterated {
333                        tys: vec![
334                            #(#quoted_items),*
335                        ],
336                    }
337                }
338            }
339        };
340        tokens.append_all(quoted);
341    }
342}
343
344#[derive(Debug, Clone)]
345pub struct NameMacro {
346    pub ident: String,
347    pub used_ty: UsedType,
348}
349
350impl Parse for NameMacro {
351    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
352        let ident = input
353            .parse::<syn::Ident>()
354            .map(|v| v.to_string())
355            // Avoid making folks unable to use rust keywords.
356            .or_else(|_| input.parse::<syn::Token![type]>().map(|_| String::from("type")))
357            .or_else(|_| input.parse::<syn::Token![mod]>().map(|_| String::from("mod")))
358            .or_else(|_| input.parse::<syn::Token![extern]>().map(|_| String::from("extern")))
359            .or_else(|_| input.parse::<syn::Token![async]>().map(|_| String::from("async")))
360            .or_else(|_| input.parse::<syn::Token![crate]>().map(|_| String::from("crate")))
361            .or_else(|_| input.parse::<syn::Token![use]>().map(|_| String::from("use")))?;
362        let _comma: Token![,] = input.parse()?;
363        let ty: syn::Type = input.parse()?;
364
365        let used_ty = UsedType::new(ty)?;
366
367        Ok(Self { ident, used_ty })
368    }
369}