cbit/
lib.rs

1#![allow(rustdoc::redundant_explicit_links)] // For cargo-rdme's sake
2
3//! A proc-macro to use callback-based iterators with `for`-loop syntax and functionality.
4//!
5//! ## Overview
6//!
7//! `cbit` (short for **c**losure-**b**ased **it**erator) is a crate which allows you to use iterator
8//! functions which call into a closure to process each element as if they were just a regular Rust
9//! [`Iterator`](::std::iter::Iterator) in a `for` loop. To create an iterator, just define a function
10//! which takes in a closure as its last argument. Both the function and the closure must return a
11//! [`ControlFlow`](::std::ops::ControlFlow) object with some generic `Break` type.
12//!
13//! ```
14//! use std::ops::ControlFlow;
15//!
16//! fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
17//!     for i in 0..n {
18//!         f(i)?;
19//!     }
20//!     ControlFlow::Continue(())
21//! }
22//! ```
23//!
24//! From there, you can use the iterator like a regular `for`-loop by driving it using the
25//! [`cbit!`](crate::cbit!) macro.
26//!
27//! ```rust
28//! # use std::ops::ControlFlow;
29//! # fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
30//! #     for i in 0..n {
31//! #         f(i)?;
32//! #     }
33//! #     ControlFlow::Continue(())
34//! # }
35//! fn demo(n: u64) -> u64 {
36//!     let mut c = 0;
37//!     cbit::cbit!(for i in up_to(n) {
38//!         c += i;
39//!     });
40//!     c
41//! }
42//! ```
43//!
44//! Although the body of the `for` loop is technically nested in a closure, it supports all the
45//! regular control-flow mechanisms one would expect:
46//!
47//! You can early-`return` to the outer function...
48//!
49//! ```rust
50//! # use std::ops::ControlFlow;
51//! # fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
52//! #     for i in 0..n {
53//! #         f(i)?;
54//! #     }
55//! #     ControlFlow::Continue(())
56//! # }
57//! fn demo(n: u64) -> u64 {
58//!     let mut c = 0;
59//!     cbit::cbit!(for i in up_to(n) {
60//!         c += i;
61//!         if c > 1000 {
62//!             return u64::MAX;
63//!         }
64//!     });
65//!     c
66//! }
67//!
68//! assert_eq!(demo(500), u64::MAX);
69//! ```
70//!
71//! You can `break` and `continue` in the body...
72//!
73//! ```rust
74//! # use std::ops::ControlFlow;
75//! # fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
76//! #     for i in 0..n {
77//! #         f(i)?;
78//! #     }
79//! #     ControlFlow::Continue(())
80//! # }
81//! fn demo(n: u64) -> u64 {
82//!     let mut c = 0;
83//!     cbit::cbit!('me: for i in up_to(n) {
84//!         if i == 2 {
85//!             continue 'me;  // This label is optional.
86//!         }
87//!
88//!         c += i;
89//!
90//!         if c > 5 {
91//!             break;
92//!         }
93//!     });
94//!     c
95//! }
96//!
97//! assert_eq!(demo(5), 1 + 3 + 4);
98//! ```
99//!
100//! And you can even `break` and `continue` to scopes outside the body!
101//!
102//! ```rust
103//! # use std::ops::ControlFlow;
104//! # fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
105//! #     for i in 0..n {
106//! #         f(i)?;
107//! #     }
108//! #     ControlFlow::Continue(())
109//! # }
110//! fn demo(n: u64) -> u64 {
111//!     let mut c = 0;
112//!     'outer_1: loop {
113//!         let something = 'outer_2: {
114//!             cbit::cbit!(for i in up_to(n) break loop 'outer_1, 'outer_2 {
115//!                 if i == 5 && c < 20 {
116//!                     continue 'outer_1;
117//!                 }
118//!                 if i == 8 {
119//!                     break 'outer_2 c < 10;
120//!                 }
121//!                 c += i;
122//!             });
123//!             false
124//!         };
125//!
126//!         if something {
127//!             assert!(c < 10);
128//!         } else {
129//!             break;
130//!         }
131//!     }
132//!     c
133//! }
134//!
135//! demo(10);  // I'm honestly not really sure what this function is supposed to do.
136//! ```
137//!
138//! Check the documentation of [`cbit!`] for more details on its syntax and specific behavior.
139//!
140//! ## Advantages and Drawbacks
141//!
142//! Closure-based iterators play much nicer with the Rust optimizer than coroutines and their
143//! [stable `async` userland counterpart](https://docs.rs/genawaiter/latest/genawaiter/) do
144//! as of `rustc 1.74.0`.
145//!
146//! Here is the disassembly of a regular loop implementation of factorial:
147//!
148//! ```
149//! pub fn regular(n: u64) -> u64 {
150//!     let mut c = 0;
151//!     for i in 0..n {
152//!         c += i;
153//!     }
154//!     c
155//! }
156//! ```
157//!
158//! ```text
159//! asm::regular:
160//! Lfunc_begin7:
161//!         push rbp
162//!         mov rbp, rsp
163//!         test rdi, rdi
164//!         je LBB7_1
165//!         lea rax, [rdi - 1]
166//!         lea rcx, [rdi - 2]
167//!         mul rcx
168//!         shld rdx, rax, 63
169//!         lea rax, [rdi + rdx - 1]
170//!         pop rbp
171//!         ret
172//! LBB7_1:
173//!         xor eax, eax
174//!         pop rbp
175//!         ret
176//! ```
177//!
178//! ...and here is the disassembly of the loop reimplemented in cbit:
179//!
180//! ```
181//! use std::ops::ControlFlow;
182//!
183//! pub fn cbit(n: u64) -> u64 {
184//!     let mut c = 0;
185//!     cbit::cbit!(for i in up_to(n) {
186//!         c += i;
187//!     });
188//!     c
189//! }
190//!
191//! fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
192//!     for i in 0..n {
193//!         f(i)?;
194//!     }
195//!     ControlFlow::Continue(())
196//! }
197//! ```
198//!
199//! ```text
200//! asm::cbit:
201//! Lfunc_begin8:
202//!         push rbp
203//!         mov rbp, rsp
204//!         test rdi, rdi
205//!         je LBB8_1
206//!         lea rax, [rdi - 1]
207//!         lea rcx, [rdi - 2]
208//!         mul rcx
209//!         shld rdx, rax, 63
210//!         lea rax, [rdi + rdx - 1]
211//!         pop rbp
212//!         ret
213//! LBB8_1:
214//!         xor eax, eax
215//!         pop rbp
216//!         ret
217//! ```
218//!
219//! Except for the label names, they're entirely identical!
220//!
221//! Meanwhile, the same example written with `rustc 1.76.0-nightly (49b3924bd 2023-11-27)`'s coroutines
222//! yields far worse codegen ([permalink](https://godbolt.org/z/Kjh9q195s)):
223//!
224//! ```no_compile
225//! #![feature(coroutines, coroutine_trait, iter_from_coroutine)]
226//!
227//! use std::{iter::from_coroutine, ops::Coroutine};
228//!
229//! fn upto_n(n: u64) -> impl Coroutine<Yield = u64, Return = ()> {
230//!     move || {
231//!         for i in 0..n {
232//!             yield i;
233//!         }
234//!     }
235//! }
236//!
237//! pub fn sum(n: u64) -> u64 {
238//!     let mut c = 0;
239//!     let mut co = std::pin::pin!(upto_n(n));
240//!     for i in from_coroutine(co) {
241//!         c += i;
242//!     }
243//!     c
244//! }
245//! ```
246//!
247//! ```text
248//! example::sum:
249//!         xor     edx, edx
250//!         xor     eax, eax
251//!         test    edx, edx
252//!         je      .LBB0_4
253//! .LBB0_2:
254//!         cmp     edx, 3
255//!         jne     .LBB0_3
256//!         cmp     rcx, rdi
257//!         jb      .LBB0_7
258//!         jmp     .LBB0_6
259//! .LBB0_4:
260//!         xor     ecx, ecx
261//!         cmp     rcx, rdi
262//!         jae     .LBB0_6
263//! .LBB0_7:
264//!         setb    dl
265//!         movzx   edx, dl
266//!         add     rax, rcx
267//!         add     rcx, rdx
268//!         lea     edx, [2*rdx + 1]
269//!         test    edx, edx
270//!         jne     .LBB0_2
271//!         jmp     .LBB0_4
272//! .LBB0_6:
273//!         ret
274//! .LBB0_3:
275//!         push    rax
276//!         lea     rdi, [rip + str.0]
277//!         lea     rdx, [rip + .L__unnamed_1]
278//!         mov     esi, 34
279//!         call    qword ptr [rip + core::panicking::panic@GOTPCREL]
280//!         ud2
281//! ```
282//!
283//! A similar thing can be seen with userland implementations of this feature such as
284//! [`genawaiter`](https://docs.rs/genawaiter/latest/genawaiter/index.html).
285//!
286//! However, what more general coroutine implementations provide in exchange for potential performance
287//! degradation is immense expressivity. Fundamentally, `cbit` iterators cannot be interwoven, making
288//! adapters such as `zip` impossible to implement—something coroutines have no problem doing.
289
290use proc_macro2::{Ident, Span, TokenStream};
291use quote::quote;
292use syn::{punctuated::Punctuated, Lifetime, Token};
293use syntax::CbitForExpr;
294
295mod syntax;
296
297/// A proc-macro to use callback-based iterators with for-loop syntax and functionality.
298///
299/// ## Syntax
300///
301/// ```text
302/// ('<loop-label: lifetime>:)? for <binding: pattern> in <iterator: function-call-expr>
303///     (break ((loop)? '<extern-label: lifetime>)*)?
304/// {
305///     <body: token stream>
306/// }
307/// ```
308///
309/// Arguments:
310///
311/// - `loop-label`: This is the optional label used by your virtual loop. `break`'ing or `continue`'ing
312///   to this label will break out of and continue the cbit iterator respectively.
313/// - `binding`: This is the irrefutable pattern the iterator's arguments will be decomposed into.
314/// - `iterator`: Syntactically, this can be any (potentially generic) function or method call
315///   expression and generics can be explicitly supplied if desired. See the [iteration protocol](#iteration-protocol)
316///   section for details on the semantic requirements for this function.
317/// - The loop also contains an optional list of external control-flow labels which is started by the
318///   `break` keyword and is followed by a non-empty non-trailing comma-separated list of...
319///      - An optional `loop` keyword which, if specified, asserts that the label can accept `continue`s
320///        in addition to `break`s.
321///      - `extern-label`: the label the `cbit!` body is allowed to `break` or `continue` out to.
322///
323/// ## Iteration Protocol
324///
325/// The called function or method can take on any non-zero number of arguments but must accept a
326/// single-argument function closure as its last argument. The closure must be able to return a
327/// [`ControlFlow`](::std::ops::ControlFlow) object with a generic `Break` type and the function must
328/// return a `ControlFlow` object with the same `Break` type.
329///
330/// ```
331/// use std::{iter::IntoIterator, ops::ControlFlow};
332///
333/// // A simple example...
334/// fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
335///     for i in 0..n {
336///         f(i)?;
337///     }
338///     ControlFlow::Continue(())
339/// }
340///
341/// // A slightly more involved example...
342/// fn enumerate<I: IntoIterator, B>(
343///     values: I,
344///     index_offset: usize,
345///     mut f: impl FnMut((usize, I::Item),
346/// ) -> ControlFlow<B>) -> ControlFlow<B> {
347///     for (i, v) in values.into_iter().enumerate() {
348///         f((i + index_offset, v))?;
349///     }
350///     ControlFlow::Continue(())
351/// }
352/// ```
353///
354/// The `Continue` parameter of the `ControlFlow` objects, meanwhile, is a lot more flexible. The
355/// `Continue` parameter on the return type of the inner closure designates the type users are expected
356/// to give back to the calling iterator function. Since users can run `continue` in the body, this
357/// type must implement [`Default`].
358///
359/// The `Continue` parameter on the return type of the iterator function, meanwhile, can be used to
360/// return values from the `cbit!` macro expression. If users `break` out of loops with a non-unit
361/// output `Continue` type, they must provide this value themself.
362///
363/// ```
364/// use std::ops::ControlFlow;
365///
366/// fn demo(list: &[i32]) -> i32 {
367///     cbit::cbit!(for (accum, value) in reduce(0, list) {
368///         if *value > 100 {
369///             break -1;
370///         }
371///         accum + value
372///     })
373/// }
374///
375/// fn reduce<T, I: IntoIterator, B>(
376///     initial: T,
377///     values: I,
378///     mut f: impl FnMut((T, I::Item)) -> ControlFlow<B, T>,
379/// ) -> ControlFlow<B, T> {
380///     let mut accum = initial;
381///     for value in values {
382///         accum = f((accum, value))?;
383///     }
384///     ControlFlow::Continue(accum)
385/// }
386///
387/// assert_eq!(demo(&[1, 2, 3]), 6);
388/// assert_eq!(demo(&[1, 2, 3, 4, 101, 8]), -1);
389/// ```
390#[proc_macro]
391pub fn cbit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
392    let input = syn::parse_macro_input!(input as CbitForExpr);
393
394    // Define some common syntax trees
395    let core_ = quote! { ::core };
396    let ops_ = quote! { #core_::ops };
397    let pin_ = quote! { #core_::pin };
398    let task_ = quote! { #core_::task };
399    let future_ = quote! { #core_::future };
400    let option_ = quote! { #core_::option::Option };
401
402    // Extract our break labels
403    let empty_punct_list = Punctuated::new();
404    let in_break_labels = input
405        .breaks
406        .as_ref()
407        .map_or(&empty_punct_list, |breaks| &breaks.lt);
408
409    let derive_early_break_variant_name =
410        |lt: &Lifetime| Ident::new(&format!("EarlyBreakTo_{}", lt.ident), lt.span());
411
412    let derive_early_continue_variant_name =
413        |lt: &Lifetime| Ident::new(&format!("EarlyContinueTo_{}", lt.ident), lt.span());
414
415    // Define an enum for our control flow
416    let control_flow_enum_def;
417    let control_flow_ty_decl;
418    let control_flow_ty_use;
419    {
420        let break_variant_names = in_break_labels
421            .iter()
422            .map(|v| derive_early_break_variant_name(&v.lt))
423            .collect::<Vec<_>>();
424
425        let continue_variant_names = in_break_labels
426            .iter()
427            .filter(|&v| v.kw_loop.is_some())
428            .map(|v| derive_early_continue_variant_name(&v.lt));
429
430        control_flow_enum_def = quote! {
431            #[allow(non_camel_case_types)]
432            #[allow(clippy::enum_variant_names)]
433            enum OurControlFlowResult<EarlyReturn, EarlyBreak #(, #break_variant_names)*> {
434                EarlyReturn(EarlyReturn),
435                EarlyBreak(EarlyBreak),
436                #(#break_variant_names (#break_variant_names),)*
437                #(#continue_variant_names,)*
438            }
439        };
440
441        control_flow_ty_decl = quote! {
442            #[allow(non_camel_case_types)]
443            type OurControlFlow<EarlyReturn, EarlyBreak #(, #break_variant_names)*> = #ops_::ControlFlow<
444                OurControlFlowResult<EarlyReturn, EarlyBreak #(, #break_variant_names)*>,
445                EarlyBreak,
446            >;
447        };
448
449        let underscores =
450            (0..(break_variant_names.len() + 2)).map(|_| Token![_](Span::call_site()));
451
452        control_flow_ty_use = quote! { OurControlFlow<#(#underscores),*> };
453    }
454
455    // Define our initial break layer
456    let aborter = |resolution: TokenStream| {
457        quote! {
458            how_to_resolve_pending = #option_::Some(#resolution);
459            #future_::pending::<()>().await;
460            #core_::unreachable!();
461        }
462    };
463
464    let for_body = input.body.body;
465    let for_body = {
466        let optional_label = &input.label;
467        let break_aborter = aborter(quote! {
468            #ops_::ControlFlow::Break(OurControlFlowResult::EarlyBreak(break_result))
469        });
470
471        quote! {
472            '__cbit_absorber_magic_innermost: {
473                let mut did_run = false;
474                let break_result = #optional_label loop {
475                    if did_run {
476                        // The user must have used `continue`.
477                        break '__cbit_absorber_magic_innermost #core_::default::Default::default();
478                    }
479
480                    did_run = true;
481                    let break_result = { #for_body };
482
483                    // The user completed the loop.
484                    #[allow(unreachable_code)]
485                    break '__cbit_absorber_magic_innermost break_result;
486                };
487
488                // The user broke out of the loop.
489                #[allow(unreachable_code)]
490                {
491                    #break_aborter
492                }
493            }
494        }
495    };
496
497    // Build up an onion of user-specified break layers
498    let for_body = {
499        let mut for_body = for_body;
500        for break_label_entry in in_break_labels {
501            let break_label = &break_label_entry.lt;
502
503            let break_aborter = {
504                let variant_name = derive_early_break_variant_name(break_label);
505                aborter(quote! {
506                    #ops_::ControlFlow::Break(OurControlFlowResult::#variant_name(break_result))
507                })
508            };
509
510            let outer_label = Lifetime::new(
511                &format!("'__cbit_absorber_magic_for_{}", break_label.ident),
512                break_label.span(),
513            );
514
515            if break_label_entry.kw_loop.is_some() {
516                let continue_aborter = {
517                    let variant_name = derive_early_continue_variant_name(break_label);
518                    aborter(quote! {
519                        #ops_::ControlFlow::Break(OurControlFlowResult::#variant_name)
520                    })
521                };
522
523                for_body = quote! {#outer_label: {
524                    let mut did_run = false;
525                    let break_result = #break_label: loop {
526                        if did_run {
527                            // The user must have used `continue`.
528                            #continue_aborter
529                        }
530
531                        did_run = true;
532                        let break_result = { #for_body };
533
534                        // The user completed the loop.
535                        #[allow(unreachable_code)]
536                        break #outer_label break_result;
537                    };
538
539                    // The user broke out of the loop.
540                    #[allow(unreachable_code)]
541                    {
542                        #break_aborter
543                    }
544                }};
545            } else {
546                for_body = quote! {#outer_label: {
547                    let break_result = #break_label: {
548                        let break_result = { #for_body };
549
550                        // The user completed the loop.
551                        #[allow(unreachable_code)]
552                        break #outer_label break_result;
553                    };
554
555                    // The user broke out of the block.
556                    #[allow(unreachable_code)]
557                    {
558                        #break_aborter
559                    }
560                }};
561            }
562        }
563
564        for_body
565    };
566
567    // Build up a layer to capture early returns and generally process arguments
568    let for_body = {
569        let body_input_pat = &input.body_pattern;
570        let termination_aborter = aborter(quote! { #ops_::ControlFlow::Continue(end_result) });
571        quote! {
572            |#body_input_pat| {
573                let mut how_to_resolve_pending = #option_::None;
574
575                let body = #pin_::pin!(async {
576                    let end_result = { #for_body };
577
578                    #[allow(unreachable_code)] { #termination_aborter }
579                });
580
581                match #future_::Future::poll(
582                    body,
583                    &mut #task_::Context::from_waker(&{  // TODO: Use `Waker::noop` once it stabilizes
584                        const VTABLE: #task_::RawWakerVTable = #task_::RawWakerVTable::new(
585                            // Cloning just returns a new no-op raw waker
586                            |_| RAW,
587                            // `wake` does nothing
588                            |_| {},
589                            // `wake_by_ref` does nothing
590                            |_| {},
591                            // Dropping does nothing as we don't allocate anything
592                            |_| {},
593                        );
594                        const RAW: #task_::RawWaker = #task_::RawWaker::new(#core_::ptr::null(), &VTABLE);
595                        unsafe { #task_::Waker::from_raw(RAW) }
596                    })
597                ) {
598                    #task_::Poll::Ready(early_return) => #ops_::ControlFlow::Break(
599                        OurControlFlowResult::EarlyReturn(early_return),
600                    ),
601                    #task_::Poll::Pending => how_to_resolve_pending.expect(
602                        "the async block in a cbit iterator is an implementation detail; do not \
603                         `.await` in it!"
604                    ),
605                }
606            }
607        }
608    };
609
610    // Build up a list of break/continue handlers
611    let break_out_matchers = in_break_labels.iter().map(|v| {
612        let lt = &v.lt;
613        let variant_name = derive_early_break_variant_name(lt);
614        quote! {
615            OurControlFlowResult::#variant_name(break_out) => break #lt break_out,
616        }
617    });
618
619    let continue_out_matchers = in_break_labels
620        .iter()
621        .filter(|v| v.kw_loop.is_some())
622        .map(|v| {
623            let lt = &v.lt;
624            let variant_name = derive_early_continue_variant_name(lt);
625            quote! {
626                OurControlFlowResult::#variant_name => continue #lt,
627            }
628        });
629
630    // Build up our function call site
631    let driver_call_site = match &input.call {
632        syntax::AnyCallExpr::Function(call) => {
633            let driver_attrs = &call.attrs;
634            let driver_fn_expr = &call.func;
635            let driver_fn_args = call.args.iter();
636
637            quote! {
638                #(#driver_attrs)*
639                let result: #control_flow_ty_use = #driver_fn_expr (#(#driver_fn_args,)* #for_body);
640            }
641        }
642        syntax::AnyCallExpr::Method(call) => {
643            let driver_attrs = &call.attrs;
644            let driver_receiver_expr = &call.receiver;
645            let driver_method = &call.method;
646            let driver_turbo = &call.turbofish;
647            let driver_fn_args = call.args.iter();
648
649            quote! {
650                #(#driver_attrs)*
651                let result: #control_flow_ty_use =
652                    #driver_receiver_expr.#driver_method #driver_turbo (
653                        #(#driver_fn_args,)*
654                        #for_body
655                    );
656            }
657        }
658    };
659
660    // Put it all together
661    quote! {{
662        // enum ControlFlowResult<...> { ... }
663        #control_flow_enum_def
664
665        // type ControlFlow<A, B, ...> = core::ops::ControlFlow<ControlFlowResult<A, B, ...>, A>;
666        #control_flow_ty_decl
667
668        // let result = my_fn(args, |...| async { ... });
669        #driver_call_site
670
671        match result {
672            #ops_::ControlFlow::Break(result) => match result {
673                OurControlFlowResult::EarlyReturn(early_result) => return early_result,
674                OurControlFlowResult::EarlyBreak(result) => result,
675                #(#break_out_matchers)*
676                #(#continue_out_matchers)*
677            },
678            #ops_::ControlFlow::Continue(result) => result,
679        }
680    }}
681    .into()
682}