postfix_macros_impl/
lib.rs

1/*!
2Postfix macros on stable Rust, today.
3
4This is the crate containing the proc macro implementation
5of the [`postfix_macros!`] macro.
6
7The `postfix-macros` crate reexports the macro
8defined by this crate, and adds some macros of its
9own that are helpful in postfix macro context.
10If you don't need these extra macros,
11you can use this crate instead and save
12the extra dependency.
13
14```
15# use postfix_macros_impl::postfix_macros;
16# #[derive(Debug, Clone, Copy)] enum Custom { Enum(()), EnumOther}
17# let val = [((),Custom::EnumOther,)];
18postfix_macros! {
19	"hello".assert_ne!("world");
20
21	val.iter()
22		.map(|v| v.1)
23		.find(|z| z.matches!(Custom::Enum(_) | Custom::EnumOther))
24		.dbg!();
25}
26```
27
28*/
29#![forbid(unsafe_code)]
30
31extern crate proc_macro;
32use proc_macro::{TokenStream, TokenTree as Tt, Punct, Group, Spacing,
33	Delimiter};
34
35#[proc_macro]
36pub fn postfix_macros(stream :TokenStream) -> TokenStream {
37	let mut vis = Visitor;
38	let res = vis.visit_stream(stream);
39	//println!("{}", res);
40	res
41}
42
43struct Visitor;
44
45impl Visitor {
46	fn visit_stream(&mut self, stream :TokenStream) -> TokenStream {
47		let mut res = Vec::new();
48		let mut stream_iter = stream.into_iter();
49		while let Some(tt) = stream_iter.next() {
50			match tt {
51				Tt::Group(group) => {
52					let mut postfix_macro = false;
53					{
54						let last_three = res.rchunks(3).next();
55						if let Some(&[Tt::Punct(ref p1), Tt::Ident(_), Tt::Punct(ref p2)]) = last_three {
56							if (p1.as_char(), p1.spacing(), p2.as_char(), p2.spacing()) == ('.', Spacing::Alone, '!', Spacing::Alone) {
57								postfix_macro = true;
58							}
59						}
60					}
61					let group = if postfix_macro {
62						// Remove the ! and macro ident
63						let mac_bang = res.pop().unwrap();
64						let mac = res.pop().unwrap();
65						// Remove the . before the macro
66						res.pop().unwrap();
67
68						// Walk the entire chain of tt's that
69						// form the expression we want to feed to the macro.
70						let expr_len = expression_length(&res);
71
72						if expr_len == 0 {
73							panic!("expected something before the postfix macro invocation");
74						}
75						//println!("  -> built");
76
77						// Build the group
78						let gr = self.visit_group(group);
79						let arg_tokens = &res[(res.len() - expr_len)..];
80						let gr = prepend_macro_arg_to_group(arg_tokens, gr);
81						res.truncate(res.len() - expr_len);
82
83						// Add back the macro ident and bang
84						res.push(mac);
85						res.push(mac_bang);
86
87						/*println!("res so far: {}",
88							res.iter().cloned().collect::<TokenStream>());*/
89
90						gr
91					} else {
92						group
93					};
94					let tt = Tt::Group(self.visit_group(group));
95					res.push(tt);
96				},
97				Tt::Ident(id) => {
98					res.push(Tt::Ident(id));
99				},
100				Tt::Punct(p) => {
101					res.push(Tt::Punct(p));
102				},
103				Tt::Literal(lit) => {
104					res.push(Tt::Literal(lit));
105				},
106			}
107		}
108		res.into_iter().collect()
109	}
110	fn visit_group(&mut self, group :Group) -> Group {
111		let delim = group.delimiter();
112		let span = group.span();
113		let stream = self.visit_stream(group.stream());
114		let mut gr = Group::new(delim, stream);
115		gr.set_span(span);
116		gr
117	}
118}
119
120
121/// Walk the entire chain of tt's that
122/// form an expression that a postfix macro call
123/// would be part of.
124///
125/// Returns the number of token tree items that
126/// belong to the expression.
127fn expression_length(tts :&[Tt]) -> usize {
128	let mut expr_len = 0;
129	let mut last_was_punctuation = true;
130	let mut last_was_group = true;
131	'outer: while expr_len < tts.len() {
132		let tt = &tts[tts.len() - 1 - expr_len];
133		let mut is_punctuation = false;
134		let mut is_group = false;
135		//println!("   {} {} {}", expr_len, tt, last_was_punctuation);
136		match tt {
137			Tt::Group(group) => {
138				is_group = true;
139				// If the group wasn't terminated by a punctuation,
140				// it belongs to e.g. a function body, if clause, etc,
141				// but not to our expression
142				if !last_was_punctuation {
143					break;
144				}
145
146				// If the group was terminated by a punctuation,
147				// it belongs to the postfix macro chain.
148				// If it's delimitered by braces, so is { ... },
149				// we need to check whether the group was an if,
150				// match, else, or else if block, and add stuff accordingly.
151
152				// If we have {}. it might be an if, match or else block.
153				if group.delimiter() == Delimiter::Brace {
154					loop {
155						//println!("GROUP SEARCH IS IN {}", tts[tts.len() - 1 - expr_len]);
156						// We are at the end, it was a {} block.
157						if expr_len + 1 >= tts.len() {
158							break;
159						}
160						let tt_before = &tts[tts.len() - 2 - expr_len];
161						match tt_before {
162							Tt::Group(_group) => {
163								// e.g. `if foo() {}`, `if { true } {}`, `if if {true } else { false } {}`,
164								// `if bools[..] {}`.
165								// Basically, just start the expression search and hope for the best :)
166							},
167							Tt::Ident(id) => {
168								let id_str = id.to_string();
169								if id_str == "else" {
170									expr_len += 3;
171									//println!("ELSE");
172									// Continue the chain search
173									continue;
174								} else {
175									// Any other ident: must be part of an expression like if something.expr {}.foo().
176									// Start the full expression search
177								}
178							},
179							Tt::Punct(p) => match p.as_char() {
180								// These indicate the end of the expression
181								';' | ',' => {
182									expr_len += 1;
183									break 'outer;
184								},
185								// This indicates the group was part of something else,
186								// like a prior macro foo! {} . bar!().
187								// Just continue the outer search normally
188								'!' => break,
189								// Unsupported stuff
190								// TODO support closures
191								'|' => panic!("Closures not supported yet"),
192								p => panic!("Group expr search encountered unsupported punctuation {}", p),
193							},
194							Tt::Literal(_lit) => {
195								// Start the expression search
196							},
197						}
198						// Perform the expression search
199						let sub_expr_len = expression_length(&tts[..tts.len() - 1 - expr_len]);
200						expr_len += sub_expr_len;
201						// Now check what's beyond the expression
202						let tt_before = if tts.len() < 2 + expr_len {
203							None
204						} else {
205							tts.get(tts.len() - 2 - expr_len)
206						};
207						let tt_before_that = if tts.len() < 3 + expr_len {
208							None
209						} else {
210							tts.get(tts.len() - 3 - expr_len)
211						};
212
213						/*println!("group search before: {} {:?} {:?}", sub_expr_len,
214							tt_before_that.map(|v| v.to_string()),
215							tt_before.map(|v| v.to_string()));*/
216						match (tt_before_that, tt_before) {
217							(Some(Tt::Ident(id_t)), Some(Tt::Ident(id))) => {
218								let id_t = id_t.to_string();
219								let id = id.to_string();
220								if id_t == "else" && id == "if" {
221									// Else if clause.
222									expr_len += 3;
223									// Continue the chain search.
224								} else if id == "match" {
225									// Done with the if/match chain search.
226									is_group = false;
227									expr_len += 1;
228									break;
229								}
230							},
231							(_, Some(Tt::Ident(id))) => {
232								let id = id.to_string();
233								if id == "if" || id == "match" {
234									// Done with the if/match chain search.
235									is_group = false;
236									expr_len += 1;
237									break;
238								} else {
239									// IDK something failed
240								}
241							},
242							(_, Some(Tt::Punct(p))) => {
243								match p.as_char() {
244									// This can be either == or if let Foo() =
245									'=' => {
246										if let Some(Tt::Punct(p_t)) = tt_before_that {
247											if p_t.as_char() == '=' {
248												// Parse another expr
249												// TODO
250												// TODO maybe instead of calling expression_length above,
251												// create a new function that calls expression_length internally and
252												// handles this case, calling expression_length again if needed?
253												// Or pass some kind of precedence setting to expression_length?
254												panic!("== in if clause not supported yet");
255											}
256										}
257										panic!("if let not supported");
258									},
259									_ => panic!("{} in if not supported yet", p),
260								}
261							},
262							(None, None) => {
263								// Nothing comes before tt.
264								// We are done
265								break;
266							},
267							_ => {
268								panic!("Hit unsupported case: {:?} {:?}", tt_before_that.map(|v| v.to_string()),
269									tt_before.map(|v| v.to_string()));
270							},
271						}
272					}
273				}
274			},
275			Tt::Ident(id) => {
276				if !last_was_punctuation && !last_was_group {
277					// two idents following another... must be `if <something>.foo!() { <stuff> }`
278					// or something like it.
279					// We need to special case the keyword mut though because `&mut` is usually
280					// prefixed to an expression.
281					let id_str = id.to_string();
282					if id_str != "mut" {
283						break;
284					}
285				}
286			},
287			Tt::Punct(p) => {
288				is_punctuation = true;
289				match p.as_char() {
290					// No expression termination
291					'.' if p.spacing() == Spacing::Alone => (),
292					':' | '?' | '!' => (),
293					// These all terminate expressions
294					'.' if p.spacing() == Spacing::Joint => break,
295					',' | ';' | '+' | '/' | '%' | '=' | '<' | '>' | '|' | '^' => break,
296					// All of & * and - can be safely prepended to expressions in any number,
297					// However the symbols also exist in a binop context.
298					// Only the leading symbol can be a binop, but what makes matters a bit
299					// more complicated is that `&&` is a valid binop as well.
300					'&' | '*' | '-' => {
301						// First, we find the end of our binop partner
302						let mut offs_until_binop_partner = 0;
303						for tt in tts[.. tts.len() - expr_len - 1].iter().rev() {
304							match tt {
305								Tt::Group(gr) => {
306									match gr.delimiter() {
307										// `{0} & 7;` is invalid and `{}` is.
308										// => all &*- are unary ops.
309										Delimiter::Brace => {
310											expr_len += offs_until_binop_partner + 1;
311											break 'outer;
312										}
313										// Both [] and () are other-parties/partners of binops
314										// e.g. `(4) & 7` is valid while `(()) & 7` isn't
315										// => this group belongs to our binop partner
316										Delimiter::Parenthesis | Delimiter::Bracket => {
317											break;
318										},
319
320										// IDK what to do here, let's just error
321										Delimiter::None => {
322											panic!("We don't support groups delimitered by none yet: {}", gr);
323										},
324									}
325								},
326
327								Tt::Ident(id) => {
328									let id_str = id.to_string();
329									match id_str.as_str() {
330										// `if` and `match` indicate that there is no binop
331										// but instead all prefix &*-s were unary ops.
332										"if" | "match" => {
333											expr_len += offs_until_binop_partner + 1;
334											break 'outer;
335										},
336										// `mut` is allowed part of a prefix
337										"mut" => (),
338										// If we encounter any other ident,
339										// it's part of the binop partner.
340										_ => break,
341									}
342								},
343								Tt::Punct(p) => {
344									match p.as_char() {
345										// ; either stands for the separator in array types/definitions,
346										// or it stands for a new statement. In both cases, unary op.
347										';' |
348										// , is used in tuples, argument lists, etc. Implies an unary op
349										',' |
350										// If we encounter =, it means an assignment OR comparison,
351										// both implying that all leading &*- were unary ops.
352										// (even though == is a binop, but that would be a binop at a higher level)
353										'=' => {
354											expr_len += offs_until_binop_partner + 1;
355											break 'outer;
356										},
357										// The ! mark may occur in places like `&!&false`
358										// and has to be treated like any leading unary
359										// operator.
360										'!' |
361										// Continue the search
362										'&' | '*' | '-' => (),
363
364										// We don't support special symbols yet
365										// TODO support more
366										_ => panic!("Binop partner search encountered punct '{}'", p),
367									}
368								},
369								Tt::Literal(_lit) => {
370									// Literals are binop partners
371									break;
372								},
373							}
374							offs_until_binop_partner += 1;
375						}
376						// If there is nothing beyond the one unary op in tts,
377						// no binop partner could be found,
378						// and we know that the sole punctuation
379						// was an unary op.
380						if offs_until_binop_partner == tts.len() - expr_len - 1 {
381							expr_len += offs_until_binop_partner + 1;
382							break;
383						}
384						let first = &tts[tts.len() - (expr_len + 1) - offs_until_binop_partner];
385						let second = &tts[tts.len() - (expr_len + 1) - offs_until_binop_partner + 1];
386						let mut binop_tts = 1;
387						match first {
388							Tt::Group(_gr) => unreachable!(),
389							// This can occur, as of current code only when we have code like `(mut hello.foo!())`,
390							// which would indicate a pattern context I guess... but for now we don't support
391							// our macro to be called in pattern contexts.
392							Tt::Ident(id) => panic!("Can't start a binop chain with ident '{}'", id),
393							Tt::Punct(p1) => {
394								if let Tt::Punct(p2) = second {
395									let is_binop_and_and = p1.spacing() == Spacing::Joint &&
396										p1.as_char() == '&' && p2.as_char() == '&';
397									if is_binop_and_and {
398										binop_tts = 2;
399									}
400								}
401							},
402							Tt::Literal(_lit) => unreachable!(),
403						}
404						// We finally know how many tt's the binop operator takes up (1 or 2).
405						// Set the length of the expression and emit the expression.
406						expr_len += 1 + offs_until_binop_partner - binop_tts;
407						break;
408					},
409					c => panic!("Encountered unsupported punctuation {}", c),
410				}
411			},
412			Tt::Literal(_lit) => {
413			},
414		}
415		expr_len += 1;
416		last_was_punctuation = is_punctuation;
417		last_was_group = is_group;
418	}
419	expr_len
420}
421
422fn prepend_macro_arg_to_group(tokens :&[Tt], gr :Group) -> Group {
423	// Build the expr's tt.
424	// If there is only one token and it's
425	// a variable/constant/static name, or a literal,
426	// we pass it directly, otherwise we wrap it in {}
427	// to make it safer.
428	let expr = match &tokens {
429		&[tt] if matches!(tt, Tt::Literal(_) | Tt::Ident(_)) => {
430			tt.clone()
431		},
432		_ => {
433			let expr_stream = tokens.iter().cloned().collect();
434			let expr_gr = Group::new(Delimiter::Brace, expr_stream);
435			Tt::Group(expr_gr)
436		},
437	};
438
439	let stream = gr.stream();
440	let delim = gr.delimiter();
441	let mut res_stream = TokenStream::from(expr);
442	if !stream.is_empty() {
443		res_stream.extend(std::iter::once(Tt::Punct(Punct::new(',', Spacing::Alone))));
444		res_stream.extend(stream);
445	}
446	Group::new(delim, res_stream)
447}