channel_drain/
lib.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{
4    parenthesized,
5    parse::{Parse, ParseStream, Result},
6    parse_macro_input,
7    punctuated::Punctuated,
8    Expr, Ident, Token,
9};
10
11struct Receiver {
12    channel: Ident,
13    msg: Ident,
14    handler: Expr,
15}
16
17impl Parse for Receiver {
18    fn parse(input: ParseStream) -> Result<Self> {
19        let channel: Ident = input.parse()?;
20        let in_parens;
21        let _parens = parenthesized!(in_parens in input);
22        let msg: Ident = in_parens.parse()?;
23        let _arrow = input.parse::<Token![=>]>()?;
24        let handler: Expr = input.parse()?;
25
26        Ok(Receiver {
27            channel,
28            msg,
29            handler,
30        })
31    }
32}
33
34type Receivers = Punctuated<Receiver, Token![,]>;
35
36fn parse_receivers(input: ParseStream) -> Result<Receivers> {
37    input.parse_terminated(Receiver::parse)
38}
39
40/// Receive on all channels until they are... drained.
41///
42/// Given a list of [`Receiver`s](crossbeam::channel::Receiver)
43/// and expressions to handle received messages, e.g.,
44/// `receiver(msg) => handle(msg)`,
45/// receive in a loop until all channels are closed and empty
46/// ([`recv()`](crossbeam::channel::Receiver::recv) returns `Err`).
47///
48/// ```
49/// # use channel_drain::drain;
50/// # use crossbeam::channel::bounded;
51/// let (tx1, rx1) = bounded(10);
52/// let (tx2, rx2) = bounded(10);
53///
54/// tx1.send(25.0).unwrap();
55/// tx1.send(62.4).unwrap();
56/// tx2.send(42).unwrap();
57/// tx2.send(22).unwrap();
58/// tx2.send(99).unwrap();
59///
60/// drop(tx1);
61/// drop(tx2);
62///
63/// drain! {
64///     rx1(dubs) => { println!("Some double: \"{}\"", dubs) },
65///     rx2(i) => println!("Some int: {}", i)
66/// };
67#[proc_macro]
68pub fn drain(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
69    let receivers = parse_macro_input!(input with parse_receivers);
70
71    let channels: Vec<&Ident> = receivers.iter().map(|r| &r.channel).collect();
72    let channel_len = channels.len();
73
74    let selectors = build_selectors(&channels);
75    let op_matches = build_op_match(&receivers);
76
77    let whole = quote! {{
78        // We can just keep track of the number of remaining channels open,
79        // since we remove each channel from the `Select` below as soon as
80        // it errors once. (We could skip this entirely if `Select` had a len().)
81        let mut channels_open: usize = #channel_len;
82
83        let mut sel = crossbeam::channel::Select::new();
84        #selectors
85
86        // While any channels are open, keep receiving.
87        while channels_open > 0 {
88            let op = sel.select();
89            match op.index() {
90                #op_matches
91                wut => unreachable!("Unexpected index {}", wut)
92            }
93        }
94    }};
95
96    whole.into()
97}
98
99fn build_selectors(channels: &[&Ident]) -> TokenStream {
100    let mut selectors = TokenStream::new();
101    for (i, channel) in channels.iter().enumerate() {
102        selectors.extend(quote! {
103            assert_eq!(sel.recv(&#channel), #i);
104        })
105    }
106    selectors
107}
108
109fn build_op_match(receivers: &Receivers) -> TokenStream {
110    let mut match_arms = TokenStream::new();
111    for (
112        i,
113        Receiver {
114            channel,
115            msg,
116            handler,
117        },
118    ) in receivers.iter().enumerate()
119    {
120        match_arms.extend(quote! {
121            idx if idx == #i => {
122                match op.recv(&#channel) {
123                    Ok(#msg) => #handler,
124                    Err(_) => {
125                        // Indexes are stable; this doesn't shift remaining channels.
126                        sel.remove(#i);
127                        assert!(channels_open > 0);
128                        channels_open -= 1;
129                    }
130                }
131            },
132        })
133    }
134    match_arms
135}