1use crate::BinaryPolynomial;
3
4macro_rules! impl_binary_poly {
6 ($name:ident, $value_type:ty, $double_name:ident) => {
7 #[repr(transparent)]
8 #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
9 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10 pub struct $name($value_type);
11
12 unsafe impl bytemuck::Pod for $name {}
14 unsafe impl bytemuck::Zeroable for $name {}
15
16 impl $name {
17 pub const fn new(val: $value_type) -> Self {
18 Self(val)
19 }
20
21 pub fn value(&self) -> $value_type {
22 self.0
23 }
24
25 pub fn shl(&self, n: u32) -> Self {
26 Self(self.0 << n)
27 }
28
29 pub fn shr(&self, n: u32) -> Self {
30 Self(self.0 >> n)
31 }
32
33 pub fn leading_zeros(&self) -> u32 {
34 self.0.leading_zeros()
35 }
36
37 #[allow(dead_code)]
38 pub fn split(&self) -> (Self, Self) {
39 let half_bits = core::mem::size_of::<$value_type>() * 4;
40 let mask = ((1u64 << half_bits) - 1) as $value_type;
41 let lo = Self(self.0 & mask);
42 let hi = Self(self.0 >> half_bits);
43 (hi, lo)
44 }
45 }
46
47 impl BinaryPolynomial for $name {
48 type Value = $value_type;
49
50 fn zero() -> Self {
51 Self(0)
52 }
53
54 fn one() -> Self {
55 Self(1)
56 }
57
58 fn from_value(val: u64) -> Self {
59 Self(val as $value_type)
60 }
61
62 fn value(&self) -> Self::Value {
63 self.0
64 }
65
66 fn add(&self, other: &Self) -> Self {
67 Self(self.0 ^ other.0)
68 }
69
70 fn mul(&self, other: &Self) -> Self {
71 let mut result = 0 as $value_type;
73 let a = self.0;
74 let b = other.0;
75 let bits = core::mem::size_of::<$value_type>() * 8;
76
77 for i in 0..bits {
78 let mask = (0 as $value_type).wrapping_sub((b >> i) & 1);
80 result ^= a.wrapping_shl(i as u32) & mask;
81 }
82
83 Self(result)
84 }
85
86 fn div_rem(&self, divisor: &Self) -> (Self, Self) {
87 assert_ne!(divisor.0, 0, "Division by zero");
88
89 let mut quotient = Self::zero();
90 let mut remainder = *self;
91
92 if remainder.0 == 0 {
93 return (quotient, remainder);
94 }
95
96 let divisor_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - divisor.leading_zeros();
97 let mut remainder_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - remainder.leading_zeros();
98
99 while remainder_bits >= divisor_bits && remainder.0 != 0 {
100 let shift = remainder_bits - divisor_bits;
101 quotient.0 |= 1 << shift;
102 remainder.0 ^= divisor.0 << shift;
103 remainder_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - remainder.leading_zeros();
104 }
105
106 (quotient, remainder)
107 }
108 }
109
110 impl From<$value_type> for $name {
111 fn from(val: $value_type) -> Self {
112 Self(val)
113 }
114 }
115 };
116}
117
118impl_binary_poly!(BinaryPoly16, u16, BinaryPoly32);
120impl_binary_poly!(BinaryPoly32, u32, BinaryPoly64);
121
122#[repr(transparent)]
124#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
125#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
126pub struct BinaryPoly64(u64);
127
128unsafe impl bytemuck::Pod for BinaryPoly64 {}
130unsafe impl bytemuck::Zeroable for BinaryPoly64 {}
131
132impl BinaryPoly64 {
133 pub const fn new(val: u64) -> Self {
134 Self(val)
135 }
136
137 pub fn value(&self) -> u64 {
138 self.0
139 }
140
141 pub fn shl(&self, n: u32) -> Self {
142 Self(self.0 << n)
143 }
144
145 pub fn shr(&self, n: u32) -> Self {
146 Self(self.0 >> n)
147 }
148
149 pub fn leading_zeros(&self) -> u32 {
150 self.0.leading_zeros()
151 }
152
153 pub fn split(&self) -> (BinaryPoly32, BinaryPoly32) {
154 let lo = BinaryPoly32::new(self.0 as u32);
155 let hi = BinaryPoly32::new((self.0 >> 32) as u32);
156 (hi, lo)
157 }
158}
159
160impl BinaryPolynomial for BinaryPoly64 {
161 type Value = u64;
162
163 fn zero() -> Self {
164 Self(0)
165 }
166
167 fn one() -> Self {
168 Self(1)
169 }
170
171 fn from_value(val: u64) -> Self {
172 Self(val)
173 }
174
175 fn value(&self) -> Self::Value {
176 self.0
177 }
178
179 fn add(&self, other: &Self) -> Self {
180 Self(self.0 ^ other.0)
181 }
182
183 fn mul(&self, other: &Self) -> Self {
184 use crate::simd::carryless_mul_64;
185 carryless_mul_64(*self, *other).truncate_to_64()
186 }
187
188 fn div_rem(&self, divisor: &Self) -> (Self, Self) {
189 assert_ne!(divisor.0, 0, "Division by zero");
190
191 let mut quotient = Self::zero();
192 let mut remainder = *self;
193
194 if remainder.0 == 0 {
195 return (quotient, remainder);
196 }
197
198 let divisor_bits = 64 - divisor.leading_zeros();
199 let mut remainder_bits = 64 - remainder.leading_zeros();
200
201 while remainder_bits >= divisor_bits && remainder.0 != 0 {
202 let shift = remainder_bits - divisor_bits;
203 quotient.0 |= 1 << shift;
204 remainder.0 ^= divisor.0 << shift;
205 remainder_bits = 64 - remainder.leading_zeros();
206 }
207
208 (quotient, remainder)
209 }
210}
211
212impl From<u64> for BinaryPoly64 {
213 fn from(val: u64) -> Self {
214 Self(val)
215 }
216}
217
218#[repr(transparent)]
220#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
221#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
222pub struct BinaryPoly128(u128);
223
224unsafe impl bytemuck::Pod for BinaryPoly128 {}
226unsafe impl bytemuck::Zeroable for BinaryPoly128 {}
227
228impl BinaryPoly128 {
229 pub const fn new(val: u128) -> Self {
230 Self(val)
231 }
232
233 pub fn value(&self) -> u128 {
234 self.0
235 }
236
237 pub fn truncate_to_64(&self) -> BinaryPoly64 {
238 BinaryPoly64::new(self.0 as u64)
239 }
240
241 pub fn split(&self) -> (BinaryPoly64, BinaryPoly64) {
242 let lo = BinaryPoly64::new(self.0 as u64);
243 let hi = BinaryPoly64::new((self.0 >> 64) as u64);
244 (hi, lo)
245 }
246
247 pub fn leading_zeros(&self) -> u32 {
248 self.0.leading_zeros()
249 }
250
251 pub fn mul_full(&self, other: &Self) -> BinaryPoly256 {
253 use crate::simd::carryless_mul_128_full;
254 carryless_mul_128_full(*self, *other)
255 }
256}
257
258impl BinaryPolynomial for BinaryPoly128 {
259 type Value = u128;
260
261 fn zero() -> Self {
262 Self(0)
263 }
264
265 fn one() -> Self {
266 Self(1)
267 }
268
269 fn from_value(val: u64) -> Self {
270 Self(val as u128)
271 }
272
273 fn value(&self) -> Self::Value {
274 self.0
275 }
276
277 fn add(&self, other: &Self) -> Self {
278 Self(self.0 ^ other.0)
279 }
280
281 fn mul(&self, other: &Self) -> Self {
282 use crate::simd::carryless_mul_128;
283 carryless_mul_128(*self, *other)
284 }
285
286 fn div_rem(&self, divisor: &Self) -> (Self, Self) {
287 assert_ne!(divisor.0, 0, "Division by zero");
288
289 let mut quotient = Self::zero();
290 let mut remainder = *self;
291
292 if remainder.0 == 0 {
293 return (quotient, remainder);
294 }
295
296 let divisor_bits = 128 - divisor.leading_zeros();
297 let mut remainder_bits = 128 - remainder.leading_zeros();
298
299 while remainder_bits >= divisor_bits && remainder.0 != 0 {
300 let shift = remainder_bits - divisor_bits;
301 quotient.0 |= 1u128 << shift;
302 remainder.0 ^= divisor.0 << shift;
303 remainder_bits = 128 - remainder.leading_zeros();
304 }
305
306 (quotient, remainder)
307 }
308}
309
310impl From<u128> for BinaryPoly128 {
311 fn from(val: u128) -> Self {
312 Self(val)
313 }
314}
315
316#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
318#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
319pub struct BinaryPoly256 {
320 hi: u128,
321 lo: u128,
322}
323
324impl BinaryPoly256 {
325 pub fn from_parts(hi: u128, lo: u128) -> Self {
326 Self { hi, lo }
327 }
328
329 pub fn split(&self) -> (BinaryPoly128, BinaryPoly128) {
330 (BinaryPoly128::new(self.hi), BinaryPoly128::new(self.lo))
331 }
332
333 pub fn reduce_mod(&self, modulus: &BinaryPoly128) -> BinaryPoly128 {
335 if modulus.value() == (1u128 << 127) | 0x87 {
340 let mut result = self.lo;
342 let mut high = self.hi;
343
344 while high != 0 {
346 let feedback = high.wrapping_shl(7)
348 ^ high.wrapping_shl(2)
349 ^ high.wrapping_shl(1)
350 ^ high;
351
352 result ^= feedback;
353 high >>= 121; }
355
356 return BinaryPoly128::new(result);
357 }
358
359 if self.hi == 0 {
361 return BinaryPoly128::new(self.lo);
363 }
364
365 let mut remainder_hi = self.hi;
367 let mut remainder_lo = self.lo;
368
369 let mod_bits = 128 - modulus.leading_zeros();
371 let mod_val = modulus.value();
372 let mod_mask = mod_val ^ (1u128 << (mod_bits - 1));
373
374 while remainder_hi != 0 {
376 let shift = remainder_hi.leading_zeros();
377
378 if shift < 128 {
379 let bit_pos = 127 - shift;
381
382 remainder_hi ^= 1u128 << bit_pos;
384
385 if bit_pos >= (mod_bits - 1) {
387 remainder_hi ^= mod_mask << (bit_pos - (mod_bits - 1));
388 } else {
389 let right_shift = (mod_bits - 1) - bit_pos;
390 remainder_hi ^= mod_mask >> right_shift;
391 remainder_lo ^= mod_mask << (128 - right_shift);
392 }
393 } else {
394 break;
395 }
396 }
397
398 let mut remainder = BinaryPoly128::new(remainder_lo);
400
401 if remainder.leading_zeros() < modulus.leading_zeros() {
402 let (_, r) = remainder.div_rem(modulus);
403 remainder = r;
404 }
405
406 remainder
407 }
408
409 pub fn high(&self) -> BinaryPoly128 {
411 BinaryPoly128::new(self.hi)
412 }
413
414 pub fn low(&self) -> BinaryPoly128 {
416 BinaryPoly128::new(self.lo)
417 }
418
419 pub fn leading_zeros(&self) -> u32 {
420 if self.hi == 0 {
421 128 + self.lo.leading_zeros()
422 } else {
423 self.hi.leading_zeros()
424 }
425 }
426
427 pub fn add(&self, other: &Self) -> Self {
428 Self {
429 hi: self.hi ^ other.hi,
430 lo: self.lo ^ other.lo,
431 }
432 }
433
434 pub fn shl(&self, n: u32) -> Self {
435 if n == 0 {
436 *self
437 } else if n >= 256 {
438 Self { hi: 0, lo: 0 }
439 } else if n >= 128 {
440 Self {
441 hi: self.lo << (n - 128),
442 lo: 0,
443 }
444 } else {
445 Self {
446 hi: (self.hi << n) | (self.lo >> (128 - n)),
447 lo: self.lo << n,
448 }
449 }
450 }
451
452 pub fn shr(&self, n: u32) -> Self {
453 if n == 0 {
454 *self
455 } else if n >= 256 {
456 Self { hi: 0, lo: 0 }
457 } else if n >= 128 {
458 Self {
459 hi: 0,
460 lo: self.hi >> (n - 128),
461 }
462 } else {
463 Self {
464 hi: self.hi >> n,
465 lo: (self.lo >> n) | (self.hi << (128 - n)),
466 }
467 }
468 }
469}