use crate::poly::{BinaryPoly128, BinaryPoly16, BinaryPoly32, BinaryPoly64};
use crate::{BinaryFieldElement, BinaryPolynomial};
const IRREDUCIBLE_16: u32 = 0x1002D; const IRREDUCIBLE_32: u64 = (1u64 << 32) | 0b11001 | (1 << 7) | (1 << 9) | (1 << 15);
macro_rules! impl_binary_elem {
($name:ident, $poly_type:ident, $poly_double:ident, $value_type:ty, $value_double:ty, $irreducible:expr, $bitsize:expr) => {
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "scale",
derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
)]
pub struct $name($poly_type);
// SAFETY: $name is repr(transparent) over $poly_type which wraps $value_type (a primitive)
unsafe impl bytemuck::Pod for $name {}
unsafe impl bytemuck::Zeroable for $name {}
impl $name {
pub const fn from_value(val: $value_type) -> Self {
Self($poly_type::new(val))
}
fn mod_irreducible_wide(poly: $poly_double) -> Self {
let mut p = poly.value();
let irr = $irreducible;
let n = $bitsize;
let total_bits = core::mem::size_of::<$value_double>() * 8;
loop {
if p == 0 {
break; }
let lz = p.leading_zeros() as usize;
let high_bit = total_bits - lz - 1;
if high_bit < n {
break;
}
p ^= irr << (high_bit - n);
}
Self($poly_type::new(p as $value_type))
}
}
impl BinaryFieldElement for $name {
type Poly = $poly_type;
fn zero() -> Self {
Self($poly_type::zero())
}
fn one() -> Self {
Self($poly_type::one())
}
fn from_poly(poly: Self::Poly) -> Self {
Self(poly)
}
fn poly(&self) -> Self::Poly {
self.0
}
fn add(&self, other: &Self) -> Self {
Self(self.0.add(&other.0))
}
fn mul(&self, other: &Self) -> Self {
let a_wide = $poly_double::from_value(self.0.value() as u64);
let b_wide = $poly_double::from_value(other.0.value() as u64);
let prod_wide = a_wide.mul(&b_wide);
Self::mod_irreducible_wide(prod_wide)
}
fn inv(&self) -> Self {
assert_ne!(self.0.value(), 0, "Cannot invert zero");
if $bitsize <= 16 {
let exp = (1u64 << $bitsize) - 2;
return self.pow(exp);
}
let mut acc = self.mul(self);
let mut result = acc;
for _ in 2..$bitsize {
acc = acc.mul(&acc); result = result.mul(&acc);
}
result
}
fn pow(&self, mut exp: u64) -> Self {
if *self == Self::zero() {
return Self::zero();
}
let mut result = Self::one();
let mut base = *self;
while exp > 0 {
if exp & 1 == 1 {
result = result.mul(&base);
}
base = base.mul(&base);
exp >>= 1;
}
result
}
}
impl From<$value_type> for $name {
fn from(val: $value_type) -> Self {
Self::from_value(val)
}
}
#[cfg(feature = "rand")]
impl rand::distributions::Distribution<$name> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> $name {
$name::from_value(rng.gen())
}
}
};
}
impl_binary_elem!(
BinaryElem16,
BinaryPoly16,
BinaryPoly32,
u16,
u32,
IRREDUCIBLE_16,
16
);
impl_binary_elem!(
BinaryElem32,
BinaryPoly32,
BinaryPoly64,
u32,
u64,
IRREDUCIBLE_32,
32
);
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "scale",
derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
)]
pub struct BinaryElem128(BinaryPoly128);
unsafe impl bytemuck::Pod for BinaryElem128 {}
unsafe impl bytemuck::Zeroable for BinaryElem128 {}
impl BinaryElem128 {
pub const fn from_value(val: u128) -> Self {
Self(BinaryPoly128::new(val))
}
}
impl BinaryFieldElement for BinaryElem128 {
type Poly = BinaryPoly128;
fn zero() -> Self {
Self(BinaryPoly128::zero())
}
fn one() -> Self {
Self(BinaryPoly128::one())
}
fn from_poly(poly: Self::Poly) -> Self {
Self(poly)
}
fn poly(&self) -> Self::Poly {
self.0
}
fn add(&self, other: &Self) -> Self {
Self(self.0.add(&other.0))
}
fn mul(&self, other: &Self) -> Self {
use crate::simd::{carryless_mul_128_full, reduce_gf128};
let product = carryless_mul_128_full(self.0, other.0);
let reduced = reduce_gf128(product);
Self(reduced)
}
fn inv(&self) -> Self {
assert_ne!(self.0.value(), 0, "Cannot invert zero");
let result = crate::fast_inverse::invert_gf128(self.0.value());
Self(BinaryPoly128::new(result))
}
fn pow(&self, mut exp: u64) -> Self {
if *self == Self::zero() {
return Self::zero();
}
let mut result = Self::one();
let mut base = *self;
while exp > 0 {
if exp & 1 == 1 {
result = result.mul(&base);
}
base = base.mul(&base);
exp >>= 1;
}
result
}
}
impl BinaryElem128 {
#[inline]
pub fn mul_by_x(&self) -> Self {
let val = self.0.value();
let shifted = val << 1;
let overflow = (val >> 127) & 1;
let reduced = shifted ^ (overflow * 0x87);
Self(BinaryPoly128::new(reduced))
}
}
impl From<u128> for BinaryElem128 {
fn from(val: u128) -> Self {
Self::from_value(val)
}
}
#[cfg(feature = "rand")]
impl rand::distributions::Distribution<BinaryElem128> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> BinaryElem128 {
BinaryElem128::from_value(rng.gen())
}
}
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "scale",
derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
)]
pub struct BinaryElem64(BinaryPoly64);
unsafe impl bytemuck::Pod for BinaryElem64 {}
unsafe impl bytemuck::Zeroable for BinaryElem64 {}
impl BinaryElem64 {
pub const fn from_value(val: u64) -> Self {
Self(BinaryPoly64::new(val))
}
}
impl BinaryFieldElement for BinaryElem64 {
type Poly = BinaryPoly64;
fn zero() -> Self {
Self(BinaryPoly64::zero())
}
fn one() -> Self {
Self(BinaryPoly64::one())
}
fn from_poly(poly: Self::Poly) -> Self {
Self(poly)
}
fn poly(&self) -> Self::Poly {
self.0
}
fn add(&self, other: &Self) -> Self {
Self(self.0.add(&other.0))
}
fn mul(&self, other: &Self) -> Self {
Self(self.0.mul(&other.0))
}
fn inv(&self) -> Self {
assert_ne!(self.0.value(), 0, "Cannot invert zero");
self.pow(0xFFFFFFFFFFFFFFFE)
}
fn pow(&self, mut exp: u64) -> Self {
if *self == Self::zero() {
return Self::zero();
}
let mut result = Self::one();
let mut base = *self;
while exp > 0 {
if exp & 1 == 1 {
result = result.mul(&base);
}
base = base.mul(&base);
exp >>= 1;
}
result
}
}
impl From<BinaryElem16> for BinaryElem32 {
fn from(elem: BinaryElem16) -> Self {
BinaryElem32::from(elem.0.value() as u32)
}
}
impl From<BinaryElem16> for BinaryElem64 {
fn from(elem: BinaryElem16) -> Self {
BinaryElem64(BinaryPoly64::new(elem.0.value() as u64))
}
}
impl From<BinaryElem16> for BinaryElem128 {
fn from(elem: BinaryElem16) -> Self {
BinaryElem128::from(elem.0.value() as u128)
}
}
impl From<BinaryElem32> for BinaryElem64 {
fn from(elem: BinaryElem32) -> Self {
BinaryElem64(BinaryPoly64::new(elem.0.value() as u64))
}
}
impl From<BinaryElem32> for BinaryElem128 {
fn from(elem: BinaryElem32) -> Self {
BinaryElem128::from(elem.0.value() as u128)
}
}
impl From<BinaryElem64> for BinaryElem128 {
fn from(elem: BinaryElem64) -> Self {
BinaryElem128::from(elem.0.value() as u128)
}
}