implicit_fn/
lib.rs

1//! A macro that adds support for implicit closures to Rust.
2//!
3//! This provides a concise alternative to regular closure syntax
4//! for when each parameter is used at most once
5//! and is not deeply nested.
6//!
7//! This feature [has been suggested before](https://github.com/rust-lang/rfcs/issues/2554),
8//! but this macro mostly exists for fun.
9//!
10//! # Examples
11//!
12//! Using implicit closures to concisely sum a list:
13//!
14//! ```
15//! #[implicit_fn]
16//! fn main() {
17//!     let n = [1, 2, 3].into_iter().fold(0, _ + _);
18//!     assert_eq!(n, 6);
19//! }
20//! # use implicit_fn::implicit_fn;
21//! ```
22//!
23//! Copying all the elements of an array:
24//!
25//! ```
26//! #[implicit_fn]
27//! fn main() {
28//!     let array: [&u32; 3] = [&1, &2, &3];
29//!     let array: [u32; 3] = array.map(*_);
30//!     assert_eq!(array, [1, 2, 3]);
31//! }
32//! # use implicit_fn::implicit_fn;
33//! ```
34//!
35//! Running a fallible function in an iterator:
36//!
37//! ```
38//! #[implicit_fn]
39//! fn main() -> Result<(), Box<dyn Error>> {
40//!     let names = fs::read_dir(concat!(env!("CARGO_MANIFEST_DIR"), "/src"))?
41//!         .map(_?.file_name().into_string().map_err(|_| "file not UTF-8")?)
42//!         .collect::<Result<Vec<_>, Box<dyn Error>>>()?;
43//!     assert_eq!(names, ["lib.rs"]);
44//!     Ok(())
45//! }
46//! # use implicit_fn::implicit_fn;
47//! # use std::fs;
48//! # use std::error::Error;
49//! # use std::io;
50//! ```
51//!
52//! Running a match on an array of options:
53//!
54//! ```
55//! #[implicit_fn]
56//! fn main() {
57//!     let options = [Some(16), None, Some(2)];
58//!     let numbers = options.map(match _ {
59//!         Some(x) => x + 1,
60//!         None => 0,
61//!     });
62//!     assert_eq!(numbers, [17, 0, 3]);
63//! }
64//! # use implicit_fn::implicit_fn;
65//! ```
66//!
67//! Printing the elements of an iterator:
68//!
69//! ```
70//! #[implicit_fn]
71//! fn main() {
72//!     [1, 2, 3].into_iter().for_each(println!("{}", _));
73//! }
74//! # use implicit_fn::implicit_fn;
75//! ```
76#![warn(clippy::pedantic)]
77#![warn(redundant_lifetimes)]
78#![warn(rust_2018_idioms)]
79#![warn(single_use_lifetimes)]
80#![warn(unit_bindings)]
81#![warn(unused_crate_dependencies)]
82#![warn(unused_lifetimes)]
83#![warn(unused_qualifications)]
84#![allow(clippy::items_after_test_module)]
85
86/// Transform all the implicit closures inside the expressions of an item.
87///
88/// The size of the closure, i.e. where the `||` is put, is determined syntactically
89/// according to the following rules:
90/// + The closure is always large enough such that
91///   identity functions (i.e. `|x| x`) are never generated,
92///   except for `(_)` which is always an identity function.
93/// + The closure encompasses as many “transparent” syntactic elements as it can,
94///   where the transparent elements are
95///   - unary and binary operators (e.g. prefix `*` and `!`, `/`, or `+=`);
96///   - `.await`, `?`, field access, and
97///     the left hand side of method calls and indexing expressions;
98///   - the conditions of `while` and `if`,
99///     the iterator in `for` and the scrutinee of a `match`, and
100///   - both sides of a range expression (e.g. `_..=5`).
101///
102/// Notably, `f(_ + 5)` will always parse as `f(|x| x + 5)` and not `|x| f(x + 5)`,
103/// and `(_ + 1) * 2` will parse as `(|x| x + 1) * 2`.
104///
105/// For examples, see [the crate docs](crate).
106///
107/// # Limitations
108///
109/// We only support `?` with [`Result`] and not other types implementing `Try`,
110/// due to the lack of a stable `Try` trait.
111///
112/// Macro bodies are not transformed except for builtin macros.
113#[proc_macro_attribute]
114pub fn implicit_fn(
115    attr: proc_macro::TokenStream,
116    body: proc_macro::TokenStream,
117) -> proc_macro::TokenStream {
118    inner(attr.into(), body.into()).into()
119}
120
121fn inner(attr: TokenStream, body: TokenStream) -> TokenStream {
122    let mut diagnostics = Diagnostics(None);
123
124    if !attr.is_empty() {
125        diagnostics.add(attr, "expected no tokens");
126    }
127
128    let mut item = match syn::parse2::<Item>(body) {
129        Ok(input) => input,
130        Err(e) => return diagnostics.with(e).into_compile_error(),
131    };
132
133    Transformer.visit_item_mut(&mut item);
134
135    item.into_token_stream()
136}
137
138/// The top-level transformer.
139struct Transformer;
140
141impl VisitMut for Transformer {
142    fn visit_expr_mut(&mut self, e: &mut Expr) {
143        maybe_make_closure(e);
144    }
145}
146
147/// Potentially make the given expression into a closure.
148///
149/// Also applies the transform to all subexpressions.
150fn maybe_make_closure(e: &mut Expr) {
151    if let Expr::Infer(_) = e {
152        return;
153    }
154    maybe_make_closure_force(e);
155}
156
157fn maybe_make_closure_force(e: &mut Expr) {
158    let mut helper = ReplaceHoles {
159        is_async: None,
160        is_try: None,
161        holes: Vec::new(),
162    };
163    replace_holes(&mut helper, e);
164    if helper.holes.is_empty() {
165        return;
166    }
167
168    let or1_span = helper.holes.first().unwrap().span();
169    let or2_span = helper.holes.last().unwrap().span();
170    let mut inputs = helper
171        .holes
172        .into_iter()
173        .map(|id| {
174            let span = id.span();
175            punctuated::Pair::Punctuated(ident_pat(id), Token![,](span))
176        })
177        .collect::<Punctuated<Pat, Token![,]>>();
178    inputs.pop_punct();
179
180    let old_expr = replace(e, Expr::PLACEHOLDER);
181    *e = Expr::Closure(ExprClosure {
182        attrs: Vec::new(),
183        lifetimes: None,
184        constness: None,
185        movability: None,
186        asyncness: helper.is_async.map(Token![async]),
187        capture: None,
188        or1_token: Token![|](or1_span),
189        inputs,
190        or2_token: Token![|](or2_span),
191        output: ReturnType::Default,
192        body: match helper.is_try {
193            Some(try_span) => Box::new(Expr::Verbatim(
194                quote_spanned!(try_span=> ::core::result::Result::Ok(#old_expr)),
195            )),
196            None => Box::new(old_expr),
197        },
198    });
199}
200
201/// Traverses an expression’s subexpressions,
202/// ignoring subexpressions transparent to the macro,
203/// replacing holes with variables.
204struct ReplaceHoles {
205    is_async: Option<Span>,
206    is_try: Option<Span>,
207    holes: Vec<Ident>,
208}
209
210fn replace_holes(st: &mut ReplaceHoles, e: &mut Expr) {
211    match e {
212        Expr::Await(e) => {
213            st.is_async.get_or_insert(e.await_token.span);
214            replace_holes(st, &mut e.base);
215        }
216        Expr::Binary(e) => {
217            replace_holes(st, &mut e.left);
218            replace_holes(st, &mut e.right);
219        }
220        Expr::Call(e) => {
221            replace_holes(st, &mut e.func);
222        }
223        Expr::Field(e) => replace_holes(st, &mut e.base),
224        Expr::ForLoop(e) => replace_holes(st, &mut e.expr),
225        Expr::If(e) => {
226            replace_holes(st, &mut e.cond);
227            if let Some((_, else_branch)) = &mut e.else_branch {
228                replace_holes(st, else_branch);
229            }
230        }
231        Expr::Index(e) => replace_holes(st, &mut e.expr),
232        Expr::Let(e) => replace_holes(st, &mut e.expr),
233        Expr::Match(e) => replace_holes(st, &mut e.expr),
234        Expr::MethodCall(e) => replace_holes(st, &mut e.receiver),
235        Expr::Range(e) => {
236            if let Some(start) = &mut e.start {
237                replace_holes(st, start);
238            }
239            if let Some(end) = &mut e.end {
240                replace_holes(st, end);
241            }
242        }
243        Expr::Try(e) => {
244            st.is_try.get_or_insert(e.question_token.span);
245            replace_holes(st, &mut e.expr);
246        }
247        Expr::Unary(e) => replace_holes(st, &mut e.expr),
248        Expr::While(e) => replace_holes(st, &mut e.cond),
249
250        Expr::Infer(infer) => *e = st.add(infer.underscore_token),
251
252        // As an exception to the “no identity closures” rule,
253        // we allow (_) itself to be an identity closure.
254        Expr::Paren(e) => return maybe_make_closure_force(&mut e.expr),
255
256        Expr::Assign(e) => {
257            struct Helper<'a>(&'a mut ReplaceHoles);
258            impl VisitMut for Helper<'_> {
259                fn visit_expr_mut(&mut self, e: &mut Expr) {
260                    replace_holes(self.0, e);
261                }
262            }
263            IgnoreAssigneeExpr(Helper(st)).visit_expr_mut(&mut e.left);
264            st.visit_expr_mut(&mut e.right);
265            return;
266        }
267
268        _ => {}
269    }
270    // Call the freestanding function instead of the method
271    // to look for `_`s in top-level expressions.
272    visit_expr_mut(st, e);
273}
274
275// `VisitMut` implementation for all the top-level expressions,
276// to avoid producing identity closures.
277impl VisitMut for ReplaceHoles {
278    fn visit_expr_mut(&mut self, e: &mut Expr) {
279        if let Expr::Infer(infer) = e {
280            *e = self.add(infer.underscore_token);
281        } else {
282            maybe_make_closure_force(e);
283        }
284    }
285    // Override to treat `let _: [T; _] = …;` properly.
286    fn visit_type_array_mut(&mut self, i: &mut TypeArray) {
287        maybe_make_closure(&mut i.len);
288    }
289    // Override to treat `f::<{ _ }>` properly.
290    fn visit_generic_argument_mut(&mut self, a: &mut GenericArgument) {
291        if let GenericArgument::Const(Expr::Block(block)) = a {
292            self.visit_expr_block_mut(block);
293            return;
294        }
295        visit_generic_argument_mut(self, a);
296    }
297    fn visit_macro_mut(&mut self, m: &mut Macro) {
298        visit_macro(self, m);
299    }
300}
301
302impl ReplaceHoles {
303    fn add(&mut self, underscore_token: Token![_]) -> Expr {
304        let n = self.holes.len();
305        let ident = Ident::new(&format!("p{n}"), underscore_token.span);
306        self.holes.push(ident.clone());
307        Expr::Verbatim(ident.into_token_stream())
308    }
309}
310
311fn ident_pat(ident: Ident) -> Pat {
312    Pat::Ident(PatIdent {
313        attrs: Vec::new(),
314        by_ref: None,
315        mutability: None,
316        ident,
317        subpat: None,
318    })
319}
320
321struct IgnoreAssigneeExpr<V>(V);
322
323impl<V: VisitMut> VisitMut for IgnoreAssigneeExpr<V> {
324    // For the definition of assignee expressions:
325    // https://doc.rust-lang.org/reference/expressions.html#place-expressions-and-value-expressions
326    fn visit_expr_mut(&mut self, e: &mut Expr) {
327        match e {
328            Expr::Infer(_) => {}
329            Expr::Tuple(e) => {
330                for elem in &mut e.elems {
331                    self.visit_expr_mut(elem);
332                }
333            }
334            Expr::Array(e) => {
335                for elem in &mut e.elems {
336                    self.visit_expr_mut(elem);
337                }
338            }
339            Expr::Call(e) => {
340                self.0.visit_expr_mut(&mut e.func);
341                for arg in &mut e.args {
342                    self.visit_expr_mut(arg);
343                }
344            }
345            Expr::Struct(e) => {
346                for field in &mut e.fields {
347                    self.visit_field_value_mut(field);
348                }
349                if let Some(rest) = &mut e.rest {
350                    self.visit_expr_mut(rest);
351                }
352            }
353            _ => self.0.visit_expr_mut(e),
354        }
355    }
356}
357
358#[derive(Default)]
359struct Diagnostics(Option<syn::Error>);
360
361impl Diagnostics {
362    fn add_error(&mut self, new_e: syn::Error) {
363        *self = Self(Some(take(self).with(new_e)));
364    }
365    fn with(self, mut new_e: syn::Error) -> syn::Error {
366        if let Some(mut e) = self.0 {
367            e.combine(new_e);
368            new_e = e;
369        }
370        new_e
371    }
372    fn add(&mut self, tokens: impl ToTokens, message: impl Display) {
373        self.add_error(syn::Error::new_spanned(tokens, message));
374    }
375}
376
377fn visit_macro(v: &mut impl VisitMut, m: &mut Macro) {
378    let _ = visit_macro_inner(v, m);
379}
380
381fn visit_macro_inner(v: &mut impl VisitMut, m: &mut Macro) -> syn::Result<()> {
382    let Some(ident) = m.path.get_ident() else {
383        return Ok(());
384    };
385    let ident = ident.to_string();
386
387    if [
388        "assert",
389        "assert_eq",
390        "assert_ne",
391        "dbg",
392        "debug_assert",
393        "debug_assert_eq",
394        "debug_assert_ne",
395        "eprint",
396        "eprintln",
397        "format",
398        "format_args",
399        "panic",
400        "print",
401        "println",
402        "todo",
403        "unimplemented",
404        "unreachable",
405        "vec",
406        "write",
407        "writeln",
408    ]
409    .contains(&&*ident)
410    {
411        let mut args = m.parse_body_with(|input: ParseStream<'_>| {
412            <Punctuated<Expr, Token![,]>>::parse_terminated(input)
413        })?;
414        for arg in &mut args {
415            v.visit_expr_mut(arg);
416        }
417        m.tokens = args.into_token_stream();
418    } else if ident == "matches" {
419        let (mut scrutinee, comma, pattern, mut guard, comma2) =
420            m.parse_body_with(|input: ParseStream<'_>| {
421                let scrutinee: Expr = input.parse()?;
422                let comma = input.parse::<Token![,]>()?;
423                let pattern = Pat::parse_multi(input)?;
424                let guard = match input.parse::<Option<Token![if]>>()? {
425                    Some(r#if) => Some((r#if, input.parse::<Expr>()?)),
426                    None => None,
427                };
428                let comma2 = input.parse::<Option<Token![,]>>()?;
429                Ok((scrutinee, comma, pattern, guard, comma2))
430            })?;
431        v.visit_expr_mut(&mut scrutinee);
432        if let Some((_, guard)) = &mut guard {
433            v.visit_expr_mut(guard);
434        }
435
436        m.tokens = TokenStream::new();
437        scrutinee.to_tokens(&mut m.tokens);
438        comma.to_tokens(&mut m.tokens);
439        pattern.to_tokens(&mut m.tokens);
440        if let Some((r#if, guard)) = guard {
441            r#if.to_tokens(&mut m.tokens);
442            guard.to_tokens(&mut m.tokens);
443        }
444        comma2.to_tokens(&mut m.tokens);
445    }
446
447    Ok(())
448}
449
450#[cfg(test)]
451mod tests {
452    #[test]
453    fn top_level_underscores() {
454        assert_output!(_, _);
455        assert_output!(
456            {
457                _;
458            },
459            |p0| {
460                p0;
461            }
462        );
463    }
464
465    #[test]
466    fn identity() {
467        assert_output!((_), (|p0| p0));
468        assert_output!({ { _ } }, { |p0| { p0 } });
469    }
470
471    #[test]
472    fn avoid_identity() {
473        assert_output!(f(_), |p0| f(p0));
474        assert_output!((f(_)), (|p0| f(p0)));
475        assert_output!({ _ }, |p0| { p0 });
476        assert_output!(
477            if x {
478                _
479            },
480            |p0| if x {
481                p0
482            },
483        );
484        assert_output!(|_| _, |p0| |_| p0);
485        assert_output!(
486            {
487                x;
488                _
489            },
490            |p0| {
491                x;
492                p0
493            },
494        );
495        assert_output!(
496            {
497                _;
498                _
499            },
500            |p0, p1| {
501                p0;
502                p1
503            },
504        );
505        assert_output!(
506            if true {
507                _;
508                _
509            },
510            |p0, p1| if true {
511                p0;
512                p1
513            },
514        );
515    }
516
517    #[test]
518    fn assignee_expression_underscores() {
519        assert_output!(_ = x, _ = x);
520        assert_output!((_, _) = x, (_, _) = x);
521        assert_output!([_] = x, [_] = x);
522        assert_output!(f(_) = x, f(_) = x);
523        assert_output!(S { x: _ } = x, S { x: _ } = x);
524    }
525
526    #[test]
527    fn assignment() {
528        assert_output!(*f(_) = x, |p0| *f(p0) = x);
529        assert_output!(*f(g(_)) = x, *f(|p0| g(p0)) = x);
530        assert_output!(x = _, |p0| x = p0);
531        assert_output!(x = f(_), x = |p0| f(p0));
532    }
533
534    #[test]
535    fn infer_underscores() {
536        assert_output!(
537            {
538                let x: [u8; _];
539            },
540            {
541                let x: [u8; _];
542            }
543        );
544        assert_output!(
545            {
546                let x: [u8; f(_)];
547            },
548            {
549                let x: [u8; |p0| f(p0)];
550            }
551        );
552        assert_output!(f::<_>(), f::<_>());
553        assert_output!(f::<{ _ }>(), |p0| f::<{ p0 }>());
554        assert_output!(f::<{ f(_) }>(), f::<{ |p0| f(p0) }>());
555    }
556
557    #[test]
558    fn transparent() {
559        assert_output!(f(_).await, async |p0| f(p0).await);
560        assert_output!(f(_) + 1, |p0| f(p0) + 1);
561        assert_output!(x += f(_), |p0| x += f(p0));
562        assert_output!(f(_).f, |p0| f(p0).f);
563        assert_output!(for _ in f(_) {}, |p0| for _ in f(p0) {});
564        assert_output!(if f(_) {}, |p0| if f(p0) {});
565        assert_output!(
566            if x {
567            } else if f(_) {
568            },
569            |p0| if x {
570            } else if f(p0) {
571            }
572        );
573        assert_output!(f(_)[0], |p0| f(p0)[0]);
574        assert_output!(if let P = f(_) {}, |p0| if let P = f(p0) {});
575        assert_output!(match f(_) {}, |p0| match f(p0) {});
576        assert_output!(f(_).g(), |p0| f(p0).g());
577        assert_output!(_..x, |p0| p0..x);
578        // This doesn’t even parse for some reason:
579        // assert_output!(x..=_, |p0| x..p0);
580        assert_output!(f(_)?, |p0| ::core::result::Result::Ok(f(p0)?));
581        assert_output!(!f(_), |p0| !f(p0));
582        assert_output!(while f(_) {}, |p0| while f(p0) {});
583    }
584
585    #[test]
586    fn macros() {
587        assert_output!(arbitrary!(_), arbitrary!(_));
588        assert_output!(assert!(_), |p0| assert!(p0));
589        assert_output!(assert_ne!(_, _, _), |p0, p1, p2| assert_ne!(p0, p1, p2));
590        assert_output!(vec![_], |p0| vec![p0]);
591        assert_output!(matches!(_, _), |p0| matches!(p0, _));
592        assert_output!(matches!(_, _ if _), |p0, p1| matches!(p0, _ if p1));
593    }
594
595    macro_rules! assert_output {
596        ($in:expr, $out:expr $(,)?) => {
597            assert_output_inner(quote!($in), quote!($out))
598        };
599    }
600    use assert_output;
601
602    #[track_caller]
603    fn assert_output_inner(r#in: TokenStream, out: TokenStream) {
604        let mut r#in = syn::parse2::<Expr>(r#in).unwrap();
605        let out = syn::parse2::<Expr>(out).unwrap();
606        maybe_make_closure(&mut r#in);
607        assert_eq!(
608            r#in.into_token_stream().to_string(),
609            out.into_token_stream().to_string()
610        );
611    }
612
613    use super::maybe_make_closure;
614    use proc_macro2::TokenStream;
615    use quote::ToTokens as _;
616    use quote::quote;
617    use syn::Expr;
618}
619
620use proc_macro2::Ident;
621use proc_macro2::Span;
622use proc_macro2::TokenStream;
623use quote::ToTokens;
624use quote::quote_spanned;
625use std::fmt::Display;
626use std::mem::replace;
627use std::mem::take;
628use syn::Expr;
629use syn::ExprClosure;
630use syn::GenericArgument;
631use syn::Item;
632use syn::Macro;
633use syn::Pat;
634use syn::PatIdent;
635use syn::ReturnType;
636use syn::Token;
637use syn::TypeArray;
638use syn::parse::ParseStream;
639use syn::punctuated;
640use syn::punctuated::Punctuated;
641use syn::visit_mut::VisitMut;
642use syn::visit_mut::visit_expr_mut;
643use syn::visit_mut::visit_generic_argument_mut;