shiv-macro-impl 0.1.0-alpha.3

Macro impl for shiv
Documentation
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
    parse_quote, punctuated::Punctuated, token::Comma, Data, DeriveInput, GenericParam, Generics,
    Ident, Index, Path, Type,
};

pub fn derive_system_param(input: DeriveInput, shiv: Path) -> proc_macro2::TokenStream {
    validate_lifetimes(&input.generics);

    let fields = fields(&input.data);
    let field_idents = field_idents(&input.data);

    let state_generics = state_generics(&input.generics, &shiv);
    let fetch_generics = fetch_generics(&input.generics);
    let read_only_generics = read_only_generics(&input.generics, &shiv);

    let (state_impl_generics, state_ty_generics, state_where_clause) =
        state_generics.split_for_impl();
    let (fetch_impl_generics, _, _) = fetch_generics.split_for_impl();
    let (_, _, read_only_where_clause) = read_only_generics.split_for_impl();

    let marker_generics = marker_generics(&input.generics);
    let fetch_ty_generics = fetch_ty_generics(&input.generics, &fields, &shiv);

    let indices = (0..fields.len()).map(|i| Index::from(i));

    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

    let vis = input.vis;
    let name = input.ident;

    quote! {
        const _: () = {
            #[automatically_derived]
            impl #impl_generics #shiv::system::SystemParam for #name #ty_generics #where_clause {
                type Fetch = FetchState<#fetch_ty_generics>;
            }

            #vis struct FetchState #state_ty_generics #state_where_clause {
                state: __TSystemParamState,
                marker: ::std::marker::PhantomData<fn() -> (#marker_generics)>,
            }

            #[automatically_derived]
            unsafe impl #state_impl_generics #shiv::system::ReadOnlySystemParamFetch for
                FetchState #state_ty_generics #read_only_where_clause
            {
            }

            #[automatically_derived]
            unsafe impl #state_impl_generics #shiv::system::SystemParamState for FetchState
                #state_ty_generics #state_where_clause
            {
                #[inline]
                fn init(
                    world: &mut #shiv::world::World,
                    meta: &mut #shiv::system::SystemMeta,
                ) -> Self {
                    Self {
                        state: __TSystemParamState::init(world, meta),
                        marker: ::std::marker::PhantomData,
                    }
                }

                #[inline]
                fn apply(&mut self, world: &mut #shiv::world::World) {
                    self.state.apply(world);
                }
            }

            #[automatically_derived]
            impl #fetch_impl_generics #shiv::system::SystemParamFetch<'w, 's> for
                FetchState<#fetch_ty_generics> #state_where_clause
            {
                type Item = #name #ty_generics;

                #[inline]
                #[allow(dead_code)]
                unsafe fn get_param(
                    &'s mut self,
                    meta: &#shiv::system::SystemMeta,
                    world: &'w #shiv::world::World,
                    change_ticks: ::std::primitive::u32,
                ) -> Self::Item {
                    let param = #shiv::system::SystemParamFetch::get_param(
                        &mut self.state,
                        meta,
                        world,
                        change_ticks
                    );

                    #name {#(#field_idents: param.#indices,)*}
                }
            }
        };
    }
}

fn validate_lifetimes(generics: &Generics) {
    for lifetime in generics.lifetimes() {
        let ident = &lifetime.lifetime.ident;

        if !(ident == "w" || ident == "s") {
            panic!(
                "Invalid lifetime: {}, only valid lifetimes are 'w and 's",
                ident
            );
        }
    }
}

fn has_lifetime(generics: &Generics, lifetime: &str) -> bool {
    for lt in generics.lifetimes() {
        if lt.lifetime.ident == lifetime {
            return true;
        }
    }

    false
}

fn fetch_generics(generics: &Generics) -> Generics {
    let mut generics = generics.clone();

    if !has_lifetime(&generics, "w") {
        generics.params.push(parse_quote!('w));
    }

    if !has_lifetime(&generics, "s") {
        generics.params.push(parse_quote!('s));
    }

    generics
}

fn state_generics(generics: &Generics, shiv: &Path) -> Generics {
    let mut generics = generics.clone();

    generics.params = generics
        .params
        .clone()
        .into_pairs()
        .filter(|param| match param.value() {
            syn::GenericParam::Lifetime(_) => false,
            _ => true,
        })
        .collect();

    generics.params.push(parse_quote!(
        __TSystemParamState: #shiv::system::SystemParamState
    ));

    generics.make_where_clause().predicates.push(parse_quote!(
        Self: ::std::marker::Send + ::std::marker::Sync + 'static
    ));

    generics
}

fn read_only_generics(generics: &Generics, shiv: &Path) -> Generics {
    let mut generics = generics.clone();

    let where_clause = generics.make_where_clause();
    where_clause.predicates.push(parse_quote!(
        __TSystemParamState: #shiv::system::ReadOnlySystemParamFetch
    ));
    where_clause.predicates.push(parse_quote!(
        Self: for<'w, 's> #shiv::system::SystemParamFetch<'w, 's>
    ));

    generics
}

fn marker_generics(generics: &Generics) -> Punctuated<TokenStream, Comma> {
    let mut marker_generics = Punctuated::<TokenStream, Comma>::new();
    for generic in generics.params.iter() {
        if let GenericParam::Type(ty) = generic {
            let ident = &ty.ident;
            marker_generics.push(parse_quote!(#ident));
        }
    }

    marker_generics
}

fn fetch_ty_generics(
    generics: &Generics,
    fields: &[Type],
    shiv: &Path,
) -> Punctuated<TokenStream, Comma> {
    let mut fetch_ty_generics = Punctuated::<TokenStream, Comma>::new();
    for generic in generics.params.iter() {
        if let GenericParam::Type(ty) = generic {
            let ident = &ty.ident;
            fetch_ty_generics.push(parse_quote!(#ident));
        }
    }

    fetch_ty_generics.push(quote!((#(<#fields as #shiv::system::SystemParam>::Fetch,)*)));

    fetch_ty_generics
}

fn fields(data: &Data) -> Vec<Type> {
    match data {
        Data::Struct(s) => match &s.fields {
            syn::Fields::Named(fields) => {
                fields.named.iter().map(|field| field.ty.clone()).collect()
            }
            syn::Fields::Unnamed(_) => unimplemented!("Unnamed fields are not supported"),
            syn::Fields::Unit => Vec::new(),
        },
        _ => unimplemented!("Only structs are supported"),
    }
}

fn field_idents(data: &Data) -> Vec<Ident> {
    match data {
        Data::Struct(s) => match &s.fields {
            syn::Fields::Named(fields) => fields
                .named
                .iter()
                .map(|field| field.ident.clone().unwrap())
                .collect(),
            syn::Fields::Unnamed(_) => unimplemented!("Unnamed fields are not supported"),
            syn::Fields::Unit => Vec::new(),
        },
        _ => unimplemented!("Only structs are supported"),
    }
}