1#![allow(long_running_const_eval)]
25
26#[cfg(not(feature = "std"))]
27extern crate alloc;
28
29#[cfg(not(feature = "std"))]
30use alloc::{vec, vec::Vec};
31
32use crate::poly::BinaryPoly128;
33
34#[inline]
39fn pow_2_2_n(value: u128, n: usize, table: &[[[u128; 16]; 32]; 7]) -> u128 {
40 match n {
41 0 => square_gf128(value),
42 1..=7 => {
43 let mut result = 0u128;
44 for nibble_index in 0..32 {
45 let nibble_value = ((value >> (nibble_index * 4)) & 0x0F) as usize;
46 result ^= table[n - 1][nibble_index][nibble_value];
47 }
48 result
49 }
50 _ => value,
51 }
52}
53
54#[inline]
56fn square_gf128(x: u128) -> u128 {
57 let lo = x as u64;
63 let hi = (x >> 64) as u64;
64
65 let lo_spread = spread_bits(lo);
67 let hi_spread = spread_bits(hi);
68
69 let result_lo = lo_spread;
78 let result_hi = hi_spread;
79
80 reduce_256_to_128(result_hi, result_lo)
81}
82
83#[inline]
85fn spread_bits(x: u64) -> u128 {
86 #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
88 {
89 use core::arch::x86_64::_pdep_u64;
92 const EVEN_BITS_MASK: u64 = 0x5555_5555_5555_5555;
93
94 let lo = (x & 0xFFFF_FFFF) as u64;
96 let hi = (x >> 32) as u64;
97
98 let lo_spread = unsafe { _pdep_u64(lo, EVEN_BITS_MASK) };
100 let hi_spread = unsafe { _pdep_u64(hi, EVEN_BITS_MASK) };
102
103 (lo_spread as u128) | ((hi_spread as u128) << 64)
104 }
105
106 #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))]
107 {
108 spread_bits_parallel(x)
111 }
112}
113
114#[inline]
116fn spread_bits_parallel(x: u64) -> u128 {
117 let mut v = x as u128;
120
121 v = (v | (v << 16)) & 0x0000_FFFF_0000_FFFF_0000_FFFF_0000_FFFF;
123 v = (v | (v << 8)) & 0x00FF_00FF_00FF_00FF_00FF_00FF_00FF_00FF;
125 v = (v | (v << 4)) & 0x0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F;
127 v = (v | (v << 2)) & 0x3333_3333_3333_3333_3333_3333_3333_3333;
129 v = (v | (v << 1)) & 0x5555_5555_5555_5555_5555_5555_5555_5555;
131
132 v
133}
134
135#[inline]
139fn reduce_256_to_128(hi: u128, lo: u128) -> u128 {
140 let tmp = hi ^ (hi >> 127) ^ (hi >> 126) ^ (hi >> 121);
144
145 lo ^ tmp ^ (tmp << 1) ^ (tmp << 2) ^ (tmp << 7)
147}
148
149#[inline]
154pub fn invert_gf128(value: u128) -> u128 {
155 if value == 0 {
156 return 0;
157 }
158
159 let mut self_pow_2_pow_k1s = value;
164
165 let mut res = pow_2_2_n(self_pow_2_pow_k1s, 0, &NIBBLE_POW_TABLE);
167
168 let mut self_pow_2_pow_k1s_to_k0s = res;
170
171 for k in 1..7 {
173 self_pow_2_pow_k1s = mul_gf128(self_pow_2_pow_k1s, self_pow_2_pow_k1s_to_k0s);
175
176 self_pow_2_pow_k1s_to_k0s = pow_2_2_n(self_pow_2_pow_k1s, k, &NIBBLE_POW_TABLE);
178
179 res = mul_gf128(res, self_pow_2_pow_k1s_to_k0s);
181 }
182
183 res
184}
185
186pub fn batch_invert_gf128(values: &[u128]) -> Vec<u128> {
202 if values.is_empty() {
203 return Vec::new();
204 }
205
206 let n = values.len();
207 let mut result = vec![0u128; n];
208
209 let non_zero_indices: Vec<usize> = values
211 .iter()
212 .enumerate()
213 .filter(|(_, &v)| v != 0)
214 .map(|(i, _)| i)
215 .collect();
216
217 if non_zero_indices.is_empty() {
218 return result; }
220
221 let mut prefix_products = Vec::with_capacity(non_zero_indices.len());
227 let mut running = values[non_zero_indices[0]];
228 prefix_products.push(running);
229
230 for &idx in &non_zero_indices[1..] {
231 running = mul_gf128(running, values[idx]);
232 prefix_products.push(running);
233 }
234
235 let mut inv_suffix = invert_gf128(running);
237
238 for i in (1..non_zero_indices.len()).rev() {
240 let idx = non_zero_indices[i];
241 result[idx] = mul_gf128(prefix_products[i - 1], inv_suffix);
243 inv_suffix = mul_gf128(inv_suffix, values[idx]);
245 }
246
247 result[non_zero_indices[0]] = inv_suffix;
249
250 result
251}
252
253pub fn batch_invert_gf128_in_place(values: &mut [u128]) {
257 let inverted = batch_invert_gf128(values);
258 values.copy_from_slice(&inverted);
259}
260
261#[inline]
263fn mul_gf128(a: u128, b: u128) -> u128 {
264 use crate::simd::{carryless_mul_128_full, reduce_gf128};
265 let a_poly = BinaryPoly128::new(a);
266 let b_poly = BinaryPoly128::new(b);
267 let product = carryless_mul_128_full(a_poly, b_poly);
268 reduce_gf128(product).value()
269}
270
271static NIBBLE_POW_TABLE: [[[u128; 16]; 32]; 7] = generate_nibble_table();
275
276const fn generate_nibble_table() -> [[[u128; 16]; 32]; 7] {
278 let mut table = [[[0u128; 16]; 32]; 7];
279
280 let mut n = 0;
282 while n < 7 {
283 let mut pos = 0;
285 while pos < 32 {
286 let mut val = 0;
288 while val < 16 {
289 let input = (val as u128) << (pos * 4);
291 let result = const_pow_2_k(input, n + 1);
292 table[n][pos][val] = result;
293 val += 1;
294 }
295 pos += 1;
296 }
297 n += 1;
298 }
299
300 table
301}
302
303const fn const_pow_2_k(x: u128, k: usize) -> u128 {
305 let iterations = 1usize << k;
307 let mut result = x;
308 let mut i = 0;
309 while i < iterations {
310 result = const_square_gf128(result);
311 i += 1;
312 }
313 result
314}
315
316const fn const_square_gf128(x: u128) -> u128 {
318 let lo = x as u64;
320 let hi = (x >> 64) as u64;
321
322 let lo_spread = const_spread_bits(lo);
324 let hi_spread = const_spread_bits(hi);
325
326 const_reduce_256_to_128(hi_spread, lo_spread)
328}
329
330const fn const_spread_bits(x: u64) -> u128 {
332 let mut result = 0u128;
333 let mut val = x;
334 let mut i = 0;
335 while i < 64 {
336 if val & 1 != 0 {
337 result |= 1u128 << (2 * i);
338 }
339 val >>= 1;
340 i += 1;
341 }
342 result
343}
344
345const fn const_reduce_256_to_128(hi: u128, lo: u128) -> u128 {
349 let tmp = hi ^ (hi >> 127) ^ (hi >> 126) ^ (hi >> 121);
353
354 lo ^ tmp ^ (tmp << 1) ^ (tmp << 2) ^ (tmp << 7)
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::{BinaryElem128, BinaryFieldElement};
362
363 #[test]
364 fn test_invert_basic() {
365 let test_values: [u128; 8] = [
367 1,
368 2,
369 0x12345678,
370 0xdeadbeef,
371 0xffffffffffffffff,
372 0x123456789abcdef0123456789abcdef0,
373 u128::MAX,
374 u128::MAX - 1,
375 ];
376
377 for &x in &test_values {
378 let x_inv = invert_gf128(x);
379 let product = mul_gf128(x, x_inv);
380 assert_eq!(product, 1, "x * x^(-1) should be 1 for x = 0x{:032x}", x);
381 }
382 }
383
384 #[test]
385 fn test_invert_zero() {
386 assert_eq!(invert_gf128(0), 0);
387 }
388
389 #[test]
390 fn test_invert_matches_slow() {
391 let test_values: [u128; 8] = [
393 1,
394 2,
395 0x12345678,
396 0xdeadbeef,
397 0xffffffffffffffff,
398 0x123456789abcdef0123456789abcdef0,
399 u128::MAX,
400 u128::MAX - 1,
401 ];
402
403 for &x in &test_values {
404 let fast_inv = invert_gf128(x);
405
406 let elem = BinaryElem128::from(x);
408 let slow_inv = elem.inv();
409 let slow_inv_val = slow_inv.poly().value();
410
411 assert_eq!(
412 fast_inv, slow_inv_val,
413 "fast and slow inverse should match for x = 0x{:032x}",
414 x
415 );
416 }
417 }
418
419 #[test]
420 fn test_square_basic() {
421 let x = 0x123456789abcdef0u128;
423 let x_sq = square_gf128(x);
424
425 let x_sq_mul = mul_gf128(x, x);
427 assert_eq!(x_sq, x_sq_mul, "square should match multiplication");
428 }
429
430 #[test]
431 fn test_batch_invert() {
432 let values: Vec<u128> = vec![
433 1,
434 2,
435 0x12345678,
436 0xdeadbeef,
437 0xffffffffffffffff,
438 0x123456789abcdef0123456789abcdef0,
439 u128::MAX,
440 u128::MAX - 1,
441 ];
442
443 let batch_inverted = batch_invert_gf128(&values);
444
445 for (i, &v) in values.iter().enumerate() {
447 let individual_inv = invert_gf128(v);
448 assert_eq!(
449 batch_inverted[i], individual_inv,
450 "batch inversion should match individual for index {} value 0x{:032x}",
451 i, v
452 );
453 }
454 }
455
456 #[test]
457 fn test_batch_invert_with_zeros() {
458 let values: Vec<u128> = vec![1, 0, 2, 0, 3, 0];
459 let batch_inverted = batch_invert_gf128(&values);
460
461 assert_eq!(batch_inverted[1], 0);
463 assert_eq!(batch_inverted[3], 0);
464 assert_eq!(batch_inverted[5], 0);
465
466 assert_eq!(batch_inverted[0], invert_gf128(1));
468 assert_eq!(batch_inverted[2], invert_gf128(2));
469 assert_eq!(batch_inverted[4], invert_gf128(3));
470 }
471
472 #[test]
473 fn test_batch_invert_empty() {
474 let values: Vec<u128> = vec![];
475 let batch_inverted = batch_invert_gf128(&values);
476 assert!(batch_inverted.is_empty());
477 }
478
479 #[test]
480 fn test_batch_invert_single() {
481 let values = vec![0x12345678u128];
482 let batch_inverted = batch_invert_gf128(&values);
483 assert_eq!(batch_inverted[0], invert_gf128(0x12345678));
484 }
485
486 #[test]
487 fn test_spread_bits_correctness() {
488 let test_cases: [(u64, u128); 4] = [
490 (0b1, 0b1), (0b10, 0b100), (0b101, 0b10001), (0b11111111, 0x5555), ];
495
496 for (input, expected) in test_cases {
497 let result = spread_bits(input);
498 assert_eq!(
499 result, expected,
500 "spread_bits(0b{:b}) should be 0b{:b}, got 0b{:b}",
501 input, expected, result
502 );
503 }
504 }
505}