architect-derive 0.4.0

Architect.xyz Trading Platform API, proc-macros
Documentation
use proc_macro::TokenStream;
use quote::quote;
use std::collections::{HashMap, HashSet, VecDeque};
use syn::{
    custom_punctuation, parse::ParseStream, parse_macro_input, Data, DeriveInput, Fields,
    Ident, Result, Token, Type, Variant,
};

custom_punctuation!(LeftRightArrow, <->);
custom_punctuation!(LeftArrow, <-);

enum Arrow {
    Left,
    Right,
    LeftRight,
}

fn parse_arrows(input: ParseStream) -> Result<Vec<(Ident, Arrow, Ident)>> {
    let mut pairs = vec![];
    let mut last: Ident = input.parse()?;
    while !input.is_empty() {
        if input.peek(LeftRightArrow) {
            let _: LeftRightArrow = input.parse()?;
            let ident: Ident = input.parse()?;
            pairs.push((last, Arrow::LeftRight, ident.clone()));
            last = ident;
        } else if input.peek(LeftArrow) {
            let _: LeftArrow = input.parse()?;
            let ident: Ident = input.parse()?;
            pairs.push((last, Arrow::Left, ident.clone()));
            last = ident;
        } else {
            let _: Token![->] = input.parse()?;
            let ident: Ident = input.parse()?;
            pairs.push((last, Arrow::Right, ident.clone()));
            last = ident;
        }
    }
    Ok(pairs)
}

pub fn derive_try_into_any_inner(input: TokenStream) -> TokenStream {
    let DeriveInput { ident: enum_name, data, attrs, .. } = parse_macro_input!(input);
    let mut try_into_impls: HashMap<Ident, HashSet<Ident>> = HashMap::new();
    let mut nodes: HashSet<Ident> = HashSet::new();
    for attr in attrs {
        if attr.path().is_ident("transitive") {
            let list = attr.meta.require_list().unwrap();
            match list.parse_args_with(parse_arrows) {
                Ok(pairs) => {
                    for (l, a, r) in pairs {
                        nodes.insert(l.clone());
                        nodes.insert(r.clone());
                        match a {
                            Arrow::Left => {
                                try_into_impls.entry(r).or_default().insert(l);
                            }
                            Arrow::Right => {
                                try_into_impls.entry(l).or_default().insert(r);
                            }
                            Arrow::LeftRight => {
                                let (l2, r2) = (l.clone(), r.clone());
                                try_into_impls.entry(l).or_default().insert(r);
                                try_into_impls.entry(r2).or_default().insert(l2);
                            }
                        }
                    }
                }
                Err(e) => {
                    panic!(
                        "expected pairs of variant idents separated by any of: <-, ->, <->; instead: {e}"
                    );
                }
            }
        }
    }
    // enumerate the transitive closure of try_into conversions via BFS
    // (Z, A) => [<- B <- C <- ... <- Z] exists iff Z can try_into A via the closed path
    let mut try_into_paths: HashMap<(Ident, Ident), Vec<Ident>> = HashMap::new();
    for target_node in nodes.drain() {
        let mut visited: HashSet<Ident> = HashSet::new();
        let mut next: VecDeque<(Ident, Vec<Ident>)> = VecDeque::new();
        next.push_back((target_node.clone(), vec![]));
        while !next.is_empty() {
            let (node, path) = next.pop_front().unwrap();
            if let Some(neighbors) = try_into_impls.get(&node) {
                for from_node in neighbors {
                    if visited.contains(from_node) {
                        continue;
                    }
                    visited.insert(from_node.clone());
                    let mut path = path.clone();
                    path.push(from_node.clone());
                    next.push_back((from_node.clone(), path.clone()));
                    if from_node != &target_node {
                        try_into_paths
                            .insert((from_node.clone(), target_node.clone()), path);
                    }
                }
            }
        }
    }
    let data_enum = match data {
        Data::Enum(data_enum) => data_enum,
        _ => panic!("TryIntoAnyVariant can only be derived for enums"),
    };
    let unwrap_unit = |variant: &Ident, fields: &Fields| match fields {
        Fields::Unnamed(fields_unnamed) if fields_unnamed.unnamed.len() == 1 => {
            fields_unnamed.unnamed.first().unwrap().clone()
        }
        _ => panic!("variant `{}` is not a single tuple variant", variant),
    };
    let mut inner_types: HashMap<Ident, Type> = HashMap::new();
    for Variant { ident, fields, .. } in data_enum.variants.iter() {
        let field = unwrap_unit(ident, fields);
        inner_types.insert(ident.clone(), field.ty.clone());
    }
    let mut impls: Vec<_> = vec![];
    for Variant { ident: variant, .. } in data_enum.variants.iter() {
        let ty = inner_types.get(variant).unwrap();
        let mut conversion_arms: Vec<_> = vec![];
        for Variant { ident: inner_variant, .. } in data_enum.variants.iter() {
            if variant == inner_variant {
                conversion_arms.push(quote! {
                    #enum_name::#inner_variant(inner) => {
                        Ok(MaybeSplit::Just(inner))
                    }
                });
            } else if let Some(path) =
                try_into_paths.get(&(variant.clone(), inner_variant.clone()))
            {
                let mut steps: Vec<_> = vec![];
                let mut first_step = true;
                for ident in path.iter() {
                    let ty = inner_types.get(ident).unwrap();
                    if first_step {
                        first_step = false;
                        steps.push(quote! {
                            let step: #ty = (step).try_into().map_err(|_| ())?;
                        });
                    } else {
                        steps.push(quote! {
                            let step: #ty = (&step).try_into().map_err(|_| ())?;
                        });
                    }
                }
                conversion_arms.push(quote! {
                    #enum_name::#inner_variant(inner) => {
                        let step = &inner;
                        #(#steps)*
                        Ok(MaybeSplit::Split(
                            #enum_name::#inner_variant(inner),
                            step,
                        ))
                    }
                });
            } else {
                conversion_arms.push(quote! {
                    #enum_name::#inner_variant(_) => Err(()),
                });
            }
        }
        impls.push(quote! {
            impl TryInto<MaybeSplit<#enum_name, #ty>> for #enum_name {
                type Error = ();

                fn try_into(self) -> Result<MaybeSplit<#enum_name, #ty>, ()> {
                    match self {
                        #(#conversion_arms)*
                    }
                }
            }
        });
    }
    // concatenate all generated impls
    let all = quote! {
        #(#impls)*
    };
    TokenStream::from(all)
}