Skip to main content

hekate_math/towers/
block64.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate-math project.
3// Copyright (C) 2026 Andrei Kochergin <zeek@tuta.com>
4// Copyright (C) 2026 Oumuamua Labs. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! BLOCK 64 (GF(2^64))
19use crate::constants::FLAT_TO_TOWER_BIT_MASKS_64;
20use crate::towers::bit::Bit;
21use crate::towers::block8::Block8;
22use crate::towers::block16::Block16;
23use crate::towers::block32::Block32;
24use crate::{
25    CanonicalDeserialize, CanonicalSerialize, Flat, FlatPromote, HardwareField, PackableField,
26    PackedFlat, TowerField, constants,
27};
28use core::ops::{Add, AddAssign, BitXor, BitXorAssign, Mul, MulAssign, Sub, SubAssign};
29use serde::{Deserialize, Serialize};
30use zeroize::Zeroize;
31
32#[cfg(not(feature = "table-math"))]
33#[repr(align(64))]
34struct CtConvertBasisU64<const N: usize>([u64; N]);
35
36#[cfg(not(feature = "table-math"))]
37static TOWER_TO_FLAT_BASIS_64: CtConvertBasisU64<64> =
38    CtConvertBasisU64(constants::RAW_TOWER_TO_FLAT_64);
39
40#[cfg(not(feature = "table-math"))]
41static FLAT_TO_TOWER_BASIS_64: CtConvertBasisU64<64> =
42    CtConvertBasisU64(constants::RAW_FLAT_TO_TOWER_64);
43
44#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
45#[repr(transparent)]
46pub struct Block64(pub u64);
47
48impl Block64 {
49    // 0x2000_0000 << 32 = 0x2000_0000_0000_0000
50    pub const TAU: Self = Block64(0x2000_0000_0000_0000);
51
52    pub fn new(lo: Block32, hi: Block32) -> Self {
53        Self((hi.0 as u64) << 32 | (lo.0 as u64))
54    }
55
56    #[inline(always)]
57    pub fn split(self) -> (Block32, Block32) {
58        (Block32(self.0 as u32), Block32((self.0 >> 32) as u32))
59    }
60}
61
62impl TowerField for Block64 {
63    const BITS: usize = 64;
64    const ZERO: Self = Block64(0);
65    const ONE: Self = Block64(1);
66
67    const EXTENSION_TAU: Self = Self::TAU;
68
69    fn invert(&self) -> Self {
70        let (l, h) = self.split();
71        let h2 = h * h;
72        let l2 = l * l;
73        let hl = h * l;
74        let norm = (h2 * Block32::TAU) + hl + l2;
75
76        let norm_inv = norm.invert();
77        let res_hi = h * norm_inv;
78        let res_lo = (h + l) * norm_inv;
79
80        Self::new(res_lo, res_hi)
81    }
82
83    fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
84        let mut buf = [0u8; 8];
85        buf.copy_from_slice(&bytes[0..8]);
86
87        Self(u64::from_le_bytes(buf))
88    }
89}
90
91impl Add for Block64 {
92    type Output = Self;
93
94    fn add(self, rhs: Self) -> Self {
95        Self(self.0.bitxor(rhs.0))
96    }
97}
98
99impl Sub for Block64 {
100    type Output = Self;
101
102    fn sub(self, rhs: Self) -> Self {
103        self.add(rhs)
104    }
105}
106
107impl Mul for Block64 {
108    type Output = Self;
109
110    fn mul(self, rhs: Self) -> Self {
111        let (a0, a1) = self.split();
112        let (b0, b1) = rhs.split();
113
114        let v0 = a0 * b0;
115        let v1 = a1 * b1;
116        let v_sum = (a0 + a1) * (b0 + b1);
117
118        let c_hi = v0 + v_sum;
119        let c_lo = v0 + (v1 * Block32::TAU);
120
121        Self::new(c_lo, c_hi)
122    }
123}
124
125impl AddAssign for Block64 {
126    fn add_assign(&mut self, rhs: Self) {
127        self.0.bitxor_assign(rhs.0);
128    }
129}
130
131impl SubAssign for Block64 {
132    fn sub_assign(&mut self, rhs: Self) {
133        self.0.bitxor_assign(rhs.0);
134    }
135}
136
137impl MulAssign for Block64 {
138    fn mul_assign(&mut self, rhs: Self) {
139        *self = *self * rhs;
140    }
141}
142
143impl CanonicalSerialize for Block64 {
144    fn serialized_size(&self) -> usize {
145        8
146    }
147
148    fn serialize(&self, writer: &mut [u8]) -> Result<(), ()> {
149        if writer.len() < 8 {
150            return Err(());
151        }
152
153        writer.copy_from_slice(&self.0.to_le_bytes());
154
155        Ok(())
156    }
157}
158
159impl CanonicalDeserialize for Block64 {
160    fn deserialize(bytes: &[u8]) -> Result<Self, ()> {
161        if bytes.len() < 8 {
162            return Err(());
163        }
164
165        let mut buf = [0u8; 8];
166        buf.copy_from_slice(&bytes[0..8]);
167
168        Ok(Self(u64::from_le_bytes(buf)))
169    }
170}
171
172impl From<u8> for Block64 {
173    #[inline(always)]
174    fn from(val: u8) -> Self {
175        Self(val as u64)
176    }
177}
178
179impl From<u32> for Block64 {
180    #[inline(always)]
181    fn from(val: u32) -> Self {
182        Self::from(val as u64)
183    }
184}
185
186impl From<u64> for Block64 {
187    #[inline(always)]
188    fn from(val: u64) -> Self {
189        Self(val)
190    }
191}
192
193impl From<u128> for Block64 {
194    #[inline(always)]
195    fn from(val: u128) -> Self {
196        Self(val as u64)
197    }
198}
199
200// ========================================
201// FIELD LIFTING
202// ========================================
203
204impl From<Bit> for Block64 {
205    #[inline(always)]
206    fn from(val: Bit) -> Self {
207        Self(val.0 as u64)
208    }
209}
210
211impl From<Block8> for Block64 {
212    #[inline(always)]
213    fn from(val: Block8) -> Self {
214        Self(val.0 as u64)
215    }
216}
217
218impl From<Block16> for Block64 {
219    #[inline(always)]
220    fn from(val: Block16) -> Self {
221        Self(val.0 as u64)
222    }
223}
224
225impl From<Block32> for Block64 {
226    #[inline(always)]
227    fn from(val: Block32) -> Self {
228        Self(val.0 as u64)
229    }
230}
231
232// ===================================
233// PACKED BLOCK 64 (Width = 2)
234// ===================================
235
236pub const PACKED_WIDTH_64: usize = 2;
237
238#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
239#[repr(C, align(16))] // 128-bit alignment
240pub struct PackedBlock64(pub [Block64; PACKED_WIDTH_64]);
241
242impl PackedBlock64 {
243    #[inline(always)]
244    pub fn zero() -> Self {
245        Self([Block64::ZERO; PACKED_WIDTH_64])
246    }
247}
248
249impl PackableField for Block64 {
250    type Packed = PackedBlock64;
251
252    const WIDTH: usize = PACKED_WIDTH_64;
253
254    #[inline(always)]
255    fn pack(chunk: &[Self]) -> Self::Packed {
256        assert!(
257            chunk.len() >= PACKED_WIDTH_64,
258            "PackableField::pack: input slice too short",
259        );
260
261        let mut arr = [Self::ZERO; PACKED_WIDTH_64];
262        arr.copy_from_slice(&chunk[..PACKED_WIDTH_64]);
263
264        PackedBlock64(arr)
265    }
266
267    #[inline(always)]
268    fn unpack(packed: Self::Packed, output: &mut [Self]) {
269        assert!(
270            output.len() >= PACKED_WIDTH_64,
271            "PackableField::unpack: output slice too short",
272        );
273
274        output[..PACKED_WIDTH_64].copy_from_slice(&packed.0);
275    }
276}
277
278impl Add for PackedBlock64 {
279    type Output = Self;
280
281    #[inline(always)]
282    fn add(self, rhs: Self) -> Self {
283        let mut res = [Block64::ZERO; PACKED_WIDTH_64];
284        for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
285            *out = *l + *r;
286        }
287
288        Self(res)
289    }
290}
291
292impl AddAssign for PackedBlock64 {
293    #[inline(always)]
294    fn add_assign(&mut self, rhs: Self) {
295        for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
296            *l += *r;
297        }
298    }
299}
300
301impl Sub for PackedBlock64 {
302    type Output = Self;
303
304    #[inline(always)]
305    fn sub(self, rhs: Self) -> Self {
306        self.add(rhs)
307    }
308}
309
310impl SubAssign for PackedBlock64 {
311    #[inline(always)]
312    fn sub_assign(&mut self, rhs: Self) {
313        self.add_assign(rhs);
314    }
315}
316
317impl Mul for PackedBlock64 {
318    type Output = Self;
319
320    #[inline(always)]
321    fn mul(self, rhs: Self) -> Self {
322        #[cfg(target_arch = "aarch64")]
323        {
324            let a0 = mul_iso_64(self.0[0], rhs.0[0]);
325            let a1 = mul_iso_64(self.0[1], rhs.0[1]);
326
327            Self([a0, a1])
328        }
329
330        #[cfg(not(target_arch = "aarch64"))]
331        {
332            let mut res = [Block64::ZERO; PACKED_WIDTH_64];
333            for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
334                *out = *l * *r;
335            }
336
337            Self(res)
338        }
339    }
340}
341
342impl MulAssign for PackedBlock64 {
343    #[inline(always)]
344    fn mul_assign(&mut self, rhs: Self) {
345        for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
346            *l *= *r;
347        }
348    }
349}
350
351impl Mul<Block64> for PackedBlock64 {
352    type Output = Self;
353
354    #[inline(always)]
355    fn mul(self, rhs: Block64) -> Self {
356        let mut res = [Block64::ZERO; PACKED_WIDTH_64];
357        for (out, v) in res.iter_mut().zip(self.0.iter()) {
358            *out = *v * rhs;
359        }
360
361        Self(res)
362    }
363}
364
365// ===================================
366// Hardware Field
367// ===================================
368
369impl HardwareField for Block64 {
370    #[inline(always)]
371    fn to_hardware(self) -> Flat<Self> {
372        #[cfg(feature = "table-math")]
373        {
374            Flat::from_raw(apply_matrix_64(self, &constants::TOWER_TO_FLAT_64))
375        }
376
377        #[cfg(not(feature = "table-math"))]
378        {
379            Flat::from_raw(Block64(map_ct_64(self.0, &TOWER_TO_FLAT_BASIS_64.0)))
380        }
381    }
382
383    #[inline(always)]
384    fn from_hardware(value: Flat<Self>) -> Self {
385        let value = value.into_raw();
386
387        #[cfg(feature = "table-math")]
388        {
389            apply_matrix_64(value, &constants::FLAT_TO_TOWER_64)
390        }
391
392        #[cfg(not(feature = "table-math"))]
393        {
394            Block64(map_ct_64(value.0, &FLAT_TO_TOWER_BASIS_64.0))
395        }
396    }
397
398    #[inline(always)]
399    fn add_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
400        Flat::from_raw(lhs.into_raw() + rhs.into_raw())
401    }
402
403    #[inline(always)]
404    fn add_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
405        let lhs = lhs.into_raw();
406        let rhs = rhs.into_raw();
407
408        #[cfg(target_arch = "aarch64")]
409        {
410            PackedFlat::from_raw(neon::add_packed_64(lhs, rhs))
411        }
412
413        #[cfg(not(target_arch = "aarch64"))]
414        {
415            PackedFlat::from_raw(lhs + rhs)
416        }
417    }
418
419    #[inline(always)]
420    fn mul_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
421        let lhs = lhs.into_raw();
422        let rhs = rhs.into_raw();
423
424        #[cfg(target_arch = "aarch64")]
425        {
426            Flat::from_raw(neon::mul_flat_64(lhs, rhs))
427        }
428
429        #[cfg(not(target_arch = "aarch64"))]
430        {
431            let a_tower = Self::from_hardware(Flat::from_raw(lhs));
432            let b_tower = Self::from_hardware(Flat::from_raw(rhs));
433
434            (a_tower * b_tower).to_hardware()
435        }
436    }
437
438    #[inline(always)]
439    fn mul_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
440        let lhs = lhs.into_raw();
441        let rhs = rhs.into_raw();
442
443        #[cfg(target_arch = "aarch64")]
444        {
445            PackedFlat::from_raw(neon::mul_flat_packed_64(lhs, rhs))
446        }
447
448        #[cfg(not(target_arch = "aarch64"))]
449        {
450            let mut l = [Self::ZERO; <Self as PackableField>::WIDTH];
451            let mut r = [Self::ZERO; <Self as PackableField>::WIDTH];
452            let mut res = [Self::ZERO; <Self as PackableField>::WIDTH];
453
454            Self::unpack(lhs, &mut l);
455            Self::unpack(rhs, &mut r);
456
457            for i in 0..<Self as PackableField>::WIDTH {
458                res[i] = Self::mul_hardware(Flat::from_raw(l[i]), Flat::from_raw(r[i])).into_raw();
459            }
460
461            PackedFlat::from_raw(Self::pack(&res))
462        }
463    }
464
465    #[inline(always)]
466    fn mul_hardware_scalar_packed(lhs: PackedFlat<Self>, rhs: Flat<Self>) -> PackedFlat<Self> {
467        let broadcasted = PackedBlock64([rhs.into_raw(); PACKED_WIDTH_64]);
468        Self::mul_hardware_packed(lhs, PackedFlat::from_raw(broadcasted))
469    }
470
471    #[inline(always)]
472    fn tower_bit_from_hardware(value: Flat<Self>, bit_idx: usize) -> u8 {
473        let mask = FLAT_TO_TOWER_BIT_MASKS_64[bit_idx];
474
475        // Parity of (x & mask) without popcount
476        // Folds 64 bits down to 4,
477        // then uses a lookup table.
478        let mut v = value.into_raw().0 & mask;
479        v ^= v >> 32;
480        v ^= v >> 16;
481        v ^= v >> 8;
482        v ^= v >> 4;
483
484        let idx = (v & 0xF) as u8;
485
486        // Nibble parity lookup encoded
487        // in a 16-bit constant (0x6996).
488        ((0x6996u16 >> idx) & 1) as u8
489    }
490}
491
492impl FlatPromote<Block8> for Block64 {
493    #[inline(always)]
494    fn promote_flat(val: Flat<Block8>) -> Flat<Self> {
495        let val = val.into_raw();
496
497        #[cfg(not(feature = "table-math"))]
498        {
499            let mut acc = 0u64;
500            for i in 0..8 {
501                let bit = (val.0 >> i) & 1;
502                let mask = 0u64.wrapping_sub(bit as u64);
503                acc ^= constants::LIFT_BASIS_8_TO_64[i] & mask;
504            }
505
506            Flat::from_raw(Block64(acc))
507        }
508
509        #[cfg(feature = "table-math")]
510        {
511            Flat::from_raw(Block64(constants::LIFT_TABLE_8_TO_64[val.0 as usize]))
512        }
513    }
514}
515
516// ===========================================
517// UTILS
518// ===========================================
519
520#[cfg(target_arch = "aarch64")]
521#[inline(always)]
522pub fn mul_iso_64(a: Block64, b: Block64) -> Block64 {
523    let a_flat = a.to_hardware();
524    let b_flat = b.to_hardware();
525
526    let c_flat = Flat::from_raw(neon::mul_flat_64(a_flat.into_raw(), b_flat.into_raw()));
527
528    c_flat.to_tower()
529}
530
531#[cfg(feature = "table-math")]
532#[inline(always)]
533pub fn apply_matrix_64(val: Block64, table: &[u64; 2048]) -> Block64 {
534    let mut res = 0u64;
535    let v = val.0;
536
537    // 8 lookups (8-bit window)
538    for i in 0..8 {
539        let byte = (v >> (i * 8)) & 0xFF;
540        let idx = (i * 256) + (byte as usize);
541        res ^= unsafe { *table.get_unchecked(idx) };
542    }
543
544    Block64(res)
545}
546
547#[cfg(not(feature = "table-math"))]
548#[inline(always)]
549fn map_ct_64(x: u64, basis: &[u64; 64]) -> u64 {
550    let mut acc = 0u64;
551    let mut i = 0usize;
552
553    while i < 64 {
554        let bit = (x >> i) & 1;
555        let mask = 0u64.wrapping_sub(bit);
556        acc ^= basis[i] & mask;
557        i += 1;
558    }
559
560    acc
561}
562
563// ===========================================
564// SIMD INSTRUCTIONS
565// ===========================================
566
567#[cfg(target_arch = "aarch64")]
568mod neon {
569    use super::*;
570    use core::arch::aarch64::*;
571    use core::mem::transmute;
572
573    #[inline(always)]
574    pub fn add_packed_64(lhs: PackedBlock64, rhs: PackedBlock64) -> PackedBlock64 {
575        unsafe {
576            let l: uint8x16_t = transmute::<[Block64; PACKED_WIDTH_64], uint8x16_t>(lhs.0);
577            let r: uint8x16_t = transmute::<[Block64; PACKED_WIDTH_64], uint8x16_t>(rhs.0);
578            let res = veorq_u8(l, r);
579            let out: [Block64; PACKED_WIDTH_64] =
580                transmute::<uint8x16_t, [Block64; PACKED_WIDTH_64]>(res);
581
582            PackedBlock64(out)
583        }
584    }
585
586    #[inline(always)]
587    pub fn mul_flat_packed_64(lhs: PackedBlock64, rhs: PackedBlock64) -> PackedBlock64 {
588        unsafe {
589            let a: uint64x2_t = transmute(lhs.0);
590            let b: uint64x2_t = transmute(rhs.0);
591
592            let a_lo = vget_low_u64(a);
593            let b_lo = vget_low_u64(b);
594
595            let p0: uint64x2_t =
596                transmute(vmull_p64(vget_lane_u64(a_lo, 0), vget_lane_u64(b_lo, 0)));
597
598            let a_hi = vget_high_u64(a);
599            let b_hi = vget_high_u64(b);
600            let p1: uint64x2_t =
601                transmute(vmull_p64(vget_lane_u64(a_hi, 0), vget_lane_u64(b_hi, 0)));
602
603            let r0 = reduce_64(p0);
604            let r1 = reduce_64(p1);
605
606            PackedBlock64([r0, r1])
607        }
608    }
609
610    #[inline(always)]
611    fn reduce_64(prod: uint64x2_t) -> Block64 {
612        unsafe {
613            let l = vgetq_lane_u64(prod, 0);
614            let h = vgetq_lane_u64(prod, 1);
615
616            let r_val = constants::POLY_64;
617
618            let h_red: uint64x2_t = transmute(vmull_p64(h, r_val));
619
620            let folded = vgetq_lane_u64(h_red, 0);
621            let carry = vgetq_lane_u64(h_red, 1);
622
623            let mut res = l ^ folded;
624
625            let carry_red: uint64x2_t = transmute(vmull_p64(carry, r_val));
626            res ^= vgetq_lane_u64(carry_red, 0);
627
628            Block64(res)
629        }
630    }
631
632    #[inline(always)]
633    pub fn mul_flat_64(a: Block64, b: Block64) -> Block64 {
634        unsafe {
635            // Multiply 64x64 -> 128
636            let prod = vmull_p64(a.0, b.0);
637            let prod_u64: uint64x2_t = transmute(prod);
638
639            let l = vgetq_lane_u64(prod_u64, 0);
640            let h = vgetq_lane_u64(prod_u64, 1);
641
642            // Reduce mod P(x) = x^64 + R(x).
643            let r_val = constants::POLY_64; // u64
644
645            // H * R
646            let h_red = vmull_p64(h, r_val);
647            let h_red_u64: uint64x2_t = transmute(h_red);
648
649            let folded = vgetq_lane_u64(h_red_u64, 0);
650            let carry = vgetq_lane_u64(h_red_u64, 1);
651
652            let mut res = l ^ folded;
653
654            // Reduce carry (if exists)
655            let carry_red = vmull_p64(carry, r_val);
656            let carry_res_vec: uint64x2_t = transmute(carry_red);
657            let carry_val = vgetq_lane_u64(carry_res_vec, 0);
658
659            res ^= carry_val;
660
661            Block64(res)
662        }
663    }
664}
665
666// ==================================
667// BLOCK 64 TESTS
668// ==================================
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673    use proptest::prelude::*;
674    use rand::{RngExt, rng};
675
676    // ==================================
677    // BASIC
678    // ==================================
679
680    #[test]
681    fn tower_constants() {
682        // Check that tau is propagated correctly
683        // For Block64, tau must be (0, 1) from Block32.
684        let tau64 = Block64::EXTENSION_TAU;
685        let (lo64, hi64) = tau64.split();
686        assert_eq!(lo64, Block32::ZERO);
687        assert_eq!(hi64, Block32::TAU);
688    }
689
690    #[test]
691    fn add_truth() {
692        let zero = Block64::ZERO;
693        let one = Block64::ONE;
694
695        assert_eq!(zero + zero, zero);
696        assert_eq!(zero + one, one);
697        assert_eq!(one + zero, one);
698        assert_eq!(one + one, zero);
699    }
700
701    #[test]
702    fn mul_truth() {
703        let zero = Block64::ZERO;
704        let one = Block64::ONE;
705
706        assert_eq!(zero * zero, zero);
707        assert_eq!(zero * one, zero);
708        assert_eq!(one * one, one);
709    }
710
711    #[test]
712    fn add() {
713        // 5 ^ 3 = 6
714        // 101 ^ 011 = 110
715        assert_eq!(Block64(5) + Block64(3), Block64(6));
716    }
717
718    #[test]
719    fn mul_simple() {
720        // Check for prime numbers (without overflow)
721        // x^1 * x^1 = x^2 (2 * 2 = 4)
722        assert_eq!(Block64(2) * Block64(2), Block64(4));
723    }
724
725    #[test]
726    fn mul_overflow() {
727        // Reduction verification (AES test vectors)
728        // Example from the AES specification:
729        // 0x57 * 0x83 = 0xC1
730        assert_eq!(Block64(0x57) * Block64(0x83), Block64(0xC1));
731    }
732
733    #[test]
734    fn karatsuba_correctness() {
735        // Let's check using Block64 as an example
736        // Let A = X (hi=1, lo=0)
737        // Let B = X (hi=1, lo=0)
738        // A * B = X^2
739        // According to the rule:
740        // X^2 = X + tau
741        // Where tau for Block32 = 0x2000_0000.
742        // So the result should be:
743        // hi=1 (X), lo=0x20 (tau)
744
745        // Construct X manually
746        let x = Block64::new(Block32::ZERO, Block32::ONE);
747        let squared = x * x;
748
749        // Verify result via splitting
750        let (res_lo, res_hi) = squared.split();
751
752        assert_eq!(res_hi, Block32::ONE, "X^2 should contain X component");
753        assert_eq!(
754            res_lo,
755            Block32(0x2000_0000),
756            "X^2 should contain tau component (0x2000_0000)"
757        );
758    }
759
760    #[test]
761    fn security_zeroize() {
762        let mut secret_val = Block64::from(0xDEAD_BEEF_CAFE_BABE_u64);
763        assert_ne!(secret_val, Block64::ZERO);
764
765        secret_val.zeroize();
766
767        assert_eq!(secret_val, Block64::ZERO);
768        assert_eq!(secret_val.0, 0, "Block64 memory leak detected");
769    }
770
771    #[test]
772    fn invert_zero() {
773        // Zero check
774        assert_eq!(
775            Block64::ZERO.invert(),
776            Block64::ZERO,
777            "invert(0) must return 0"
778        );
779    }
780
781    #[test]
782    fn inversion_random() {
783        let mut rng = rng();
784        for _ in 0..1000 {
785            let val = Block64(rng.random());
786            if val != Block64::ZERO {
787                let inv = val.invert();
788                assert_eq!(
789                    val * inv,
790                    Block64::ONE,
791                    "Inversion identity failed: a * a^-1 != 1"
792                );
793            }
794        }
795    }
796
797    #[test]
798    fn tower_embedding() {
799        let mut rng = rng();
800        for _ in 0..100 {
801            let a = Block32(rng.random());
802            let b = Block32(rng.random());
803
804            // 1. Structure check
805            let a_lifted: Block64 = a.into();
806            let (lo, hi) = a_lifted.split();
807
808            assert_eq!(lo, a, "Embedding structure failed: low part mismatch");
809            assert_eq!(
810                hi,
811                Block32::ZERO,
812                "Embedding structure failed: high part must be zero"
813            );
814
815            // 2. Addition Homomorphism
816            let sum_sub = a + b;
817            let sum_lifted: Block64 = sum_sub.into();
818            let sum_in_super = Block64::from(a) + Block64::from(b);
819
820            assert_eq!(sum_lifted, sum_in_super, "Homomorphism failed: add");
821
822            // 3. Multiplication Homomorphism
823            let prod_sub = a * b;
824            let prod_lifted: Block64 = prod_sub.into();
825            let prod_in_super = Block64::from(a) * Block64::from(b);
826
827            assert_eq!(prod_lifted, prod_in_super, "Homomorphism failed: mul");
828        }
829    }
830
831    // ==================================
832    // HARDWARE
833    // ==================================
834
835    #[test]
836    fn isomorphism_roundtrip() {
837        let mut rng = rng();
838        for _ in 0..1000 {
839            let val = Block64(rng.random::<u64>());
840            assert_eq!(val.to_hardware().to_tower(), val);
841        }
842    }
843
844    #[test]
845    fn flat_mul_homomorphism() {
846        let mut rng = rng();
847        for _ in 0..1000 {
848            let a = Block64(rng.random());
849            let b = Block64(rng.random());
850
851            let expected_flat = (a * b).to_hardware();
852            let actual_flat = a.to_hardware() * b.to_hardware();
853
854            assert_eq!(
855                actual_flat, expected_flat,
856                "Block64 flat multiplication mismatch: (a*b)^H != a^H * b^H"
857            );
858        }
859    }
860
861    #[test]
862    fn packed_consistency() {
863        let mut rng = rng();
864        for _ in 0..100 {
865            let a_vals = [Block64(rng.random()), Block64(rng.random())];
866            let b_vals = [Block64(rng.random()), Block64(rng.random())];
867
868            let a_flat_vals = a_vals.map(|x| x.to_hardware());
869            let b_flat_vals = b_vals.map(|x| x.to_hardware());
870            let a_packed = Flat::<Block64>::pack(&a_flat_vals);
871            let b_packed = Flat::<Block64>::pack(&b_flat_vals);
872
873            // 1. Test SIMD Add (XOR)
874            let add_res = Block64::add_hardware_packed(a_packed, b_packed);
875
876            let mut add_out = [Block64::ZERO.to_hardware(); 2];
877            Flat::<Block64>::unpack(add_res, &mut add_out);
878
879            assert_eq!(add_out[0], (a_vals[0] + b_vals[0]).to_hardware());
880            assert_eq!(add_out[1], (a_vals[1] + b_vals[1]).to_hardware());
881
882            // 2. Test SIMD Mul (Isomorphic/Flat basis)
883            let mul_res = Block64::mul_hardware_packed(a_packed, b_packed);
884
885            let mut mul_out = [Block64::ZERO.to_hardware(); 2];
886            Flat::<Block64>::unpack(mul_res, &mut mul_out);
887
888            assert_eq!(
889                mul_out[0],
890                (a_vals[0] * b_vals[0]).to_hardware(),
891                "Block64 SIMD mul mismatch at index 0"
892            );
893            assert_eq!(
894                mul_out[1],
895                (a_vals[1] * b_vals[1]).to_hardware(),
896                "Block64 SIMD mul mismatch at index 1"
897            );
898        }
899    }
900
901    // ==================================
902    // PACKED
903    // ==================================
904
905    #[test]
906    fn pack_unpack_roundtrip() {
907        let mut rng = rng();
908        let data = [Block64(rng.random()), Block64(rng.random())];
909
910        let packed = Block64::pack(&data);
911        let mut unpacked = [Block64::ZERO; 2];
912
913        Block64::unpack(packed, &mut unpacked);
914        assert_eq!(data, unpacked);
915    }
916
917    #[test]
918    fn packed_add_consistency() {
919        let mut rng = rng();
920        let a_vals = [Block64(rng.random()), Block64(rng.random())];
921        let b_vals = [Block64(rng.random()), Block64(rng.random())];
922
923        let res_packed = Block64::pack(&a_vals) + Block64::pack(&b_vals);
924        let mut res_unpacked = [Block64::ZERO; 2];
925        Block64::unpack(res_packed, &mut res_unpacked);
926
927        assert_eq!(res_unpacked[0], a_vals[0] + b_vals[0]);
928        assert_eq!(res_unpacked[1], a_vals[1] + b_vals[1]);
929    }
930
931    #[test]
932    fn packed_mul_consistency() {
933        let mut rng = rng();
934
935        for _ in 0..1000 {
936            let mut a_arr = [Block64::ZERO; PACKED_WIDTH_64];
937            let mut b_arr = [Block64::ZERO; PACKED_WIDTH_64];
938
939            for i in 0..PACKED_WIDTH_64 {
940                let val_a: u64 = rng.random();
941                let val_b: u64 = rng.random();
942                a_arr[i] = Block64(val_a);
943                b_arr[i] = Block64(val_b);
944            }
945
946            let a_packed = PackedBlock64(a_arr);
947            let b_packed = PackedBlock64(b_arr);
948
949            // Perform SIMD multiplication
950            let c_packed = a_packed * b_packed;
951
952            // Verify against Scalar
953            let mut c_expected = [Block64::ZERO; PACKED_WIDTH_64];
954            for i in 0..PACKED_WIDTH_64 {
955                c_expected[i] = a_arr[i] * b_arr[i];
956            }
957
958            assert_eq!(c_packed.0, c_expected, "SIMD Block64 mismatch!");
959        }
960    }
961
962    proptest! {
963        #[test]
964        fn parity_masks_match_from_hardware(x_flat in any::<u64>()) {
965            let tower = Block64::from_hardware(Flat::from_raw(Block64(x_flat))).0;
966
967            for (k, &mask) in FLAT_TO_TOWER_BIT_MASKS_64.iter().enumerate() {
968                // Ensure the static masks
969                // themselves are correct.
970                let parity = ((x_flat & mask).count_ones() & 1) as u8;
971                let bit = ((tower >> k) & 1) as u8;
972                prop_assert_eq!(parity, bit, "Block64 static mask mismatch at k={}", k);
973
974                // Ensure XOR-tree implementation matches.
975                let via_api = Flat::from_raw(Block64(x_flat)).tower_bit(k);
976                prop_assert_eq!(
977                    via_api, bit,
978                    "Block64 tower_bit_from_hardware mismatch at x_flat={:#018x}, bit_idx={}",
979                    x_flat, k
980                );
981            }
982        }
983    }
984}