ligerito_binary_fields/
elem.rs

1use crate::{BinaryFieldElement, BinaryPolynomial};
2use crate::poly::{BinaryPoly16, BinaryPoly32, BinaryPoly64, BinaryPoly128};
3
4// Irreducible polynomials (matching Julia implementation)
5const IRREDUCIBLE_16: u32 = 0x1002D;  // x^16 + x^5 + x^3 + x^2 + 1 (need to store in larger type)
6const IRREDUCIBLE_32: u64 = (1u64 << 32) | 0b11001 | (1 << 7) | (1 << 9) | (1 << 15);  // x^32 + Conway polynomial
7
8macro_rules! impl_binary_elem {
9    ($name:ident, $poly_type:ident, $poly_double:ident, $value_type:ty, $value_double:ty, $irreducible:expr, $bitsize:expr) => {
10        #[repr(transparent)]
11        #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
12        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13        #[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
14        pub struct $name($poly_type);
15
16        // SAFETY: $name is repr(transparent) over $poly_type which wraps $value_type (a primitive)
17        unsafe impl bytemuck::Pod for $name {}
18        unsafe impl bytemuck::Zeroable for $name {}
19
20        impl $name {
21            pub const fn from_value(val: $value_type) -> Self {
22                Self($poly_type::new(val))
23            }
24
25            fn mod_irreducible_wide(poly: $poly_double) -> Self {
26                // julia-like reduction using leading_zeros (lzcnt)
27                let mut p = poly.value();
28                let irr = $irreducible;
29                let n = $bitsize;
30
31                // use leading_zeros instead of loop - julia uses this (binarypoly.jl:146)
32                let total_bits = core::mem::size_of::<$value_double>() * 8;
33
34                loop {
35                    if p == 0 {
36                        break;  // avoid underflow when p is zero
37                    }
38
39                    let lz = p.leading_zeros() as usize;
40                    let high_bit = total_bits - lz - 1;
41
42                    if high_bit < n {
43                        break;
44                    }
45
46                    p ^= irr << (high_bit - n);
47                }
48
49                Self($poly_type::new(p as $value_type))
50            }
51        }
52
53        impl BinaryFieldElement for $name {
54            type Poly = $poly_type;
55
56            fn zero() -> Self {
57                Self($poly_type::zero())
58            }
59
60            fn one() -> Self {
61                Self($poly_type::one())
62            }
63
64            fn from_poly(poly: Self::Poly) -> Self {
65                // For from_poly, we assume the polynomial is already reduced
66                Self(poly)
67            }
68
69            fn poly(&self) -> Self::Poly {
70                self.0
71            }
72
73            fn add(&self, other: &Self) -> Self {
74                Self(self.0.add(&other.0))
75            }
76
77            fn mul(&self, other: &Self) -> Self {
78                // Perform full multiplication using double-width type
79                let a_wide = $poly_double::from_value(self.0.value() as u64);
80                let b_wide = $poly_double::from_value(other.0.value() as u64);
81                let prod_wide = a_wide.mul(&b_wide);
82
83                // Reduce modulo irreducible polynomial
84                Self::mod_irreducible_wide(prod_wide)
85            }
86
87            fn inv(&self) -> Self {
88                assert_ne!(self.0.value(), 0, "Cannot invert zero");
89
90                // For binary fields, we can use Fermat's little theorem efficiently
91                // a^(2^n - 2) = a^(-1) in GF(2^n)
92
93                // For small fields, use direct exponentiation
94                if $bitsize <= 16 {
95                    let exp = (1u64 << $bitsize) - 2;
96                    return self.pow(exp);
97                }
98
99                // For larger fields, use the addition chain method
100                // 2^n - 2 = 2 + 4 + 8 + ... + 2^(n-1)
101
102                // Start with a^2
103                let mut acc = self.mul(self);
104                let mut result = acc; // a^2
105
106                // Compute a^4, a^8, ..., a^(2^(n-1)) and multiply them all
107                for _ in 2..$bitsize {
108                    acc = acc.mul(&acc); // Square to get next power of 2
109                    result = result.mul(&acc);
110                }
111
112                result
113            }
114
115            fn pow(&self, mut exp: u64) -> Self {
116                if *self == Self::zero() {
117                    return Self::zero();
118                }
119
120                let mut result = Self::one();
121                let mut base = *self;
122
123                while exp > 0 {
124                    if exp & 1 == 1 {
125                        result = result.mul(&base);
126                    }
127                    base = base.mul(&base);
128                    exp >>= 1;
129                }
130
131                result
132            }
133        }
134
135        impl From<$value_type> for $name {
136            fn from(val: $value_type) -> Self {
137                Self::from_value(val)
138            }
139        }
140
141        #[cfg(feature = "rand")]
142        impl rand::distributions::Distribution<$name> for rand::distributions::Standard {
143            fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> $name {
144                $name::from_value(rng.gen())
145            }
146        }
147    };
148}
149
150impl_binary_elem!(BinaryElem16, BinaryPoly16, BinaryPoly32, u16, u32, IRREDUCIBLE_16, 16);
151impl_binary_elem!(BinaryElem32, BinaryPoly32, BinaryPoly64, u32, u64, IRREDUCIBLE_32, 32);
152
153// BinaryElem128 needs special handling since we don't have BinaryPoly256
154#[repr(transparent)]
155#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
156#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
157#[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
158pub struct BinaryElem128(BinaryPoly128);
159
160// SAFETY: BinaryElem128 is repr(transparent) over BinaryPoly128 which wraps u128 (a primitive)
161unsafe impl bytemuck::Pod for BinaryElem128 {}
162unsafe impl bytemuck::Zeroable for BinaryElem128 {}
163
164impl BinaryElem128 {
165    pub const fn from_value(val: u128) -> Self {
166        Self(BinaryPoly128::new(val))
167    }
168}
169
170impl BinaryFieldElement for BinaryElem128 {
171    type Poly = BinaryPoly128;
172
173    fn zero() -> Self {
174        Self(BinaryPoly128::zero())
175    }
176
177    fn one() -> Self {
178        Self(BinaryPoly128::one())
179    }
180
181    fn from_poly(poly: Self::Poly) -> Self {
182        Self(poly)
183    }
184
185    fn poly(&self) -> Self::Poly {
186        self.0
187    }
188
189    fn add(&self, other: &Self) -> Self {
190        Self(self.0.add(&other.0))
191    }
192
193    fn mul(&self, other: &Self) -> Self {
194        // Use SIMD carryless multiplication + reduction for performance
195        use crate::simd::{carryless_mul_128_full, reduce_gf128};
196
197        let product = carryless_mul_128_full(self.0, other.0);
198        let reduced = reduce_gf128(product);
199
200        Self(reduced)
201    }
202
203    fn inv(&self) -> Self {
204        assert_ne!(self.0.value(), 0, "Cannot invert zero");
205
206        // For GF(2^128), use Fermat's little theorem
207        // a^(2^128 - 2) = a^(-1)
208        // 2^128 - 2 = 111...110 in binary (127 ones followed by a zero)
209
210        // We can compute this efficiently using the fact that:
211        // 2^128 - 2 = 2 + 4 + 8 + ... + 2^127
212        // So a^(2^128 - 2) = a^2 * a^4 * a^8 * ... * a^(2^127)
213
214        // Start with a^2
215        let mut square = self.mul(self);
216        let mut result = square; // a^2
217
218        // Compute a^4, a^8, ..., a^(2^127) and multiply them all together
219        for _ in 1..127 {
220            square = square.mul(&square); // Square to get next power of 2
221            result = result.mul(&square); // Multiply into result
222        }
223
224        result
225    }
226
227    fn pow(&self, mut exp: u64) -> Self {
228        if *self == Self::zero() {
229            return Self::zero();
230        }
231
232        let mut result = Self::one();
233        let mut base = *self;
234
235        while exp > 0 {
236            if exp & 1 == 1 {
237                result = result.mul(&base);
238            }
239            base = base.mul(&base);
240            exp >>= 1;
241        }
242
243        result
244    }
245}
246
247impl From<u128> for BinaryElem128 {
248    fn from(val: u128) -> Self {
249        Self::from_value(val)
250    }
251}
252
253#[cfg(feature = "rand")]
254impl rand::distributions::Distribution<BinaryElem128> for rand::distributions::Standard {
255    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> BinaryElem128 {
256        BinaryElem128::from_value(rng.gen())
257    }
258}
259
260// BinaryElem64 needs special handling
261#[repr(transparent)]
262#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
263#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
264#[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
265pub struct BinaryElem64(BinaryPoly64);
266
267// SAFETY: BinaryElem64 is repr(transparent) over BinaryPoly64 which wraps u64 (a primitive)
268unsafe impl bytemuck::Pod for BinaryElem64 {}
269unsafe impl bytemuck::Zeroable for BinaryElem64 {}
270
271impl BinaryElem64 {
272    pub const fn from_value(val: u64) -> Self {
273        Self(BinaryPoly64::new(val))
274    }
275}
276
277impl BinaryFieldElement for BinaryElem64 {
278    type Poly = BinaryPoly64;
279
280    fn zero() -> Self {
281        Self(BinaryPoly64::zero())
282    }
283
284    fn one() -> Self {
285        Self(BinaryPoly64::one())
286    }
287
288    fn from_poly(poly: Self::Poly) -> Self {
289        // For now, no reduction for 64-bit field
290        Self(poly)
291    }
292
293    fn poly(&self) -> Self::Poly {
294        self.0
295    }
296
297    fn add(&self, other: &Self) -> Self {
298        Self(self.0.add(&other.0))
299    }
300
301    fn mul(&self, other: &Self) -> Self {
302        Self(self.0.mul(&other.0))
303    }
304
305    fn inv(&self) -> Self {
306        assert_ne!(self.0.value(), 0, "Cannot invert zero");
307        // Fermat's little theorem: a^(2^64 - 2) = a^(-1)
308        self.pow(0xFFFFFFFFFFFFFFFE)
309    }
310
311    fn pow(&self, mut exp: u64) -> Self {
312        if *self == Self::zero() {
313            return Self::zero();
314        }
315
316        let mut result = Self::one();
317        let mut base = *self;
318
319        while exp > 0 {
320            if exp & 1 == 1 {
321                result = result.mul(&base);
322            }
323            base = base.mul(&base);
324            exp >>= 1;
325        }
326
327        result
328    }
329}
330
331// Field embeddings for Ligerito
332impl From<BinaryElem16> for BinaryElem32 {
333    fn from(elem: BinaryElem16) -> Self {
334        BinaryElem32::from(elem.0.value() as u32)
335    }
336}
337
338impl From<BinaryElem16> for BinaryElem64 {
339    fn from(elem: BinaryElem16) -> Self {
340        BinaryElem64(BinaryPoly64::new(elem.0.value() as u64))
341    }
342}
343
344impl From<BinaryElem16> for BinaryElem128 {
345    fn from(elem: BinaryElem16) -> Self {
346        BinaryElem128::from(elem.0.value() as u128)
347    }
348}
349
350impl From<BinaryElem32> for BinaryElem64 {
351    fn from(elem: BinaryElem32) -> Self {
352        BinaryElem64(BinaryPoly64::new(elem.0.value() as u64))
353    }
354}
355
356impl From<BinaryElem32> for BinaryElem128 {
357    fn from(elem: BinaryElem32) -> Self {
358        BinaryElem128::from(elem.0.value() as u128)
359    }
360}
361
362impl From<BinaryElem64> for BinaryElem128 {
363    fn from(elem: BinaryElem64) -> Self {
364        BinaryElem128::from(elem.0.value() as u128)
365    }
366}