Skip to main content

ligerito_binary_fields/
elem.rs

1use crate::poly::{BinaryPoly128, BinaryPoly16, BinaryPoly32, BinaryPoly64};
2use crate::{BinaryFieldElement, BinaryPolynomial};
3
4// Irreducible polynomials (matching Julia implementation)
5const IRREDUCIBLE_16: u32 = 0x1002D; // x^16 + x^5 + x^3 + x^2 + 1 (need to store in larger type)
6const IRREDUCIBLE_32: u64 = (1u64 << 32) | 0b11001 | (1 << 7) | (1 << 9) | (1 << 15); // x^32 + Conway polynomial
7
8macro_rules! impl_binary_elem {
9    ($name:ident, $poly_type:ident, $poly_double:ident, $value_type:ty, $value_double:ty, $irreducible:expr, $bitsize:expr) => {
10        #[repr(transparent)]
11        #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
12        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13        #[cfg_attr(
14            feature = "scale",
15            derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
16        )]
17        pub struct $name($poly_type);
18
19        // SAFETY: $name is repr(transparent) over $poly_type which wraps $value_type (a primitive)
20        unsafe impl bytemuck::Pod for $name {}
21        unsafe impl bytemuck::Zeroable for $name {}
22
23        impl $name {
24            pub const fn from_value(val: $value_type) -> Self {
25                Self($poly_type::new(val))
26            }
27
28            fn mod_irreducible_wide(poly: $poly_double) -> Self {
29                // julia-like reduction using leading_zeros (lzcnt)
30                let mut p = poly.value();
31                let irr = $irreducible;
32                let n = $bitsize;
33
34                // use leading_zeros instead of loop - julia uses this (binarypoly.jl:146)
35                let total_bits = core::mem::size_of::<$value_double>() * 8;
36
37                loop {
38                    if p == 0 {
39                        break; // avoid underflow when p is zero
40                    }
41
42                    let lz = p.leading_zeros() as usize;
43                    let high_bit = total_bits - lz - 1;
44
45                    if high_bit < n {
46                        break;
47                    }
48
49                    p ^= irr << (high_bit - n);
50                }
51
52                Self($poly_type::new(p as $value_type))
53            }
54        }
55
56        impl BinaryFieldElement for $name {
57            type Poly = $poly_type;
58
59            fn zero() -> Self {
60                Self($poly_type::zero())
61            }
62
63            fn one() -> Self {
64                Self($poly_type::one())
65            }
66
67            fn from_poly(poly: Self::Poly) -> Self {
68                // For from_poly, we assume the polynomial is already reduced
69                Self(poly)
70            }
71
72            fn poly(&self) -> Self::Poly {
73                self.0
74            }
75
76            fn add(&self, other: &Self) -> Self {
77                Self(self.0.add(&other.0))
78            }
79
80            fn mul(&self, other: &Self) -> Self {
81                // Perform full multiplication using double-width type
82                let a_wide = $poly_double::from_value(self.0.value() as u64);
83                let b_wide = $poly_double::from_value(other.0.value() as u64);
84                let prod_wide = a_wide.mul(&b_wide);
85
86                // Reduce modulo irreducible polynomial
87                Self::mod_irreducible_wide(prod_wide)
88            }
89
90            fn inv(&self) -> Self {
91                assert_ne!(self.0.value(), 0, "Cannot invert zero");
92
93                // For binary fields, we can use Fermat's little theorem efficiently
94                // a^(2^n - 2) = a^(-1) in GF(2^n)
95
96                // For small fields, use direct exponentiation
97                if $bitsize <= 16 {
98                    let exp = (1u64 << $bitsize) - 2;
99                    return self.pow(exp);
100                }
101
102                // For larger fields, use the addition chain method
103                // 2^n - 2 = 2 + 4 + 8 + ... + 2^(n-1)
104
105                // Start with a^2
106                let mut acc = self.mul(self);
107                let mut result = acc; // a^2
108
109                // Compute a^4, a^8, ..., a^(2^(n-1)) and multiply them all
110                for _ in 2..$bitsize {
111                    acc = acc.mul(&acc); // Square to get next power of 2
112                    result = result.mul(&acc);
113                }
114
115                result
116            }
117
118            fn pow(&self, mut exp: u64) -> Self {
119                if *self == Self::zero() {
120                    return Self::zero();
121                }
122
123                let mut result = Self::one();
124                let mut base = *self;
125
126                while exp > 0 {
127                    if exp & 1 == 1 {
128                        result = result.mul(&base);
129                    }
130                    base = base.mul(&base);
131                    exp >>= 1;
132                }
133
134                result
135            }
136        }
137
138        impl From<$value_type> for $name {
139            fn from(val: $value_type) -> Self {
140                Self::from_value(val)
141            }
142        }
143
144        #[cfg(feature = "rand")]
145        impl rand::distributions::Distribution<$name> for rand::distributions::Standard {
146            fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> $name {
147                $name::from_value(rng.gen())
148            }
149        }
150    };
151}
152
153impl_binary_elem!(
154    BinaryElem16,
155    BinaryPoly16,
156    BinaryPoly32,
157    u16,
158    u32,
159    IRREDUCIBLE_16,
160    16
161);
162impl_binary_elem!(
163    BinaryElem32,
164    BinaryPoly32,
165    BinaryPoly64,
166    u32,
167    u64,
168    IRREDUCIBLE_32,
169    32
170);
171
172// BinaryElem128 needs special handling since we don't have BinaryPoly256
173#[repr(transparent)]
174#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
175#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
176#[cfg_attr(
177    feature = "scale",
178    derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
179)]
180pub struct BinaryElem128(BinaryPoly128);
181
182// SAFETY: BinaryElem128 is repr(transparent) over BinaryPoly128 which wraps u128 (a primitive)
183unsafe impl bytemuck::Pod for BinaryElem128 {}
184unsafe impl bytemuck::Zeroable for BinaryElem128 {}
185
186impl BinaryElem128 {
187    pub const fn from_value(val: u128) -> Self {
188        Self(BinaryPoly128::new(val))
189    }
190}
191
192impl BinaryFieldElement for BinaryElem128 {
193    type Poly = BinaryPoly128;
194
195    fn zero() -> Self {
196        Self(BinaryPoly128::zero())
197    }
198
199    fn one() -> Self {
200        Self(BinaryPoly128::one())
201    }
202
203    fn from_poly(poly: Self::Poly) -> Self {
204        Self(poly)
205    }
206
207    fn poly(&self) -> Self::Poly {
208        self.0
209    }
210
211    fn add(&self, other: &Self) -> Self {
212        Self(self.0.add(&other.0))
213    }
214
215    fn mul(&self, other: &Self) -> Self {
216        // Use SIMD carryless multiplication + reduction for performance
217        use crate::simd::{carryless_mul_128_full, reduce_gf128};
218
219        let product = carryless_mul_128_full(self.0, other.0);
220        let reduced = reduce_gf128(product);
221
222        Self(reduced)
223    }
224
225    fn inv(&self) -> Self {
226        assert_ne!(self.0.value(), 0, "Cannot invert zero");
227
228        // Use Itoh-Tsujii fast inversion with precomputed nibble tables
229        // Reduces from ~127 multiplications to ~9
230        let result = crate::fast_inverse::invert_gf128(self.0.value());
231        Self(BinaryPoly128::new(result))
232    }
233
234    fn pow(&self, mut exp: u64) -> Self {
235        if *self == Self::zero() {
236            return Self::zero();
237        }
238
239        let mut result = Self::one();
240        let mut base = *self;
241
242        while exp > 0 {
243            if exp & 1 == 1 {
244                result = result.mul(&base);
245            }
246            base = base.mul(&base);
247            exp >>= 1;
248        }
249
250        result
251    }
252}
253
254impl BinaryElem128 {
255    /// Multiply by x (field element 2) - very fast special case
256    ///
257    /// In GF(2^128) with irreducible x^128 + x^7 + x^2 + x + 1,
258    /// multiplying by x is just a left shift with conditional reduction.
259    /// This is ~10x faster than general multiplication.
260    #[inline]
261    pub fn mul_by_x(&self) -> Self {
262        let val = self.0.value();
263
264        // Shift left by 1 (multiply by x in polynomial ring)
265        let shifted = val << 1;
266
267        // If bit 128 would be set (overflow), reduce by the irreducible polynomial
268        // x^128 = x^7 + x^2 + x + 1 (mod irreducible)
269        // So we add 0x87 if the high bit was set
270        let overflow = (val >> 127) & 1;
271        let reduced = shifted ^ (overflow * 0x87);
272
273        Self(BinaryPoly128::new(reduced))
274    }
275}
276
277impl From<u128> for BinaryElem128 {
278    fn from(val: u128) -> Self {
279        Self::from_value(val)
280    }
281}
282
283#[cfg(feature = "rand")]
284impl rand::distributions::Distribution<BinaryElem128> for rand::distributions::Standard {
285    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> BinaryElem128 {
286        BinaryElem128::from_value(rng.gen())
287    }
288}
289
290// BinaryElem64 needs special handling
291#[repr(transparent)]
292#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
293#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
294#[cfg_attr(
295    feature = "scale",
296    derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
297)]
298pub struct BinaryElem64(BinaryPoly64);
299
300// SAFETY: BinaryElem64 is repr(transparent) over BinaryPoly64 which wraps u64 (a primitive)
301unsafe impl bytemuck::Pod for BinaryElem64 {}
302unsafe impl bytemuck::Zeroable for BinaryElem64 {}
303
304impl BinaryElem64 {
305    pub const fn from_value(val: u64) -> Self {
306        Self(BinaryPoly64::new(val))
307    }
308}
309
310impl BinaryFieldElement for BinaryElem64 {
311    type Poly = BinaryPoly64;
312
313    fn zero() -> Self {
314        Self(BinaryPoly64::zero())
315    }
316
317    fn one() -> Self {
318        Self(BinaryPoly64::one())
319    }
320
321    fn from_poly(poly: Self::Poly) -> Self {
322        // For now, no reduction for 64-bit field
323        Self(poly)
324    }
325
326    fn poly(&self) -> Self::Poly {
327        self.0
328    }
329
330    fn add(&self, other: &Self) -> Self {
331        Self(self.0.add(&other.0))
332    }
333
334    fn mul(&self, other: &Self) -> Self {
335        Self(self.0.mul(&other.0))
336    }
337
338    fn inv(&self) -> Self {
339        assert_ne!(self.0.value(), 0, "Cannot invert zero");
340        // Fermat's little theorem: a^(2^64 - 2) = a^(-1)
341        self.pow(0xFFFFFFFFFFFFFFFE)
342    }
343
344    fn pow(&self, mut exp: u64) -> Self {
345        if *self == Self::zero() {
346            return Self::zero();
347        }
348
349        let mut result = Self::one();
350        let mut base = *self;
351
352        while exp > 0 {
353            if exp & 1 == 1 {
354                result = result.mul(&base);
355            }
356            base = base.mul(&base);
357            exp >>= 1;
358        }
359
360        result
361    }
362}
363
364// Field embeddings for Ligerito
365impl From<BinaryElem16> for BinaryElem32 {
366    fn from(elem: BinaryElem16) -> Self {
367        BinaryElem32::from(elem.0.value() as u32)
368    }
369}
370
371impl From<BinaryElem16> for BinaryElem64 {
372    fn from(elem: BinaryElem16) -> Self {
373        BinaryElem64(BinaryPoly64::new(elem.0.value() as u64))
374    }
375}
376
377impl From<BinaryElem16> for BinaryElem128 {
378    fn from(elem: BinaryElem16) -> Self {
379        BinaryElem128::from(elem.0.value() as u128)
380    }
381}
382
383impl From<BinaryElem32> for BinaryElem64 {
384    fn from(elem: BinaryElem32) -> Self {
385        BinaryElem64(BinaryPoly64::new(elem.0.value() as u64))
386    }
387}
388
389impl From<BinaryElem32> for BinaryElem128 {
390    fn from(elem: BinaryElem32) -> Self {
391        BinaryElem128::from(elem.0.value() as u128)
392    }
393}
394
395impl From<BinaryElem64> for BinaryElem128 {
396    fn from(elem: BinaryElem64) -> Self {
397        BinaryElem128::from(elem.0.value() as u128)
398    }
399}