#![allow(long_running_const_eval)]
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
use crate::poly::BinaryPoly128;
#[inline]
fn pow_2_2_n(value: u128, n: usize, table: &[[[u128; 16]; 32]; 7]) -> u128 {
match n {
0 => square_gf128(value),
1..=7 => {
let mut result = 0u128;
for nibble_index in 0..32 {
let nibble_value = ((value >> (nibble_index * 4)) & 0x0F) as usize;
result ^= table[n - 1][nibble_index][nibble_value];
}
result
}
_ => value,
}
}
#[inline]
fn square_gf128(x: u128) -> u128 {
mul_gf128(x, x)
}
#[inline]
pub fn invert_gf128(value: u128) -> u128 {
if value == 0 {
return 0;
}
let mut self_pow_2_pow_k1s = value;
let mut res = pow_2_2_n(self_pow_2_pow_k1s, 0, &NIBBLE_POW_TABLE);
let mut self_pow_2_pow_k1s_to_k0s = res;
for k in 1..7 {
self_pow_2_pow_k1s = mul_gf128(self_pow_2_pow_k1s, self_pow_2_pow_k1s_to_k0s);
self_pow_2_pow_k1s_to_k0s = pow_2_2_n(self_pow_2_pow_k1s, k, &NIBBLE_POW_TABLE);
res = mul_gf128(res, self_pow_2_pow_k1s_to_k0s);
}
res
}
pub fn batch_invert_gf128(values: &[u128]) -> Vec<u128> {
if values.is_empty() {
return Vec::new();
}
let n = values.len();
let mut result = vec![0u128; n];
let non_zero_indices: Vec<usize> = values
.iter()
.enumerate()
.filter(|(_, &v)| v != 0)
.map(|(i, _)| i)
.collect();
if non_zero_indices.is_empty() {
return result; }
let mut prefix_products = Vec::with_capacity(non_zero_indices.len());
let mut running = values[non_zero_indices[0]];
prefix_products.push(running);
for &idx in &non_zero_indices[1..] {
running = mul_gf128(running, values[idx]);
prefix_products.push(running);
}
let mut inv_suffix = invert_gf128(running);
for i in (1..non_zero_indices.len()).rev() {
let idx = non_zero_indices[i];
result[idx] = mul_gf128(prefix_products[i - 1], inv_suffix);
inv_suffix = mul_gf128(inv_suffix, values[idx]);
}
result[non_zero_indices[0]] = inv_suffix;
result
}
pub fn batch_invert_gf128_in_place(values: &mut [u128]) {
let inverted = batch_invert_gf128(values);
values.copy_from_slice(&inverted);
}
#[inline]
fn mul_gf128(a: u128, b: u128) -> u128 {
use crate::simd::{carryless_mul_128_full, reduce_gf128};
let a_poly = BinaryPoly128::new(a);
let b_poly = BinaryPoly128::new(b);
let product = carryless_mul_128_full(a_poly, b_poly);
reduce_gf128(product).value()
}
static NIBBLE_POW_TABLE: [[[u128; 16]; 32]; 7] = generate_nibble_table();
const fn generate_nibble_table() -> [[[u128; 16]; 32]; 7] {
let mut table = [[[0u128; 16]; 32]; 7];
let mut n = 0;
while n < 7 {
let mut pos = 0;
while pos < 32 {
let mut val = 0;
while val < 16 {
let input = (val as u128) << (pos * 4);
let result = const_pow_2_k(input, n + 1);
table[n][pos][val] = result;
val += 1;
}
pos += 1;
}
n += 1;
}
table
}
const fn const_pow_2_k(x: u128, k: usize) -> u128 {
let iterations = 1usize << k;
let mut result = x;
let mut i = 0;
while i < iterations {
result = const_square_gf128(result);
i += 1;
}
result
}
const fn const_square_gf128(x: u128) -> u128 {
let lo = x as u64;
let hi = (x >> 64) as u64;
let lo_spread = const_spread_bits(lo);
let hi_spread = const_spread_bits(hi);
const_reduce_256_to_128(hi_spread, lo_spread)
}
const fn const_spread_bits(x: u64) -> u128 {
let mut result = 0u128;
let mut val = x;
let mut i = 0;
while i < 64 {
if val & 1 != 0 {
result |= 1u128 << (2 * i);
}
val >>= 1;
i += 1;
}
result
}
const fn const_reduce_256_to_128(hi: u128, lo: u128) -> u128 {
let tmp = hi ^ (hi >> 127) ^ (hi >> 126) ^ (hi >> 121);
lo ^ tmp ^ (tmp << 1) ^ (tmp << 2) ^ (tmp << 7)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{BinaryElem128, BinaryFieldElement};
#[test]
fn test_invert_basic() {
let test_values: [u128; 8] = [
1,
2,
0x12345678,
0xdeadbeef,
0xffffffffffffffff,
0x123456789abcdef0123456789abcdef0,
u128::MAX,
u128::MAX - 1,
];
for &x in &test_values {
let x_inv = invert_gf128(x);
let product = mul_gf128(x, x_inv);
assert_eq!(product, 1, "x * x^(-1) should be 1 for x = 0x{:032x}", x);
}
}
#[test]
fn test_invert_zero() {
assert_eq!(invert_gf128(0), 0);
}
#[test]
fn test_invert_matches_slow() {
let test_values: [u128; 8] = [
1,
2,
0x12345678,
0xdeadbeef,
0xffffffffffffffff,
0x123456789abcdef0123456789abcdef0,
u128::MAX,
u128::MAX - 1,
];
for &x in &test_values {
let fast_inv = invert_gf128(x);
let elem = BinaryElem128::from(x);
let slow_inv = elem.inv();
let slow_inv_val = slow_inv.poly().value();
assert_eq!(
fast_inv, slow_inv_val,
"fast and slow inverse should match for x = 0x{:032x}",
x
);
}
}
#[test]
fn test_square_basic() {
let x = 0x123456789abcdef0u128;
let x_sq = square_gf128(x);
let x_sq_mul = mul_gf128(x, x);
assert_eq!(x_sq, x_sq_mul, "square should match multiplication");
}
#[test]
fn test_batch_invert() {
let values: Vec<u128> = vec![
1,
2,
0x12345678,
0xdeadbeef,
0xffffffffffffffff,
0x123456789abcdef0123456789abcdef0,
u128::MAX,
u128::MAX - 1,
];
let batch_inverted = batch_invert_gf128(&values);
for (i, &v) in values.iter().enumerate() {
let individual_inv = invert_gf128(v);
assert_eq!(
batch_inverted[i], individual_inv,
"batch inversion should match individual for index {} value 0x{:032x}",
i, v
);
}
}
#[test]
fn test_batch_invert_with_zeros() {
let values: Vec<u128> = vec![1, 0, 2, 0, 3, 0];
let batch_inverted = batch_invert_gf128(&values);
assert_eq!(batch_inverted[1], 0);
assert_eq!(batch_inverted[3], 0);
assert_eq!(batch_inverted[5], 0);
assert_eq!(batch_inverted[0], invert_gf128(1));
assert_eq!(batch_inverted[2], invert_gf128(2));
assert_eq!(batch_inverted[4], invert_gf128(3));
}
#[test]
fn test_batch_invert_empty() {
let values: Vec<u128> = vec![];
let batch_inverted = batch_invert_gf128(&values);
assert!(batch_inverted.is_empty());
}
#[test]
fn test_batch_invert_single() {
let values = vec![0x12345678u128];
let batch_inverted = batch_invert_gf128(&values);
assert_eq!(batch_inverted[0], invert_gf128(0x12345678));
}
}