bignumbe_rs_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    parse::{self, Parse},
5    parse_macro_input, Ident, Lit, Token, Visibility,
6};
7
8struct BaseData {
9    exp_range: (u32, u32),
10    sig_range: (u64, u64),
11    powers: Vec<u64>,
12    powers_u128: Vec<u128>,
13}
14
15struct BaseInput {
16    num: Lit,
17    vis: Visibility,
18    name: Ident,
19}
20
21impl Parse for BaseInput {
22    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
23        let num = input.parse()?;
24        let _com: Token![,] = input.parse()?;
25        let vis = input.parse()?;
26        let name = input.parse()?;
27
28        Ok(Self { num, vis, name })
29    }
30}
31
32/// Called like create_efficient_base(n, (pub) IntName), where n is the number of the base
33/// and IntName is the name of the type you want to create, with optional visibility
34/// qualifiers
35#[proc_macro]
36pub fn make_bignum(input: TokenStream) -> TokenStream {
37    let BaseInput { num, vis, name } = parse_macro_input!(input as BaseInput);
38    let (core, base_ident) = create_efficient_base_core(num);
39
40    quote! {
41        #core
42
43        #vis type #name = bignumbe_rs::BigNumBase<#base_ident>;
44    }
45    .into()
46}
47
48/// Called like create_efficient_base(n), where n is the number of the base
49#[proc_macro]
50pub fn create_efficient_base(input: TokenStream) -> TokenStream {
51    create_efficient_base_core(parse_macro_input!(input as Lit))
52        .0
53        .into()
54}
55
56fn create_efficient_base_core(lit: Lit) -> (proc_macro2::TokenStream, Ident) {
57    let number: u16 = if let Lit::Int(li) = lit {
58        li.base10_parse()
59            .expect("Input must be a valid base-10 number")
60    } else {
61        panic!("Input must be a valid u16 value greater than 2");
62    };
63    let number = number as u64;
64
65    let base_ident = format_ident!("__Base{}", number);
66
67    let BaseData {
68        exp_range,
69        sig_range,
70        powers,
71        powers_u128,
72    } = get_base_data(number as u16);
73
74    let power_tables = generate_power_tables(number, powers, powers_u128);
75    let impl_code = generate_impl(number, &base_ident, exp_range, sig_range);
76
77    // Create a default
78
79    (
80        quote! {
81            #[derive(Clone, Copy, Debug)]
82            struct #base_ident();
83
84            #power_tables
85            #impl_code
86        },
87        base_ident,
88    )
89}
90
91fn generate_impl(
92    number: u64,
93    base_ident: &Ident,
94    exp_range: (u32, u32),
95    sig_range: (u64, u64),
96) -> proc_macro2::TokenStream {
97    let powers_ident = format_ident!("__BASE_{}_POWERS", number);
98    let powers_u128_ident = format_ident!("__BASE_{}_U128_POWERS", number);
99
100    let (min_exp, max_exp) = exp_range;
101    let (min_sig, max_sig) = sig_range;
102
103    let shared = quote! {
104        const NUMBER: u16 = #number as u16;
105
106        fn new() -> Self {
107            Self()
108        }
109
110        fn exp_range(&self) -> bignumbe_rs::ExpRange {
111            bignumbe_rs::ExpRange(#min_exp, #max_exp)
112        }
113
114        fn sig_range(&self) -> bignumbe_rs::SigRange {
115            bignumbe_rs::SigRange(#min_sig, #max_sig)
116        }
117
118        fn pow(exp: u32) -> u64 {
119            #powers_ident[exp as usize]
120        }
121
122        fn pow_u128(exp: u32) -> u128 {
123            #powers_u128_ident[exp as usize]
124        }
125    };
126
127    if number.is_power_of_two() {
128        let log = number.ilog2();
129
130        quote! {
131            impl bignumbe_rs::Base for #base_ident {
132                #shared
133
134                fn rshift(lhs: u64, exp: u32) -> u64 {
135                    lhs >> (#log * exp)
136                }
137
138                fn rshift_u128(lhs: u128, exp: u32) -> u128 {
139                    lhs >> (#log * exp)
140                }
141
142                fn lshift(lhs: u64, exp: u32) -> u64 {
143                    lhs << (#log * exp)
144                }
145
146                fn lshift_u128(lhs: u128, exp: u32) -> u128 {
147                    lhs << (#log * exp)
148                }
149            }
150        }
151    } else {
152        quote! {
153            impl bignumbe_rs::Base for #base_ident {
154                #shared
155
156                fn rshift(lhs: u64, exp: u32) -> u64 {
157                    lhs / Self::pow(exp)
158                }
159
160                fn rshift_u128(lhs: u128, exp: u32) -> u128 {
161                    lhs / Self::pow_u128(exp)
162                }
163
164                fn lshift(lhs: u64, exp: u32) -> u64 {
165                    lhs * Self::pow(exp)
166                }
167
168                fn lshift_u128(lhs: u128, exp: u32) -> u128 {
169                    lhs * Self::pow_u128(exp)
170                }
171            }
172        }
173    }
174}
175
176fn generate_power_tables(
177    number: u64,
178    powers: Vec<u64>,
179    powers_u128: Vec<u128>,
180) -> proc_macro2::TokenStream {
181    let powers_len = powers.len();
182    let powers_u128_len = powers_u128.len();
183
184    let table_ident = format_ident!("__BASE_{}_POWERS", number);
185    let table_u128_ident = format_ident!("__BASE_{}_U128_POWERS", number);
186    quote! {
187        const #table_ident: [u64; #powers_len] = [
188            #(
189                #powers
190            ),*
191        ];
192
193        const #table_u128_ident: [u128; #powers_u128_len] = [
194            #(
195                #powers_u128
196            ),*
197        ];
198    }
199}
200
201fn get_base_data(number: u16) -> BaseData {
202    let mut curr = 1u128;
203
204    let mut powers = Vec::new();
205    let mut powers_u128 = Vec::new();
206
207    loop {
208        if curr <= u64::MAX as u128 {
209            powers.push(curr as u64);
210        }
211
212        powers_u128.push(curr);
213
214        match curr.checked_mul(number as u128) {
215            Some(res) => curr = res,
216            None => break,
217        }
218    }
219
220    let number = number as u64;
221    // TODO consider rewriting this to use the length and content of powers array instead
222    let (exp_range, sig_range) = if number.is_power_of_two() && number.ilog2().is_power_of_two() {
223        // This is a special case where sig_max = u64::MAX. We have to handle it
224        // specially to avoid overflowing the u64
225        let pow = number.ilog2();
226        let exp = 64 / pow;
227        let sig = number.pow(exp - 1);
228
229        ((exp - 1, exp), (sig, u64::MAX))
230    } else {
231        let exp = u64::MAX.ilog(number);
232        ((exp - 1, exp), (number.pow(exp - 1), number.pow(exp) - 1))
233    };
234
235    BaseData {
236        powers,
237        powers_u128,
238        exp_range,
239        sig_range,
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    macro_rules! test_base {
248        (spec $num:expr, $min_exp:expr) => {{
249            let num = $num as u64;
250            let data = get_base_data($num);
251            let min_exp = $min_exp;
252
253            assert_eq!(data.exp_range, (min_exp, min_exp + 1));
254            assert_eq!(data.sig_range, (num.pow(min_exp as u32), u64::MAX));
255            assert_eq!(data.powers.len(), data.exp_range.1 as usize);
256
257            assert_eq!(data.powers_u128.len(), data.exp_range.1 as usize * 2);
258
259            for (i, n) in data.powers.iter().enumerate() {
260                assert_eq!(*n, num.pow(i as u32));
261            }
262        }};
263        // By default it treats the base as not a power of two
264        ($num:expr, $min_exp:expr) => {{
265            let num = $num as u64;
266            let data = get_base_data($num);
267            let min_exp = $min_exp;
268
269            assert_eq!(data.exp_range, (min_exp, min_exp + 1));
270            assert_eq!(
271                data.sig_range,
272                (num.pow(min_exp as u32), num.pow(min_exp as u32 + 1) - 1)
273            );
274            assert_eq!(data.powers.len(), data.exp_range.1 as usize + 1);
275
276            assert_eq!(data.powers_u128.len(), data.exp_range.1 as usize * 2 + 1);
277
278            for (i, n) in data.powers.iter().enumerate() {
279                assert_eq!(*n, num.pow(i as u32));
280            }
281        }};
282    }
283
284    #[test]
285    fn get_base_data_test() {
286        test_base!(10, 18);
287        test_base!(8, 20);
288        test_base!(spec 256, 7);
289        test_base!(spec 16, 15);
290        test_base!(spec 2, 63);
291    }
292}