use super::*;
pub fn fetch_name_with_generic_params(ast: &DeriveInput) -> (String, Vec<String>) {
let mut param_string = String::from("");
let params: Vec<String> = ast
.generics
.params
.iter()
.filter_map(|param| match param {
GenericParam::Lifetime(inner) => {
if inner.lifetime.ident.to_token_stream().to_string() != String::from("static") {
panic!("VariantAccess can only be derived for types with static lifetimes");
} else {
None
}
}
GenericParam::Type(inner) => {
param_string.push_str(&format!("{},", inner.ident.to_string()));
Some(inner.ident.to_string())
}
GenericParam::Const(_) => {
panic!("VariantAccess does not currently support const generics")
}
})
.collect();
param_string.pop();
if !params.is_empty() {
(format!("{}<{}>", ast.ident, param_string), params)
} else {
(ast.ident.to_string(), params)
}
}
fn parse_type(ty: &syn::Type) -> String {
match ty {
syn::Type::Array(array) => parse_array(array),
syn::Type::Tuple(tuple) => parse_tuple(tuple),
syn::Type::Path(path) => parse_path(path),
other @ _ => panic!(
"VariantAccess cannot be derived for enums with a field of type: {:?}",
other.to_token_stream()
),
}
}
fn parse_path(path: &syn::TypePath) -> String {
let mut fullname = String::from("");
let _ = path
.path
.segments
.pairs()
.map(|segment| fullname.push_str(&segment.to_token_stream().to_string()))
.collect::<()>();
fullname.retain(|c| c != ' ');
fullname
}
fn parse_array(array: &syn::TypeArray) -> String {
let mut fullname = String::from("[");
let inner_ty = parse_type(&array.elem);
fullname.push_str(&inner_ty);
let len = match &array.len {
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(int),
..
}) => int,
_ => panic!(
"VariantAccess can't be derived on array \
types whose length is not expressed in terms of an \
integer literal"
),
};
fullname.push_str(&format!(";{}]", len.base10_digits()));
fullname.retain(|c| c != ' ');
fullname
}
fn parse_tuple(tuple: &syn::TypeTuple) -> String {
let mut fullname = String::from("(");
let _ = tuple
.elems
.pairs()
.map(|segment| fullname.push_str(&format!("{},", parse_type(segment.value()))))
.collect::<()>();
let mut fullname = String::from(&fullname[..fullname.len() - 1]);
fullname.push_str(")");
fullname.retain(|c| c != ' ');
fullname
}
pub fn fetch_types_from_enum(ast: &DeriveInput) -> HashMap<String, &Ident> {
let mut types: HashMap<String, &Ident> = HashMap::new();
if let Data::Enum(data) = &ast.data {
for var in data.variants.iter() {
if let syn::Fields::Unnamed(field_) = &var.fields {
if field_.unnamed.len() > 1 {
panic!("Can only derive for enums whose types do not contain multiple fields.");
}
for field_entry in field_.unnamed.iter() {
if types
.insert(parse_type(&field_entry.ty), &var.ident)
.is_some()
{
panic!("Cannot derive VariantAccess for enum with multiple fields of same type");
}
}
} else {
panic!("Cannot derive VariantAccess for enums whose types have named fields.")
}
}
} else {
panic!("Can only derive VariantAccess for enums.")
};
types
}
pub fn create_marker_structs(name: &str, types: &HashMap<String, &Ident>) -> TokenStream {
let mut piece = format!("#[allow(non_snake_case)]\n mod variant_access_{}", name);
piece.push_str("{ ");
for field_ in types.values() {
piece.push_str(&format!("pub (crate) struct {};", field_.to_string()));
}
piece.push_str("} ");
piece.parse().unwrap()
}
#[cfg(test)]
mod test_parsers {
use super::*;
#[test]
fn test_parse_tuple() {
let ast: DeriveInput = syn::parse_str(
r#"
enum TupleTest {
F1((i64, bool)),
}
"#,
)
.unwrap();
let types = fetch_types_from_enum(&ast);
let type_names: Vec<_> = types.keys().collect();
assert_eq!(type_names, vec!["(i64,bool)"]);
}
#[test]
fn test_parse_array() {
let ast: DeriveInput = syn::parse_str(
r#"
enum TupleTest {
F1([u8; 32]),
}
"#,
)
.unwrap();
let types = fetch_types_from_enum(&ast);
let type_names: Vec<_> = types.keys().collect();
assert_eq!(type_names, vec!["[u8;32]"]);
}
#[test]
fn test_parse_tuple_in_array() {
let ast: DeriveInput = syn::parse_str(
r#"
enum TupleTest {
F1([(i32, bool); 32]),
}
"#,
).unwrap();
let types = fetch_types_from_enum(&ast);
let type_names: Vec<_> = types.keys().collect();
assert_eq!(type_names, vec!["[(i32,bool);32]"]);
}
#[test]
fn test_parse_array_in_tuple() {
let ast: DeriveInput = syn::parse_str(
r#"
enum TupleTest {
F1((i32, [u8; 32])),
}
"#,
).unwrap();
let types = fetch_types_from_enum(&ast);
let type_names: Vec<_> = types.keys().collect();
assert_eq!(type_names, vec!["(i32,[u8;32])"]);
}
#[test]
fn test_parse_generics_arrays_and_tuples() {
let ast: DeriveInput = syn::parse_str(
r#"
enum TupleTest<T: Debug> {
F1([(T, [u8; 32]); 12]),
}
"#,
).unwrap();
let types = fetch_types_from_enum(&ast);
let type_names: Vec<_> = types.keys().collect();
assert_eq!(type_names, vec!["[(T,[u8;32]);12]"]);
}
#[test]
fn test_nested_generics() {
let ast: DeriveInput = syn::parse_str(
r#"
enum TupleTest<T: Debug, H> {
F1((Box<T>, PhantomData<H>)),
}
"#,
).unwrap();
let types = fetch_types_from_enum(&ast);
let type_names: Vec<_> = types.keys().collect();
assert_eq!(type_names, vec!["(Box<T>,PhantomData<H>)"]);
}
}