postfix_macros_impl/lib.rs
1/*!
2Postfix macros on stable Rust, today.
3
4This is the crate containing the proc macro implementation
5of the [`postfix_macros!`] macro.
6
7The `postfix-macros` crate reexports the macro
8defined by this crate, and adds some macros of its
9own that are helpful in postfix macro context.
10If you don't need these extra macros,
11you can use this crate instead and save
12the extra dependency.
13
14```
15# use postfix_macros_impl::postfix_macros;
16# #[derive(Debug, Clone, Copy)] enum Custom { Enum(()), EnumOther}
17# let val = [((),Custom::EnumOther,)];
18postfix_macros! {
19 "hello".assert_ne!("world");
20
21 val.iter()
22 .map(|v| v.1)
23 .find(|z| z.matches!(Custom::Enum(_) | Custom::EnumOther))
24 .dbg!();
25}
26```
27
28*/
29#![forbid(unsafe_code)]
30
31extern crate proc_macro;
32use proc_macro::{TokenStream, TokenTree as Tt, Punct, Group, Spacing,
33 Delimiter};
34
35#[proc_macro]
36pub fn postfix_macros(stream :TokenStream) -> TokenStream {
37 let mut vis = Visitor;
38 let res = vis.visit_stream(stream);
39 //println!("{}", res);
40 res
41}
42
43struct Visitor;
44
45impl Visitor {
46 fn visit_stream(&mut self, stream :TokenStream) -> TokenStream {
47 let mut res = Vec::new();
48 let mut stream_iter = stream.into_iter();
49 while let Some(tt) = stream_iter.next() {
50 match tt {
51 Tt::Group(group) => {
52 let mut postfix_macro = false;
53 {
54 let last_three = res.rchunks(3).next();
55 if let Some(&[Tt::Punct(ref p1), Tt::Ident(_), Tt::Punct(ref p2)]) = last_three {
56 if (p1.as_char(), p1.spacing(), p2.as_char(), p2.spacing()) == ('.', Spacing::Alone, '!', Spacing::Alone) {
57 postfix_macro = true;
58 }
59 }
60 }
61 let group = if postfix_macro {
62 // Remove the ! and macro ident
63 let mac_bang = res.pop().unwrap();
64 let mac = res.pop().unwrap();
65 // Remove the . before the macro
66 res.pop().unwrap();
67
68 // Walk the entire chain of tt's that
69 // form the expression we want to feed to the macro.
70 let expr_len = expression_length(&res);
71
72 if expr_len == 0 {
73 panic!("expected something before the postfix macro invocation");
74 }
75 //println!(" -> built");
76
77 // Build the group
78 let gr = self.visit_group(group);
79 let arg_tokens = &res[(res.len() - expr_len)..];
80 let gr = prepend_macro_arg_to_group(arg_tokens, gr);
81 res.truncate(res.len() - expr_len);
82
83 // Add back the macro ident and bang
84 res.push(mac);
85 res.push(mac_bang);
86
87 /*println!("res so far: {}",
88 res.iter().cloned().collect::<TokenStream>());*/
89
90 gr
91 } else {
92 group
93 };
94 let tt = Tt::Group(self.visit_group(group));
95 res.push(tt);
96 },
97 Tt::Ident(id) => {
98 res.push(Tt::Ident(id));
99 },
100 Tt::Punct(p) => {
101 res.push(Tt::Punct(p));
102 },
103 Tt::Literal(lit) => {
104 res.push(Tt::Literal(lit));
105 },
106 }
107 }
108 res.into_iter().collect()
109 }
110 fn visit_group(&mut self, group :Group) -> Group {
111 let delim = group.delimiter();
112 let span = group.span();
113 let stream = self.visit_stream(group.stream());
114 let mut gr = Group::new(delim, stream);
115 gr.set_span(span);
116 gr
117 }
118}
119
120
121/// Walk the entire chain of tt's that
122/// form an expression that a postfix macro call
123/// would be part of.
124///
125/// Returns the number of token tree items that
126/// belong to the expression.
127fn expression_length(tts :&[Tt]) -> usize {
128 let mut expr_len = 0;
129 let mut last_was_punctuation = true;
130 let mut last_was_group = true;
131 'outer: while expr_len < tts.len() {
132 let tt = &tts[tts.len() - 1 - expr_len];
133 let mut is_punctuation = false;
134 let mut is_group = false;
135 //println!(" {} {} {}", expr_len, tt, last_was_punctuation);
136 match tt {
137 Tt::Group(group) => {
138 is_group = true;
139 // If the group wasn't terminated by a punctuation,
140 // it belongs to e.g. a function body, if clause, etc,
141 // but not to our expression
142 if !last_was_punctuation {
143 break;
144 }
145
146 // If the group was terminated by a punctuation,
147 // it belongs to the postfix macro chain.
148 // If it's delimitered by braces, so is { ... },
149 // we need to check whether the group was an if,
150 // match, else, or else if block, and add stuff accordingly.
151
152 // If we have {}. it might be an if, match or else block.
153 if group.delimiter() == Delimiter::Brace {
154 loop {
155 //println!("GROUP SEARCH IS IN {}", tts[tts.len() - 1 - expr_len]);
156 // We are at the end, it was a {} block.
157 if expr_len + 1 >= tts.len() {
158 break;
159 }
160 let tt_before = &tts[tts.len() - 2 - expr_len];
161 match tt_before {
162 Tt::Group(_group) => {
163 // e.g. `if foo() {}`, `if { true } {}`, `if if {true } else { false } {}`,
164 // `if bools[..] {}`.
165 // Basically, just start the expression search and hope for the best :)
166 },
167 Tt::Ident(id) => {
168 let id_str = id.to_string();
169 if id_str == "else" {
170 expr_len += 3;
171 //println!("ELSE");
172 // Continue the chain search
173 continue;
174 } else {
175 // Any other ident: must be part of an expression like if something.expr {}.foo().
176 // Start the full expression search
177 }
178 },
179 Tt::Punct(p) => match p.as_char() {
180 // These indicate the end of the expression
181 ';' | ',' => {
182 expr_len += 1;
183 break 'outer;
184 },
185 // This indicates the group was part of something else,
186 // like a prior macro foo! {} . bar!().
187 // Just continue the outer search normally
188 '!' => break,
189 // Unsupported stuff
190 // TODO support closures
191 '|' => panic!("Closures not supported yet"),
192 p => panic!("Group expr search encountered unsupported punctuation {}", p),
193 },
194 Tt::Literal(_lit) => {
195 // Start the expression search
196 },
197 }
198 // Perform the expression search
199 let sub_expr_len = expression_length(&tts[..tts.len() - 1 - expr_len]);
200 expr_len += sub_expr_len;
201 // Now check what's beyond the expression
202 let tt_before = if tts.len() < 2 + expr_len {
203 None
204 } else {
205 tts.get(tts.len() - 2 - expr_len)
206 };
207 let tt_before_that = if tts.len() < 3 + expr_len {
208 None
209 } else {
210 tts.get(tts.len() - 3 - expr_len)
211 };
212
213 /*println!("group search before: {} {:?} {:?}", sub_expr_len,
214 tt_before_that.map(|v| v.to_string()),
215 tt_before.map(|v| v.to_string()));*/
216 match (tt_before_that, tt_before) {
217 (Some(Tt::Ident(id_t)), Some(Tt::Ident(id))) => {
218 let id_t = id_t.to_string();
219 let id = id.to_string();
220 if id_t == "else" && id == "if" {
221 // Else if clause.
222 expr_len += 3;
223 // Continue the chain search.
224 } else if id == "match" {
225 // Done with the if/match chain search.
226 is_group = false;
227 expr_len += 1;
228 break;
229 }
230 },
231 (_, Some(Tt::Ident(id))) => {
232 let id = id.to_string();
233 if id == "if" || id == "match" {
234 // Done with the if/match chain search.
235 is_group = false;
236 expr_len += 1;
237 break;
238 } else {
239 // IDK something failed
240 }
241 },
242 (_, Some(Tt::Punct(p))) => {
243 match p.as_char() {
244 // This can be either == or if let Foo() =
245 '=' => {
246 if let Some(Tt::Punct(p_t)) = tt_before_that {
247 if p_t.as_char() == '=' {
248 // Parse another expr
249 // TODO
250 // TODO maybe instead of calling expression_length above,
251 // create a new function that calls expression_length internally and
252 // handles this case, calling expression_length again if needed?
253 // Or pass some kind of precedence setting to expression_length?
254 panic!("== in if clause not supported yet");
255 }
256 }
257 panic!("if let not supported");
258 },
259 _ => panic!("{} in if not supported yet", p),
260 }
261 },
262 (None, None) => {
263 // Nothing comes before tt.
264 // We are done
265 break;
266 },
267 _ => {
268 panic!("Hit unsupported case: {:?} {:?}", tt_before_that.map(|v| v.to_string()),
269 tt_before.map(|v| v.to_string()));
270 },
271 }
272 }
273 }
274 },
275 Tt::Ident(id) => {
276 if !last_was_punctuation && !last_was_group {
277 // two idents following another... must be `if <something>.foo!() { <stuff> }`
278 // or something like it.
279 // We need to special case the keyword mut though because `&mut` is usually
280 // prefixed to an expression.
281 let id_str = id.to_string();
282 if id_str != "mut" {
283 break;
284 }
285 }
286 },
287 Tt::Punct(p) => {
288 is_punctuation = true;
289 match p.as_char() {
290 // No expression termination
291 '.' if p.spacing() == Spacing::Alone => (),
292 ':' | '?' | '!' => (),
293 // These all terminate expressions
294 '.' if p.spacing() == Spacing::Joint => break,
295 ',' | ';' | '+' | '/' | '%' | '=' | '<' | '>' | '|' | '^' => break,
296 // All of & * and - can be safely prepended to expressions in any number,
297 // However the symbols also exist in a binop context.
298 // Only the leading symbol can be a binop, but what makes matters a bit
299 // more complicated is that `&&` is a valid binop as well.
300 '&' | '*' | '-' => {
301 // First, we find the end of our binop partner
302 let mut offs_until_binop_partner = 0;
303 for tt in tts[.. tts.len() - expr_len - 1].iter().rev() {
304 match tt {
305 Tt::Group(gr) => {
306 match gr.delimiter() {
307 // `{0} & 7;` is invalid and `{}` is.
308 // => all &*- are unary ops.
309 Delimiter::Brace => {
310 expr_len += offs_until_binop_partner + 1;
311 break 'outer;
312 }
313 // Both [] and () are other-parties/partners of binops
314 // e.g. `(4) & 7` is valid while `(()) & 7` isn't
315 // => this group belongs to our binop partner
316 Delimiter::Parenthesis | Delimiter::Bracket => {
317 break;
318 },
319
320 // IDK what to do here, let's just error
321 Delimiter::None => {
322 panic!("We don't support groups delimitered by none yet: {}", gr);
323 },
324 }
325 },
326
327 Tt::Ident(id) => {
328 let id_str = id.to_string();
329 match id_str.as_str() {
330 // `if` and `match` indicate that there is no binop
331 // but instead all prefix &*-s were unary ops.
332 "if" | "match" => {
333 expr_len += offs_until_binop_partner + 1;
334 break 'outer;
335 },
336 // `mut` is allowed part of a prefix
337 "mut" => (),
338 // If we encounter any other ident,
339 // it's part of the binop partner.
340 _ => break,
341 }
342 },
343 Tt::Punct(p) => {
344 match p.as_char() {
345 // ; either stands for the separator in array types/definitions,
346 // or it stands for a new statement. In both cases, unary op.
347 ';' |
348 // , is used in tuples, argument lists, etc. Implies an unary op
349 ',' |
350 // If we encounter =, it means an assignment OR comparison,
351 // both implying that all leading &*- were unary ops.
352 // (even though == is a binop, but that would be a binop at a higher level)
353 '=' => {
354 expr_len += offs_until_binop_partner + 1;
355 break 'outer;
356 },
357 // The ! mark may occur in places like `&!&false`
358 // and has to be treated like any leading unary
359 // operator.
360 '!' |
361 // Continue the search
362 '&' | '*' | '-' => (),
363
364 // We don't support special symbols yet
365 // TODO support more
366 _ => panic!("Binop partner search encountered punct '{}'", p),
367 }
368 },
369 Tt::Literal(_lit) => {
370 // Literals are binop partners
371 break;
372 },
373 }
374 offs_until_binop_partner += 1;
375 }
376 // If there is nothing beyond the one unary op in tts,
377 // no binop partner could be found,
378 // and we know that the sole punctuation
379 // was an unary op.
380 if offs_until_binop_partner == tts.len() - expr_len - 1 {
381 expr_len += offs_until_binop_partner + 1;
382 break;
383 }
384 let first = &tts[tts.len() - (expr_len + 1) - offs_until_binop_partner];
385 let second = &tts[tts.len() - (expr_len + 1) - offs_until_binop_partner + 1];
386 let mut binop_tts = 1;
387 match first {
388 Tt::Group(_gr) => unreachable!(),
389 // This can occur, as of current code only when we have code like `(mut hello.foo!())`,
390 // which would indicate a pattern context I guess... but for now we don't support
391 // our macro to be called in pattern contexts.
392 Tt::Ident(id) => panic!("Can't start a binop chain with ident '{}'", id),
393 Tt::Punct(p1) => {
394 if let Tt::Punct(p2) = second {
395 let is_binop_and_and = p1.spacing() == Spacing::Joint &&
396 p1.as_char() == '&' && p2.as_char() == '&';
397 if is_binop_and_and {
398 binop_tts = 2;
399 }
400 }
401 },
402 Tt::Literal(_lit) => unreachable!(),
403 }
404 // We finally know how many tt's the binop operator takes up (1 or 2).
405 // Set the length of the expression and emit the expression.
406 expr_len += 1 + offs_until_binop_partner - binop_tts;
407 break;
408 },
409 c => panic!("Encountered unsupported punctuation {}", c),
410 }
411 },
412 Tt::Literal(_lit) => {
413 },
414 }
415 expr_len += 1;
416 last_was_punctuation = is_punctuation;
417 last_was_group = is_group;
418 }
419 expr_len
420}
421
422fn prepend_macro_arg_to_group(tokens :&[Tt], gr :Group) -> Group {
423 // Build the expr's tt.
424 // If there is only one token and it's
425 // a variable/constant/static name, or a literal,
426 // we pass it directly, otherwise we wrap it in {}
427 // to make it safer.
428 let expr = match &tokens {
429 &[tt] if matches!(tt, Tt::Literal(_) | Tt::Ident(_)) => {
430 tt.clone()
431 },
432 _ => {
433 let expr_stream = tokens.iter().cloned().collect();
434 let expr_gr = Group::new(Delimiter::Brace, expr_stream);
435 Tt::Group(expr_gr)
436 },
437 };
438
439 let stream = gr.stream();
440 let delim = gr.delimiter();
441 let mut res_stream = TokenStream::from(expr);
442 if !stream.is_empty() {
443 res_stream.extend(std::iter::once(Tt::Punct(Punct::new(',', Spacing::Alone))));
444 res_stream.extend(stream);
445 }
446 Group::new(delim, res_stream)
447}