1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
/*!
Postfix macros on stable Rust, today.

This is the crate containing the proc macro implementation
of the [`postfix_macros!`] macro.

The `postfix-macros` crate reexports the macro
defined by this crate, and adds some macros of its
own that are helpful in postfix macro context.
If you don't need these extra macros,
you can use this crate instead and save
the extra dependency.

```
# use postfix_macros_impl::postfix_macros;
# #[derive(Debug, Clone, Copy)] enum Custom { Enum(()), EnumOther}
# let val = [((),Custom::EnumOther,)];
postfix_macros! {
	"hello".assert_ne!("world");

	val.iter()
		.map(|v| v.1)
		.find(|z| z.matches!(Custom::Enum(_) | Custom::EnumOther))
		.dbg!();
}
```

*/
#![forbid(unsafe_code)]

extern crate proc_macro;
use proc_macro::{TokenStream, TokenTree as Tt, Punct, Group, Spacing,
	Delimiter};

#[proc_macro]
pub fn postfix_macros(stream :TokenStream) -> TokenStream {
	let mut vis = Visitor;
	let res = vis.visit_stream(stream);
	//println!("{}", res);
	res
}

struct Visitor;

impl Visitor {
	fn visit_stream(&mut self, stream :TokenStream) -> TokenStream {
		let mut res = Vec::new();
		let mut stream_iter = stream.into_iter();
		while let Some(tt) = stream_iter.next() {
			match tt {
				Tt::Group(group) => {
					let mut postfix_macro = false;
					{
						let last_three = res.rchunks(3).next();
						if let Some(&[Tt::Punct(ref p1), Tt::Ident(_), Tt::Punct(ref p2)]) = last_three {
							if (p1.as_char(), p1.spacing(), p2.as_char(), p2.spacing()) == ('.', Spacing::Alone, '!', Spacing::Alone) {
								postfix_macro = true;
							}
						}
					}
					let group = if postfix_macro {
						// Remove the ! and macro ident
						let mac_bang = res.pop().unwrap();
						let mac = res.pop().unwrap();
						// Remove the . before the macro
						res.pop().unwrap();

						// Walk the entire chain of tt's that
						// form the expression we want to feed to the macro.
						let expr_len = expression_length(&res);

						if expr_len == 0 {
							panic!("expected something before the postfix macro invocation");
						}
						//println!("  -> built");

						// Build the group
						let gr = self.visit_group(group);
						let arg_tokens = &res[(res.len() - expr_len)..];
						let gr = prepend_macro_arg_to_group(arg_tokens, gr);
						res.truncate(res.len() - expr_len);

						// Add back the macro ident and bang
						res.push(mac);
						res.push(mac_bang);

						/*println!("res so far: {}",
							res.iter().cloned().collect::<TokenStream>());*/

						gr
					} else {
						group
					};
					let tt = Tt::Group(self.visit_group(group));
					res.push(tt);
				},
				Tt::Ident(id) => {
					res.push(Tt::Ident(id));
				},
				Tt::Punct(p) => {
					res.push(Tt::Punct(p));
				},
				Tt::Literal(lit) => {
					res.push(Tt::Literal(lit));
				},
			}
		}
		res.into_iter().collect()
	}
	fn visit_group(&mut self, group :Group) -> Group {
		let delim = group.delimiter();
		let span = group.span();
		let stream = self.visit_stream(group.stream());
		let mut gr = Group::new(delim, stream);
		gr.set_span(span);
		gr
	}
}


/// Walk the entire chain of tt's that
/// form an expression that a postfix macro call
/// would be part of.
///
/// Returns the number of token tree items that
/// belong to the expression.
fn expression_length(tts :&[Tt]) -> usize {
	let mut expr_len = 0;
	let mut last_was_punctuation = true;
	let mut last_was_group = true;
	'outer: while expr_len < tts.len() {
		let tt = &tts[tts.len() - 1 - expr_len];
		let mut is_punctuation = false;
		let mut is_group = false;
		//println!("   {} {} {}", expr_len, tt, last_was_punctuation);
		match tt {
			Tt::Group(group) => {
				is_group = true;
				// If the group wasn't terminated by a punctuation,
				// it belongs to e.g. a function body, if clause, etc,
				// but not to our expression
				if !last_was_punctuation {
					break;
				}

				// If the group was terminated by a punctuation,
				// it belongs to the postfix macro chain.
				// If it's delimitered by braces, so is { ... },
				// we need to check whether the group was an if,
				// match, else, or else if block, and add stuff accordingly.

				// If we have {}. it might be an if, match or else block.
				if group.delimiter() == Delimiter::Brace {
					loop {
						//println!("GROUP SEARCH IS IN {}", tts[tts.len() - 1 - expr_len]);
						// We are at the end, it was a {} block.
						if expr_len + 1 >= tts.len() {
							break;
						}
						let tt_before = &tts[tts.len() - 2 - expr_len];
						match tt_before {
							Tt::Group(_group) => {
								// e.g. `if foo() {}`, `if { true } {}`, `if if {true } else { false } {}`,
								// `if bools[..] {}`.
								// Basically, just start the expression search and hope for the best :)
							},
							Tt::Ident(id) => {
								let id_str = id.to_string();
								if id_str == "else" {
									expr_len += 3;
									//println!("ELSE");
									// Continue the chain search
									continue;
								} else {
									// Any other ident: must be part of an expression like if something.expr {}.foo().
									// Start the full expression search
								}
							},
							Tt::Punct(p) => match p.as_char() {
								// These indicate the end of the expression
								';' | ',' => {
									expr_len += 1;
									break 'outer;
								},
								// This indicates the group was part of something else,
								// like a prior macro foo! {} . bar!().
								// Just continue the outer search normally
								'!' => break,
								// Unsupported stuff
								// TODO support closures
								'|' => panic!("Closures not supported yet"),
								p => panic!("Group expr search encountered unsupported punctuation {}", p),
							},
							Tt::Literal(_lit) => {
								// Start the expression search
							},
						}
						// Perform the expression search
						let sub_expr_len = expression_length(&tts[..tts.len() - 1 - expr_len]);
						expr_len += sub_expr_len;
						// Now check what's beyond the expression
						let tt_before = if tts.len() < 2 + expr_len {
							None
						} else {
							tts.get(tts.len() - 2 - expr_len)
						};
						let tt_before_that = if tts.len() < 3 + expr_len {
							None
						} else {
							tts.get(tts.len() - 3 - expr_len)
						};

						/*println!("group search before: {} {:?} {:?}", sub_expr_len,
							tt_before_that.map(|v| v.to_string()),
							tt_before.map(|v| v.to_string()));*/
						match (tt_before_that, tt_before) {
							(Some(Tt::Ident(id_t)), Some(Tt::Ident(id))) => {
								let id_t = id_t.to_string();
								let id = id.to_string();
								if id_t == "else" && id == "if" {
									// Else if clause.
									expr_len += 3;
									// Continue the chain search.
								} else if id == "match" {
									// Done with the if/match chain search.
									is_group = false;
									expr_len += 1;
									break;
								}
							},
							(_, Some(Tt::Ident(id))) => {
								let id = id.to_string();
								if id == "if" || id == "match" {
									// Done with the if/match chain search.
									is_group = false;
									expr_len += 1;
									break;
								} else {
									// IDK something failed
								}
							},
							(_, Some(Tt::Punct(p))) => {
								match p.as_char() {
									// This can be either == or if let Foo() =
									'=' => {
										if let Some(Tt::Punct(p_t)) = tt_before_that {
											if p_t.as_char() == '=' {
												// Parse another expr
												// TODO
												// TODO maybe instead of calling expression_length above,
												// create a new function that calls expression_length internally and
												// handles this case, calling expression_length again if needed?
												// Or pass some kind of precedence setting to expression_length?
												panic!("== in if clause not supported yet");
											}
										}
										panic!("if let not supported");
									},
									_ => panic!("{} in if not supported yet", p),
								}
							},
							(None, None) => {
								// Nothing comes before tt.
								// We are done
								break;
							},
							_ => {
								panic!("Hit unsupported case: {:?} {:?}", tt_before_that.map(|v| v.to_string()),
									tt_before.map(|v| v.to_string()));
							},
						}
					}
				}
			},
			Tt::Ident(id) => {
				if !last_was_punctuation && !last_was_group {
					// two idents following another... must be `if <something>.foo!() { <stuff> }`
					// or something like it.
					// We need to special case the keyword mut though because `&mut` is usually
					// prefixed to an expression.
					let id_str = id.to_string();
					if id_str != "mut" {
						break;
					}
				}
			},
			Tt::Punct(p) => {
				is_punctuation = true;
				match p.as_char() {
					// No expression termination
					'.' if p.spacing() == Spacing::Alone => (),
					':' | '?' | '!' => (),
					// These all terminate expressions
					'.' if p.spacing() == Spacing::Joint => break,
					',' | ';' | '+' | '/' | '%' | '=' | '<' | '>' | '|' | '^' => break,
					// All of & * and - can be safely prepended to expressions in any number,
					// However the symbols also exist in a binop context.
					// Only the leading symbol can be a binop, but what makes matters a bit
					// more complicated is that `&&` is a valid binop as well.
					'&' | '*' | '-' => {
						// First, we find the end of our binop partner
						let mut offs_until_binop_partner = 0;
						for tt in tts[.. tts.len() - expr_len - 1].iter().rev() {
							match tt {
								Tt::Group(gr) => {
									match gr.delimiter() {
										// `{0} & 7;` is invalid and `{}` is.
										// => all &*- are unary ops.
										Delimiter::Brace => {
											expr_len += offs_until_binop_partner + 1;
											break 'outer;
										}
										// Both [] and () are other-parties/partners of binops
										// e.g. `(4) & 7` is valid while `(()) & 7` isn't
										// => this group belongs to our binop partner
										Delimiter::Parenthesis | Delimiter::Bracket => {
											break;
										},

										// IDK what to do here, let's just error
										Delimiter::None => {
											panic!("We don't support groups delimitered by none yet: {}", gr);
										},
									}
								},

								Tt::Ident(id) => {
									let id_str = id.to_string();
									match id_str.as_str() {
										// `if` and `match` indicate that there is no binop
										// but instead all prefix &*-s were unary ops.
										"if" | "match" => {
											expr_len += offs_until_binop_partner + 1;
											break 'outer;
										},
										// `mut` is allowed part of a prefix
										"mut" => (),
										// If we encounter any other ident,
										// it's part of the binop partner.
										_ => break,
									}
								},
								Tt::Punct(p) => {
									match p.as_char() {
										// ; either stands for the separator in array types/definitions,
										// or it stands for a new statement. In both cases, unary op.
										';' |
										// , is used in tuples, argument lists, etc. Implies an unary op
										',' |
										// If we encounter =, it means an assignment OR comparison,
										// both implying that all leading &*- were unary ops.
										// (even though == is a binop, but that would be a binop at a higher level)
										'=' => {
											expr_len += offs_until_binop_partner + 1;
											break 'outer;
										},
										// The ! mark may occur in places like `&!&false`
										// and has to be treated like any leading unary
										// operator.
										'!' |
										// Continue the search
										'&' | '*' | '-' => (),

										// We don't support special symbols yet
										// TODO support more
										_ => panic!("Binop partner search encountered punct '{}'", p),
									}
								},
								Tt::Literal(_lit) => {
									// Literals are binop partners
									break;
								},
							}
							offs_until_binop_partner += 1;
						}
						// If there is nothing beyond the one unary op in tts,
						// no binop partner could be found,
						// and we know that the sole punctuation
						// was an unary op.
						if offs_until_binop_partner == tts.len() - expr_len - 1 {
							expr_len += offs_until_binop_partner + 1;
							break;
						}
						let first = &tts[tts.len() - (expr_len + 1) - offs_until_binop_partner];
						let second = &tts[tts.len() - (expr_len + 1) - offs_until_binop_partner + 1];
						let mut binop_tts = 1;
						match first {
							Tt::Group(_gr) => unreachable!(),
							// This can occur, as of current code only when we have code like `(mut hello.foo!())`,
							// which would indicate a pattern context I guess... but for now we don't support
							// our macro to be called in pattern contexts.
							Tt::Ident(id) => panic!("Can't start a binop chain with ident '{}'", id),
							Tt::Punct(p1) => {
								if let Tt::Punct(p2) = second {
									let is_binop_and_and = p1.spacing() == Spacing::Joint &&
										p1.as_char() == '&' && p2.as_char() == '&';
									if is_binop_and_and {
										binop_tts = 2;
									}
								}
							},
							Tt::Literal(_lit) => unreachable!(),
						}
						// We finally know how many tt's the binop operator takes up (1 or 2).
						// Set the length of the expression and emit the expression.
						expr_len += 1 + offs_until_binop_partner - binop_tts;
						break;
					},
					c => panic!("Encountered unsupported punctuation {}", c),
				}
			},
			Tt::Literal(_lit) => {
			},
		}
		expr_len += 1;
		last_was_punctuation = is_punctuation;
		last_was_group = is_group;
	}
	expr_len
}

fn prepend_macro_arg_to_group(tokens :&[Tt], gr :Group) -> Group {
	// Build the expr's tt.
	// If there is only one token and it's
	// a variable/constant/static name, or a literal,
	// we pass it directly, otherwise we wrap it in {}
	// to make it safer.
	let expr = match &tokens {
		&[tt] if matches!(tt, Tt::Literal(_) | Tt::Ident(_)) => {
			tt.clone()
		},
		_ => {
			let expr_stream = tokens.iter().cloned().collect();
			let expr_gr = Group::new(Delimiter::Brace, expr_stream);
			Tt::Group(expr_gr)
		},
	};

	let stream = gr.stream();
	let delim = gr.delimiter();
	let mut res_stream = TokenStream::from(expr);
	if !stream.is_empty() {
		res_stream.extend(std::iter::once(Tt::Punct(Punct::new(',', Spacing::Alone))));
		res_stream.extend(stream);
	}
	Group::new(delim, res_stream)
}