1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 parse::{self, Parse},
5 parse_macro_input, Ident, Lit, Token, Visibility,
6};
7
8struct BaseData {
9 exp_range: (u32, u32),
10 sig_range: (u64, u64),
11 powers: Vec<u64>,
12 powers_u128: Vec<u128>,
13}
14
15struct BaseInput {
16 num: Lit,
17 vis: Visibility,
18 name: Ident,
19}
20
21impl Parse for BaseInput {
22 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
23 let num = input.parse()?;
24 let _com: Token![,] = input.parse()?;
25 let vis = input.parse()?;
26 let name = input.parse()?;
27
28 Ok(Self { num, vis, name })
29 }
30}
31
32#[proc_macro]
36pub fn make_bignum(input: TokenStream) -> TokenStream {
37 let BaseInput { num, vis, name } = parse_macro_input!(input as BaseInput);
38 let (core, base_ident) = create_efficient_base_core(num);
39
40 quote! {
41 #core
42
43 #vis type #name = bignumbe_rs::BigNumBase<#base_ident>;
44 }
45 .into()
46}
47
48#[proc_macro]
50pub fn create_efficient_base(input: TokenStream) -> TokenStream {
51 create_efficient_base_core(parse_macro_input!(input as Lit))
52 .0
53 .into()
54}
55
56fn create_efficient_base_core(lit: Lit) -> (proc_macro2::TokenStream, Ident) {
57 let number: u16 = if let Lit::Int(li) = lit {
58 li.base10_parse()
59 .expect("Input must be a valid base-10 number")
60 } else {
61 panic!("Input must be a valid u16 value greater than 2");
62 };
63 let number = number as u64;
64
65 let base_ident = format_ident!("__Base{}", number);
66
67 let BaseData {
68 exp_range,
69 sig_range,
70 powers,
71 powers_u128,
72 } = get_base_data(number as u16);
73
74 let power_tables = generate_power_tables(number, powers, powers_u128);
75 let impl_code = generate_impl(number, &base_ident, exp_range, sig_range);
76
77 (
80 quote! {
81 #[derive(Clone, Copy, Debug)]
82 struct #base_ident();
83
84 #power_tables
85 #impl_code
86 },
87 base_ident,
88 )
89}
90
91fn generate_impl(
92 number: u64,
93 base_ident: &Ident,
94 exp_range: (u32, u32),
95 sig_range: (u64, u64),
96) -> proc_macro2::TokenStream {
97 let powers_ident = format_ident!("__BASE_{}_POWERS", number);
98 let powers_u128_ident = format_ident!("__BASE_{}_U128_POWERS", number);
99
100 let (min_exp, max_exp) = exp_range;
101 let (min_sig, max_sig) = sig_range;
102
103 let shared = quote! {
104 const NUMBER: u16 = #number as u16;
105
106 fn new() -> Self {
107 Self()
108 }
109
110 fn exp_range(&self) -> bignumbe_rs::ExpRange {
111 bignumbe_rs::ExpRange(#min_exp, #max_exp)
112 }
113
114 fn sig_range(&self) -> bignumbe_rs::SigRange {
115 bignumbe_rs::SigRange(#min_sig, #max_sig)
116 }
117
118 fn pow(exp: u32) -> u64 {
119 #powers_ident[exp as usize]
120 }
121
122 fn pow_u128(exp: u32) -> u128 {
123 #powers_u128_ident[exp as usize]
124 }
125 };
126
127 if number.is_power_of_two() {
128 let log = number.ilog2();
129
130 quote! {
131 impl bignumbe_rs::Base for #base_ident {
132 #shared
133
134 fn rshift(lhs: u64, exp: u32) -> u64 {
135 lhs >> (#log * exp)
136 }
137
138 fn rshift_u128(lhs: u128, exp: u32) -> u128 {
139 lhs >> (#log * exp)
140 }
141
142 fn lshift(lhs: u64, exp: u32) -> u64 {
143 lhs << (#log * exp)
144 }
145
146 fn lshift_u128(lhs: u128, exp: u32) -> u128 {
147 lhs << (#log * exp)
148 }
149 }
150 }
151 } else {
152 quote! {
153 impl bignumbe_rs::Base for #base_ident {
154 #shared
155
156 fn rshift(lhs: u64, exp: u32) -> u64 {
157 lhs / Self::pow(exp)
158 }
159
160 fn rshift_u128(lhs: u128, exp: u32) -> u128 {
161 lhs / Self::pow_u128(exp)
162 }
163
164 fn lshift(lhs: u64, exp: u32) -> u64 {
165 lhs * Self::pow(exp)
166 }
167
168 fn lshift_u128(lhs: u128, exp: u32) -> u128 {
169 lhs * Self::pow_u128(exp)
170 }
171 }
172 }
173 }
174}
175
176fn generate_power_tables(
177 number: u64,
178 powers: Vec<u64>,
179 powers_u128: Vec<u128>,
180) -> proc_macro2::TokenStream {
181 let powers_len = powers.len();
182 let powers_u128_len = powers_u128.len();
183
184 let table_ident = format_ident!("__BASE_{}_POWERS", number);
185 let table_u128_ident = format_ident!("__BASE_{}_U128_POWERS", number);
186 quote! {
187 const #table_ident: [u64; #powers_len] = [
188 #(
189 #powers
190 ),*
191 ];
192
193 const #table_u128_ident: [u128; #powers_u128_len] = [
194 #(
195 #powers_u128
196 ),*
197 ];
198 }
199}
200
201fn get_base_data(number: u16) -> BaseData {
202 let mut curr = 1u128;
203
204 let mut powers = Vec::new();
205 let mut powers_u128 = Vec::new();
206
207 loop {
208 if curr <= u64::MAX as u128 {
209 powers.push(curr as u64);
210 }
211
212 powers_u128.push(curr);
213
214 match curr.checked_mul(number as u128) {
215 Some(res) => curr = res,
216 None => break,
217 }
218 }
219
220 let number = number as u64;
221 let (exp_range, sig_range) = if number.is_power_of_two() && number.ilog2().is_power_of_two() {
223 let pow = number.ilog2();
226 let exp = 64 / pow;
227 let sig = number.pow(exp - 1);
228
229 ((exp - 1, exp), (sig, u64::MAX))
230 } else {
231 let exp = u64::MAX.ilog(number);
232 ((exp - 1, exp), (number.pow(exp - 1), number.pow(exp) - 1))
233 };
234
235 BaseData {
236 powers,
237 powers_u128,
238 exp_range,
239 sig_range,
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 macro_rules! test_base {
248 (spec $num:expr, $min_exp:expr) => {{
249 let num = $num as u64;
250 let data = get_base_data($num);
251 let min_exp = $min_exp;
252
253 assert_eq!(data.exp_range, (min_exp, min_exp + 1));
254 assert_eq!(data.sig_range, (num.pow(min_exp as u32), u64::MAX));
255 assert_eq!(data.powers.len(), data.exp_range.1 as usize);
256
257 assert_eq!(data.powers_u128.len(), data.exp_range.1 as usize * 2);
258
259 for (i, n) in data.powers.iter().enumerate() {
260 assert_eq!(*n, num.pow(i as u32));
261 }
262 }};
263 ($num:expr, $min_exp:expr) => {{
265 let num = $num as u64;
266 let data = get_base_data($num);
267 let min_exp = $min_exp;
268
269 assert_eq!(data.exp_range, (min_exp, min_exp + 1));
270 assert_eq!(
271 data.sig_range,
272 (num.pow(min_exp as u32), num.pow(min_exp as u32 + 1) - 1)
273 );
274 assert_eq!(data.powers.len(), data.exp_range.1 as usize + 1);
275
276 assert_eq!(data.powers_u128.len(), data.exp_range.1 as usize * 2 + 1);
277
278 for (i, n) in data.powers.iter().enumerate() {
279 assert_eq!(*n, num.pow(i as u32));
280 }
281 }};
282 }
283
284 #[test]
285 fn get_base_data_test() {
286 test_base!(10, 18);
287 test_base!(8, 20);
288 test_base!(spec 256, 7);
289 test_base!(spec 16, 15);
290 test_base!(spec 2, 63);
291 }
292}