Skip to main content

hekate_math/towers/
block32.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate-math project.
3// Copyright (C) 2026 Andrei Kochergin <zeek@tuta.com>
4// Copyright (C) 2026 Oumuamua Labs. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! BLOCK 32 (GF(2^32))
19use crate::towers::bit::Bit;
20use crate::towers::block8::Block8;
21use crate::towers::block16::Block16;
22use crate::{
23    CanonicalDeserialize, CanonicalSerialize, Flat, FlatPromote, HardwareField, PackableField,
24    PackedFlat, TowerField, constants,
25};
26use core::ops::{Add, AddAssign, BitXor, BitXorAssign, Mul, MulAssign, Sub, SubAssign};
27use serde::{Deserialize, Serialize};
28use zeroize::Zeroize;
29
30#[cfg(not(feature = "table-math"))]
31#[repr(align(64))]
32struct CtConvertBasisU32<const N: usize>([u32; N]);
33
34#[cfg(not(feature = "table-math"))]
35static TOWER_TO_FLAT_BASIS_32: CtConvertBasisU32<32> =
36    CtConvertBasisU32(constants::RAW_TOWER_TO_FLAT_32);
37
38#[cfg(not(feature = "table-math"))]
39static FLAT_TO_TOWER_BASIS_32: CtConvertBasisU32<32> =
40    CtConvertBasisU32(constants::RAW_FLAT_TO_TOWER_32);
41
42#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
43#[repr(transparent)]
44pub struct Block32(pub u32);
45
46impl Block32 {
47    // 0x2000 << 16 = 0x2000_0000
48    pub const TAU: Self = Block32(0x2000_0000);
49
50    pub fn new(lo: Block16, hi: Block16) -> Self {
51        Self((hi.0 as u32) << 16 | (lo.0 as u32))
52    }
53
54    #[inline(always)]
55    pub fn split(self) -> (Block16, Block16) {
56        (Block16(self.0 as u16), Block16((self.0 >> 16) as u16))
57    }
58}
59
60impl TowerField for Block32 {
61    const BITS: usize = 32;
62    const ZERO: Self = Block32(0);
63    const ONE: Self = Block32(1);
64
65    const EXTENSION_TAU: Self = Self::TAU;
66
67    fn invert(&self) -> Self {
68        let (l, h) = self.split();
69        let h2 = h * h;
70        let l2 = l * l;
71        let hl = h * l;
72
73        // Tau here is Block16::TAU
74        let norm = (h2 * Block16::TAU) + hl + l2;
75
76        let norm_inv = norm.invert();
77        let res_hi = h * norm_inv;
78        let res_lo = (h + l) * norm_inv;
79
80        Self::new(res_lo, res_hi)
81    }
82
83    fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
84        let mut buf = [0u8; 4];
85        buf.copy_from_slice(&bytes[0..4]);
86
87        Self(u32::from_le_bytes(buf))
88    }
89}
90
91impl Add for Block32 {
92    type Output = Self;
93
94    fn add(self, rhs: Self) -> Self {
95        Self(self.0.bitxor(rhs.0))
96    }
97}
98
99impl Sub for Block32 {
100    type Output = Self;
101
102    fn sub(self, rhs: Self) -> Self {
103        self.add(rhs)
104    }
105}
106
107impl Mul for Block32 {
108    type Output = Self;
109
110    fn mul(self, rhs: Self) -> Self {
111        let (a0, a1) = self.split();
112        let (b0, b1) = rhs.split();
113
114        let v0 = a0 * b0;
115        let v1 = a1 * b1;
116        let v_sum = (a0 + a1) * (b0 + b1);
117
118        let c_hi = v0 + v_sum;
119        let c_lo = v0 + (v1 * Block16::TAU);
120
121        Self::new(c_lo, c_hi)
122    }
123}
124
125impl AddAssign for Block32 {
126    fn add_assign(&mut self, rhs: Self) {
127        self.0.bitxor_assign(rhs.0);
128    }
129}
130
131impl SubAssign for Block32 {
132    fn sub_assign(&mut self, rhs: Self) {
133        self.0.bitxor_assign(rhs.0);
134    }
135}
136
137impl MulAssign for Block32 {
138    fn mul_assign(&mut self, rhs: Self) {
139        *self = *self * rhs;
140    }
141}
142
143impl CanonicalSerialize for Block32 {
144    fn serialized_size(&self) -> usize {
145        4
146    }
147
148    fn serialize(&self, writer: &mut [u8]) -> Result<(), ()> {
149        if writer.len() < 4 {
150            return Err(());
151        }
152
153        writer.copy_from_slice(&self.0.to_le_bytes());
154
155        Ok(())
156    }
157}
158
159impl CanonicalDeserialize for Block32 {
160    fn deserialize(bytes: &[u8]) -> Result<Self, ()> {
161        if bytes.len() < 4 {
162            return Err(());
163        }
164
165        let mut buf = [0u8; 4];
166        buf.copy_from_slice(&bytes[0..4]);
167
168        Ok(Self(u32::from_le_bytes(buf)))
169    }
170}
171
172impl From<u8> for Block32 {
173    fn from(val: u8) -> Self {
174        Self(val as u32)
175    }
176}
177
178impl From<u16> for Block32 {
179    #[inline]
180    fn from(val: u16) -> Self {
181        Self::from(val as u32)
182    }
183}
184
185impl From<u32> for Block32 {
186    #[inline]
187    fn from(val: u32) -> Self {
188        Self(val)
189    }
190}
191
192impl From<u64> for Block32 {
193    #[inline]
194    fn from(val: u64) -> Self {
195        Self(val as u32)
196    }
197}
198
199impl From<u128> for Block32 {
200    #[inline]
201    fn from(val: u128) -> Self {
202        Self(val as u32)
203    }
204}
205
206// ========================================
207// FIELD LIFTING
208// ========================================
209
210impl From<Bit> for Block32 {
211    #[inline(always)]
212    fn from(val: Bit) -> Self {
213        Self(val.0 as u32)
214    }
215}
216
217impl From<Block8> for Block32 {
218    #[inline(always)]
219    fn from(val: Block8) -> Self {
220        Self(val.0 as u32)
221    }
222}
223
224impl From<Block16> for Block32 {
225    #[inline(always)]
226    fn from(val: Block16) -> Self {
227        Self(val.0 as u32)
228    }
229}
230
231// ========================================
232// PACKED BLOCK 32 (Width = 4)
233// ========================================
234
235pub const PACKED_WIDTH_32: usize = 4;
236
237#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
238#[repr(C, align(16))]
239pub struct PackedBlock32(pub [Block32; PACKED_WIDTH_32]);
240
241impl PackedBlock32 {
242    #[inline(always)]
243    pub fn zero() -> Self {
244        Self([Block32::ZERO; PACKED_WIDTH_32])
245    }
246}
247
248impl PackableField for Block32 {
249    type Packed = PackedBlock32;
250
251    const WIDTH: usize = PACKED_WIDTH_32;
252
253    #[inline(always)]
254    fn pack(chunk: &[Self]) -> Self::Packed {
255        assert!(
256            chunk.len() >= PACKED_WIDTH_32,
257            "PackableField::pack: input slice too short",
258        );
259
260        let mut arr = [Self::ZERO; PACKED_WIDTH_32];
261        arr.copy_from_slice(&chunk[..PACKED_WIDTH_32]);
262
263        PackedBlock32(arr)
264    }
265
266    #[inline(always)]
267    fn unpack(packed: Self::Packed, output: &mut [Self]) {
268        assert!(
269            output.len() >= PACKED_WIDTH_32,
270            "PackableField::unpack: output slice too short",
271        );
272
273        output[..PACKED_WIDTH_32].copy_from_slice(&packed.0);
274    }
275}
276
277impl Add for PackedBlock32 {
278    type Output = Self;
279
280    #[inline(always)]
281    fn add(self, rhs: Self) -> Self {
282        let mut res = [Block32::ZERO; PACKED_WIDTH_32];
283        for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
284            *out = *l + *r;
285        }
286
287        Self(res)
288    }
289}
290
291impl AddAssign for PackedBlock32 {
292    #[inline(always)]
293    fn add_assign(&mut self, rhs: Self) {
294        for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
295            *l += *r;
296        }
297    }
298}
299
300impl Sub for PackedBlock32 {
301    type Output = Self;
302
303    #[inline(always)]
304    fn sub(self, rhs: Self) -> Self {
305        self.add(rhs)
306    }
307}
308
309impl SubAssign for PackedBlock32 {
310    #[inline(always)]
311    fn sub_assign(&mut self, rhs: Self) {
312        self.add_assign(rhs);
313    }
314}
315
316impl Mul for PackedBlock32 {
317    type Output = Self;
318
319    #[inline(always)]
320    fn mul(self, rhs: Self) -> Self {
321        #[cfg(target_arch = "aarch64")]
322        {
323            let a0 = mul_iso_32(self.0[0], rhs.0[0]);
324            let a1 = mul_iso_32(self.0[1], rhs.0[1]);
325            let a2 = mul_iso_32(self.0[2], rhs.0[2]);
326            let a3 = mul_iso_32(self.0[3], rhs.0[3]);
327
328            Self([a0, a1, a2, a3])
329        }
330
331        #[cfg(not(target_arch = "aarch64"))]
332        {
333            let mut res = [Block32::ZERO; PACKED_WIDTH_32];
334            for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
335                *out = *l * *r;
336            }
337
338            Self(res)
339        }
340    }
341}
342
343impl MulAssign for PackedBlock32 {
344    #[inline(always)]
345    fn mul_assign(&mut self, rhs: Self) {
346        for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
347            *l *= *r;
348        }
349    }
350}
351
352impl Mul<Block32> for PackedBlock32 {
353    type Output = Self;
354
355    #[inline(always)]
356    fn mul(self, rhs: Block32) -> Self {
357        let mut res = [Block32::ZERO; PACKED_WIDTH_32];
358        for (out, v) in res.iter_mut().zip(self.0.iter()) {
359            *out = *v * rhs;
360        }
361
362        Self(res)
363    }
364}
365
366// ===================================
367// Block32 Hardware Field
368// ===================================
369
370impl HardwareField for Block32 {
371    #[inline(always)]
372    fn to_hardware(self) -> Flat<Self> {
373        #[cfg(feature = "table-math")]
374        {
375            Flat::from_raw(apply_matrix_32(self, &constants::TOWER_TO_FLAT_32))
376        }
377
378        #[cfg(not(feature = "table-math"))]
379        {
380            Flat::from_raw(Block32(map_ct_32(self.0, &TOWER_TO_FLAT_BASIS_32.0)))
381        }
382    }
383
384    #[inline(always)]
385    fn from_hardware(value: Flat<Self>) -> Self {
386        let value = value.into_raw();
387        #[cfg(feature = "table-math")]
388        {
389            apply_matrix_32(value, &constants::FLAT_TO_TOWER_32)
390        }
391
392        #[cfg(not(feature = "table-math"))]
393        {
394            Block32(map_ct_32(value.0, &FLAT_TO_TOWER_BASIS_32.0))
395        }
396    }
397
398    #[inline(always)]
399    fn add_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
400        Flat::from_raw(lhs.into_raw() + rhs.into_raw())
401    }
402
403    #[inline(always)]
404    fn add_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
405        let lhs = lhs.into_raw();
406        let rhs = rhs.into_raw();
407
408        #[cfg(target_arch = "aarch64")]
409        {
410            PackedFlat::from_raw(neon::add_packed_32(lhs, rhs))
411        }
412
413        #[cfg(not(target_arch = "aarch64"))]
414        {
415            PackedFlat::from_raw(lhs + rhs)
416        }
417    }
418
419    #[inline(always)]
420    fn mul_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
421        let lhs = lhs.into_raw();
422        let rhs = rhs.into_raw();
423
424        #[cfg(target_arch = "aarch64")]
425        {
426            Flat::from_raw(neon::mul_flat_32(lhs, rhs))
427        }
428
429        #[cfg(not(target_arch = "aarch64"))]
430        {
431            let a_tower = Self::from_hardware(Flat::from_raw(lhs));
432            let b_tower = Self::from_hardware(Flat::from_raw(rhs));
433
434            (a_tower * b_tower).to_hardware()
435        }
436    }
437
438    #[inline(always)]
439    fn mul_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
440        let lhs = lhs.into_raw();
441        let rhs = rhs.into_raw();
442
443        #[cfg(target_arch = "aarch64")]
444        {
445            PackedFlat::from_raw(neon::mul_flat_packed_32(lhs, rhs))
446        }
447
448        #[cfg(not(target_arch = "aarch64"))]
449        {
450            let mut l = [Self::ZERO; <Self as PackableField>::WIDTH];
451            let mut r = [Self::ZERO; <Self as PackableField>::WIDTH];
452            let mut res = [Self::ZERO; <Self as PackableField>::WIDTH];
453
454            Self::unpack(lhs, &mut l);
455            Self::unpack(rhs, &mut r);
456
457            for i in 0..<Self as PackableField>::WIDTH {
458                res[i] = Self::mul_hardware(Flat::from_raw(l[i]), Flat::from_raw(r[i])).into_raw();
459            }
460
461            PackedFlat::from_raw(Self::pack(&res))
462        }
463    }
464
465    #[inline(always)]
466    fn mul_hardware_scalar_packed(lhs: PackedFlat<Self>, rhs: Flat<Self>) -> PackedFlat<Self> {
467        let broadcasted = PackedBlock32([rhs.into_raw(); PACKED_WIDTH_32]);
468        Self::mul_hardware_packed(lhs, PackedFlat::from_raw(broadcasted))
469    }
470
471    #[inline(always)]
472    fn tower_bit_from_hardware(value: Flat<Self>, bit_idx: usize) -> u8 {
473        let mask = constants::FLAT_TO_TOWER_BIT_MASKS_32[bit_idx];
474
475        // Parity of (x & mask) without popcount.
476        // Folds 32 bits down to 1
477        // using a binary XOR tree.
478        let mut v = value.into_raw().0 & mask;
479        v ^= v >> 16;
480        v ^= v >> 8;
481        v ^= v >> 4;
482        v ^= v >> 2;
483        v ^= v >> 1;
484
485        (v & 1) as u8
486    }
487}
488
489impl FlatPromote<Block8> for Block32 {
490    #[inline(always)]
491    fn promote_flat(val: Flat<Block8>) -> Flat<Self> {
492        let val = val.into_raw();
493        #[cfg(not(feature = "table-math"))]
494        {
495            let mut acc = 0u32;
496            for i in 0..8 {
497                let bit = (val.0 >> i) & 1;
498                let mask = 0u32.wrapping_sub(bit as u32);
499                acc ^= constants::LIFT_BASIS_8_TO_32[i] & mask;
500            }
501
502            Flat::from_raw(Block32(acc))
503        }
504
505        #[cfg(feature = "table-math")]
506        {
507            Flat::from_raw(Block32(constants::LIFT_TABLE_8_TO_32[val.0 as usize]))
508        }
509    }
510}
511
512// ===========================================
513// UTILS
514// ===========================================
515
516#[cfg(target_arch = "aarch64")]
517#[inline(always)]
518pub fn mul_iso_32(a: Block32, b: Block32) -> Block32 {
519    let a_flat = a.to_hardware();
520    let b_flat = b.to_hardware();
521    let c_flat = Flat::from_raw(neon::mul_flat_32(a_flat.into_raw(), b_flat.into_raw()));
522
523    c_flat.to_tower()
524}
525
526#[cfg(feature = "table-math")]
527#[inline(always)]
528pub fn apply_matrix_32(val: Block32, table: &[u32; 1024]) -> Block32 {
529    let mut res = 0u32;
530    let v = val.0;
531
532    // 4 lookups
533    for i in 0..4 {
534        let byte = (v >> (i * 8)) & 0xFF;
535        let idx = (i * 256) + (byte as usize);
536        res ^= unsafe { *table.get_unchecked(idx) };
537    }
538
539    Block32(res)
540}
541
542#[cfg(not(feature = "table-math"))]
543#[inline(always)]
544fn map_ct_32(x: u32, basis: &[u32; 32]) -> u32 {
545    let mut acc = 0u32;
546    let mut i = 0usize;
547
548    while i < 32 {
549        let bit = (x >> i) & 1;
550        let mask = 0u32.wrapping_sub(bit);
551        acc ^= basis[i] & mask;
552        i += 1;
553    }
554
555    acc
556}
557
558// ===========================================
559// 32-BIT SIMD INSTRUCTIONS
560// ===========================================
561
562#[cfg(target_arch = "aarch64")]
563mod neon {
564    use super::*;
565    use core::arch::aarch64::*;
566    use core::mem::transmute;
567
568    #[inline(always)]
569    pub fn add_packed_32(lhs: PackedBlock32, rhs: PackedBlock32) -> PackedBlock32 {
570        unsafe {
571            let l: uint8x16_t = transmute::<[Block32; PACKED_WIDTH_32], uint8x16_t>(lhs.0);
572            let r: uint8x16_t = transmute::<[Block32; PACKED_WIDTH_32], uint8x16_t>(rhs.0);
573            let res = veorq_u8(l, r);
574            let out: [Block32; PACKED_WIDTH_32] =
575                transmute::<uint8x16_t, [Block32; PACKED_WIDTH_32]>(res);
576
577            PackedBlock32(out)
578        }
579    }
580
581    #[inline(always)]
582    pub fn mul_flat_packed_32(lhs: PackedBlock32, rhs: PackedBlock32) -> PackedBlock32 {
583        let r0 = mul_flat_32(lhs.0[0], rhs.0[0]);
584        let r1 = mul_flat_32(lhs.0[1], rhs.0[1]);
585        let r2 = mul_flat_32(lhs.0[2], rhs.0[2]);
586        let r3 = mul_flat_32(lhs.0[3], rhs.0[3]);
587
588        PackedBlock32([r0, r1, r2, r3])
589    }
590
591    #[inline(always)]
592    pub fn mul_flat_32(a: Block32, b: Block32) -> Block32 {
593        unsafe {
594            // 1. Multiply 32x32 -> 64
595            // Cast u32 to u64 for vmull
596            let prod = vmull_p64(a.0 as u64, b.0 as u64);
597
598            // The result is 128-bit type, but only care
599            // about low 64 bits because 32*32 fits in 64 bits.
600            let prod_u64: uint64x2_t = transmute(prod);
601            let prod_val = vgetq_lane_u64(prod_u64, 0);
602
603            let l = (prod_val & 0xFFFFFFFF) as u32;
604            let h = (prod_val >> 32) as u32;
605
606            // 2. Reduce mod P(x) = x^32 + R(x)
607            let r_val = constants::POLY_32 as u64;
608
609            // H * R
610            let h_red = vmull_p64(h as u64, r_val);
611            let h_red_vec: uint64x2_t = transmute(h_red);
612            let h_red_val = vgetq_lane_u64(h_red_vec, 0);
613
614            let folded = (h_red_val & 0xFFFFFFFF) as u32;
615            let carry = (h_red_val >> 32) as u32;
616
617            let mut res = l ^ folded;
618
619            // 3. Reduce carry
620            let carry_red = vmull_p64(carry as u64, r_val);
621            let carry_res_vec: uint64x2_t = transmute(carry_red);
622            let carry_val = vgetq_lane_u64(carry_res_vec, 0);
623
624            res ^= carry_val as u32;
625
626            Block32(res)
627        }
628    }
629}
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634    use proptest::prelude::*;
635    use rand::{RngExt, rng};
636
637    // ==================================
638    // BASIC
639    // ==================================
640
641    #[test]
642    fn tower_constants() {
643        // Check that tau is propagated correctly
644        // For Block32, tau must be (0, 1) from Block16.
645        let tau32 = Block32::EXTENSION_TAU;
646        let (lo32, hi32) = tau32.split();
647        assert_eq!(lo32, Block16::ZERO);
648        assert_eq!(hi32, Block16::TAU);
649    }
650
651    #[test]
652    fn add_truth() {
653        let zero = Block32::ZERO;
654        let one = Block32::ONE;
655
656        assert_eq!(zero + zero, zero);
657        assert_eq!(zero + one, one);
658        assert_eq!(one + zero, one);
659        assert_eq!(one + one, zero);
660    }
661
662    #[test]
663    fn mul_truth() {
664        let zero = Block32::ZERO;
665        let one = Block32::ONE;
666
667        assert_eq!(zero * zero, zero);
668        assert_eq!(zero * one, zero);
669        assert_eq!(one * one, one);
670    }
671
672    #[test]
673    fn add() {
674        // 5 ^ 3 = 6
675        // 101 ^ 011 = 110
676        assert_eq!(Block32(5) + Block32(3), Block32(6));
677    }
678
679    #[test]
680    fn mul_simple() {
681        // Check for prime numbers (without overflow)
682        // x^1 * x^1 = x^2 (2 * 2 = 4)
683        assert_eq!(Block32(2) * Block32(2), Block32(4));
684    }
685
686    #[test]
687    fn mul_overflow() {
688        // Reduction verification (AES test vectors)
689        // Example from the AES specification:
690        // 0x57 * 0x83 = 0xC1
691        assert_eq!(Block32(0x57) * Block32(0x83), Block32(0xC1));
692    }
693
694    #[test]
695    fn karatsuba_correctness() {
696        // Let A = X (hi=1, lo=0)
697        // Let B = X (hi=1, lo=0)
698        // A * B = X^2
699        // According to the rule:
700        // X^2 = X + tau
701        // Where tau for Block16 = 0x2000.
702        // So the result should be:
703        // hi=1 (X), lo=0x20 (tau)
704
705        // Construct X manually
706        let x = Block32::new(Block16::ZERO, Block16::ONE);
707        let squared = x * x;
708
709        // Verify result via splitting
710        let (res_lo, res_hi) = squared.split();
711
712        assert_eq!(res_hi, Block16::ONE, "X^2 should contain X component");
713        assert_eq!(
714            res_lo,
715            Block16(0x2000),
716            "X^2 should contain tau component (0x2000)"
717        );
718    }
719
720    #[test]
721    fn security_zeroize() {
722        let mut secret_val = Block32::from(0xDEAD_BEEF_u32);
723        assert_ne!(secret_val, Block32::ZERO);
724
725        secret_val.zeroize();
726
727        assert_eq!(secret_val, Block32::ZERO);
728        assert_eq!(secret_val.0, 0, "Block32 memory leak detected");
729    }
730
731    #[test]
732    fn invert_zero() {
733        // Verify that inverting zero adheres
734        // to the API contract (returns 0).
735        assert_eq!(
736            Block32::ZERO.invert(),
737            Block32::ZERO,
738            "invert(0) must return 0"
739        );
740    }
741
742    #[test]
743    fn inversion_random() {
744        let mut rng = rng();
745        for _ in 0..1000 {
746            let val = Block32(rng.random());
747
748            if val != Block32::ZERO {
749                let inv = val.invert();
750                let res = val * inv;
751
752                assert_eq!(
753                    res,
754                    Block32::ONE,
755                    "Inversion identity failed: a * a^-1 != 1"
756                );
757            }
758        }
759    }
760
761    #[test]
762    fn tower_embedding() {
763        let mut rng = rng();
764        for _ in 0..100 {
765            let a_u16: u16 = rng.random();
766            let b_u16: u16 = rng.random();
767            let a = Block16(a_u16);
768            let b = Block16(b_u16);
769
770            // 1. Structure check
771            let a_lifted: Block32 = a.into();
772            let (lo, hi) = a_lifted.split();
773
774            assert_eq!(lo, a, "Embedding structure failed: low part mismatch");
775            assert_eq!(
776                hi,
777                Block16::ZERO,
778                "Embedding structure failed: high part must be zero"
779            );
780
781            // 2. Addition Homomorphism
782            let sum_sub = a + b;
783            let sum_lifted: Block32 = sum_sub.into();
784            let sum_manual = Block32::from(a) + Block32::from(b);
785
786            assert_eq!(sum_lifted, sum_manual, "Homomorphism failed: add");
787
788            // 3. Multiplication Homomorphism
789            let prod_sub = a * b;
790            let prod_lifted: Block32 = prod_sub.into();
791            let prod_manual = Block32::from(a) * Block32::from(b);
792
793            assert_eq!(prod_lifted, prod_manual, "Homomorphism failed: mul");
794        }
795    }
796
797    // ==================================
798    // HARDWARE
799    // ==================================
800
801    #[test]
802    fn isomorphism_roundtrip() {
803        let mut rng = rng();
804        for _ in 0..1000 {
805            let val = Block32(rng.random::<u32>());
806            assert_eq!(
807                val.to_hardware().to_tower(),
808                val,
809                "Block32 isomorphism roundtrip failed"
810            );
811        }
812    }
813
814    #[test]
815    fn flat_mul_homomorphism() {
816        let mut rng = rng();
817        for _ in 0..1000 {
818            let a = Block32(rng.random::<u32>());
819            let b = Block32(rng.random::<u32>());
820            assert_eq!(a.to_hardware() * b.to_hardware(), (a * b).to_hardware());
821        }
822    }
823
824    #[test]
825    fn packed_consistency() {
826        let mut rng = rng();
827        let mut a_vals = [Block32::ZERO; 4];
828        let mut b_vals = [Block32::ZERO; 4];
829
830        for i in 0..4 {
831            a_vals[i] = Block32(rng.random::<u32>());
832            b_vals[i] = Block32(rng.random::<u32>());
833        }
834
835        // Add consistency
836        let a_flat_vals = a_vals.map(|x| x.to_hardware());
837        let b_flat_vals = b_vals.map(|x| x.to_hardware());
838        let add_res = Block32::add_hardware_packed(
839            Flat::<Block32>::pack(&a_flat_vals),
840            Flat::<Block32>::pack(&b_flat_vals),
841        );
842
843        let mut add_out = [Block32::ZERO.to_hardware(); 4];
844        Flat::<Block32>::unpack(add_res, &mut add_out);
845
846        for i in 0..4 {
847            assert_eq!(add_out[i], (a_vals[i] + b_vals[i]).to_hardware());
848        }
849
850        // Mul consistency (Flat basis)
851        let mul_res = Block32::mul_hardware_packed(
852            Flat::<Block32>::pack(&a_flat_vals),
853            Flat::<Block32>::pack(&b_flat_vals),
854        );
855
856        let mut mul_out = [Block32::ZERO.to_hardware(); 4];
857        Flat::<Block32>::unpack(mul_res, &mut mul_out);
858
859        for i in 0..4 {
860            assert_eq!(mul_out[i], (a_vals[i] * b_vals[i]).to_hardware());
861        }
862    }
863
864    // ==================================
865    // PACKED
866    // ==================================
867
868    #[test]
869    fn pack_unpack_roundtrip() {
870        let mut rng = rng();
871        let mut data = [Block32::ZERO; PACKED_WIDTH_32];
872
873        for v in data.iter_mut() {
874            *v = Block32(rng.random());
875        }
876
877        let packed = Block32::pack(&data);
878        let mut unpacked = [Block32::ZERO; PACKED_WIDTH_32];
879        Block32::unpack(packed, &mut unpacked);
880
881        assert_eq!(data, unpacked);
882    }
883
884    #[test]
885    fn packed_add_consistency() {
886        let mut rng = rng();
887        let a_vals = [
888            Block32(rng.random()),
889            Block32(rng.random()),
890            Block32(rng.random()),
891            Block32(rng.random()),
892        ];
893        let b_vals = [
894            Block32(rng.random()),
895            Block32(rng.random()),
896            Block32(rng.random()),
897            Block32(rng.random()),
898        ];
899
900        let res_packed = Block32::pack(&a_vals) + Block32::pack(&b_vals);
901        let mut res_unpacked = [Block32::ZERO; PACKED_WIDTH_32];
902        Block32::unpack(res_packed, &mut res_unpacked);
903
904        for i in 0..PACKED_WIDTH_32 {
905            assert_eq!(res_unpacked[i], a_vals[i] + b_vals[i]);
906        }
907    }
908
909    #[test]
910    fn packed_mul_consistency() {
911        let mut rng = rng();
912
913        for _ in 0..1000 {
914            let mut a_arr = [Block32::ZERO; PACKED_WIDTH_32];
915            let mut b_arr = [Block32::ZERO; PACKED_WIDTH_32];
916
917            for i in 0..PACKED_WIDTH_32 {
918                let val_a: u32 = rng.random();
919                let val_b: u32 = rng.random();
920                a_arr[i] = Block32(val_a);
921                b_arr[i] = Block32(val_b);
922            }
923
924            let a_packed = PackedBlock32(a_arr);
925            let b_packed = PackedBlock32(b_arr);
926
927            // Perform SIMD multiplication
928            let c_packed = a_packed * b_packed;
929
930            // Verify against Scalar
931            let mut c_expected = [Block32::ZERO; PACKED_WIDTH_32];
932            for i in 0..PACKED_WIDTH_32 {
933                c_expected[i] = a_arr[i] * b_arr[i];
934            }
935
936            assert_eq!(c_packed.0, c_expected, "SIMD Block32 mismatch!");
937        }
938    }
939
940    proptest! {
941        #[test]
942        fn parity_masks_match_from_hardware(x_flat in any::<u32>()) {
943            let tower = Block32::from_hardware(Flat::from_raw(Block32(x_flat))).0;
944
945            for k in 0..32 {
946                let bit = ((tower >> k) & 1) as u8;
947                let via_api = Flat::from_raw(Block32(x_flat)).tower_bit(k);
948
949                prop_assert_eq!(
950                    via_api, bit,
951                    "Block32 tower_bit_from_hardware mismatch at x_flat={:#010x}, bit_idx={}",
952                    x_flat, k
953                );
954            }
955        }
956    }
957}