1use crate::constants::FLAT_TO_TOWER_BIT_MASKS_8;
20use crate::towers::bit::Bit;
21use crate::{
22 CanonicalDeserialize, CanonicalSerialize, Flat, HardwareField, PackableField, PackedFlat,
23 TowerField, constants,
24};
25use core::ops::{Add, AddAssign, BitXor, Mul, MulAssign, Sub, SubAssign};
26use serde::{Deserialize, Serialize};
27use zeroize::Zeroize;
28
29#[cfg(not(feature = "table-math"))]
30#[repr(align(64))]
31struct CtConvertBasisU8<const N: usize>([u8; N]);
32
33#[cfg(not(feature = "table-math"))]
34static TOWER_TO_FLAT_BASIS_8: CtConvertBasisU8<8> =
35 CtConvertBasisU8(constants::RAW_TOWER_TO_FLAT_8);
36
37#[cfg(not(feature = "table-math"))]
38static FLAT_TO_TOWER_BASIS_8: CtConvertBasisU8<8> =
39 CtConvertBasisU8(constants::RAW_FLAT_TO_TOWER_8);
40
41#[cfg(feature = "table-math")]
52const EXP_TABLE: [u8; 256] = generate_exp_table();
53
54#[cfg(feature = "table-math")]
59const LOG_TABLE: [u8; 256] = generate_log_table();
60
61#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
63#[repr(transparent)]
64pub struct Block8(pub u8);
65
66impl Block8 {
67 pub const fn new(val: u8) -> Self {
68 Self(val)
69 }
70
71 #[inline(always)]
72 pub fn square(self) -> Self {
73 let mut s = self.0 as u16;
76 s = (s | (s << 4)) & 0x0f0f;
77 s = (s | (s << 2)) & 0x3333;
78 s = (s | (s << 1)) & 0x5555;
79
80 let hi = s >> 8;
81 let s = (s & 0x00ff) ^ (hi ^ (hi << 1) ^ (hi << 3) ^ (hi << 4));
82
83 let hi = s >> 8;
84
85 Block8(((s & 0x00ff) ^ (hi ^ (hi << 1) ^ (hi << 3) ^ (hi << 4))) as u8)
86 }
87}
88
89impl TowerField for Block8 {
90 const BITS: usize = 8;
91 const ZERO: Self = Block8(0);
92 const ONE: Self = Block8(1);
93
94 const EXTENSION_TAU: Self = Block8(0x20);
95
96 fn invert(&self) -> Self {
97 #[cfg(feature = "table-math")]
98 {
99 if self.0 == 0 {
100 return Self::ZERO;
101 }
102
103 let i = LOG_TABLE[self.0 as usize] as usize;
104 Block8(EXP_TABLE[255 - i])
105 }
106
107 #[cfg(not(feature = "table-math"))]
108 {
109 let x = *self;
113 let x2 = x * x;
114 let x4 = x2 * x2;
115 let x8 = x4 * x4;
116 let x16 = x8 * x8;
117 let x32 = x16 * x16;
118 let x64 = x32 * x32;
119 let x128 = x64 * x64;
120
121 x128 * x64 * x32 * x16 * x8 * x4 * x2
123 }
124 }
125
126 fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
127 Self(bytes[0])
128 }
129}
130
131impl Add for Block8 {
133 type Output = Self;
134
135 fn add(self, rhs: Self) -> Self::Output {
136 Self(self.0.bitxor(rhs.0))
137 }
138}
139
140impl Sub for Block8 {
141 type Output = Self;
142
143 fn sub(self, rhs: Self) -> Self::Output {
144 self.add(rhs)
145 }
146}
147
148impl Mul for Block8 {
150 type Output = Self;
151
152 fn mul(self, rhs: Self) -> Self::Output {
153 #[cfg(feature = "table-math")]
154 {
155 if self.0 == 0 || rhs.0 == 0 {
157 return Self::ZERO;
158 }
159
160 let i = LOG_TABLE[self.0 as usize] as usize;
164 let j = LOG_TABLE[rhs.0 as usize] as usize;
165
166 let k = i + j;
170 let idx = if k >= 255 { k - 255 } else { k };
171
172 Self(EXP_TABLE[idx])
174 }
175
176 #[cfg(not(feature = "table-math"))]
177 {
178 #[cfg(target_arch = "aarch64")]
179 {
180 neon::mul_8(self, rhs)
181 }
182
183 #[cfg(not(target_arch = "aarch64"))]
184 {
185 let mut a = self.0;
186 let mut b = rhs.0;
187 let mut res = 0u8;
188
189 for _ in 0..8 {
192 let bit = b & 1;
193 let mask = 0u8.wrapping_sub(bit);
194 res ^= a & mask;
195
196 let high_bit = a >> 7;
197 let overflow_mask = 0u8.wrapping_sub(high_bit);
198 a = (a << 1) ^ (0x1B & overflow_mask);
199
200 b >>= 1;
201 }
202
203 Self(res)
204 }
205 }
206 }
207}
208
209impl AddAssign for Block8 {
210 fn add_assign(&mut self, rhs: Self) {
211 *self = *self + rhs;
212 }
213}
214
215impl SubAssign for Block8 {
216 fn sub_assign(&mut self, rhs: Self) {
217 *self = *self - rhs;
218 }
219}
220
221impl MulAssign for Block8 {
222 fn mul_assign(&mut self, rhs: Self) {
223 *self = *self * rhs;
224 }
225}
226
227impl CanonicalSerialize for Block8 {
228 #[inline]
229 fn serialized_size(&self) -> usize {
230 1
231 }
232
233 #[inline]
234 fn serialize(&self, writer: &mut [u8]) -> Result<(), ()> {
235 if writer.is_empty() {
236 return Err(());
237 }
238
239 writer[0] = self.0;
240
241 Ok(())
242 }
243}
244
245impl CanonicalDeserialize for Block8 {
246 fn deserialize(bytes: &[u8]) -> Result<Self, ()> {
247 if bytes.is_empty() {
248 return Err(());
249 }
250
251 Ok(Self(bytes[0]))
252 }
253}
254
255impl From<u8> for Block8 {
256 #[inline]
257 fn from(val: u8) -> Self {
258 Self::new(val)
259 }
260}
261
262impl From<u32> for Block8 {
263 #[inline]
264 fn from(val: u32) -> Self {
265 Self(val as u8)
266 }
267}
268
269impl From<u64> for Block8 {
270 #[inline]
271 fn from(val: u64) -> Self {
272 Self(val as u8)
273 }
274}
275
276impl From<u128> for Block8 {
277 #[inline]
278 fn from(val: u128) -> Self {
279 Self(val as u8)
280 }
281}
282
283impl From<Bit> for Block8 {
288 #[inline(always)]
289 fn from(val: Bit) -> Self {
290 Self(val.0)
291 }
292}
293
294pub const PACKED_WIDTH_8: usize = 16;
300
301#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
302#[repr(C, align(16))]
303pub struct PackedBlock8(pub [Block8; PACKED_WIDTH_8]);
304
305impl PackedBlock8 {
306 #[inline(always)]
307 pub fn zero() -> Self {
308 Self([Block8::ZERO; PACKED_WIDTH_8])
309 }
310}
311
312impl PackableField for Block8 {
313 type Packed = PackedBlock8;
314
315 const WIDTH: usize = PACKED_WIDTH_8;
316
317 #[inline(always)]
318 fn pack(chunk: &[Self]) -> Self::Packed {
319 assert!(
320 chunk.len() >= PACKED_WIDTH_8,
321 "PackableField::pack: input slice too short",
322 );
323
324 let mut arr = [Self::ZERO; PACKED_WIDTH_8];
325 arr.copy_from_slice(&chunk[..PACKED_WIDTH_8]);
326
327 PackedBlock8(arr)
328 }
329
330 #[inline(always)]
331 fn unpack(packed: Self::Packed, output: &mut [Self]) {
332 assert!(
333 output.len() >= PACKED_WIDTH_8,
334 "PackableField::unpack: output slice too short",
335 );
336
337 output[..PACKED_WIDTH_8].copy_from_slice(&packed.0);
338 }
339}
340
341impl Add for PackedBlock8 {
342 type Output = Self;
343
344 #[inline(always)]
345 fn add(self, rhs: Self) -> Self {
346 let mut res = [Block8::ZERO; PACKED_WIDTH_8];
347 for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
348 *out = *l + *r;
349 }
350
351 Self(res)
352 }
353}
354
355impl AddAssign for PackedBlock8 {
356 #[inline(always)]
357 fn add_assign(&mut self, rhs: Self) {
358 for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
359 *l += *r;
360 }
361 }
362}
363
364impl Sub for PackedBlock8 {
365 type Output = Self;
366
367 #[inline(always)]
368 fn sub(self, rhs: Self) -> Self {
369 self.add(rhs)
370 }
371}
372
373impl SubAssign for PackedBlock8 {
374 #[inline(always)]
375 fn sub_assign(&mut self, rhs: Self) {
376 self.add_assign(rhs);
377 }
378}
379
380impl Mul for PackedBlock8 {
381 type Output = Self;
382
383 #[inline(always)]
384 fn mul(self, rhs: Self) -> Self {
385 #[cfg(target_arch = "aarch64")]
386 {
387 let mut res = [Block8::ZERO; PACKED_WIDTH_8];
388 for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
389 *out = mul_iso_8(*l, *r);
390 }
391
392 Self(res)
393 }
394
395 #[cfg(not(target_arch = "aarch64"))]
396 {
397 let mut res = [Block8::ZERO; PACKED_WIDTH_8];
398 for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
399 *out = *l * *r;
400 }
401
402 Self(res)
403 }
404 }
405}
406
407impl MulAssign for PackedBlock8 {
408 #[inline(always)]
409 fn mul_assign(&mut self, rhs: Self) {
410 *self = *self * rhs;
411 }
412}
413
414impl Mul<Block8> for PackedBlock8 {
415 type Output = Self;
416
417 #[inline(always)]
418 fn mul(self, rhs: Block8) -> Self {
419 let mut res = [Block8::ZERO; PACKED_WIDTH_8];
420 for (out, v) in res.iter_mut().zip(self.0.iter()) {
421 *out = *v * rhs;
422 }
423
424 Self(res)
425 }
426}
427
428impl HardwareField for Block8 {
433 #[inline(always)]
434 fn to_hardware(self) -> Flat<Self> {
435 #[cfg(feature = "table-math")]
436 {
437 Flat::from_raw(apply_matrix_8(self, &constants::TOWER_TO_FLAT_8))
438 }
439
440 #[cfg(not(feature = "table-math"))]
441 {
442 Flat::from_raw(Block8(map_ct_8(self.0, &TOWER_TO_FLAT_BASIS_8.0)))
443 }
444 }
445
446 #[inline(always)]
447 fn from_hardware(value: Flat<Self>) -> Self {
448 let value = value.into_raw();
449 #[cfg(feature = "table-math")]
450 {
451 apply_matrix_8(value, &constants::FLAT_TO_TOWER_8)
452 }
453
454 #[cfg(not(feature = "table-math"))]
455 {
456 Block8(map_ct_8(value.0, &FLAT_TO_TOWER_BASIS_8.0))
457 }
458 }
459
460 #[inline(always)]
461 fn add_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
462 Flat::from_raw(lhs.into_raw() + rhs.into_raw())
463 }
464
465 #[inline(always)]
466 fn add_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
467 let lhs = lhs.into_raw();
468 let rhs = rhs.into_raw();
469 #[cfg(target_arch = "aarch64")]
470 {
471 PackedFlat::from_raw(neon::add_packed_8(lhs, rhs))
472 }
473
474 #[cfg(not(target_arch = "aarch64"))]
475 {
476 PackedFlat::from_raw(lhs + rhs)
477 }
478 }
479
480 #[inline(always)]
481 fn mul_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
482 let lhs = lhs.into_raw();
483 let rhs = rhs.into_raw();
484 #[cfg(target_arch = "aarch64")]
485 {
486 Flat::from_raw(neon::mul_8(lhs, rhs))
487 }
488
489 #[cfg(not(target_arch = "aarch64"))]
490 {
491 let a_tower = Self::from_hardware(Flat::from_raw(lhs));
492 let b_tower = Self::from_hardware(Flat::from_raw(rhs));
493
494 (a_tower * b_tower).to_hardware()
495 }
496 }
497
498 #[inline(always)]
499 fn mul_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
500 let lhs = lhs.into_raw();
501 let rhs = rhs.into_raw();
502
503 #[cfg(target_arch = "aarch64")]
504 {
505 PackedFlat::from_raw(neon::mul_flat_packed_8(lhs, rhs))
506 }
507
508 #[cfg(not(target_arch = "aarch64"))]
509 {
510 let mut l = [Self::ZERO; <Self as PackableField>::WIDTH];
511 let mut r = [Self::ZERO; <Self as PackableField>::WIDTH];
512 let mut res = [Self::ZERO; <Self as PackableField>::WIDTH];
513
514 Self::unpack(lhs, &mut l);
515 Self::unpack(rhs, &mut r);
516
517 for i in 0..<Self as PackableField>::WIDTH {
518 res[i] = Self::mul_hardware(Flat::from_raw(l[i]), Flat::from_raw(r[i])).into_raw();
519 }
520
521 PackedFlat::from_raw(Self::pack(&res))
522 }
523 }
524
525 #[inline(always)]
526 fn mul_hardware_scalar_packed(lhs: PackedFlat<Self>, rhs: Flat<Self>) -> PackedFlat<Self> {
527 let broadcasted = PackedBlock8([rhs.into_raw(); PACKED_WIDTH_8]);
528 Self::mul_hardware_packed(lhs, PackedFlat::from_raw(broadcasted))
529 }
530
531 #[inline(always)]
532 fn tower_bit_from_hardware(value: Flat<Self>, bit_idx: usize) -> u8 {
533 let mask = FLAT_TO_TOWER_BIT_MASKS_8[bit_idx];
534
535 let mut v = value.into_raw().0 & mask;
537 v ^= v >> 4;
538 v ^= v >> 2;
539 v ^= v >> 1;
540
541 v & 1
542 }
543}
544
545#[cfg(target_arch = "aarch64")]
550#[inline(always)]
551fn mul_iso_8(a: Block8, b: Block8) -> Block8 {
552 let a_f = a.to_hardware();
553 let b_f = b.to_hardware();
554 let c_f = Flat::from_raw(neon::mul_8(a_f.into_raw(), b_f.into_raw()));
555
556 c_f.to_tower()
557}
558
559#[cfg(feature = "table-math")]
560#[inline(always)]
561fn apply_matrix_8(val: Block8, table: &[u8; 256]) -> Block8 {
562 let idx = val.0 as usize;
563 Block8(unsafe { *table.get_unchecked(idx) })
564}
565
566#[cfg(not(feature = "table-math"))]
567#[inline(always)]
568fn map_ct_8(x: u8, basis: &[u8; 8]) -> u8 {
569 let mut acc = 0u8;
570 let mut i = 0usize;
571
572 while i < 8 {
573 let bit = (x >> i) & 1;
574 let mask = 0u8.wrapping_sub(bit);
575 acc ^= basis[i] & mask;
576 i += 1;
577 }
578
579 acc
580}
581
582#[cfg(feature = "table-math")]
583const fn generate_exp_table() -> [u8; 256] {
584 let mut table = [0u8; 256];
585 let mut val: u8 = 1;
586
587 let mut i = 0;
594 while i < 256 {
595 table[i] = val;
596
597 let high_bit = val & 0x80;
601 let mut shifted = val << 1;
602
603 if high_bit != 0 {
607 shifted ^= 0x1B;
608 }
609
610 val = shifted ^ val;
611 i += 1;
612 }
613
614 table
615}
616
617#[cfg(feature = "table-math")]
618const fn generate_log_table() -> [u8; 256] {
619 let mut table = [0u8; 256];
620
621 let mut val: u8 = 1;
630 let mut i = 0;
631
632 while i < 255 {
633 table[val as usize] = i as u8;
634
635 let high_bit = val & 0x80;
636 let mut shifted = val << 1;
637
638 if high_bit != 0 {
639 shifted ^= 0x1B;
640 }
641
642 val = shifted ^ val;
643
644 i += 1;
645 }
646
647 table
650}
651
652#[cfg(target_arch = "aarch64")]
657mod neon {
658 use super::*;
659 use core::arch::aarch64::*;
660 use core::mem::transmute;
661
662 #[inline(always)]
663 pub fn add_packed_8(lhs: PackedBlock8, rhs: PackedBlock8) -> PackedBlock8 {
664 unsafe {
665 let res = veorq_u8(
666 transmute::<[Block8; 16], uint8x16_t>(lhs.0),
667 transmute::<[Block8; 16], uint8x16_t>(rhs.0),
668 );
669 transmute(res)
670 }
671 }
672
673 #[inline(always)]
674 pub fn mul_8(a: Block8, b: Block8) -> Block8 {
675 unsafe {
676 let a_poly = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(a.0));
679 let b_poly = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(b.0));
680
681 let prod = vmull_p8(a_poly, b_poly);
684
685 let prod_u16 = vgetq_lane_u16(transmute::<poly16x8_t, uint16x8_t>(prod), 0);
687
688 let l = (prod_u16 & 0xFF) as u8;
689 let h = (prod_u16 >> 8) as u8;
690
691 let r_val = constants::POLY_8; let h_poly = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(h));
696 let r_poly = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(r_val));
697 let h_red = vmull_p8(h_poly, r_poly);
698
699 let h_red_u16 = vgetq_lane_u16(transmute::<poly16x8_t, uint16x8_t>(h_red), 0);
700
701 let folded = (h_red_u16 & 0xFF) as u8;
702 let carry = (h_red_u16 >> 8) as u8;
703
704 let mut res = l ^ folded;
705
706 let c_poly = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(carry));
710 let c_red = vmull_p8(c_poly, r_poly);
711 let c_red_u16 = vgetq_lane_u16(transmute::<poly16x8_t, uint16x8_t>(c_red), 0);
712
713 res ^= (c_red_u16 & 0xFF) as u8;
714
715 Block8(res)
716 }
717 }
718
719 #[inline(always)]
722 pub fn mul_flat_packed_8(lhs: PackedBlock8, rhs: PackedBlock8) -> PackedBlock8 {
723 unsafe {
724 let a: uint8x16_t = transmute(lhs.0);
725 let b: uint8x16_t = transmute(rhs.0);
726
727 let a_lo = vget_low_u8(a);
729 let a_hi = vget_high_u8(a);
730 let b_lo = vget_low_u8(b);
731 let b_hi = vget_high_u8(b);
732
733 let res_lo = vmull_p8(
736 transmute::<uint8x8_t, poly8x8_t>(a_lo),
737 transmute::<uint8x8_t, poly8x8_t>(b_lo),
738 );
739 let res_hi = vmull_p8(
740 transmute::<uint8x8_t, poly8x8_t>(a_hi),
741 transmute::<uint8x8_t, poly8x8_t>(b_hi),
742 );
743
744 let tbl_lo = vld1q_u8(
747 [
748 0x00, 0x1b, 0x36, 0x2d, 0x6c, 0x77, 0x5a, 0x41, 0xd8, 0xc3, 0xee, 0xf5, 0xb4,
749 0xaf, 0x82, 0x99,
750 ]
751 .as_ptr(),
752 );
753
754 let tbl_hi = vld1q_u8(
755 [
756 0x00, 0xab, 0x4d, 0xe6, 0x9a, 0x31, 0xd7, 0x7c, 0x2f, 0x84, 0x62, 0xc9, 0xb5,
757 0x1e, 0xf8, 0x53,
758 ]
759 .as_ptr(),
760 );
761
762 let reduce_tbl = |val_poly: poly16x8_t| -> uint8x8_t {
766 let val: uint16x8_t = transmute(val_poly);
767
768 let data = vmovn_u16(val);
770 let carry_u16 = vshrq_n_u16(val, 8);
771 let carry = vmovn_u16(carry_u16);
772
773 let mask_lo = vdup_n_u8(0x0F);
775 let h_lo = vand_u8(carry, mask_lo);
776 let h_hi = vshr_n_u8(carry, 4);
777
778 let r_lo = vqtbl1_u8(tbl_lo, h_lo);
783 let r_hi = vqtbl1_u8(tbl_hi, h_hi);
784
785 veor_u8(data, veor_u8(r_lo, r_hi))
787 };
788
789 let final_lo = reduce_tbl(res_lo);
790 let final_hi = reduce_tbl(res_hi);
791
792 let res = vcombine_u8(final_lo, final_hi);
795
796 PackedBlock8(transmute::<uint8x16_t, [Block8; 16]>(res))
797 }
798 }
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804 use rand::{RngExt, rng};
805
806 #[test]
811 fn tower_constants() {
812 assert_eq!(Block8::EXTENSION_TAU, Block8(0x20));
815 }
816
817 #[test]
818 fn add_truth() {
819 let zero = Block8::ZERO;
820 let one = Block8::ONE;
821
822 assert_eq!(zero + zero, zero);
823 assert_eq!(zero + one, one);
824 assert_eq!(one + zero, one);
825 assert_eq!(one + one, zero);
826 }
827
828 #[test]
829 fn mul_truth() {
830 let zero = Block8::ZERO;
831 let one = Block8::ONE;
832
833 assert_eq!(zero * zero, zero);
834 assert_eq!(zero * one, zero);
835 assert_eq!(one * one, one);
836 }
837
838 #[test]
839 fn add() {
840 assert_eq!(Block8(5) + Block8(3), Block8(6));
843 }
844
845 #[test]
846 fn mul_simple() {
847 assert_eq!(Block8(2) * Block8(2), Block8(4));
850 }
851
852 #[test]
853 fn mul_overflow() {
854 assert_eq!(Block8(0x57) * Block8(0x83), Block8(0xC1));
858 }
859
860 #[test]
861 fn square_exhaustive() {
862 for i in 0u16..=255 {
863 let x = Block8(i as u8);
864 assert_eq!(x.square(), x * x, "Block8 square mismatch at {i:#04x}");
865 }
866 }
867
868 #[test]
869 fn security_zeroize() {
870 let mut secret_val = Block8::from(0xFF_u32);
871 assert_ne!(secret_val, Block8::ZERO);
872
873 secret_val.zeroize();
874
875 assert_eq!(secret_val, Block8::ZERO);
876 assert_eq!(secret_val.0, 0, "Block8 memory leak detected");
877 }
878
879 #[test]
880 fn inversion_exhaustive() {
881 for i in 0u8..=255 {
883 let val = Block8(i);
884
885 if val == Block8::ZERO {
886 assert_eq!(val.invert(), Block8::ZERO, "invert(0) must return 0");
889 } else {
890 let inv = val.invert();
893 let product = val * inv;
894
895 assert_eq!(
896 product,
897 Block8::ONE,
898 "Inversion identity failed: a * a^-1 != 1"
899 );
900 }
901 }
902 }
903
904 #[test]
909 fn isomorphism_roundtrip() {
910 let mut rng = rng();
911 for _ in 0..1000 {
912 let val = Block8::from(rng.random::<u8>());
913
914 assert_eq!(
917 val.to_hardware().to_tower(),
918 val,
919 "Block8 isomorphism roundtrip failed"
920 );
921 }
922 }
923
924 #[test]
925 fn parity_masks_match_from_hardware() {
926 for x in 0u16..=255 {
929 let x_flat = x as u8;
930 let tower = Block8::from_hardware(Flat::from_raw(Block8(x_flat))).0;
931
932 for (k, &mask) in FLAT_TO_TOWER_BIT_MASKS_8.iter().enumerate() {
933 let parity = ((x_flat & mask).count_ones() & 1) as u8;
934 let bit = (tower >> k) & 1;
935 assert_eq!(
936 parity, bit,
937 "Block8 mask mismatch at x={x_flat:#04x}, k={k}"
938 );
939
940 let via_api = Flat::from_raw(Block8(x_flat)).tower_bit(k);
941 assert_eq!(via_api, bit, "Block8 tower_bit_from_hardware mismatch");
942 }
943 }
944 }
945
946 #[test]
947 fn flat_mul_homomorphism() {
948 let mut rng = rng();
949 for _ in 0..1000 {
950 let a = Block8::from(rng.random::<u8>());
951 let b = Block8::from(rng.random::<u8>());
952
953 let expected_flat = (a * b).to_hardware();
954 let actual_flat = a.to_hardware() * b.to_hardware();
955
956 assert_eq!(
958 actual_flat, expected_flat,
959 "Block8 flat multiplication mismatch"
960 );
961 }
962 }
963
964 #[test]
965 fn packed_consistency() {
966 let mut rng = rng();
967 for _ in 0..100 {
968 let mut a_vals = [Block8::ZERO; 16];
969 let mut b_vals = [Block8::ZERO; 16];
970
971 for i in 0..16 {
972 a_vals[i] = Block8::from(rng.random::<u8>());
973 b_vals[i] = Block8::from(rng.random::<u8>());
974 }
975
976 let a_flat_vals = a_vals.map(|x| x.to_hardware());
977 let b_flat_vals = b_vals.map(|x| x.to_hardware());
978 let a_packed = Flat::<Block8>::pack(&a_flat_vals);
979 let b_packed = Flat::<Block8>::pack(&b_flat_vals);
980
981 let add_res = Block8::add_hardware_packed(a_packed, b_packed);
983
984 let mut add_out = [Block8::ZERO.to_hardware(); 16];
985 Flat::<Block8>::unpack(add_res, &mut add_out);
986
987 for i in 0..16 {
988 assert_eq!(
989 add_out[i],
990 (a_vals[i] + b_vals[i]).to_hardware(),
991 "Block8 packed add mismatch"
992 );
993 }
994
995 let mul_res = Block8::mul_hardware_packed(a_packed, b_packed);
997
998 let mut mul_out = [Block8::ZERO.to_hardware(); 16];
999 Flat::<Block8>::unpack(mul_res, &mut mul_out);
1000
1001 for i in 0..16 {
1002 assert_eq!(
1003 mul_out[i],
1004 (a_vals[i] * b_vals[i]).to_hardware(),
1005 "Block8 packed mul mismatch"
1006 );
1007 }
1008 }
1009 }
1010
1011 #[test]
1016 fn pack_unpack_roundtrip() {
1017 let mut rng = rng();
1018 let mut data = [Block8::ZERO; PACKED_WIDTH_8];
1019
1020 for v in data.iter_mut() {
1021 *v = Block8(rng.random());
1022 }
1023
1024 let packed = Block8::pack(&data);
1025 let mut unpacked = [Block8::ZERO; PACKED_WIDTH_8];
1026 Block8::unpack(packed, &mut unpacked);
1027
1028 assert_eq!(data, unpacked, "Block8 pack/unpack roundtrip failed");
1029 }
1030
1031 #[test]
1032 fn packed_add_consistency() {
1033 let mut rng = rng();
1034 let mut a_vals = [Block8::ZERO; PACKED_WIDTH_8];
1035 let mut b_vals = [Block8::ZERO; PACKED_WIDTH_8];
1036
1037 for i in 0..PACKED_WIDTH_8 {
1038 a_vals[i] = Block8(rng.random());
1039 b_vals[i] = Block8(rng.random());
1040 }
1041
1042 let a_packed = Block8::pack(&a_vals);
1043 let b_packed = Block8::pack(&b_vals);
1044 let res_packed = a_packed + b_packed;
1045
1046 let mut res_unpacked = [Block8::ZERO; PACKED_WIDTH_8];
1047 Block8::unpack(res_packed, &mut res_unpacked);
1048
1049 for i in 0..PACKED_WIDTH_8 {
1050 assert_eq!(
1051 res_unpacked[i],
1052 a_vals[i] + b_vals[i],
1053 "Block8 packed add mismatch at index {}",
1054 i
1055 );
1056 }
1057 }
1058
1059 #[test]
1060 fn packed_mul_consistency() {
1061 let mut rng = rng();
1062
1063 for _ in 0..1000 {
1064 let mut a_arr = [Block8::ZERO; PACKED_WIDTH_8];
1065 let mut b_arr = [Block8::ZERO; PACKED_WIDTH_8];
1066
1067 for i in 0..PACKED_WIDTH_8 {
1068 let val_a: u8 = rng.random();
1069 let val_b: u8 = rng.random();
1070 a_arr[i] = Block8(val_a);
1071 b_arr[i] = Block8(val_b);
1072 }
1073
1074 let a_packed = PackedBlock8(a_arr);
1075 let b_packed = PackedBlock8(b_arr);
1076 let c_packed = a_packed * b_packed;
1077
1078 let mut c_expected = [Block8::ZERO; PACKED_WIDTH_8];
1079 for i in 0..PACKED_WIDTH_8 {
1080 c_expected[i] = a_arr[i] * b_arr[i];
1081 }
1082
1083 assert_eq!(c_packed.0, c_expected, "SIMD Block8 mismatch!");
1084 }
1085 }
1086}