Skip to main content

hekate_math/towers/
block16.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate-math project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! BLOCK 16 (GF(2^16))
19use crate::towers::bit::Bit;
20use crate::towers::block8::Block8;
21use crate::{
22    BinaryFieldExtras, CanonicalDeserialize, CanonicalSerialize, Flat, FlatPromote, HardwareField,
23    PackableField, PackedFlat, TowerField, constants,
24};
25use core::ops::{Add, AddAssign, BitXor, BitXorAssign, Mul, MulAssign, Sub, SubAssign};
26use serde::{Deserialize, Serialize};
27use zeroize::Zeroize;
28
29#[cfg(not(feature = "table-math"))]
30#[repr(align(64))]
31struct CtConvertBasisU16<const N: usize>([u16; N]);
32
33#[cfg(not(feature = "table-math"))]
34static TOWER_TO_FLAT_BASIS_16: CtConvertBasisU16<16> =
35    CtConvertBasisU16(constants::RAW_TOWER_TO_FLAT_16);
36
37#[cfg(not(feature = "table-math"))]
38static FLAT_TO_TOWER_BASIS_16: CtConvertBasisU16<16> =
39    CtConvertBasisU16(constants::RAW_FLAT_TO_TOWER_16);
40
41#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
42#[repr(transparent)]
43pub struct Block16(pub u16);
44
45impl Block16 {
46    pub const TAU: Self = Block16(0x2000);
47
48    pub fn new(lo: Block8, hi: Block8) -> Self {
49        Self((hi.0 as u16) << 8 | (lo.0 as u16))
50    }
51
52    #[inline(always)]
53    pub fn split(self) -> (Block8, Block8) {
54        (Block8(self.0 as u8), Block8((self.0 >> 8) as u8))
55    }
56}
57
58impl TowerField for Block16 {
59    const BITS: usize = 16;
60    const ZERO: Self = Block16(0);
61    const ONE: Self = Block16(1);
62
63    const EXTENSION_TAU: Self = Self::TAU;
64
65    fn invert(&self) -> Self {
66        let (l, h) = self.split();
67
68        // Norm = h^2 * tau + h*l + l^2
69        let h2 = h * h;
70        let l2 = l * l;
71        let hl = h * l;
72        let norm = (h2 * Block8::EXTENSION_TAU) + hl + l2;
73
74        let norm_inv = norm.invert();
75
76        // Res = (h*norm_inv) X + (h+l)*norm_inv
77        let res_hi = h * norm_inv;
78        let res_lo = (h + l) * norm_inv;
79
80        Self::new(res_lo, res_hi)
81    }
82
83    fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
84        let mut buf = [0u8; 2];
85        buf.copy_from_slice(&bytes[0..2]);
86
87        Self(u16::from_le_bytes(buf))
88    }
89}
90
91impl Add for Block16 {
92    type Output = Self;
93
94    fn add(self, rhs: Self) -> Self {
95        Self(self.0.bitxor(rhs.0))
96    }
97}
98
99impl Sub for Block16 {
100    type Output = Self;
101
102    fn sub(self, rhs: Self) -> Self {
103        Self(self.0.bitxor(rhs.0))
104    }
105}
106
107impl Mul for Block16 {
108    type Output = Self;
109
110    fn mul(self, rhs: Self) -> Self {
111        let (a0, a1) = self.split();
112        let (b0, b1) = rhs.split();
113
114        // Karatsuba
115        let v0 = a0 * b0;
116        let v1 = a1 * b1;
117        let v_sum = (a0 + a1) * (b0 + b1);
118
119        // Reconstruction with reduction X^2 = X + tau
120        // Hi
121        let c_hi = v0 + v_sum;
122
123        // Lo
124        let c_lo = v0 + (v1 * Block8::EXTENSION_TAU);
125
126        Self::new(c_lo, c_hi)
127    }
128}
129
130impl AddAssign for Block16 {
131    fn add_assign(&mut self, rhs: Self) {
132        self.0.bitxor_assign(rhs.0);
133    }
134}
135
136impl SubAssign for Block16 {
137    fn sub_assign(&mut self, rhs: Self) {
138        self.0.bitxor_assign(rhs.0);
139    }
140}
141
142impl MulAssign for Block16 {
143    fn mul_assign(&mut self, rhs: Self) {
144        *self = *self * rhs;
145    }
146}
147
148impl CanonicalSerialize for Block16 {
149    fn serialized_size(&self) -> usize {
150        2
151    }
152
153    fn serialize(&self, writer: &mut [u8]) -> Result<(), ()> {
154        if writer.len() < 2 {
155            return Err(());
156        }
157
158        writer[..2].copy_from_slice(&self.0.to_le_bytes());
159
160        Ok(())
161    }
162}
163
164impl CanonicalDeserialize for Block16 {
165    fn deserialize(bytes: &[u8]) -> Result<Self, ()> {
166        if bytes.len() < 2 {
167            return Err(());
168        }
169
170        let mut buf = [0u8; 2];
171        buf.copy_from_slice(&bytes[0..2]);
172
173        Ok(Self(u16::from_le_bytes(buf)))
174    }
175}
176
177impl From<u8> for Block16 {
178    fn from(val: u8) -> Self {
179        Self(val as u16)
180    }
181}
182
183impl From<u16> for Block16 {
184    #[inline]
185    fn from(val: u16) -> Self {
186        Self(val)
187    }
188}
189
190impl From<u32> for Block16 {
191    #[inline]
192    fn from(val: u32) -> Self {
193        Self(val as u16)
194    }
195}
196
197impl From<u64> for Block16 {
198    #[inline]
199    fn from(val: u64) -> Self {
200        Self(val as u16)
201    }
202}
203
204impl From<u128> for Block16 {
205    #[inline]
206    fn from(val: u128) -> Self {
207        Self(val as u16)
208    }
209}
210
211// ========================================
212// FIELD LIFTING
213// ========================================
214
215impl From<Bit> for Block16 {
216    #[inline(always)]
217    fn from(val: Bit) -> Self {
218        Self(val.0 as u16)
219    }
220}
221
222impl From<Block8> for Block16 {
223    #[inline(always)]
224    fn from(val: Block8) -> Self {
225        Self(val.0 as u16)
226    }
227}
228
229// ===================================
230// PACKED BLOCK 16 (Width = 8)
231// ===================================
232
233// 128 bits / 16 = 8 elements
234pub const PACKED_WIDTH_16: usize = 8;
235
236#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
237#[repr(C, align(16))]
238pub struct PackedBlock16(pub [Block16; PACKED_WIDTH_16]);
239
240impl PackedBlock16 {
241    #[inline(always)]
242    pub fn zero() -> Self {
243        Self([Block16::ZERO; PACKED_WIDTH_16])
244    }
245}
246
247impl PackableField for Block16 {
248    type Packed = PackedBlock16;
249
250    const WIDTH: usize = PACKED_WIDTH_16;
251
252    #[inline(always)]
253    fn pack(chunk: &[Self]) -> Self::Packed {
254        assert!(
255            chunk.len() >= PACKED_WIDTH_16,
256            "PackableField::pack: input slice too short",
257        );
258
259        let mut arr = [Self::ZERO; PACKED_WIDTH_16];
260        arr.copy_from_slice(&chunk[..PACKED_WIDTH_16]);
261
262        PackedBlock16(arr)
263    }
264
265    #[inline(always)]
266    fn unpack(packed: Self::Packed, output: &mut [Self]) {
267        assert!(
268            output.len() >= PACKED_WIDTH_16,
269            "PackableField::unpack: output slice too short",
270        );
271
272        output[..PACKED_WIDTH_16].copy_from_slice(&packed.0);
273    }
274}
275
276impl Add for PackedBlock16 {
277    type Output = Self;
278
279    #[inline(always)]
280    fn add(self, rhs: Self) -> Self {
281        let mut res = [Block16::ZERO; PACKED_WIDTH_16];
282        for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
283            *out = *l + *r;
284        }
285
286        Self(res)
287    }
288}
289
290impl AddAssign for PackedBlock16 {
291    #[inline(always)]
292    fn add_assign(&mut self, rhs: Self) {
293        for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
294            *l += *r;
295        }
296    }
297}
298
299impl Sub for PackedBlock16 {
300    type Output = Self;
301
302    #[inline(always)]
303    fn sub(self, rhs: Self) -> Self {
304        self.add(rhs)
305    }
306}
307
308impl SubAssign for PackedBlock16 {
309    #[inline(always)]
310    fn sub_assign(&mut self, rhs: Self) {
311        self.add_assign(rhs);
312    }
313}
314
315impl Mul for PackedBlock16 {
316    type Output = Self;
317
318    #[inline(always)]
319    fn mul(self, rhs: Self) -> Self {
320        #[cfg(target_arch = "aarch64")]
321        {
322            let mut res = [Block16::ZERO; PACKED_WIDTH_16];
323            for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
324                *out = mul_iso_16(*l, *r);
325            }
326
327            Self(res)
328        }
329
330        #[cfg(not(target_arch = "aarch64"))]
331        {
332            let mut res = [Block16::ZERO; PACKED_WIDTH_16];
333            for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
334                *out = *l * *r;
335            }
336
337            Self(res)
338        }
339    }
340}
341
342impl MulAssign for PackedBlock16 {
343    #[inline(always)]
344    fn mul_assign(&mut self, rhs: Self) {
345        *self = *self * rhs;
346    }
347}
348
349impl Mul<Block16> for PackedBlock16 {
350    type Output = Self;
351
352    #[inline(always)]
353    fn mul(self, rhs: Block16) -> Self {
354        let mut res = [Block16::ZERO; PACKED_WIDTH_16];
355        for (out, v) in res.iter_mut().zip(self.0.iter()) {
356            *out = *v * rhs;
357        }
358
359        Self(res)
360    }
361}
362
363// ===================================
364// Hardware Field
365// ===================================
366
367impl HardwareField for Block16 {
368    #[inline(always)]
369    fn to_hardware(self) -> Flat<Self> {
370        #[cfg(feature = "table-math")]
371        {
372            Flat::from_raw(apply_matrix_16(self, &constants::TOWER_TO_FLAT_16))
373        }
374
375        #[cfg(not(feature = "table-math"))]
376        {
377            Flat::from_raw(Block16(map_ct_16(self.0, &TOWER_TO_FLAT_BASIS_16.0)))
378        }
379    }
380
381    #[inline(always)]
382    fn from_hardware(value: Flat<Self>) -> Self {
383        let value = value.into_raw();
384
385        #[cfg(feature = "table-math")]
386        {
387            apply_matrix_16(value, &constants::FLAT_TO_TOWER_16)
388        }
389
390        #[cfg(not(feature = "table-math"))]
391        {
392            Block16(map_ct_16(value.0, &FLAT_TO_TOWER_BASIS_16.0))
393        }
394    }
395
396    #[inline(always)]
397    fn add_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
398        Flat::from_raw(lhs.into_raw() + rhs.into_raw())
399    }
400
401    #[inline(always)]
402    fn add_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
403        let lhs = lhs.into_raw();
404        let rhs = rhs.into_raw();
405
406        #[cfg(target_arch = "aarch64")]
407        {
408            PackedFlat::from_raw(neon::add_packed_16(lhs, rhs))
409        }
410
411        #[cfg(not(target_arch = "aarch64"))]
412        {
413            PackedFlat::from_raw(lhs + rhs)
414        }
415    }
416
417    #[inline(always)]
418    fn mul_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
419        let lhs = lhs.into_raw();
420        let rhs = rhs.into_raw();
421
422        #[cfg(target_arch = "aarch64")]
423        {
424            Flat::from_raw(neon::mul_flat_16(lhs, rhs))
425        }
426
427        #[cfg(not(target_arch = "aarch64"))]
428        {
429            let a_tower = Self::from_hardware(Flat::from_raw(lhs));
430            let b_tower = Self::from_hardware(Flat::from_raw(rhs));
431
432            (a_tower * b_tower).to_hardware()
433        }
434    }
435
436    #[inline(always)]
437    fn mul_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
438        let lhs = lhs.into_raw();
439        let rhs = rhs.into_raw();
440
441        #[cfg(target_arch = "aarch64")]
442        {
443            PackedFlat::from_raw(neon::mul_flat_packed_16(lhs, rhs))
444        }
445
446        #[cfg(not(target_arch = "aarch64"))]
447        {
448            let mut l = [Self::ZERO; <Self as PackableField>::WIDTH];
449            let mut r = [Self::ZERO; <Self as PackableField>::WIDTH];
450            let mut res = [Self::ZERO; <Self as PackableField>::WIDTH];
451
452            Self::unpack(lhs, &mut l);
453            Self::unpack(rhs, &mut r);
454
455            for i in 0..<Self as PackableField>::WIDTH {
456                res[i] = Self::mul_hardware(Flat::from_raw(l[i]), Flat::from_raw(r[i])).into_raw();
457            }
458
459            PackedFlat::from_raw(Self::pack(&res))
460        }
461    }
462
463    #[inline(always)]
464    fn mul_hardware_scalar_packed(lhs: PackedFlat<Self>, rhs: Flat<Self>) -> PackedFlat<Self> {
465        #[cfg(target_arch = "aarch64")]
466        {
467            PackedFlat::from_raw(neon::mul_flat_scalar_packed_16(
468                lhs.into_raw(),
469                rhs.into_raw(),
470            ))
471        }
472
473        #[cfg(not(target_arch = "aarch64"))]
474        {
475            let broadcasted = PackedBlock16([rhs.into_raw(); PACKED_WIDTH_16]);
476            Self::mul_hardware_packed(lhs, PackedFlat::from_raw(broadcasted))
477        }
478    }
479
480    #[inline(always)]
481    fn tower_bit_from_hardware(value: Flat<Self>, bit_idx: usize) -> u8 {
482        let mask = constants::FLAT_TO_TOWER_BIT_MASKS_16[bit_idx];
483
484        // Parity of (x & mask) without
485        // popcount. Folds 16 bits down
486        // to 1 using a binary XOR tree.
487        let mut v = value.into_raw().0 & mask;
488        v ^= v >> 8;
489        v ^= v >> 4;
490        v ^= v >> 2;
491        v ^= v >> 1;
492
493        (v & 1) as u8
494    }
495}
496
497impl FlatPromote<Block8> for Block16 {
498    #[inline(always)]
499    fn promote_flat(val: Flat<Block8>) -> Flat<Self> {
500        let val = val.into_raw();
501
502        #[cfg(not(feature = "table-math"))]
503        {
504            let mut acc = 0u16;
505            for i in 0..8 {
506                let bit = (val.0 >> i) & 1;
507                let mask = 0u16.wrapping_sub(bit as u16);
508                acc ^= constants::LIFT_BASIS_8_TO_16[i] & mask;
509            }
510
511            Flat::from_raw(Block16(acc))
512        }
513
514        #[cfg(feature = "table-math")]
515        {
516            Flat::from_raw(Block16(constants::LIFT_TABLE_8_TO_16[val.0 as usize]))
517        }
518    }
519}
520
521// ===================================
522// Binary Field Extras
523// ===================================
524
525impl BinaryFieldExtras for Block16 {
526    #[inline(always)]
527    fn square(&self) -> Self {
528        // char 2:
529        // (lo + hi·X)^2 = lo^2 + hi^2·X^2,
530        // no cross term.
531        let (lo, hi) = self.split();
532        let hi2 = hi.square();
533
534        Self::new(lo.square() + hi2 * Block8::EXTENSION_TAU, hi2)
535    }
536
537    #[inline(always)]
538    fn trace(&self) -> Bit {
539        Bit(((self.0 & constants::TRACE_MASK_16).count_ones() & 1) as u8)
540    }
541
542    #[inline(always)]
543    fn solve_quadratic(c: Self) -> Option<Self> {
544        match c.trace() {
545            Bit(0) => Some(Block16(map_ct_16(
546                c.0,
547                &constants::SOLVE_QUADRATIC_BASIS_16,
548            ))),
549            _ => None,
550        }
551    }
552}
553
554// ===========================================
555// UTILS
556// ===========================================
557
558#[cfg(target_arch = "aarch64")]
559#[inline(always)]
560pub fn mul_iso_16(a: Block16, b: Block16) -> Block16 {
561    let a_f = a.to_hardware();
562    let b_f = b.to_hardware();
563    let c_f = Flat::from_raw(neon::mul_flat_16(a_f.into_raw(), b_f.into_raw()));
564
565    c_f.to_tower()
566}
567
568#[cfg(feature = "table-math")]
569#[inline(always)]
570pub fn apply_matrix_16(val: Block16, table: &[u16; 512]) -> Block16 {
571    let v = val.0;
572    let mut res = 0u16;
573
574    // 2 lookups (8-bit window)
575    for i in 0..2 {
576        let idx = (i * 256) + ((v >> (i * 8)) & 0xFF) as usize;
577        res ^= unsafe { *table.get_unchecked(idx) };
578    }
579
580    Block16(res)
581}
582
583#[inline(always)]
584fn map_ct_16(x: u16, basis: &[u16; 16]) -> u16 {
585    let mut acc = 0u16;
586    let mut i = 0usize;
587
588    while i < 16 {
589        let bit = (x >> i) & 1;
590        let mask = 0u16.wrapping_sub(bit);
591
592        acc ^= basis[i] & mask;
593        i += 1;
594    }
595
596    acc
597}
598
599// ===========================================
600// SIMD INSTRUCTIONS
601// ===========================================
602
603#[cfg(target_arch = "aarch64")]
604mod neon {
605    use super::*;
606    use core::arch::aarch64::*;
607    use core::mem::transmute;
608
609    // Shifts 5,3,1,0 in reduce_packed_16 encode R = POLY_16 (0x2b)
610    const _: () = assert!(constants::POLY_16 == 0x2b, "packed fold hardcodes R = 0x2b");
611
612    #[inline(always)]
613    pub fn add_packed_16(lhs: PackedBlock16, rhs: PackedBlock16) -> PackedBlock16 {
614        unsafe {
615            let res = veorq_u8(
616                transmute::<[Block16; 8], uint8x16_t>(lhs.0),
617                transmute::<[Block16; 8], uint8x16_t>(rhs.0),
618            );
619            transmute(res)
620        }
621    }
622
623    #[inline(always)]
624    pub fn mul_flat_16(a: Block16, b: Block16) -> Block16 {
625        unsafe {
626            // Note: Using 64-bit PMULL for 16-bit blocks
627            // is optimal on Apple Silicon. The pipeline
628            // parallelism of scalar `vmull_p64` outperforms
629            // complex SIMD Karatsuba.
630            let prod = vmull_p64(a.0 as u64, b.0 as u64);
631            let prod_val = vgetq_lane_u64(transmute::<u128, uint64x2_t>(prod), 0);
632
633            let l = (prod_val & 0xFFFF) as u16;
634            let h = (prod_val >> 16) as u16; // The rest fits in u16 for 16x16
635
636            // P(x) = x^16 + R
637            let r_val = constants::POLY_16 as u64;
638
639            // h * R
640            let h_red = vmull_p64(h as u64, r_val);
641            let h_red_val = vgetq_lane_u64(transmute::<u128, uint64x2_t>(h_red), 0);
642
643            // Result of h*R fits in 32 bits max (16+16).
644            // It's x^16 * H = H * R.
645            // res = L ^ (H*R)
646            // Since H*R > 16 bits, we have carry.
647
648            let folded = (h_red_val & 0xFFFF) as u16;
649            let carry = (h_red_val >> 16) as u16;
650
651            let mut res = l ^ folded;
652
653            // Unconditional reduction
654            // ensures constant-time.
655            let c_red = vmull_p64(carry as u64, r_val);
656            let c_val = vgetq_lane_u64(transmute::<u128, uint64x2_t>(c_red), 0);
657
658            res ^= c_val as u16;
659
660            Block16(res)
661        }
662    }
663
664    /// 8-wide GF(2^16) flat multiply, bit-identical
665    /// to eight scalar mul_flat_16 calls.
666    #[inline(always)]
667    pub fn mul_flat_packed_16(lhs: PackedBlock16, rhs: PackedBlock16) -> PackedBlock16 {
668        unsafe {
669            let a = transmute::<[Block16; 8], uint16x8_t>(lhs.0);
670            let b = transmute::<[Block16; 8], uint16x8_t>(rhs.0);
671
672            let a_lo = vmovn_u16(a);
673            let a_hi = vmovn_u16(vshrq_n_u16(a, 8));
674            let b_lo = vmovn_u16(b);
675            let b_hi = vmovn_u16(vshrq_n_u16(b, 8));
676
677            let ll = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
678                transmute::<uint8x8_t, poly8x8_t>(a_lo),
679                transmute::<uint8x8_t, poly8x8_t>(b_lo),
680            ));
681
682            let hh = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
683                transmute::<uint8x8_t, poly8x8_t>(a_hi),
684                transmute::<uint8x8_t, poly8x8_t>(b_hi),
685            ));
686
687            let mm = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
688                transmute::<uint8x8_t, poly8x8_t>(veor_u8(a_lo, a_hi)),
689                transmute::<uint8x8_t, poly8x8_t>(veor_u8(b_lo, b_hi)),
690            ));
691
692            PackedBlock16(transmute::<uint16x8_t, [Block16; 8]>(reduce_packed_16(
693                ll, mm, hh,
694            )))
695        }
696    }
697
698    /// Hoists the scalar twiddle's lane-uniform byte split
699    /// out of the eight lanes; otherwise as mul_flat_packed_16.
700    #[inline(always)]
701    pub fn mul_flat_scalar_packed_16(lhs: PackedBlock16, scalar: Block16) -> PackedBlock16 {
702        unsafe {
703            let a = transmute::<[Block16; 8], uint16x8_t>(lhs.0);
704
705            let s_lo = (scalar.0 & 0xff) as u8;
706            let s_hi = (scalar.0 >> 8) as u8;
707
708            let b_lo = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(s_lo));
709            let b_hi = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(s_hi));
710            let b_mid = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(s_lo ^ s_hi));
711
712            let a_lo = vmovn_u16(a);
713            let a_hi = vmovn_u16(vshrq_n_u16(a, 8));
714
715            let ll = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
716                transmute::<uint8x8_t, poly8x8_t>(a_lo),
717                b_lo,
718            ));
719
720            let hh = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
721                transmute::<uint8x8_t, poly8x8_t>(a_hi),
722                b_hi,
723            ));
724
725            let mm = transmute::<poly16x8_t, uint16x8_t>(vmull_p8(
726                transmute::<uint8x8_t, poly8x8_t>(veor_u8(a_lo, a_hi)),
727                b_mid,
728            ));
729
730            PackedBlock16(transmute::<uint16x8_t, [Block16; 8]>(reduce_packed_16(
731                ll, mm, hh,
732            )))
733        }
734    }
735
736    // Bit-identical to the scalar mul_flat_16 fold.
737    #[inline(always)]
738    fn reduce_packed_16(ll: uint16x8_t, mm: uint16x8_t, hh: uint16x8_t) -> uint16x8_t {
739        unsafe {
740            let mid = veorq_u16(veorq_u16(mm, ll), hh);
741            let l = veorq_u16(ll, vshlq_n_u16(mid, 8));
742            let h = veorq_u16(hh, vshrq_n_u16(mid, 8));
743
744            let h_fold = veorq_u16(
745                veorq_u16(vshlq_n_u16(h, 5), vshlq_n_u16(h, 3)),
746                veorq_u16(vshlq_n_u16(h, 1), h),
747            );
748
749            // h <= deg 14, so h*R spills <= 4 bits past bit 15.
750            let carry = veorq_u16(
751                veorq_u16(vshrq_n_u16(h, 11), vshrq_n_u16(h, 13)),
752                vshrq_n_u16(h, 15),
753            );
754
755            let carry_fold = veorq_u16(
756                veorq_u16(vshlq_n_u16(carry, 5), vshlq_n_u16(carry, 3)),
757                veorq_u16(vshlq_n_u16(carry, 1), carry),
758            );
759
760            veorq_u16(veorq_u16(l, h_fold), carry_fold)
761        }
762    }
763}
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768    use rand::{RngExt, rng};
769
770    #[cfg(target_arch = "aarch64")]
771    use proptest::prelude::*;
772
773    // ==================================
774    // BASIC
775    // ==================================
776
777    #[test]
778    fn tower_constants() {
779        // Check that tau is propagated correctly
780        // For Block16, tau must be (0, 1) from Block8.
781        let tau16 = Block16::EXTENSION_TAU;
782        let (lo16, hi16) = tau16.split();
783        assert_eq!(lo16, Block8::ZERO);
784        assert_eq!(hi16, Block8(0x20));
785    }
786
787    #[test]
788    fn add_truth() {
789        let zero = Block16::ZERO;
790        let one = Block16::ONE;
791
792        assert_eq!(zero + zero, zero);
793        assert_eq!(zero + one, one);
794        assert_eq!(one + zero, one);
795        assert_eq!(one + one, zero);
796    }
797
798    #[test]
799    fn mul_truth() {
800        let zero = Block16::ZERO;
801        let one = Block16::ONE;
802
803        assert_eq!(zero * zero, zero);
804        assert_eq!(zero * one, zero);
805        assert_eq!(one * one, one);
806    }
807
808    #[test]
809    fn add() {
810        // 5 ^ 3 = 6
811        // 101 ^ 011 = 110
812        assert_eq!(Block16(5) + Block16(3), Block16(6));
813    }
814
815    #[test]
816    fn mul_simple() {
817        // Check for prime numbers (without overflow)
818        // x^1 * x^1 = x^2 (2 * 2 = 4)
819        assert_eq!(Block16(2) * Block16(2), Block16(4));
820    }
821
822    #[test]
823    fn mul_overflow() {
824        // Reduction verification (AES test vectors)
825        // Example from the AES specification:
826        // 0x57 * 0x83 = 0xC1
827        assert_eq!(Block16(0x57) * Block16(0x83), Block16(0xC1));
828    }
829
830    #[test]
831    fn karatsuba_correctness() {
832        // Let A = X (hi=1, lo=0)
833        // Let B = X (hi=1, lo=0)
834        // A * B = X^2
835        // According to the rule:
836        // X^2 = X + tau
837        // Where tau for Block8 = 0x20.
838        // So the result should be:
839        // hi=1 (X), lo=0x20 (tau)
840
841        // Construct X manually
842        let x = Block16::new(Block8::ZERO, Block8::ONE);
843        let squared = x * x;
844
845        // Verify result via splitting
846        let (res_lo, res_hi) = squared.split();
847
848        assert_eq!(res_hi, Block8::ONE, "X^2 should contain X component");
849        assert_eq!(
850            res_lo,
851            Block8(0x20),
852            "X^2 should contain tau component (0x20)"
853        );
854    }
855
856    #[test]
857    fn security_zeroize() {
858        let mut secret_val = Block16::from(0xDEAD_u16);
859        assert_ne!(secret_val, Block16::ZERO);
860
861        secret_val.zeroize();
862
863        assert_eq!(secret_val, Block16::ZERO);
864        assert_eq!(secret_val.0, 0, "Block16 memory leak detected");
865    }
866
867    #[test]
868    fn invert_zero() {
869        // Critical safety check:
870        // Inverting zero must return 0 by convention.
871        assert_eq!(
872            Block16::ZERO.invert(),
873            Block16::ZERO,
874            "invert(0) must return 0"
875        );
876    }
877
878    #[test]
879    fn inversion_random() {
880        let mut rng = rng();
881
882        // Test a significant number of random elements
883        for _ in 0..1000 {
884            let val_u16: u16 = rng.random();
885            let val = Block16(val_u16);
886
887            if val != Block16::ZERO {
888                let inv = val.invert();
889                let res = val * inv;
890
891                assert_eq!(
892                    res,
893                    Block16::ONE,
894                    "Inversion identity failed: a * a^-1 != 1"
895                );
896            }
897        }
898    }
899
900    #[test]
901    fn tower_embedding() {
902        let mut rng = rng();
903        for _ in 0..100 {
904            let a = Block8(rng.random());
905            let b = Block8(rng.random());
906
907            // 1. Structure check:
908            // Lifting puts value in low part,
909            // zero in high part Subfield element
910            // 'a' inside extension must look like (a, 0)
911            let a_lifted: Block16 = a.into();
912            let (lo, hi) = a_lifted.split();
913
914            assert_eq!(lo, a, "Embedding structure failed: low part mismatch");
915            assert_eq!(
916                hi,
917                Block8::ZERO,
918                "Embedding structure failed: high part must be zero"
919            );
920
921            // 2. Addition Homomorphism:
922            // lift(a + b) == lift(a) + lift(b)
923            let sum_sub = a + b;
924            let sum_lifted: Block16 = sum_sub.into();
925            let sum_manual = Block16::from(a) + Block16::from(b);
926
927            assert_eq!(sum_lifted, sum_manual, "Homomorphism failed: add");
928
929            // 3. Multiplication Homomorphism:
930            // lift(a * b) == lift(a) * lift(b)
931            // Operations in the subfield must
932            // match operations in the superfield.
933            let prod_sub = a * b;
934            let prod_lifted: Block16 = prod_sub.into();
935            let prod_manual = Block16::from(a) * Block16::from(b);
936
937            assert_eq!(prod_lifted, prod_manual, "Homomorphism failed: mul");
938        }
939    }
940
941    // ==================================
942    // HARDWARE
943    // ==================================
944
945    #[test]
946    fn isomorphism_roundtrip() {
947        let mut rng = rng();
948        for _ in 0..1000 {
949            let val = Block16(rng.random::<u16>());
950            assert_eq!(
951                val.to_hardware().to_tower(),
952                val,
953                "Block16 isomorphism roundtrip failed"
954            );
955        }
956    }
957
958    #[test]
959    fn flat_mul_homomorphism() {
960        let mut rng = rng();
961        for _ in 0..1000 {
962            let a = Block16(rng.random::<u16>());
963            let b = Block16(rng.random::<u16>());
964
965            let expected_flat = (a * b).to_hardware();
966            let actual_flat = a.to_hardware() * b.to_hardware();
967
968            assert_eq!(
969                actual_flat, expected_flat,
970                "Block16 flat multiplication mismatch"
971            );
972        }
973    }
974
975    #[test]
976    fn packed_consistency() {
977        let mut rng = rng();
978        for _ in 0..100 {
979            let mut a_vals = [Block16::ZERO; 8];
980            let mut b_vals = [Block16::ZERO; 8];
981
982            for i in 0..8 {
983                a_vals[i] = Block16(rng.random::<u16>());
984                b_vals[i] = Block16(rng.random::<u16>());
985            }
986
987            let a_flat_vals = a_vals.map(|x| x.to_hardware());
988            let b_flat_vals = b_vals.map(|x| x.to_hardware());
989            let a_packed = Flat::<Block16>::pack(&a_flat_vals);
990            let b_packed = Flat::<Block16>::pack(&b_flat_vals);
991
992            // Test SIMD Add
993            let add_res = Block16::add_hardware_packed(a_packed, b_packed);
994
995            let mut add_out = [Block16::ZERO.to_hardware(); 8];
996            Flat::<Block16>::unpack(add_res, &mut add_out);
997
998            for i in 0..8 {
999                assert_eq!(
1000                    add_out[i],
1001                    (a_vals[i] + b_vals[i]).to_hardware(),
1002                    "Block16 packed add mismatch"
1003                );
1004            }
1005
1006            // Test SIMD Mul
1007            let mul_res = Block16::mul_hardware_packed(a_packed, b_packed);
1008
1009            let mut mul_out = [Block16::ZERO.to_hardware(); 8];
1010            Flat::<Block16>::unpack(mul_res, &mut mul_out);
1011
1012            for i in 0..8 {
1013                assert_eq!(
1014                    mul_out[i],
1015                    (a_vals[i] * b_vals[i]).to_hardware(),
1016                    "Block16 packed mul mismatch"
1017                );
1018            }
1019        }
1020    }
1021
1022    // ==================================
1023    // PACKED
1024    // ==================================
1025
1026    #[test]
1027    fn pack_unpack_roundtrip() {
1028        let mut rng = rng();
1029        let mut data = [Block16::ZERO; PACKED_WIDTH_16];
1030
1031        for v in data.iter_mut() {
1032            *v = Block16(rng.random());
1033        }
1034
1035        let packed = Block16::pack(&data);
1036        let mut unpacked = [Block16::ZERO; PACKED_WIDTH_16];
1037        Block16::unpack(packed, &mut unpacked);
1038
1039        assert_eq!(data, unpacked, "Block16 pack/unpack roundtrip failed");
1040    }
1041
1042    #[test]
1043    fn packed_add_consistency() {
1044        let mut rng = rng();
1045        let mut a_vals = [Block16::ZERO; PACKED_WIDTH_16];
1046        let mut b_vals = [Block16::ZERO; PACKED_WIDTH_16];
1047
1048        for i in 0..PACKED_WIDTH_16 {
1049            a_vals[i] = Block16(rng.random());
1050            b_vals[i] = Block16(rng.random());
1051        }
1052
1053        let res_packed = Block16::pack(&a_vals) + Block16::pack(&b_vals);
1054        let mut res_unpacked = [Block16::ZERO; PACKED_WIDTH_16];
1055        Block16::unpack(res_packed, &mut res_unpacked);
1056
1057        for i in 0..PACKED_WIDTH_16 {
1058            assert_eq!(
1059                res_unpacked[i],
1060                a_vals[i] + b_vals[i],
1061                "Block16 packed add mismatch"
1062            );
1063        }
1064    }
1065
1066    #[test]
1067    fn packed_mul_consistency() {
1068        let mut rng = rng();
1069
1070        for _ in 0..1000 {
1071            let mut a_arr = [Block16::ZERO; PACKED_WIDTH_16];
1072            let mut b_arr = [Block16::ZERO; PACKED_WIDTH_16];
1073
1074            for i in 0..PACKED_WIDTH_16 {
1075                let val_a_u16: u16 = rng.random();
1076                let val_b_u16: u16 = rng.random();
1077
1078                a_arr[i] = Block16(val_a_u16);
1079                b_arr[i] = Block16(val_b_u16);
1080            }
1081
1082            let a_packed = PackedBlock16(a_arr);
1083            let b_packed = PackedBlock16(b_arr);
1084            let c_packed = a_packed * b_packed;
1085
1086            let mut c_expected = [Block16::ZERO; PACKED_WIDTH_16];
1087            for i in 0..PACKED_WIDTH_16 {
1088                c_expected[i] = a_arr[i] * b_arr[i];
1089            }
1090
1091            assert_eq!(c_packed.0, c_expected, "SIMD Block16 mismatch!");
1092        }
1093    }
1094
1095    #[test]
1096    fn parity_masks_match_from_hardware() {
1097        // Exhaustive for Block16:
1098        // 65536 values * 16 bits.
1099        for x_flat in 0u16..=u16::MAX {
1100            let tower = Block16::from_hardware(Flat::from_raw(Block16(x_flat))).0;
1101
1102            for k in 0..16 {
1103                let bit = ((tower >> k) & 1) as u8;
1104                let via_api = Flat::from_raw(Block16(x_flat)).tower_bit(k);
1105
1106                assert_eq!(
1107                    via_api, bit,
1108                    "Block16 tower_bit_from_hardware mismatch at x_flat={x_flat:#06x}, bit_idx={k}"
1109                );
1110            }
1111        }
1112    }
1113
1114    // The NEON vmull_p8 path vs
1115    // the scalar vmull_p64 path.
1116    #[cfg(target_arch = "aarch64")]
1117    proptest! {
1118        #![proptest_config(ProptestConfig::with_cases(65536))]
1119
1120        #[test]
1121        fn neon_packed_eq_scalar(a in any::<[u16; 8]>(), b in any::<[u16; 8]>()) {
1122            let pp = neon::mul_flat_packed_16(
1123                PackedBlock16(a.map(Block16)),
1124                PackedBlock16(b.map(Block16)),
1125            );
1126
1127            let want: [Block16; 8] =
1128                core::array::from_fn(|i| neon::mul_flat_16(Block16(a[i]), Block16(b[i])));
1129
1130            prop_assert_eq!(pp.0, want);
1131        }
1132
1133        #[test]
1134        fn neon_scalar_packed_eq_scalar(a in any::<[u16; 8]>(), s in any::<u16>()) {
1135            let sp = neon::mul_flat_scalar_packed_16(PackedBlock16(a.map(Block16)), Block16(s));
1136
1137            let want: [Block16; 8] =
1138                core::array::from_fn(|i| neon::mul_flat_16(Block16(a[i]), Block16(s)));
1139
1140            prop_assert_eq!(sp.0, want);
1141        }
1142    }
1143}