lightning-wire-msgs-derive-base 0.2.6

derive macros for defining serialization and deserialization of lightning wire messages
Documentation
use proc_macro2::Span;
use proc_macro2::TokenStream;
use quote::quote;

enum Subset {
    Both,
    Reader,
    Writer,
}

pub fn impl_trait(ast: &syn::DeriveInput) -> TokenStream {
    use syn::Data::*;
    match &ast.data {
        Struct(ref s) => impl_trait_struct(&ast.ident, &ast.attrs, s, &ast.generics, Subset::Both),
        _ => unimplemented!(),
    }
}

pub fn impl_writer(ast: &syn::DeriveInput) -> TokenStream {
    use syn::Data::*;
    match &ast.data {
        Struct(ref s) => {
            impl_trait_struct(&ast.ident, &ast.attrs, s, &ast.generics, Subset::Writer)
        }
        _ => unimplemented!(),
    }
}

pub fn impl_reader(ast: &syn::DeriveInput) -> TokenStream {
    use syn::Data::*;
    match &ast.data {
        Struct(ref s) => {
            impl_trait_struct(&ast.ident, &ast.attrs, s, &ast.generics, Subset::Reader)
        }
        _ => unimplemented!(),
    }
}

fn def_decode(
    name: &syn::Ident,
    field: &[syn::Member],
    tlv_type: &[Option<syn::Lit>],
) -> proc_macro2::TokenStream {
    let wire_item_read_expr = tlv_type.iter().map(|t| {
        if let Some(t) = t {
            quote! {
                lightning_wire_msgs::TLVWireItemReader::decode_tlv(&mut peek_reader, #t)?
            }
        } else {
            quote! {
                lightning_wire_msgs::WireItemReader::decode(&mut peek_reader)?
            }
        }
    });

    quote! {
        fn decode<R: std::io::Read>(reader: &mut R, check_type: bool) -> std::io::Result<Self> {
            if check_type {
                let mut msg_type = [0_u8; 2];
                reader.read_exact(&mut msg_type)?;
                let msg_type = u16::from_be_bytes(msg_type);
                if msg_type != Self::MSG_TYPE {
                    return Err(std::io::Error::from(std::io::ErrorKind::InvalidData));
                }
            }
            let mut peek_reader = lightning_wire_msgs::PeekReader::from(reader);

            Ok(#name {
                #(
                    #field: #wire_item_read_expr,
                )*
            })
        }
    }
}

fn def_encode(field: &[syn::Member], tlv_type: &[Option<syn::Lit>]) -> proc_macro2::TokenStream {
    let wire_item_write_expr = field.iter().enumerate().map(|(i, field)| {
        if let Some(ref t) = &tlv_type[i] {
            quote! {
                if let Some(ref field) = &self.#field {
                    count += lightning_wire_msgs::TLVWireItemWriter::encode_tlv(field, w, #t)?;
                }
            }
        } else {
            quote! {
                count += lightning_wire_msgs::WireItemWriter::encode(&self.#field, w)?;
            }
        }
    });
    quote! {
        fn encode<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<usize> {
            let mut count = 0;
            count += w.write(&u16::to_be_bytes(Self::MSG_TYPE))?;
            #(
                #wire_item_write_expr
            )*
            w.flush()?;
            Ok(count)
        }
    }
}

fn impl_trait_struct(
    name: &syn::Ident,
    attrs: &Vec<syn::Attribute>,
    struct_data: &syn::DataStruct,
    generics: &syn::Generics,
    subset: Subset,
) -> TokenStream {
    let num = attrs
        .iter()
        .filter_map(|a| match a.parse_meta() {
            Ok(m) => match m {
                syn::Meta::NameValue(nv) => {
                    if nv.path.is_ident("msg_type") {
                        Some(nv.lit)
                    } else {
                        None
                    }
                }
                _ => None,
            },
            Err(_) => None,
        })
        .next()
        .expect("missing attribute \"msg_type\"\n\nhelp: add #[msg_type = ...]");
    let mut tlv = None;
    let field_mapper = |(i, f): (usize, &syn::Field)| -> (syn::Member, Option<syn::Lit>) {
        let mut new_tlv = None;
        let res = (
            f.ident
                .as_ref()
                .map(|id| syn::Member::Named(id.clone()))
                .unwrap_or_else(|| {
                    syn::Member::Unnamed(syn::Index {
                        index: i as u32,
                        span: Span::call_site(),
                    })
                }),
            f.attrs
                .iter()
                .filter_map(|a| match a.parse_meta() {
                    Ok(m) => match m {
                        syn::Meta::NameValue(nv) => {
                            if nv.path.is_ident("tlv_type") {
                                if let syn::Lit::Int(ref lit) = nv.lit {
                                    new_tlv = Some(
                                        lit.base10_parse::<u64>().expect("tlv_type must be a u64"),
                                    );
                                    Some(nv.lit)
                                } else {
                                    panic!("tlv_type must be a u64")
                                }
                            } else {
                                None
                            }
                        }
                        _ => None,
                    },
                    Err(_) => None,
                })
                .next(),
        );
        match (tlv, new_tlv) {
            (Some(_), None) => panic!("tlv stream must occur after expected fields"),
            (Some(old), Some(new)) if old > new => {
                panic!("tlv stream must be monotonically increasing by type")
            }
            (_, Some(_)) => match &f.ty {
                syn::Type::Path(ref p)
                    if p.path.segments.last().expect("missing type").ident == "Option" =>
                {
                    match &p.path.segments.last().unwrap().arguments {
                        syn::PathArguments::AngleBracketed(a) => {
                            (match a.args.first().expect("tlv value must be Option") {
                                syn::GenericArgument::Type(t) => t.clone(),
                                _ => panic!("tlv value must be Option"),
                            })
                        }
                        _ => panic!("tlv value must be Option"),
                    };
                }
                _ => panic!("tlv value must be Option"),
            },
            _ => (),
        };
        tlv = new_tlv;
        res
    };
    let punc = syn::punctuated::Punctuated::<syn::Field, ()>::new();
    let (field, tlv_type): (Vec<syn::Member>, Vec<Option<syn::Lit>>) = match &struct_data.fields {
        syn::Fields::Named(n) => n.named.iter(),
        syn::Fields::Unnamed(n) => n.unnamed.iter(),
        syn::Fields::Unit => punc.iter(),
    }
    .enumerate()
    .map(field_mapper)
    .unzip();

    let type_params: Vec<syn::GenericParam> = generics
        .params
        .iter()
        .map(|gparam| match gparam {
            syn::GenericParam::Type(tp) => {
                let mut tp = tp.clone();
                tp.bounds = syn::punctuated::Punctuated::new();
                syn::GenericParam::Type(tp)
            }
            syn::GenericParam::Lifetime(ltp) => {
                let mut ltp = ltp.clone();
                ltp.bounds = syn::punctuated::Punctuated::new();
                syn::GenericParam::Lifetime(ltp)
            }
            a => a.clone(),
        })
        .collect();
    let generics_stripped = {
        let mut generics = generics.clone();
        generics.params = syn::punctuated::Punctuated::new();
        for param in type_params.iter() {
            generics.params.push(param.clone());
        }
        generics.where_clause = None;
        generics
    };
    let generics_params = &generics.params;
    let generics_where_clause = &generics.where_clause;
    let decode = def_decode(name, &field, &tlv_type);
    let encode = def_encode(&field, &tlv_type);
    let gen = match subset {
        Subset::Both => quote! {
            impl<#generics_params> lightning_wire_msgs::WireMessage for #name#generics_stripped #generics_where_clause {
                const MSG_TYPE: u16 = #num;
                #encode
                #decode
            }
        },
        Subset::Writer => quote! {
            impl<#generics_params> lightning_wire_msgs::WireMessageWriter for #name#generics_stripped #generics_where_clause {
                const MSG_TYPE: u16 = #num;
                #encode
            }
        },
        Subset::Reader => quote! {
            impl<#generics_params> lightning_wire_msgs::WireMessageReader for #name#generics_stripped #generics_where_clause {
                const MSG_TYPE: u16 = #num;
                #decode
            }
        },
    };
    gen.into()
}