ligerito_binary_fields/
poly.rs

1// src/poly.rs
2use crate::BinaryPolynomial;
3
4// Macro to implement binary polynomials for different sizes
5macro_rules! impl_binary_poly {
6    ($name:ident, $value_type:ty, $double_name:ident) => {
7        #[repr(transparent)]
8        #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
9        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10        pub struct $name($value_type);
11
12        // SAFETY: $name is repr(transparent) over $value_type (a primitive integer type)
13        unsafe impl bytemuck::Pod for $name {}
14        unsafe impl bytemuck::Zeroable for $name {}
15
16        impl $name {
17            pub const fn new(val: $value_type) -> Self {
18                Self(val)
19            }
20
21            pub fn value(&self) -> $value_type {
22                self.0
23            }
24
25            pub fn shl(&self, n: u32) -> Self {
26                Self(self.0 << n)
27            }
28
29            pub fn shr(&self, n: u32) -> Self {
30                Self(self.0 >> n)
31            }
32
33            pub fn leading_zeros(&self) -> u32 {
34                self.0.leading_zeros()
35            }
36
37            #[allow(dead_code)]
38            pub fn split(&self) -> (Self, Self) {
39                let half_bits = core::mem::size_of::<$value_type>() * 4;
40                let mask = ((1u64 << half_bits) - 1) as $value_type;
41                let lo = Self(self.0 & mask);
42                let hi = Self(self.0 >> half_bits);
43                (hi, lo)
44            }
45        }
46
47        impl BinaryPolynomial for $name {
48            type Value = $value_type;
49
50            fn zero() -> Self {
51                Self(0)
52            }
53
54            fn one() -> Self {
55                Self(1)
56            }
57
58            fn from_value(val: u64) -> Self {
59                Self(val as $value_type)
60            }
61
62            fn value(&self) -> Self::Value {
63                self.0
64            }
65
66            fn add(&self, other: &Self) -> Self {
67                Self(self.0 ^ other.0)
68            }
69
70            fn mul(&self, other: &Self) -> Self {
71                // constant-time carryless multiplication
72                let mut result = 0 as $value_type;
73                let a = self.0;
74                let b = other.0;
75                let bits = core::mem::size_of::<$value_type>() * 8;
76
77                for i in 0..bits {
78                    // constant-time conditional xor
79                    let mask = (0 as $value_type).wrapping_sub((b >> i) & 1);
80                    result ^= a.wrapping_shl(i as u32) & mask;
81                }
82
83                Self(result)
84            }
85
86            fn div_rem(&self, divisor: &Self) -> (Self, Self) {
87                assert_ne!(divisor.0, 0, "Division by zero");
88
89                let mut quotient = Self::zero();
90                let mut remainder = *self;
91
92                if remainder.0 == 0 {
93                    return (quotient, remainder);
94                }
95
96                let divisor_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - divisor.leading_zeros();
97                let mut remainder_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - remainder.leading_zeros();
98
99                while remainder_bits >= divisor_bits && remainder.0 != 0 {
100                    let shift = remainder_bits - divisor_bits;
101                    quotient.0 |= 1 << shift;
102                    remainder.0 ^= divisor.0 << shift;
103                    remainder_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - remainder.leading_zeros();
104                }
105
106                (quotient, remainder)
107            }
108        }
109
110        impl From<$value_type> for $name {
111            fn from(val: $value_type) -> Self {
112                Self(val)
113            }
114        }
115    };
116}
117
118// Define polynomial types
119impl_binary_poly!(BinaryPoly16, u16, BinaryPoly32);
120impl_binary_poly!(BinaryPoly32, u32, BinaryPoly64);
121
122// BinaryPoly64 with SIMD support
123#[repr(transparent)]
124#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
125#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
126pub struct BinaryPoly64(u64);
127
128// SAFETY: BinaryPoly64 is repr(transparent) over u64 (a primitive)
129unsafe impl bytemuck::Pod for BinaryPoly64 {}
130unsafe impl bytemuck::Zeroable for BinaryPoly64 {}
131
132impl BinaryPoly64 {
133    pub const fn new(val: u64) -> Self {
134        Self(val)
135    }
136
137    pub fn value(&self) -> u64 {
138        self.0
139    }
140
141    pub fn shl(&self, n: u32) -> Self {
142        Self(self.0 << n)
143    }
144
145    pub fn shr(&self, n: u32) -> Self {
146        Self(self.0 >> n)
147    }
148
149    pub fn leading_zeros(&self) -> u32 {
150        self.0.leading_zeros()
151    }
152
153    pub fn split(&self) -> (BinaryPoly32, BinaryPoly32) {
154        let lo = BinaryPoly32::new(self.0 as u32);
155        let hi = BinaryPoly32::new((self.0 >> 32) as u32);
156        (hi, lo)
157    }
158}
159
160impl BinaryPolynomial for BinaryPoly64 {
161    type Value = u64;
162
163    fn zero() -> Self {
164        Self(0)
165    }
166
167    fn one() -> Self {
168        Self(1)
169    }
170
171    fn from_value(val: u64) -> Self {
172        Self(val)
173    }
174
175    fn value(&self) -> Self::Value {
176        self.0
177    }
178
179    fn add(&self, other: &Self) -> Self {
180        Self(self.0 ^ other.0)
181    }
182
183    fn mul(&self, other: &Self) -> Self {
184        use crate::simd::carryless_mul_64;
185        carryless_mul_64(*self, *other).truncate_to_64()
186    }
187
188    fn div_rem(&self, divisor: &Self) -> (Self, Self) {
189        assert_ne!(divisor.0, 0, "Division by zero");
190
191        let mut quotient = Self::zero();
192        let mut remainder = *self;
193
194        if remainder.0 == 0 {
195            return (quotient, remainder);
196        }
197
198        let divisor_bits = 64 - divisor.leading_zeros();
199        let mut remainder_bits = 64 - remainder.leading_zeros();
200
201        while remainder_bits >= divisor_bits && remainder.0 != 0 {
202            let shift = remainder_bits - divisor_bits;
203            quotient.0 |= 1 << shift;
204            remainder.0 ^= divisor.0 << shift;
205            remainder_bits = 64 - remainder.leading_zeros();
206        }
207
208        (quotient, remainder)
209    }
210}
211
212impl From<u64> for BinaryPoly64 {
213    fn from(val: u64) -> Self {
214        Self(val)
215    }
216}
217
218// BinaryPoly128
219#[repr(transparent)]
220#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
221#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
222pub struct BinaryPoly128(u128);
223
224// SAFETY: BinaryPoly128 is repr(transparent) over u128 (a primitive)
225unsafe impl bytemuck::Pod for BinaryPoly128 {}
226unsafe impl bytemuck::Zeroable for BinaryPoly128 {}
227
228impl BinaryPoly128 {
229    pub const fn new(val: u128) -> Self {
230        Self(val)
231    }
232
233    pub fn value(&self) -> u128 {
234        self.0
235    }
236
237    pub fn truncate_to_64(&self) -> BinaryPoly64 {
238        BinaryPoly64::new(self.0 as u64)
239    }
240
241    pub fn split(&self) -> (BinaryPoly64, BinaryPoly64) {
242        let lo = BinaryPoly64::new(self.0 as u64);
243        let hi = BinaryPoly64::new((self.0 >> 64) as u64);
244        (hi, lo)
245    }
246
247    pub fn leading_zeros(&self) -> u32 {
248        self.0.leading_zeros()
249    }
250
251    // full 128x128 -> 256 bit multiplication
252    pub fn mul_full(&self, other: &Self) -> BinaryPoly256 {
253        use crate::simd::carryless_mul_128_full;
254        carryless_mul_128_full(*self, *other)
255    }
256}
257
258impl BinaryPolynomial for BinaryPoly128 {
259    type Value = u128;
260
261    fn zero() -> Self {
262        Self(0)
263    }
264
265    fn one() -> Self {
266        Self(1)
267    }
268
269    fn from_value(val: u64) -> Self {
270        Self(val as u128)
271    }
272
273    fn value(&self) -> Self::Value {
274        self.0
275    }
276
277    fn add(&self, other: &Self) -> Self {
278        Self(self.0 ^ other.0)
279    }
280
281    fn mul(&self, other: &Self) -> Self {
282        use crate::simd::carryless_mul_128;
283        carryless_mul_128(*self, *other)
284    }
285
286    fn div_rem(&self, divisor: &Self) -> (Self, Self) {
287        assert_ne!(divisor.0, 0, "Division by zero");
288
289        let mut quotient = Self::zero();
290        let mut remainder = *self;
291
292        if remainder.0 == 0 {
293            return (quotient, remainder);
294        }
295
296        let divisor_bits = 128 - divisor.leading_zeros();
297        let mut remainder_bits = 128 - remainder.leading_zeros();
298
299        while remainder_bits >= divisor_bits && remainder.0 != 0 {
300            let shift = remainder_bits - divisor_bits;
301            quotient.0 |= 1u128 << shift;
302            remainder.0 ^= divisor.0 << shift;
303            remainder_bits = 128 - remainder.leading_zeros();
304        }
305
306        (quotient, remainder)
307    }
308}
309
310impl From<u128> for BinaryPoly128 {
311    fn from(val: u128) -> Self {
312        Self(val)
313    }
314}
315
316// BinaryPoly256 for intermediate calculations
317#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
318#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
319pub struct BinaryPoly256 {
320    hi: u128,
321    lo: u128,
322}
323
324impl BinaryPoly256 {
325    pub fn from_parts(hi: u128, lo: u128) -> Self {
326        Self { hi, lo }
327    }
328
329    pub fn split(&self) -> (BinaryPoly128, BinaryPoly128) {
330        (BinaryPoly128::new(self.hi), BinaryPoly128::new(self.lo))
331    }
332
333    /// reduce modulo a 128-bit polynomial (for field operations)
334    pub fn reduce_mod(&self, modulus: &BinaryPoly128) -> BinaryPoly128 {
335        // for irreducible polynomials of form x^128 + lower terms,
336        // we can use efficient reduction
337
338        // special case for GF(2^128) with x^128 + x^7 + x^2 + x + 1
339        if modulus.value() == (1u128 << 127) | 0x87 {
340            // efficient reduction for gcm polynomial
341            let mut result = self.lo;
342            let mut high = self.hi;
343
344            // reduce 128 bits at a time
345            while high != 0 {
346                // x^128 = x^7 + x^2 + x + 1
347                let feedback = high.wrapping_shl(7)
348                    ^ high.wrapping_shl(2)
349                    ^ high.wrapping_shl(1)
350                    ^ high;
351
352                result ^= feedback;
353                high >>= 121; // process remaining bits
354            }
355
356            return BinaryPoly128::new(result);
357        }
358
359        // general case: polynomial long division
360        if self.hi == 0 {
361            // already reduced
362            return BinaryPoly128::new(self.lo);
363        }
364
365        // work with a copy
366        let mut remainder_hi = self.hi;
367        let mut remainder_lo = self.lo;
368
369        // get modulus without the leading bit
370        let mod_bits = 128 - modulus.leading_zeros();
371        let mod_val = modulus.value();
372        let mod_mask = mod_val ^ (1u128 << (mod_bits - 1));
373
374        // reduce high 128 bits
375        while remainder_hi != 0 {
376            let shift = remainder_hi.leading_zeros();
377
378            if shift < 128 {
379                // align the leading bit
380                let bit_pos = 127 - shift;
381
382                // xor with modulus shifted appropriately
383                remainder_hi ^= 1u128 << bit_pos;
384
385                // xor lower bits of modulus into result
386                if bit_pos >= (mod_bits - 1) {
387                    remainder_hi ^= mod_mask << (bit_pos - (mod_bits - 1));
388                } else {
389                    let right_shift = (mod_bits - 1) - bit_pos;
390                    remainder_hi ^= mod_mask >> right_shift;
391                    remainder_lo ^= mod_mask << (128 - right_shift);
392                }
393            } else {
394                break;
395            }
396        }
397
398        // now reduce remainder_lo if needed
399        let mut remainder = BinaryPoly128::new(remainder_lo);
400
401        if remainder.leading_zeros() < modulus.leading_zeros() {
402            let (_, r) = remainder.div_rem(modulus);
403            remainder = r;
404        }
405
406        remainder
407    }
408
409    /// get the high 128 bits
410    pub fn high(&self) -> BinaryPoly128 {
411        BinaryPoly128::new(self.hi)
412    }
413
414    /// get the low 128 bits
415    pub fn low(&self) -> BinaryPoly128 {
416        BinaryPoly128::new(self.lo)
417    }
418
419    pub fn leading_zeros(&self) -> u32 {
420        if self.hi == 0 {
421            128 + self.lo.leading_zeros()
422        } else {
423            self.hi.leading_zeros()
424        }
425    }
426
427    pub fn add(&self, other: &Self) -> Self {
428        Self {
429            hi: self.hi ^ other.hi,
430            lo: self.lo ^ other.lo,
431        }
432    }
433
434    pub fn shl(&self, n: u32) -> Self {
435        if n == 0 {
436            *self
437        } else if n >= 256 {
438            Self { hi: 0, lo: 0 }
439        } else if n >= 128 {
440            Self {
441                hi: self.lo << (n - 128),
442                lo: 0,
443            }
444        } else {
445            Self {
446                hi: (self.hi << n) | (self.lo >> (128 - n)),
447                lo: self.lo << n,
448            }
449        }
450    }
451
452    pub fn shr(&self, n: u32) -> Self {
453        if n == 0 {
454            *self
455        } else if n >= 256 {
456            Self { hi: 0, lo: 0 }
457        } else if n >= 128 {
458            Self {
459                hi: 0,
460                lo: self.hi >> (n - 128),
461            }
462        } else {
463            Self {
464                hi: self.hi >> n,
465                lo: (self.lo >> n) | (self.hi << (128 - n)),
466            }
467        }
468    }
469}