rsrl_derive 0.1.0

Macros for RSRL
Documentation
use proc_macro2::TokenStream;
use quote::ToTokens;
use std::iter;
use syn::{Data, DataStruct, Field, Fields, Generics, Ident, Meta, Type};

const WEIGHTS: &str = "weights";

struct Implementation {
    pub generics: Option<Vec<Type>>,
    pub body: Body,
}

impl Implementation {
    pub fn concrete(body: Body) -> Implementation {
        Implementation {
            generics: None,
            body,
        }
    }

    pub fn with_generics(generics: Vec<Type>, body: Body) -> Implementation {
        Implementation {
            generics: Some(generics),
            body,
        }
    }

    fn make_where_predicates(types: &Vec<Type>) -> Vec<syn::WherePredicate> {
        types
            .into_iter()
            .map(|g| parse_quote! { #g: crate::params::Parameterised })
            .collect()
    }

    pub fn add_trait_bounds(&self, mut generics: Generics) -> Generics {
        if let Some(ref gs) = self.generics {
            let new_ps = Self::make_where_predicates(gs);

            generics.make_where_clause().predicates.extend(new_ps);
        }

        generics
    }
}

struct Body {
    pub weights: Option<TokenStream>,
    pub weights_dim: Option<TokenStream>,
    pub weights_view: TokenStream,
    pub weights_view_mut: TokenStream,
}

impl ToTokens for Body {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        if let Some(ref weights_fn) = self.weights {
            tokens.extend(iter::once(quote! {
                fn weights(&self) -> crate::params::Weights { #weights_fn }
            }));
        }

        if let Some(ref weights_dim_fn) = self.weights_dim {
            tokens.extend(iter::once(quote! {
                fn weights_dim(&self) -> (usize, usize) { #weights_dim_fn }
            }));
        }

        let weights_view_fn = &self.weights_view;
        let weights_view_mut_fn = &self.weights_view_mut;

        tokens.extend(
            iter::once(quote! {
                fn weights_view(&self) -> crate::params::WeightsView {
                    #weights_view_fn
                }
            })
            .chain(iter::once(quote! {
                fn weights_view_mut(&mut self) -> crate::params::WeightsViewMut {
                    #weights_view_mut_fn
                }
            })),
        );
    }
}

struct WeightsField<'a, I: ToTokens> {
    pub ident: I,
    pub field: &'a Field,
}

impl<'a, I: ToTokens> WeightsField<'a, I> {
    pub fn type_ident(&self) -> &'a Ident {
        match self.field.ty {
            Type::Path(ref tp) => &tp.path.segments[0].ident,
            _ => unimplemented!(),
        }
    }
}

pub fn expand_derive_parameterised(ast: &syn::DeriveInput) -> TokenStream {
    let name = &ast.ident;
    let implementation = parameterised_impl(&ast.data);

    let body = &implementation.body;
    let generics = implementation.add_trait_bounds(ast.generics.clone());
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    TokenStream::from(quote! {
        #[automatically_derived]
        impl #impl_generics crate::params::Parameterised for #name #ty_generics #where_clause {
            #body
        }
    })
}

fn parameterised_impl(data: &Data) -> Implementation {
    match data {
        Data::Struct(ref ds) => parameterised_struct_impl(ds),
        _ => unimplemented!(),
    }
}

fn parameterised_struct_impl(ds: &DataStruct) -> Implementation {
    let n_fields = ds.fields.iter().len();

    if n_fields > 1 {
        let mut annotated_fields: Vec<_> = ds
            .fields
            .iter()
            .enumerate()
            .filter(|(_, f)| has_weight_attribute(f))
            .map(|(i, f)| WeightsField {
                ident: f
                    .ident
                    .clone()
                    .map(|i| quote! { #i })
                    .unwrap_or(quote! { #i }),
                field: f,
            })
            .collect();

        if annotated_fields.is_empty() {
            let (index, iwf) = ds
                .fields
                .iter()
                .enumerate()
                .find(|(_, f)| match &f.ident {
                    Some(ident) => ident.to_string() == WEIGHTS,
                    None => false,
                })
                .expect("Couldn't infer weights field, consider annotating with #[weights].");

            parameterised_wf_impl(WeightsField {
                ident: iwf
                    .ident
                    .clone()
                    .map(|i| quote! { #i })
                    .unwrap_or(quote! { #index }),
                field: iwf,
            })
        } else if annotated_fields.len() == 1 {
            parameterised_wf_impl(annotated_fields.pop().unwrap())
        } else {
            panic!(
                "Duplicate #[weights] annotations - \
                 automatic view concatenation implementations are not currently supported."
            )
        }
    } else if n_fields == 1 {
        match ds.fields {
            Fields::Unnamed(ref fs) => parameterised_wf_impl(WeightsField {
                ident: quote! { 0 },
                field: &fs.unnamed[0],
            }),
            Fields::Named(ref fs) => parameterised_wf_impl(WeightsField {
                ident: fs.named[0].ident.clone(),
                field: &fs.named[0],
            }),
            _ => unreachable!(),
        }
    } else {
        panic!("Nothing to derive Parameterised from!")
    }
}

fn parameterised_wf_impl<I: ToTokens>(wf: WeightsField<I>) -> Implementation {
    let ident = &wf.ident;
    let type_ident = wf.type_ident();

    match type_ident.to_string().as_ref() {
        "Vec" => Implementation::concrete(Body {
            weights: Some(quote! {
                let n_rows = self.#ident.len();

                crate::params::Weights::from_shape_vec((n_rows, 1), self.#ident.clone()).unwrap()
            }),
            weights_dim: Some(quote! { (self.#ident.len(), 1) }),
            weights_view: quote! {
                let n_rows = self.#ident.len();

                crate::params::WeightsView::from_shape((n_rows, 1), &self.#ident).unwrap()
            },
            weights_view_mut: quote! {
                let n_rows = self.#ident.len();

                crate::params::WeightsViewMut::from_shape((n_rows, 1), &mut self.#ident).unwrap()
            },
        }),
        "Array1" => Implementation::concrete(Body {
            weights: Some(quote! {
                let n_rows = self.#ident.len();

                self.#ident.clone().into_shape((n_rows, 1)).unwrap()
            }),
            weights_dim: Some(quote! { (self.#ident.len(), 1) }),
            weights_view: quote! {
                let n_rows = self.#ident.len();

                self.#ident.view().into_shape((n_rows, 1)).unwrap()
            },
            weights_view_mut: quote! {
                let n_rows = self.#ident.len();

                self.#ident.view_mut().into_shape((n_rows, 1)).unwrap()
            },
        }),
        "Weights" | "Array2" => Implementation::concrete(Body {
            weights: Some(quote! { self.#ident.clone() }),
            weights_dim: Some(quote! { self.#ident.dim() }),
            weights_view: quote! { self.#ident.view() },
            weights_view_mut: quote! { self.#ident.view_mut() },
        }),
        _ => Implementation::with_generics(
            vec![wf.field.ty.clone()],
            Body {
                weights: Some(quote! { self.#ident.weights() }),
                weights_dim: Some(quote! { self.#ident.weights_dim() }),
                weights_view: quote! { self.#ident.weights_view() },
                weights_view_mut: quote! { self.#ident.weights_view_mut() },
            },
        ),
    }
}

fn has_weight_attribute(f: &Field) -> bool {
    f.attrs.iter().any(|a| {
        a.parse_meta()
            .map(|meta| match meta {
                Meta::Path(ref path) => path.is_ident(WEIGHTS),
                _ => false,
            })
            .unwrap_or(false)
    })
}