1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
///a macro for deriving Inherited state traits in layer.rs to bring some traits up to the top level
///via accessor and setter blanket implementations
//pub struct T_layer(LayerState);
//impl InheritState for T_Layer{
// fn get_mut_layer_state(&mut self) -> &mut LayerState {
// &mut self.0
// }
// fn get_layer_state(&self) -> &LayerState {
// &self.0
// }
//}
////also impl the following
//impl InitializeLayer for T_layer{
// fn init() -> Self {
// T_layer(LayerState {
// input: None,
// input_size: 0,
// output_size: 0,
// width: 0,
// activation: None,
// dtype: DataType::Float,
// })
// }
//}
//a trait that automates deriving the InheritState trait for a struct that contains a LayerState as
//self.0
use proc_macro::TokenStream;
extern crate proc_macro;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};
#[proc_macro_derive(InheritState)]
pub fn derive_inherit_state(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let name = &ast.ident;
//allow non inheritence (if self.0 is self.layer_state)
let is_inherited = match ast.data {
syn::Data::Struct(ref data) => match data.fields {
syn::Fields::Named(ref fields) => fields
.named
.iter()
.any(|field| field.ident.as_ref().unwrap() == "0"),
_ => true,
},
_ => false,
};
if is_inherited {
let gen = quote! {
impl InheritState for #name{
fn get_mut_layer_state(&mut self) -> &mut LayerState {
&mut self.0
}
fn get_layer_state(&self) -> &LayerState {
&self.0
}
}
impl InitializeLayer for #name {
fn init() -> Self {
#name(LayerState {
input: None,
input_size: 0,
output_size: 0,
width: 0,
activation: None,
dtype: DataType::Float,
})
}
}
};
gen.into()
} else {
//get all fields except self.layer_state in a list
let fields = match ast.data {
syn::Data::Struct(ref data) => match data.fields {
syn::Fields::Named(ref fields) => fields
.named
.iter()
.filter(|field| field.ident.as_ref().unwrap() != "layer_state")
.map(|field| field.ident.as_ref().unwrap())
.collect::<Vec<_>>(),
_ => vec![],
},
_ => vec![],
};
let gen = quote! {
impl InheritState for #name{
fn get_mut_layer_state(&mut self) -> &mut LayerState {
&mut self.layer_state
}
fn get_layer_state(&self) -> &LayerState {
&self.layer_state
}
}
//TODO: this needs to add all the custom local variables (not in LayerState) to the struct
impl InitializeLayer for #name {
fn init() -> Self {
#name {
layer_state: LayerState {
input: None,
input_size: 0,
output_size: 0,
width: 0,
activation: None,
dtype: DataType::Float,
}
//also initialize any other fields as long as they have
//default implementation
#(, #fields: Default::default())*
}
}
}
};
gen.into()
}
}