pgx_sql_entity_graph/aggregate/
aggregate_type.rs1use super::get_pgx_attr_macro;
18use crate::pg_extern::NameMacro;
19use crate::UsedType;
20
21use proc_macro2::TokenStream as TokenStream2;
22use quote::ToTokens;
23use syn::parse::{Parse, ParseStream};
24use syn::{parse_quote, Expr, Type};
25
26#[derive(Debug, Clone)]
27pub struct AggregateTypeList {
28 pub found: Vec<AggregateType>,
29 pub original: syn::Type,
30}
31
32impl AggregateTypeList {
33 pub fn new(maybe_type_list: syn::Type) -> Result<Self, syn::Error> {
34 match &maybe_type_list {
35 Type::Tuple(tuple) => {
36 let mut coll = Vec::new();
37 for elem in &tuple.elems {
38 let parsed_elem = AggregateType::new(elem.clone())?;
39 coll.push(parsed_elem);
40 }
41 Ok(Self { found: coll, original: maybe_type_list })
42 }
43 ty => {
44 Ok(Self { found: vec![AggregateType::new(ty.clone())?], original: maybe_type_list })
45 }
46 }
47 }
48
49 pub fn entity_tokens(&self) -> Expr {
50 let found = self.found.iter().map(|x| x.entity_tokens());
51 parse_quote! {
52 vec![#(#found),*]
53 }
54 }
55}
56
57impl Parse for AggregateTypeList {
58 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
59 Self::new(input.parse()?)
60 }
61}
62
63impl ToTokens for AggregateTypeList {
64 fn to_tokens(&self, tokens: &mut TokenStream2) {
65 self.original.to_tokens(tokens)
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct AggregateType {
71 pub used_ty: UsedType,
72 pub name: Option<String>,
74}
75
76impl AggregateType {
77 pub fn new(ty: syn::Type) -> Result<Self, syn::Error> {
78 let (name_macro, name) = if let Some(name_macro) = get_pgx_attr_macro("name", &ty) {
79 let name_macro = syn::parse2::<NameMacro>(name_macro)?;
80 let name = Some(name_macro.ident.clone());
81 (Some(name_macro), name)
82 } else {
83 (None, None)
84 };
85
86 let used_ty = name_macro.map(|v| v.used_ty).unwrap_or(UsedType::new(ty)?);
87
88 let retval = Self { used_ty, name };
89 Ok(retval)
90 }
91
92 pub fn entity_tokens(&self) -> Expr {
93 let used_ty_entity_tokens = self.used_ty.entity_tokens();
94 let name = self.name.iter();
95 parse_quote! {
96 ::pgx::pgx_sql_entity_graph::AggregateTypeEntity {
97 used_ty: #used_ty_entity_tokens,
98 name: None #( .unwrap_or(Some(#name)) )*,
99 }
100 }
101 }
102}
103
104impl ToTokens for AggregateType {
105 fn to_tokens(&self, tokens: &mut TokenStream2) {
106 self.used_ty.resolved_ty.to_tokens(tokens)
107 }
108}
109
110impl Parse for AggregateType {
111 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
112 Self::new(input.parse()?)
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::AggregateTypeList;
119 use eyre::{eyre as eyre_err, Result};
120 use syn::parse_quote;
121
122 #[test]
123 fn solo() -> Result<()> {
124 let tokens: syn::Type = parse_quote! {
125 i32
126 };
127 let list = AggregateTypeList::new(tokens);
129 assert!(list.is_ok());
130 let list = list.unwrap();
131 let found = &list.found[0];
132 let found_string = match &found.used_ty.resolved_ty {
133 syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
134 _ => return Err(eyre_err!("Wrong found.used_ty.resolved_ty")),
135 };
136 assert_eq!(found_string, "i32");
137 Ok(())
138 }
139
140 #[test]
141 fn list() -> Result<()> {
142 let tokens: syn::Type = parse_quote! {
143 (i32, i8)
144 };
145 let list = AggregateTypeList::new(tokens);
147 assert!(list.is_ok());
148 let list = list.unwrap();
149 let first = &list.found[0];
150 let first_string = match &first.used_ty.resolved_ty {
151 syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
152 _ => return Err(eyre_err!("Wrong first.used_ty.resolved_ty: {:?}", first)),
153 };
154 assert_eq!(first_string, "i32");
155
156 let second = &list.found[1];
157 let second_string = match &second.used_ty.resolved_ty {
158 syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
159 _ => return Err(eyre_err!("Wrong second.used_ty.resolved_ty: {:?}", second)),
160 };
161 assert_eq!(second_string, "i8");
162 Ok(())
163 }
164
165 #[test]
166 fn list_variadic_with_path() -> Result<()> {
167 let tokens: syn::Type = parse_quote! {
168 (i32, pgx::variadic!(i8))
169 };
170 let list = AggregateTypeList::new(tokens);
172 assert!(list.is_ok());
173 let list = list.unwrap();
174 let first = &list.found[0];
175 let first_string = match &first.used_ty.resolved_ty {
176 syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
177 _ => return Err(eyre_err!("Wrong first.used_ty.resolved_ty: {:?}", first)),
178 };
179 assert_eq!(first_string, "i32");
180
181 let second = &list.found[1];
182 let second_string = match &second.used_ty.resolved_ty {
183 syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
184 _ => return Err(eyre_err!("Wrong second.used_ty.resolved_ty: {:?}", second)),
185 };
186 assert_eq!(second_string, "VariadicArray");
187 Ok(())
188 }
189
190 #[test]
191 fn list_variadic() -> Result<()> {
192 let tokens: syn::Type = parse_quote! {
193 (i32, variadic!(i8))
194 };
195 let list = AggregateTypeList::new(tokens);
197 assert!(list.is_ok());
198 let list = list.unwrap();
199 let first = &list.found[0];
200 let first_string = match &first.used_ty.resolved_ty {
201 syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
202 _ => return Err(eyre_err!("Wrong first.ty: {:?}", first)),
203 };
204 assert_eq!(first_string, "i32");
205
206 let second = &list.found[1];
207 let second_string = match &second.used_ty.resolved_ty {
208 syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
209 _ => return Err(eyre_err!("Wrong second.used_ty.resolved_ty: {:?}", second)),
210 };
211 assert_eq!(second_string, "VariadicArray");
212 Ok(())
213 }
214}