use crate::{
binrw::{
codegen::generate_impl,
parser::{Enum, EnumVariant, Input, ParseResult, Struct, StructField},
Options,
},
combine_error,
};
use quote::quote;
use std::collections::HashSet;
use syn::{spanned::Spanned, DeriveInput};
pub(crate) fn derive(mut derive_input: DeriveInput) -> proc_macro2::TokenStream {
let mut binread_input = Input::from_input(
&derive_input,
Options {
derive: false,
write: false,
},
);
let mut binwrite_input = Input::from_input(
&derive_input,
Options {
derive: false,
write: true,
},
);
if let Some(error) = apply_temp_crossover(&mut binread_input, &mut binwrite_input) {
binwrite_input = ParseResult::Partial(binwrite_input.unwrap_tuple().0, error);
}
let generated_read_impl = generate_impl::<false>(&derive_input, &binread_input);
let generated_write_impl = generate_impl::<true>(&derive_input, &binwrite_input);
super::clean_attr(&mut derive_input, binread_input.ok().as_ref());
quote!(
#derive_input
#generated_read_impl
#generated_write_impl
)
}
#[rustfmt::skip]
fn apply_temp_crossover(
binread_result: &mut ParseResult<Input>,
binwrite_result: &mut ParseResult<Input>,
) -> Option<syn::Error> {
let (ParseResult::Ok(binread_input), ParseResult::Ok(binwrite_input)) = (binread_result, binwrite_result) else {
return None;
};
match (binread_input, binwrite_input) {
(Input::Struct(binread_struct), Input::Struct(binwrite_struct)) => {
apply_temp_crossover_struct(binread_struct, binwrite_struct)
}
(Input::Enum(binread_enum), Input::Enum(binwrite_enum)) => {
apply_temp_crossover_enum(binread_enum, binwrite_enum)
}
(Input::UnitStruct(_), Input::UnitStruct(_))
| (Input::UnitOnlyEnum(_), Input::UnitOnlyEnum(_)) => None,
_ => unreachable!("read and write input should always be the same kind"),
}
}
fn apply_temp_crossover_enum(
binread_enum: &mut Enum,
binwrite_enum: &mut Enum,
) -> Option<syn::Error> {
let mut all_errors = None::<syn::Error>;
for (read_variant, write_variant) in binread_enum
.variants
.iter_mut()
.zip(binwrite_enum.variants.iter_mut())
{
match (read_variant, write_variant) {
(
EnumVariant::Variant {
options: read_struct,
..
},
EnumVariant::Variant {
options: write_struct,
..
},
) => {
if let Some(error) = apply_temp_crossover_struct(read_struct, write_struct) {
combine_error(&mut all_errors, error);
}
}
(EnumVariant::Unit(_), EnumVariant::Unit(_)) => {}
_ => unreachable!("read and write input should always be the same kind"),
}
}
all_errors
}
fn apply_temp_crossover_struct(
binread_struct: &mut Struct,
binwrite_struct: &mut Struct,
) -> Option<syn::Error> {
let read_temporary = extract_temporary_field_names(&binread_struct.fields, false);
let write_temporary = extract_temporary_field_names(&binwrite_struct.fields, true);
if let Some(error) = validate_fields_temporary(&binwrite_struct.fields, &read_temporary) {
return Some(error);
}
set_fields_temporary(&mut binread_struct.fields, &write_temporary);
set_fields_temporary(&mut binwrite_struct.fields, &read_temporary);
None
}
fn validate_fields_temporary(
fields: &[StructField],
read_temporary: &HashSet<syn::Ident>,
) -> Option<syn::Error> {
let mut all_errors = None::<syn::Error>;
for field in fields {
if read_temporary.contains(&field.ident) && !field.generated_value() {
combine_error(
&mut all_errors,
syn::Error::new(
field.field.span(),
"`#[br(temp)]` is invalid without a corresponding `#[bw(ignore)]`, `#[bw(calc)]`, or `#[bw(try_calc)]`",
),
);
}
}
all_errors
}
fn extract_temporary_field_names(fields: &[StructField], for_write: bool) -> HashSet<syn::Ident> {
fields
.iter()
.filter(|f| f.is_temp(for_write))
.map(|f| f.ident.clone())
.collect()
}
fn set_fields_temporary(fields: &mut [StructField], temporary_names: &HashSet<syn::Ident>) {
for field in fields {
if temporary_names.contains(&field.ident) {
field.force_temp();
}
}
}