pgx_sql_entity_graph/aggregate/
aggregate_type.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9/*!
10
11`#[pg_aggregate]` related type metadata for Rust to SQL translation
12
13> Like all of the [`sql_entity_graph`][crate::pgx_sql_entity_graph] APIs, this is considered **internal**
14to the `pgx` framework and very subject to change between versions. While you may use this, please do it with caution.
15
16*/
17use super::get_pgx_attr_macro;
18use crate::pg_extern::NameMacro;
19use crate::UsedType;
20
21use proc_macro2::TokenStream as TokenStream2;
22use quote::ToTokens;
23use syn::parse::{Parse, ParseStream};
24use syn::{parse_quote, Expr, Type};
25
26#[derive(Debug, Clone)]
27pub struct AggregateTypeList {
28    pub found: Vec<AggregateType>,
29    pub original: syn::Type,
30}
31
32impl AggregateTypeList {
33    pub fn new(maybe_type_list: syn::Type) -> Result<Self, syn::Error> {
34        match &maybe_type_list {
35            Type::Tuple(tuple) => {
36                let mut coll = Vec::new();
37                for elem in &tuple.elems {
38                    let parsed_elem = AggregateType::new(elem.clone())?;
39                    coll.push(parsed_elem);
40                }
41                Ok(Self { found: coll, original: maybe_type_list })
42            }
43            ty => {
44                Ok(Self { found: vec![AggregateType::new(ty.clone())?], original: maybe_type_list })
45            }
46        }
47    }
48
49    pub fn entity_tokens(&self) -> Expr {
50        let found = self.found.iter().map(|x| x.entity_tokens());
51        parse_quote! {
52            vec![#(#found),*]
53        }
54    }
55}
56
57impl Parse for AggregateTypeList {
58    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
59        Self::new(input.parse()?)
60    }
61}
62
63impl ToTokens for AggregateTypeList {
64    fn to_tokens(&self, tokens: &mut TokenStream2) {
65        self.original.to_tokens(tokens)
66    }
67}
68
69#[derive(Debug, Clone)]
70pub struct AggregateType {
71    pub used_ty: UsedType,
72    /// The name, if it exists.
73    pub name: Option<String>,
74}
75
76impl AggregateType {
77    pub fn new(ty: syn::Type) -> Result<Self, syn::Error> {
78        let (name_macro, name) = if let Some(name_macro) = get_pgx_attr_macro("name", &ty) {
79            let name_macro = syn::parse2::<NameMacro>(name_macro)?;
80            let name = Some(name_macro.ident.clone());
81            (Some(name_macro), name)
82        } else {
83            (None, None)
84        };
85
86        let used_ty = name_macro.map(|v| v.used_ty).unwrap_or(UsedType::new(ty)?);
87
88        let retval = Self { used_ty, name };
89        Ok(retval)
90    }
91
92    pub fn entity_tokens(&self) -> Expr {
93        let used_ty_entity_tokens = self.used_ty.entity_tokens();
94        let name = self.name.iter();
95        parse_quote! {
96            ::pgx::pgx_sql_entity_graph::AggregateTypeEntity {
97                used_ty: #used_ty_entity_tokens,
98                name: None #( .unwrap_or(Some(#name)) )*,
99            }
100        }
101    }
102}
103
104impl ToTokens for AggregateType {
105    fn to_tokens(&self, tokens: &mut TokenStream2) {
106        self.used_ty.resolved_ty.to_tokens(tokens)
107    }
108}
109
110impl Parse for AggregateType {
111    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
112        Self::new(input.parse()?)
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::AggregateTypeList;
119    use eyre::{eyre as eyre_err, Result};
120    use syn::parse_quote;
121
122    #[test]
123    fn solo() -> Result<()> {
124        let tokens: syn::Type = parse_quote! {
125            i32
126        };
127        // It should not error, as it's valid.
128        let list = AggregateTypeList::new(tokens);
129        assert!(list.is_ok());
130        let list = list.unwrap();
131        let found = &list.found[0];
132        let found_string = match &found.used_ty.resolved_ty {
133            syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
134            _ => return Err(eyre_err!("Wrong found.used_ty.resolved_ty")),
135        };
136        assert_eq!(found_string, "i32");
137        Ok(())
138    }
139
140    #[test]
141    fn list() -> Result<()> {
142        let tokens: syn::Type = parse_quote! {
143            (i32, i8)
144        };
145        // It should not error, as it's valid.
146        let list = AggregateTypeList::new(tokens);
147        assert!(list.is_ok());
148        let list = list.unwrap();
149        let first = &list.found[0];
150        let first_string = match &first.used_ty.resolved_ty {
151            syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
152            _ => return Err(eyre_err!("Wrong first.used_ty.resolved_ty: {:?}", first)),
153        };
154        assert_eq!(first_string, "i32");
155
156        let second = &list.found[1];
157        let second_string = match &second.used_ty.resolved_ty {
158            syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
159            _ => return Err(eyre_err!("Wrong second.used_ty.resolved_ty: {:?}", second)),
160        };
161        assert_eq!(second_string, "i8");
162        Ok(())
163    }
164
165    #[test]
166    fn list_variadic_with_path() -> Result<()> {
167        let tokens: syn::Type = parse_quote! {
168            (i32, pgx::variadic!(i8))
169        };
170        // It should not error, as it's valid.
171        let list = AggregateTypeList::new(tokens);
172        assert!(list.is_ok());
173        let list = list.unwrap();
174        let first = &list.found[0];
175        let first_string = match &first.used_ty.resolved_ty {
176            syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
177            _ => return Err(eyre_err!("Wrong first.used_ty.resolved_ty: {:?}", first)),
178        };
179        assert_eq!(first_string, "i32");
180
181        let second = &list.found[1];
182        let second_string = match &second.used_ty.resolved_ty {
183            syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
184            _ => return Err(eyre_err!("Wrong second.used_ty.resolved_ty: {:?}", second)),
185        };
186        assert_eq!(second_string, "VariadicArray");
187        Ok(())
188    }
189
190    #[test]
191    fn list_variadic() -> Result<()> {
192        let tokens: syn::Type = parse_quote! {
193            (i32, variadic!(i8))
194        };
195        // It should not error, as it's valid.
196        let list = AggregateTypeList::new(tokens);
197        assert!(list.is_ok());
198        let list = list.unwrap();
199        let first = &list.found[0];
200        let first_string = match &first.used_ty.resolved_ty {
201            syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
202            _ => return Err(eyre_err!("Wrong first.ty: {:?}", first)),
203        };
204        assert_eq!(first_string, "i32");
205
206        let second = &list.found[1];
207        let second_string = match &second.used_ty.resolved_ty {
208            syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
209            _ => return Err(eyre_err!("Wrong second.used_ty.resolved_ty: {:?}", second)),
210        };
211        assert_eq!(second_string, "VariadicArray");
212        Ok(())
213    }
214}