future_union_impl/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::{TokenStream, TokenTree};
4use proc_macro_hack::proc_macro_hack;
5use std::iter::FromIterator;
6
7use quote::quote;
8
9#[proc_macro_hack]
10pub fn future_union_impl(item: TokenStream) -> TokenStream {
11    let mut iter = item.into_iter();
12
13    let count_arg_token = iter.next().unwrap_or_else(|| panic!("Too few arguments"));
14    let count_arg = syn::parse::<syn::LitInt>(
15        TokenStream::from(count_arg_token)
16    ).unwrap_or_else(|_| panic!("Expecting integer literal")).value();
17
18    let comma_1 = iter.next().unwrap_or_else(|| panic!("Too few arguments"));
19    match comma_1 {
20        TokenTree::Punct(ref p) if p.as_char() == ',' => (),
21        _ => panic!("Invalid syntax, expected a comma"),
22    }
23
24    let n_arg_token = iter.next().unwrap_or_else(|| panic!("Too few arguments"));
25    let n_arg = syn::parse::<syn::LitInt>(
26        TokenStream::from(n_arg_token)
27    ).unwrap_or_else(|_| panic!("Expecting integer literal")).value();
28
29    let comma_2 = iter.next().unwrap_or_else(|| panic!("Too few arguments"));
30    match comma_2 {
31        TokenTree::Punct(ref p) if p.as_char() == ',' => (),
32        _ => panic!("Invalid syntax, expected a comma"),
33    }
34
35    let remaining_tokens = proc_macro2::TokenStream::from(TokenStream::from_iter(iter));
36
37    TokenStream::from(
38        future_union_make_tree(count_arg, n_arg, remaining_tokens)
39    )
40}
41
42fn future_union_make_tree(count: u64, n: u64, expr: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
43    assert!(n < count);
44
45    if count <= 0 {
46        panic!()
47    } else if count == 1 {
48        expr
49    } else if count == 2 {
50        if n & 1 == 0 {
51            quote!( futures::future::Either::A(#expr) )
52        } else {
53            quote!( futures::future::Either::B(#expr) )
54        }
55    } else {
56        let max_cap = round_up_to_power_of_2(count);
57        if first_half(max_cap, n) {
58            let sub_tree = future_union_make_tree(max_cap/2, n, expr);
59            quote!( futures::future::Either::A(#sub_tree) )
60        } else {
61            let sub_tree = future_union_make_tree(count-max_cap/2, n-max_cap/2, expr);
62            quote!( futures::future::Either::B(#sub_tree) )
63        }
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[test]
72    fn round_up_to_power_of_2_test() {
73        assert_eq!(round_up_to_power_of_2(2), 2);
74        assert_eq!(round_up_to_power_of_2(3), 4);
75        assert_eq!(round_up_to_power_of_2(4), 4);
76        assert_eq!(round_up_to_power_of_2(5), 8);
77        assert_eq!(round_up_to_power_of_2(6), 8);
78        assert_eq!(round_up_to_power_of_2(7), 8);
79        assert_eq!(round_up_to_power_of_2(8), 8);
80        assert_eq!(round_up_to_power_of_2(9), 16);
81    }
82
83    #[test]
84    fn first_half_test() {
85        assert!(first_half(2, 0));
86        assert!(!first_half(2, 1));
87
88        assert!(first_half(4, 0));
89        assert!(first_half(4, 1));
90        assert!(!first_half(4, 2));
91        assert!(!first_half(4, 3));
92
93        assert!(first_half(8, 0));
94        assert!(first_half(8, 1));
95        assert!(first_half(8, 2));
96        assert!(first_half(8, 3));
97        assert!(!first_half(8, 4));
98        assert!(!first_half(8, 5));
99        assert!(!first_half(8, 6));
100        assert!(!first_half(8, 7));
101    }
102}
103
104fn first_half(cap: u64, n: u64) -> bool {
105    assert!(n < cap);
106    n < cap/2
107}
108
109// Rounds up to the nearest power of 2. Probably can just check the most significant bit?
110fn round_up_to_power_of_2(n: u64) -> u64 {
111    (n as f64).log2().ceil().exp2() as u64
112}
113