ligerito_binary_fields/
poly.rs

1// src/poly.rs
2use crate::BinaryPolynomial;
3
4// Macro to implement binary polynomials for different sizes
5macro_rules! impl_binary_poly {
6    ($name:ident, $value_type:ty, $double_name:ident) => {
7        #[repr(transparent)]
8        #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
9        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10        #[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
11        pub struct $name($value_type);
12
13        // SAFETY: $name is repr(transparent) over $value_type (a primitive integer type)
14        unsafe impl bytemuck::Pod for $name {}
15        unsafe impl bytemuck::Zeroable for $name {}
16
17        impl $name {
18            pub const fn new(val: $value_type) -> Self {
19                Self(val)
20            }
21
22            pub fn value(&self) -> $value_type {
23                self.0
24            }
25
26            pub fn shl(&self, n: u32) -> Self {
27                Self(self.0 << n)
28            }
29
30            pub fn shr(&self, n: u32) -> Self {
31                Self(self.0 >> n)
32            }
33
34            pub fn leading_zeros(&self) -> u32 {
35                self.0.leading_zeros()
36            }
37
38            #[allow(dead_code)]
39            pub fn split(&self) -> (Self, Self) {
40                let half_bits = core::mem::size_of::<$value_type>() * 4;
41                let mask = ((1u64 << half_bits) - 1) as $value_type;
42                let lo = Self(self.0 & mask);
43                let hi = Self(self.0 >> half_bits);
44                (hi, lo)
45            }
46        }
47
48        impl BinaryPolynomial for $name {
49            type Value = $value_type;
50
51            fn zero() -> Self {
52                Self(0)
53            }
54
55            fn one() -> Self {
56                Self(1)
57            }
58
59            fn from_value(val: u64) -> Self {
60                Self(val as $value_type)
61            }
62
63            fn value(&self) -> Self::Value {
64                self.0
65            }
66
67            fn add(&self, other: &Self) -> Self {
68                Self(self.0 ^ other.0)
69            }
70
71            fn mul(&self, other: &Self) -> Self {
72                // constant-time carryless multiplication
73                let mut result = 0 as $value_type;
74                let a = self.0;
75                let b = other.0;
76                let bits = core::mem::size_of::<$value_type>() * 8;
77
78                for i in 0..bits {
79                    // constant-time conditional xor
80                    let mask = (0 as $value_type).wrapping_sub((b >> i) & 1);
81                    result ^= a.wrapping_shl(i as u32) & mask;
82                }
83
84                Self(result)
85            }
86
87            fn div_rem(&self, divisor: &Self) -> (Self, Self) {
88                assert_ne!(divisor.0, 0, "Division by zero");
89
90                let mut quotient = Self::zero();
91                let mut remainder = *self;
92
93                if remainder.0 == 0 {
94                    return (quotient, remainder);
95                }
96
97                let divisor_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - divisor.leading_zeros();
98                let mut remainder_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - remainder.leading_zeros();
99
100                while remainder_bits >= divisor_bits && remainder.0 != 0 {
101                    let shift = remainder_bits - divisor_bits;
102                    quotient.0 |= 1 << shift;
103                    remainder.0 ^= divisor.0 << shift;
104                    remainder_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - remainder.leading_zeros();
105                }
106
107                (quotient, remainder)
108            }
109        }
110
111        impl From<$value_type> for $name {
112            fn from(val: $value_type) -> Self {
113                Self(val)
114            }
115        }
116    };
117}
118
119// Define polynomial types
120impl_binary_poly!(BinaryPoly16, u16, BinaryPoly32);
121impl_binary_poly!(BinaryPoly32, u32, BinaryPoly64);
122
123// BinaryPoly64 with SIMD support
124#[repr(transparent)]
125#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
126#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
127#[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
128pub struct BinaryPoly64(u64);
129
130// SAFETY: BinaryPoly64 is repr(transparent) over u64 (a primitive)
131unsafe impl bytemuck::Pod for BinaryPoly64 {}
132unsafe impl bytemuck::Zeroable for BinaryPoly64 {}
133
134impl BinaryPoly64 {
135    pub const fn new(val: u64) -> Self {
136        Self(val)
137    }
138
139    pub fn value(&self) -> u64 {
140        self.0
141    }
142
143    pub fn shl(&self, n: u32) -> Self {
144        Self(self.0 << n)
145    }
146
147    pub fn shr(&self, n: u32) -> Self {
148        Self(self.0 >> n)
149    }
150
151    pub fn leading_zeros(&self) -> u32 {
152        self.0.leading_zeros()
153    }
154
155    pub fn split(&self) -> (BinaryPoly32, BinaryPoly32) {
156        let lo = BinaryPoly32::new(self.0 as u32);
157        let hi = BinaryPoly32::new((self.0 >> 32) as u32);
158        (hi, lo)
159    }
160}
161
162impl BinaryPolynomial for BinaryPoly64 {
163    type Value = u64;
164
165    fn zero() -> Self {
166        Self(0)
167    }
168
169    fn one() -> Self {
170        Self(1)
171    }
172
173    fn from_value(val: u64) -> Self {
174        Self(val)
175    }
176
177    fn value(&self) -> Self::Value {
178        self.0
179    }
180
181    fn add(&self, other: &Self) -> Self {
182        Self(self.0 ^ other.0)
183    }
184
185    fn mul(&self, other: &Self) -> Self {
186        use crate::simd::carryless_mul_64;
187        carryless_mul_64(*self, *other).truncate_to_64()
188    }
189
190    fn div_rem(&self, divisor: &Self) -> (Self, Self) {
191        assert_ne!(divisor.0, 0, "Division by zero");
192
193        let mut quotient = Self::zero();
194        let mut remainder = *self;
195
196        if remainder.0 == 0 {
197            return (quotient, remainder);
198        }
199
200        let divisor_bits = 64 - divisor.leading_zeros();
201        let mut remainder_bits = 64 - remainder.leading_zeros();
202
203        while remainder_bits >= divisor_bits && remainder.0 != 0 {
204            let shift = remainder_bits - divisor_bits;
205            quotient.0 |= 1 << shift;
206            remainder.0 ^= divisor.0 << shift;
207            remainder_bits = 64 - remainder.leading_zeros();
208        }
209
210        (quotient, remainder)
211    }
212}
213
214impl From<u64> for BinaryPoly64 {
215    fn from(val: u64) -> Self {
216        Self(val)
217    }
218}
219
220// BinaryPoly128
221#[repr(transparent)]
222#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
223#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
224#[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
225pub struct BinaryPoly128(u128);
226
227// SAFETY: BinaryPoly128 is repr(transparent) over u128 (a primitive)
228unsafe impl bytemuck::Pod for BinaryPoly128 {}
229unsafe impl bytemuck::Zeroable for BinaryPoly128 {}
230
231impl BinaryPoly128 {
232    pub const fn new(val: u128) -> Self {
233        Self(val)
234    }
235
236    pub fn value(&self) -> u128 {
237        self.0
238    }
239
240    pub fn truncate_to_64(&self) -> BinaryPoly64 {
241        BinaryPoly64::new(self.0 as u64)
242    }
243
244    pub fn split(&self) -> (BinaryPoly64, BinaryPoly64) {
245        let lo = BinaryPoly64::new(self.0 as u64);
246        let hi = BinaryPoly64::new((self.0 >> 64) as u64);
247        (hi, lo)
248    }
249
250    pub fn leading_zeros(&self) -> u32 {
251        self.0.leading_zeros()
252    }
253
254    // full 128x128 -> 256 bit multiplication
255    pub fn mul_full(&self, other: &Self) -> BinaryPoly256 {
256        use crate::simd::carryless_mul_128_full;
257        carryless_mul_128_full(*self, *other)
258    }
259}
260
261impl BinaryPolynomial for BinaryPoly128 {
262    type Value = u128;
263
264    fn zero() -> Self {
265        Self(0)
266    }
267
268    fn one() -> Self {
269        Self(1)
270    }
271
272    fn from_value(val: u64) -> Self {
273        Self(val as u128)
274    }
275
276    fn value(&self) -> Self::Value {
277        self.0
278    }
279
280    fn add(&self, other: &Self) -> Self {
281        Self(self.0 ^ other.0)
282    }
283
284    fn mul(&self, other: &Self) -> Self {
285        use crate::simd::carryless_mul_128;
286        carryless_mul_128(*self, *other)
287    }
288
289    fn div_rem(&self, divisor: &Self) -> (Self, Self) {
290        assert_ne!(divisor.0, 0, "Division by zero");
291
292        let mut quotient = Self::zero();
293        let mut remainder = *self;
294
295        if remainder.0 == 0 {
296            return (quotient, remainder);
297        }
298
299        let divisor_bits = 128 - divisor.leading_zeros();
300        let mut remainder_bits = 128 - remainder.leading_zeros();
301
302        while remainder_bits >= divisor_bits && remainder.0 != 0 {
303            let shift = remainder_bits - divisor_bits;
304            quotient.0 |= 1u128 << shift;
305            remainder.0 ^= divisor.0 << shift;
306            remainder_bits = 128 - remainder.leading_zeros();
307        }
308
309        (quotient, remainder)
310    }
311}
312
313impl From<u128> for BinaryPoly128 {
314    fn from(val: u128) -> Self {
315        Self(val)
316    }
317}
318
319// BinaryPoly256 for intermediate calculations
320#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
321#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
322#[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
323pub struct BinaryPoly256 {
324    hi: u128,
325    lo: u128,
326}
327
328impl BinaryPoly256 {
329    pub fn from_parts(hi: u128, lo: u128) -> Self {
330        Self { hi, lo }
331    }
332
333    pub fn split(&self) -> (BinaryPoly128, BinaryPoly128) {
334        (BinaryPoly128::new(self.hi), BinaryPoly128::new(self.lo))
335    }
336
337    /// reduce modulo a 128-bit polynomial (for field operations)
338    pub fn reduce_mod(&self, modulus: &BinaryPoly128) -> BinaryPoly128 {
339        // for irreducible polynomials of form x^128 + lower terms,
340        // we can use efficient reduction
341
342        // special case for GF(2^128) with x^128 + x^7 + x^2 + x + 1
343        if modulus.value() == (1u128 << 127) | 0x87 {
344            // efficient reduction for gcm polynomial
345            let mut result = self.lo;
346            let mut high = self.hi;
347
348            // reduce 128 bits at a time
349            while high != 0 {
350                // x^128 = x^7 + x^2 + x + 1
351                let feedback = high.wrapping_shl(7)
352                    ^ high.wrapping_shl(2)
353                    ^ high.wrapping_shl(1)
354                    ^ high;
355
356                result ^= feedback;
357                high >>= 121; // process remaining bits
358            }
359
360            return BinaryPoly128::new(result);
361        }
362
363        // general case: polynomial long division
364        if self.hi == 0 {
365            // already reduced
366            return BinaryPoly128::new(self.lo);
367        }
368
369        // work with a copy
370        let mut remainder_hi = self.hi;
371        let mut remainder_lo = self.lo;
372
373        // get modulus without the leading bit
374        let mod_bits = 128 - modulus.leading_zeros();
375        let mod_val = modulus.value();
376        let mod_mask = mod_val ^ (1u128 << (mod_bits - 1));
377
378        // reduce high 128 bits
379        while remainder_hi != 0 {
380            let shift = remainder_hi.leading_zeros();
381
382            if shift < 128 {
383                // align the leading bit
384                let bit_pos = 127 - shift;
385
386                // xor with modulus shifted appropriately
387                remainder_hi ^= 1u128 << bit_pos;
388
389                // xor lower bits of modulus into result
390                if bit_pos >= (mod_bits - 1) {
391                    remainder_hi ^= mod_mask << (bit_pos - (mod_bits - 1));
392                } else {
393                    let right_shift = (mod_bits - 1) - bit_pos;
394                    remainder_hi ^= mod_mask >> right_shift;
395                    remainder_lo ^= mod_mask << (128 - right_shift);
396                }
397            } else {
398                break;
399            }
400        }
401
402        // now reduce remainder_lo if needed
403        let mut remainder = BinaryPoly128::new(remainder_lo);
404
405        if remainder.leading_zeros() < modulus.leading_zeros() {
406            let (_, r) = remainder.div_rem(modulus);
407            remainder = r;
408        }
409
410        remainder
411    }
412
413    /// get the high 128 bits
414    pub fn high(&self) -> BinaryPoly128 {
415        BinaryPoly128::new(self.hi)
416    }
417
418    /// get the low 128 bits
419    pub fn low(&self) -> BinaryPoly128 {
420        BinaryPoly128::new(self.lo)
421    }
422
423    pub fn leading_zeros(&self) -> u32 {
424        if self.hi == 0 {
425            128 + self.lo.leading_zeros()
426        } else {
427            self.hi.leading_zeros()
428        }
429    }
430
431    pub fn add(&self, other: &Self) -> Self {
432        Self {
433            hi: self.hi ^ other.hi,
434            lo: self.lo ^ other.lo,
435        }
436    }
437
438    pub fn shl(&self, n: u32) -> Self {
439        if n == 0 {
440            *self
441        } else if n >= 256 {
442            Self { hi: 0, lo: 0 }
443        } else if n >= 128 {
444            Self {
445                hi: self.lo << (n - 128),
446                lo: 0,
447            }
448        } else {
449            Self {
450                hi: (self.hi << n) | (self.lo >> (128 - n)),
451                lo: self.lo << n,
452            }
453        }
454    }
455
456    pub fn shr(&self, n: u32) -> Self {
457        if n == 0 {
458            *self
459        } else if n >= 256 {
460            Self { hi: 0, lo: 0 }
461        } else if n >= 128 {
462            Self {
463                hi: 0,
464                lo: self.hi >> (n - 128),
465            }
466        } else {
467            Self {
468                hi: self.hi >> n,
469                lo: (self.lo >> n) | (self.hi << (128 - n)),
470            }
471        }
472    }
473}