ligerito_binary_fields/
elem.rs

1use crate::{BinaryFieldElement, BinaryPolynomial};
2use crate::poly::{BinaryPoly16, BinaryPoly32, BinaryPoly64, BinaryPoly128};
3
4// Irreducible polynomials (matching Julia implementation)
5const IRREDUCIBLE_16: u32 = 0x1002D;  // x^16 + x^5 + x^3 + x^2 + 1 (need to store in larger type)
6const IRREDUCIBLE_32: u64 = (1u64 << 32) | 0b11001 | (1 << 7) | (1 << 9) | (1 << 15);  // x^32 + Conway polynomial
7
8macro_rules! impl_binary_elem {
9    ($name:ident, $poly_type:ident, $poly_double:ident, $value_type:ty, $value_double:ty, $irreducible:expr, $bitsize:expr) => {
10        #[repr(transparent)]
11        #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
12        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13        pub struct $name($poly_type);
14
15        // SAFETY: $name is repr(transparent) over $poly_type which wraps $value_type (a primitive)
16        unsafe impl bytemuck::Pod for $name {}
17        unsafe impl bytemuck::Zeroable for $name {}
18
19        impl $name {
20            pub const fn from_value(val: $value_type) -> Self {
21                Self($poly_type::new(val))
22            }
23
24            fn mod_irreducible_wide(poly: $poly_double) -> Self {
25                // julia-like reduction using leading_zeros (lzcnt)
26                let mut p = poly.value();
27                let irr = $irreducible;
28                let n = $bitsize;
29
30                // use leading_zeros instead of loop - julia uses this (binarypoly.jl:146)
31                let total_bits = core::mem::size_of::<$value_double>() * 8;
32
33                loop {
34                    if p == 0 {
35                        break;  // avoid underflow when p is zero
36                    }
37
38                    let lz = p.leading_zeros() as usize;
39                    let high_bit = total_bits - lz - 1;
40
41                    if high_bit < n {
42                        break;
43                    }
44
45                    p ^= irr << (high_bit - n);
46                }
47
48                Self($poly_type::new(p as $value_type))
49            }
50        }
51
52        impl BinaryFieldElement for $name {
53            type Poly = $poly_type;
54
55            fn zero() -> Self {
56                Self($poly_type::zero())
57            }
58
59            fn one() -> Self {
60                Self($poly_type::one())
61            }
62
63            fn from_poly(poly: Self::Poly) -> Self {
64                // For from_poly, we assume the polynomial is already reduced
65                Self(poly)
66            }
67
68            fn poly(&self) -> Self::Poly {
69                self.0
70            }
71
72            fn add(&self, other: &Self) -> Self {
73                Self(self.0.add(&other.0))
74            }
75
76            fn mul(&self, other: &Self) -> Self {
77                // Perform full multiplication using double-width type
78                let a_wide = $poly_double::from_value(self.0.value() as u64);
79                let b_wide = $poly_double::from_value(other.0.value() as u64);
80                let prod_wide = a_wide.mul(&b_wide);
81
82                // Reduce modulo irreducible polynomial
83                Self::mod_irreducible_wide(prod_wide)
84            }
85
86            fn inv(&self) -> Self {
87                assert_ne!(self.0.value(), 0, "Cannot invert zero");
88
89                // For binary fields, we can use Fermat's little theorem efficiently
90                // a^(2^n - 2) = a^(-1) in GF(2^n)
91
92                // For small fields, use direct exponentiation
93                if $bitsize <= 16 {
94                    let exp = (1u64 << $bitsize) - 2;
95                    return self.pow(exp);
96                }
97
98                // For larger fields, use the addition chain method
99                // 2^n - 2 = 2 + 4 + 8 + ... + 2^(n-1)
100
101                // Start with a^2
102                let mut acc = self.mul(self);
103                let mut result = acc; // a^2
104
105                // Compute a^4, a^8, ..., a^(2^(n-1)) and multiply them all
106                for _ in 2..$bitsize {
107                    acc = acc.mul(&acc); // Square to get next power of 2
108                    result = result.mul(&acc);
109                }
110
111                result
112            }
113
114            fn pow(&self, mut exp: u64) -> Self {
115                if *self == Self::zero() {
116                    return Self::zero();
117                }
118
119                let mut result = Self::one();
120                let mut base = *self;
121
122                while exp > 0 {
123                    if exp & 1 == 1 {
124                        result = result.mul(&base);
125                    }
126                    base = base.mul(&base);
127                    exp >>= 1;
128                }
129
130                result
131            }
132        }
133
134        impl From<$value_type> for $name {
135            fn from(val: $value_type) -> Self {
136                Self::from_value(val)
137            }
138        }
139
140        impl rand::distributions::Distribution<$name> for rand::distributions::Standard {
141            fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> $name {
142                $name::from_value(rng.gen())
143            }
144        }
145    };
146}
147
148impl_binary_elem!(BinaryElem16, BinaryPoly16, BinaryPoly32, u16, u32, IRREDUCIBLE_16, 16);
149impl_binary_elem!(BinaryElem32, BinaryPoly32, BinaryPoly64, u32, u64, IRREDUCIBLE_32, 32);
150
151// BinaryElem128 needs special handling since we don't have BinaryPoly256
152#[repr(transparent)]
153#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
154#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
155pub struct BinaryElem128(BinaryPoly128);
156
157// SAFETY: BinaryElem128 is repr(transparent) over BinaryPoly128 which wraps u128 (a primitive)
158unsafe impl bytemuck::Pod for BinaryElem128 {}
159unsafe impl bytemuck::Zeroable for BinaryElem128 {}
160
161impl BinaryElem128 {
162    pub const fn from_value(val: u128) -> Self {
163        Self(BinaryPoly128::new(val))
164    }
165}
166
167impl BinaryFieldElement for BinaryElem128 {
168    type Poly = BinaryPoly128;
169
170    fn zero() -> Self {
171        Self(BinaryPoly128::zero())
172    }
173
174    fn one() -> Self {
175        Self(BinaryPoly128::one())
176    }
177
178    fn from_poly(poly: Self::Poly) -> Self {
179        Self(poly)
180    }
181
182    fn poly(&self) -> Self::Poly {
183        self.0
184    }
185
186    fn add(&self, other: &Self) -> Self {
187        Self(self.0.add(&other.0))
188    }
189
190    fn mul(&self, other: &Self) -> Self {
191        // Use SIMD carryless multiplication + reduction for performance
192        use crate::simd::{carryless_mul_128_full, reduce_gf128};
193
194        let product = carryless_mul_128_full(self.0, other.0);
195        let reduced = reduce_gf128(product);
196
197        Self(reduced)
198    }
199
200    fn inv(&self) -> Self {
201        assert_ne!(self.0.value(), 0, "Cannot invert zero");
202
203        // For GF(2^128), use Fermat's little theorem
204        // a^(2^128 - 2) = a^(-1)
205        // 2^128 - 2 = 111...110 in binary (127 ones followed by a zero)
206
207        // We can compute this efficiently using the fact that:
208        // 2^128 - 2 = 2 + 4 + 8 + ... + 2^127
209        // So a^(2^128 - 2) = a^2 * a^4 * a^8 * ... * a^(2^127)
210
211        // Start with a^2
212        let mut square = self.mul(self);
213        let mut result = square; // a^2
214
215        // Compute a^4, a^8, ..., a^(2^127) and multiply them all together
216        for _ in 1..127 {
217            square = square.mul(&square); // Square to get next power of 2
218            result = result.mul(&square); // Multiply into result
219        }
220
221        result
222    }
223
224    fn pow(&self, mut exp: u64) -> Self {
225        if *self == Self::zero() {
226            return Self::zero();
227        }
228
229        let mut result = Self::one();
230        let mut base = *self;
231
232        while exp > 0 {
233            if exp & 1 == 1 {
234                result = result.mul(&base);
235            }
236            base = base.mul(&base);
237            exp >>= 1;
238        }
239
240        result
241    }
242}
243
244impl From<u128> for BinaryElem128 {
245    fn from(val: u128) -> Self {
246        Self::from_value(val)
247    }
248}
249
250impl rand::distributions::Distribution<BinaryElem128> for rand::distributions::Standard {
251    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> BinaryElem128 {
252        BinaryElem128::from_value(rng.gen())
253    }
254}
255
256// BinaryElem64 needs special handling
257#[repr(transparent)]
258#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
259pub struct BinaryElem64(BinaryPoly64);
260
261// SAFETY: BinaryElem64 is repr(transparent) over BinaryPoly64 which wraps u64 (a primitive)
262unsafe impl bytemuck::Pod for BinaryElem64 {}
263unsafe impl bytemuck::Zeroable for BinaryElem64 {}
264
265impl BinaryElem64 {
266    pub const fn from_value(val: u64) -> Self {
267        Self(BinaryPoly64::new(val))
268    }
269}
270
271impl BinaryFieldElement for BinaryElem64 {
272    type Poly = BinaryPoly64;
273
274    fn zero() -> Self {
275        Self(BinaryPoly64::zero())
276    }
277
278    fn one() -> Self {
279        Self(BinaryPoly64::one())
280    }
281
282    fn from_poly(poly: Self::Poly) -> Self {
283        // For now, no reduction for 64-bit field
284        Self(poly)
285    }
286
287    fn poly(&self) -> Self::Poly {
288        self.0
289    }
290
291    fn add(&self, other: &Self) -> Self {
292        Self(self.0.add(&other.0))
293    }
294
295    fn mul(&self, other: &Self) -> Self {
296        Self(self.0.mul(&other.0))
297    }
298
299    fn inv(&self) -> Self {
300        assert_ne!(self.0.value(), 0, "Cannot invert zero");
301        // Fermat's little theorem: a^(2^64 - 2) = a^(-1)
302        self.pow(0xFFFFFFFFFFFFFFFE)
303    }
304
305    fn pow(&self, mut exp: u64) -> Self {
306        if *self == Self::zero() {
307            return Self::zero();
308        }
309
310        let mut result = Self::one();
311        let mut base = *self;
312
313        while exp > 0 {
314            if exp & 1 == 1 {
315                result = result.mul(&base);
316            }
317            base = base.mul(&base);
318            exp >>= 1;
319        }
320
321        result
322    }
323}
324
325// Field embeddings for Ligerito
326impl From<BinaryElem16> for BinaryElem32 {
327    fn from(elem: BinaryElem16) -> Self {
328        BinaryElem32::from(elem.0.value() as u32)
329    }
330}
331
332impl From<BinaryElem16> for BinaryElem64 {
333    fn from(elem: BinaryElem16) -> Self {
334        BinaryElem64(BinaryPoly64::new(elem.0.value() as u64))
335    }
336}
337
338impl From<BinaryElem16> for BinaryElem128 {
339    fn from(elem: BinaryElem16) -> Self {
340        BinaryElem128::from(elem.0.value() as u128)
341    }
342}
343
344impl From<BinaryElem32> for BinaryElem64 {
345    fn from(elem: BinaryElem32) -> Self {
346        BinaryElem64(BinaryPoly64::new(elem.0.value() as u64))
347    }
348}
349
350impl From<BinaryElem32> for BinaryElem128 {
351    fn from(elem: BinaryElem32) -> Self {
352        BinaryElem128::from(elem.0.value() as u128)
353    }
354}
355
356impl From<BinaryElem64> for BinaryElem128 {
357    fn from(elem: BinaryElem64) -> Self {
358        BinaryElem128::from(elem.0.value() as u128)
359    }
360}