use std::collections::HashSet;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Error};
#[proc_macro_derive(Deinterleave)]
pub fn derive_deinterleave(tokens: TokenStream) -> TokenStream {
let derive_input = parse_macro_input!(tokens as DeriveInput);
process_derive_input(derive_input).into()
}
fn process_derive_input(derive_input: DeriveInput) -> proc_macro2::TokenStream {
let Data::Struct(data_struct) = derive_input.data else {
return Error::new_spanned(derive_input, "expected a struct").into_compile_error();
};
let struct_ident = derive_input.ident;
let field_count = data_struct.fields.len();
let mut seen_field_ty = HashSet::new();
let mut has_fields = Vec::default();
let mut push_to = Vec::default();
let mut remove_from = Vec::default();
let mut swap_remove_from = Vec::default();
let mut iter_field_ty = Vec::default();
let mut iter_field = Vec::default();
let mut iter_mut_field_ty = Vec::default();
let mut iter_mut_field = Vec::default();
for (i, f) in data_struct.fields.into_iter().enumerate() {
let field_ty = f.ty;
let field_ident = f.ident;
if seen_field_ty.contains(&field_ty) {
return Error::new_spanned(field_ty.clone(), "type seen more than once (you can only use a type once in the fields of a deinterleaved struct)").to_compile_error();
}
seen_field_ty.insert(field_ty.clone());
iter_field_ty.push(quote! {
#field_ident: std::slice::Iter<'a, #field_ty>
});
iter_field.push(quote! {
#field_ident: unsafe { std::mem::transmute::<&Vec<u8>, &Vec<#field_ty>>(&data[#i]) }.iter()
});
iter_mut_field_ty.push(quote! {
#field_ident: std::slice::IterMut<'a, #field_ty>
});
iter_mut_field.push(quote! {
#field_ident: unsafe { std::mem::transmute::<&mut Vec<u8>, &mut Vec<#field_ty>>(&mut data[#i]) }.iter_mut()
});
has_fields.push(quote! {
impl deinterleave::HasField<#field_ty> for #struct_ident {
const INDEX: usize = #i;
}
});
push_to.push(quote! {
unsafe { std::mem::transmute::<&mut Vec<u8>, &mut Vec<#field_ty>>(&mut data[#i]) }.push(value.#field_ident)
});
remove_from.push(quote! {
#field_ident : unsafe { std::mem::transmute::<&mut Vec<u8>, &mut Vec<#field_ty>>(&mut data[#i]) }.remove(i)
});
swap_remove_from.push(quote! {
#field_ident : unsafe { std::mem::transmute::<&mut Vec<u8>, &mut Vec<#field_ty>>(&mut data[#i]) }.swap_remove(i)
});
}
let iter_ident = format_ident!("{struct_ident}Iter");
let iter_mut_ident = format_ident!("{struct_ident}IterMut");
quote! {
#(#has_fields)*
#[derive(Debug)]
pub struct #iter_ident<'a> {
_phantom: std::marker::PhantomData<&'a ()>,
#(pub #iter_field_ty),*
}
#[derive(Debug)]
pub struct #iter_mut_ident<'a> {
_phantom: std::marker::PhantomData<&'a mut ()>,
#(pub #iter_mut_field_ty),*
}
impl deinterleave::Deinterleave for #struct_ident {
type Iter<'a> = #iter_ident<'a>;
type IterMut<'a> = #iter_mut_ident<'a>;
fn new_data() -> Vec<Vec<u8>> {
vec![Vec::default(); #field_count]
}
fn iter(data: &Vec<Vec<u8>>) -> Self::Iter<'_> {
#iter_ident {
_phantom: std::marker::PhantomData,
#(#iter_field),*
}
}
fn iter_mut(data: &mut Vec<Vec<u8>>) -> Self::IterMut<'_> {
#iter_mut_ident {
_phantom: std::marker::PhantomData,
#(#iter_mut_field),*
}
}
fn push_to(data: &mut Vec<Vec<u8>>, value: Self) {
#(#push_to);*
}
fn remove_from(data: &mut Vec<Vec<u8>>, i: usize) -> Self {
Self {
#(#remove_from),*
}
}
fn swap_remove_from(data: &mut Vec<Vec<u8>>, i: usize) -> Self {
Self {
#(#swap_remove_from),*
}
}
}
}
}