1use crate::BinaryPolynomial;
3
4macro_rules! impl_binary_poly {
6 ($name:ident, $value_type:ty, $double_name:ident) => {
7 #[repr(transparent)]
8 #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
9 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10 #[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
11 pub struct $name($value_type);
12
13 unsafe impl bytemuck::Pod for $name {}
15 unsafe impl bytemuck::Zeroable for $name {}
16
17 impl $name {
18 pub const fn new(val: $value_type) -> Self {
19 Self(val)
20 }
21
22 pub fn value(&self) -> $value_type {
23 self.0
24 }
25
26 pub fn shl(&self, n: u32) -> Self {
27 Self(self.0 << n)
28 }
29
30 pub fn shr(&self, n: u32) -> Self {
31 Self(self.0 >> n)
32 }
33
34 pub fn leading_zeros(&self) -> u32 {
35 self.0.leading_zeros()
36 }
37
38 #[allow(dead_code)]
39 pub fn split(&self) -> (Self, Self) {
40 let half_bits = core::mem::size_of::<$value_type>() * 4;
41 let mask = ((1u64 << half_bits) - 1) as $value_type;
42 let lo = Self(self.0 & mask);
43 let hi = Self(self.0 >> half_bits);
44 (hi, lo)
45 }
46 }
47
48 impl BinaryPolynomial for $name {
49 type Value = $value_type;
50
51 fn zero() -> Self {
52 Self(0)
53 }
54
55 fn one() -> Self {
56 Self(1)
57 }
58
59 fn from_value(val: u64) -> Self {
60 Self(val as $value_type)
61 }
62
63 fn value(&self) -> Self::Value {
64 self.0
65 }
66
67 fn add(&self, other: &Self) -> Self {
68 Self(self.0 ^ other.0)
69 }
70
71 fn mul(&self, other: &Self) -> Self {
72 let mut result = 0 as $value_type;
74 let a = self.0;
75 let b = other.0;
76 let bits = core::mem::size_of::<$value_type>() * 8;
77
78 for i in 0..bits {
79 let mask = (0 as $value_type).wrapping_sub((b >> i) & 1);
81 result ^= a.wrapping_shl(i as u32) & mask;
82 }
83
84 Self(result)
85 }
86
87 fn div_rem(&self, divisor: &Self) -> (Self, Self) {
88 assert_ne!(divisor.0, 0, "Division by zero");
89
90 let mut quotient = Self::zero();
91 let mut remainder = *self;
92
93 if remainder.0 == 0 {
94 return (quotient, remainder);
95 }
96
97 let divisor_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - divisor.leading_zeros();
98 let mut remainder_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - remainder.leading_zeros();
99
100 while remainder_bits >= divisor_bits && remainder.0 != 0 {
101 let shift = remainder_bits - divisor_bits;
102 quotient.0 |= 1 << shift;
103 remainder.0 ^= divisor.0 << shift;
104 remainder_bits = (core::mem::size_of::<$value_type>() * 8) as u32 - remainder.leading_zeros();
105 }
106
107 (quotient, remainder)
108 }
109 }
110
111 impl From<$value_type> for $name {
112 fn from(val: $value_type) -> Self {
113 Self(val)
114 }
115 }
116 };
117}
118
119impl_binary_poly!(BinaryPoly16, u16, BinaryPoly32);
121impl_binary_poly!(BinaryPoly32, u32, BinaryPoly64);
122
123#[repr(transparent)]
125#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
126#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
127#[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
128pub struct BinaryPoly64(u64);
129
130unsafe impl bytemuck::Pod for BinaryPoly64 {}
132unsafe impl bytemuck::Zeroable for BinaryPoly64 {}
133
134impl BinaryPoly64 {
135 pub const fn new(val: u64) -> Self {
136 Self(val)
137 }
138
139 pub fn value(&self) -> u64 {
140 self.0
141 }
142
143 pub fn shl(&self, n: u32) -> Self {
144 Self(self.0 << n)
145 }
146
147 pub fn shr(&self, n: u32) -> Self {
148 Self(self.0 >> n)
149 }
150
151 pub fn leading_zeros(&self) -> u32 {
152 self.0.leading_zeros()
153 }
154
155 pub fn split(&self) -> (BinaryPoly32, BinaryPoly32) {
156 let lo = BinaryPoly32::new(self.0 as u32);
157 let hi = BinaryPoly32::new((self.0 >> 32) as u32);
158 (hi, lo)
159 }
160}
161
162impl BinaryPolynomial for BinaryPoly64 {
163 type Value = u64;
164
165 fn zero() -> Self {
166 Self(0)
167 }
168
169 fn one() -> Self {
170 Self(1)
171 }
172
173 fn from_value(val: u64) -> Self {
174 Self(val)
175 }
176
177 fn value(&self) -> Self::Value {
178 self.0
179 }
180
181 fn add(&self, other: &Self) -> Self {
182 Self(self.0 ^ other.0)
183 }
184
185 fn mul(&self, other: &Self) -> Self {
186 use crate::simd::carryless_mul_64;
187 carryless_mul_64(*self, *other).truncate_to_64()
188 }
189
190 fn div_rem(&self, divisor: &Self) -> (Self, Self) {
191 assert_ne!(divisor.0, 0, "Division by zero");
192
193 let mut quotient = Self::zero();
194 let mut remainder = *self;
195
196 if remainder.0 == 0 {
197 return (quotient, remainder);
198 }
199
200 let divisor_bits = 64 - divisor.leading_zeros();
201 let mut remainder_bits = 64 - remainder.leading_zeros();
202
203 while remainder_bits >= divisor_bits && remainder.0 != 0 {
204 let shift = remainder_bits - divisor_bits;
205 quotient.0 |= 1 << shift;
206 remainder.0 ^= divisor.0 << shift;
207 remainder_bits = 64 - remainder.leading_zeros();
208 }
209
210 (quotient, remainder)
211 }
212}
213
214impl From<u64> for BinaryPoly64 {
215 fn from(val: u64) -> Self {
216 Self(val)
217 }
218}
219
220#[repr(transparent)]
222#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
223#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
224#[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
225pub struct BinaryPoly128(u128);
226
227unsafe impl bytemuck::Pod for BinaryPoly128 {}
229unsafe impl bytemuck::Zeroable for BinaryPoly128 {}
230
231impl BinaryPoly128 {
232 pub const fn new(val: u128) -> Self {
233 Self(val)
234 }
235
236 pub fn value(&self) -> u128 {
237 self.0
238 }
239
240 pub fn truncate_to_64(&self) -> BinaryPoly64 {
241 BinaryPoly64::new(self.0 as u64)
242 }
243
244 pub fn split(&self) -> (BinaryPoly64, BinaryPoly64) {
245 let lo = BinaryPoly64::new(self.0 as u64);
246 let hi = BinaryPoly64::new((self.0 >> 64) as u64);
247 (hi, lo)
248 }
249
250 pub fn leading_zeros(&self) -> u32 {
251 self.0.leading_zeros()
252 }
253
254 pub fn mul_full(&self, other: &Self) -> BinaryPoly256 {
256 use crate::simd::carryless_mul_128_full;
257 carryless_mul_128_full(*self, *other)
258 }
259}
260
261impl BinaryPolynomial for BinaryPoly128 {
262 type Value = u128;
263
264 fn zero() -> Self {
265 Self(0)
266 }
267
268 fn one() -> Self {
269 Self(1)
270 }
271
272 fn from_value(val: u64) -> Self {
273 Self(val as u128)
274 }
275
276 fn value(&self) -> Self::Value {
277 self.0
278 }
279
280 fn add(&self, other: &Self) -> Self {
281 Self(self.0 ^ other.0)
282 }
283
284 fn mul(&self, other: &Self) -> Self {
285 use crate::simd::carryless_mul_128;
286 carryless_mul_128(*self, *other)
287 }
288
289 fn div_rem(&self, divisor: &Self) -> (Self, Self) {
290 assert_ne!(divisor.0, 0, "Division by zero");
291
292 let mut quotient = Self::zero();
293 let mut remainder = *self;
294
295 if remainder.0 == 0 {
296 return (quotient, remainder);
297 }
298
299 let divisor_bits = 128 - divisor.leading_zeros();
300 let mut remainder_bits = 128 - remainder.leading_zeros();
301
302 while remainder_bits >= divisor_bits && remainder.0 != 0 {
303 let shift = remainder_bits - divisor_bits;
304 quotient.0 |= 1u128 << shift;
305 remainder.0 ^= divisor.0 << shift;
306 remainder_bits = 128 - remainder.leading_zeros();
307 }
308
309 (quotient, remainder)
310 }
311}
312
313impl From<u128> for BinaryPoly128 {
314 fn from(val: u128) -> Self {
315 Self(val)
316 }
317}
318
319#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
321#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
322#[cfg_attr(feature = "scale", derive(codec::Encode, codec::Decode, scale_info::TypeInfo))]
323pub struct BinaryPoly256 {
324 hi: u128,
325 lo: u128,
326}
327
328impl BinaryPoly256 {
329 pub fn from_parts(hi: u128, lo: u128) -> Self {
330 Self { hi, lo }
331 }
332
333 pub fn split(&self) -> (BinaryPoly128, BinaryPoly128) {
334 (BinaryPoly128::new(self.hi), BinaryPoly128::new(self.lo))
335 }
336
337 pub fn reduce_mod(&self, modulus: &BinaryPoly128) -> BinaryPoly128 {
339 if modulus.value() == (1u128 << 127) | 0x87 {
344 let mut result = self.lo;
346 let mut high = self.hi;
347
348 while high != 0 {
350 let feedback = high.wrapping_shl(7)
352 ^ high.wrapping_shl(2)
353 ^ high.wrapping_shl(1)
354 ^ high;
355
356 result ^= feedback;
357 high >>= 121; }
359
360 return BinaryPoly128::new(result);
361 }
362
363 if self.hi == 0 {
365 return BinaryPoly128::new(self.lo);
367 }
368
369 let mut remainder_hi = self.hi;
371 let mut remainder_lo = self.lo;
372
373 let mod_bits = 128 - modulus.leading_zeros();
375 let mod_val = modulus.value();
376 let mod_mask = mod_val ^ (1u128 << (mod_bits - 1));
377
378 while remainder_hi != 0 {
380 let shift = remainder_hi.leading_zeros();
381
382 if shift < 128 {
383 let bit_pos = 127 - shift;
385
386 remainder_hi ^= 1u128 << bit_pos;
388
389 if bit_pos >= (mod_bits - 1) {
391 remainder_hi ^= mod_mask << (bit_pos - (mod_bits - 1));
392 } else {
393 let right_shift = (mod_bits - 1) - bit_pos;
394 remainder_hi ^= mod_mask >> right_shift;
395 remainder_lo ^= mod_mask << (128 - right_shift);
396 }
397 } else {
398 break;
399 }
400 }
401
402 let mut remainder = BinaryPoly128::new(remainder_lo);
404
405 if remainder.leading_zeros() < modulus.leading_zeros() {
406 let (_, r) = remainder.div_rem(modulus);
407 remainder = r;
408 }
409
410 remainder
411 }
412
413 pub fn high(&self) -> BinaryPoly128 {
415 BinaryPoly128::new(self.hi)
416 }
417
418 pub fn low(&self) -> BinaryPoly128 {
420 BinaryPoly128::new(self.lo)
421 }
422
423 pub fn leading_zeros(&self) -> u32 {
424 if self.hi == 0 {
425 128 + self.lo.leading_zeros()
426 } else {
427 self.hi.leading_zeros()
428 }
429 }
430
431 pub fn add(&self, other: &Self) -> Self {
432 Self {
433 hi: self.hi ^ other.hi,
434 lo: self.lo ^ other.lo,
435 }
436 }
437
438 pub fn shl(&self, n: u32) -> Self {
439 if n == 0 {
440 *self
441 } else if n >= 256 {
442 Self { hi: 0, lo: 0 }
443 } else if n >= 128 {
444 Self {
445 hi: self.lo << (n - 128),
446 lo: 0,
447 }
448 } else {
449 Self {
450 hi: (self.hi << n) | (self.lo >> (128 - n)),
451 lo: self.lo << n,
452 }
453 }
454 }
455
456 pub fn shr(&self, n: u32) -> Self {
457 if n == 0 {
458 *self
459 } else if n >= 256 {
460 Self { hi: 0, lo: 0 }
461 } else if n >= 128 {
462 Self {
463 hi: 0,
464 lo: self.hi >> (n - 128),
465 }
466 } else {
467 Self {
468 hi: self.hi >> n,
469 lo: (self.lo >> n) | (self.hi << (128 - n)),
470 }
471 }
472 }
473}