Skip to main content

ark_ff_macros/
lib.rs

1#![warn(
2    unused,
3    future_incompatible,
4    nonstandard_style,
5    rust_2018_idioms,
6    rust_2021_compatibility
7)]
8#![forbid(unsafe_code)]
9
10use num_bigint::BigUint;
11use proc_macro::TokenStream;
12use quote::format_ident;
13use syn::{Expr, ExprLit, Item, ItemFn, Lit, Meta};
14
15pub(crate) mod montgomery;
16mod small_fp;
17mod unroll;
18
19pub(crate) mod utils;
20
21#[proc_macro]
22pub fn to_sign_and_limbs(input: TokenStream) -> TokenStream {
23    let num = utils::parse_string(input).expect("expected decimal string");
24    let (is_positive, limbs) = utils::str_to_limbs(&num);
25
26    let limbs: String = limbs.join(", ");
27    let limbs_and_sign = format!("({is_positive}") + ", [" + &limbs + "])";
28    let tuple: Expr = syn::parse_str(&limbs_and_sign).unwrap();
29    quote::quote!(#tuple).into()
30}
31
32/// Define optimal field type and its corresponding config
33///  
34/// If modulus fits into a native datatype, the resulting type is SmallFp<<name>Config>
35/// Otherwise, it is the appropriately sized Fp*<MontBackend<<name>Config, N>>
36#[proc_macro]
37pub fn define_field(input: TokenStream) -> TokenStream {
38    let args = syn::parse_macro_input!(input as utils::FieldArgs);
39
40    let modulus_big = args
41        .modulus
42        .parse::<BigUint>()
43        .expect("modulus should be a decimal integer string");
44
45    let limbs = utils::str_to_limbs_u64(&args.modulus).1.len();
46
47    let name = args.name;
48    let config_name = format_ident!("{}Config", name);
49    let is_small_modulus = modulus_big < (BigUint::from(1u128) << 64);
50
51    if is_small_modulus {
52        let modulus_u128: u128 = args
53            .modulus
54            .parse()
55            .expect("modulus should fit in u128 for small field");
56        let generator_u128: u128 = args
57            .generator
58            .parse()
59            .expect("generator should fit in u128 for small field");
60
61        let config_impl =
62            small_fp::small_fp_config_helper(modulus_u128, generator_u128, config_name.clone());
63
64        quote::quote! {
65            pub struct #config_name;
66            pub type #name = ark_ff::SmallFp<#config_name>;
67            #config_impl
68        }
69        .into()
70    } else {
71        let fp_alias = utils::fp_alias_for_limbs(limbs);
72
73        let generator_big = args
74            .generator
75            .parse::<BigUint>()
76            .expect("generator should be a decimal integer string");
77
78        let (small_subgroup_base, small_subgroup_power) =
79            match utils::find_conservative_subgroup_base(&modulus_big) {
80                Some((base, power)) => (Some(base), Some(power)),
81                None => (None, None),
82            };
83
84        let config_impl = montgomery::mont_config_helper(
85            modulus_big,
86            generator_big,
87            small_subgroup_base,
88            small_subgroup_power,
89            config_name.clone(),
90        );
91
92        quote::quote! {
93            pub struct #config_name;
94            pub type #name = #fp_alias<ark_ff::MontBackend<#config_name, #limbs>>;
95            #config_impl
96        }
97        .into()
98    }
99}
100
101/// Derive the `MontConfig` trait.
102///
103/// The attributes available to this macro are
104/// * `modulus`: Specify the prime modulus underlying this prime field.
105/// * `generator`: Specify the generator of the multiplicative subgroup of this
106///   prime field. This value must be a quadratic non-residue in the field.
107/// * `small_subgroup_base` and `small_subgroup_power` (optional): If the field
108///   has insufficient two-adicity, specify an additional subgroup of size
109///   `small_subgroup_base.pow(small_subgroup_power)`.
110// This code was adapted from the `PrimeField` Derive Macro in ff-derive.
111#[proc_macro_derive(
112    MontConfig,
113    attributes(modulus, generator, small_subgroup_base, small_subgroup_power)
114)]
115pub fn mont_config(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
116    // Parse the type definition
117    let ast: syn::DeriveInput = syn::parse(input).unwrap();
118
119    // We're given the modulus p of the prime field
120    let modulus: BigUint = fetch_attr("modulus", &ast.attrs)
121        .expect("Please supply a modulus attribute")
122        .parse()
123        .expect("Modulus should be a number");
124
125    // We may be provided with a generator of p - 1 order. It is required that this
126    // generator be quadratic nonresidue.
127    let generator: BigUint = fetch_attr("generator", &ast.attrs)
128        .expect("Please supply a generator attribute")
129        .parse()
130        .expect("Generator should be a number");
131
132    let small_subgroup_base: Option<u32> = fetch_attr("small_subgroup_base", &ast.attrs)
133        .map(|s| s.parse().expect("small_subgroup_base should be a number"));
134
135    let small_subgroup_power: Option<u32> = fetch_attr("small_subgroup_power", &ast.attrs)
136        .map(|s| s.parse().expect("small_subgroup_power should be a number"));
137
138    montgomery::mont_config_helper(
139        modulus,
140        generator,
141        small_subgroup_base,
142        small_subgroup_power,
143        ast.ident,
144    )
145    .into()
146}
147
148/// Derive the `SmallFpConfig` trait for small prime fields.
149///
150/// The attributes available to this macro are:
151/// * `modulus`: Specify the prime modulus underlying this prime field.
152/// * `generator`: Specify the generator of the multiplicative subgroup.
153///
154/// Note: Only Montgomery backend is supported.
155#[proc_macro_derive(SmallFpConfig, attributes(modulus, generator))]
156pub fn small_fp_config(input: TokenStream) -> TokenStream {
157    let ast: syn::DeriveInput = syn::parse(input).unwrap();
158
159    let modulus: u128 = fetch_attr("modulus", &ast.attrs)
160        .expect("Please supply a modulus attribute")
161        .parse()
162        .expect("Modulus should be a number");
163
164    let generator: u128 = fetch_attr("generator", &ast.attrs)
165        .expect("Please supply a generator attribute")
166        .parse()
167        .expect("Generator should be a number");
168
169    small_fp::small_fp_config_helper(modulus, generator, ast.ident).into()
170}
171
172const ARG_MSG: &str = "Failed to parse unroll threshold; must be a positive integer";
173
174/// Attribute used to unroll for loops found inside a function block.
175#[proc_macro_attribute]
176pub fn unroll_for_loops(args: TokenStream, input: TokenStream) -> TokenStream {
177    let unroll_by = match syn::parse2::<syn::Lit>(args.into()).expect(ARG_MSG) {
178        Lit::Int(int) => int.base10_parse().expect(ARG_MSG),
179        _ => panic!("{}", ARG_MSG),
180    };
181
182    let item: Item = syn::parse(input).expect("Failed to parse input.");
183
184    if let Item::Fn(item_fn) = item {
185        let new_block = {
186            let ItemFn {
187                block: ref box_block,
188                ..
189            } = item_fn;
190            unroll::unroll_in_block(box_block, unroll_by)
191        };
192        let new_item = Item::Fn(ItemFn {
193            block: Box::new(new_block),
194            ..item_fn
195        });
196        quote::quote! ( #new_item ).into()
197    } else {
198        quote::quote! ( #item ).into()
199    }
200}
201
202/// Fetch an attribute string from the derived struct.
203fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
204    // Go over each attribute
205    for attr in attrs {
206        match attr.meta {
207            // If the attribute's path matches `name`, and if the attribute is of
208            // the form `#[name = "value"]`, return `value`
209            Meta::NameValue(ref nv) if nv.path.is_ident(name) => {
210                // Extract and return the string value.
211                // If `value` is not a string, return an error
212                if let Expr::Lit(ExprLit {
213                    lit: Lit::Str(ref s),
214                    ..
215                }) = nv.value
216                {
217                    return Some(s.value());
218                }
219                panic!("attribute {name} should be a string")
220            },
221            _ => {},
222        }
223    }
224    None
225}
226
227#[test]
228#[allow(clippy::match_same_arms)]
229fn test_str_to_limbs() {
230    use num_bigint::Sign::*;
231    for i in 0..100 {
232        for sign in [Plus, Minus] {
233            let number = 1i128 << i;
234            let signed_number = match sign {
235                Minus => -number,
236                Plus => number,
237                _ => number,
238            };
239            for base in [2, 8, 16, 10] {
240                let mut string = match base {
241                    2 => format!("{:#b}", number),
242                    8 => format!("{:#o}", number),
243                    16 => format!("{:#x}", number),
244                    10 => format!("{}", number),
245                    _ => unreachable!(),
246                };
247                if sign == Minus {
248                    string.insert(0, '-');
249                }
250                let (is_positive, limbs) = utils::str_to_limbs(&string.clone());
251                assert_eq!(
252                    limbs[0],
253                    format!("{}u64", signed_number.unsigned_abs() as u64),
254                    "{signed_number}, {i}"
255                );
256                if i > 63 {
257                    assert_eq!(
258                        limbs[1],
259                        format!("{}u64", (signed_number.abs() >> 64) as u64),
260                        "{signed_number}, {i}"
261                    );
262                }
263
264                assert_eq!(is_positive, sign == Plus);
265            }
266        }
267    }
268    let (is_positive, limbs) = utils::str_to_limbs("0");
269    assert!(is_positive);
270    assert_eq!(&limbs, &["0u64".to_string()]);
271
272    let (is_positive, limbs) = utils::str_to_limbs("-5");
273    assert!(!is_positive);
274    assert_eq!(&limbs, &["5u64".to_string()]);
275
276    let (is_positive, limbs) = utils::str_to_limbs("100");
277    assert!(is_positive);
278    assert_eq!(&limbs, &["100u64".to_string()]);
279
280    let large_num = -((1i128 << 64) + 101234001234i128);
281    let (is_positive, limbs) = utils::str_to_limbs(&large_num.to_string());
282    assert!(!is_positive);
283    assert_eq!(&limbs, &["101234001234u64".to_string(), "1u64".to_string()]);
284
285    let num = "80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410946";
286    let (is_positive, limbs) = utils::str_to_limbs(num);
287    assert!(is_positive);
288    let expected_limbs = [
289        format!("{}u64", 0x8508c00000000002u64),
290        format!("{}u64", 0x452217cc90000000u64),
291        format!("{}u64", 0xc5ed1347970dec00u64),
292        format!("{}u64", 0x619aaf7d34594aabu64),
293        format!("{}u64", 0x9b3af05dd14f6ecu64),
294    ];
295    assert_eq!(&limbs, &expected_limbs);
296}