packetrs-impl 0.5.0

Macro-based struct serialization/deserialization
Documentation
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};

use crate::{
    model_types::{
        are_fields_named, GetParameterValue, PacketRsEnum, PacketRsEnumVariant, PacketRsField,
        PacketRsStruct,
    },
    syn_helpers::{get_ctx_type, get_ident_of_inner_type, is_collection},
};

pub(crate) fn get_crate_name() -> syn::Ident {
    let found_crate =
        proc_macro_crate::crate_name("packetrs").expect("packetrs is present in Cargo.toml");

    let crate_name = match found_crate {
        proc_macro_crate::FoundCrate::Itself => "packetrs".to_string(),
        proc_macro_crate::FoundCrate::Name(name) => name,
    };

    syn::Ident::new(&crate_name, Span::call_site())
}

/// Based on whether the 'inner' type of the given field (i.e. the type that will actually be read
/// from the buffer) is 'built-in' or not (from BitCursor's perspective), generate and return the
/// call to read the value from a buffer.
fn generate_read_call(field: &PacketRsField, read_context: &Vec<syn::Expr>) -> TokenStream {
    let inner_type = get_ident_of_inner_type(&field.ty)
        .expect(format!("Unable to get ident of inner type from: {:#?}", &field.ty).as_ref());
    quote! {
        #inner_type::read(buf, (#(#read_context),*))
    }
}

fn generate_field_read(field: &PacketRsField) -> TokenStream {
    let crate_name = get_crate_name();
    let field_name = &field.name;
    let field_ty = &field.ty;
    let error_context = field_name
        .as_ref()
        .expect(format!("Unable to get name of field for error_context {:#?}", field).as_ref())
        .to_string();

    // Generate the context assignments, if there are any.
    // TODO: we have to do the clone here so we can return an empty vec in the else case,
    // otherwise we can't return a reference to a temporary vector.  is there a better way?
    let read_context = field
        .get_caller_context_param_value()
        .map_or(Vec::new(), |c| c.clone());

    let custom_reader = field.get_custom_reader();
    let count_param = field.get_count_param_value();

    let read_call = if let Some(ref custom_reader_value) = custom_reader {
        quote! {
            #custom_reader_value(buf, (#(#read_context),*))
        }
    } else {
        let field_read_call = generate_read_call(&field, &read_context);
        if let Some(ref count_param_value) = count_param {
            quote! {
                (0..#count_param_value)
                    .map(|_| #field_read_call)
                    .map(|r| r.map_err(|e| e.into()))
                    .collect::<::#crate_name::error::PacketRsResult<#field_ty>>()
            }
        } else {
            if is_collection(field_ty) {
                panic!(
                    "Field {:?} is a collection: either a count or custom_reader param is required",
                    field_name
                );
            }
            quote! {
                #field_read_call
            }
        }
    };

    // If there is a fixed value param, generate the assertion
    let fixed_value_assertion = if let Some(fixed_value) = field.get_fixed_value() {
        let field_name_str = field_name.as_ref().unwrap().to_string();
        let fixed_value = syn::parse_str::<syn::Expr>(fixed_value.value().as_ref()).unwrap();
        quote! {
            if #field_name != #fixed_value {
                bail!("{} value didn't match: expected {}, got {}", #field_name_str, #fixed_value, #field_name);
            }
        }
    } else {
        TokenStream::new()
    };
    // If there is an assert expression, generate the assertion
    let assertion = if let Some(assertion) = field.get_assert() {
        let field_name_str = field_name.as_ref().unwrap().to_string();
        let assertion_str = quote! { #assertion }.to_string();
        quote! {
            let assert_func = #assertion;
            if !assert_func(#field_name) {
                bail!("value of field '{}' ({}) didn't pass assertion: {}", #field_name_str, #field_name, #assertion_str);
            }
        }
    } else {
        TokenStream::new()
    };

    quote! {
        let #field_name = #read_call.context(#error_context)?;
        #fixed_value_assertion
        #assertion
    }
}

/// Return a proc_macro2::TokenStream that includes local assignments for the read value of each of
/// the given fields.
fn generate_field_reads(fields: &Vec<PacketRsField>) -> TokenStream {
    let field_reads = fields
        .iter()
        .map(|f| generate_field_read(&f))
        .collect::<Vec<TokenStream>>();

    quote! {
        #(#field_reads)*
    }
}

/// Given a Vec of FnArgs, generate the context variable assignments, e.g.:
/// let foo = ctx.0;
/// let bar = ctx.1;
/// NOTE: I tried to return a Vec<syn::Local> here by doing:
///  syn::parse::<syn::Local>(
///      quote! {
///          let #fn_arg = ctx.#idx;
///      }
///      .into(),
///  )
/// But for some reason parse isn't implemented for syn::Local, so for now just returning a
/// TokenStream instead
fn generate_context_assignments(context: &Vec<syn::FnArg>) -> TokenStream {
    // If there's only a single context argument, then it won't be stored in a type so we'll assign
    // it directly
    if context.len() == 1 {
        let fn_arg = &context[0];
        quote! {
            let #fn_arg = ctx;
        }
    } else {
        let lines = context
            .iter()
            .enumerate()
            .map(|(idx, fn_arg)| {
                let idx: syn::Index = idx.into();
                quote! {
                    let #fn_arg = ctx.#idx;
                }
            })
            .collect::<Vec<proc_macro2::TokenStream>>();

        quote! {
            #(#lines)*
        }
    }
}

fn generate_struct_read_body(rs_struct: &PacketRsStruct) -> proc_macro2::TokenStream {
    let context_assignments =
        if let Some(required_ctx) = rs_struct.get_required_context_param_value() {
            generate_context_assignments(&required_ctx)
        } else {
            proc_macro2::TokenStream::new()
        };

    // If the struct has named fields, then take them directly. If not, then generate synthetic
    // field names for each of the unnamed fields, and copy the attributes from the struct itself
    // to make it more convenient.
    // TODO: way to avoid the clone here?
    let fields = if are_fields_named(&rs_struct.fields) {
        rs_struct.fields.clone()
    } else {
        rs_struct
            .fields
            .iter()
            .enumerate()
            .map(|(idx, f)| PacketRsField {
                name: Some(format_ident!("field_{}", idx)),
                ty: f.ty,
                parameters: rs_struct.parameters.clone(),
            })
            .collect()
    };
    let reads = generate_field_reads(&fields);
    let field_names = fields
        .iter()
        .map(|f| f.name.as_ref().expect("Unable to get name of named field"));
    let context_assignments_and_field_reads = quote! {
        #context_assignments
        #reads
    };
    let creation = if are_fields_named(&rs_struct.fields) {
        quote! {
            Ok(Self { #(#field_names),* })
        }
    } else {
        quote! {
            Ok(Self(#(#field_names),*))
        }
    };
    quote! {
        #context_assignments_and_field_reads
        #creation
    }
}

pub(crate) fn generate_struct(packetrs_struct: &PacketRsStruct) -> TokenStream {
    let crate_name = get_crate_name();
    let expected_context = packetrs_struct.get_required_context_param_value();
    let ctx_type = get_ctx_type(&expected_context).expect("Error getting ctx type");
    let struct_name = &packetrs_struct.name;
    let read_body = generate_struct_read_body(&packetrs_struct);
    quote! {
        impl ::#crate_name::packetrs_read::PacketrsRead<#ctx_type> for #struct_name {
            fn read(buf: &mut ::#crate_name::bitcursor::BitCursor, ctx: #ctx_type) -> ::#crate_name::error::PacketRsResult<Self> {
                #read_body
            }
        }
    }.into()
}

fn generate_match_arm(enum_name: &syn::Ident, variant: &PacketRsEnumVariant) -> TokenStream {
    let variant_name = variant.name;
    let variant_name_str = variant_name.to_string();
    let key = variant
        .get_enum_id()
        .expect(format!("Enum variant {} is missing 'id' attribute", variant_name).as_ref())
        .value();
    // TODO: this won't cover everything (like a guard on a match arm), but it's probably
    // good enough?  See https://docs.rs/syn/latest/syn/struct.Arm.html
    let key = syn::parse_str::<syn::Pat>(&key).expect("Unable to parse match pattern");

    let fields = if are_fields_named(&variant.fields) {
        variant.fields.clone()
    } else {
        variant
            .fields
            .iter()
            .enumerate()
            .map(|(idx, f)| PacketRsField {
                name: Some(format_ident!("field_{}", idx)),
                ty: f.ty,
                parameters: variant.parameters.clone(),
            })
            .collect()
    };

    let reads = generate_field_reads(&fields);
    let field_names = fields.iter().map(|f| {
        f.name
            .as_ref()
            .expect(format!("Found unnamed fields amongst named fields: {:#?}", f).as_ref())
    });
    if variant.fields.is_empty() {
        quote! {
            #key => {
                (|| {
                    Ok(#enum_name::#variant_name)
                })().context(#variant_name_str)
            }
        }
    } else if are_fields_named(&variant.fields) {
        quote! {
            #key => {
                (|| {
                    #reads
                    Ok(#enum_name::#variant_name { #(#field_names),* })
                })().context(#variant_name_str)
            }
        }
    } else {
        quote! {
            #key => {
                (|| {
                    #reads
                    Ok(#enum_name::#variant_name(#(#field_names),*))
                })().context(#variant_name_str)
            }
        }
    }
}

pub(crate) fn generate_enum(packetrs_enum: &PacketRsEnum) -> TokenStream {
    let crate_name = get_crate_name();
    let expected_context = packetrs_enum.get_required_context_param_value();
    let context_assignments = if let Some(required_ctx) = expected_context {
        generate_context_assignments(&required_ctx)
    } else {
        TokenStream::new()
    };
    let ctx_type = get_ctx_type(&expected_context).expect("Error getting ctx type");
    let enum_name = &packetrs_enum.name;
    let enum_variant_key = packetrs_enum
        .get_enum_key()
        .expect(format!("Enum {} is missing 'key' attribute", enum_name).as_ref())
        .value();

    // TODO: without this, we get quotes around the variant key in the match statement below.  is
    // there a better way?
    let enum_variant_key = syn::parse_str::<syn::Expr>(&enum_variant_key).expect(
        format!(
            "Unable to parse enum key as an expression: {}",
            enum_variant_key
        )
        .as_ref(),
    );

    let match_arms = packetrs_enum
        .variants
        .iter()
        .map(|v| generate_match_arm(&enum_name, &v))
        .collect::<Vec<proc_macro2::TokenStream>>();

    quote! {
        impl ::#crate_name::packetrs_read::PacketrsRead<#ctx_type> for #enum_name {
            fn read(buf: &mut ::#crate_name::bitcursor::BitCursor, ctx: #ctx_type) -> ::#crate_name::error::PacketRsResult<Self> {
                #context_assignments
                match #enum_variant_key {
                    #(#match_arms),*,
                    v @ _ => {
                        todo!("Value of {} is not implemented", v);
                    }
                }
            }
        }
    }.into()
}