1use crate::BinaryPolynomial;
3
4macro_rules! impl_binary_poly {
6 ($name:ident, $value_type:ty, $double_name:ident) => {
7 #[repr(transparent)]
8 #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
9 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10 #[cfg_attr(
11 feature = "scale",
12 derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
13 )]
14 pub struct $name($value_type);
15
16 unsafe impl bytemuck::Pod for $name {}
18 unsafe impl bytemuck::Zeroable for $name {}
19
20 impl $name {
21 pub const fn new(val: $value_type) -> Self {
22 Self(val)
23 }
24
25 pub fn value(&self) -> $value_type {
26 self.0
27 }
28
29 pub fn shl(&self, n: u32) -> Self {
30 Self(self.0 << n)
31 }
32
33 pub fn shr(&self, n: u32) -> Self {
34 Self(self.0 >> n)
35 }
36
37 pub fn leading_zeros(&self) -> u32 {
38 self.0.leading_zeros()
39 }
40
41 #[allow(dead_code)]
42 pub fn split(&self) -> (Self, Self) {
43 let half_bits = core::mem::size_of::<$value_type>() * 4;
44 let mask = ((1u64 << half_bits) - 1) as $value_type;
45 let lo = Self(self.0 & mask);
46 let hi = Self(self.0 >> half_bits);
47 (hi, lo)
48 }
49 }
50
51 impl BinaryPolynomial for $name {
52 type Value = $value_type;
53
54 fn zero() -> Self {
55 Self(0)
56 }
57
58 fn one() -> Self {
59 Self(1)
60 }
61
62 fn from_value(val: u64) -> Self {
63 Self(val as $value_type)
64 }
65
66 fn value(&self) -> Self::Value {
67 self.0
68 }
69
70 fn add(&self, other: &Self) -> Self {
71 Self(self.0 ^ other.0)
72 }
73
74 fn mul(&self, other: &Self) -> Self {
75 let mut result = 0 as $value_type;
77 let a = self.0;
78 let b = other.0;
79 let bits = core::mem::size_of::<$value_type>() * 8;
80
81 for i in 0..bits {
82 let mask = (0 as $value_type).wrapping_sub((b >> i) & 1);
84 result ^= a.wrapping_shl(i as u32) & mask;
85 }
86
87 Self(result)
88 }
89
90 fn div_rem(&self, divisor: &Self) -> (Self, Self) {
91 assert_ne!(divisor.0, 0, "Division by zero");
92
93 let mut quotient = Self::zero();
94 let mut remainder = *self;
95
96 if remainder.0 == 0 {
97 return (quotient, remainder);
98 }
99
100 let divisor_bits =
101 (core::mem::size_of::<$value_type>() * 8) as u32 - divisor.leading_zeros();
102 let mut remainder_bits =
103 (core::mem::size_of::<$value_type>() * 8) as u32 - remainder.leading_zeros();
104
105 while remainder_bits >= divisor_bits && remainder.0 != 0 {
106 let shift = remainder_bits - divisor_bits;
107 quotient.0 |= 1 << shift;
108 remainder.0 ^= divisor.0 << shift;
109 remainder_bits = (core::mem::size_of::<$value_type>() * 8) as u32
110 - remainder.leading_zeros();
111 }
112
113 (quotient, remainder)
114 }
115 }
116
117 impl From<$value_type> for $name {
118 fn from(val: $value_type) -> Self {
119 Self(val)
120 }
121 }
122 };
123}
124
125impl_binary_poly!(BinaryPoly16, u16, BinaryPoly32);
127impl_binary_poly!(BinaryPoly32, u32, BinaryPoly64);
128
129#[repr(transparent)]
131#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133#[cfg_attr(
134 feature = "scale",
135 derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
136)]
137pub struct BinaryPoly64(u64);
138
139unsafe impl bytemuck::Pod for BinaryPoly64 {}
141unsafe impl bytemuck::Zeroable for BinaryPoly64 {}
142
143impl BinaryPoly64 {
144 pub const fn new(val: u64) -> Self {
145 Self(val)
146 }
147
148 pub fn value(&self) -> u64 {
149 self.0
150 }
151
152 pub fn shl(&self, n: u32) -> Self {
153 Self(self.0 << n)
154 }
155
156 pub fn shr(&self, n: u32) -> Self {
157 Self(self.0 >> n)
158 }
159
160 pub fn leading_zeros(&self) -> u32 {
161 self.0.leading_zeros()
162 }
163
164 pub fn split(&self) -> (BinaryPoly32, BinaryPoly32) {
165 let lo = BinaryPoly32::new(self.0 as u32);
166 let hi = BinaryPoly32::new((self.0 >> 32) as u32);
167 (hi, lo)
168 }
169}
170
171impl BinaryPolynomial for BinaryPoly64 {
172 type Value = u64;
173
174 fn zero() -> Self {
175 Self(0)
176 }
177
178 fn one() -> Self {
179 Self(1)
180 }
181
182 fn from_value(val: u64) -> Self {
183 Self(val)
184 }
185
186 fn value(&self) -> Self::Value {
187 self.0
188 }
189
190 fn add(&self, other: &Self) -> Self {
191 Self(self.0 ^ other.0)
192 }
193
194 fn mul(&self, other: &Self) -> Self {
195 use crate::simd::carryless_mul_64;
196 carryless_mul_64(*self, *other).truncate_to_64()
197 }
198
199 fn div_rem(&self, divisor: &Self) -> (Self, Self) {
200 assert_ne!(divisor.0, 0, "Division by zero");
201
202 let mut quotient = Self::zero();
203 let mut remainder = *self;
204
205 if remainder.0 == 0 {
206 return (quotient, remainder);
207 }
208
209 let divisor_bits = 64 - divisor.leading_zeros();
210 let mut remainder_bits = 64 - remainder.leading_zeros();
211
212 while remainder_bits >= divisor_bits && remainder.0 != 0 {
213 let shift = remainder_bits - divisor_bits;
214 quotient.0 |= 1 << shift;
215 remainder.0 ^= divisor.0 << shift;
216 remainder_bits = 64 - remainder.leading_zeros();
217 }
218
219 (quotient, remainder)
220 }
221}
222
223impl From<u64> for BinaryPoly64 {
224 fn from(val: u64) -> Self {
225 Self(val)
226 }
227}
228
229#[repr(transparent)]
231#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
232#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
233#[cfg_attr(
234 feature = "scale",
235 derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
236)]
237pub struct BinaryPoly128(u128);
238
239unsafe impl bytemuck::Pod for BinaryPoly128 {}
241unsafe impl bytemuck::Zeroable for BinaryPoly128 {}
242
243impl BinaryPoly128 {
244 pub const fn new(val: u128) -> Self {
245 Self(val)
246 }
247
248 pub fn value(&self) -> u128 {
249 self.0
250 }
251
252 pub fn truncate_to_64(&self) -> BinaryPoly64 {
253 BinaryPoly64::new(self.0 as u64)
254 }
255
256 pub fn split(&self) -> (BinaryPoly64, BinaryPoly64) {
257 let lo = BinaryPoly64::new(self.0 as u64);
258 let hi = BinaryPoly64::new((self.0 >> 64) as u64);
259 (hi, lo)
260 }
261
262 pub fn leading_zeros(&self) -> u32 {
263 self.0.leading_zeros()
264 }
265
266 pub fn mul_full(&self, other: &Self) -> BinaryPoly256 {
268 use crate::simd::carryless_mul_128_full;
269 carryless_mul_128_full(*self, *other)
270 }
271}
272
273impl BinaryPolynomial for BinaryPoly128 {
274 type Value = u128;
275
276 fn zero() -> Self {
277 Self(0)
278 }
279
280 fn one() -> Self {
281 Self(1)
282 }
283
284 fn from_value(val: u64) -> Self {
285 Self(val as u128)
286 }
287
288 fn value(&self) -> Self::Value {
289 self.0
290 }
291
292 fn add(&self, other: &Self) -> Self {
293 Self(self.0 ^ other.0)
294 }
295
296 fn mul(&self, other: &Self) -> Self {
297 use crate::simd::carryless_mul_128;
298 carryless_mul_128(*self, *other)
299 }
300
301 fn div_rem(&self, divisor: &Self) -> (Self, Self) {
302 assert_ne!(divisor.0, 0, "Division by zero");
303
304 let mut quotient = Self::zero();
305 let mut remainder = *self;
306
307 if remainder.0 == 0 {
308 return (quotient, remainder);
309 }
310
311 let divisor_bits = 128 - divisor.leading_zeros();
312 let mut remainder_bits = 128 - remainder.leading_zeros();
313
314 while remainder_bits >= divisor_bits && remainder.0 != 0 {
315 let shift = remainder_bits - divisor_bits;
316 quotient.0 |= 1u128 << shift;
317 remainder.0 ^= divisor.0 << shift;
318 remainder_bits = 128 - remainder.leading_zeros();
319 }
320
321 (quotient, remainder)
322 }
323}
324
325impl From<u128> for BinaryPoly128 {
326 fn from(val: u128) -> Self {
327 Self(val)
328 }
329}
330
331#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
333#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
334#[cfg_attr(
335 feature = "scale",
336 derive(codec::Encode, codec::Decode, scale_info::TypeInfo)
337)]
338pub struct BinaryPoly256 {
339 hi: u128,
340 lo: u128,
341}
342
343impl BinaryPoly256 {
344 pub fn from_parts(hi: u128, lo: u128) -> Self {
345 Self { hi, lo }
346 }
347
348 pub fn split(&self) -> (BinaryPoly128, BinaryPoly128) {
349 (BinaryPoly128::new(self.hi), BinaryPoly128::new(self.lo))
350 }
351
352 pub fn reduce_mod(&self, modulus: &BinaryPoly128) -> BinaryPoly128 {
354 if modulus.value() == (1u128 << 127) | 0x87 {
359 let mut result = self.lo;
361 let mut high = self.hi;
362
363 while high != 0 {
365 let feedback =
367 high.wrapping_shl(7) ^ high.wrapping_shl(2) ^ high.wrapping_shl(1) ^ high;
368
369 result ^= feedback;
370 high >>= 121; }
372
373 return BinaryPoly128::new(result);
374 }
375
376 if self.hi == 0 {
378 return BinaryPoly128::new(self.lo);
380 }
381
382 let mut remainder_hi = self.hi;
384 let mut remainder_lo = self.lo;
385
386 let mod_bits = 128 - modulus.leading_zeros();
388 let mod_val = modulus.value();
389 let mod_mask = mod_val ^ (1u128 << (mod_bits - 1));
390
391 while remainder_hi != 0 {
393 let shift = remainder_hi.leading_zeros();
394
395 if shift < 128 {
396 let bit_pos = 127 - shift;
398
399 remainder_hi ^= 1u128 << bit_pos;
401
402 if bit_pos >= (mod_bits - 1) {
404 remainder_hi ^= mod_mask << (bit_pos - (mod_bits - 1));
405 } else {
406 let right_shift = (mod_bits - 1) - bit_pos;
407 remainder_hi ^= mod_mask >> right_shift;
408 remainder_lo ^= mod_mask << (128 - right_shift);
409 }
410 } else {
411 break;
412 }
413 }
414
415 let mut remainder = BinaryPoly128::new(remainder_lo);
417
418 if remainder.leading_zeros() < modulus.leading_zeros() {
419 let (_, r) = remainder.div_rem(modulus);
420 remainder = r;
421 }
422
423 remainder
424 }
425
426 pub fn high(&self) -> BinaryPoly128 {
428 BinaryPoly128::new(self.hi)
429 }
430
431 pub fn low(&self) -> BinaryPoly128 {
433 BinaryPoly128::new(self.lo)
434 }
435
436 pub fn leading_zeros(&self) -> u32 {
437 if self.hi == 0 {
438 128 + self.lo.leading_zeros()
439 } else {
440 self.hi.leading_zeros()
441 }
442 }
443
444 pub fn add(&self, other: &Self) -> Self {
445 Self {
446 hi: self.hi ^ other.hi,
447 lo: self.lo ^ other.lo,
448 }
449 }
450
451 pub fn shl(&self, n: u32) -> Self {
452 if n == 0 {
453 *self
454 } else if n >= 256 {
455 Self { hi: 0, lo: 0 }
456 } else if n >= 128 {
457 Self {
458 hi: self.lo << (n - 128),
459 lo: 0,
460 }
461 } else {
462 Self {
463 hi: (self.hi << n) | (self.lo >> (128 - n)),
464 lo: self.lo << n,
465 }
466 }
467 }
468
469 pub fn shr(&self, n: u32) -> Self {
470 if n == 0 {
471 *self
472 } else if n >= 256 {
473 Self { hi: 0, lo: 0 }
474 } else if n >= 128 {
475 Self {
476 hi: 0,
477 lo: self.hi >> (n - 128),
478 }
479 } else {
480 Self {
481 hi: self.hi >> n,
482 lo: (self.lo >> n) | (self.hi << (128 - n)),
483 }
484 }
485 }
486}