Skip to main content

ligerito_binary_fields/
fast_inverse.rs

1//! Fast inversion for GF(2^128) using Itoh-Tsujii with nibble tables
2//!
3//! This implements efficient field inversion using precomputed tables
4//! of frobenius powers. The algorithm reduces from ~127 multiplications
5//! to ~9 multiplications + table lookups.
6//!
7//! ## Algorithm
8//!
9//! Uses the identity: x^(-1) = x^(2^128 - 2)
10//!
11//! The exponent 2^128 - 2 is computed using addition chains:
12//! - 2^128 - 2 = 2 * (2^127 - 1)
13//! - 2^127 - 1 has binary representation of 127 ones
14//!
15//! By using the linearity of frobenius in characteristic 2, we can
16//! decompose computations by nibbles and use table lookups.
17//!
18//! ## References
19//!
20//! - Itoh-Tsujii: "A fast algorithm for computing multiplicative inverses in GF(2^m)"
21//! - Binius implementation: <https://github.com/IrreducibleOSS/binius>
22//! - GF2t Java library: <https://github.com/reyzin/GF2t>
23
24#![allow(long_running_const_eval)]
25
26#[cfg(not(feature = "std"))]
27extern crate alloc;
28
29#[cfg(not(feature = "std"))]
30use alloc::{vec, vec::Vec};
31
32use crate::poly::BinaryPoly128;
33
34/// Compute x^(2^(2^n)) using nibble table lookup
35///
36/// Uses the linearity of frobenius: (a + b)^(2^k) = a^(2^k) + b^(2^k)
37/// So we can decompose x into 32 nibbles and XOR the precomputed results.
38#[inline]
39fn pow_2_2_n(value: u128, n: usize, table: &[[[u128; 16]; 32]; 7]) -> u128 {
40    match n {
41        0 => square_gf128(value),
42        1..=7 => {
43            let mut result = 0u128;
44            for nibble_index in 0..32 {
45                let nibble_value = ((value >> (nibble_index * 4)) & 0x0F) as usize;
46                result ^= table[n - 1][nibble_index][nibble_value];
47            }
48            result
49        }
50        _ => value,
51    }
52}
53
54/// Square in GF(2^128) with reduction modulo x^128 + x^7 + x^2 + x + 1
55#[inline]
56fn square_gf128(x: u128) -> u128 {
57    // Squaring in binary field: spread bits and reduce
58    // x^2 doubles the bit positions, creating a 256-bit result
59    // then reduce mod the irreducible polynomial
60
61    // Split into two 64-bit halves
62    let lo = x as u64;
63    let hi = (x >> 64) as u64;
64
65    // Spread bits: each bit at position i moves to position 2i
66    let lo_spread = spread_bits(lo);
67    let hi_spread = spread_bits(hi);
68
69    // Combine: hi_spread goes into bits [128..256), lo_spread into [0..128)
70    // But hi_spread bits 0..64 go into result bits 128..192
71    // and hi_spread bits 64..128 go into result bits 192..256
72
73    // The 256-bit result is [hi_spread : lo_spread]
74    // We need to reduce the high 128 bits
75
76    // Actually simpler: just use the polynomial multiplication reduction
77    let result_lo = lo_spread;
78    let result_hi = hi_spread;
79
80    reduce_256_to_128(result_hi, result_lo)
81}
82
83/// Spread 64 bits to 128 bits (bit i -> bit 2i)
84#[inline]
85fn spread_bits(x: u64) -> u128 {
86    // Use BMI2 PDEP instruction when available for ~10x speedup
87    #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
88    {
89        // PDEP deposits bits from x into positions specified by mask
90        // Mask 0x5555...5555 has 1s at even positions (0,2,4,6,...)
91        use core::arch::x86_64::_pdep_u64;
92        const EVEN_BITS_MASK: u64 = 0x5555_5555_5555_5555;
93
94        // Split into two 32-bit halves for 64-bit PDEP
95        let lo = (x & 0xFFFF_FFFF) as u64;
96        let hi = (x >> 32) as u64;
97
98        // Deposit low 32 bits into even positions 0-62
99        let lo_spread = unsafe { _pdep_u64(lo, EVEN_BITS_MASK) };
100        // Deposit high 32 bits into even positions 64-126
101        let hi_spread = unsafe { _pdep_u64(hi, EVEN_BITS_MASK) };
102
103        (lo_spread as u128) | ((hi_spread as u128) << 64)
104    }
105
106    #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))]
107    {
108        // Fallback: parallel bit deposit using magic multiplies
109        // This is ~4x faster than the naive loop
110        spread_bits_parallel(x)
111    }
112}
113
114/// Parallel bit spread without BMI2 (still faster than loop)
115#[inline]
116fn spread_bits_parallel(x: u64) -> u128 {
117    // Spread using parallel prefix technique
118    // Each step doubles the spacing between bits
119    let mut v = x as u128;
120
121    // Step 1: spread bits by 1 (16-bit groups)
122    v = (v | (v << 16)) & 0x0000_FFFF_0000_FFFF_0000_FFFF_0000_FFFF;
123    // Step 2: spread bits by 2 (8-bit groups)
124    v = (v | (v << 8)) & 0x00FF_00FF_00FF_00FF_00FF_00FF_00FF_00FF;
125    // Step 3: spread bits by 4 (4-bit groups)
126    v = (v | (v << 4)) & 0x0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F;
127    // Step 4: spread bits by 8 (2-bit groups)
128    v = (v | (v << 2)) & 0x3333_3333_3333_3333_3333_3333_3333_3333;
129    // Step 5: spread bits by 16 (1-bit groups)
130    v = (v | (v << 1)) & 0x5555_5555_5555_5555_5555_5555_5555_5555;
131
132    v
133}
134
135/// Reduce 256-bit value to 128-bit modulo x^128 + x^7 + x^2 + x + 1
136///
137/// Uses the same algorithm as reduce_gf128 in simd.rs
138#[inline]
139fn reduce_256_to_128(hi: u128, lo: u128) -> u128 {
140    // Irreducible: x^128 + x^7 + x^2 + x + 1
141    // For bits in hi that would overflow when shifted left by 1,2,7,
142    // we compute tmp = bits that wrap around
143    let tmp = hi ^ (hi >> 127) ^ (hi >> 126) ^ (hi >> 121);
144
145    // Then apply the reduction: for each bit i in irreducible (0,1,2,7)
146    lo ^ tmp ^ (tmp << 1) ^ (tmp << 2) ^ (tmp << 7)
147}
148
149/// Invert a field element in GF(2^128)
150///
151/// Computes x^(-1) = x^(2^128 - 2) using Itoh-Tsujii algorithm.
152/// Returns 0 if x is 0 (not mathematically correct but safe).
153#[inline]
154pub fn invert_gf128(value: u128) -> u128 {
155    if value == 0 {
156        return 0;
157    }
158
159    // Computes value^(2^128-2)
160    // value * value^(2^128 - 2) = value^(2^128-1) = 1
161
162    // self_pow_2_pow_k1s contains value raised to power with 2^k ones in binary
163    let mut self_pow_2_pow_k1s = value;
164
165    // Square to get exponent = 2 (binary: 10)
166    let mut res = pow_2_2_n(self_pow_2_pow_k1s, 0, &NIBBLE_POW_TABLE);
167
168    // self_pow_2_pow_k1s_to_k0s = value^(2^k ones followed by 2^k zeros)
169    let mut self_pow_2_pow_k1s_to_k0s = res;
170
171    // Build up the exponent 2^128 - 2 = 111...110 (127 ones followed by a zero)
172    for k in 1..7 {
173        // Fill in zeros in exponent with ones
174        self_pow_2_pow_k1s = mul_gf128(self_pow_2_pow_k1s, self_pow_2_pow_k1s_to_k0s);
175
176        // Append 2^k zeros to exponent
177        self_pow_2_pow_k1s_to_k0s = pow_2_2_n(self_pow_2_pow_k1s, k, &NIBBLE_POW_TABLE);
178
179        // Prepend 2^k ones to result
180        res = mul_gf128(res, self_pow_2_pow_k1s_to_k0s);
181    }
182
183    res
184}
185
186/// Batch invert multiple field elements using Montgomery's trick
187///
188/// Given N elements, computes N inversions using only 3(N-1) multiplications
189/// plus a single inversion, instead of N separate inversions.
190///
191/// This is ~3x faster for N > 3, and ~9x faster for large N.
192///
193/// # Example
194/// ```ignore
195/// let inputs = [a, b, c, d];
196/// let outputs = batch_invert_gf128(&inputs);
197/// assert_eq!(outputs[0], a.inv());
198/// assert_eq!(outputs[1], b.inv());
199/// // etc
200/// ```
201pub fn batch_invert_gf128(values: &[u128]) -> Vec<u128> {
202    if values.is_empty() {
203        return Vec::new();
204    }
205
206    let n = values.len();
207    let mut result = vec![0u128; n];
208
209    // Handle zeros by tracking their positions
210    let non_zero_indices: Vec<usize> = values
211        .iter()
212        .enumerate()
213        .filter(|(_, &v)| v != 0)
214        .map(|(i, _)| i)
215        .collect();
216
217    if non_zero_indices.is_empty() {
218        return result; // All zeros
219    }
220
221    // Montgomery's trick:
222    // 1. Compute prefix products: p[i] = a[0] * a[1] * ... * a[i]
223    // 2. Invert the final product: inv_all = p[n-1]^(-1)
224    // 3. Recover individual inverses using suffix products
225
226    let mut prefix_products = Vec::with_capacity(non_zero_indices.len());
227    let mut running = values[non_zero_indices[0]];
228    prefix_products.push(running);
229
230    for &idx in &non_zero_indices[1..] {
231        running = mul_gf128(running, values[idx]);
232        prefix_products.push(running);
233    }
234
235    // Single inversion of the cumulative product
236    let mut inv_suffix = invert_gf128(running);
237
238    // Work backwards to recover individual inverses
239    for i in (1..non_zero_indices.len()).rev() {
240        let idx = non_zero_indices[i];
241        // inv(a[i]) = prefix[i-1] * inv_suffix
242        result[idx] = mul_gf128(prefix_products[i - 1], inv_suffix);
243        // Update inv_suffix = inv_suffix * a[i] = inv(prefix[i-1])
244        inv_suffix = mul_gf128(inv_suffix, values[idx]);
245    }
246
247    // First element's inverse is just inv_suffix at the end
248    result[non_zero_indices[0]] = inv_suffix;
249
250    result
251}
252
253/// Batch invert in-place (more memory efficient)
254///
255/// Modifies the input slice in-place, replacing each element with its inverse.
256pub fn batch_invert_gf128_in_place(values: &mut [u128]) {
257    let inverted = batch_invert_gf128(values);
258    values.copy_from_slice(&inverted);
259}
260
261/// Multiply two elements in GF(2^128)
262#[inline]
263fn mul_gf128(a: u128, b: u128) -> u128 {
264    use crate::simd::{carryless_mul_128_full, reduce_gf128};
265    let a_poly = BinaryPoly128::new(a);
266    let b_poly = BinaryPoly128::new(b);
267    let product = carryless_mul_128_full(a_poly, b_poly);
268    reduce_gf128(product).value()
269}
270
271/// Precomputed table: table[n][nibble_pos][nibble_val] = (nibble_val << 4*nibble_pos)^(2^(2^(n+1)))
272///
273/// Generated for GF(2^128) with irreducible x^128 + x^7 + x^2 + x + 1
274static NIBBLE_POW_TABLE: [[[u128; 16]; 32]; 7] = generate_nibble_table();
275
276/// Generate the nibble power table at compile time
277const fn generate_nibble_table() -> [[[u128; 16]; 32]; 7] {
278    let mut table = [[[0u128; 16]; 32]; 7];
279
280    // For each power level n (computing x^(2^(2^(n+1))))
281    let mut n = 0;
282    while n < 7 {
283        // For each nibble position (0..32)
284        let mut pos = 0;
285        while pos < 32 {
286            // For each nibble value (0..16)
287            let mut val = 0;
288            while val < 16 {
289                // Compute (val << 4*pos)^(2^(2^(n+1)))
290                let input = (val as u128) << (pos * 4);
291                let result = const_pow_2_k(input, n + 1);
292                table[n][pos][val] = result;
293                val += 1;
294            }
295            pos += 1;
296        }
297        n += 1;
298    }
299
300    table
301}
302
303/// Compute x^(2^(2^k)) at compile time
304const fn const_pow_2_k(x: u128, k: usize) -> u128 {
305    // 2^k squarings
306    let iterations = 1usize << k;
307    let mut result = x;
308    let mut i = 0;
309    while i < iterations {
310        result = const_square_gf128(result);
311        i += 1;
312    }
313    result
314}
315
316/// Square in GF(2^128) at compile time
317const fn const_square_gf128(x: u128) -> u128 {
318    // Split into 64-bit halves
319    let lo = x as u64;
320    let hi = (x >> 64) as u64;
321
322    // Spread bits
323    let lo_spread = const_spread_bits(lo);
324    let hi_spread = const_spread_bits(hi);
325
326    // Reduce
327    const_reduce_256_to_128(hi_spread, lo_spread)
328}
329
330/// Spread bits at compile time
331const fn const_spread_bits(x: u64) -> u128 {
332    let mut result = 0u128;
333    let mut val = x;
334    let mut i = 0;
335    while i < 64 {
336        if val & 1 != 0 {
337            result |= 1u128 << (2 * i);
338        }
339        val >>= 1;
340        i += 1;
341    }
342    result
343}
344
345/// Reduce at compile time
346///
347/// Uses the same algorithm as reduce_gf128 in simd.rs
348const fn const_reduce_256_to_128(hi: u128, lo: u128) -> u128 {
349    // Irreducible: x^128 + x^7 + x^2 + x + 1
350    // For bits in hi that would overflow when shifted left by 1,2,7,
351    // we compute tmp = bits that wrap around
352    let tmp = hi ^ (hi >> 127) ^ (hi >> 126) ^ (hi >> 121);
353
354    // Then apply the reduction: for each bit i in irreducible (0,1,2,7)
355    lo ^ tmp ^ (tmp << 1) ^ (tmp << 2) ^ (tmp << 7)
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::{BinaryElem128, BinaryFieldElement};
362
363    #[test]
364    fn test_invert_basic() {
365        // Test that x * x^(-1) = 1
366        let test_values: [u128; 8] = [
367            1,
368            2,
369            0x12345678,
370            0xdeadbeef,
371            0xffffffffffffffff,
372            0x123456789abcdef0123456789abcdef0,
373            u128::MAX,
374            u128::MAX - 1,
375        ];
376
377        for &x in &test_values {
378            let x_inv = invert_gf128(x);
379            let product = mul_gf128(x, x_inv);
380            assert_eq!(product, 1, "x * x^(-1) should be 1 for x = 0x{:032x}", x);
381        }
382    }
383
384    #[test]
385    fn test_invert_zero() {
386        assert_eq!(invert_gf128(0), 0);
387    }
388
389    #[test]
390    fn test_invert_matches_slow() {
391        // Compare with the slow fermat-based inverse for all test values
392        let test_values: [u128; 8] = [
393            1,
394            2,
395            0x12345678,
396            0xdeadbeef,
397            0xffffffffffffffff,
398            0x123456789abcdef0123456789abcdef0,
399            u128::MAX,
400            u128::MAX - 1,
401        ];
402
403        for &x in &test_values {
404            let fast_inv = invert_gf128(x);
405
406            // Slow inverse using existing implementation
407            let elem = BinaryElem128::from(x);
408            let slow_inv = elem.inv();
409            let slow_inv_val = slow_inv.poly().value();
410
411            assert_eq!(
412                fast_inv, slow_inv_val,
413                "fast and slow inverse should match for x = 0x{:032x}",
414                x
415            );
416        }
417    }
418
419    #[test]
420    fn test_square_basic() {
421        // x^2 in GF(2^128) should satisfy x^2 + x^2 = 0 (characteristic 2)
422        let x = 0x123456789abcdef0u128;
423        let x_sq = square_gf128(x);
424
425        // Also verify using multiplication
426        let x_sq_mul = mul_gf128(x, x);
427        assert_eq!(x_sq, x_sq_mul, "square should match multiplication");
428    }
429
430    #[test]
431    fn test_batch_invert() {
432        let values: Vec<u128> = vec![
433            1,
434            2,
435            0x12345678,
436            0xdeadbeef,
437            0xffffffffffffffff,
438            0x123456789abcdef0123456789abcdef0,
439            u128::MAX,
440            u128::MAX - 1,
441        ];
442
443        let batch_inverted = batch_invert_gf128(&values);
444
445        // Verify each batch result matches individual inversion
446        for (i, &v) in values.iter().enumerate() {
447            let individual_inv = invert_gf128(v);
448            assert_eq!(
449                batch_inverted[i], individual_inv,
450                "batch inversion should match individual for index {} value 0x{:032x}",
451                i, v
452            );
453        }
454    }
455
456    #[test]
457    fn test_batch_invert_with_zeros() {
458        let values: Vec<u128> = vec![1, 0, 2, 0, 3, 0];
459        let batch_inverted = batch_invert_gf128(&values);
460
461        // Zeros should remain zeros
462        assert_eq!(batch_inverted[1], 0);
463        assert_eq!(batch_inverted[3], 0);
464        assert_eq!(batch_inverted[5], 0);
465
466        // Non-zeros should be correctly inverted
467        assert_eq!(batch_inverted[0], invert_gf128(1));
468        assert_eq!(batch_inverted[2], invert_gf128(2));
469        assert_eq!(batch_inverted[4], invert_gf128(3));
470    }
471
472    #[test]
473    fn test_batch_invert_empty() {
474        let values: Vec<u128> = vec![];
475        let batch_inverted = batch_invert_gf128(&values);
476        assert!(batch_inverted.is_empty());
477    }
478
479    #[test]
480    fn test_batch_invert_single() {
481        let values = vec![0x12345678u128];
482        let batch_inverted = batch_invert_gf128(&values);
483        assert_eq!(batch_inverted[0], invert_gf128(0x12345678));
484    }
485
486    #[test]
487    fn test_spread_bits_correctness() {
488        // Verify the optimized spread_bits matches expected behavior
489        let test_cases: [(u64, u128); 4] = [
490            (0b1, 0b1),           // bit 0 -> bit 0
491            (0b10, 0b100),        // bit 1 -> bit 2
492            (0b101, 0b10001),     // bits 0,2 -> bits 0,4
493            (0b11111111, 0x5555), // first 8 bits spread
494        ];
495
496        for (input, expected) in test_cases {
497            let result = spread_bits(input);
498            assert_eq!(
499                result, expected,
500                "spread_bits(0b{:b}) should be 0b{:b}, got 0b{:b}",
501                input, expected, result
502            );
503        }
504    }
505}