for_streams/lib.rs
1//! The `for_streams!` macro, for driving multiple async [`Stream`]s concurrently. `for_streams!`
2//! works well with [Tokio](https://tokio.rs/), but it doesn't depend on Tokio.
3//!
4//! # The simplest case
5//!
6//! ```rust
7//! # use std::time::Duration;
8//! # #[tokio::main]
9//! # async fn main() {
10//! use for_streams::for_streams;
11//!
12//! for_streams! {
13//! x in futures::stream::iter(1..=3) => {
14//! tokio::time::sleep(Duration::from_millis(1)).await;
15//! print!("{x} ");
16//! }
17//! y in futures::stream::iter(101..=103) => {
18//! tokio::time::sleep(Duration::from_millis(1)).await;
19//! print!("{y} ");
20//! }
21//! }
22//! # }
23//! ```
24//!
25//! That takes three milliseconds and prints `1 101 2 102 3 103`. The behavior there is similar to
26//! using [`StreamExt::for_each`][for_each] and [`futures::join!`][join] together like this:
27//!
28//! ```rust
29//! # use futures::StreamExt;
30//! # use std::time::Duration;
31//! # #[tokio::main]
32//! # async fn main() {
33//! futures::join!(
34//! futures::stream::iter(1..=3).for_each(|x| async move {
35//! tokio::time::sleep(Duration::from_millis(1)).await;
36//! println!("{x}");
37//! }),
38//! futures::stream::iter(101..=103).for_each(|x| async move {
39//! tokio::time::sleep(Duration::from_millis(1)).await;
40//! println!("{x}");
41//! }),
42//! );
43//! # }
44//! ```
45//!
46//! However, importantly, using [`select!`] in a loop does _not_ behave the same way:
47//!
48//! ```rust
49//! # use futures::StreamExt;
50//! # use std::time::Duration;
51//! # #[tokio::main]
52//! # async fn main() {
53//! let mut stream1 = futures::stream::iter(1..=3).fuse();
54//! let mut stream2 = futures::stream::iter(101..=103).fuse();
55//! loop {
56//! futures::select! {
57//! x = stream1.next() => {
58//! if let Some(x) = x {
59//! tokio::time::sleep(Duration::from_millis(1)).await;
60//! println!("{x}");
61//! }
62//! }
63//! y = stream2.next() => {
64//! if let Some(y) = y {
65//! tokio::time::sleep(Duration::from_millis(1)).await;
66//! println!("{y}");
67//! }
68//! }
69//! complete => break,
70//! }
71//! }
72//! # }
73//! ```
74//!
75//! That approach takes _six_ milliseconds, not three. `select!` is [notorious] for cancellation
76//! footguns, but this is actually a different problem: the body of a `select!` arm doesn't run
77//! concurrently with any other arms (neither their bodies nor their "scrutinees"). Using `select!`
78//! in a loop to drive multiple streams is often a mistake, [occasionally a deadlock][deadlock] but
79//! frequently a silent performance bug.
80//!
81//! And yet, `select!` in a loop gives us an appealing degree of control. Any of the bodies can
82//! `break` the loop, for example, which is awkward to replicate with `join!`. This is what
83//! `for_streams!` is about. It's like `select!` in a loop, but specifically for `Stream`s, with
84//! fewer footguns and several convenience features.
85//!
86//! # More interesting features
87//!
88//! `continue`, `break`, and `return` are all supported. `continue` skips to the next element of
89//! that stream, `break` stops reading from that stream, and `return` ends the whole macro (not the
90//! calling function, similar to `return` in an `async` block). The only valid return type is `()`.
91//! This example prints `a2 b1 c1 a4 b2 c2 a6 c3 a8` and then exits:
92//!
93//! ```rust
94//! # use for_streams::for_streams;
95//! # use std::time::Duration;
96//! # #[tokio::main]
97//! # async fn main() {
98//! for_streams! {
99//! a in futures::stream::iter(1..1_000_000_000) => {
100//! if a % 2 == 1 {
101//! continue; // Skip the odd elements in this arm.
102//! }
103//! print!("a{a} ");
104//! tokio::time::sleep(Duration::from_millis(1)).await;
105//! }
106//! b in futures::stream::iter(1..1_000_000_000) => {
107//! if b > 2 {
108//! break; // Stop this arm after two elements.
109//! }
110//! print!("b{b} ");
111//! tokio::time::sleep(Duration::from_millis(1)).await;
112//! }
113//! c in futures::stream::iter(1..1_000_000_000) => {
114//! if c > 3 {
115//! return; // Stop the whole loop after three elements.
116//! }
117//! print!("c{c} ");
118//! tokio::time::sleep(Duration::from_millis(1)).await;
119//! }
120//! }
121//! # }
122//! ```
123//!
124//! Sometimes you have a stream that's finite, like a channel that will eventually close, and
125//! another streams that's infinite, like a timer that ticks forever. You can use `in background`
126//! (in place of `in`) to tell `for_streams!` not to wait for some arms to finish:
127//!
128//! ```rust
129//! # use for_streams::for_streams;
130//! # use std::time::Duration;
131//! # #[tokio::main]
132//! # async fn main() {
133//! use tokio::time::interval;
134//! use tokio_stream::wrappers::IntervalStream;
135//!
136//! let timer = IntervalStream::new(interval(Duration::from_millis(1)));
137//! for_streams! {
138//! x in futures::stream::iter(1..10) => {
139//! tokio::time::sleep(Duration::from_millis(1)).await;
140//! println!("{x}");
141//! }
142//! // We'll never reach the end of this `timer` stream, but `in background`
143//! // means we'll exit when the first arm is done, instead of ticking forever.
144//! _ in background timer => {
145//! println!("tick");
146//! }
147//! }
148//! # }
149//! ```
150//!
151//! The `move` keyword is supported and has the same effect as it would on a lambda or an `async
152//! move` block, making the block take ownership of all the values it references. This can be
153//! useful if you need a channel writer or a lock guard to drop promptly when one arm is done:
154//!
155//! ```rust
156//! # use for_streams::for_streams;
157//! # #[tokio::main]
158//! # async fn main() {
159//! use tokio::sync::mpsc::channel;
160//! use tokio_stream::wrappers::ReceiverStream;
161//!
162//! // This is a bounded channel, so the sender will block quickly on the
163//! // second message if the receiver isn't reading concurrently.
164//! let (sender, receiver) = channel::<i32>(1);
165//! let mut outputs = Vec::new();
166//! for_streams! {
167//! // The `move` keyword makes this arm take ownership of `sender`, which
168//! // means that `sender` drops as soon as this branch is finished. This
169//! // example would deadlock without it.
170//! val in tokio_stream::iter(1..=5) => move {
171//! sender.send(val).await.unwrap();
172//! }
173//! // This arm borrows `outputs` but can't take ownership of it, because
174//! // we use it again below in the assert.
175//! val in ReceiverStream::new(receiver) => {
176//! outputs.push(val);
177//! }
178//! }
179//! assert_eq!(outputs, vec![1, 2, 3, 4, 5]);
180//! # }
181//! ```
182//!
183//! [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
184//! [for_each]: https://docs.rs/futures/latest/futures/stream/trait.StreamExt.html#method.for_each
185//! [join]: https://docs.rs/futures/latest/futures/macro.join.html
186//! [`select!`]: https://docs.rs/futures/latest/futures/macro.select.html
187//! [notorious]: https://sunshowers.io/posts/cancelling-async-rust/
188//! [deadlock]: https://rfd.shared.oxide.computer/rfd/0609
189
190use proc_macro2::{Span, TokenStream as TokenStream2};
191use quote::{ToTokens, format_ident, quote};
192use syn::{
193 Block, Expr, Ident, Pat, Token,
194 parse::{Parse, ParseStream},
195 parse_macro_input,
196};
197
198mod kw {
199 syn::custom_keyword!(background);
200}
201
202struct Arm {
203 pattern: Pat,
204 stream_expr: Expr,
205 body: Block,
206 is_background: bool,
207 is_move: bool,
208}
209
210impl Parse for Arm {
211 fn parse(input: ParseStream) -> syn::Result<Self> {
212 let pattern = Pat::parse_single(input)?;
213 _ = input.parse::<Token![in]>()?;
214 // Check whether we can parse a stream expression after `background`. If not, `background`
215 // itself could be the stream expression (i.e. a local variable name).
216 let fork = input.fork();
217 let is_background = fork.parse::<kw::background>().is_ok() && fork.parse::<Expr>().is_ok();
218 if is_background {
219 _ = input.parse::<kw::background>()?;
220 }
221 let stream_expr = input.parse()?;
222 _ = input.parse::<Token![=>]>()?;
223 let is_move = input.parse::<Token![move]>().is_ok();
224 let body = input.parse()?;
225 Ok(Self {
226 pattern,
227 stream_expr,
228 body,
229 is_background,
230 is_move,
231 })
232 }
233}
234
235struct ForStreams {
236 arms: Vec<Arm>,
237}
238
239impl Parse for ForStreams {
240 fn parse(input: ParseStream) -> syn::Result<Self> {
241 let mut arms = Vec::new();
242 while !input.is_empty() {
243 let arm = input.parse::<Arm>()?;
244 arms.push(arm);
245 }
246 Ok(Self { arms })
247 }
248}
249
250impl ToTokens for ForStreams {
251 fn to_tokens(&self, tokens: &mut TokenStream2) {
252 let mut initializers = TokenStream2::new();
253 let cancel_flag = format_ident!("cancel_flag", span = Span::mixed_site());
254 let arm_names: Vec<Ident> = (0..self.arms.len())
255 .map(|i| format_ident!("arm_{}", i, span = Span::mixed_site()))
256 .collect();
257 for i in 0..self.arms.len() {
258 let Arm {
259 pattern,
260 stream_expr,
261 body,
262 is_background: _,
263 is_move,
264 } = &self.arms[i];
265 let move_token = if *is_move {
266 quote! { move }
267 } else {
268 quote! {}
269 };
270 let returned_early = format_ident!("returned_early", span = Span::mixed_site());
271 let returned_early_ref = format_ident!("returned_early_ref", span = Span::mixed_site());
272 let stream = format_ident!("stream", span = Span::mixed_site());
273 let name = &arm_names[i];
274 initializers.extend(quote! {
275 let mut #name = ::std::pin::pin!(::futures::future::FutureExt::fuse({
276 async {
277 let mut #returned_early = true;
278 // For the `move` case, we need to explicitly take a reference to
279 // `returned_early`, so that we don't copy it.
280 let #returned_early_ref = &mut #returned_early;
281 let _: () = async #move_token {
282 let mut #stream = ::std::pin::pin!(#stream_expr);
283 while let Some(#pattern) = ::futures::stream::StreamExt::next(&mut #stream).await {
284 // NOTE: The #body may `continue`, `break`, or `return`.
285 #body
286 }
287 *#returned_early_ref = false;
288 }.await;
289 if #returned_early {
290 ::std::sync::atomic::AtomicBool::store(&#cancel_flag, true, ::std::sync::atomic::Ordering::Relaxed);
291 }
292 }
293 }));
294 });
295 }
296
297 let mut poll_calls = TokenStream2::new();
298 let foreground_finished = format_ident!("foreground_finished", span = Span::mixed_site());
299 let cx = format_ident!("cx", span = Span::mixed_site());
300 for i in 0..self.arms.len() {
301 let name = &arm_names[i];
302 poll_calls.extend(quote! {
303 // NOTE: These are fused, so we can poll them unconditionally.
304 _ = ::std::future::Future::poll(::std::pin::Pin::as_mut(&mut #name), #cx);
305 });
306 if !self.arms[i].is_background {
307 poll_calls.extend(quote! {
308 #foreground_finished &= ::futures::future::FusedFuture::is_terminated(&#name);
309 });
310 }
311 }
312
313 tokens.extend(quote! {
314 {
315 let mut #cancel_flag = ::std::sync::atomic::AtomicBool::new(false);
316 #initializers
317 ::std::future::poll_fn(|#cx| {
318 let mut #foreground_finished = true;
319 #poll_calls
320 if ::std::sync::atomic::AtomicBool::load(&#cancel_flag, ::std::sync::atomic::Ordering::Relaxed) {
321 return ::std::task::Poll::Ready(());
322 }
323 if #foreground_finished {
324 return ::std::task::Poll::Ready(());
325 }
326 ::std::task::Poll::Pending
327 }).await;
328 }
329 });
330 }
331}
332
333#[proc_macro]
334pub fn for_streams(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
335 let c = parse_macro_input!(input as ForStreams);
336 quote! { #c }.into()
337}