g2gen/
lib.rs

1// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4// option. This file may not be copied, modified, or distributed
5// except according to those terms.
6
7//! Procedural macro to generate finite field types
8//!
9//! This is just the procedural macro, for more information look at [g2p](https://docs.rs/g2p).
10
11#![recursion_limit = "128"]
12extern crate proc_macro;
13
14use proc_macro::TokenStream as P1TokenStream;
15use proc_macro2::{Ident, Span, TokenStream as P2TokenStream};
16
17use g2poly::{extended_gcd, G2Poly};
18use quote::quote;
19use syn::{
20    parse::{Parse, ParseStream},
21    parse_macro_input, Token,
22};
23
24/// Generate a newtype of the given name and implement finite field arithmetic on it.
25///
26/// The generated type have implementations for [`Add`](::core::ops::Add),
27/// [`Sub`](::core::ops::Sub), [`Mul`](::core::ops::Mul) and [`Div`](::core::ops::Div).
28///
29/// There are also implementations for equality, copy and debug. Conversion from and to the base
30/// type are implemented via the From trait.
31/// Depending on the size of `p` the underlying type is u8, u16 or u32.
32///
33/// # Example
34/// ```ignore
35/// g2gen::g2p!(
36///     GF256,                  // Name of the newtype
37///     8,                      // The power of 2 specifying the field size 2^8 = 256 in this
38///                             // case.
39///     modulus: 0b1_0001_1101, // The reduction polynomial to use, each bit is a coefficient.
40///                             // Can be left out in case it is not needed.
41/// );
42///
43/// # fn main() {
44/// let a: GF256 = 255.into();  // Conversion from the base type
45/// assert_eq!(a - a, a + a);   // Finite field arithmetic.
46/// assert_eq!(format!("{}", a), "255_GF256");
47/// # }
48/// ```
49#[proc_macro]
50pub fn g2p(input: P1TokenStream) -> P1TokenStream {
51    let args = parse_macro_input!(input as ParsedInput);
52    let settings = Settings::from_input(args).unwrap();
53    let ident = settings.ident;
54    let ident_name = settings.ident_name;
55    let modulus = settings.modulus;
56    let generator = settings.generator;
57    let p = settings.p_val;
58    let field_size = 1_usize << p;
59    let mask = (1_u64 << p).wrapping_sub(1);
60
61    let ty = match p {
62        0 => panic!("p must be > 0"),
63        1..=8 => quote!(u8),
64        9..=16 => quote!(u16),
65        17..=32 => quote!(u32),
66        _ => unimplemented!("p > 32 is not implemented right now"),
67    };
68
69    let mod_name = Ident::new(&format!("{}_mod", ident_name), Span::call_site());
70
71    let struct_def = quote![
72        #[derive(Clone, Copy, Hash)]
73        pub struct #ident(pub #ty);
74    ];
75
76    let struct_impl = quote![
77        impl #ident {
78            pub const MASK: #ty = #mask as #ty;
79        }
80    ];
81
82    let from = quote![
83        impl ::core::convert::From<#ident> for #ty {
84            fn from(v: #ident) -> #ty {
85                v.0
86            }
87        }
88    ];
89
90    let into = quote![
91        impl ::core::convert::From<#ty> for #ident {
92            fn from(v: #ty) -> #ident {
93                #ident(v & #ident::MASK)
94            }
95        }
96    ];
97
98    let eq = quote![
99        impl ::core::cmp::PartialEq<Self> for #ident {
100            fn eq(&self, other: &Self) -> bool {
101                self.0 == other.0
102            }
103        }
104
105        impl ::core::cmp::Eq for #ident {}
106    ];
107
108    let tmpl = format!("{{}}_{}", ident_name);
109    let debug = quote![
110        impl ::core::fmt::Debug for #ident {
111            fn fmt<'a>(&self, f: &mut ::core::fmt::Formatter<'a>) -> ::core::fmt::Result {
112                write!(f, #tmpl, self.0)
113            }
114        }
115    ];
116    let display = quote![
117        impl ::core::fmt::Display for #ident {
118            fn fmt<'a>(&self, f: &mut ::core::fmt::Formatter<'a>) -> ::core::fmt::Result {
119                write!(f, #tmpl, self.0)
120            }
121        }
122    ];
123    let add = quote![
124        impl ::core::ops::Add for #ident {
125            type Output = Self;
126
127            #[allow(clippy::suspicious_arithmetic_impl)]
128            fn add(self, rhs: Self) -> Self {
129                Self(self.0 ^ rhs.0)
130            }
131        }
132        impl ::core::ops::AddAssign for #ident {
133            fn add_assign(&mut self, rhs: Self) {
134                *self = *self + rhs;
135            }
136        }
137    ];
138    let sum = quote![
139        impl ::core::iter::Sum for #ident {
140            fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
141                iter.fold(<Self as ::g2p::GaloisField>::ZERO, ::core::ops::Add::add)
142            }
143        }
144    ];
145    let sub = quote![
146        impl ::core::ops::Sub for #ident {
147            type Output = Self;
148
149
150            #[allow(clippy::suspicious_arithmetic_impl)]
151            fn sub(self, rhs: Self) -> Self {
152                Self(self.0 ^ rhs.0)
153            }
154        }
155        impl ::core::ops::SubAssign for #ident {
156            fn sub_assign(&mut self, rhs: Self) {
157                *self = *self - rhs;
158            }
159        }
160        impl ::core::ops::Neg for #ident {
161            type Output = Self;
162
163            fn neg(self) -> Self::Output {
164                self
165            }
166        }
167    ];
168    let gen = generator.0;
169    let modulus_val = modulus.0;
170    let galois_trait_impl = quote![
171        impl ::g2p::GaloisField for #ident {
172            const SIZE: usize = #field_size;
173            const MODULUS: ::g2p::G2Poly = ::g2p::G2Poly(#modulus_val);
174            const ZERO: Self = Self(0);
175            const ONE: Self = Self(1);
176            const GENERATOR: Self = Self(#gen as #ty);
177        }
178    ];
179
180    let (tables, mul, div) =
181        generate_mul_impl(ident.clone(), &ident_name, modulus, ty, field_size, mask);
182    let product = quote![
183        impl ::core::iter::Product for #ident {
184            fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
185                iter.fold(<Self as ::g2p::GaloisField>::ONE, ::core::ops::Mul::mul)
186            }
187        }
188    ];
189
190    P1TokenStream::from(quote![
191        #struct_def
192
193        mod #mod_name {
194            use super::#ident;
195            #struct_impl
196            #tables
197            #from
198            #into
199            #eq
200            #debug
201            #display
202            #add
203            #sum
204            #sub
205            #mul
206            #product
207            #div
208            #galois_trait_impl
209        }
210    ])
211}
212
213struct ParsedInput {
214    ident: syn::Ident,
215    p: syn::LitInt,
216    modulus: Option<syn::LitInt>,
217}
218
219impl Parse for ParsedInput {
220    fn parse(input: ParseStream) -> syn::Result<Self> {
221        let ident = input.parse()?;
222        let _sep: Token![,] = input.parse()?;
223        let p = input.parse()?;
224
225        let mut modulus = None;
226
227        loop {
228            let sep: Option<Token![,]> = input.parse()?;
229            if sep.is_none() || input.is_empty() {
230                break;
231            }
232            let ident: syn::Ident = input.parse()?;
233            let ident_name = ident.to_string();
234            let _sep: Token![:] = input.parse()?;
235            match ident_name.as_str() {
236                "modulus" => {
237                    if modulus.is_some() {
238                        Err(syn::parse::Error::new(
239                            ident.span(),
240                            "Double declaration of 'modulus'",
241                        ))?
242                    }
243                    modulus = Some(input.parse()?);
244                }
245                _ => Err(syn::parse::Error::new(ident.span(), "Expected 'modulus'"))?,
246            }
247        }
248
249        Ok(ParsedInput { ident, p, modulus })
250    }
251}
252
253#[derive(Debug, Clone, Eq, PartialEq)]
254struct Settings {
255    ident: syn::Ident,
256    ident_name: String,
257    p_val: u64,
258    modulus: G2Poly,
259    generator: G2Poly,
260}
261
262fn find_modulus_poly(p: u64) -> G2Poly {
263    assert!(p < 64);
264
265    let start = (1 << p) + 1;
266    let end = (1_u64 << (p + 1)).wrapping_sub(1);
267
268    for m in start..=end {
269        let p = G2Poly(m);
270        if p.is_irreducible() {
271            return p;
272        }
273    }
274
275    unreachable!("There are irreducible polynomial for any degree!")
276}
277
278fn find_generator(m: G2Poly) -> G2Poly {
279    let max = m.degree().expect("Modulus must have positive degree");
280
281    for g in 1..(2 << max) {
282        let g = G2Poly(g);
283        if g.is_generator(m) {
284            return g;
285        }
286    }
287
288    unreachable!("There must be a generator element")
289}
290
291/// Calculate the log base 256, rounded up
292///
293/// Given a number n, calculate the log base 256, rounded up. This can be though of as the number
294/// of bytes needed to represent this number.
295fn ceil_log256(mut n: usize) -> usize {
296    if n == 0 {
297        return 0;
298    }
299
300    let mut c = 1;
301    while n > 256 {
302        c += 1;
303        // NB: This is the rounding up part. If n is a proper power of 256, adding 255 will not
304        // change the result. In the other cases, this ensures that we round up in the division.
305        n = (n + 255) >> 8;
306    }
307    c
308}
309
310/// Generate multiplication array
311///
312/// Generate a string representing a 5d multiplication array. This array uses the associativity
313/// of multiplication `(a + b) * (c + d) == a*c + a*d + b*c + b*d` to reduce table size.
314///
315/// The input is split into bit chunks e.g. for a GF_1024 number we take the lower 8 bit and the
316/// remaining 2 and calculate the multiplications for each separately. Then we can cheaply add them
317/// together to get the the result with requiring a full 1024 * 1024 input.
318fn generate_mul_table_string(modulus: G2Poly) -> String {
319    assert!(modulus.is_irreducible());
320
321    let field_size = 1
322        << modulus
323            .degree()
324            .expect("Irreducible polynomial has positive degree");
325    let nparts = ceil_log256(field_size as usize);
326
327    let mut mul_table = Vec::with_capacity(nparts);
328    for left in 0..nparts {
329        let mut left_parts = Vec::with_capacity(nparts);
330        for right in 0..nparts {
331            let mut right_parts = Vec::with_capacity(256);
332            for i in 0..256 {
333                let i = i << (8 * left);
334                let mut row = Vec::with_capacity(256);
335                for j in 0..256 {
336                    let j = j << (8 * right);
337                    let v = if i < field_size && j < field_size {
338                        G2Poly(i as u64) * G2Poly(j as u64) % modulus
339                    } else {
340                        G2Poly(0)
341                    };
342
343                    row.push(format!("{}", v.0));
344                }
345                right_parts.push(format!("[{}]", row.join(",")));
346            }
347            left_parts.push(format!("[{}]", right_parts.join(",")));
348        }
349        mul_table.push(format!("[{}]", left_parts.join(",")));
350    }
351
352    format!("[{}]", mul_table.join(","))
353}
354
355fn generate_inv_table_string(modulus: G2Poly) -> String {
356    assert!(modulus.is_irreducible());
357
358    let field_size = 1
359        << modulus
360            .degree()
361            .expect("Irreducible polynomial has positive degree");
362    let mut inv_table = vec![0; field_size as usize];
363    // Inverse table is small enough to compute directly
364    for i in 1..field_size {
365        if inv_table[i as usize] != 0 {
366            // Already computed inverse
367            continue;
368        }
369
370        let a = G2Poly(i);
371
372        // Returns (gcd, x, y) such that gcd(a, m) == a * x + y * m
373        // Since we know that gcd(a, m) == 1 and that we operate modulo m, y * m === 0 mod m
374        // So we have 1 === a * x mod m
375
376        let (_gcd, x, _y) = extended_gcd(a, modulus);
377        inv_table[i as usize] = x.0;
378        inv_table[x.0 as usize] = i;
379    }
380
381    use std::fmt::Write;
382    let mut res = String::with_capacity(3 * field_size as usize);
383    write!(&mut res, "[").unwrap();
384    for v in inv_table {
385        write!(&mut res, "{},", v).unwrap();
386    }
387    write!(&mut res, "]").unwrap();
388    res
389}
390
391fn generate_mul_impl(
392    ident: syn::Ident,
393    ident_name: &str,
394    modulus: G2Poly,
395    ty: P2TokenStream,
396    field_size: usize,
397    mask: u64,
398) -> (P2TokenStream, P2TokenStream, P2TokenStream) {
399    let mul_table = generate_mul_table_string(modulus);
400    let inv_table = generate_inv_table_string(modulus);
401
402    // Faster generation than using quote
403    let mul_table_string: proc_macro2::TokenStream = mul_table.parse().unwrap();
404    let inv_table_string: proc_macro2::TokenStream = inv_table.parse().unwrap();
405
406    let nparts = ceil_log256(field_size);
407
408    // NB: We generate static arrays, as they are guaranteed to have a fixed location in memory.
409    //     Using const would mean the compiler is free to create copies on the stack etc. Since
410    //     The arrays are quite large, this could lead to stack overflows.
411    let tables = quote! {
412        pub static MUL_TABLE: [[[[#ty; 256]; 256]; #nparts]; #nparts] = #mul_table_string;
413        pub static INV_TABLE: [#ty; #field_size] = #inv_table_string;
414    };
415
416    let mut mul_ops = Vec::with_capacity(nparts * nparts);
417    for left in 0..nparts {
418        for right in 0..nparts {
419            mul_ops.push(quote![
420                #ident(MUL_TABLE[#left][#right][(((self.0 & #mask as #ty) >> (8*#left)) & 255) as usize][(((rhs.0 & #mask as #ty) >> (8*#right)) & 255) as usize])
421            ]);
422        }
423    }
424
425    let mul = quote![
426        impl ::core::ops::Mul for #ident {
427            type Output = Self;
428            fn mul(self, rhs: Self) -> Self {
429                #(#mul_ops)+*
430            }
431        }
432        impl ::core::ops::MulAssign for #ident {
433            fn mul_assign(&mut self, rhs: Self) {
434                *self = *self * rhs;
435            }
436        }
437    ];
438
439    let err_msg = format!("Division by 0 in {}", ident_name);
440
441    let div = quote![
442        impl ::core::ops::Div for #ident {
443            type Output = Self;
444
445            fn div(self, rhs: Self) -> Self {
446                if (rhs.0 & #mask as #ty) == 0 {
447                    panic!(#err_msg);
448                }
449                self * Self(INV_TABLE[(rhs.0 & #mask as #ty) as usize])
450            }
451        }
452        impl ::core::ops::DivAssign for #ident {
453            fn div_assign(&mut self, rhs: Self) {
454                *self = *self / rhs;
455            }
456        }
457    ];
458
459    (tables, mul, div)
460}
461
462impl Settings {
463    pub fn from_input(input: ParsedInput) -> syn::Result<Self> {
464        let ident = input.ident;
465        let ident_name = ident.to_string();
466        let p_val = input.p.base10_parse()?;
467        let modulus = match input.modulus {
468            Some(lit) => G2Poly(lit.base10_parse()?),
469            None => find_modulus_poly(p_val),
470        };
471
472        if !modulus.is_irreducible() {
473            Err(syn::Error::new(
474                Span::call_site(),
475                format!("Modulus {} is not irreducible", modulus),
476            ))?;
477        }
478
479        let generator = find_generator(modulus);
480
481        if !generator.is_generator(modulus) {
482            Err(syn::Error::new(
483                Span::call_site(),
484                format!("{} is not a generator", generator),
485            ))?;
486        }
487
488        Ok(Settings {
489            ident,
490            ident_name,
491            p_val,
492            modulus,
493            generator,
494        })
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn test_settings_parser() {
504        let span = Span::call_site();
505
506        let input = ParsedInput {
507            ident: Ident::new("foo", span),
508            p: syn::LitInt::new("3", span),
509            modulus: None,
510        };
511
512        let r = Settings::from_input(input);
513        assert!(r.is_ok());
514        assert_eq!(
515            r.unwrap(),
516            Settings {
517                ident: syn::Ident::new("foo", span),
518                ident_name: "foo".to_string(),
519                p_val: 3,
520                modulus: G2Poly(0b1011),
521                generator: G2Poly(0b10),
522            }
523        );
524    }
525
526    #[test]
527    fn test_generate_mul_table() {
528        let m = G2Poly(0b111);
529
530        assert_eq!(
531            include_str!("../tests/mul_table.txt").trim(),
532            generate_mul_table_string(m)
533        );
534    }
535
536    #[test]
537    fn test_generate_inv_table_string() {
538        let m = G2Poly(0b1_0001_1011);
539
540        assert_eq!(
541            include_str!("../tests/inv_table.txt").trim(),
542            generate_inv_table_string(m)
543        );
544    }
545
546    #[test]
547    fn test_ceil_log256() {
548        assert_eq!(0, ceil_log256(0));
549        assert_eq!(1, ceil_log256(1));
550        assert_eq!(1, ceil_log256(256));
551        assert_eq!(2, ceil_log256(257));
552        assert_eq!(2, ceil_log256(65536));
553        assert_eq!(3, ceil_log256(65537));
554        assert_eq!(3, ceil_log256(131072));
555        assert_eq!(3, ceil_log256(16777216));
556        assert_eq!(4, ceil_log256(16777217));
557    }
558}