1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{
4 parenthesized,
5 parse::{Parse, ParseStream, Result},
6 parse_macro_input,
7 punctuated::Punctuated,
8 Expr, Ident, Token,
9};
10
11struct Receiver {
12 channel: Ident,
13 msg: Ident,
14 handler: Expr,
15}
16
17impl Parse for Receiver {
18 fn parse(input: ParseStream) -> Result<Self> {
19 let channel: Ident = input.parse()?;
20 let in_parens;
21 let _parens = parenthesized!(in_parens in input);
22 let msg: Ident = in_parens.parse()?;
23 let _arrow = input.parse::<Token![=>]>()?;
24 let handler: Expr = input.parse()?;
25
26 Ok(Receiver {
27 channel,
28 msg,
29 handler,
30 })
31 }
32}
33
34type Receivers = Punctuated<Receiver, Token![,]>;
35
36fn parse_receivers(input: ParseStream) -> Result<Receivers> {
37 input.parse_terminated(Receiver::parse)
38}
39
40#[proc_macro]
68pub fn drain(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
69 let receivers = parse_macro_input!(input with parse_receivers);
70
71 let channels: Vec<&Ident> = receivers.iter().map(|r| &r.channel).collect();
72 let channel_len = channels.len();
73
74 let selectors = build_selectors(&channels);
75 let op_matches = build_op_match(&receivers);
76
77 let whole = quote! {{
78 let mut channels_open: usize = #channel_len;
82
83 let mut sel = crossbeam::channel::Select::new();
84 #selectors
85
86 while channels_open > 0 {
88 let op = sel.select();
89 match op.index() {
90 #op_matches
91 wut => unreachable!("Unexpected index {}", wut)
92 }
93 }
94 }};
95
96 whole.into()
97}
98
99fn build_selectors(channels: &[&Ident]) -> TokenStream {
100 let mut selectors = TokenStream::new();
101 for (i, channel) in channels.iter().enumerate() {
102 selectors.extend(quote! {
103 assert_eq!(sel.recv(&#channel), #i);
104 })
105 }
106 selectors
107}
108
109fn build_op_match(receivers: &Receivers) -> TokenStream {
110 let mut match_arms = TokenStream::new();
111 for (
112 i,
113 Receiver {
114 channel,
115 msg,
116 handler,
117 },
118 ) in receivers.iter().enumerate()
119 {
120 match_arms.extend(quote! {
121 idx if idx == #i => {
122 match op.recv(&#channel) {
123 Ok(#msg) => #handler,
124 Err(_) => {
125 sel.remove(#i);
127 assert!(channels_open > 0);
128 channels_open -= 1;
129 }
130 }
131 },
132 })
133 }
134 match_arms
135}