cbit/lib.rs
1#![allow(rustdoc::redundant_explicit_links)] // For cargo-rdme's sake
2
3//! A proc-macro to use callback-based iterators with `for`-loop syntax and functionality.
4//!
5//! ## Overview
6//!
7//! `cbit` (short for **c**losure-**b**ased **it**erator) is a crate which allows you to use iterator
8//! functions which call into a closure to process each element as if they were just a regular Rust
9//! [`Iterator`](::std::iter::Iterator) in a `for` loop. To create an iterator, just define a function
10//! which takes in a closure as its last argument. Both the function and the closure must return a
11//! [`ControlFlow`](::std::ops::ControlFlow) object with some generic `Break` type.
12//!
13//! ```
14//! use std::ops::ControlFlow;
15//!
16//! fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
17//! for i in 0..n {
18//! f(i)?;
19//! }
20//! ControlFlow::Continue(())
21//! }
22//! ```
23//!
24//! From there, you can use the iterator like a regular `for`-loop by driving it using the
25//! [`cbit!`](crate::cbit!) macro.
26//!
27//! ```rust
28//! # use std::ops::ControlFlow;
29//! # fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
30//! # for i in 0..n {
31//! # f(i)?;
32//! # }
33//! # ControlFlow::Continue(())
34//! # }
35//! fn demo(n: u64) -> u64 {
36//! let mut c = 0;
37//! cbit::cbit!(for i in up_to(n) {
38//! c += i;
39//! });
40//! c
41//! }
42//! ```
43//!
44//! Although the body of the `for` loop is technically nested in a closure, it supports all the
45//! regular control-flow mechanisms one would expect:
46//!
47//! You can early-`return` to the outer function...
48//!
49//! ```rust
50//! # use std::ops::ControlFlow;
51//! # fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
52//! # for i in 0..n {
53//! # f(i)?;
54//! # }
55//! # ControlFlow::Continue(())
56//! # }
57//! fn demo(n: u64) -> u64 {
58//! let mut c = 0;
59//! cbit::cbit!(for i in up_to(n) {
60//! c += i;
61//! if c > 1000 {
62//! return u64::MAX;
63//! }
64//! });
65//! c
66//! }
67//!
68//! assert_eq!(demo(500), u64::MAX);
69//! ```
70//!
71//! You can `break` and `continue` in the body...
72//!
73//! ```rust
74//! # use std::ops::ControlFlow;
75//! # fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
76//! # for i in 0..n {
77//! # f(i)?;
78//! # }
79//! # ControlFlow::Continue(())
80//! # }
81//! fn demo(n: u64) -> u64 {
82//! let mut c = 0;
83//! cbit::cbit!('me: for i in up_to(n) {
84//! if i == 2 {
85//! continue 'me; // This label is optional.
86//! }
87//!
88//! c += i;
89//!
90//! if c > 5 {
91//! break;
92//! }
93//! });
94//! c
95//! }
96//!
97//! assert_eq!(demo(5), 1 + 3 + 4);
98//! ```
99//!
100//! And you can even `break` and `continue` to scopes outside the body!
101//!
102//! ```rust
103//! # use std::ops::ControlFlow;
104//! # fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
105//! # for i in 0..n {
106//! # f(i)?;
107//! # }
108//! # ControlFlow::Continue(())
109//! # }
110//! fn demo(n: u64) -> u64 {
111//! let mut c = 0;
112//! 'outer_1: loop {
113//! let something = 'outer_2: {
114//! cbit::cbit!(for i in up_to(n) break loop 'outer_1, 'outer_2 {
115//! if i == 5 && c < 20 {
116//! continue 'outer_1;
117//! }
118//! if i == 8 {
119//! break 'outer_2 c < 10;
120//! }
121//! c += i;
122//! });
123//! false
124//! };
125//!
126//! if something {
127//! assert!(c < 10);
128//! } else {
129//! break;
130//! }
131//! }
132//! c
133//! }
134//!
135//! demo(10); // I'm honestly not really sure what this function is supposed to do.
136//! ```
137//!
138//! Check the documentation of [`cbit!`] for more details on its syntax and specific behavior.
139//!
140//! ## Advantages and Drawbacks
141//!
142//! Closure-based iterators play much nicer with the Rust optimizer than coroutines and their
143//! [stable `async` userland counterpart](https://docs.rs/genawaiter/latest/genawaiter/) do
144//! as of `rustc 1.74.0`.
145//!
146//! Here is the disassembly of a regular loop implementation of factorial:
147//!
148//! ```
149//! pub fn regular(n: u64) -> u64 {
150//! let mut c = 0;
151//! for i in 0..n {
152//! c += i;
153//! }
154//! c
155//! }
156//! ```
157//!
158//! ```text
159//! asm::regular:
160//! Lfunc_begin7:
161//! push rbp
162//! mov rbp, rsp
163//! test rdi, rdi
164//! je LBB7_1
165//! lea rax, [rdi - 1]
166//! lea rcx, [rdi - 2]
167//! mul rcx
168//! shld rdx, rax, 63
169//! lea rax, [rdi + rdx - 1]
170//! pop rbp
171//! ret
172//! LBB7_1:
173//! xor eax, eax
174//! pop rbp
175//! ret
176//! ```
177//!
178//! ...and here is the disassembly of the loop reimplemented in cbit:
179//!
180//! ```
181//! use std::ops::ControlFlow;
182//!
183//! pub fn cbit(n: u64) -> u64 {
184//! let mut c = 0;
185//! cbit::cbit!(for i in up_to(n) {
186//! c += i;
187//! });
188//! c
189//! }
190//!
191//! fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
192//! for i in 0..n {
193//! f(i)?;
194//! }
195//! ControlFlow::Continue(())
196//! }
197//! ```
198//!
199//! ```text
200//! asm::cbit:
201//! Lfunc_begin8:
202//! push rbp
203//! mov rbp, rsp
204//! test rdi, rdi
205//! je LBB8_1
206//! lea rax, [rdi - 1]
207//! lea rcx, [rdi - 2]
208//! mul rcx
209//! shld rdx, rax, 63
210//! lea rax, [rdi + rdx - 1]
211//! pop rbp
212//! ret
213//! LBB8_1:
214//! xor eax, eax
215//! pop rbp
216//! ret
217//! ```
218//!
219//! Except for the label names, they're entirely identical!
220//!
221//! Meanwhile, the same example written with `rustc 1.76.0-nightly (49b3924bd 2023-11-27)`'s coroutines
222//! yields far worse codegen ([permalink](https://godbolt.org/z/Kjh9q195s)):
223//!
224//! ```no_compile
225//! #![feature(coroutines, coroutine_trait, iter_from_coroutine)]
226//!
227//! use std::{iter::from_coroutine, ops::Coroutine};
228//!
229//! fn upto_n(n: u64) -> impl Coroutine<Yield = u64, Return = ()> {
230//! move || {
231//! for i in 0..n {
232//! yield i;
233//! }
234//! }
235//! }
236//!
237//! pub fn sum(n: u64) -> u64 {
238//! let mut c = 0;
239//! let mut co = std::pin::pin!(upto_n(n));
240//! for i in from_coroutine(co) {
241//! c += i;
242//! }
243//! c
244//! }
245//! ```
246//!
247//! ```text
248//! example::sum:
249//! xor edx, edx
250//! xor eax, eax
251//! test edx, edx
252//! je .LBB0_4
253//! .LBB0_2:
254//! cmp edx, 3
255//! jne .LBB0_3
256//! cmp rcx, rdi
257//! jb .LBB0_7
258//! jmp .LBB0_6
259//! .LBB0_4:
260//! xor ecx, ecx
261//! cmp rcx, rdi
262//! jae .LBB0_6
263//! .LBB0_7:
264//! setb dl
265//! movzx edx, dl
266//! add rax, rcx
267//! add rcx, rdx
268//! lea edx, [2*rdx + 1]
269//! test edx, edx
270//! jne .LBB0_2
271//! jmp .LBB0_4
272//! .LBB0_6:
273//! ret
274//! .LBB0_3:
275//! push rax
276//! lea rdi, [rip + str.0]
277//! lea rdx, [rip + .L__unnamed_1]
278//! mov esi, 34
279//! call qword ptr [rip + core::panicking::panic@GOTPCREL]
280//! ud2
281//! ```
282//!
283//! A similar thing can be seen with userland implementations of this feature such as
284//! [`genawaiter`](https://docs.rs/genawaiter/latest/genawaiter/index.html).
285//!
286//! However, what more general coroutine implementations provide in exchange for potential performance
287//! degradation is immense expressivity. Fundamentally, `cbit` iterators cannot be interwoven, making
288//! adapters such as `zip` impossible to implement—something coroutines have no problem doing.
289
290use proc_macro2::{Ident, Span, TokenStream};
291use quote::quote;
292use syn::{punctuated::Punctuated, Lifetime, Token};
293use syntax::CbitForExpr;
294
295mod syntax;
296
297/// A proc-macro to use callback-based iterators with for-loop syntax and functionality.
298///
299/// ## Syntax
300///
301/// ```text
302/// ('<loop-label: lifetime>:)? for <binding: pattern> in <iterator: function-call-expr>
303/// (break ((loop)? '<extern-label: lifetime>)*)?
304/// {
305/// <body: token stream>
306/// }
307/// ```
308///
309/// Arguments:
310///
311/// - `loop-label`: This is the optional label used by your virtual loop. `break`'ing or `continue`'ing
312/// to this label will break out of and continue the cbit iterator respectively.
313/// - `binding`: This is the irrefutable pattern the iterator's arguments will be decomposed into.
314/// - `iterator`: Syntactically, this can be any (potentially generic) function or method call
315/// expression and generics can be explicitly supplied if desired. See the [iteration protocol](#iteration-protocol)
316/// section for details on the semantic requirements for this function.
317/// - The loop also contains an optional list of external control-flow labels which is started by the
318/// `break` keyword and is followed by a non-empty non-trailing comma-separated list of...
319/// - An optional `loop` keyword which, if specified, asserts that the label can accept `continue`s
320/// in addition to `break`s.
321/// - `extern-label`: the label the `cbit!` body is allowed to `break` or `continue` out to.
322///
323/// ## Iteration Protocol
324///
325/// The called function or method can take on any non-zero number of arguments but must accept a
326/// single-argument function closure as its last argument. The closure must be able to return a
327/// [`ControlFlow`](::std::ops::ControlFlow) object with a generic `Break` type and the function must
328/// return a `ControlFlow` object with the same `Break` type.
329///
330/// ```
331/// use std::{iter::IntoIterator, ops::ControlFlow};
332///
333/// // A simple example...
334/// fn up_to<B>(n: u64, mut f: impl FnMut(u64) -> ControlFlow<B>) -> ControlFlow<B> {
335/// for i in 0..n {
336/// f(i)?;
337/// }
338/// ControlFlow::Continue(())
339/// }
340///
341/// // A slightly more involved example...
342/// fn enumerate<I: IntoIterator, B>(
343/// values: I,
344/// index_offset: usize,
345/// mut f: impl FnMut((usize, I::Item),
346/// ) -> ControlFlow<B>) -> ControlFlow<B> {
347/// for (i, v) in values.into_iter().enumerate() {
348/// f((i + index_offset, v))?;
349/// }
350/// ControlFlow::Continue(())
351/// }
352/// ```
353///
354/// The `Continue` parameter of the `ControlFlow` objects, meanwhile, is a lot more flexible. The
355/// `Continue` parameter on the return type of the inner closure designates the type users are expected
356/// to give back to the calling iterator function. Since users can run `continue` in the body, this
357/// type must implement [`Default`].
358///
359/// The `Continue` parameter on the return type of the iterator function, meanwhile, can be used to
360/// return values from the `cbit!` macro expression. If users `break` out of loops with a non-unit
361/// output `Continue` type, they must provide this value themself.
362///
363/// ```
364/// use std::ops::ControlFlow;
365///
366/// fn demo(list: &[i32]) -> i32 {
367/// cbit::cbit!(for (accum, value) in reduce(0, list) {
368/// if *value > 100 {
369/// break -1;
370/// }
371/// accum + value
372/// })
373/// }
374///
375/// fn reduce<T, I: IntoIterator, B>(
376/// initial: T,
377/// values: I,
378/// mut f: impl FnMut((T, I::Item)) -> ControlFlow<B, T>,
379/// ) -> ControlFlow<B, T> {
380/// let mut accum = initial;
381/// for value in values {
382/// accum = f((accum, value))?;
383/// }
384/// ControlFlow::Continue(accum)
385/// }
386///
387/// assert_eq!(demo(&[1, 2, 3]), 6);
388/// assert_eq!(demo(&[1, 2, 3, 4, 101, 8]), -1);
389/// ```
390#[proc_macro]
391pub fn cbit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
392 let input = syn::parse_macro_input!(input as CbitForExpr);
393
394 // Define some common syntax trees
395 let core_ = quote! { ::core };
396 let ops_ = quote! { #core_::ops };
397 let pin_ = quote! { #core_::pin };
398 let task_ = quote! { #core_::task };
399 let future_ = quote! { #core_::future };
400 let option_ = quote! { #core_::option::Option };
401
402 // Extract our break labels
403 let empty_punct_list = Punctuated::new();
404 let in_break_labels = input
405 .breaks
406 .as_ref()
407 .map_or(&empty_punct_list, |breaks| &breaks.lt);
408
409 let derive_early_break_variant_name =
410 |lt: &Lifetime| Ident::new(&format!("EarlyBreakTo_{}", lt.ident), lt.span());
411
412 let derive_early_continue_variant_name =
413 |lt: &Lifetime| Ident::new(&format!("EarlyContinueTo_{}", lt.ident), lt.span());
414
415 // Define an enum for our control flow
416 let control_flow_enum_def;
417 let control_flow_ty_decl;
418 let control_flow_ty_use;
419 {
420 let break_variant_names = in_break_labels
421 .iter()
422 .map(|v| derive_early_break_variant_name(&v.lt))
423 .collect::<Vec<_>>();
424
425 let continue_variant_names = in_break_labels
426 .iter()
427 .filter(|&v| v.kw_loop.is_some())
428 .map(|v| derive_early_continue_variant_name(&v.lt));
429
430 control_flow_enum_def = quote! {
431 #[allow(non_camel_case_types)]
432 #[allow(clippy::enum_variant_names)]
433 enum OurControlFlowResult<EarlyReturn, EarlyBreak #(, #break_variant_names)*> {
434 EarlyReturn(EarlyReturn),
435 EarlyBreak(EarlyBreak),
436 #(#break_variant_names (#break_variant_names),)*
437 #(#continue_variant_names,)*
438 }
439 };
440
441 control_flow_ty_decl = quote! {
442 #[allow(non_camel_case_types)]
443 type OurControlFlow<EarlyReturn, EarlyBreak #(, #break_variant_names)*> = #ops_::ControlFlow<
444 OurControlFlowResult<EarlyReturn, EarlyBreak #(, #break_variant_names)*>,
445 EarlyBreak,
446 >;
447 };
448
449 let underscores =
450 (0..(break_variant_names.len() + 2)).map(|_| Token));
451
452 control_flow_ty_use = quote! { OurControlFlow<#(#underscores),*> };
453 }
454
455 // Define our initial break layer
456 let aborter = |resolution: TokenStream| {
457 quote! {
458 how_to_resolve_pending = #option_::Some(#resolution);
459 #future_::pending::<()>().await;
460 #core_::unreachable!();
461 }
462 };
463
464 let for_body = input.body.body;
465 let for_body = {
466 let optional_label = &input.label;
467 let break_aborter = aborter(quote! {
468 #ops_::ControlFlow::Break(OurControlFlowResult::EarlyBreak(break_result))
469 });
470
471 quote! {
472 '__cbit_absorber_magic_innermost: {
473 let mut did_run = false;
474 let break_result = #optional_label loop {
475 if did_run {
476 // The user must have used `continue`.
477 break '__cbit_absorber_magic_innermost #core_::default::Default::default();
478 }
479
480 did_run = true;
481 let break_result = { #for_body };
482
483 // The user completed the loop.
484 #[allow(unreachable_code)]
485 break '__cbit_absorber_magic_innermost break_result;
486 };
487
488 // The user broke out of the loop.
489 #[allow(unreachable_code)]
490 {
491 #break_aborter
492 }
493 }
494 }
495 };
496
497 // Build up an onion of user-specified break layers
498 let for_body = {
499 let mut for_body = for_body;
500 for break_label_entry in in_break_labels {
501 let break_label = &break_label_entry.lt;
502
503 let break_aborter = {
504 let variant_name = derive_early_break_variant_name(break_label);
505 aborter(quote! {
506 #ops_::ControlFlow::Break(OurControlFlowResult::#variant_name(break_result))
507 })
508 };
509
510 let outer_label = Lifetime::new(
511 &format!("'__cbit_absorber_magic_for_{}", break_label.ident),
512 break_label.span(),
513 );
514
515 if break_label_entry.kw_loop.is_some() {
516 let continue_aborter = {
517 let variant_name = derive_early_continue_variant_name(break_label);
518 aborter(quote! {
519 #ops_::ControlFlow::Break(OurControlFlowResult::#variant_name)
520 })
521 };
522
523 for_body = quote! {#outer_label: {
524 let mut did_run = false;
525 let break_result = #break_label: loop {
526 if did_run {
527 // The user must have used `continue`.
528 #continue_aborter
529 }
530
531 did_run = true;
532 let break_result = { #for_body };
533
534 // The user completed the loop.
535 #[allow(unreachable_code)]
536 break #outer_label break_result;
537 };
538
539 // The user broke out of the loop.
540 #[allow(unreachable_code)]
541 {
542 #break_aborter
543 }
544 }};
545 } else {
546 for_body = quote! {#outer_label: {
547 let break_result = #break_label: {
548 let break_result = { #for_body };
549
550 // The user completed the loop.
551 #[allow(unreachable_code)]
552 break #outer_label break_result;
553 };
554
555 // The user broke out of the block.
556 #[allow(unreachable_code)]
557 {
558 #break_aborter
559 }
560 }};
561 }
562 }
563
564 for_body
565 };
566
567 // Build up a layer to capture early returns and generally process arguments
568 let for_body = {
569 let body_input_pat = &input.body_pattern;
570 let termination_aborter = aborter(quote! { #ops_::ControlFlow::Continue(end_result) });
571 quote! {
572 |#body_input_pat| {
573 let mut how_to_resolve_pending = #option_::None;
574
575 let body = #pin_::pin!(async {
576 let end_result = { #for_body };
577
578 #[allow(unreachable_code)] { #termination_aborter }
579 });
580
581 match #future_::Future::poll(
582 body,
583 &mut #task_::Context::from_waker(&{ // TODO: Use `Waker::noop` once it stabilizes
584 const VTABLE: #task_::RawWakerVTable = #task_::RawWakerVTable::new(
585 // Cloning just returns a new no-op raw waker
586 |_| RAW,
587 // `wake` does nothing
588 |_| {},
589 // `wake_by_ref` does nothing
590 |_| {},
591 // Dropping does nothing as we don't allocate anything
592 |_| {},
593 );
594 const RAW: #task_::RawWaker = #task_::RawWaker::new(#core_::ptr::null(), &VTABLE);
595 unsafe { #task_::Waker::from_raw(RAW) }
596 })
597 ) {
598 #task_::Poll::Ready(early_return) => #ops_::ControlFlow::Break(
599 OurControlFlowResult::EarlyReturn(early_return),
600 ),
601 #task_::Poll::Pending => how_to_resolve_pending.expect(
602 "the async block in a cbit iterator is an implementation detail; do not \
603 `.await` in it!"
604 ),
605 }
606 }
607 }
608 };
609
610 // Build up a list of break/continue handlers
611 let break_out_matchers = in_break_labels.iter().map(|v| {
612 let lt = &v.lt;
613 let variant_name = derive_early_break_variant_name(lt);
614 quote! {
615 OurControlFlowResult::#variant_name(break_out) => break #lt break_out,
616 }
617 });
618
619 let continue_out_matchers = in_break_labels
620 .iter()
621 .filter(|v| v.kw_loop.is_some())
622 .map(|v| {
623 let lt = &v.lt;
624 let variant_name = derive_early_continue_variant_name(lt);
625 quote! {
626 OurControlFlowResult::#variant_name => continue #lt,
627 }
628 });
629
630 // Build up our function call site
631 let driver_call_site = match &input.call {
632 syntax::AnyCallExpr::Function(call) => {
633 let driver_attrs = &call.attrs;
634 let driver_fn_expr = &call.func;
635 let driver_fn_args = call.args.iter();
636
637 quote! {
638 #(#driver_attrs)*
639 let result: #control_flow_ty_use = #driver_fn_expr (#(#driver_fn_args,)* #for_body);
640 }
641 }
642 syntax::AnyCallExpr::Method(call) => {
643 let driver_attrs = &call.attrs;
644 let driver_receiver_expr = &call.receiver;
645 let driver_method = &call.method;
646 let driver_turbo = &call.turbofish;
647 let driver_fn_args = call.args.iter();
648
649 quote! {
650 #(#driver_attrs)*
651 let result: #control_flow_ty_use =
652 #driver_receiver_expr.#driver_method #driver_turbo (
653 #(#driver_fn_args,)*
654 #for_body
655 );
656 }
657 }
658 };
659
660 // Put it all together
661 quote! {{
662 // enum ControlFlowResult<...> { ... }
663 #control_flow_enum_def
664
665 // type ControlFlow<A, B, ...> = core::ops::ControlFlow<ControlFlowResult<A, B, ...>, A>;
666 #control_flow_ty_decl
667
668 // let result = my_fn(args, |...| async { ... });
669 #driver_call_site
670
671 match result {
672 #ops_::ControlFlow::Break(result) => match result {
673 OurControlFlowResult::EarlyReturn(early_result) => return early_result,
674 OurControlFlowResult::EarlyBreak(result) => result,
675 #(#break_out_matchers)*
676 #(#continue_out_matchers)*
677 },
678 #ops_::ControlFlow::Continue(result) => result,
679 }
680 }}
681 .into()
682}