1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
use proc_macro::{TokenStream, TokenTree};
use proc_macro::TokenTree::*;
use core::iter::FromIterator;
use core::str::FromStr;
use proc_macro::{Group};
use proc_macro::Delimiter;
use itertools::izip;
#[proc_macro_attribute]
pub fn r_gen(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut it = item.clone().into_iter();
let mut out : Vec<TokenTree> = Vec::new();
if let Some(Ident(i)) = it.next() {
out.push(Ident(i));
} else {
panic!("The #[r_gen] macro can only be applied to functions.")
}
if let Some(Ident(name)) = it.next() {
out.push(Ident(name));
} else {
panic!("Generative functions require a name.")
}
if let Some(Group(args)) = it.next() {
let new_args = get_new_args(args);
out.push(Group(new_args));
} else {
panic!("Malformed generative function. Could not identify function arguments.")
}
if let Some(Group(body)) = it.next() {
out.push(TokenTree::Group(update_body(body)))
}
let out = TokenStream::from_iter(out.into_iter());
out
}
fn get_new_args(old_args : Group) -> Group {
let mut samp_trace_arg = TokenStream::from_str("mut _sample : Rc<dyn FnMut(&String, Distribution, &mut Trace) -> Value>, _trace : &mut Trace, ").unwrap();
let new_args = old_args;
samp_trace_arg.extend(new_args.stream());
let new_args = Group::new(Delimiter::Parenthesis, samp_trace_arg);
new_args
}
fn update_body(body : Group) -> Group {
let g = Group::new(Delimiter::Brace, update_tok_stream(body.stream()));
g
}
fn update_tok_stream(tok_stream : TokenStream) -> TokenStream {
let mut res = TokenStream::new();
let tracking_stream =
izip!(
tok_stream.clone().into_iter(),
tok_stream.clone().into_iter().skip(1),
tok_stream.clone().into_iter().skip(2));
let mut ti = tok_stream.clone().into_iter();
if let Some(t) = ti.next() {
res.extend(TokenStream::from(t));
} else {
return tok_stream;
}
if let Some(t) = ti.next() {
res.extend(TokenStream::from(t));
} else {
return tok_stream;
}
for (prev_prev, prev, tok) in tracking_stream {
match &tok {
Group(g) => {
match (prev, prev_prev) {
(Punct(p), Ident(i)) => {
if p.as_char() == '!' && i.to_string() == "sample" {
res.extend(update_sample_params(g.clone()));
} else {
res.extend(TokenStream::from(tok));
}
},
_ => {
res.extend(TokenStream::from(TokenTree::Group(Group::new(g.delimiter(), update_tok_stream(g.stream())))));
}
}
}
_ => {
res.extend(TokenStream::from(tok));
}
}
}
res
}
fn update_sample_params(group : Group) -> TokenStream {
let mut new_params = TokenStream::from_str("_sample _trace ").unwrap();
new_params.extend(group.stream());
TokenStream::from(TokenTree::Group(Group::new(Delimiter::Parenthesis, new_params)))
}