use super::get_pgrx_attr_macro;
use crate::UsedType;
use crate::pg_extern::NameMacro;
use proc_macro2::TokenStream as TokenStream2;
use quote::{ToTokens, quote};
use syn::parse::{Parse, ParseStream};
use syn::{Expr, Type, parse_quote};
#[derive(Debug, Clone)]
pub struct AggregateTypeList {
pub found: Vec<AggregateType>,
pub original: syn::Type,
}
impl AggregateTypeList {
pub fn new(maybe_type_list: syn::Type) -> Result<Self, syn::Error> {
match &maybe_type_list {
Type::Tuple(tuple) => {
let mut coll = Vec::new();
for elem in &tuple.elems {
let parsed_elem = AggregateType::new(elem.clone())?;
coll.push(parsed_elem);
}
Ok(Self { found: coll, original: maybe_type_list })
}
ty => {
Ok(Self { found: vec![AggregateType::new(ty.clone())?], original: maybe_type_list })
}
}
}
pub fn entity_tokens(&self) -> Expr {
let found = self.found.iter().map(|x| x.entity_tokens());
parse_quote! {
vec![#(#found),*]
}
}
pub fn section_len_tokens(&self) -> TokenStream2 {
let found = self.found.iter().map(AggregateType::section_len_tokens);
quote! {
::pgrx::pgrx_sql_entity_graph::section::list_len(&[
#( #found ),*
])
}
}
pub fn section_writer_tokens(&self, writer: TokenStream2) -> TokenStream2 {
let count = self.found.len();
let found = self.found.iter().map(|item| item.section_writer_tokens(quote! { writer }));
quote! {
{
let writer = #writer.u32(#count as u32);
#( let writer = { #found }; )*
writer
}
}
}
}
impl Parse for AggregateTypeList {
fn parse(input: ParseStream) -> Result<Self, syn::Error> {
Self::new(input.parse()?)
}
}
impl ToTokens for AggregateTypeList {
fn to_tokens(&self, tokens: &mut TokenStream2) {
self.original.to_tokens(tokens)
}
}
#[derive(Debug, Clone)]
pub struct AggregateType {
pub used_ty: UsedType,
pub name: Option<String>,
}
impl AggregateType {
pub fn new(ty: syn::Type) -> Result<Self, syn::Error> {
let (name_macro, name) = if let Some(name_macro) = get_pgrx_attr_macro("name", &ty) {
let name_macro = syn::parse2::<NameMacro>(name_macro)?;
let name = Some(name_macro.ident.clone());
(Some(name_macro), name)
} else {
(None, None)
};
let used_ty = name_macro.map(|v| v.used_ty).unwrap_or(UsedType::new(ty)?);
let retval = Self { used_ty, name };
Ok(retval)
}
pub fn entity_tokens(&self) -> Expr {
let used_ty_entity_tokens = self.used_ty.entity_tokens();
let name = self.name.iter();
parse_quote! {
::pgrx::pgrx_sql_entity_graph::AggregateTypeEntity {
used_ty: #used_ty_entity_tokens,
name: None #( .unwrap_or(Some(#name)) )*,
}
}
}
pub fn section_len_tokens(&self) -> TokenStream2 {
let used_ty_len = self.used_ty.section_len_tokens();
let name_len = self
.name
.as_ref()
.map(|name| {
quote! {
::pgrx::pgrx_sql_entity_graph::section::bool_len()
+ ::pgrx::pgrx_sql_entity_graph::section::str_len(#name)
}
})
.unwrap_or_else(|| quote! { ::pgrx::pgrx_sql_entity_graph::section::bool_len() });
quote! {
(#used_ty_len) + (#name_len)
}
}
pub fn section_writer_tokens(&self, writer: TokenStream2) -> TokenStream2 {
let name_writer = self
.name
.as_ref()
.map(|name| quote! { .bool(true).str(#name) })
.unwrap_or_else(|| quote! { .bool(false) });
self.used_ty.section_writer_tokens(quote! {
#writer #name_writer
})
}
}
impl ToTokens for AggregateType {
fn to_tokens(&self, tokens: &mut TokenStream2) {
self.used_ty.resolved_ty.to_tokens(tokens)
}
}
impl Parse for AggregateType {
fn parse(input: ParseStream) -> Result<Self, syn::Error> {
Self::new(input.parse()?)
}
}
#[cfg(test)]
mod tests {
use super::AggregateTypeList;
use eyre::{Result, eyre as eyre_err};
use syn::parse_quote;
#[test]
fn solo() -> Result<()> {
let tokens: syn::Type = parse_quote! {
i32
};
let list = AggregateTypeList::new(tokens);
assert!(list.is_ok());
let list = list.unwrap();
let found = &list.found[0];
let found_string = match &found.used_ty.resolved_ty {
syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
_ => return Err(eyre_err!("Wrong found.used_ty.resolved_ty")),
};
assert_eq!(found_string, "i32");
Ok(())
}
#[test]
fn list() -> Result<()> {
let tokens: syn::Type = parse_quote! {
(i32, i8)
};
let list = AggregateTypeList::new(tokens);
assert!(list.is_ok());
let list = list.unwrap();
let first = &list.found[0];
let first_string = match &first.used_ty.resolved_ty {
syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
_ => return Err(eyre_err!("Wrong first.used_ty.resolved_ty: {:?}", first)),
};
assert_eq!(first_string, "i32");
let second = &list.found[1];
let second_string = match &second.used_ty.resolved_ty {
syn::Type::Path(ty_path) => ty_path.path.segments.last().unwrap().ident.to_string(),
_ => return Err(eyre_err!("Wrong second.used_ty.resolved_ty: {:?}", second)),
};
assert_eq!(second_string, "i8");
Ok(())
}
}