Skip to main content

hekate_math/towers/
block8.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate-math project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! BLOCK 8 (GF(2^8))
19use crate::constants::FLAT_TO_TOWER_BIT_MASKS_8;
20use crate::towers::bit::Bit;
21use crate::{
22    CanonicalDeserialize, CanonicalSerialize, Flat, HardwareField, PackableField, PackedFlat,
23    TowerField, constants,
24};
25use core::ops::{Add, AddAssign, BitXor, Mul, MulAssign, Sub, SubAssign};
26use serde::{Deserialize, Serialize};
27use zeroize::Zeroize;
28
29#[cfg(not(feature = "table-math"))]
30#[repr(align(64))]
31struct CtConvertBasisU8<const N: usize>([u8; N]);
32
33#[cfg(not(feature = "table-math"))]
34static TOWER_TO_FLAT_BASIS_8: CtConvertBasisU8<8> =
35    CtConvertBasisU8(constants::RAW_TOWER_TO_FLAT_8);
36
37#[cfg(not(feature = "table-math"))]
38static FLAT_TO_TOWER_BASIS_8: CtConvertBasisU8<8> =
39    CtConvertBasisU8(constants::RAW_FLAT_TO_TOWER_8);
40
41// ============================================================
42// Precomputed Lookup Tables for GF(2^8) arithmetic.
43// Polynomial: x^8 + x^4 + x^3 + x + 1 (0x11B) [AES Standard]
44// Generator: 3 (x + 1)
45// ============================================================
46
47/// Exponentiation Table: g^i
48/// Maps index i -> value inside the field.
49/// Range: [0..255].
50/// Note that EXP_TABLE[0] == 1 and EXP_TABLE[255] == 1.
51#[cfg(feature = "table-math")]
52const EXP_TABLE: [u8; 256] = generate_exp_table();
53
54/// Logarithm Table: log_g(x)
55/// Maps value x -> power i such that g^i = x.
56/// Range: LOG_TABLE[1..=255] contain values 0..254.
57/// LOG_TABLE[0] is 0 (undefined).
58#[cfg(feature = "table-math")]
59const LOG_TABLE: [u8; 256] = generate_log_table();
60
61/// Field element GF(2^8).
62#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
63#[repr(transparent)]
64pub struct Block8(pub u8);
65
66impl Block8 {
67    pub const fn new(val: u8) -> Self {
68        Self(val)
69    }
70
71    #[inline(always)]
72    pub fn square(self) -> Self {
73        // Carryless square (bit spread), then fold the
74        // high half twice by 0x1b (= x^4 + x^3 + x + 1).
75        let mut s = self.0 as u16;
76        s = (s | (s << 4)) & 0x0f0f;
77        s = (s | (s << 2)) & 0x3333;
78        s = (s | (s << 1)) & 0x5555;
79
80        let hi = s >> 8;
81        let s = (s & 0x00ff) ^ (hi ^ (hi << 1) ^ (hi << 3) ^ (hi << 4));
82
83        let hi = s >> 8;
84
85        Block8(((s & 0x00ff) ^ (hi ^ (hi << 1) ^ (hi << 3) ^ (hi << 4))) as u8)
86    }
87}
88
89impl TowerField for Block8 {
90    const BITS: usize = 8;
91    const ZERO: Self = Block8(0);
92    const ONE: Self = Block8(1);
93
94    const EXTENSION_TAU: Self = Block8(0x20);
95
96    fn invert(&self) -> Self {
97        #[cfg(feature = "table-math")]
98        {
99            if self.0 == 0 {
100                return Self::ZERO;
101            }
102
103            let i = LOG_TABLE[self.0 as usize] as usize;
104            Block8(EXP_TABLE[255 - i])
105        }
106
107        #[cfg(not(feature = "table-math"))]
108        {
109            // Fermat's Little Theorem:
110            // a^-1 = a^254 in GF(2^8)
111            // Constant-time, no branching.
112            let x = *self;
113            let x2 = x * x;
114            let x4 = x2 * x2;
115            let x8 = x4 * x4;
116            let x16 = x8 * x8;
117            let x32 = x16 * x16;
118            let x64 = x32 * x32;
119            let x128 = x64 * x64;
120
121            // 254 = 128 + 64 + 32 + 16 + 8 + 4 + 2
122            x128 * x64 * x32 * x16 * x8 * x4 * x2
123        }
124    }
125
126    fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
127        Self(bytes[0])
128    }
129}
130
131/// Add (XOR)
132impl Add for Block8 {
133    type Output = Self;
134
135    fn add(self, rhs: Self) -> Self::Output {
136        Self(self.0.bitxor(rhs.0))
137    }
138}
139
140impl Sub for Block8 {
141    type Output = Self;
142
143    fn sub(self, rhs: Self) -> Self::Output {
144        self.add(rhs)
145    }
146}
147
148/// Mul (Galois Field Multiplication)
149impl Mul for Block8 {
150    type Output = Self;
151
152    fn mul(self, rhs: Self) -> Self::Output {
153        #[cfg(feature = "table-math")]
154        {
155            // Handle zero explicitly (log(0) is undefined)
156            if self.0 == 0 || rhs.0 == 0 {
157                return Self::ZERO;
158            }
159
160            // Lookup Logarithms
161            // Math:
162            // a * b = g^(log(a) + log(b))
163            let i = LOG_TABLE[self.0 as usize] as usize;
164            let j = LOG_TABLE[rhs.0 as usize] as usize;
165
166            // Add exponents modulo 255
167            // Since max(i) = 254, max(i+j) = 508.
168            // Check if sum >= 255 and subtract.
169            let k = i + j;
170            let idx = if k >= 255 { k - 255 } else { k };
171
172            // Lookup Exponent result
173            Self(EXP_TABLE[idx])
174        }
175
176        #[cfg(not(feature = "table-math"))]
177        {
178            #[cfg(target_arch = "aarch64")]
179            {
180                neon::mul_8(self, rhs)
181            }
182
183            #[cfg(not(target_arch = "aarch64"))]
184            {
185                let mut a = self.0;
186                let mut b = rhs.0;
187                let mut res = 0u8;
188
189                // Constant-time shift-and-add
190                // over GF(2^8) with poly 0x11B.
191                for _ in 0..8 {
192                    let bit = b & 1;
193                    let mask = 0u8.wrapping_sub(bit);
194                    res ^= a & mask;
195
196                    let high_bit = a >> 7;
197                    let overflow_mask = 0u8.wrapping_sub(high_bit);
198                    a = (a << 1) ^ (0x1B & overflow_mask);
199
200                    b >>= 1;
201                }
202
203                Self(res)
204            }
205        }
206    }
207}
208
209impl AddAssign for Block8 {
210    fn add_assign(&mut self, rhs: Self) {
211        *self = *self + rhs;
212    }
213}
214
215impl SubAssign for Block8 {
216    fn sub_assign(&mut self, rhs: Self) {
217        *self = *self - rhs;
218    }
219}
220
221impl MulAssign for Block8 {
222    fn mul_assign(&mut self, rhs: Self) {
223        *self = *self * rhs;
224    }
225}
226
227impl CanonicalSerialize for Block8 {
228    #[inline]
229    fn serialized_size(&self) -> usize {
230        1
231    }
232
233    #[inline]
234    fn serialize(&self, writer: &mut [u8]) -> Result<(), ()> {
235        if writer.is_empty() {
236            return Err(());
237        }
238
239        writer[0] = self.0;
240
241        Ok(())
242    }
243}
244
245impl CanonicalDeserialize for Block8 {
246    fn deserialize(bytes: &[u8]) -> Result<Self, ()> {
247        if bytes.is_empty() {
248            return Err(());
249        }
250
251        Ok(Self(bytes[0]))
252    }
253}
254
255impl From<u8> for Block8 {
256    #[inline]
257    fn from(val: u8) -> Self {
258        Self::new(val)
259    }
260}
261
262impl From<u32> for Block8 {
263    #[inline]
264    fn from(val: u32) -> Self {
265        Self(val as u8)
266    }
267}
268
269impl From<u64> for Block8 {
270    #[inline]
271    fn from(val: u64) -> Self {
272        Self(val as u8)
273    }
274}
275
276impl From<u128> for Block8 {
277    #[inline]
278    fn from(val: u128) -> Self {
279        Self(val as u8)
280    }
281}
282
283// ========================================
284// FIELD LIFTING
285// ========================================
286
287impl From<Bit> for Block8 {
288    #[inline(always)]
289    fn from(val: Bit) -> Self {
290        Self(val.0)
291    }
292}
293
294// ===================================
295// PACKED BLOCK 8 (Width = 16)
296// ===================================
297
298// 128 bits / 8 = 16 elements
299pub const PACKED_WIDTH_8: usize = 16;
300
301#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
302#[repr(C, align(16))]
303pub struct PackedBlock8(pub [Block8; PACKED_WIDTH_8]);
304
305impl PackedBlock8 {
306    #[inline(always)]
307    pub fn zero() -> Self {
308        Self([Block8::ZERO; PACKED_WIDTH_8])
309    }
310}
311
312impl PackableField for Block8 {
313    type Packed = PackedBlock8;
314
315    const WIDTH: usize = PACKED_WIDTH_8;
316
317    #[inline(always)]
318    fn pack(chunk: &[Self]) -> Self::Packed {
319        assert!(
320            chunk.len() >= PACKED_WIDTH_8,
321            "PackableField::pack: input slice too short",
322        );
323
324        let mut arr = [Self::ZERO; PACKED_WIDTH_8];
325        arr.copy_from_slice(&chunk[..PACKED_WIDTH_8]);
326
327        PackedBlock8(arr)
328    }
329
330    #[inline(always)]
331    fn unpack(packed: Self::Packed, output: &mut [Self]) {
332        assert!(
333            output.len() >= PACKED_WIDTH_8,
334            "PackableField::unpack: output slice too short",
335        );
336
337        output[..PACKED_WIDTH_8].copy_from_slice(&packed.0);
338    }
339}
340
341impl Add for PackedBlock8 {
342    type Output = Self;
343
344    #[inline(always)]
345    fn add(self, rhs: Self) -> Self {
346        let mut res = [Block8::ZERO; PACKED_WIDTH_8];
347        for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
348            *out = *l + *r;
349        }
350
351        Self(res)
352    }
353}
354
355impl AddAssign for PackedBlock8 {
356    #[inline(always)]
357    fn add_assign(&mut self, rhs: Self) {
358        for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
359            *l += *r;
360        }
361    }
362}
363
364impl Sub for PackedBlock8 {
365    type Output = Self;
366
367    #[inline(always)]
368    fn sub(self, rhs: Self) -> Self {
369        self.add(rhs)
370    }
371}
372
373impl SubAssign for PackedBlock8 {
374    #[inline(always)]
375    fn sub_assign(&mut self, rhs: Self) {
376        self.add_assign(rhs);
377    }
378}
379
380impl Mul for PackedBlock8 {
381    type Output = Self;
382
383    #[inline(always)]
384    fn mul(self, rhs: Self) -> Self {
385        #[cfg(target_arch = "aarch64")]
386        {
387            let mut res = [Block8::ZERO; PACKED_WIDTH_8];
388            for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
389                *out = mul_iso_8(*l, *r);
390            }
391
392            Self(res)
393        }
394
395        #[cfg(not(target_arch = "aarch64"))]
396        {
397            let mut res = [Block8::ZERO; PACKED_WIDTH_8];
398            for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
399                *out = *l * *r;
400            }
401
402            Self(res)
403        }
404    }
405}
406
407impl MulAssign for PackedBlock8 {
408    #[inline(always)]
409    fn mul_assign(&mut self, rhs: Self) {
410        *self = *self * rhs;
411    }
412}
413
414impl Mul<Block8> for PackedBlock8 {
415    type Output = Self;
416
417    #[inline(always)]
418    fn mul(self, rhs: Block8) -> Self {
419        let mut res = [Block8::ZERO; PACKED_WIDTH_8];
420        for (out, v) in res.iter_mut().zip(self.0.iter()) {
421            *out = *v * rhs;
422        }
423
424        Self(res)
425    }
426}
427
428// ===================================
429// Hardware Field
430// ===================================
431
432impl HardwareField for Block8 {
433    #[inline(always)]
434    fn to_hardware(self) -> Flat<Self> {
435        #[cfg(feature = "table-math")]
436        {
437            Flat::from_raw(apply_matrix_8(self, &constants::TOWER_TO_FLAT_8))
438        }
439
440        #[cfg(not(feature = "table-math"))]
441        {
442            Flat::from_raw(Block8(map_ct_8(self.0, &TOWER_TO_FLAT_BASIS_8.0)))
443        }
444    }
445
446    #[inline(always)]
447    fn from_hardware(value: Flat<Self>) -> Self {
448        let value = value.into_raw();
449        #[cfg(feature = "table-math")]
450        {
451            apply_matrix_8(value, &constants::FLAT_TO_TOWER_8)
452        }
453
454        #[cfg(not(feature = "table-math"))]
455        {
456            Block8(map_ct_8(value.0, &FLAT_TO_TOWER_BASIS_8.0))
457        }
458    }
459
460    #[inline(always)]
461    fn add_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
462        Flat::from_raw(lhs.into_raw() + rhs.into_raw())
463    }
464
465    #[inline(always)]
466    fn add_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
467        let lhs = lhs.into_raw();
468        let rhs = rhs.into_raw();
469        #[cfg(target_arch = "aarch64")]
470        {
471            PackedFlat::from_raw(neon::add_packed_8(lhs, rhs))
472        }
473
474        #[cfg(not(target_arch = "aarch64"))]
475        {
476            PackedFlat::from_raw(lhs + rhs)
477        }
478    }
479
480    #[inline(always)]
481    fn mul_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
482        let lhs = lhs.into_raw();
483        let rhs = rhs.into_raw();
484        #[cfg(target_arch = "aarch64")]
485        {
486            Flat::from_raw(neon::mul_8(lhs, rhs))
487        }
488
489        #[cfg(not(target_arch = "aarch64"))]
490        {
491            let a_tower = Self::from_hardware(Flat::from_raw(lhs));
492            let b_tower = Self::from_hardware(Flat::from_raw(rhs));
493
494            (a_tower * b_tower).to_hardware()
495        }
496    }
497
498    #[inline(always)]
499    fn mul_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
500        let lhs = lhs.into_raw();
501        let rhs = rhs.into_raw();
502
503        #[cfg(target_arch = "aarch64")]
504        {
505            PackedFlat::from_raw(neon::mul_flat_packed_8(lhs, rhs))
506        }
507
508        #[cfg(not(target_arch = "aarch64"))]
509        {
510            let mut l = [Self::ZERO; <Self as PackableField>::WIDTH];
511            let mut r = [Self::ZERO; <Self as PackableField>::WIDTH];
512            let mut res = [Self::ZERO; <Self as PackableField>::WIDTH];
513
514            Self::unpack(lhs, &mut l);
515            Self::unpack(rhs, &mut r);
516
517            for i in 0..<Self as PackableField>::WIDTH {
518                res[i] = Self::mul_hardware(Flat::from_raw(l[i]), Flat::from_raw(r[i])).into_raw();
519            }
520
521            PackedFlat::from_raw(Self::pack(&res))
522        }
523    }
524
525    #[inline(always)]
526    fn mul_hardware_scalar_packed(lhs: PackedFlat<Self>, rhs: Flat<Self>) -> PackedFlat<Self> {
527        let broadcasted = PackedBlock8([rhs.into_raw(); PACKED_WIDTH_8]);
528        Self::mul_hardware_packed(lhs, PackedFlat::from_raw(broadcasted))
529    }
530
531    #[inline(always)]
532    fn tower_bit_from_hardware(value: Flat<Self>, bit_idx: usize) -> u8 {
533        let mask = FLAT_TO_TOWER_BIT_MASKS_8[bit_idx];
534
535        // Parity of (x & mask) without popcount
536        let mut v = value.into_raw().0 & mask;
537        v ^= v >> 4;
538        v ^= v >> 2;
539        v ^= v >> 1;
540
541        v & 1
542    }
543}
544
545// ===========================================
546// UTILS
547// ===========================================
548
549#[cfg(target_arch = "aarch64")]
550#[inline(always)]
551fn mul_iso_8(a: Block8, b: Block8) -> Block8 {
552    let a_f = a.to_hardware();
553    let b_f = b.to_hardware();
554    let c_f = Flat::from_raw(neon::mul_8(a_f.into_raw(), b_f.into_raw()));
555
556    c_f.to_tower()
557}
558
559#[cfg(feature = "table-math")]
560#[inline(always)]
561fn apply_matrix_8(val: Block8, table: &[u8; 256]) -> Block8 {
562    let idx = val.0 as usize;
563    Block8(unsafe { *table.get_unchecked(idx) })
564}
565
566#[cfg(not(feature = "table-math"))]
567#[inline(always)]
568fn map_ct_8(x: u8, basis: &[u8; 8]) -> u8 {
569    let mut acc = 0u8;
570    let mut i = 0usize;
571
572    while i < 8 {
573        let bit = (x >> i) & 1;
574        let mask = 0u8.wrapping_sub(bit);
575        acc ^= basis[i] & mask;
576        i += 1;
577    }
578
579    acc
580}
581
582#[cfg(feature = "table-math")]
583const fn generate_exp_table() -> [u8; 256] {
584    let mut table = [0u8; 256];
585    let mut val: u8 = 1;
586
587    // Iterate i from 0 to 255 (inclusive).
588    // This fills table[0]..table[255].
589    // At i=0, table[0] = 1.
590    // At i=255, val cycles back to 1, so table[255] = 1.
591    // This allows safe access to table[255]
592    // during inversion logic (255 - i).
593    let mut i = 0;
594    while i < 256 {
595        table[i] = val;
596
597        // Multiply val by GENERATOR (3) in GF(2^8)
598        // val * 3 = val * (x + 1) = (val << 1) ^ val
599
600        let high_bit = val & 0x80;
601        let mut shifted = val << 1;
602
603        // AES Polynomial 0x11B.
604        // If high bit was set, XOR with
605        // the lower 8 bits (0x1B).
606        if high_bit != 0 {
607            shifted ^= 0x1B;
608        }
609
610        val = shifted ^ val;
611        i += 1;
612    }
613
614    table
615}
616
617#[cfg(feature = "table-math")]
618const fn generate_log_table() -> [u8; 256] {
619    let mut table = [0u8; 256];
620
621    // For Log table, iterate 0..254.
622    // Valid log values are in range [0, 254].
623    // log(1) is 0. log(g^254) is 254.
624    //
625    // Note:
626    // Don't map index 255 here, as log(1)
627    // is strictly 0 for canonical form.
628
629    let mut val: u8 = 1;
630    let mut i = 0;
631
632    while i < 255 {
633        table[val as usize] = i as u8;
634
635        let high_bit = val & 0x80;
636        let mut shifted = val << 1;
637
638        if high_bit != 0 {
639            shifted ^= 0x1B;
640        }
641
642        val = shifted ^ val;
643
644        i += 1;
645    }
646
647    // table[0] remains 0 (log(0) is undefined).
648
649    table
650}
651
652// ===========================================
653// 8-BIT SIMD INSTRUCTIONS
654// ===========================================
655
656#[cfg(target_arch = "aarch64")]
657mod neon {
658    use super::*;
659    use core::arch::aarch64::*;
660    use core::mem::transmute;
661
662    #[inline(always)]
663    pub fn add_packed_8(lhs: PackedBlock8, rhs: PackedBlock8) -> PackedBlock8 {
664        unsafe {
665            let res = veorq_u8(
666                transmute::<[Block8; 16], uint8x16_t>(lhs.0),
667                transmute::<[Block8; 16], uint8x16_t>(rhs.0),
668            );
669            transmute(res)
670        }
671    }
672
673    #[inline(always)]
674    pub fn mul_8(a: Block8, b: Block8) -> Block8 {
675        unsafe {
676            // Load 8-bit scalars
677            // into NEON vectors.
678            let a_poly = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(a.0));
679            let b_poly = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(b.0));
680
681            // Multiply:
682            // 8-bit x 8-bit -> 16-bit
683            let prod = vmull_p8(a_poly, b_poly);
684
685            // Extract the 16-bit result
686            let prod_u16 = vgetq_lane_u16(transmute::<poly16x8_t, uint16x8_t>(prod), 0);
687
688            let l = (prod_u16 & 0xFF) as u8;
689            let h = (prod_u16 >> 8) as u8;
690
691            // P(x) = x^8 + 0x1B
692            let r_val = constants::POLY_8; // u8
693
694            // Fold high bits (h * 0x1B)
695            let h_poly = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(h));
696            let r_poly = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(r_val));
697            let h_red = vmull_p8(h_poly, r_poly);
698
699            let h_red_u16 = vgetq_lane_u16(transmute::<poly16x8_t, uint16x8_t>(h_red), 0);
700
701            let folded = (h_red_u16 & 0xFF) as u8;
702            let carry = (h_red_u16 >> 8) as u8;
703
704            let mut res = l ^ folded;
705
706            // Unconditional carry reduction:
707            // If carry is 0, c_poly is 0,
708            // c_red is 0, and XOR does nothing.
709            let c_poly = transmute::<uint8x8_t, poly8x8_t>(vdup_n_u8(carry));
710            let c_red = vmull_p8(c_poly, r_poly);
711            let c_red_u16 = vgetq_lane_u16(transmute::<poly16x8_t, uint16x8_t>(c_red), 0);
712
713            res ^= (c_red_u16 & 0xFF) as u8;
714
715            Block8(res)
716        }
717    }
718
719    /// Vectorized multiplication for Block8 (16 elements at once).
720    /// Uses vmull_p8 for multiplication and vqtbl1q_u8 for reduction.
721    #[inline(always)]
722    pub fn mul_flat_packed_8(lhs: PackedBlock8, rhs: PackedBlock8) -> PackedBlock8 {
723        unsafe {
724            let a: uint8x16_t = transmute(lhs.0);
725            let b: uint8x16_t = transmute(rhs.0);
726
727            // Split into low/high 64-bit halves
728            let a_lo = vget_low_u8(a);
729            let a_hi = vget_high_u8(a);
730            let b_lo = vget_low_u8(b);
731            let b_hi = vget_high_u8(b);
732
733            // Multiply 8x8 -> 16 bits
734            // (poly16x8_t, which is 128-bit wide)
735            let res_lo = vmull_p8(
736                transmute::<uint8x8_t, poly8x8_t>(a_lo),
737                transmute::<uint8x8_t, poly8x8_t>(b_lo),
738            );
739            let res_hi = vmull_p8(
740                transmute::<uint8x8_t, poly8x8_t>(a_hi),
741                transmute::<uint8x8_t, poly8x8_t>(b_hi),
742            );
743
744            // Reduction using Table Lookup
745            // Load the tables once.
746            let tbl_lo = vld1q_u8(
747                [
748                    0x00, 0x1b, 0x36, 0x2d, 0x6c, 0x77, 0x5a, 0x41, 0xd8, 0xc3, 0xee, 0xf5, 0xb4,
749                    0xaf, 0x82, 0x99,
750                ]
751                .as_ptr(),
752            );
753
754            let tbl_hi = vld1q_u8(
755                [
756                    0x00, 0xab, 0x4d, 0xe6, 0x9a, 0x31, 0xd7, 0x7c, 0x2f, 0x84, 0x62, 0xc9, 0xb5,
757                    0x1e, 0xf8, 0x53,
758                ]
759                .as_ptr(),
760            );
761
762            // Helper to reduce a 128-bit vector
763            // of 16-bit polys down to a 64-bit
764            // vector of 8-bit results.
765            let reduce_tbl = |val_poly: poly16x8_t| -> uint8x8_t {
766                let val: uint16x8_t = transmute(val_poly);
767
768                // vmovn_u16 narrows 128-bit (u16x8) to 64-bit (u8x8)
769                let data = vmovn_u16(val);
770                let carry_u16 = vshrq_n_u16(val, 8);
771                let carry = vmovn_u16(carry_u16);
772
773                // Operations on 64-bit vectors
774                let mask_lo = vdup_n_u8(0x0F);
775                let h_lo = vand_u8(carry, mask_lo);
776                let h_hi = vshr_n_u8(carry, 4);
777
778                // Lookup:
779                // Table is 128-bit (q),
780                // Index is 64-bit.
781                // Result is 64-bit.
782                let r_lo = vqtbl1_u8(tbl_lo, h_lo);
783                let r_hi = vqtbl1_u8(tbl_hi, h_hi);
784
785                // XOR everything together
786                veor_u8(data, veor_u8(r_lo, r_hi))
787            };
788
789            let final_lo = reduce_tbl(res_lo);
790            let final_hi = reduce_tbl(res_hi);
791
792            // Combine two 64-bit results
793            // back into one 128-bit vector.
794            let res = vcombine_u8(final_lo, final_hi);
795
796            PackedBlock8(transmute::<uint8x16_t, [Block8; 16]>(res))
797        }
798    }
799}
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804    use rand::{RngExt, rng};
805
806    // ==================================
807    // BASIC
808    // ==================================
809
810    #[test]
811    fn tower_constants() {
812        // Check that tau is propagated correctly
813        // For Block8 we set 0x20
814        assert_eq!(Block8::EXTENSION_TAU, Block8(0x20));
815    }
816
817    #[test]
818    fn add_truth() {
819        let zero = Block8::ZERO;
820        let one = Block8::ONE;
821
822        assert_eq!(zero + zero, zero);
823        assert_eq!(zero + one, one);
824        assert_eq!(one + zero, one);
825        assert_eq!(one + one, zero);
826    }
827
828    #[test]
829    fn mul_truth() {
830        let zero = Block8::ZERO;
831        let one = Block8::ONE;
832
833        assert_eq!(zero * zero, zero);
834        assert_eq!(zero * one, zero);
835        assert_eq!(one * one, one);
836    }
837
838    #[test]
839    fn add() {
840        // 5 ^ 3 = 6
841        // 101 ^ 011 = 110
842        assert_eq!(Block8(5) + Block8(3), Block8(6));
843    }
844
845    #[test]
846    fn mul_simple() {
847        // Check for prime numbers (without overflow)
848        // x^1 * x^1 = x^2 (2 * 2 = 4)
849        assert_eq!(Block8(2) * Block8(2), Block8(4));
850    }
851
852    #[test]
853    fn mul_overflow() {
854        // Reduction verification (AES test vectors)
855        // Example from the AES specification:
856        // 0x57 * 0x83 = 0xC1
857        assert_eq!(Block8(0x57) * Block8(0x83), Block8(0xC1));
858    }
859
860    #[test]
861    fn square_exhaustive() {
862        for i in 0u16..=255 {
863            let x = Block8(i as u8);
864            assert_eq!(x.square(), x * x, "Block8 square mismatch at {i:#04x}");
865        }
866    }
867
868    #[test]
869    fn security_zeroize() {
870        let mut secret_val = Block8::from(0xFF_u32);
871        assert_ne!(secret_val, Block8::ZERO);
872
873        secret_val.zeroize();
874
875        assert_eq!(secret_val, Block8::ZERO);
876        assert_eq!(secret_val.0, 0, "Block8 memory leak detected");
877    }
878
879    #[test]
880    fn inversion_exhaustive() {
881        // Iterate over all possible field elements (0..255)
882        for i in 0u8..=255 {
883            let val = Block8(i);
884
885            if val == Block8::ZERO {
886                // Case 1:
887                // Zero inversion safety check
888                assert_eq!(val.invert(), Block8::ZERO, "invert(0) must return 0");
889            } else {
890                // Case 2:
891                // Algebraic correctness a * a^-1 = 1
892                let inv = val.invert();
893                let product = val * inv;
894
895                assert_eq!(
896                    product,
897                    Block8::ONE,
898                    "Inversion identity failed: a * a^-1 != 1"
899                );
900            }
901        }
902    }
903
904    // ==================================
905    // HARDWARE
906    // ==================================
907
908    #[test]
909    fn isomorphism_roundtrip() {
910        let mut rng = rng();
911        for _ in 0..1000 {
912            let val = Block8::from(rng.random::<u8>());
913
914            // Roundtrip:
915            // Tower -> Flat -> Tower must be identity
916            assert_eq!(
917                val.to_hardware().to_tower(),
918                val,
919                "Block8 isomorphism roundtrip failed"
920            );
921        }
922    }
923
924    #[test]
925    fn parity_masks_match_from_hardware() {
926        // Exhaustive for Block8:
927        // 256 values * 8 bits.
928        for x in 0u16..=255 {
929            let x_flat = x as u8;
930            let tower = Block8::from_hardware(Flat::from_raw(Block8(x_flat))).0;
931
932            for (k, &mask) in FLAT_TO_TOWER_BIT_MASKS_8.iter().enumerate() {
933                let parity = ((x_flat & mask).count_ones() & 1) as u8;
934                let bit = (tower >> k) & 1;
935                assert_eq!(
936                    parity, bit,
937                    "Block8 mask mismatch at x={x_flat:#04x}, k={k}"
938                );
939
940                let via_api = Flat::from_raw(Block8(x_flat)).tower_bit(k);
941                assert_eq!(via_api, bit, "Block8 tower_bit_from_hardware mismatch");
942            }
943        }
944    }
945
946    #[test]
947    fn flat_mul_homomorphism() {
948        let mut rng = rng();
949        for _ in 0..1000 {
950            let a = Block8::from(rng.random::<u8>());
951            let b = Block8::from(rng.random::<u8>());
952
953            let expected_flat = (a * b).to_hardware();
954            let actual_flat = a.to_hardware() * b.to_hardware();
955
956            // Check if multiplication in Flat basis matches Tower
957            assert_eq!(
958                actual_flat, expected_flat,
959                "Block8 flat multiplication mismatch"
960            );
961        }
962    }
963
964    #[test]
965    fn packed_consistency() {
966        let mut rng = rng();
967        for _ in 0..100 {
968            let mut a_vals = [Block8::ZERO; 16];
969            let mut b_vals = [Block8::ZERO; 16];
970
971            for i in 0..16 {
972                a_vals[i] = Block8::from(rng.random::<u8>());
973                b_vals[i] = Block8::from(rng.random::<u8>());
974            }
975
976            let a_flat_vals = a_vals.map(|x| x.to_hardware());
977            let b_flat_vals = b_vals.map(|x| x.to_hardware());
978            let a_packed = Flat::<Block8>::pack(&a_flat_vals);
979            let b_packed = Flat::<Block8>::pack(&b_flat_vals);
980
981            // Test SIMD Add (XOR)
982            let add_res = Block8::add_hardware_packed(a_packed, b_packed);
983
984            let mut add_out = [Block8::ZERO.to_hardware(); 16];
985            Flat::<Block8>::unpack(add_res, &mut add_out);
986
987            for i in 0..16 {
988                assert_eq!(
989                    add_out[i],
990                    (a_vals[i] + b_vals[i]).to_hardware(),
991                    "Block8 packed add mismatch"
992                );
993            }
994
995            // Test SIMD Mul (Flat basis)
996            let mul_res = Block8::mul_hardware_packed(a_packed, b_packed);
997
998            let mut mul_out = [Block8::ZERO.to_hardware(); 16];
999            Flat::<Block8>::unpack(mul_res, &mut mul_out);
1000
1001            for i in 0..16 {
1002                assert_eq!(
1003                    mul_out[i],
1004                    (a_vals[i] * b_vals[i]).to_hardware(),
1005                    "Block8 packed mul mismatch"
1006                );
1007            }
1008        }
1009    }
1010
1011    // ==================================
1012    // PACKED
1013    // ==================================
1014
1015    #[test]
1016    fn pack_unpack_roundtrip() {
1017        let mut rng = rng();
1018        let mut data = [Block8::ZERO; PACKED_WIDTH_8];
1019
1020        for v in data.iter_mut() {
1021            *v = Block8(rng.random());
1022        }
1023
1024        let packed = Block8::pack(&data);
1025        let mut unpacked = [Block8::ZERO; PACKED_WIDTH_8];
1026        Block8::unpack(packed, &mut unpacked);
1027
1028        assert_eq!(data, unpacked, "Block8 pack/unpack roundtrip failed");
1029    }
1030
1031    #[test]
1032    fn packed_add_consistency() {
1033        let mut rng = rng();
1034        let mut a_vals = [Block8::ZERO; PACKED_WIDTH_8];
1035        let mut b_vals = [Block8::ZERO; PACKED_WIDTH_8];
1036
1037        for i in 0..PACKED_WIDTH_8 {
1038            a_vals[i] = Block8(rng.random());
1039            b_vals[i] = Block8(rng.random());
1040        }
1041
1042        let a_packed = Block8::pack(&a_vals);
1043        let b_packed = Block8::pack(&b_vals);
1044        let res_packed = a_packed + b_packed;
1045
1046        let mut res_unpacked = [Block8::ZERO; PACKED_WIDTH_8];
1047        Block8::unpack(res_packed, &mut res_unpacked);
1048
1049        for i in 0..PACKED_WIDTH_8 {
1050            assert_eq!(
1051                res_unpacked[i],
1052                a_vals[i] + b_vals[i],
1053                "Block8 packed add mismatch at index {}",
1054                i
1055            );
1056        }
1057    }
1058
1059    #[test]
1060    fn packed_mul_consistency() {
1061        let mut rng = rng();
1062
1063        for _ in 0..1000 {
1064            let mut a_arr = [Block8::ZERO; PACKED_WIDTH_8];
1065            let mut b_arr = [Block8::ZERO; PACKED_WIDTH_8];
1066
1067            for i in 0..PACKED_WIDTH_8 {
1068                let val_a: u8 = rng.random();
1069                let val_b: u8 = rng.random();
1070                a_arr[i] = Block8(val_a);
1071                b_arr[i] = Block8(val_b);
1072            }
1073
1074            let a_packed = PackedBlock8(a_arr);
1075            let b_packed = PackedBlock8(b_arr);
1076            let c_packed = a_packed * b_packed;
1077
1078            let mut c_expected = [Block8::ZERO; PACKED_WIDTH_8];
1079            for i in 0..PACKED_WIDTH_8 {
1080                c_expected[i] = a_arr[i] * b_arr[i];
1081            }
1082
1083            assert_eq!(c_packed.0, c_expected, "SIMD Block8 mismatch!");
1084        }
1085    }
1086}