#![recursion_limit = "128"]
extern crate proc_macro;
use proc_macro::TokenStream as P1TokenStream;
use proc_macro2::{TokenStream as P2TokenStream, Ident, Span};
use g2poly::{G2Poly, extended_gcd};
use quote::quote;
use syn::{
parse::{
Parse,
ParseStream,
},
parse_macro_input,
Token,
};
#[proc_macro]
pub fn g2p(input: P1TokenStream) -> P1TokenStream {
let args = parse_macro_input!(input as ParsedInput);
let settings = Settings::from_input(args).unwrap();
let ident = settings.ident;
let ident_name = settings.ident_name;
let modulus = settings.modulus;
let generator = settings.generator;
let p = settings.p_val;
let field_size = 1_usize << p;
let mask = (1_u64 << p).wrapping_sub(1);
let ty = match p {
0 => panic!("p must be > 0"),
1..=8 => quote!(u8),
9..=16 => quote!(u16),
17..=32 => quote!(u32),
_ => unimplemented!("p > 32 is not implemented right now"),
};
let mod_name = Ident::new(&format!("{}_mod", ident_name), Span::call_site());
let struct_def = quote![
pub struct #ident(pub #ty);
];
let struct_impl = quote![
impl #ident {
pub const MASK: #ty = #mask as #ty;
}
];
let from = quote![
impl ::core::convert::From<#ident> for #ty {
fn from(v: #ident) -> #ty {
v.0
}
}
];
let into = quote![
impl ::core::convert::From<#ty> for #ident {
fn from(v: #ty) -> #ident {
#ident(v & #ident::MASK)
}
}
];
let eq = quote![
impl ::core::cmp::PartialEq<#ident> for #ident {
fn eq(&self, other: &#ident) -> bool {
self.0 == other.0
}
}
impl ::core::cmp::Eq for #ident {}
];
let tmpl = format!("{{}}_{}", ident_name);
let debug = quote![
impl ::core::fmt::Debug for #ident {
fn fmt<'a>(&self, f: &mut ::core::fmt::Formatter<'a>) -> ::core::fmt::Result {
write!(f, #tmpl, self.0)
}
}
];
let display = quote![
impl ::core::fmt::Display for #ident {
fn fmt<'a>(&self, f: &mut ::core::fmt::Formatter<'a>) -> ::core::fmt::Result {
write!(f, #tmpl, self.0)
}
}
];
let clone = quote![
impl ::core::clone::Clone for #ident {
fn clone(&self) -> Self {
*self
}
}
];
let copy = quote![
impl ::core::marker::Copy for #ident {}
];
let add = quote![
impl ::core::ops::Add for #ident {
type Output = #ident;
fn add(self, rhs: #ident) -> #ident {
#ident(self.0 ^ rhs.0)
}
}
impl ::core::ops::AddAssign for #ident {
fn add_assign(&mut self, rhs: #ident) {
*self = *self + rhs;
}
}
];
let sub = quote![
impl ::core::ops::Sub for #ident {
type Output = #ident;
fn sub(self, rhs: #ident) -> #ident {
#ident(self.0 ^ rhs.0)
}
}
impl ::core::ops::SubAssign for #ident {
fn sub_assign(&mut self, rhs: #ident) {
*self = *self - rhs;
}
}
];
let gen = generator.0;
let modulus_val = modulus.0;
let galois_trait_impl = quote![
impl ::g2p::GaloisField for #ident {
const SIZE: usize = #field_size;
const MODULUS: ::g2p::G2Poly = ::g2p::G2Poly(#modulus_val);
const ZERO: #ident = #ident(0);
const ONE: #ident = #ident(1);
const GENERATOR: #ident = #ident(#gen as #ty);
}
];
let (tables, mul, div) = generate_mul_impl(
ident.clone(),
&ident_name,
modulus,
ty,
field_size,
mask,
);
P1TokenStream::from(quote![
#struct_def
mod #mod_name {
use super::#ident;
#struct_impl
#tables
#from
#into
#eq
#debug
#display
#clone
#copy
#add
#sub
#mul
#div
#galois_trait_impl
}
])
}
struct ParsedInput {
ident: syn::Ident,
p: syn::LitInt,
modulus: Option<syn::LitInt>,
}
impl Parse for ParsedInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let ident = input.parse()?;
let _sep: Token![,] = input.parse()?;
let p = input.parse()?;
let mut modulus = None;
loop {
let sep: Option<Token![,]> = input.parse()?;
if sep.is_none() || input.is_empty() {
break;
}
let ident: syn::Ident = input.parse()?;
let ident_name = ident.to_string();
let _sep: Token![:] = input.parse()?;
match ident_name.as_str() {
"modulus" => {
if modulus.is_some() {
Err(syn::parse::Error::new(ident.span(), "Double declaration of 'modulus'"))?
}
modulus = Some(input.parse()?);
}
_ => {
Err(syn::parse::Error::new(ident.span(), "Expected 'modulus'"))?
}
}
}
Ok(ParsedInput {
ident,
p,
modulus,
})
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
struct Settings {
ident: syn::Ident,
ident_name: String,
p_val: u64,
modulus: G2Poly,
generator: G2Poly,
}
fn find_modulus_poly(p: u64) -> G2Poly {
assert!(p < 64);
let start = (1 << p) + 1;
let end = (1_u64 << (p + 1)).wrapping_sub(1);
for m in start..=end {
let p = G2Poly(m);
if p.is_irreducible() {
return p;
}
}
unreachable!("There are irreducible polynomial for any degree!")
}
fn find_generator(m: G2Poly) -> G2Poly {
let max = m.degree().expect("Modulus must have positive degree");
for g in 1..(2 << max) {
let g = G2Poly(g);
if g.is_generator(m) {
return g;
}
}
unreachable!("There must be a generator element")
}
fn ceil_log256(mut n: usize) -> usize {
if n == 0 {
return 0;
}
let mut c = 1;
while n > 256 {
c += 1;
n = (n + 255) >> 8;
}
c
}
fn generate_mul_table_string(modulus: G2Poly) -> String {
assert!(modulus.is_irreducible());
let field_size = 1 << modulus.degree().expect("Irreducible polynomial has positive degree");
let nparts = ceil_log256(field_size as usize);
let mut mul_table = Vec::with_capacity(nparts as usize);
for left in 0..nparts {
let mut left_parts = Vec::with_capacity(nparts as usize);
for right in 0..nparts {
let mut right_parts = Vec::with_capacity(256);
for i in 0..256 {
let i = i << 8 * left;
let mut row = Vec::with_capacity(256);
for j in 0..256 {
let j = j << 8 * right;
let v = if i < field_size && j < field_size {
let v = G2Poly(i as u64) * G2Poly(j as u64) % modulus;
v
} else {
G2Poly(0)
};
row.push(format!("{}", v.0));
}
right_parts.push(format!("[{}]", row.join(",")));
}
left_parts.push(format!("[{}]", right_parts.join(",")));
}
mul_table.push(format!("[{}]", left_parts.join(",")));
}
format!("[{}]", mul_table.join(","))
}
fn generate_inv_table_string(modulus: G2Poly) -> String {
assert!(modulus.is_irreducible());
let field_size = 1 << modulus.degree().expect("Irreducible polynomial has positive degree");
let mut inv_table = vec![0; field_size as usize];
for i in 1..field_size {
if inv_table[i as usize] != 0 {
continue;
}
let a = G2Poly(i);
let (_gcd, x, _y) = extended_gcd(a, modulus);
inv_table[i as usize] = x.0;
inv_table[x.0 as usize] = i;
}
use std::fmt::Write;
let mut res = String::with_capacity(3 * field_size as usize);
write!(&mut res, "[").unwrap();
for v in inv_table {
write!(&mut res, "{},", v).unwrap();
}
write!(&mut res, "]").unwrap();
res
}
fn generate_mul_impl(ident: syn::Ident, ident_name: &str, modulus: G2Poly, ty: P2TokenStream, field_size: usize, mask: u64) -> (P2TokenStream, P2TokenStream, P2TokenStream) {
let mul_table = generate_mul_table_string(modulus);
let inv_table = generate_inv_table_string(modulus);
let mul_table_string: proc_macro2::TokenStream = mul_table.parse().unwrap();
let inv_table_string: proc_macro2::TokenStream = inv_table.parse().unwrap();
let nparts = ceil_log256(field_size);
let tables = quote! {
pub static MUL_TABLE: [[[[#ty; 256]; 256]; #nparts]; #nparts] = #mul_table_string;
pub static INV_TABLE: [#ty; #field_size] = #inv_table_string;
};
let mut mul_ops = Vec::with_capacity(nparts * nparts);
for left in 0..nparts {
for right in 0..nparts {
mul_ops.push(quote![
#ident(MUL_TABLE[#left][#right][(((self.0 & #mask as #ty) >> (8*#left)) & 255) as usize][(((rhs.0 & #mask as #ty) >> (8*#right)) & 255) as usize])
]);
}
}
let mul = quote![
impl ::core::ops::Mul for #ident {
type Output = #ident;
fn mul(self, rhs: #ident) -> #ident {
#(#mul_ops)+*
}
}
impl ::core::ops::MulAssign for #ident {
fn mul_assign(&mut self, rhs: #ident) {
*self = *self * rhs;
}
}
];
let err_msg = format!("Division by 0 in {}", ident_name);
let div = quote![
impl ::core::ops::Div for #ident {
type Output = #ident;
fn div(self, rhs: #ident) -> #ident {
if (rhs.0 & #mask as #ty) == 0 {
panic!(#err_msg);
}
self * #ident(INV_TABLE[(rhs.0 & #mask as #ty) as usize])
}
}
impl ::core::ops::DivAssign for #ident {
fn div_assign(&mut self, rhs: #ident) {
*self = *self / rhs;
}
}
];
(tables, mul, div)
}
impl Settings {
pub fn from_input(input: ParsedInput) -> syn::Result<Self> {
let ident = input.ident;
let ident_name = ident.to_string();
let p_val = input.p.value();
let modulus = input.modulus
.map(|m| G2Poly(m.value()))
.unwrap_or_else(|| find_modulus_poly(p_val));
if !modulus.is_irreducible() {
Err(syn::Error::new(syn::export::Span::call_site(), format!("Modulus {} is not irreducible", modulus)))?;
}
let generator = find_generator(modulus);
if !generator.is_generator(modulus) {
Err(syn::Error::new(syn::export::Span::call_site(), format!("{} is not a generator", generator)))?;
}
Ok(Settings {
ident,
ident_name,
p_val,
modulus,
generator,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_settings_parser() {
let span = syn::export::Span::call_site();
let input = ParsedInput {
ident: syn::Ident::new("foo", span),
p: syn::LitInt::new(3, syn::IntSuffix::None, span),
modulus: None,
};
let r = Settings::from_input(input);
assert!(r.is_ok());
assert_eq!(r.unwrap(), Settings {
ident: syn::Ident::new("foo", span),
ident_name: "foo".to_string(),
p_val: 3,
modulus: G2Poly(0b1011),
generator: G2Poly(0b10),
});
}
#[test]
fn test_generate_mul_table() {
let m = G2Poly(0b111);
assert_eq!(include_str!("../tests/mul_table.txt").trim(), generate_mul_table_string(m));
}
#[test]
fn test_generate_inv_table_string() {
let m = G2Poly(0b1_0001_1011);
assert_eq!(include_str!("../tests/inv_table.txt").trim(), generate_inv_table_string(m));
}
#[test]
fn test_ceil_log256() {
assert_eq!(0, ceil_log256(0));
assert_eq!(1, ceil_log256(1));
assert_eq!(1, ceil_log256(256));
assert_eq!(2, ceil_log256(257));
assert_eq!(2, ceil_log256(65536));
assert_eq!(3, ceil_log256(65537));
assert_eq!(3, ceil_log256(131072));
assert_eq!(3, ceil_log256(16777216));
assert_eq!(4, ceil_log256(16777217));
}
}