Skip to main content

ligerito_binary_fields/
poly.rs

1// src/poly.rs
2use crate::BinaryPolynomial;
3
4// Macro to implement binary polynomials for different sizes
5macro_rules! impl_binary_poly {
6    ($name:ident, $value_type:ty, $double_name:ident) => {
7        #[repr(transparent)]
8        #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
9        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10        #[cfg_attr(
11            feature = "scale",
12            derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
13        )]
14        pub struct $name($value_type);
15
16        // SAFETY: $name is repr(transparent) over $value_type (a primitive integer type)
17        unsafe impl bytemuck::Pod for $name {}
18        unsafe impl bytemuck::Zeroable for $name {}
19
20        impl $name {
21            pub const fn new(val: $value_type) -> Self {
22                Self(val)
23            }
24
25            pub fn value(&self) -> $value_type {
26                self.0
27            }
28
29            pub fn shl(&self, n: u32) -> Self {
30                Self(self.0 << n)
31            }
32
33            pub fn shr(&self, n: u32) -> Self {
34                Self(self.0 >> n)
35            }
36
37            pub fn leading_zeros(&self) -> u32 {
38                self.0.leading_zeros()
39            }
40
41            #[allow(dead_code)]
42            pub fn split(&self) -> (Self, Self) {
43                let half_bits = core::mem::size_of::<$value_type>() * 4;
44                let mask = ((1u64 << half_bits) - 1) as $value_type;
45                let lo = Self(self.0 & mask);
46                let hi = Self(self.0 >> half_bits);
47                (hi, lo)
48            }
49        }
50
51        impl BinaryPolynomial for $name {
52            type Value = $value_type;
53
54            fn zero() -> Self {
55                Self(0)
56            }
57
58            fn one() -> Self {
59                Self(1)
60            }
61
62            fn from_value(val: u64) -> Self {
63                Self(val as $value_type)
64            }
65
66            fn value(&self) -> Self::Value {
67                self.0
68            }
69
70            fn add(&self, other: &Self) -> Self {
71                Self(self.0 ^ other.0)
72            }
73
74            fn mul(&self, other: &Self) -> Self {
75                // constant-time carryless multiplication
76                let mut result = 0 as $value_type;
77                let a = self.0;
78                let b = other.0;
79                let bits = core::mem::size_of::<$value_type>() * 8;
80
81                for i in 0..bits {
82                    // constant-time conditional xor
83                    let mask = (0 as $value_type).wrapping_sub((b >> i) & 1);
84                    result ^= a.wrapping_shl(i as u32) & mask;
85                }
86
87                Self(result)
88            }
89
90            fn div_rem(&self, divisor: &Self) -> (Self, Self) {
91                assert_ne!(divisor.0, 0, "Division by zero");
92
93                let mut quotient = Self::zero();
94                let mut remainder = *self;
95
96                if remainder.0 == 0 {
97                    return (quotient, remainder);
98                }
99
100                let divisor_bits =
101                    (core::mem::size_of::<$value_type>() * 8) as u32 - divisor.leading_zeros();
102                let mut remainder_bits =
103                    (core::mem::size_of::<$value_type>() * 8) as u32 - remainder.leading_zeros();
104
105                while remainder_bits >= divisor_bits && remainder.0 != 0 {
106                    let shift = remainder_bits - divisor_bits;
107                    quotient.0 |= 1 << shift;
108                    remainder.0 ^= divisor.0 << shift;
109                    remainder_bits = (core::mem::size_of::<$value_type>() * 8) as u32
110                        - remainder.leading_zeros();
111                }
112
113                (quotient, remainder)
114            }
115        }
116
117        impl From<$value_type> for $name {
118            fn from(val: $value_type) -> Self {
119                Self(val)
120            }
121        }
122    };
123}
124
125// Define polynomial types
126impl_binary_poly!(BinaryPoly16, u16, BinaryPoly32);
127impl_binary_poly!(BinaryPoly32, u32, BinaryPoly64);
128
129// BinaryPoly64 with SIMD support
130#[repr(transparent)]
131#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133#[cfg_attr(
134    feature = "scale",
135    derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
136)]
137pub struct BinaryPoly64(u64);
138
139// SAFETY: BinaryPoly64 is repr(transparent) over u64 (a primitive)
140unsafe impl bytemuck::Pod for BinaryPoly64 {}
141unsafe impl bytemuck::Zeroable for BinaryPoly64 {}
142
143impl BinaryPoly64 {
144    pub const fn new(val: u64) -> Self {
145        Self(val)
146    }
147
148    pub fn value(&self) -> u64 {
149        self.0
150    }
151
152    pub fn shl(&self, n: u32) -> Self {
153        Self(self.0 << n)
154    }
155
156    pub fn shr(&self, n: u32) -> Self {
157        Self(self.0 >> n)
158    }
159
160    pub fn leading_zeros(&self) -> u32 {
161        self.0.leading_zeros()
162    }
163
164    pub fn split(&self) -> (BinaryPoly32, BinaryPoly32) {
165        let lo = BinaryPoly32::new(self.0 as u32);
166        let hi = BinaryPoly32::new((self.0 >> 32) as u32);
167        (hi, lo)
168    }
169}
170
171impl BinaryPolynomial for BinaryPoly64 {
172    type Value = u64;
173
174    fn zero() -> Self {
175        Self(0)
176    }
177
178    fn one() -> Self {
179        Self(1)
180    }
181
182    fn from_value(val: u64) -> Self {
183        Self(val)
184    }
185
186    fn value(&self) -> Self::Value {
187        self.0
188    }
189
190    fn add(&self, other: &Self) -> Self {
191        Self(self.0 ^ other.0)
192    }
193
194    fn mul(&self, other: &Self) -> Self {
195        use crate::simd::carryless_mul_64;
196        carryless_mul_64(*self, *other).truncate_to_64()
197    }
198
199    fn div_rem(&self, divisor: &Self) -> (Self, Self) {
200        assert_ne!(divisor.0, 0, "Division by zero");
201
202        let mut quotient = Self::zero();
203        let mut remainder = *self;
204
205        if remainder.0 == 0 {
206            return (quotient, remainder);
207        }
208
209        let divisor_bits = 64 - divisor.leading_zeros();
210        let mut remainder_bits = 64 - remainder.leading_zeros();
211
212        while remainder_bits >= divisor_bits && remainder.0 != 0 {
213            let shift = remainder_bits - divisor_bits;
214            quotient.0 |= 1 << shift;
215            remainder.0 ^= divisor.0 << shift;
216            remainder_bits = 64 - remainder.leading_zeros();
217        }
218
219        (quotient, remainder)
220    }
221}
222
223impl From<u64> for BinaryPoly64 {
224    fn from(val: u64) -> Self {
225        Self(val)
226    }
227}
228
229// BinaryPoly128
230#[repr(transparent)]
231#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
232#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
233#[cfg_attr(
234    feature = "scale",
235    derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
236)]
237pub struct BinaryPoly128(u128);
238
239// SAFETY: BinaryPoly128 is repr(transparent) over u128 (a primitive)
240unsafe impl bytemuck::Pod for BinaryPoly128 {}
241unsafe impl bytemuck::Zeroable for BinaryPoly128 {}
242
243impl BinaryPoly128 {
244    pub const fn new(val: u128) -> Self {
245        Self(val)
246    }
247
248    pub fn value(&self) -> u128 {
249        self.0
250    }
251
252    pub fn truncate_to_64(&self) -> BinaryPoly64 {
253        BinaryPoly64::new(self.0 as u64)
254    }
255
256    pub fn split(&self) -> (BinaryPoly64, BinaryPoly64) {
257        let lo = BinaryPoly64::new(self.0 as u64);
258        let hi = BinaryPoly64::new((self.0 >> 64) as u64);
259        (hi, lo)
260    }
261
262    pub fn leading_zeros(&self) -> u32 {
263        self.0.leading_zeros()
264    }
265
266    // full 128x128 -> 256 bit multiplication
267    pub fn mul_full(&self, other: &Self) -> BinaryPoly256 {
268        use crate::simd::carryless_mul_128_full;
269        carryless_mul_128_full(*self, *other)
270    }
271}
272
273impl BinaryPolynomial for BinaryPoly128 {
274    type Value = u128;
275
276    fn zero() -> Self {
277        Self(0)
278    }
279
280    fn one() -> Self {
281        Self(1)
282    }
283
284    fn from_value(val: u64) -> Self {
285        Self(val as u128)
286    }
287
288    fn value(&self) -> Self::Value {
289        self.0
290    }
291
292    fn add(&self, other: &Self) -> Self {
293        Self(self.0 ^ other.0)
294    }
295
296    fn mul(&self, other: &Self) -> Self {
297        use crate::simd::carryless_mul_128;
298        carryless_mul_128(*self, *other)
299    }
300
301    fn div_rem(&self, divisor: &Self) -> (Self, Self) {
302        assert_ne!(divisor.0, 0, "Division by zero");
303
304        let mut quotient = Self::zero();
305        let mut remainder = *self;
306
307        if remainder.0 == 0 {
308            return (quotient, remainder);
309        }
310
311        let divisor_bits = 128 - divisor.leading_zeros();
312        let mut remainder_bits = 128 - remainder.leading_zeros();
313
314        while remainder_bits >= divisor_bits && remainder.0 != 0 {
315            let shift = remainder_bits - divisor_bits;
316            quotient.0 |= 1u128 << shift;
317            remainder.0 ^= divisor.0 << shift;
318            remainder_bits = 128 - remainder.leading_zeros();
319        }
320
321        (quotient, remainder)
322    }
323}
324
325impl From<u128> for BinaryPoly128 {
326    fn from(val: u128) -> Self {
327        Self(val)
328    }
329}
330
331// BinaryPoly256 for intermediate calculations
332#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
333#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
334#[cfg_attr(
335    feature = "scale",
336    derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
337)]
338pub struct BinaryPoly256 {
339    hi: u128,
340    lo: u128,
341}
342
343impl BinaryPoly256 {
344    pub fn from_parts(hi: u128, lo: u128) -> Self {
345        Self { hi, lo }
346    }
347
348    pub fn split(&self) -> (BinaryPoly128, BinaryPoly128) {
349        (BinaryPoly128::new(self.hi), BinaryPoly128::new(self.lo))
350    }
351
352    /// reduce modulo a 128-bit polynomial (for field operations)
353    pub fn reduce_mod(&self, modulus: &BinaryPoly128) -> BinaryPoly128 {
354        // for irreducible polynomials of form x^128 + lower terms,
355        // we can use efficient reduction
356
357        // special case for GF(2^128) with x^128 + x^7 + x^2 + x + 1
358        if modulus.value() == (1u128 << 127) | 0x87 {
359            // efficient reduction for gcm polynomial
360            let mut result = self.lo;
361            let mut high = self.hi;
362
363            // reduce 128 bits at a time
364            while high != 0 {
365                // x^128 = x^7 + x^2 + x + 1
366                let feedback =
367                    high.wrapping_shl(7) ^ high.wrapping_shl(2) ^ high.wrapping_shl(1) ^ high;
368
369                result ^= feedback;
370                high >>= 121; // process remaining bits
371            }
372
373            return BinaryPoly128::new(result);
374        }
375
376        // general case: polynomial long division
377        if self.hi == 0 {
378            // already reduced
379            return BinaryPoly128::new(self.lo);
380        }
381
382        // work with a copy
383        let mut remainder_hi = self.hi;
384        let mut remainder_lo = self.lo;
385
386        // get modulus without the leading bit
387        let mod_bits = 128 - modulus.leading_zeros();
388        let mod_val = modulus.value();
389        let mod_mask = mod_val ^ (1u128 << (mod_bits - 1));
390
391        // reduce high 128 bits
392        while remainder_hi != 0 {
393            let shift = remainder_hi.leading_zeros();
394
395            if shift < 128 {
396                // align the leading bit
397                let bit_pos = 127 - shift;
398
399                // xor with modulus shifted appropriately
400                remainder_hi ^= 1u128 << bit_pos;
401
402                // xor lower bits of modulus into result
403                if bit_pos >= (mod_bits - 1) {
404                    remainder_hi ^= mod_mask << (bit_pos - (mod_bits - 1));
405                } else {
406                    let right_shift = (mod_bits - 1) - bit_pos;
407                    remainder_hi ^= mod_mask >> right_shift;
408                    remainder_lo ^= mod_mask << (128 - right_shift);
409                }
410            } else {
411                break;
412            }
413        }
414
415        // now reduce remainder_lo if needed
416        let mut remainder = BinaryPoly128::new(remainder_lo);
417
418        if remainder.leading_zeros() < modulus.leading_zeros() {
419            let (_, r) = remainder.div_rem(modulus);
420            remainder = r;
421        }
422
423        remainder
424    }
425
426    /// get the high 128 bits
427    pub fn high(&self) -> BinaryPoly128 {
428        BinaryPoly128::new(self.hi)
429    }
430
431    /// get the low 128 bits
432    pub fn low(&self) -> BinaryPoly128 {
433        BinaryPoly128::new(self.lo)
434    }
435
436    pub fn leading_zeros(&self) -> u32 {
437        if self.hi == 0 {
438            128 + self.lo.leading_zeros()
439        } else {
440            self.hi.leading_zeros()
441        }
442    }
443
444    pub fn add(&self, other: &Self) -> Self {
445        Self {
446            hi: self.hi ^ other.hi,
447            lo: self.lo ^ other.lo,
448        }
449    }
450
451    pub fn shl(&self, n: u32) -> Self {
452        if n == 0 {
453            *self
454        } else if n >= 256 {
455            Self { hi: 0, lo: 0 }
456        } else if n >= 128 {
457            Self {
458                hi: self.lo << (n - 128),
459                lo: 0,
460            }
461        } else {
462            Self {
463                hi: (self.hi << n) | (self.lo >> (128 - n)),
464                lo: self.lo << n,
465            }
466        }
467    }
468
469    pub fn shr(&self, n: u32) -> Self {
470        if n == 0 {
471            *self
472        } else if n >= 256 {
473            Self { hi: 0, lo: 0 }
474        } else if n >= 128 {
475            Self {
476                hi: 0,
477                lo: self.hi >> (n - 128),
478            }
479        } else {
480            Self {
481                hi: self.hi >> n,
482                lo: (self.lo >> n) | (self.hi << (128 - n)),
483            }
484        }
485    }
486}