for_streams/
lib.rs

1//! The `for_streams!` macro, for driving multiple async [`Stream`]s concurrently. `for_streams!`
2//! works well with [Tokio](https://tokio.rs/), but it doesn't depend on Tokio.
3//!
4//! # The simplest case
5//!
6//! ```rust
7//! # use std::time::Duration;
8//! # #[tokio::main]
9//! # async fn main() {
10//! use for_streams::for_streams;
11//!
12//! for_streams! {
13//!     x in futures::stream::iter(1..=3) => {
14//!         tokio::time::sleep(Duration::from_millis(1)).await;
15//!         print!("{x} ");
16//!     }
17//!     y in futures::stream::iter(101..=103) => {
18//!         tokio::time::sleep(Duration::from_millis(1)).await;
19//!         print!("{y} ");
20//!     }
21//! }
22//! # }
23//! ```
24//!
25//! That takes three milliseconds and prints `1 101 2 102 3 103`. The behavior there is similar to
26//! using [`StreamExt::for_each`][for_each] and [`futures::join!`][join] together like this:
27//!
28//! ```rust
29//! # use futures::StreamExt;
30//! # use std::time::Duration;
31//! # #[tokio::main]
32//! # async fn main() {
33//! futures::join!(
34//!     futures::stream::iter(1..=3).for_each(|x| async move {
35//!         tokio::time::sleep(Duration::from_millis(1)).await;
36//!         println!("{x}");
37//!     }),
38//!     futures::stream::iter(101..=103).for_each(|x| async move {
39//!         tokio::time::sleep(Duration::from_millis(1)).await;
40//!         println!("{x}");
41//!     }),
42//! );
43//! # }
44//! ```
45//!
46//! However, importantly, using [`select!`] in a loop does _not_ behave the same way:
47//!
48//! ```rust
49//! # use futures::StreamExt;
50//! # use std::time::Duration;
51//! # #[tokio::main]
52//! # async fn main() {
53//! let mut stream1 = futures::stream::iter(1..=3).fuse();
54//! let mut stream2 = futures::stream::iter(101..=103).fuse();
55//! loop {
56//!     futures::select! {
57//!         x = stream1.next() => {
58//!             if let Some(x) = x {
59//!                 tokio::time::sleep(Duration::from_millis(1)).await;
60//!                 println!("{x}");
61//!             }
62//!         }
63//!         y = stream2.next() => {
64//!             if let Some(y) = y {
65//!                 tokio::time::sleep(Duration::from_millis(1)).await;
66//!                 println!("{y}");
67//!             }
68//!         }
69//!         complete => break,
70//!     }
71//! }
72//! # }
73//! ```
74//!
75//! That approach takes _six_ milliseconds, not three. `select!` is [notorious] for cancellation
76//! footguns, but this is actually a different problem: the body of a `select!` arm doesn't run
77//! concurrently with any other arms (neither their bodies nor their "scrutinees"). Using `select!`
78//! in a loop to drive multiple streams is often a mistake, [occasionally a deadlock][deadlock] but
79//! frequently a silent performance bug.
80//!
81//! And yet, `select!` in a loop gives us an appealing degree of control. Any of the bodies can
82//! `break` the loop, for example, which is awkward to replicate with `join!`. This is what
83//! `for_streams!` is about. It's like `select!` in a loop, but specifically for `Stream`s, with
84//! fewer footguns and several convenience features.
85//!
86//! # More interesting features
87//!
88//! `continue`, `break`, and `return` are all supported. `continue` skips to the next element of
89//! that stream, `break` stops reading from that stream, and `return` ends the whole macro (not the
90//! calling function, similar to `return` in an `async` block). The only valid return type is `()`.
91//! This example prints `a2 b1 c1 a4 b2 c2 a6 c3 a8` and then exits:
92//!
93//! ```rust
94//! # use for_streams::for_streams;
95//! # use std::time::Duration;
96//! # #[tokio::main]
97//! # async fn main() {
98//! for_streams! {
99//!     a in futures::stream::iter(1..1_000_000_000) => {
100//!         if a % 2 == 1 {
101//!             continue; // Skip the odd elements in this arm.
102//!         }
103//!         print!("a{a} ");
104//!         tokio::time::sleep(Duration::from_millis(1)).await;
105//!     }
106//!     b in futures::stream::iter(1..1_000_000_000) => {
107//!         if b > 2 {
108//!             break; // Stop this arm after two elements.
109//!         }
110//!         print!("b{b} ");
111//!         tokio::time::sleep(Duration::from_millis(1)).await;
112//!     }
113//!     c in futures::stream::iter(1..1_000_000_000) => {
114//!         if c > 3 {
115//!             return; // Stop the whole loop after three elements.
116//!         }
117//!         print!("c{c} ");
118//!         tokio::time::sleep(Duration::from_millis(1)).await;
119//!     }
120//! }
121//! # }
122//! ```
123//!
124//! Sometimes you have a stream that's finite, like a channel that will eventually close, and
125//! another streams that's infinite, like a timer that ticks forever. You can use `in background`
126//! (in place of `in`) to tell `for_streams!` not to wait for some arms to finish:
127//!
128//! ```rust
129//! # use for_streams::for_streams;
130//! # use std::time::Duration;
131//! # #[tokio::main]
132//! # async fn main() {
133//! use tokio::time::interval;
134//! use tokio_stream::wrappers::IntervalStream;
135//!
136//! let timer = IntervalStream::new(interval(Duration::from_millis(1)));
137//! for_streams! {
138//!     x in futures::stream::iter(1..10) => {
139//!         tokio::time::sleep(Duration::from_millis(1)).await;
140//!         println!("{x}");
141//!     }
142//!     // We'll never reach the end of this `timer` stream, but `in background`
143//!     // means we'll exit when the first arm is done, instead of ticking forever.
144//!     _ in background timer => {
145//!         println!("tick");
146//!     }
147//! }
148//! # }
149//! ```
150//!
151//! The `move` keyword is supported and has the same effect as it would on a lambda or an `async
152//! move` block, making the block take ownership of all the values it references. This can be
153//! useful if you need a channel writer or a lock guard to drop promptly when one arm is done:
154//!
155//! ```rust
156//! # use for_streams::for_streams;
157//! # #[tokio::main]
158//! # async fn main() {
159//! use tokio::sync::mpsc::channel;
160//! use tokio_stream::wrappers::ReceiverStream;
161//!
162//! // This is a bounded channel, so the sender will block quickly on the
163//! // second message if the receiver isn't reading concurrently.
164//! let (sender, receiver) = channel::<i32>(1);
165//! let mut outputs = Vec::new();
166//! for_streams! {
167//!     // The `move` keyword makes this arm take ownership of `sender`, which
168//!     // means that `sender` drops as soon as this branch is finished. This
169//!     // example would deadlock without it.
170//!     val in tokio_stream::iter(1..=5) => move {
171//!         sender.send(val).await.unwrap();
172//!     }
173//!     // This arm borrows `outputs` but can't take ownership of it, because
174//!     // we use it again below in the assert.
175//!     val in ReceiverStream::new(receiver) => {
176//!         outputs.push(val);
177//!     }
178//! }
179//! assert_eq!(outputs, vec![1, 2, 3, 4, 5]);
180//! # }
181//! ```
182//!
183//! [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
184//! [for_each]: https://docs.rs/futures/latest/futures/stream/trait.StreamExt.html#method.for_each
185//! [join]: https://docs.rs/futures/latest/futures/macro.join.html
186//! [`select!`]: https://docs.rs/futures/latest/futures/macro.select.html
187//! [notorious]: https://sunshowers.io/posts/cancelling-async-rust/
188//! [deadlock]: https://rfd.shared.oxide.computer/rfd/0609
189
190use proc_macro2::{Span, TokenStream as TokenStream2};
191use quote::{ToTokens, format_ident, quote};
192use syn::{
193    Block, Expr, Ident, Pat, Token,
194    parse::{Parse, ParseStream},
195    parse_macro_input,
196};
197
198mod kw {
199    syn::custom_keyword!(background);
200}
201
202struct Arm {
203    pattern: Pat,
204    stream_expr: Expr,
205    body: Block,
206    is_background: bool,
207    is_move: bool,
208}
209
210impl Parse for Arm {
211    fn parse(input: ParseStream) -> syn::Result<Self> {
212        let pattern = Pat::parse_single(input)?;
213        _ = input.parse::<Token![in]>()?;
214        // Check whether we can parse a stream expression after `background`. If not, `background`
215        // itself could be the stream expression (i.e. a local variable name).
216        let fork = input.fork();
217        let is_background = fork.parse::<kw::background>().is_ok() && fork.parse::<Expr>().is_ok();
218        if is_background {
219            _ = input.parse::<kw::background>()?;
220        }
221        let stream_expr = input.parse()?;
222        _ = input.parse::<Token![=>]>()?;
223        let is_move = input.parse::<Token![move]>().is_ok();
224        let body = input.parse()?;
225        Ok(Self {
226            pattern,
227            stream_expr,
228            body,
229            is_background,
230            is_move,
231        })
232    }
233}
234
235struct ForStreams {
236    arms: Vec<Arm>,
237}
238
239impl Parse for ForStreams {
240    fn parse(input: ParseStream) -> syn::Result<Self> {
241        let mut arms = Vec::new();
242        while !input.is_empty() {
243            let arm = input.parse::<Arm>()?;
244            arms.push(arm);
245        }
246        Ok(Self { arms })
247    }
248}
249
250impl ToTokens for ForStreams {
251    fn to_tokens(&self, tokens: &mut TokenStream2) {
252        let mut initializers = TokenStream2::new();
253        let cancel_flag = format_ident!("cancel_flag", span = Span::mixed_site());
254        let arm_names: Vec<Ident> = (0..self.arms.len())
255            .map(|i| format_ident!("arm_{}", i, span = Span::mixed_site()))
256            .collect();
257        for i in 0..self.arms.len() {
258            let Arm {
259                pattern,
260                stream_expr,
261                body,
262                is_background: _,
263                is_move,
264            } = &self.arms[i];
265            let move_token = if *is_move {
266                quote! { move }
267            } else {
268                quote! {}
269            };
270            let returned_early = format_ident!("returned_early", span = Span::mixed_site());
271            let returned_early_ref = format_ident!("returned_early_ref", span = Span::mixed_site());
272            let stream = format_ident!("stream", span = Span::mixed_site());
273            let name = &arm_names[i];
274            initializers.extend(quote! {
275                let mut #name = ::std::pin::pin!(::futures::future::FutureExt::fuse({
276                    async {
277                        let mut #returned_early = true;
278                        // For the `move` case, we need to explicitly take a reference to
279                        // `returned_early`, so that we don't copy it.
280                        let #returned_early_ref = &mut #returned_early;
281                        let _: () = async #move_token {
282                            let mut #stream = ::std::pin::pin!(#stream_expr);
283                            while let Some(#pattern) = ::futures::stream::StreamExt::next(&mut #stream).await {
284                                // NOTE: The #body may `continue`, `break`, or `return`.
285                                #body
286                            }
287                            *#returned_early_ref = false;
288                        }.await;
289                        if #returned_early {
290                            ::std::sync::atomic::AtomicBool::store(&#cancel_flag, true, ::std::sync::atomic::Ordering::Relaxed);
291                        }
292                    }
293                }));
294            });
295        }
296
297        let mut poll_calls = TokenStream2::new();
298        let foreground_finished = format_ident!("foreground_finished", span = Span::mixed_site());
299        let cx = format_ident!("cx", span = Span::mixed_site());
300        for i in 0..self.arms.len() {
301            let name = &arm_names[i];
302            poll_calls.extend(quote! {
303                // NOTE: These are fused, so we can poll them unconditionally.
304                _ = ::std::future::Future::poll(::std::pin::Pin::as_mut(&mut #name), #cx);
305            });
306            if !self.arms[i].is_background {
307                poll_calls.extend(quote! {
308                    #foreground_finished &= ::futures::future::FusedFuture::is_terminated(&#name);
309                });
310            }
311        }
312
313        tokens.extend(quote! {
314            {
315                let mut #cancel_flag = ::std::sync::atomic::AtomicBool::new(false);
316                #initializers
317                ::std::future::poll_fn(|#cx| {
318                    let mut #foreground_finished = true;
319                    #poll_calls
320                    if ::std::sync::atomic::AtomicBool::load(&#cancel_flag, ::std::sync::atomic::Ordering::Relaxed) {
321                        return ::std::task::Poll::Ready(());
322                    }
323                    if #foreground_finished {
324                        return ::std::task::Poll::Ready(());
325                    }
326                    ::std::task::Poll::Pending
327                }).await;
328            }
329        });
330    }
331}
332
333#[proc_macro]
334pub fn for_streams(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
335    let c = parse_macro_input!(input as ForStreams);
336    quote! { #c }.into()
337}