1#![recursion_limit = "128"]
12extern crate proc_macro;
13
14use proc_macro::TokenStream as P1TokenStream;
15use proc_macro2::{Ident, Span, TokenStream as P2TokenStream};
16
17use g2poly::{extended_gcd, G2Poly};
18use quote::quote;
19use syn::{
20 parse::{Parse, ParseStream},
21 parse_macro_input, Token,
22};
23
24#[proc_macro]
50pub fn g2p(input: P1TokenStream) -> P1TokenStream {
51 let args = parse_macro_input!(input as ParsedInput);
52 let settings = Settings::from_input(args).unwrap();
53 let ident = settings.ident;
54 let ident_name = settings.ident_name;
55 let modulus = settings.modulus;
56 let generator = settings.generator;
57 let p = settings.p_val;
58 let field_size = 1_usize << p;
59 let mask = (1_u64 << p).wrapping_sub(1);
60
61 let ty = match p {
62 0 => panic!("p must be > 0"),
63 1..=8 => quote!(u8),
64 9..=16 => quote!(u16),
65 17..=32 => quote!(u32),
66 _ => unimplemented!("p > 32 is not implemented right now"),
67 };
68
69 let mod_name = Ident::new(&format!("{}_mod", ident_name), Span::call_site());
70
71 let struct_def = quote![
72 #[derive(Clone, Copy, Eq, PartialEq, Hash)]
73 pub struct #ident(pub #ty);
74 ];
75
76 let struct_impl = quote![
77 impl #ident {
78 pub const MASK: #ty = #mask as #ty;
79 }
80 ];
81
82 let from = quote![
83 impl ::core::convert::From<#ident> for #ty {
84 fn from(v: #ident) -> #ty {
85 v.0
86 }
87 }
88 ];
89
90 let into = quote![
91 impl ::core::convert::From<#ty> for #ident {
92 fn from(v: #ty) -> #ident {
93 #ident(v & #ident::MASK)
94 }
95 }
96 ];
97
98 let tmpl = format!("{{}}_{}", ident_name);
99 let debug = quote![
100 impl ::core::fmt::Debug for #ident {
101 fn fmt<'a>(&self, f: &mut ::core::fmt::Formatter<'a>) -> ::core::fmt::Result {
102 write!(f, #tmpl, self.0)
103 }
104 }
105 ];
106 let display = quote![
107 impl ::core::fmt::Display for #ident {
108 fn fmt<'a>(&self, f: &mut ::core::fmt::Formatter<'a>) -> ::core::fmt::Result {
109 write!(f, #tmpl, self.0)
110 }
111 }
112 ];
113 let add = quote![
114 impl ::core::ops::Add for #ident {
115 type Output = Self;
116
117 #[allow(clippy::suspicious_arithmetic_impl)]
118 fn add(self, rhs: Self) -> Self {
119 Self(self.0 ^ rhs.0)
120 }
121 }
122 impl ::core::ops::AddAssign for #ident {
123 fn add_assign(&mut self, rhs: Self) {
124 *self = *self + rhs;
125 }
126 }
127 ];
128 let sum = quote![
129 impl ::core::iter::Sum for #ident {
130 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
131 iter.fold(<Self as ::g2p::GaloisField>::ZERO, ::core::ops::Add::add)
132 }
133 }
134 ];
135 let sub = quote![
136 impl ::core::ops::Sub for #ident {
137 type Output = Self;
138
139
140 #[allow(clippy::suspicious_arithmetic_impl)]
141 fn sub(self, rhs: Self) -> Self {
142 Self(self.0 ^ rhs.0)
143 }
144 }
145 impl ::core::ops::SubAssign for #ident {
146 fn sub_assign(&mut self, rhs: Self) {
147 *self = *self - rhs;
148 }
149 }
150 impl ::core::ops::Neg for #ident {
151 type Output = Self;
152
153 fn neg(self) -> Self::Output {
154 self
155 }
156 }
157 ];
158 let gen = generator.0;
159 let modulus_val = modulus.0;
160 let galois_trait_impl = quote![
161 impl ::g2p::GaloisField for #ident {
162 const SIZE: usize = #field_size;
163 const MODULUS: ::g2p::G2Poly = ::g2p::G2Poly(#modulus_val);
164 const ZERO: Self = Self(0);
165 const ONE: Self = Self(1);
166 const GENERATOR: Self = Self(#gen as #ty);
167 }
168 ];
169
170 let (tables, mul, div) =
171 generate_mul_impl(ident.clone(), &ident_name, modulus, ty, field_size, mask);
172 let product = quote![
173 impl ::core::iter::Product for #ident {
174 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
175 iter.fold(<Self as ::g2p::GaloisField>::ONE, ::core::ops::Mul::mul)
176 }
177 }
178 ];
179
180 P1TokenStream::from(quote![
181 #struct_def
182
183 mod #mod_name {
184 use super::#ident;
185 #struct_impl
186 #tables
187 #from
188 #into
189 #debug
190 #display
191 #add
192 #sum
193 #sub
194 #mul
195 #product
196 #div
197 #galois_trait_impl
198 }
199 ])
200}
201
202struct ParsedInput {
203 ident: syn::Ident,
204 p: syn::LitInt,
205 modulus: Option<syn::LitInt>,
206}
207
208impl Parse for ParsedInput {
209 fn parse(input: ParseStream) -> syn::Result<Self> {
210 let ident = input.parse()?;
211 let _sep: Token![,] = input.parse()?;
212 let p = input.parse()?;
213
214 let mut modulus = None;
215
216 loop {
217 let sep: Option<Token![,]> = input.parse()?;
218 if sep.is_none() || input.is_empty() {
219 break;
220 }
221 let ident: syn::Ident = input.parse()?;
222 let ident_name = ident.to_string();
223 let _sep: Token![:] = input.parse()?;
224 match ident_name.as_str() {
225 "modulus" => {
226 if modulus.is_some() {
227 Err(syn::parse::Error::new(
228 ident.span(),
229 "Double declaration of 'modulus'",
230 ))?
231 }
232 modulus = Some(input.parse()?);
233 }
234 _ => Err(syn::parse::Error::new(ident.span(), "Expected 'modulus'"))?,
235 }
236 }
237
238 Ok(ParsedInput { ident, p, modulus })
239 }
240}
241
242#[derive(Debug, Clone, Eq, PartialEq)]
243struct Settings {
244 ident: syn::Ident,
245 ident_name: String,
246 p_val: u64,
247 modulus: G2Poly,
248 generator: G2Poly,
249}
250
251fn find_modulus_poly(p: u64) -> G2Poly {
252 assert!(p < 64);
253
254 let start = (1 << p) + 1;
255 let end = (1_u64 << (p + 1)).wrapping_sub(1);
256
257 for m in start..=end {
258 let p = G2Poly(m);
259 if p.is_irreducible() {
260 return p;
261 }
262 }
263
264 unreachable!("There are irreducible polynomial for any degree!")
265}
266
267fn find_generator(m: G2Poly) -> G2Poly {
268 let max = m.degree().expect("Modulus must have positive degree");
269
270 for g in 1..(2 << max) {
271 let g = G2Poly(g);
272 if g.is_generator(m) {
273 return g;
274 }
275 }
276
277 unreachable!("There must be a generator element")
278}
279
280fn ceil_log256(mut n: usize) -> usize {
285 if n == 0 {
286 return 0;
287 }
288
289 let mut c = 1;
290 while n > 256 {
291 c += 1;
292 n = (n + 255) >> 8;
295 }
296 c
297}
298
299fn generate_mul_table_string(modulus: G2Poly) -> String {
308 assert!(modulus.is_irreducible());
309
310 let field_size = 1
311 << modulus
312 .degree()
313 .expect("Irreducible polynomial has positive degree");
314 let nparts = ceil_log256(field_size as usize);
315
316 let mut mul_table = Vec::with_capacity(nparts);
317 for left in 0..nparts {
318 let mut left_parts = Vec::with_capacity(nparts);
319 for right in 0..nparts {
320 let mut right_parts = Vec::with_capacity(256);
321 for i in 0..256 {
322 let i = i << (8 * left);
323 let mut row = Vec::with_capacity(256);
324 for j in 0..256 {
325 let j = j << (8 * right);
326 let v = if i < field_size && j < field_size {
327 G2Poly(i as u64) * G2Poly(j as u64) % modulus
328 } else {
329 G2Poly(0)
330 };
331
332 row.push(format!("{}", v.0));
333 }
334 right_parts.push(format!("[{}]", row.join(",")));
335 }
336 left_parts.push(format!("[{}]", right_parts.join(",")));
337 }
338 mul_table.push(format!("[{}]", left_parts.join(",")));
339 }
340
341 format!("[{}]", mul_table.join(","))
342}
343
344fn generate_inv_table_string(modulus: G2Poly) -> String {
345 assert!(modulus.is_irreducible());
346
347 let field_size = 1
348 << modulus
349 .degree()
350 .expect("Irreducible polynomial has positive degree");
351 let mut inv_table = vec![0; field_size as usize];
352 for i in 1..field_size {
354 if inv_table[i as usize] != 0 {
355 continue;
357 }
358
359 let a = G2Poly(i);
360
361 let (_gcd, x, _y) = extended_gcd(a, modulus);
366 inv_table[i as usize] = x.0;
367 inv_table[x.0 as usize] = i;
368 }
369
370 use std::fmt::Write;
371 let mut res = String::with_capacity(3 * field_size as usize);
372 write!(&mut res, "[").unwrap();
373 for v in inv_table {
374 write!(&mut res, "{},", v).unwrap();
375 }
376 write!(&mut res, "]").unwrap();
377 res
378}
379
380fn generate_mul_impl(
381 ident: syn::Ident,
382 ident_name: &str,
383 modulus: G2Poly,
384 ty: P2TokenStream,
385 field_size: usize,
386 mask: u64,
387) -> (P2TokenStream, P2TokenStream, P2TokenStream) {
388 let mul_table = generate_mul_table_string(modulus);
389 let inv_table = generate_inv_table_string(modulus);
390
391 let mul_table_string: proc_macro2::TokenStream = mul_table.parse().unwrap();
393 let inv_table_string: proc_macro2::TokenStream = inv_table.parse().unwrap();
394
395 let nparts = ceil_log256(field_size);
396
397 let tables = quote! {
401 pub static MUL_TABLE: [[[[#ty; 256]; 256]; #nparts]; #nparts] = #mul_table_string;
402 pub static INV_TABLE: [#ty; #field_size] = #inv_table_string;
403 };
404
405 let mut mul_ops = Vec::with_capacity(nparts * nparts);
406 for left in 0..nparts {
407 for right in 0..nparts {
408 mul_ops.push(quote![
409 #ident(MUL_TABLE[#left][#right][(((self.0 & #mask as #ty) >> (8*#left)) & 255) as usize][(((rhs.0 & #mask as #ty) >> (8*#right)) & 255) as usize])
410 ]);
411 }
412 }
413
414 let mul = quote![
415 impl ::core::ops::Mul for #ident {
416 type Output = Self;
417 fn mul(self, rhs: Self) -> Self {
418 #(#mul_ops)+*
419 }
420 }
421 impl ::core::ops::MulAssign for #ident {
422 fn mul_assign(&mut self, rhs: Self) {
423 *self = *self * rhs;
424 }
425 }
426 ];
427
428 let err_msg = format!("Division by 0 in {}", ident_name);
429
430 let div = quote![
431 impl ::core::ops::Div for #ident {
432 type Output = Self;
433
434 fn div(self, rhs: Self) -> Self {
435 if (rhs.0 & #mask as #ty) == 0 {
436 panic!(#err_msg);
437 }
438 self * Self(INV_TABLE[(rhs.0 & #mask as #ty) as usize])
439 }
440 }
441 impl ::core::ops::DivAssign for #ident {
442 fn div_assign(&mut self, rhs: Self) {
443 *self = *self / rhs;
444 }
445 }
446 ];
447
448 (tables, mul, div)
449}
450
451impl Settings {
452 pub fn from_input(input: ParsedInput) -> syn::Result<Self> {
453 let ident = input.ident;
454 let ident_name = ident.to_string();
455 let p_val = input.p.base10_parse()?;
456 let modulus = match input.modulus {
457 Some(lit) => G2Poly(lit.base10_parse()?),
458 None => find_modulus_poly(p_val),
459 };
460
461 if !modulus.is_irreducible() {
462 Err(syn::Error::new(
463 Span::call_site(),
464 format!("Modulus {} is not irreducible", modulus),
465 ))?;
466 }
467
468 let generator = find_generator(modulus);
469
470 if !generator.is_generator(modulus) {
471 Err(syn::Error::new(
472 Span::call_site(),
473 format!("{} is not a generator", generator),
474 ))?;
475 }
476
477 Ok(Settings {
478 ident,
479 ident_name,
480 p_val,
481 modulus,
482 generator,
483 })
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
492 fn test_settings_parser() {
493 let span = Span::call_site();
494
495 let input = ParsedInput {
496 ident: Ident::new("foo", span),
497 p: syn::LitInt::new("3", span),
498 modulus: None,
499 };
500
501 let r = Settings::from_input(input);
502 assert!(r.is_ok());
503 assert_eq!(
504 r.unwrap(),
505 Settings {
506 ident: syn::Ident::new("foo", span),
507 ident_name: "foo".to_string(),
508 p_val: 3,
509 modulus: G2Poly(0b1011),
510 generator: G2Poly(0b10),
511 }
512 );
513 }
514
515 #[test]
516 fn test_generate_mul_table() {
517 let m = G2Poly(0b111);
518
519 assert_eq!(
520 include_str!("../tests/mul_table.txt").trim(),
521 generate_mul_table_string(m)
522 );
523 }
524
525 #[test]
526 fn test_generate_inv_table_string() {
527 let m = G2Poly(0b1_0001_1011);
528
529 assert_eq!(
530 include_str!("../tests/inv_table.txt").trim(),
531 generate_inv_table_string(m)
532 );
533 }
534
535 #[test]
536 fn test_ceil_log256() {
537 assert_eq!(0, ceil_log256(0));
538 assert_eq!(1, ceil_log256(1));
539 assert_eq!(1, ceil_log256(256));
540 assert_eq!(2, ceil_log256(257));
541 assert_eq!(2, ceil_log256(65536));
542 assert_eq!(3, ceil_log256(65537));
543 assert_eq!(3, ceil_log256(131072));
544 assert_eq!(3, ceil_log256(16777216));
545 assert_eq!(4, ceil_log256(16777217));
546 }
547}