g2gen/
lib.rs

1// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4// option. This file may not be copied, modified, or distributed
5// except according to those terms.
6
7//! Procedural macro to generate finite field types
8//!
9//! This is just the procedural macro, for more information look at [g2p](https://docs.rs/g2p).
10
11#![recursion_limit = "128"]
12extern crate proc_macro;
13
14use proc_macro::TokenStream as P1TokenStream;
15use proc_macro2::{Ident, Span, TokenStream as P2TokenStream};
16
17use g2poly::{extended_gcd, G2Poly};
18use quote::quote;
19use syn::{
20    parse::{Parse, ParseStream},
21    parse_macro_input, Token,
22};
23
24/// Generate a newtype of the given name and implement finite field arithmetic on it.
25///
26/// The generated type have implementations for [`Add`](::core::ops::Add),
27/// [`Sub`](::core::ops::Sub), [`Mul`](::core::ops::Mul) and [`Div`](::core::ops::Div).
28///
29/// There are also implementations for equality, copy and debug. Conversion from and to the base
30/// type are implemented via the From trait.
31/// Depending on the size of `p` the underlying type is u8, u16 or u32.
32///
33/// # Example
34/// ```ignore
35/// g2gen::g2p!(
36///     GF256,                  // Name of the newtype
37///     8,                      // The power of 2 specifying the field size 2^8 = 256 in this
38///                             // case.
39///     modulus: 0b1_0001_1101, // The reduction polynomial to use, each bit is a coefficient.
40///                             // Can be left out in case it is not needed.
41/// );
42///
43/// # fn main() {
44/// let a: GF256 = 255.into();  // Conversion from the base type
45/// assert_eq!(a - a, a + a);   // Finite field arithmetic.
46/// assert_eq!(format!("{}", a), "255_GF256");
47/// # }
48/// ```
49#[proc_macro]
50pub fn g2p(input: P1TokenStream) -> P1TokenStream {
51    let args = parse_macro_input!(input as ParsedInput);
52    let settings = Settings::from_input(args).unwrap();
53    let ident = settings.ident;
54    let ident_name = settings.ident_name;
55    let modulus = settings.modulus;
56    let generator = settings.generator;
57    let p = settings.p_val;
58    let field_size = 1_usize << p;
59    let mask = (1_u64 << p).wrapping_sub(1);
60
61    let ty = match p {
62        0 => panic!("p must be > 0"),
63        1..=8 => quote!(u8),
64        9..=16 => quote!(u16),
65        17..=32 => quote!(u32),
66        _ => unimplemented!("p > 32 is not implemented right now"),
67    };
68
69    let mod_name = Ident::new(&format!("{}_mod", ident_name), Span::call_site());
70
71    let struct_def = quote![
72        #[derive(Clone, Copy, Eq, PartialEq, Hash)]
73        pub struct #ident(pub #ty);
74    ];
75
76    let struct_impl = quote![
77        impl #ident {
78            pub const MASK: #ty = #mask as #ty;
79        }
80    ];
81
82    let from = quote![
83        impl ::core::convert::From<#ident> for #ty {
84            fn from(v: #ident) -> #ty {
85                v.0
86            }
87        }
88    ];
89
90    let into = quote![
91        impl ::core::convert::From<#ty> for #ident {
92            fn from(v: #ty) -> #ident {
93                #ident(v & #ident::MASK)
94            }
95        }
96    ];
97
98    let tmpl = format!("{{}}_{}", ident_name);
99    let debug = quote![
100        impl ::core::fmt::Debug for #ident {
101            fn fmt<'a>(&self, f: &mut ::core::fmt::Formatter<'a>) -> ::core::fmt::Result {
102                write!(f, #tmpl, self.0)
103            }
104        }
105    ];
106    let display = quote![
107        impl ::core::fmt::Display for #ident {
108            fn fmt<'a>(&self, f: &mut ::core::fmt::Formatter<'a>) -> ::core::fmt::Result {
109                write!(f, #tmpl, self.0)
110            }
111        }
112    ];
113    let add = quote![
114        impl ::core::ops::Add for #ident {
115            type Output = Self;
116
117            #[allow(clippy::suspicious_arithmetic_impl)]
118            fn add(self, rhs: Self) -> Self {
119                Self(self.0 ^ rhs.0)
120            }
121        }
122        impl ::core::ops::AddAssign for #ident {
123            fn add_assign(&mut self, rhs: Self) {
124                *self = *self + rhs;
125            }
126        }
127    ];
128    let sum = quote![
129        impl ::core::iter::Sum for #ident {
130            fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
131                iter.fold(<Self as ::g2p::GaloisField>::ZERO, ::core::ops::Add::add)
132            }
133        }
134    ];
135    let sub = quote![
136        impl ::core::ops::Sub for #ident {
137            type Output = Self;
138
139
140            #[allow(clippy::suspicious_arithmetic_impl)]
141            fn sub(self, rhs: Self) -> Self {
142                Self(self.0 ^ rhs.0)
143            }
144        }
145        impl ::core::ops::SubAssign for #ident {
146            fn sub_assign(&mut self, rhs: Self) {
147                *self = *self - rhs;
148            }
149        }
150        impl ::core::ops::Neg for #ident {
151            type Output = Self;
152
153            fn neg(self) -> Self::Output {
154                self
155            }
156        }
157    ];
158    let gen = generator.0;
159    let modulus_val = modulus.0;
160    let galois_trait_impl = quote![
161        impl ::g2p::GaloisField for #ident {
162            const SIZE: usize = #field_size;
163            const MODULUS: ::g2p::G2Poly = ::g2p::G2Poly(#modulus_val);
164            const ZERO: Self = Self(0);
165            const ONE: Self = Self(1);
166            const GENERATOR: Self = Self(#gen as #ty);
167        }
168    ];
169
170    let (tables, mul, div) =
171        generate_mul_impl(ident.clone(), &ident_name, modulus, ty, field_size, mask);
172    let product = quote![
173        impl ::core::iter::Product for #ident {
174            fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
175                iter.fold(<Self as ::g2p::GaloisField>::ONE, ::core::ops::Mul::mul)
176            }
177        }
178    ];
179
180    P1TokenStream::from(quote![
181        #struct_def
182
183        mod #mod_name {
184            use super::#ident;
185            #struct_impl
186            #tables
187            #from
188            #into
189            #debug
190            #display
191            #add
192            #sum
193            #sub
194            #mul
195            #product
196            #div
197            #galois_trait_impl
198        }
199    ])
200}
201
202struct ParsedInput {
203    ident: syn::Ident,
204    p: syn::LitInt,
205    modulus: Option<syn::LitInt>,
206}
207
208impl Parse for ParsedInput {
209    fn parse(input: ParseStream) -> syn::Result<Self> {
210        let ident = input.parse()?;
211        let _sep: Token![,] = input.parse()?;
212        let p = input.parse()?;
213
214        let mut modulus = None;
215
216        loop {
217            let sep: Option<Token![,]> = input.parse()?;
218            if sep.is_none() || input.is_empty() {
219                break;
220            }
221            let ident: syn::Ident = input.parse()?;
222            let ident_name = ident.to_string();
223            let _sep: Token![:] = input.parse()?;
224            match ident_name.as_str() {
225                "modulus" => {
226                    if modulus.is_some() {
227                        Err(syn::parse::Error::new(
228                            ident.span(),
229                            "Double declaration of 'modulus'",
230                        ))?
231                    }
232                    modulus = Some(input.parse()?);
233                }
234                _ => Err(syn::parse::Error::new(ident.span(), "Expected 'modulus'"))?,
235            }
236        }
237
238        Ok(ParsedInput { ident, p, modulus })
239    }
240}
241
242#[derive(Debug, Clone, Eq, PartialEq)]
243struct Settings {
244    ident: syn::Ident,
245    ident_name: String,
246    p_val: u64,
247    modulus: G2Poly,
248    generator: G2Poly,
249}
250
251fn find_modulus_poly(p: u64) -> G2Poly {
252    assert!(p < 64);
253
254    let start = (1 << p) + 1;
255    let end = (1_u64 << (p + 1)).wrapping_sub(1);
256
257    for m in start..=end {
258        let p = G2Poly(m);
259        if p.is_irreducible() {
260            return p;
261        }
262    }
263
264    unreachable!("There are irreducible polynomial for any degree!")
265}
266
267fn find_generator(m: G2Poly) -> G2Poly {
268    let max = m.degree().expect("Modulus must have positive degree");
269
270    for g in 1..(2 << max) {
271        let g = G2Poly(g);
272        if g.is_generator(m) {
273            return g;
274        }
275    }
276
277    unreachable!("There must be a generator element")
278}
279
280/// Calculate the log base 256, rounded up
281///
282/// Given a number n, calculate the log base 256, rounded up. This can be though of as the number
283/// of bytes needed to represent this number.
284fn ceil_log256(mut n: usize) -> usize {
285    if n == 0 {
286        return 0;
287    }
288
289    let mut c = 1;
290    while n > 256 {
291        c += 1;
292        // NB: This is the rounding up part. If n is a proper power of 256, adding 255 will not
293        // change the result. In the other cases, this ensures that we round up in the division.
294        n = (n + 255) >> 8;
295    }
296    c
297}
298
299/// Generate multiplication array
300///
301/// Generate a string representing a 5d multiplication array. This array uses the associativity
302/// of multiplication `(a + b) * (c + d) == a*c + a*d + b*c + b*d` to reduce table size.
303///
304/// The input is split into bit chunks e.g. for a GF_1024 number we take the lower 8 bit and the
305/// remaining 2 and calculate the multiplications for each separately. Then we can cheaply add them
306/// together to get the the result with requiring a full 1024 * 1024 input.
307fn generate_mul_table_string(modulus: G2Poly) -> String {
308    assert!(modulus.is_irreducible());
309
310    let field_size = 1
311        << modulus
312            .degree()
313            .expect("Irreducible polynomial has positive degree");
314    let nparts = ceil_log256(field_size as usize);
315
316    let mut mul_table = Vec::with_capacity(nparts);
317    for left in 0..nparts {
318        let mut left_parts = Vec::with_capacity(nparts);
319        for right in 0..nparts {
320            let mut right_parts = Vec::with_capacity(256);
321            for i in 0..256 {
322                let i = i << (8 * left);
323                let mut row = Vec::with_capacity(256);
324                for j in 0..256 {
325                    let j = j << (8 * right);
326                    let v = if i < field_size && j < field_size {
327                        G2Poly(i as u64) * G2Poly(j as u64) % modulus
328                    } else {
329                        G2Poly(0)
330                    };
331
332                    row.push(format!("{}", v.0));
333                }
334                right_parts.push(format!("[{}]", row.join(",")));
335            }
336            left_parts.push(format!("[{}]", right_parts.join(",")));
337        }
338        mul_table.push(format!("[{}]", left_parts.join(",")));
339    }
340
341    format!("[{}]", mul_table.join(","))
342}
343
344fn generate_inv_table_string(modulus: G2Poly) -> String {
345    assert!(modulus.is_irreducible());
346
347    let field_size = 1
348        << modulus
349            .degree()
350            .expect("Irreducible polynomial has positive degree");
351    let mut inv_table = vec![0; field_size as usize];
352    // Inverse table is small enough to compute directly
353    for i in 1..field_size {
354        if inv_table[i as usize] != 0 {
355            // Already computed inverse
356            continue;
357        }
358
359        let a = G2Poly(i);
360
361        // Returns (gcd, x, y) such that gcd(a, m) == a * x + y * m
362        // Since we know that gcd(a, m) == 1 and that we operate modulo m, y * m === 0 mod m
363        // So we have 1 === a * x mod m
364
365        let (_gcd, x, _y) = extended_gcd(a, modulus);
366        inv_table[i as usize] = x.0;
367        inv_table[x.0 as usize] = i;
368    }
369
370    use std::fmt::Write;
371    let mut res = String::with_capacity(3 * field_size as usize);
372    write!(&mut res, "[").unwrap();
373    for v in inv_table {
374        write!(&mut res, "{},", v).unwrap();
375    }
376    write!(&mut res, "]").unwrap();
377    res
378}
379
380fn generate_mul_impl(
381    ident: syn::Ident,
382    ident_name: &str,
383    modulus: G2Poly,
384    ty: P2TokenStream,
385    field_size: usize,
386    mask: u64,
387) -> (P2TokenStream, P2TokenStream, P2TokenStream) {
388    let mul_table = generate_mul_table_string(modulus);
389    let inv_table = generate_inv_table_string(modulus);
390
391    // Faster generation than using quote
392    let mul_table_string: proc_macro2::TokenStream = mul_table.parse().unwrap();
393    let inv_table_string: proc_macro2::TokenStream = inv_table.parse().unwrap();
394
395    let nparts = ceil_log256(field_size);
396
397    // NB: We generate static arrays, as they are guaranteed to have a fixed location in memory.
398    //     Using const would mean the compiler is free to create copies on the stack etc. Since
399    //     The arrays are quite large, this could lead to stack overflows.
400    let tables = quote! {
401        pub static MUL_TABLE: [[[[#ty; 256]; 256]; #nparts]; #nparts] = #mul_table_string;
402        pub static INV_TABLE: [#ty; #field_size] = #inv_table_string;
403    };
404
405    let mut mul_ops = Vec::with_capacity(nparts * nparts);
406    for left in 0..nparts {
407        for right in 0..nparts {
408            mul_ops.push(quote![
409                #ident(MUL_TABLE[#left][#right][(((self.0 & #mask as #ty) >> (8*#left)) & 255) as usize][(((rhs.0 & #mask as #ty) >> (8*#right)) & 255) as usize])
410            ]);
411        }
412    }
413
414    let mul = quote![
415        impl ::core::ops::Mul for #ident {
416            type Output = Self;
417            fn mul(self, rhs: Self) -> Self {
418                #(#mul_ops)+*
419            }
420        }
421        impl ::core::ops::MulAssign for #ident {
422            fn mul_assign(&mut self, rhs: Self) {
423                *self = *self * rhs;
424            }
425        }
426    ];
427
428    let err_msg = format!("Division by 0 in {}", ident_name);
429
430    let div = quote![
431        impl ::core::ops::Div for #ident {
432            type Output = Self;
433
434            fn div(self, rhs: Self) -> Self {
435                if (rhs.0 & #mask as #ty) == 0 {
436                    panic!(#err_msg);
437                }
438                self * Self(INV_TABLE[(rhs.0 & #mask as #ty) as usize])
439            }
440        }
441        impl ::core::ops::DivAssign for #ident {
442            fn div_assign(&mut self, rhs: Self) {
443                *self = *self / rhs;
444            }
445        }
446    ];
447
448    (tables, mul, div)
449}
450
451impl Settings {
452    pub fn from_input(input: ParsedInput) -> syn::Result<Self> {
453        let ident = input.ident;
454        let ident_name = ident.to_string();
455        let p_val = input.p.base10_parse()?;
456        let modulus = match input.modulus {
457            Some(lit) => G2Poly(lit.base10_parse()?),
458            None => find_modulus_poly(p_val),
459        };
460
461        if !modulus.is_irreducible() {
462            Err(syn::Error::new(
463                Span::call_site(),
464                format!("Modulus {} is not irreducible", modulus),
465            ))?;
466        }
467
468        let generator = find_generator(modulus);
469
470        if !generator.is_generator(modulus) {
471            Err(syn::Error::new(
472                Span::call_site(),
473                format!("{} is not a generator", generator),
474            ))?;
475        }
476
477        Ok(Settings {
478            ident,
479            ident_name,
480            p_val,
481            modulus,
482            generator,
483        })
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn test_settings_parser() {
493        let span = Span::call_site();
494
495        let input = ParsedInput {
496            ident: Ident::new("foo", span),
497            p: syn::LitInt::new("3", span),
498            modulus: None,
499        };
500
501        let r = Settings::from_input(input);
502        assert!(r.is_ok());
503        assert_eq!(
504            r.unwrap(),
505            Settings {
506                ident: syn::Ident::new("foo", span),
507                ident_name: "foo".to_string(),
508                p_val: 3,
509                modulus: G2Poly(0b1011),
510                generator: G2Poly(0b10),
511            }
512        );
513    }
514
515    #[test]
516    fn test_generate_mul_table() {
517        let m = G2Poly(0b111);
518
519        assert_eq!(
520            include_str!("../tests/mul_table.txt").trim(),
521            generate_mul_table_string(m)
522        );
523    }
524
525    #[test]
526    fn test_generate_inv_table_string() {
527        let m = G2Poly(0b1_0001_1011);
528
529        assert_eq!(
530            include_str!("../tests/inv_table.txt").trim(),
531            generate_inv_table_string(m)
532        );
533    }
534
535    #[test]
536    fn test_ceil_log256() {
537        assert_eq!(0, ceil_log256(0));
538        assert_eq!(1, ceil_log256(1));
539        assert_eq!(1, ceil_log256(256));
540        assert_eq!(2, ceil_log256(257));
541        assert_eq!(2, ceil_log256(65536));
542        assert_eq!(3, ceil_log256(65537));
543        assert_eq!(3, ceil_log256(131072));
544        assert_eq!(3, ceil_log256(16777216));
545        assert_eq!(4, ceil_log256(16777217));
546    }
547}