1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::*;
use ron::de::from_str;
use serde::Deserialize;
use syn::Index;

macro_rules! bail{
    ($($t: tt)*) => {{
        let err = format!($($t)*);
        return quote!{compile_error!(#err);}.into()
    }}
}

#[derive(Deserialize)]
struct Layout {
    name: String,
    derive: Option<Vec<String>>,
    layers: Vec<(String, Vec<usize>)>,
    cached: bool,
}

#[proc_macro]
pub fn model(input: TokenStream) -> TokenStream {
    let layout: Layout = match from_str(&input.to_string()) {
        Err(e) => bail!("Error parsing macro: {}", e),
        Ok(ok) => ok,
    };
    if !layout.cached{
        bail!("Non-Cached networks not implemented yet")
    }

    let model_name = format_ident!("{}", layout.name);
    let o_size = Index::from(*layout.layers.last().unwrap().1.last().unwrap());
    let i_size = Index::from(layout.layers[0].1[0]);
    let derive = layout.derive.unwrap_or(vec![]).iter().map(|d| format_ident!("{}", d)).collect::<Vec<_>>();

    let layers = layout.layers.iter().map(|l|{
        let ty = l.0.split("::").map(|p| format_ident!("{}", p));
        let generics = l.1.iter();
        quote!{
            #(#ty)::*<#(#generics),*>
        }
    }).collect::<Vec<_>>();

    let fields = layers.iter().enumerate().map(|(n, l)|{
        let field = format_ident!("l{}", n);
        quote!(#field : #l)
    });

    let cache_ty = quote!((#(<#layers as exotic::deep::LayerTy>::Output),*));

    let mut predict_body = quote!(i);
    for l in 0..layout.layers.len(){
        let l = format_ident!("l{}", l);
        predict_body = quote!(self.#l.predict(#predict_body).context(format!("Error in Layer {} while predicting", stringify!(#l)))?);
    }

    let write_cache = (0..layout.layers.len()).map(|l|
        if l == 0{
            quote!(cache.0 = self.l0.predict(i)?;)
        }else{
            let prev_l = Index::from(l-1);
            let ll = format_ident!("l{}", l);
            let l = Index::from(l);
            quote!{
                cache.#l = self.#ll.predict(cache.#prev_l)?;
            }
        }
    ).collect::<Vec<_>>();

    let mut backprop_body = quote!(delta);
    for l in 0..layout.layers.len()-1{
        let l = layout.layers.len()-l-1;
        let ll = format_ident!("l{}", l);
        let l = Index::from(l-1);
        backprop_body = quote!(self.#ll.backpropagate(cache.#l, #backprop_body).context(format!("Error in Layer {} while backpropagating", stringify!(#ll)))?);
    }
    backprop_body = quote!(self.l0.backpropagate(i, #backprop_body)?);

    let o_layer = Index::from(layout.layers.len()-1);
    let cache = (0..layout.layers.len()).map(|n| Index::from(n));

    TokenStream::from(quote!(//println!("{}", stringify!(
        #[derive(#(#derive),*)]
        struct #model_name{
            #(#fields,)*
            cache: #cache_ty,
        }

        impl #model_name{
            pub fn new_cache() -> #cache_ty{
                (#(<#layers as exotic::deep::LayerTy>::Output::zero()),*)
            }
            pub fn cache(&mut self, i: impl deep::Axon<#i_size>) -> Result<deep::AxonRef<<Self as deep::LayerTy>::Output, #o_size>>{
                let cache = &mut self.cache;
                #(#write_cache)*
                Ok(deep::AxonRef::Ref(&self.cache.#o_layer))
            }
            pub fn backpropagate_from_cache(&mut self, i: impl deep::Axon<#i_size>, delta: impl Axon<#o_size>) -> Result<<Self as deep::LayerTy>::Gradient>{ 
                let cache = &self.cache;
                Ok(#backprop_body)
            }
            pub fn printcache(&self){
                #(
                    println!("cache {} : {}", #cache, self.cache.#cache.to_ref());
                )*
            }
        }

        impl LayerTy for #model_name{
            type Output = [config::FLOAT; #o_size];
            type Gradient = [config::FLOAT; #i_size];
        }

        impl Layer<#i_size, #o_size> for #model_name{
            fn backpropagate(&mut self, i: impl deep::Axon<#i_size>, delta: impl Axon<#o_size>) -> Result<Self::Gradient>{
                let cache = &mut self.cache;
                #(#write_cache)*
                Ok(#backprop_body)
            }

            fn predict(&self, i: impl deep::Axon<#i_size>) -> Result<Self::Output>{
                Ok(#predict_body)
            }
        }
    ))
}

//            let a = self.0.predict(i)?;
//            let b = self.1.predict(&a)?;
//            drop(a);
//            let a = self.2.predict(&b)?;
//            drop(b);
//            let b = self.3.predict(&a)?;
//            drop(a);
//            self.4.predict(&b)