use proc_macro::TokenStream;
use quote::quote;
use std::collections::{HashMap, HashSet, VecDeque};
use syn::{
custom_punctuation, parse::ParseStream, parse_macro_input, Data, DeriveInput, Fields,
Ident, Result, Token, Type, Variant,
};
custom_punctuation!(LeftRightArrow, <->);
custom_punctuation!(LeftArrow, <-);
enum Arrow {
Left,
Right,
LeftRight,
}
fn parse_arrows(input: ParseStream) -> Result<Vec<(Ident, Arrow, Ident)>> {
let mut pairs = vec![];
let mut last: Ident = input.parse()?;
while !input.is_empty() {
if input.peek(LeftRightArrow) {
let _: LeftRightArrow = input.parse()?;
let ident: Ident = input.parse()?;
pairs.push((last, Arrow::LeftRight, ident.clone()));
last = ident;
} else if input.peek(LeftArrow) {
let _: LeftArrow = input.parse()?;
let ident: Ident = input.parse()?;
pairs.push((last, Arrow::Left, ident.clone()));
last = ident;
} else {
let _: Token![->] = input.parse()?;
let ident: Ident = input.parse()?;
pairs.push((last, Arrow::Right, ident.clone()));
last = ident;
}
}
Ok(pairs)
}
pub fn derive_try_into_any_inner(input: TokenStream) -> TokenStream {
let DeriveInput { ident: enum_name, data, attrs, .. } = parse_macro_input!(input);
let mut try_into_impls: HashMap<Ident, HashSet<Ident>> = HashMap::new();
let mut nodes: HashSet<Ident> = HashSet::new();
for attr in attrs {
if attr.path().is_ident("transitive") {
let list = attr.meta.require_list().unwrap();
match list.parse_args_with(parse_arrows) {
Ok(pairs) => {
for (l, a, r) in pairs {
nodes.insert(l.clone());
nodes.insert(r.clone());
match a {
Arrow::Left => {
try_into_impls.entry(r).or_default().insert(l);
}
Arrow::Right => {
try_into_impls.entry(l).or_default().insert(r);
}
Arrow::LeftRight => {
let (l2, r2) = (l.clone(), r.clone());
try_into_impls.entry(l).or_default().insert(r);
try_into_impls.entry(r2).or_default().insert(l2);
}
}
}
}
Err(e) => {
panic!(
"expected pairs of variant idents separated by any of: <-, ->, <->; instead: {e}"
);
}
}
}
}
let mut try_into_paths: HashMap<(Ident, Ident), Vec<Ident>> = HashMap::new();
for target_node in nodes.drain() {
let mut visited: HashSet<Ident> = HashSet::new();
let mut next: VecDeque<(Ident, Vec<Ident>)> = VecDeque::new();
next.push_back((target_node.clone(), vec![]));
while !next.is_empty() {
let (node, path) = next.pop_front().unwrap();
if let Some(neighbors) = try_into_impls.get(&node) {
for from_node in neighbors {
if visited.contains(from_node) {
continue;
}
visited.insert(from_node.clone());
let mut path = path.clone();
path.push(from_node.clone());
next.push_back((from_node.clone(), path.clone()));
if from_node != &target_node {
try_into_paths
.insert((from_node.clone(), target_node.clone()), path);
}
}
}
}
}
let data_enum = match data {
Data::Enum(data_enum) => data_enum,
_ => panic!("TryIntoAnyVariant can only be derived for enums"),
};
let unwrap_unit = |variant: &Ident, fields: &Fields| match fields {
Fields::Unnamed(fields_unnamed) if fields_unnamed.unnamed.len() == 1 => {
fields_unnamed.unnamed.first().unwrap().clone()
}
_ => panic!("variant `{}` is not a single tuple variant", variant),
};
let mut inner_types: HashMap<Ident, Type> = HashMap::new();
for Variant { ident, fields, .. } in data_enum.variants.iter() {
let field = unwrap_unit(ident, fields);
inner_types.insert(ident.clone(), field.ty.clone());
}
let mut impls: Vec<_> = vec![];
for Variant { ident: variant, .. } in data_enum.variants.iter() {
let ty = inner_types.get(variant).unwrap();
let mut conversion_arms: Vec<_> = vec![];
for Variant { ident: inner_variant, .. } in data_enum.variants.iter() {
if variant == inner_variant {
conversion_arms.push(quote! {
#enum_name::#inner_variant(inner) => {
Ok(MaybeSplit::Just(inner))
}
});
} else if let Some(path) =
try_into_paths.get(&(variant.clone(), inner_variant.clone()))
{
let mut steps: Vec<_> = vec![];
let mut first_step = true;
for ident in path.iter() {
let ty = inner_types.get(ident).unwrap();
if first_step {
first_step = false;
steps.push(quote! {
let step: #ty = (step).try_into().map_err(|_| ())?;
});
} else {
steps.push(quote! {
let step: #ty = (&step).try_into().map_err(|_| ())?;
});
}
}
conversion_arms.push(quote! {
#enum_name::#inner_variant(inner) => {
let step = &inner;
#(#steps)*
Ok(MaybeSplit::Split(
#enum_name::#inner_variant(inner),
step,
))
}
});
} else {
conversion_arms.push(quote! {
#enum_name::#inner_variant(_) => Err(()),
});
}
}
impls.push(quote! {
impl TryInto<MaybeSplit<#enum_name, #ty>> for #enum_name {
type Error = ();
fn try_into(self) -> Result<MaybeSplit<#enum_name, #ty>, ()> {
match self {
#(#conversion_arms)*
}
}
}
});
}
let all = quote! {
#(#impls)*
};
TokenStream::from(all)
}