1#![recursion_limit = "128"]
12extern crate proc_macro;
13
14use proc_macro::TokenStream as P1TokenStream;
15use proc_macro2::{Ident, Span, TokenStream as P2TokenStream};
16
17use g2poly::{extended_gcd, G2Poly};
18use quote::quote;
19use syn::{
20 parse::{Parse, ParseStream},
21 parse_macro_input, Token,
22};
23
24#[proc_macro]
50pub fn g2p(input: P1TokenStream) -> P1TokenStream {
51 let args = parse_macro_input!(input as ParsedInput);
52 let settings = Settings::from_input(args).unwrap();
53 let ident = settings.ident;
54 let ident_name = settings.ident_name;
55 let modulus = settings.modulus;
56 let generator = settings.generator;
57 let p = settings.p_val;
58 let field_size = 1_usize << p;
59 let mask = (1_u64 << p).wrapping_sub(1);
60
61 let ty = match p {
62 0 => panic!("p must be > 0"),
63 1..=8 => quote!(u8),
64 9..=16 => quote!(u16),
65 17..=32 => quote!(u32),
66 _ => unimplemented!("p > 32 is not implemented right now"),
67 };
68
69 let mod_name = Ident::new(&format!("{}_mod", ident_name), Span::call_site());
70
71 let struct_def = quote![
72 #[derive(Clone, Copy, Hash)]
73 pub struct #ident(pub #ty);
74 ];
75
76 let struct_impl = quote![
77 impl #ident {
78 pub const MASK: #ty = #mask as #ty;
79 }
80 ];
81
82 let from = quote![
83 impl ::core::convert::From<#ident> for #ty {
84 fn from(v: #ident) -> #ty {
85 v.0
86 }
87 }
88 ];
89
90 let into = quote![
91 impl ::core::convert::From<#ty> for #ident {
92 fn from(v: #ty) -> #ident {
93 #ident(v & #ident::MASK)
94 }
95 }
96 ];
97
98 let eq = quote![
99 impl ::core::cmp::PartialEq<Self> for #ident {
100 fn eq(&self, other: &Self) -> bool {
101 self.0 == other.0
102 }
103 }
104
105 impl ::core::cmp::Eq for #ident {}
106 ];
107
108 let tmpl = format!("{{}}_{}", ident_name);
109 let debug = quote![
110 impl ::core::fmt::Debug for #ident {
111 fn fmt<'a>(&self, f: &mut ::core::fmt::Formatter<'a>) -> ::core::fmt::Result {
112 write!(f, #tmpl, self.0)
113 }
114 }
115 ];
116 let display = quote![
117 impl ::core::fmt::Display for #ident {
118 fn fmt<'a>(&self, f: &mut ::core::fmt::Formatter<'a>) -> ::core::fmt::Result {
119 write!(f, #tmpl, self.0)
120 }
121 }
122 ];
123 let add = quote![
124 impl ::core::ops::Add for #ident {
125 type Output = Self;
126
127 #[allow(clippy::suspicious_arithmetic_impl)]
128 fn add(self, rhs: Self) -> Self {
129 Self(self.0 ^ rhs.0)
130 }
131 }
132 impl ::core::ops::AddAssign for #ident {
133 fn add_assign(&mut self, rhs: Self) {
134 *self = *self + rhs;
135 }
136 }
137 ];
138 let sum = quote![
139 impl ::core::iter::Sum for #ident {
140 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
141 iter.fold(<Self as ::g2p::GaloisField>::ZERO, ::core::ops::Add::add)
142 }
143 }
144 ];
145 let sub = quote![
146 impl ::core::ops::Sub for #ident {
147 type Output = Self;
148
149
150 #[allow(clippy::suspicious_arithmetic_impl)]
151 fn sub(self, rhs: Self) -> Self {
152 Self(self.0 ^ rhs.0)
153 }
154 }
155 impl ::core::ops::SubAssign for #ident {
156 fn sub_assign(&mut self, rhs: Self) {
157 *self = *self - rhs;
158 }
159 }
160 impl ::core::ops::Neg for #ident {
161 type Output = Self;
162
163 fn neg(self) -> Self::Output {
164 self
165 }
166 }
167 ];
168 let gen = generator.0;
169 let modulus_val = modulus.0;
170 let galois_trait_impl = quote![
171 impl ::g2p::GaloisField for #ident {
172 const SIZE: usize = #field_size;
173 const MODULUS: ::g2p::G2Poly = ::g2p::G2Poly(#modulus_val);
174 const ZERO: Self = Self(0);
175 const ONE: Self = Self(1);
176 const GENERATOR: Self = Self(#gen as #ty);
177 }
178 ];
179
180 let (tables, mul, div) =
181 generate_mul_impl(ident.clone(), &ident_name, modulus, ty, field_size, mask);
182 let product = quote![
183 impl ::core::iter::Product for #ident {
184 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
185 iter.fold(<Self as ::g2p::GaloisField>::ONE, ::core::ops::Mul::mul)
186 }
187 }
188 ];
189
190 P1TokenStream::from(quote![
191 #struct_def
192
193 mod #mod_name {
194 use super::#ident;
195 #struct_impl
196 #tables
197 #from
198 #into
199 #eq
200 #debug
201 #display
202 #add
203 #sum
204 #sub
205 #mul
206 #product
207 #div
208 #galois_trait_impl
209 }
210 ])
211}
212
213struct ParsedInput {
214 ident: syn::Ident,
215 p: syn::LitInt,
216 modulus: Option<syn::LitInt>,
217}
218
219impl Parse for ParsedInput {
220 fn parse(input: ParseStream) -> syn::Result<Self> {
221 let ident = input.parse()?;
222 let _sep: Token![,] = input.parse()?;
223 let p = input.parse()?;
224
225 let mut modulus = None;
226
227 loop {
228 let sep: Option<Token![,]> = input.parse()?;
229 if sep.is_none() || input.is_empty() {
230 break;
231 }
232 let ident: syn::Ident = input.parse()?;
233 let ident_name = ident.to_string();
234 let _sep: Token![:] = input.parse()?;
235 match ident_name.as_str() {
236 "modulus" => {
237 if modulus.is_some() {
238 Err(syn::parse::Error::new(
239 ident.span(),
240 "Double declaration of 'modulus'",
241 ))?
242 }
243 modulus = Some(input.parse()?);
244 }
245 _ => Err(syn::parse::Error::new(ident.span(), "Expected 'modulus'"))?,
246 }
247 }
248
249 Ok(ParsedInput { ident, p, modulus })
250 }
251}
252
253#[derive(Debug, Clone, Eq, PartialEq)]
254struct Settings {
255 ident: syn::Ident,
256 ident_name: String,
257 p_val: u64,
258 modulus: G2Poly,
259 generator: G2Poly,
260}
261
262fn find_modulus_poly(p: u64) -> G2Poly {
263 assert!(p < 64);
264
265 let start = (1 << p) + 1;
266 let end = (1_u64 << (p + 1)).wrapping_sub(1);
267
268 for m in start..=end {
269 let p = G2Poly(m);
270 if p.is_irreducible() {
271 return p;
272 }
273 }
274
275 unreachable!("There are irreducible polynomial for any degree!")
276}
277
278fn find_generator(m: G2Poly) -> G2Poly {
279 let max = m.degree().expect("Modulus must have positive degree");
280
281 for g in 1..(2 << max) {
282 let g = G2Poly(g);
283 if g.is_generator(m) {
284 return g;
285 }
286 }
287
288 unreachable!("There must be a generator element")
289}
290
291fn ceil_log256(mut n: usize) -> usize {
296 if n == 0 {
297 return 0;
298 }
299
300 let mut c = 1;
301 while n > 256 {
302 c += 1;
303 n = (n + 255) >> 8;
306 }
307 c
308}
309
310fn generate_mul_table_string(modulus: G2Poly) -> String {
319 assert!(modulus.is_irreducible());
320
321 let field_size = 1
322 << modulus
323 .degree()
324 .expect("Irreducible polynomial has positive degree");
325 let nparts = ceil_log256(field_size as usize);
326
327 let mut mul_table = Vec::with_capacity(nparts);
328 for left in 0..nparts {
329 let mut left_parts = Vec::with_capacity(nparts);
330 for right in 0..nparts {
331 let mut right_parts = Vec::with_capacity(256);
332 for i in 0..256 {
333 let i = i << (8 * left);
334 let mut row = Vec::with_capacity(256);
335 for j in 0..256 {
336 let j = j << (8 * right);
337 let v = if i < field_size && j < field_size {
338 G2Poly(i as u64) * G2Poly(j as u64) % modulus
339 } else {
340 G2Poly(0)
341 };
342
343 row.push(format!("{}", v.0));
344 }
345 right_parts.push(format!("[{}]", row.join(",")));
346 }
347 left_parts.push(format!("[{}]", right_parts.join(",")));
348 }
349 mul_table.push(format!("[{}]", left_parts.join(",")));
350 }
351
352 format!("[{}]", mul_table.join(","))
353}
354
355fn generate_inv_table_string(modulus: G2Poly) -> String {
356 assert!(modulus.is_irreducible());
357
358 let field_size = 1
359 << modulus
360 .degree()
361 .expect("Irreducible polynomial has positive degree");
362 let mut inv_table = vec![0; field_size as usize];
363 for i in 1..field_size {
365 if inv_table[i as usize] != 0 {
366 continue;
368 }
369
370 let a = G2Poly(i);
371
372 let (_gcd, x, _y) = extended_gcd(a, modulus);
377 inv_table[i as usize] = x.0;
378 inv_table[x.0 as usize] = i;
379 }
380
381 use std::fmt::Write;
382 let mut res = String::with_capacity(3 * field_size as usize);
383 write!(&mut res, "[").unwrap();
384 for v in inv_table {
385 write!(&mut res, "{},", v).unwrap();
386 }
387 write!(&mut res, "]").unwrap();
388 res
389}
390
391fn generate_mul_impl(
392 ident: syn::Ident,
393 ident_name: &str,
394 modulus: G2Poly,
395 ty: P2TokenStream,
396 field_size: usize,
397 mask: u64,
398) -> (P2TokenStream, P2TokenStream, P2TokenStream) {
399 let mul_table = generate_mul_table_string(modulus);
400 let inv_table = generate_inv_table_string(modulus);
401
402 let mul_table_string: proc_macro2::TokenStream = mul_table.parse().unwrap();
404 let inv_table_string: proc_macro2::TokenStream = inv_table.parse().unwrap();
405
406 let nparts = ceil_log256(field_size);
407
408 let tables = quote! {
412 pub static MUL_TABLE: [[[[#ty; 256]; 256]; #nparts]; #nparts] = #mul_table_string;
413 pub static INV_TABLE: [#ty; #field_size] = #inv_table_string;
414 };
415
416 let mut mul_ops = Vec::with_capacity(nparts * nparts);
417 for left in 0..nparts {
418 for right in 0..nparts {
419 mul_ops.push(quote![
420 #ident(MUL_TABLE[#left][#right][(((self.0 & #mask as #ty) >> (8*#left)) & 255) as usize][(((rhs.0 & #mask as #ty) >> (8*#right)) & 255) as usize])
421 ]);
422 }
423 }
424
425 let mul = quote![
426 impl ::core::ops::Mul for #ident {
427 type Output = Self;
428 fn mul(self, rhs: Self) -> Self {
429 #(#mul_ops)+*
430 }
431 }
432 impl ::core::ops::MulAssign for #ident {
433 fn mul_assign(&mut self, rhs: Self) {
434 *self = *self * rhs;
435 }
436 }
437 ];
438
439 let err_msg = format!("Division by 0 in {}", ident_name);
440
441 let div = quote![
442 impl ::core::ops::Div for #ident {
443 type Output = Self;
444
445 fn div(self, rhs: Self) -> Self {
446 if (rhs.0 & #mask as #ty) == 0 {
447 panic!(#err_msg);
448 }
449 self * Self(INV_TABLE[(rhs.0 & #mask as #ty) as usize])
450 }
451 }
452 impl ::core::ops::DivAssign for #ident {
453 fn div_assign(&mut self, rhs: Self) {
454 *self = *self / rhs;
455 }
456 }
457 ];
458
459 (tables, mul, div)
460}
461
462impl Settings {
463 pub fn from_input(input: ParsedInput) -> syn::Result<Self> {
464 let ident = input.ident;
465 let ident_name = ident.to_string();
466 let p_val = input.p.base10_parse()?;
467 let modulus = match input.modulus {
468 Some(lit) => G2Poly(lit.base10_parse()?),
469 None => find_modulus_poly(p_val),
470 };
471
472 if !modulus.is_irreducible() {
473 Err(syn::Error::new(
474 Span::call_site(),
475 format!("Modulus {} is not irreducible", modulus),
476 ))?;
477 }
478
479 let generator = find_generator(modulus);
480
481 if !generator.is_generator(modulus) {
482 Err(syn::Error::new(
483 Span::call_site(),
484 format!("{} is not a generator", generator),
485 ))?;
486 }
487
488 Ok(Settings {
489 ident,
490 ident_name,
491 p_val,
492 modulus,
493 generator,
494 })
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn test_settings_parser() {
504 let span = Span::call_site();
505
506 let input = ParsedInput {
507 ident: Ident::new("foo", span),
508 p: syn::LitInt::new("3", span),
509 modulus: None,
510 };
511
512 let r = Settings::from_input(input);
513 assert!(r.is_ok());
514 assert_eq!(
515 r.unwrap(),
516 Settings {
517 ident: syn::Ident::new("foo", span),
518 ident_name: "foo".to_string(),
519 p_val: 3,
520 modulus: G2Poly(0b1011),
521 generator: G2Poly(0b10),
522 }
523 );
524 }
525
526 #[test]
527 fn test_generate_mul_table() {
528 let m = G2Poly(0b111);
529
530 assert_eq!(
531 include_str!("../tests/mul_table.txt").trim(),
532 generate_mul_table_string(m)
533 );
534 }
535
536 #[test]
537 fn test_generate_inv_table_string() {
538 let m = G2Poly(0b1_0001_1011);
539
540 assert_eq!(
541 include_str!("../tests/inv_table.txt").trim(),
542 generate_inv_table_string(m)
543 );
544 }
545
546 #[test]
547 fn test_ceil_log256() {
548 assert_eq!(0, ceil_log256(0));
549 assert_eq!(1, ceil_log256(1));
550 assert_eq!(1, ceil_log256(256));
551 assert_eq!(2, ceil_log256(257));
552 assert_eq!(2, ceil_log256(65536));
553 assert_eq!(3, ceil_log256(65537));
554 assert_eq!(3, ceil_log256(131072));
555 assert_eq!(3, ceil_log256(16777216));
556 assert_eq!(4, ceil_log256(16777217));
557 }
558}