1extern crate proc_macro;
2
3use proc_macro::{TokenStream, TokenTree};
4use proc_macro_hack::proc_macro_hack;
5use std::iter::FromIterator;
6
7use quote::quote;
8
9#[proc_macro_hack]
10pub fn future_union_impl(item: TokenStream) -> TokenStream {
11 let mut iter = item.into_iter();
12
13 let count_arg_token = iter.next().unwrap_or_else(|| panic!("Too few arguments"));
14 let count_arg = syn::parse::<syn::LitInt>(
15 TokenStream::from(count_arg_token)
16 ).unwrap_or_else(|_| panic!("Expecting integer literal")).value();
17
18 let comma_1 = iter.next().unwrap_or_else(|| panic!("Too few arguments"));
19 match comma_1 {
20 TokenTree::Punct(ref p) if p.as_char() == ',' => (),
21 _ => panic!("Invalid syntax, expected a comma"),
22 }
23
24 let n_arg_token = iter.next().unwrap_or_else(|| panic!("Too few arguments"));
25 let n_arg = syn::parse::<syn::LitInt>(
26 TokenStream::from(n_arg_token)
27 ).unwrap_or_else(|_| panic!("Expecting integer literal")).value();
28
29 let comma_2 = iter.next().unwrap_or_else(|| panic!("Too few arguments"));
30 match comma_2 {
31 TokenTree::Punct(ref p) if p.as_char() == ',' => (),
32 _ => panic!("Invalid syntax, expected a comma"),
33 }
34
35 let remaining_tokens = proc_macro2::TokenStream::from(TokenStream::from_iter(iter));
36
37 TokenStream::from(
38 future_union_make_tree(count_arg, n_arg, remaining_tokens)
39 )
40}
41
42fn future_union_make_tree(count: u64, n: u64, expr: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
43 assert!(n < count);
44
45 if count <= 0 {
46 panic!()
47 } else if count == 1 {
48 expr
49 } else if count == 2 {
50 if n & 1 == 0 {
51 quote!( futures::future::Either::A(#expr) )
52 } else {
53 quote!( futures::future::Either::B(#expr) )
54 }
55 } else {
56 let max_cap = round_up_to_power_of_2(count);
57 if first_half(max_cap, n) {
58 let sub_tree = future_union_make_tree(max_cap/2, n, expr);
59 quote!( futures::future::Either::A(#sub_tree) )
60 } else {
61 let sub_tree = future_union_make_tree(count-max_cap/2, n-max_cap/2, expr);
62 quote!( futures::future::Either::B(#sub_tree) )
63 }
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70
71 #[test]
72 fn round_up_to_power_of_2_test() {
73 assert_eq!(round_up_to_power_of_2(2), 2);
74 assert_eq!(round_up_to_power_of_2(3), 4);
75 assert_eq!(round_up_to_power_of_2(4), 4);
76 assert_eq!(round_up_to_power_of_2(5), 8);
77 assert_eq!(round_up_to_power_of_2(6), 8);
78 assert_eq!(round_up_to_power_of_2(7), 8);
79 assert_eq!(round_up_to_power_of_2(8), 8);
80 assert_eq!(round_up_to_power_of_2(9), 16);
81 }
82
83 #[test]
84 fn first_half_test() {
85 assert!(first_half(2, 0));
86 assert!(!first_half(2, 1));
87
88 assert!(first_half(4, 0));
89 assert!(first_half(4, 1));
90 assert!(!first_half(4, 2));
91 assert!(!first_half(4, 3));
92
93 assert!(first_half(8, 0));
94 assert!(first_half(8, 1));
95 assert!(first_half(8, 2));
96 assert!(first_half(8, 3));
97 assert!(!first_half(8, 4));
98 assert!(!first_half(8, 5));
99 assert!(!first_half(8, 6));
100 assert!(!first_half(8, 7));
101 }
102}
103
104fn first_half(cap: u64, n: u64) -> bool {
105 assert!(n < cap);
106 n < cap/2
107}
108
109fn round_up_to_power_of_2(n: u64) -> u64 {
111 (n as f64).log2().ceil().exp2() as u64
112}
113