1use crate::towers::bit::Bit;
20use crate::towers::block8::Block8;
21use crate::{
22 BinaryFieldExtras, CanonicalDeserialize, CanonicalSerialize, Flat, FlatPromote, HardwareField,
23 PackableField, PackedFlat, TowerField, constants,
24};
25use core::ops::{Add, AddAssign, BitXor, BitXorAssign, Mul, MulAssign, Sub, SubAssign};
26use serde::{Deserialize, Serialize};
27use zeroize::Zeroize;
28
29#[cfg(not(feature = "table-math"))]
30#[repr(align(64))]
31struct CtConvertBasisU16<const N: usize>([u16; N]);
32
33#[cfg(not(feature = "table-math"))]
34static TOWER_TO_FLAT_BASIS_16: CtConvertBasisU16<16> =
35 CtConvertBasisU16(constants::RAW_TOWER_TO_FLAT_16);
36
37#[cfg(not(feature = "table-math"))]
38static FLAT_TO_TOWER_BASIS_16: CtConvertBasisU16<16> =
39 CtConvertBasisU16(constants::RAW_FLAT_TO_TOWER_16);
40
41#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
42#[repr(transparent)]
43pub struct Block16(pub u16);
44
45impl Block16 {
46 pub const TAU: Self = Block16(0x2000);
47
48 pub fn new(lo: Block8, hi: Block8) -> Self {
49 Self((hi.0 as u16) << 8 | (lo.0 as u16))
50 }
51
52 #[inline(always)]
53 pub fn split(self) -> (Block8, Block8) {
54 (Block8(self.0 as u8), Block8((self.0 >> 8) as u8))
55 }
56}
57
58impl TowerField for Block16 {
59 const BITS: usize = 16;
60 const ZERO: Self = Block16(0);
61 const ONE: Self = Block16(1);
62
63 const EXTENSION_TAU: Self = Self::TAU;
64
65 fn invert(&self) -> Self {
66 let (l, h) = self.split();
67
68 let h2 = h * h;
70 let l2 = l * l;
71 let hl = h * l;
72 let norm = (h2 * Block8::EXTENSION_TAU) + hl + l2;
73
74 let norm_inv = norm.invert();
75
76 let res_hi = h * norm_inv;
78 let res_lo = (h + l) * norm_inv;
79
80 Self::new(res_lo, res_hi)
81 }
82
83 fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
84 let mut buf = [0u8; 2];
85 buf.copy_from_slice(&bytes[0..2]);
86
87 Self(u16::from_le_bytes(buf))
88 }
89}
90
91impl Add for Block16 {
92 type Output = Self;
93
94 fn add(self, rhs: Self) -> Self {
95 Self(self.0.bitxor(rhs.0))
96 }
97}
98
99impl Sub for Block16 {
100 type Output = Self;
101
102 fn sub(self, rhs: Self) -> Self {
103 Self(self.0.bitxor(rhs.0))
104 }
105}
106
107impl Mul for Block16 {
108 type Output = Self;
109
110 fn mul(self, rhs: Self) -> Self {
111 let (a0, a1) = self.split();
112 let (b0, b1) = rhs.split();
113
114 let v0 = a0 * b0;
116 let v1 = a1 * b1;
117 let v_sum = (a0 + a1) * (b0 + b1);
118
119 let c_hi = v0 + v_sum;
122
123 let c_lo = v0 + (v1 * Block8::EXTENSION_TAU);
125
126 Self::new(c_lo, c_hi)
127 }
128}
129
130impl AddAssign for Block16 {
131 fn add_assign(&mut self, rhs: Self) {
132 self.0.bitxor_assign(rhs.0);
133 }
134}
135
136impl SubAssign for Block16 {
137 fn sub_assign(&mut self, rhs: Self) {
138 self.0.bitxor_assign(rhs.0);
139 }
140}
141
142impl MulAssign for Block16 {
143 fn mul_assign(&mut self, rhs: Self) {
144 *self = *self * rhs;
145 }
146}
147
148impl CanonicalSerialize for Block16 {
149 fn serialized_size(&self) -> usize {
150 2
151 }
152
153 fn serialize(&self, writer: &mut [u8]) -> Result<(), ()> {
154 if writer.len() < 2 {
155 return Err(());
156 }
157
158 writer[..2].copy_from_slice(&self.0.to_le_bytes());
159
160 Ok(())
161 }
162}
163
164impl CanonicalDeserialize for Block16 {
165 fn deserialize(bytes: &[u8]) -> Result<Self, ()> {
166 if bytes.len() < 2 {
167 return Err(());
168 }
169
170 let mut buf = [0u8; 2];
171 buf.copy_from_slice(&bytes[0..2]);
172
173 Ok(Self(u16::from_le_bytes(buf)))
174 }
175}
176
177impl From<u8> for Block16 {
178 fn from(val: u8) -> Self {
179 Self(val as u16)
180 }
181}
182
183impl From<u16> for Block16 {
184 #[inline]
185 fn from(val: u16) -> Self {
186 Self(val)
187 }
188}
189
190impl From<u32> for Block16 {
191 #[inline]
192 fn from(val: u32) -> Self {
193 Self(val as u16)
194 }
195}
196
197impl From<u64> for Block16 {
198 #[inline]
199 fn from(val: u64) -> Self {
200 Self(val as u16)
201 }
202}
203
204impl From<u128> for Block16 {
205 #[inline]
206 fn from(val: u128) -> Self {
207 Self(val as u16)
208 }
209}
210
211impl From<Bit> for Block16 {
216 #[inline(always)]
217 fn from(val: Bit) -> Self {
218 Self(val.0 as u16)
219 }
220}
221
222impl From<Block8> for Block16 {
223 #[inline(always)]
224 fn from(val: Block8) -> Self {
225 Self(val.0 as u16)
226 }
227}
228
229pub const PACKED_WIDTH_16: usize = 8;
235
236#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
237#[repr(C, align(16))]
238pub struct PackedBlock16(pub [Block16; PACKED_WIDTH_16]);
239
240impl PackedBlock16 {
241 #[inline(always)]
242 pub fn zero() -> Self {
243 Self([Block16::ZERO; PACKED_WIDTH_16])
244 }
245}
246
247impl PackableField for Block16 {
248 type Packed = PackedBlock16;
249
250 const WIDTH: usize = PACKED_WIDTH_16;
251
252 #[inline(always)]
253 fn pack(chunk: &[Self]) -> Self::Packed {
254 assert!(
255 chunk.len() >= PACKED_WIDTH_16,
256 "PackableField::pack: input slice too short",
257 );
258
259 let mut arr = [Self::ZERO; PACKED_WIDTH_16];
260 arr.copy_from_slice(&chunk[..PACKED_WIDTH_16]);
261
262 PackedBlock16(arr)
263 }
264
265 #[inline(always)]
266 fn unpack(packed: Self::Packed, output: &mut [Self]) {
267 assert!(
268 output.len() >= PACKED_WIDTH_16,
269 "PackableField::unpack: output slice too short",
270 );
271
272 output[..PACKED_WIDTH_16].copy_from_slice(&packed.0);
273 }
274}
275
276impl Add for PackedBlock16 {
277 type Output = Self;
278
279 #[inline(always)]
280 fn add(self, rhs: Self) -> Self {
281 let mut res = [Block16::ZERO; PACKED_WIDTH_16];
282 for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
283 *out = *l + *r;
284 }
285
286 Self(res)
287 }
288}
289
290impl AddAssign for PackedBlock16 {
291 #[inline(always)]
292 fn add_assign(&mut self, rhs: Self) {
293 for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
294 *l += *r;
295 }
296 }
297}
298
299impl Sub for PackedBlock16 {
300 type Output = Self;
301
302 #[inline(always)]
303 fn sub(self, rhs: Self) -> Self {
304 self.add(rhs)
305 }
306}
307
308impl SubAssign for PackedBlock16 {
309 #[inline(always)]
310 fn sub_assign(&mut self, rhs: Self) {
311 self.add_assign(rhs);
312 }
313}
314
315impl Mul for PackedBlock16 {
316 type Output = Self;
317
318 #[inline(always)]
319 fn mul(self, rhs: Self) -> Self {
320 #[cfg(target_arch = "aarch64")]
321 {
322 let mut res = [Block16::ZERO; PACKED_WIDTH_16];
323 for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
324 *out = mul_iso_16(*l, *r);
325 }
326
327 Self(res)
328 }
329
330 #[cfg(not(target_arch = "aarch64"))]
331 {
332 let mut res = [Block16::ZERO; PACKED_WIDTH_16];
333 for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
334 *out = *l * *r;
335 }
336
337 Self(res)
338 }
339 }
340}
341
342impl MulAssign for PackedBlock16 {
343 #[inline(always)]
344 fn mul_assign(&mut self, rhs: Self) {
345 *self = *self * rhs;
346 }
347}
348
349impl Mul<Block16> for PackedBlock16 {
350 type Output = Self;
351
352 #[inline(always)]
353 fn mul(self, rhs: Block16) -> Self {
354 let mut res = [Block16::ZERO; PACKED_WIDTH_16];
355 for (out, v) in res.iter_mut().zip(self.0.iter()) {
356 *out = *v * rhs;
357 }
358
359 Self(res)
360 }
361}
362
363impl HardwareField for Block16 {
368 #[inline(always)]
369 fn to_hardware(self) -> Flat<Self> {
370 #[cfg(feature = "table-math")]
371 {
372 Flat::from_raw(apply_matrix_16(self, &constants::TOWER_TO_FLAT_16))
373 }
374
375 #[cfg(not(feature = "table-math"))]
376 {
377 Flat::from_raw(Block16(map_ct_16(self.0, &TOWER_TO_FLAT_BASIS_16.0)))
378 }
379 }
380
381 #[inline(always)]
382 fn from_hardware(value: Flat<Self>) -> Self {
383 let value = value.into_raw();
384
385 #[cfg(feature = "table-math")]
386 {
387 apply_matrix_16(value, &constants::FLAT_TO_TOWER_16)
388 }
389
390 #[cfg(not(feature = "table-math"))]
391 {
392 Block16(map_ct_16(value.0, &FLAT_TO_TOWER_BASIS_16.0))
393 }
394 }
395
396 #[inline(always)]
397 fn add_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
398 Flat::from_raw(lhs.into_raw() + rhs.into_raw())
399 }
400
401 #[inline(always)]
402 fn add_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
403 let lhs = lhs.into_raw();
404 let rhs = rhs.into_raw();
405
406 #[cfg(target_arch = "aarch64")]
407 {
408 PackedFlat::from_raw(neon::add_packed_16(lhs, rhs))
409 }
410
411 #[cfg(not(target_arch = "aarch64"))]
412 {
413 PackedFlat::from_raw(lhs + rhs)
414 }
415 }
416
417 #[inline(always)]
418 fn mul_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
419 let lhs = lhs.into_raw();
420 let rhs = rhs.into_raw();
421
422 #[cfg(target_arch = "aarch64")]
423 {
424 Flat::from_raw(neon::mul_flat_16(lhs, rhs))
425 }
426
427 #[cfg(not(target_arch = "aarch64"))]
428 {
429 let a_tower = Self::from_hardware(Flat::from_raw(lhs));
430 let b_tower = Self::from_hardware(Flat::from_raw(rhs));
431
432 (a_tower * b_tower).to_hardware()
433 }
434 }
435
436 #[inline(always)]
437 fn mul_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
438 let lhs = lhs.into_raw();
439 let rhs = rhs.into_raw();
440
441 #[cfg(target_arch = "aarch64")]
442 {
443 PackedFlat::from_raw(neon::mul_flat_packed_16(lhs, rhs))
444 }
445
446 #[cfg(not(target_arch = "aarch64"))]
447 {
448 let mut l = [Self::ZERO; <Self as PackableField>::WIDTH];
449 let mut r = [Self::ZERO; <Self as PackableField>::WIDTH];
450 let mut res = [Self::ZERO; <Self as PackableField>::WIDTH];
451
452 Self::unpack(lhs, &mut l);
453 Self::unpack(rhs, &mut r);
454
455 for i in 0..<Self as PackableField>::WIDTH {
456 res[i] = Self::mul_hardware(Flat::from_raw(l[i]), Flat::from_raw(r[i])).into_raw();
457 }
458
459 PackedFlat::from_raw(Self::pack(&res))
460 }
461 }
462
463 #[inline(always)]
464 fn mul_hardware_scalar_packed(lhs: PackedFlat<Self>, rhs: Flat<Self>) -> PackedFlat<Self> {
465 #[cfg(target_arch = "aarch64")]
466 {
467 PackedFlat::from_raw(neon::mul_flat_scalar_packed_16(
468 lhs.into_raw(),
469 rhs.into_raw(),
470 ))
471 }
472
473 #[cfg(not(target_arch = "aarch64"))]
474 {
475 let broadcasted = PackedBlock16([rhs.into_raw(); PACKED_WIDTH_16]);
476 Self::mul_hardware_packed(lhs, PackedFlat::from_raw(broadcasted))
477 }
478 }
479
480 #[inline(always)]
481 fn tower_bit_from_hardware(value: Flat<Self>, bit_idx: usize) -> u8 {
482 let mask = constants::FLAT_TO_TOWER_BIT_MASKS_16[bit_idx];
483
484 let mut v = value.into_raw().0 & mask;
488 v ^= v >> 8;
489 v ^= v >> 4;
490 v ^= v >> 2;
491 v ^= v >> 1;
492
493 (v & 1) as u8
494 }
495}
496
497impl FlatPromote<Block8> for Block16 {
498 #[inline(always)]
499 fn promote_flat(val: Flat<Block8>) -> Flat<Self> {
500 let val = val.into_raw();
501
502 #[cfg(not(feature = "table-math"))]
503 {
504 let mut acc = 0u16;
505 for i in 0..8 {
506 let bit = (val.0 >> i) & 1;
507 let mask = 0u16.wrapping_sub(bit as u16);
508 acc ^= constants::LIFT_BASIS_8_TO_16[i] & mask;
509 }
510
511 Flat::from_raw(Block16(acc))
512 }
513
514 #[cfg(feature = "table-math")]
515 {
516 Flat::from_raw(Block16(constants::LIFT_TABLE_8_TO_16[val.0 as usize]))
517 }
518 }
519}
520
521impl BinaryFieldExtras for Block16 {
526 #[inline(always)]
527 fn square(&self) -> Self {
528 let (lo, hi) = self.split();
532 let hi2 = hi.square();
533
534 Self::new(lo.square() + hi2 * Block8::EXTENSION_TAU, hi2)
535 }
536
537 #[inline(always)]
538 fn trace(&self) -> Bit {
539 Bit(((self.0 & constants::TRACE_MASK_16).count_ones() & 1) as u8)
540 }
541
542 #[inline(always)]
543 fn solve_quadratic(c: Self) -> Option<Self> {
544 match c.trace() {
545 Bit(0) => Some(Block16(map_ct_16(
546 c.0,
547 &constants::SOLVE_QUADRATIC_BASIS_16,
548 ))),
549 _ => None,
550 }
551 }
552}
553
554#[cfg(target_arch = "aarch64")]
559#[inline(always)]
560pub fn mul_iso_16(a: Block16, b: Block16) -> Block16 {
561 let a_f = a.to_hardware();
562 let b_f = b.to_hardware();
563 let c_f = Flat::from_raw(neon::mul_flat_16(a_f.into_raw(), b_f.into_raw()));
564
565 c_f.to_tower()
566}
567
568#[cfg(feature = "table-math")]
569#[inline(always)]
570pub fn apply_matrix_16(val: Block16, table: &[u16; 512]) -> Block16 {
571 let v = val.0;
572 let mut res = 0u16;
573
574 for i in 0..2 {
576 let idx = (i * 256) + ((v >> (i * 8)) & 0xFF) as usize;
577 res ^= unsafe { *table.get_unchecked(idx) };
578 }
579
580 Block16(res)
581}
582
583#[inline(always)]
584fn map_ct_16(x: u16, basis: &[u16; 16]) -> u16 {
585 let mut acc = 0u16;
586 let mut i = 0usize;
587
588 while i < 16 {
589 let bit = (x >> i) & 1;
590 let mask = 0u16.wrapping_sub(bit);
591
592 acc ^= basis[i] & mask;
593 i += 1;
594 }
595
596 acc
597}
598
599#[cfg(target_arch = "aarch64")]
604mod neon {
605 use super::*;
606 use core::arch::aarch64::*;
607 use core::mem::transmute;
608
609 const _: () = assert!(constants::POLY_16 == 0x2b, "packed fold hardcodes R = 0x2b");
611
612 #[inline(always)]
613 pub fn add_packed_16(lhs: PackedBlock16, rhs: PackedBlock16) -> PackedBlock16 {
614 unsafe {
615 let res = veorq_u8(
616 transmute::<[Block16; 8], uint8x16_t>(lhs.0),
617 transmute::<[Block16; 8], uint8x16_t>(rhs.0),
618 );
619 transmute(res)
620 }
621 }
622
623 #[inline(always)]
624 pub fn mul_flat_16(a: Block16, b: Block16) -> Block16 {
625 unsafe {
626 let prod = vmull_p64(a.0 as u64, b.0 as u64);
631 let prod_val = vgetq_lane_u64(transmute::<u128, uint64x2_t>(prod), 0);
632
633 let l = (prod_val & 0xFFFF) as u16;
634 let h = (prod_val >> 16) as u16; let r_val = constants::POLY_16 as u64;
638
639 let h_red = vmull_p64(h as u64, r_val);
641 let h_red_val = vgetq_lane_u64(transmute::<u128, uint64x2_t>(h_red), 0);
642
643 let folded = (h_red_val & 0xFFFF) as u16;
649 let carry = (h_red_val >> 16) as u16;
650
651 let mut res = l ^ folded;
652
653 let c_red = vmull_p64(carry as u64, r_val);
656 let c_val = vgetq_lane_u64(transmute::<u128, uint64x2_t>(c_red), 0);
657
658 res ^= c_val as u16;
659
660 Block16(res)
661 }
662 }
663
664 #[inline(always)]
667 pub fn mul_flat_packed_16(lhs: PackedBlock16, rhs: PackedBlock16) -> PackedBlock16 {
668 unsafe {
669 let a = transmute::<[Block16; 8], uint16x8_t>(lhs.0);
670 let b = transmute::<[Block16; 8], uint16x8_t>(rhs.0);
671
672 let a_lo = vmovn_u16(a);
673 let a_hi = vmovn_u16(vshrq_n_u16(a, 8));
674 let b_lo = vmovn_u16(b);
675 let b_hi = vmovn_u16(vshrq_n_u16(b, 8));
676
677 let ll = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
678 transmute::<uint8x8_t, poly8x8_t>(a_lo),
679 transmute::<uint8x8_t, poly8x8_t>(b_lo),
680 ));
681
682 let hh = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
683 transmute::<uint8x8_t, poly8x8_t>(a_hi),
684 transmute::<uint8x8_t, poly8x8_t>(b_hi),
685 ));
686
687 let mm = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
688 transmute::<uint8x8_t, poly8x8_t>(veor_u8(a_lo, a_hi)),
689 transmute::<uint8x8_t, poly8x8_t>(veor_u8(b_lo, b_hi)),
690 ));
691
692 PackedBlock16(transmute::<uint16x8_t, [Block16; 8]>(reduce_packed_16(
693 ll, mm, hh,
694 )))
695 }
696 }
697
698 #[inline(always)]
701 pub fn mul_flat_scalar_packed_16(lhs: PackedBlock16, scalar: Block16) -> PackedBlock16 {
702 unsafe {
703 let a = transmute::<[Block16; 8], uint16x8_t>(lhs.0);
704
705 let s_lo = (scalar.0 & 0xff) as u8;
706 let s_hi = (scalar.0 >> 8) as u8;
707
708 let b_lo = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(s_lo));
709 let b_hi = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(s_hi));
710 let b_mid = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(s_lo ^ s_hi));
711
712 let a_lo = vmovn_u16(a);
713 let a_hi = vmovn_u16(vshrq_n_u16(a, 8));
714
715 let ll = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
716 transmute::<uint8x8_t, poly8x8_t>(a_lo),
717 b_lo,
718 ));
719
720 let hh = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
721 transmute::<uint8x8_t, poly8x8_t>(a_hi),
722 b_hi,
723 ));
724
725 let mm = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
726 transmute::<uint8x8_t, poly8x8_t>(veor_u8(a_lo, a_hi)),
727 b_mid,
728 ));
729
730 PackedBlock16(transmute::<uint16x8_t, [Block16; 8]>(reduce_packed_16(
731 ll, mm, hh,
732 )))
733 }
734 }
735
736 #[inline(always)]
738 fn reduce_packed_16(ll: uint16x8_t, mm: uint16x8_t, hh: uint16x8_t) -> uint16x8_t {
739 unsafe {
740 let mid = veorq_u16(veorq_u16(mm, ll), hh);
741 let l = veorq_u16(ll, vshlq_n_u16(mid, 8));
742 let h = veorq_u16(hh, vshrq_n_u16(mid, 8));
743
744 let h_fold = veorq_u16(
745 veorq_u16(vshlq_n_u16(h, 5), vshlq_n_u16(h, 3)),
746 veorq_u16(vshlq_n_u16(h, 1), h),
747 );
748
749 let carry = veorq_u16(
751 veorq_u16(vshrq_n_u16(h, 11), vshrq_n_u16(h, 13)),
752 vshrq_n_u16(h, 15),
753 );
754
755 let carry_fold = veorq_u16(
756 veorq_u16(vshlq_n_u16(carry, 5), vshlq_n_u16(carry, 3)),
757 veorq_u16(vshlq_n_u16(carry, 1), carry),
758 );
759
760 veorq_u16(veorq_u16(l, h_fold), carry_fold)
761 }
762 }
763}
764
765#[cfg(test)]
766mod tests {
767 use super::*;
768 use rand::{RngExt, rng};
769
770 #[cfg(target_arch = "aarch64")]
771 use proptest::prelude::*;
772
773 #[test]
778 fn tower_constants() {
779 let tau16 = Block16::EXTENSION_TAU;
782 let (lo16, hi16) = tau16.split();
783 assert_eq!(lo16, Block8::ZERO);
784 assert_eq!(hi16, Block8(0x20));
785 }
786
787 #[test]
788 fn add_truth() {
789 let zero = Block16::ZERO;
790 let one = Block16::ONE;
791
792 assert_eq!(zero + zero, zero);
793 assert_eq!(zero + one, one);
794 assert_eq!(one + zero, one);
795 assert_eq!(one + one, zero);
796 }
797
798 #[test]
799 fn mul_truth() {
800 let zero = Block16::ZERO;
801 let one = Block16::ONE;
802
803 assert_eq!(zero * zero, zero);
804 assert_eq!(zero * one, zero);
805 assert_eq!(one * one, one);
806 }
807
808 #[test]
809 fn add() {
810 assert_eq!(Block16(5) + Block16(3), Block16(6));
813 }
814
815 #[test]
816 fn mul_simple() {
817 assert_eq!(Block16(2) * Block16(2), Block16(4));
820 }
821
822 #[test]
823 fn mul_overflow() {
824 assert_eq!(Block16(0x57) * Block16(0x83), Block16(0xC1));
828 }
829
830 #[test]
831 fn karatsuba_correctness() {
832 let x = Block16::new(Block8::ZERO, Block8::ONE);
843 let squared = x * x;
844
845 let (res_lo, res_hi) = squared.split();
847
848 assert_eq!(res_hi, Block8::ONE, "X^2 should contain X component");
849 assert_eq!(
850 res_lo,
851 Block8(0x20),
852 "X^2 should contain tau component (0x20)"
853 );
854 }
855
856 #[test]
857 fn security_zeroize() {
858 let mut secret_val = Block16::from(0xDEAD_u16);
859 assert_ne!(secret_val, Block16::ZERO);
860
861 secret_val.zeroize();
862
863 assert_eq!(secret_val, Block16::ZERO);
864 assert_eq!(secret_val.0, 0, "Block16 memory leak detected");
865 }
866
867 #[test]
868 fn invert_zero() {
869 assert_eq!(
872 Block16::ZERO.invert(),
873 Block16::ZERO,
874 "invert(0) must return 0"
875 );
876 }
877
878 #[test]
879 fn inversion_random() {
880 let mut rng = rng();
881
882 for _ in 0..1000 {
884 let val_u16: u16 = rng.random();
885 let val = Block16(val_u16);
886
887 if val != Block16::ZERO {
888 let inv = val.invert();
889 let res = val * inv;
890
891 assert_eq!(
892 res,
893 Block16::ONE,
894 "Inversion identity failed: a * a^-1 != 1"
895 );
896 }
897 }
898 }
899
900 #[test]
901 fn tower_embedding() {
902 let mut rng = rng();
903 for _ in 0..100 {
904 let a = Block8(rng.random());
905 let b = Block8(rng.random());
906
907 let a_lifted: Block16 = a.into();
912 let (lo, hi) = a_lifted.split();
913
914 assert_eq!(lo, a, "Embedding structure failed: low part mismatch");
915 assert_eq!(
916 hi,
917 Block8::ZERO,
918 "Embedding structure failed: high part must be zero"
919 );
920
921 let sum_sub = a + b;
924 let sum_lifted: Block16 = sum_sub.into();
925 let sum_manual = Block16::from(a) + Block16::from(b);
926
927 assert_eq!(sum_lifted, sum_manual, "Homomorphism failed: add");
928
929 let prod_sub = a * b;
934 let prod_lifted: Block16 = prod_sub.into();
935 let prod_manual = Block16::from(a) * Block16::from(b);
936
937 assert_eq!(prod_lifted, prod_manual, "Homomorphism failed: mul");
938 }
939 }
940
941 #[test]
946 fn isomorphism_roundtrip() {
947 let mut rng = rng();
948 for _ in 0..1000 {
949 let val = Block16(rng.random::<u16>());
950 assert_eq!(
951 val.to_hardware().to_tower(),
952 val,
953 "Block16 isomorphism roundtrip failed"
954 );
955 }
956 }
957
958 #[test]
959 fn flat_mul_homomorphism() {
960 let mut rng = rng();
961 for _ in 0..1000 {
962 let a = Block16(rng.random::<u16>());
963 let b = Block16(rng.random::<u16>());
964
965 let expected_flat = (a * b).to_hardware();
966 let actual_flat = a.to_hardware() * b.to_hardware();
967
968 assert_eq!(
969 actual_flat, expected_flat,
970 "Block16 flat multiplication mismatch"
971 );
972 }
973 }
974
975 #[test]
976 fn packed_consistency() {
977 let mut rng = rng();
978 for _ in 0..100 {
979 let mut a_vals = [Block16::ZERO; 8];
980 let mut b_vals = [Block16::ZERO; 8];
981
982 for i in 0..8 {
983 a_vals[i] = Block16(rng.random::<u16>());
984 b_vals[i] = Block16(rng.random::<u16>());
985 }
986
987 let a_flat_vals = a_vals.map(|x| x.to_hardware());
988 let b_flat_vals = b_vals.map(|x| x.to_hardware());
989 let a_packed = Flat::<Block16>::pack(&a_flat_vals);
990 let b_packed = Flat::<Block16>::pack(&b_flat_vals);
991
992 let add_res = Block16::add_hardware_packed(a_packed, b_packed);
994
995 let mut add_out = [Block16::ZERO.to_hardware(); 8];
996 Flat::<Block16>::unpack(add_res, &mut add_out);
997
998 for i in 0..8 {
999 assert_eq!(
1000 add_out[i],
1001 (a_vals[i] + b_vals[i]).to_hardware(),
1002 "Block16 packed add mismatch"
1003 );
1004 }
1005
1006 let mul_res = Block16::mul_hardware_packed(a_packed, b_packed);
1008
1009 let mut mul_out = [Block16::ZERO.to_hardware(); 8];
1010 Flat::<Block16>::unpack(mul_res, &mut mul_out);
1011
1012 for i in 0..8 {
1013 assert_eq!(
1014 mul_out[i],
1015 (a_vals[i] * b_vals[i]).to_hardware(),
1016 "Block16 packed mul mismatch"
1017 );
1018 }
1019 }
1020 }
1021
1022 #[test]
1027 fn pack_unpack_roundtrip() {
1028 let mut rng = rng();
1029 let mut data = [Block16::ZERO; PACKED_WIDTH_16];
1030
1031 for v in data.iter_mut() {
1032 *v = Block16(rng.random());
1033 }
1034
1035 let packed = Block16::pack(&data);
1036 let mut unpacked = [Block16::ZERO; PACKED_WIDTH_16];
1037 Block16::unpack(packed, &mut unpacked);
1038
1039 assert_eq!(data, unpacked, "Block16 pack/unpack roundtrip failed");
1040 }
1041
1042 #[test]
1043 fn packed_add_consistency() {
1044 let mut rng = rng();
1045 let mut a_vals = [Block16::ZERO; PACKED_WIDTH_16];
1046 let mut b_vals = [Block16::ZERO; PACKED_WIDTH_16];
1047
1048 for i in 0..PACKED_WIDTH_16 {
1049 a_vals[i] = Block16(rng.random());
1050 b_vals[i] = Block16(rng.random());
1051 }
1052
1053 let res_packed = Block16::pack(&a_vals) + Block16::pack(&b_vals);
1054 let mut res_unpacked = [Block16::ZERO; PACKED_WIDTH_16];
1055 Block16::unpack(res_packed, &mut res_unpacked);
1056
1057 for i in 0..PACKED_WIDTH_16 {
1058 assert_eq!(
1059 res_unpacked[i],
1060 a_vals[i] + b_vals[i],
1061 "Block16 packed add mismatch"
1062 );
1063 }
1064 }
1065
1066 #[test]
1067 fn packed_mul_consistency() {
1068 let mut rng = rng();
1069
1070 for _ in 0..1000 {
1071 let mut a_arr = [Block16::ZERO; PACKED_WIDTH_16];
1072 let mut b_arr = [Block16::ZERO; PACKED_WIDTH_16];
1073
1074 for i in 0..PACKED_WIDTH_16 {
1075 let val_a_u16: u16 = rng.random();
1076 let val_b_u16: u16 = rng.random();
1077
1078 a_arr[i] = Block16(val_a_u16);
1079 b_arr[i] = Block16(val_b_u16);
1080 }
1081
1082 let a_packed = PackedBlock16(a_arr);
1083 let b_packed = PackedBlock16(b_arr);
1084 let c_packed = a_packed * b_packed;
1085
1086 let mut c_expected = [Block16::ZERO; PACKED_WIDTH_16];
1087 for i in 0..PACKED_WIDTH_16 {
1088 c_expected[i] = a_arr[i] * b_arr[i];
1089 }
1090
1091 assert_eq!(c_packed.0, c_expected, "SIMD Block16 mismatch!");
1092 }
1093 }
1094
1095 #[test]
1096 fn parity_masks_match_from_hardware() {
1097 for x_flat in 0u16..=u16::MAX {
1100 let tower = Block16::from_hardware(Flat::from_raw(Block16(x_flat))).0;
1101
1102 for k in 0..16 {
1103 let bit = ((tower >> k) & 1) as u8;
1104 let via_api = Flat::from_raw(Block16(x_flat)).tower_bit(k);
1105
1106 assert_eq!(
1107 via_api, bit,
1108 "Block16 tower_bit_from_hardware mismatch at x_flat={x_flat:#06x}, bit_idx={k}"
1109 );
1110 }
1111 }
1112 }
1113
1114 #[cfg(target_arch = "aarch64")]
1117 proptest! {
1118 #![proptest_config(ProptestConfig::with_cases(65536))]
1119
1120 #[test]
1121 fn neon_packed_eq_scalar(a in any::<[u16; 8]>(), b in any::<[u16; 8]>()) {
1122 let pp = neon::mul_flat_packed_16(
1123 PackedBlock16(a.map(Block16)),
1124 PackedBlock16(b.map(Block16)),
1125 );
1126
1127 let want: [Block16; 8] =
1128 core::array::from_fn(|i| neon::mul_flat_16(Block16(a[i]), Block16(b[i])));
1129
1130 prop_assert_eq!(pp.0, want);
1131 }
1132
1133 #[test]
1134 fn neon_scalar_packed_eq_scalar(a in any::<[u16; 8]>(), s in any::<u16>()) {
1135 let sp = neon::mul_flat_scalar_packed_16(PackedBlock16(a.map(Block16)), Block16(s));
1136
1137 let want: [Block16; 8] =
1138 core::array::from_fn(|i| neon::mul_flat_16(Block16(a[i]), Block16(s)));
1139
1140 prop_assert_eq!(sp.0, want);
1141 }
1142 }
1143}