use proc_macro::{TokenStream, TokenTree};
use proc_macro::TokenTree::*;
use core::iter::FromIterator;
use core::str::FromStr;
use proc_macro::{Group};
use proc_macro::Delimiter;
use itertools::izip;
#[proc_macro_attribute]
pub fn r_gen(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut it = item.clone().into_iter();
let mut out : Vec<TokenTree> = Vec::new();
if let Some(Ident(i)) = it.next() {
out.push(Ident(i));
} else {
panic!("The #[r_gen] macro can only be applied to functions.")
}
if let Some(Ident(name)) = it.next() {
out.push(Ident(name));
} else {
panic!("Generative functions require a name.")
}
if let Some(Group(args)) = it.next() {
let new_args = get_new_args(args);
out.push(Group(new_args));
} else {
panic!("Malformed generative function. Could not identify function arguments.")
}
if let Some(Group(body)) = it.next() {
out.push(TokenTree::Group(update_body(body)))
}
let out = TokenStream::from_iter(out.into_iter());
out
}
fn get_new_args(old_args : Group) -> Group {
let mut samp_trace_arg = TokenStream::from_str("mut _sample : Rc<dyn FnMut(&String, Distribution, &mut Trace) -> Value>, _trace : &mut Trace, ").unwrap();
let new_args = old_args; samp_trace_arg.extend(new_args.stream());
let new_args = Group::new(Delimiter::Parenthesis, samp_trace_arg);
new_args
}
fn update_body(body : Group) -> Group {
let g = Group::new(Delimiter::Brace, update_tok_stream(body.stream()));
g
}
fn update_tok_stream(tok_stream : TokenStream) -> TokenStream {
let mut res = TokenStream::new();
let tracking_stream =
izip!(
tok_stream.clone().into_iter(),
tok_stream.clone().into_iter().skip(1),
tok_stream.clone().into_iter().skip(2));
let mut ti = tok_stream.clone().into_iter();
if let Some(t) = ti.next() {
res.extend(TokenStream::from(t));
} else {
return tok_stream;
}
if let Some(t) = ti.next() {
res.extend(TokenStream::from(t));
} else {
return tok_stream;
}
for (prev_prev, prev, tok) in tracking_stream {
match &tok {
Group(g) => {
match (prev, prev_prev) {
(Punct(p), Ident(i)) => {
if p.as_char() == '!' && i.to_string() == "sample" {
res.extend(update_sample_params(g.clone()));
} else {
res.extend(TokenStream::from(tok));
}
},
_ => {
res.extend(TokenStream::from(TokenTree::Group(Group::new(g.delimiter(), update_tok_stream(g.stream())))));
}
}
}
_ => {
res.extend(TokenStream::from(tok));
}
}
}
res
}
fn update_sample_params(group : Group) -> TokenStream {
let mut new_params = TokenStream::from_str("_sample _trace ").unwrap();
new_params.extend(group.stream());
TokenStream::from(TokenTree::Group(Group::new(Delimiter::Parenthesis, new_params)))
}