1#![warn(
2 unused,
3 future_incompatible,
4 nonstandard_style,
5 rust_2018_idioms,
6 rust_2021_compatibility
7)]
8#![forbid(unsafe_code)]
9
10use num_bigint::BigUint;
11use proc_macro::TokenStream;
12use quote::format_ident;
13use syn::{Expr, ExprLit, Item, ItemFn, Lit, Meta};
14
15pub(crate) mod montgomery;
16mod small_fp;
17mod unroll;
18
19pub(crate) mod utils;
20
21#[proc_macro]
22pub fn to_sign_and_limbs(input: TokenStream) -> TokenStream {
23 let num = utils::parse_string(input).expect("expected decimal string");
24 let (is_positive, limbs) = utils::str_to_limbs(&num);
25
26 let limbs: String = limbs.join(", ");
27 let limbs_and_sign = format!("({is_positive}") + ", [" + &limbs + "])";
28 let tuple: Expr = syn::parse_str(&limbs_and_sign).unwrap();
29 quote::quote!(#tuple).into()
30}
31
32#[proc_macro]
37pub fn define_field(input: TokenStream) -> TokenStream {
38 let args = syn::parse_macro_input!(input as utils::FieldArgs);
39
40 let modulus_big = args
41 .modulus
42 .parse::<BigUint>()
43 .expect("modulus should be a decimal integer string");
44
45 let limbs = utils::str_to_limbs_u64(&args.modulus).1.len();
46
47 let name = args.name;
48 let config_name = format_ident!("{}Config", name);
49 let is_small_modulus = modulus_big < (BigUint::from(1u128) << 64);
50
51 if is_small_modulus {
52 let modulus_u128: u128 = args
53 .modulus
54 .parse()
55 .expect("modulus should fit in u128 for small field");
56 let generator_u128: u128 = args
57 .generator
58 .parse()
59 .expect("generator should fit in u128 for small field");
60
61 let config_impl =
62 small_fp::small_fp_config_helper(modulus_u128, generator_u128, config_name.clone());
63
64 quote::quote! {
65 pub struct #config_name;
66 pub type #name = ark_ff::SmallFp<#config_name>;
67 #config_impl
68 }
69 .into()
70 } else {
71 let fp_alias = utils::fp_alias_for_limbs(limbs);
72
73 let generator_big = args
74 .generator
75 .parse::<BigUint>()
76 .expect("generator should be a decimal integer string");
77
78 let (small_subgroup_base, small_subgroup_power) =
79 match utils::find_conservative_subgroup_base(&modulus_big) {
80 Some((base, power)) => (Some(base), Some(power)),
81 None => (None, None),
82 };
83
84 let config_impl = montgomery::mont_config_helper(
85 modulus_big,
86 generator_big,
87 small_subgroup_base,
88 small_subgroup_power,
89 config_name.clone(),
90 );
91
92 quote::quote! {
93 pub struct #config_name;
94 pub type #name = #fp_alias<ark_ff::MontBackend<#config_name, #limbs>>;
95 #config_impl
96 }
97 .into()
98 }
99}
100
101#[proc_macro_derive(
112 MontConfig,
113 attributes(modulus, generator, small_subgroup_base, small_subgroup_power)
114)]
115pub fn mont_config(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
116 let ast: syn::DeriveInput = syn::parse(input).unwrap();
118
119 let modulus: BigUint = fetch_attr("modulus", &ast.attrs)
121 .expect("Please supply a modulus attribute")
122 .parse()
123 .expect("Modulus should be a number");
124
125 let generator: BigUint = fetch_attr("generator", &ast.attrs)
128 .expect("Please supply a generator attribute")
129 .parse()
130 .expect("Generator should be a number");
131
132 let small_subgroup_base: Option<u32> = fetch_attr("small_subgroup_base", &ast.attrs)
133 .map(|s| s.parse().expect("small_subgroup_base should be a number"));
134
135 let small_subgroup_power: Option<u32> = fetch_attr("small_subgroup_power", &ast.attrs)
136 .map(|s| s.parse().expect("small_subgroup_power should be a number"));
137
138 montgomery::mont_config_helper(
139 modulus,
140 generator,
141 small_subgroup_base,
142 small_subgroup_power,
143 ast.ident,
144 )
145 .into()
146}
147
148#[proc_macro_derive(SmallFpConfig, attributes(modulus, generator))]
156pub fn small_fp_config(input: TokenStream) -> TokenStream {
157 let ast: syn::DeriveInput = syn::parse(input).unwrap();
158
159 let modulus: u128 = fetch_attr("modulus", &ast.attrs)
160 .expect("Please supply a modulus attribute")
161 .parse()
162 .expect("Modulus should be a number");
163
164 let generator: u128 = fetch_attr("generator", &ast.attrs)
165 .expect("Please supply a generator attribute")
166 .parse()
167 .expect("Generator should be a number");
168
169 small_fp::small_fp_config_helper(modulus, generator, ast.ident).into()
170}
171
172const ARG_MSG: &str = "Failed to parse unroll threshold; must be a positive integer";
173
174#[proc_macro_attribute]
176pub fn unroll_for_loops(args: TokenStream, input: TokenStream) -> TokenStream {
177 let unroll_by = match syn::parse2::<syn::Lit>(args.into()).expect(ARG_MSG) {
178 Lit::Int(int) => int.base10_parse().expect(ARG_MSG),
179 _ => panic!("{}", ARG_MSG),
180 };
181
182 let item: Item = syn::parse(input).expect("Failed to parse input.");
183
184 if let Item::Fn(item_fn) = item {
185 let new_block = {
186 let ItemFn {
187 block: ref box_block,
188 ..
189 } = item_fn;
190 unroll::unroll_in_block(box_block, unroll_by)
191 };
192 let new_item = Item::Fn(ItemFn {
193 block: Box::new(new_block),
194 ..item_fn
195 });
196 quote::quote! ( #new_item ).into()
197 } else {
198 quote::quote! ( #item ).into()
199 }
200}
201
202fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
204 for attr in attrs {
206 match attr.meta {
207 Meta::NameValue(ref nv) if nv.path.is_ident(name) => {
210 if let Expr::Lit(ExprLit {
213 lit: Lit::Str(ref s),
214 ..
215 }) = nv.value
216 {
217 return Some(s.value());
218 }
219 panic!("attribute {name} should be a string")
220 },
221 _ => {},
222 }
223 }
224 None
225}
226
227#[test]
228#[allow(clippy::match_same_arms)]
229fn test_str_to_limbs() {
230 use num_bigint::Sign::*;
231 for i in 0..100 {
232 for sign in [Plus, Minus] {
233 let number = 1i128 << i;
234 let signed_number = match sign {
235 Minus => -number,
236 Plus => number,
237 _ => number,
238 };
239 for base in [2, 8, 16, 10] {
240 let mut string = match base {
241 2 => format!("{:#b}", number),
242 8 => format!("{:#o}", number),
243 16 => format!("{:#x}", number),
244 10 => format!("{}", number),
245 _ => unreachable!(),
246 };
247 if sign == Minus {
248 string.insert(0, '-');
249 }
250 let (is_positive, limbs) = utils::str_to_limbs(&string.clone());
251 assert_eq!(
252 limbs[0],
253 format!("{}u64", signed_number.unsigned_abs() as u64),
254 "{signed_number}, {i}"
255 );
256 if i > 63 {
257 assert_eq!(
258 limbs[1],
259 format!("{}u64", (signed_number.abs() >> 64) as u64),
260 "{signed_number}, {i}"
261 );
262 }
263
264 assert_eq!(is_positive, sign == Plus);
265 }
266 }
267 }
268 let (is_positive, limbs) = utils::str_to_limbs("0");
269 assert!(is_positive);
270 assert_eq!(&limbs, &["0u64".to_string()]);
271
272 let (is_positive, limbs) = utils::str_to_limbs("-5");
273 assert!(!is_positive);
274 assert_eq!(&limbs, &["5u64".to_string()]);
275
276 let (is_positive, limbs) = utils::str_to_limbs("100");
277 assert!(is_positive);
278 assert_eq!(&limbs, &["100u64".to_string()]);
279
280 let large_num = -((1i128 << 64) + 101234001234i128);
281 let (is_positive, limbs) = utils::str_to_limbs(&large_num.to_string());
282 assert!(!is_positive);
283 assert_eq!(&limbs, &["101234001234u64".to_string(), "1u64".to_string()]);
284
285 let num = "80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410946";
286 let (is_positive, limbs) = utils::str_to_limbs(num);
287 assert!(is_positive);
288 let expected_limbs = [
289 format!("{}u64", 0x8508c00000000002u64),
290 format!("{}u64", 0x452217cc90000000u64),
291 format!("{}u64", 0xc5ed1347970dec00u64),
292 format!("{}u64", 0x619aaf7d34594aabu64),
293 format!("{}u64", 0x9b3af05dd14f6ecu64),
294 ];
295 assert_eq!(&limbs, &expected_limbs);
296}