use core::cmp::Ordering;
use core::fmt;
use core::iter::{Product, Sum};
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
#[derive(Debug)]
pub enum BinaryFieldError {
InverseOfZero,
}
#[derive(Clone, Copy, Debug)]
#[derive(Default)]
pub struct TowerFieldElement {
pub value: u128,
pub num_level: usize,
}
impl TowerFieldElement {
pub fn new(val: u128, num_level: usize) -> Self {
let safe_level = if num_level > 7 { 7 } else { num_level };
let bits = 1 << safe_level;
let mask = if bits >= 128 {
u128::MAX
} else {
(1 << bits) - 1
};
Self {
value: val & mask,
num_level: safe_level,
}
}
pub fn is_zero(&self) -> bool {
self.value == 0
}
#[inline]
pub fn is_one(&self) -> bool {
self.value == 1
}
#[inline]
pub fn value(&self) -> u128 {
self.value
}
#[inline]
pub fn num_level(&self) -> usize {
self.num_level
}
#[inline]
pub fn num_bits(&self) -> usize {
1 << self.num_level()
}
#[cfg(feature = "std")]
pub fn to_binary_string(&self) -> String {
format!("{:0width$b}", self.value, width = self.num_bits())
}
pub fn split(&self) -> (Self, Self) {
let half_bits = self.num_bits() / 2;
let mask = (1 << half_bits) - 1;
let lo = self.value() & mask;
let hi = (self.value() >> half_bits) & mask;
(
Self::new(hi, self.num_level() - 1),
Self::new(lo, self.num_level() - 1),
)
}
pub fn join(&self, low: &Self) -> Self {
let joined = (self.value() << self.num_bits()) | low.value();
Self::new(joined, self.num_level() + 1)
}
pub fn extend_num_level(&mut self, new_level: usize) {
if self.num_level() < new_level {
self.num_level = new_level;
}
}
pub fn zero() -> Self {
Self::new(0, 0)
}
pub fn one() -> Self {
Self::new(1, 0)
}
fn add_elements(&self, other: &Self) -> Self {
let num_level = self.num_level().max(other.num_level());
Self::new(self.value() ^ other.value(), num_level)
}
fn mul(self, other: Self) -> Self {
match self.num_level().cmp(&other.num_level()) {
Ordering::Greater => {
let (a_hi, a_lo) = self.split();
a_hi.mul(other).join(&a_lo.mul(other))
}
Ordering::Less => {
other.mul(self)
}
Ordering::Equal => {
if self.num_level() == 0 {
return Self::new(self.value() & other.value(), 0);
}
let (a_high, a_low) = self.split();
let (b_high, b_low) = other.split();
let low_product = a_low.mul(b_low); let high_product = a_high.mul(b_high);
let x_value = if self.num_level() == 1 {
Self::new(1, 0)
} else {
Self::new(1 << (self.num_bits() / 4), self.num_level() - 1)
};
let shifted_high_product = high_product.mul(x_value);
let sum_product = (a_low + a_high).mul(b_low + b_high);
let middle_term = sum_product - low_product - high_product;
(shifted_high_product + middle_term).join(&(high_product + low_product))
}
}
}
pub fn inv(&self) -> Result<Self, BinaryFieldError> {
if self.is_zero() {
return Err(BinaryFieldError::InverseOfZero);
}
if self.num_level() <= 1 || self.num_bits() <= 4 {
let exponent = (1 << self.num_bits()) - 2;
Ok(Self::pow(self, exponent as u32))
} else {
let (a_hi, a_lo) = self.split();
let two_pow_k_minus_one = Self::new(1 << (self.num_bits() / 4), self.num_level() - 1);
let a_lo_next = a_lo + a_hi * two_pow_k_minus_one;
let delta = a_lo * a_lo_next + a_hi * a_hi;
let delta_inverse = delta.inv()?;
let out_hi = delta_inverse * a_hi;
let out_lo = delta_inverse * a_lo_next;
Ok(out_hi.join(&out_lo))
}
}
pub fn pow(&self, exp: u32) -> Self {
let mut result = Self::one();
let mut base = *self;
let mut exp_val = exp;
while exp_val > 0 {
if exp_val & 1 == 1 {
result *= base;
}
base = base * base;
exp_val >>= 1;
}
result
}
}
impl PartialEq<TowerFieldElement> for TowerFieldElement {
fn eq(&self, other: &Self) -> bool {
self.value() == other.value()
}
}
impl Eq for TowerFieldElement {}
impl Add for TowerFieldElement {
type Output = Self;
fn add(self, other: Self) -> Self {
self.add_elements(&other)
}
}
impl<'a> Add<&'a TowerFieldElement> for &'a TowerFieldElement {
type Output = TowerFieldElement;
fn add(self, other: &'a TowerFieldElement) -> TowerFieldElement {
self.add_elements(other)
}
}
impl AddAssign for TowerFieldElement {
fn add_assign(&mut self, other: Self) {
*self = *self + other;
}
}
#[allow(clippy::suspicious_arithmetic_impl)]
impl Sub for TowerFieldElement {
type Output = Self;
fn sub(self, other: Self) -> Self {
self + other
}
}
impl Neg for TowerFieldElement {
type Output = Self;
fn neg(self) -> Self {
self
}
}
impl Mul for TowerFieldElement {
type Output = Self;
fn mul(self, other: Self) -> Self {
self.mul(other)
}
}
impl Mul<&TowerFieldElement> for &TowerFieldElement {
type Output = TowerFieldElement;
fn mul(self, other: &TowerFieldElement) -> TowerFieldElement {
<TowerFieldElement as Mul<TowerFieldElement>>::mul(*self, *other)
}
}
impl MulAssign for TowerFieldElement {
fn mul_assign(&mut self, other: Self) {
*self = *self * other;
}
}
impl Product for TowerFieldElement {
fn product<I>(iter: I) -> Self
where
I: Iterator<Item = Self>,
{
iter.fold(Self::one(), |acc, x| acc * x)
}
}
impl Sum for TowerFieldElement {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = Self>,
{
iter.fold(Self::zero(), |acc, x| acc + x)
}
}
impl fmt::Display for TowerFieldElement {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.value)
}
}
impl From<u128> for TowerFieldElement {
fn from(val: u128) -> Self {
TowerFieldElement::new(val, 7)
}
}
impl From<u64> for TowerFieldElement {
fn from(val: u64) -> Self {
TowerFieldElement::new(val as u128, 6)
}
}
impl From<u32> for TowerFieldElement {
fn from(val: u32) -> Self {
TowerFieldElement::new(val as u128, 5)
}
}
impl From<u16> for TowerFieldElement {
fn from(val: u16) -> Self {
TowerFieldElement::new(val as u128, 4)
}
}
impl From<u8> for TowerFieldElement {
fn from(val: u8) -> Self {
TowerFieldElement::new(val as u128, 3)
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_new_safe() {
let elem = TowerFieldElement::new(0, 8);
assert_eq!(elem.num_level, 7);
let elem = TowerFieldElement::new(4, 1); assert_eq!(elem.value, 0); }
#[test]
fn test_addition() {
let a = TowerFieldElement::new(5, 9); let b = TowerFieldElement::new(3, 2);
let c = a + b;
assert_eq!(c.value, 6);
assert_eq!(c.num_level, 7);
let d = b + a;
assert_eq!(d, c);
}
#[test]
fn mul_in_level_0() {
let a = TowerFieldElement::new(0, 0);
let b = TowerFieldElement::new(1, 0);
assert_eq!(a * a, a);
assert_eq!(a * b, a);
assert_eq!(b * b, b);
}
#[test]
fn mul_in_level_1() {
let a = TowerFieldElement::new(0b00, 1); let b = TowerFieldElement::new(0b01, 1); let c = TowerFieldElement::new(0b10, 1); let d = TowerFieldElement::new(0b11, 1); assert_eq!(a * a, a);
assert_eq!(a * b, a);
assert_eq!(b * c, c);
assert_eq!(c * d, b);
}
#[test]
fn mul_in_level_2() {
let a = TowerFieldElement::new(0b0000, 2); let b = TowerFieldElement::new(0b0001, 2); let c = TowerFieldElement::new(0b0010, 2); let d = TowerFieldElement::new(0b0011, 2); let e = TowerFieldElement::new(0b0100, 2); let f = TowerFieldElement::new(0b0101, 2); let g = TowerFieldElement::new(0b0110, 2); let h = TowerFieldElement::new(0b0111, 2); let i = TowerFieldElement::new(0b1000, 2); let j = TowerFieldElement::new(0b1001, 2); let k = TowerFieldElement::new(0b1010, 2); let l = TowerFieldElement::new(0b1011, 2); let n = TowerFieldElement::new(0b1100, 2); let m = TowerFieldElement::new(0b1101, 2); let o = TowerFieldElement::new(0b1110, 2); let p = TowerFieldElement::new(0b1111, 2);
assert_eq!(a * p, a); assert_eq!(a * l, a); assert_eq!(b * m, m); assert_eq!(c * e, i); assert_eq!(c * c, d); assert_eq!(g * h, n); assert_eq!(k * j, b); assert_eq!(j * f, d); assert_eq!(e * e, j); assert_eq!(n * o, k); }
#[test]
fn mul_between_different_levels() {
let a = TowerFieldElement::new(0b10, 1); let b = TowerFieldElement::new(0b0100, 2); let c = TowerFieldElement::new(0b1000, 2); assert_eq!(a * b, c);
}
#[test]
fn test_correct_level_mul() {
let a = TowerFieldElement::new(0b1111, 5);
let b = TowerFieldElement::new(0b1010, 2);
assert_eq!((a * b).num_level, 5);
}
#[test]
fn mul_is_asociative() {
let a = TowerFieldElement::new(83, 7);
let b = TowerFieldElement::new(31, 5);
let c = TowerFieldElement::new(3, 2);
let ab = a * b;
let bc = b * c;
assert_eq!(ab * c, a * bc);
}
#[test]
fn mul_is_conmutative() {
let a = TowerFieldElement::new(127, 7);
let b = TowerFieldElement::new(6, 3);
let ab = a * b;
let ba = b * a;
assert_eq!(ab, ba);
}
#[test]
fn test_inverse() {
let a0 = TowerFieldElement::new(1, 0);
let inv_a0 = a0.inv().unwrap();
assert_eq!(inv_a0.value, 1);
assert_eq!(inv_a0.num_level, 0);
let a1 = TowerFieldElement::new(2, 1);
let inv_a1 = a1.inv().unwrap();
assert_eq!(inv_a1.value, 3); assert_eq!(inv_a1.num_level, 1);
let a2 = TowerFieldElement::new(15, 4);
let inv_a2 = a2.inv().unwrap();
let one = TowerFieldElement::new(1, 4);
assert_eq!(a2 * inv_a2, one);
let a3 = TowerFieldElement::new(30, 5);
let inv_a3 = a3.inv().unwrap();
let one = TowerFieldElement::new(1, 5);
assert_eq!(a3 * inv_a3, one);
let zero = TowerFieldElement::zero();
assert!(matches!(zero.inv(), Err(BinaryFieldError::InverseOfZero)));
}
#[test]
fn test_multiplication_overflow() {
for level in 0..7 {
let max_value = (1u128 << (1 << level)) - 1; let a = TowerFieldElement::new(max_value, level);
let b = TowerFieldElement::new(max_value, level);
let result = a * b;
assert!(result.value < (1u128 << result.num_bits()));
}
}
#[test]
fn test_split_join_consistency() {
for i in 0..20 {
let original = TowerFieldElement::new(i, 3);
let (hi, lo) = original.split();
let rejoined = hi.join(&lo);
assert_eq!(rejoined, original);
}
}
#[cfg(feature = "std")]
#[test]
fn test_bin_representation() {
let a = TowerFieldElement::new(0b1010, 5);
assert_eq!(a.to_binary_string(), "00000000000000000000000000001010");
let b = TowerFieldElement::new(0b1010, 4);
assert_eq!(b.to_binary_string(), "0000000000001010");
}
fn arb_tower_element_any() -> impl Strategy<Value = TowerFieldElement> {
(0usize..=7)
.prop_flat_map(|level| {
let max_val = if level == 0 {
1
} else if (1usize << level) >= 128 {
u128::MAX
} else {
(1u128 << (1 << level)) - 1
};
(Just(level), 0u128..=max_val)
})
.prop_map(|(level, val)| TowerFieldElement::new(val, level))
}
#[cfg(feature = "std")]
proptest! {
#[test]
fn test_mul_commutative(a in arb_tower_element_any(), b in arb_tower_element_any()) {
prop_assert_eq!(a * b, b * a);
}
#[test]
fn test_mul_associative(a in arb_tower_element_any(), b in arb_tower_element_any(), c in arb_tower_element_any()) {
prop_assert_eq!((a * b) * c, a * (b * c));
}
}
}