use proc_macro2::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
use syn::{parse_quote, ItemTrait, Token, Type, TypeParamBound};
#[cfg_attr(test, derive(Eq, PartialEq, Debug))]
pub struct AbstractFactoryAttribute {
factory_trait: Type,
#[allow(dead_code)]
sep: Token![,],
types: Punctuated<Type, Token![,]>,
}
impl Parse for AbstractFactoryAttribute {
fn parse(input: ParseStream) -> Result<Self> {
Ok(AbstractFactoryAttribute {
factory_trait: input.parse()?,
sep: input.parse()?,
types: input.parse_terminated(Type::parse, Token![,])?,
})
}
}
impl AbstractFactoryAttribute {
pub fn expand(&self, input_trait: &mut ItemTrait) -> TokenStream {
let factory_traits: Punctuated<TypeParamBound, Token![+]> = {
let types = self.types.iter();
let factory_name = &self.factory_trait;
parse_quote! {
#(#factory_name<#types>)+*
}
};
input_trait.supertraits.extend(factory_traits);
quote! {
#input_trait
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use despatma_test_helpers::reformat;
use syn::parse_str;
type Result = std::result::Result<(), Box<dyn std::error::Error>>;
mod abstract_factory {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn parse() -> Result {
let actual: AbstractFactoryAttribute = parse_str("Factory, u32, i64")?;
let mut expected_types = Punctuated::new();
expected_types.push(parse_str("u32")?);
expected_types.push(parse_str("i64")?);
assert_eq!(
actual,
AbstractFactoryAttribute {
factory_trait: parse_str("Factory")?,
sep: Default::default(),
types: expected_types,
}
);
Ok(())
}
#[test]
#[should_panic(expected = "expected `,`")]
fn missing_types() {
parse_str::<AbstractFactoryAttribute>("Factory").unwrap();
}
#[test]
fn expand() -> Result {
let mut t = parse_str::<ItemTrait>("pub trait Abstraction<T>: Display + Extend<T> {}")?;
let mut input_types = Punctuated::new();
input_types.push(parse_str("u32")?);
input_types.push(parse_str("i64")?);
let actual = &AbstractFactoryAttribute {
factory_trait: parse_str("Factory")?,
sep: Default::default(),
types: input_types,
}
.expand(&mut t);
assert_eq!(
reformat(&actual),
"pub trait Abstraction<T>: Display + Extend<T> + Factory<u32> + Factory<i64> {}\n"
);
Ok(())
}
}
}