bird-protocol-derive 0.1.3

Packet, PacketReadable, PacketWritable derives for bird-protocol crate
Documentation
use std::collections::HashMap;
use proc_macro2::{Ident, Span, TokenStream};
use proc_macro_crate::{crate_name, FoundCrate};
use quote::{quote, ToTokens};
use syn::{Attribute, Data, DataEnum, DataStruct, DeriveInput, Expr, ExprAssign, Field, Fields, GenericParam, Generics, LifetimeDef, Lit, parse_quote, Path, PathArguments, PathSegment, Token};
use syn::parse::{Parse, ParseStream};
use syn::parse_quote::ParseQuote;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Colon2;

pub const FIELD_ATTRIBUTES: &[&str] = &["variant", "var", "order"];
pub const DATA_ATTRIBUTES: &[&str] = &["lifetime", "enum_type", "enum_variant"];
pub const VARIANT_ATTRIBUTES: &[&str] = &["value"];

#[derive(Debug, Clone)]
pub struct FieldAttributes {
    pub order: Option<usize>,
    pub variant: Option<TokenStream>,
}

#[derive(Debug, Clone)]
pub struct DataAttributes {
    pub enum_type: Option<TokenStream>,
    pub enum_variant: Option<TokenStream>,
    pub lead_lifetime: Option<TokenStream>,
}

#[derive(Debug, Clone)]
pub struct VariantAttributes {
    pub value: Option<TokenStream>,
}

#[derive(Debug, Clone)]
pub struct PacketAttributes {
    pub bound: TokenStream,
    pub state: TokenStream,
    pub id: TokenStream,
}

pub trait FieldVisitor {
    fn visit(&mut self, ident: Ident, field: &Field, attributes: FieldAttributes) -> syn::Result<()>;
}

pub trait VariantVisitor {
    fn visit(&mut self, ident: Path, fields: &Fields, value: Option<TokenStream>, attributes: VariantAttributes) -> syn::Result<()>;
}

pub fn visit_fields(fields: &Fields, visitor: &mut impl FieldVisitor) -> syn::Result<()> {
    match fields {
        Fields::Named(named) => for field in &named.named {
            visitor.visit(
                field.ident.clone().unwrap(),
                field,
                get_attributes(FIELD_ATTRIBUTES, &field.attrs)?.try_into()?,
            )?;
        },
        Fields::Unnamed(unnamed) => {
            let mut counter = 0usize;
            for field in &unnamed.unnamed {
                visitor.visit(
                    Ident::new(format!("__{}", counter).as_str(), field.span()),
                    field,
                    get_attributes(FIELD_ATTRIBUTES, &field.attrs)?.try_into()?,
                )?;
                if counter != usize::MAX {
                    counter += 1;
                }
            }
        }
        Fields::Unit => {}
    }
    Ok(())
}

pub fn visit_derive_input(derive_input: &DeriveInput, visitor: &mut impl VariantVisitor) -> syn::Result<()> {
    match derive_input.data {
        Data::Struct(ref data_struct) =>
            visit_struct(&derive_input.ident, &derive_input.attrs, data_struct, visitor),
        Data::Enum(ref data_enum) =>
            visit_enum(&derive_input.ident, &derive_input.attrs, data_enum, visitor),
        Data::Union(ref data_union) => Err(syn::Error::new(
            data_union.union_token.span, "Union types is not supported, yet. Use enum instead",
        ))
    }
}

pub fn visit_enum(
    enum_ident: &Ident,
    _attributes: &Vec<Attribute>,
    data_enum: &DataEnum,
    visitor: &mut impl VariantVisitor,
) -> syn::Result<()> {
    let mut start = quote! { 0 };
    let mut counter = -1isize;
    for variant in &data_enum.variants {
        let variant_attributes: VariantAttributes =
            get_attributes(VARIANT_ATTRIBUTES, &variant.attrs)?.try_into()?;
        match variant_attributes.value.as_ref().or(
            variant.discriminant.as_ref().map(|(_, expr)| expr.to_token_stream()).as_ref()
        ) {
            Some(value) => {
                start = value.clone();
                counter = 0;
            }
            None => {
                counter += 1;
            }
        }
        visitor.visit(
            Path {
                leading_colon: None,
                segments: {
                    let mut res = Punctuated::new();
                    res.push_value(PathSegment {
                        ident: enum_ident.clone(),
                        arguments: PathArguments::None,
                    });
                    res.push_punct(Colon2::default());
                    res.push_value(PathSegment {
                        ident: variant.ident.clone(),
                        arguments: PathArguments::None,
                    });
                    res
                },
            },
            &variant.fields,
            Some(quote! { #start + #counter }),
            variant_attributes,
        )?
    }
    Ok(())
}

pub fn visit_struct(
    struct_ident: &Ident,
    attributes: &Vec<Attribute>,
    data_struct: &DataStruct,
    visitor: &mut impl VariantVisitor,
) -> syn::Result<()> {
    visitor.visit(
        Path {
            leading_colon: None,
            segments: {
                let mut res = Punctuated::new();
                res.push_value(PathSegment {
                    ident: struct_ident.clone(),
                    arguments: PathArguments::None,
                });
                res
            },
        },
        &data_struct.fields,
        None,
        get_attributes(DATA_ATTRIBUTES, attributes)?.try_into()?,
    )
}

pub fn get_attributes<'a>(names: &[&'a str], attributes: &Vec<Attribute>) -> syn::Result<HashMap<&'a str, Expr>> {
    let mut res = HashMap::new();
    for attribute in attributes {
        for name in names {
            if attribute.path.is_ident(name) {
                res.insert(*name, attribute.parse_args()?);
                break;
            }
        }
    }
    Ok(res)
}

pub fn expr_to_usize(expr: &Expr) -> syn::Result<usize> {
    if let Expr::Lit(ref lit) = expr {
        if let Lit::Int(ref int) = lit.lit {
            return int.base10_parse();
        }
    }
    Err(syn::Error::new(expr.span(), "Must be positive integer"))
}

impl TryFrom<HashMap<&str, Expr>> for FieldAttributes {
    type Error = syn::Error;

    fn try_from(value: HashMap<&str, Expr>) -> Result<Self, Self::Error> {
        Ok(FieldAttributes {
            order: match value.get("order") {
                Some(expr) => Some(expr_to_usize(expr)?),
                None => None,
            },
            variant: value.get("variant")
                .or(value.get("var"))
                .map(|expr| expr.to_token_stream()),
        })
    }
}

impl TryFrom<HashMap<&str, Expr>> for DataAttributes {
    type Error = syn::Error;

    fn try_from(value: HashMap<&str, Expr>) -> Result<Self, Self::Error> {
        Ok(DataAttributes {
            enum_type: value.get("enum_type")
                .map(|expr| expr.to_token_stream()),
            enum_variant: value.get("enum_variant")
                .map(|expr| expr.to_token_stream()),
            lead_lifetime: value.get("lifetime")
                .map(|expr| expr.to_token_stream()),
        })
    }
}

impl TryFrom<HashMap<&str, Expr>> for VariantAttributes {
    type Error = syn::Error;

    fn try_from(attrs: HashMap<&str, Expr>) -> Result<Self, Self::Error> {
        Ok(VariantAttributes {
            value: attrs.get("value")
                .map(|expr| expr.to_token_stream())
        })
    }
}

pub fn get_bird_protocol_crate() -> TokenStream {
    let found_crate = crate_name("bird-protocol").unwrap();
    match found_crate {
        FoundCrate::Itself => quote! {crate},
        FoundCrate::Name(name) => {
            let ident = Ident::new(name.as_str(), Span::call_site());
            quote! {#ident}
        }
    }
}

pub fn add_trait_lifetime(generics: &mut Generics, lifetime: TokenStream) {
    generics.params.push(parse_quote! { #lifetime })
}

pub fn get_lifetimes(generics: &Generics) -> Vec<LifetimeDef> {
    let mut result = Vec::new();
    for param in &generics.params {
        if let GenericParam::Lifetime(ref lifetime) = param {
            result.push(lifetime.clone())
        }
    }
    result
}

fn collect_map_attribute(input: ParseStream) -> syn::Result<HashMap<String, Expr>> {
    let punctuated: Punctuated<ExprAssign, Token![,]> = Punctuated::parse(input)?;
    let mut result = HashMap::new();
    for value in punctuated {
        let name = match value.left.as_ref() {
            Expr::Path(ref expr_path) => expr_path.path
                .segments
                .iter()
                .map(|segment| segment.ident.to_string())
                .collect::<Vec<String>>()
                .join("::"),
            Expr::Lit(ref lit) => match lit.lit {
                Lit::Str(ref lit) => lit.value(),
                _ => return Err(syn::Error::new(
                    value.left.span(), "Only string lit supported or path")
                )
            },
            _ => return Err(syn::Error::new(
                value.left.span(), "Only string lit supported or path")
            )
        };
        result.insert(name, value.right.as_ref().clone());
    }
    Ok(result)
}

impl Parse for PacketAttributes {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let attrs = collect_map_attribute(input)?;
        Ok(Self {
            // TODO handling
            id: attrs.get("id").unwrap().to_token_stream(),
            bound: attrs.get("bound").unwrap().to_token_stream(),
            state: attrs.get("state").unwrap().to_token_stream()
        })
    }
}