Skip to main content

hekate_math/towers/
block256.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate-math project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! BLOCK 256 (GF(2^256))
19use crate::{Bit, Block8, Block16, Block32, Block64, Block128};
20use crate::{
21    CanonicalDeserialize, CanonicalSerialize, Flat, FlatPromote, HardwareField, PackableField,
22    PackedFlat, TowerField,
23};
24use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
25use serde::{Deserialize, Serialize};
26use zeroize::Zeroize;
27
28// Flat<Block256> = Flat<Block128>[y] / (y² + y + τ_flat).
29// τ_flat = to_hardware(Block128::EXTENSION_TAU).
30const TAU_FLAT: u128 = 0x66340c45203fe3685d08f8c248334a81;
31
32#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
33#[repr(C, align(32))]
34pub struct Block256(pub [u128; 2]); // [lo, hi]
35
36impl Block256 {
37    const TAU: Self = Block256([0, 0x2000_0000_0000_0000_0000_0000_0000_0000]);
38
39    pub fn new(lo: Block128, hi: Block128) -> Self {
40        Self([lo.0, hi.0])
41    }
42
43    #[inline(always)]
44    pub fn split(self) -> (Block128, Block128) {
45        (Block128(self.0[0]), Block128(self.0[1]))
46    }
47}
48
49impl TowerField for Block256 {
50    const BITS: usize = 256;
51    const ZERO: Self = Block256([0, 0]);
52    const ONE: Self = Block256([1, 0]);
53
54    const EXTENSION_TAU: Self = Self::TAU;
55
56    fn invert(&self) -> Self {
57        let (l, h) = self.split();
58        let h2 = h * h;
59        let l2 = l * l;
60        let hl = h * l;
61        let norm = (h2 * Block128::EXTENSION_TAU) + hl + l2;
62
63        let norm_inv = norm.invert();
64        let res_hi = h * norm_inv;
65        let res_lo = (h + l) * norm_inv;
66
67        Self::new(res_lo, res_hi)
68    }
69
70    fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
71        let mut lo_buf = [0u8; 16];
72        let mut hi_buf = [0u8; 16];
73
74        lo_buf.copy_from_slice(&bytes[0..16]);
75        hi_buf.copy_from_slice(&bytes[16..32]);
76
77        Self([u128::from_le_bytes(lo_buf), u128::from_le_bytes(hi_buf)])
78    }
79}
80
81impl Add for Block256 {
82    type Output = Self;
83
84    fn add(self, rhs: Self) -> Self {
85        Self([self.0[0] ^ rhs.0[0], self.0[1] ^ rhs.0[1]])
86    }
87}
88
89impl Sub for Block256 {
90    type Output = Self;
91
92    fn sub(self, rhs: Self) -> Self {
93        self.add(rhs)
94    }
95}
96
97impl Mul for Block256 {
98    type Output = Self;
99
100    fn mul(self, rhs: Self) -> Self {
101        let (a0, a1) = self.split();
102        let (b0, b1) = rhs.split();
103
104        let v0 = a0 * b0;
105        let v1 = a1 * b1;
106        let v_sum = (a0 + a1) * (b0 + b1);
107
108        let c_hi = v0 + v_sum;
109        let c_lo = v0 + (v1 * Block128::EXTENSION_TAU);
110
111        Self::new(c_lo, c_hi)
112    }
113}
114
115impl AddAssign for Block256 {
116    fn add_assign(&mut self, rhs: Self) {
117        self.0[0] ^= rhs.0[0];
118        self.0[1] ^= rhs.0[1];
119    }
120}
121
122impl SubAssign for Block256 {
123    fn sub_assign(&mut self, rhs: Self) {
124        self.0[0] ^= rhs.0[0];
125        self.0[1] ^= rhs.0[1];
126    }
127}
128
129impl MulAssign for Block256 {
130    fn mul_assign(&mut self, rhs: Self) {
131        *self = *self * rhs;
132    }
133}
134
135impl CanonicalSerialize for Block256 {
136    fn serialized_size(&self) -> usize {
137        32
138    }
139
140    fn serialize(&self, writer: &mut [u8]) -> Result<(), ()> {
141        if writer.len() < 32 {
142            return Err(());
143        }
144
145        writer[0..16].copy_from_slice(&self.0[0].to_le_bytes());
146        writer[16..32].copy_from_slice(&self.0[1].to_le_bytes());
147
148        Ok(())
149    }
150}
151
152impl CanonicalDeserialize for Block256 {
153    fn deserialize(bytes: &[u8]) -> Result<Self, ()> {
154        if bytes.len() < 32 {
155            return Err(());
156        }
157
158        let mut lo_buf = [0u8; 16];
159        let mut hi_buf = [0u8; 16];
160
161        lo_buf.copy_from_slice(&bytes[0..16]);
162        hi_buf.copy_from_slice(&bytes[16..32]);
163
164        Ok(Self([
165            u128::from_le_bytes(lo_buf),
166            u128::from_le_bytes(hi_buf),
167        ]))
168    }
169}
170
171impl From<u8> for Block256 {
172    fn from(val: u8) -> Self {
173        Self([val as u128, 0])
174    }
175}
176
177impl From<u32> for Block256 {
178    #[inline]
179    fn from(val: u32) -> Self {
180        Self([val as u128, 0])
181    }
182}
183
184impl From<u64> for Block256 {
185    #[inline]
186    fn from(val: u64) -> Self {
187        Self([val as u128, 0])
188    }
189}
190
191impl From<u128> for Block256 {
192    #[inline]
193    fn from(val: u128) -> Self {
194        Self([val, 0])
195    }
196}
197
198impl From<Bit> for Block256 {
199    #[inline(always)]
200    fn from(val: Bit) -> Self {
201        Self([val.0 as u128, 0])
202    }
203}
204
205impl From<Block8> for Block256 {
206    #[inline(always)]
207    fn from(val: Block8) -> Self {
208        Self([val.0 as u128, 0])
209    }
210}
211
212impl From<Block16> for Block256 {
213    #[inline(always)]
214    fn from(val: Block16) -> Self {
215        Self([val.0 as u128, 0])
216    }
217}
218
219impl From<Block32> for Block256 {
220    #[inline(always)]
221    fn from(val: Block32) -> Self {
222        Self([val.0 as u128, 0])
223    }
224}
225
226impl From<Block64> for Block256 {
227    #[inline(always)]
228    fn from(val: Block64) -> Self {
229        Self([val.0 as u128, 0])
230    }
231}
232
233impl From<Block128> for Block256 {
234    #[inline(always)]
235    fn from(val: Block128) -> Self {
236        Self([val.0, 0])
237    }
238}
239
240// ===================================
241// PACKED BLOCK 256 (Width = 2)
242// ===================================
243
244pub const PACKED_WIDTH_256: usize = 2;
245
246#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
247#[repr(C, align(64))]
248pub struct PackedBlock256(pub [Block256; PACKED_WIDTH_256]);
249
250impl PackedBlock256 {
251    #[inline(always)]
252    pub fn zero() -> Self {
253        Self([Block256::ZERO; PACKED_WIDTH_256])
254    }
255
256    #[inline(always)]
257    pub fn broadcast(val: Block256) -> Self {
258        Self([val; PACKED_WIDTH_256])
259    }
260}
261
262impl PackableField for Block256 {
263    type Packed = PackedBlock256;
264
265    const WIDTH: usize = PACKED_WIDTH_256;
266
267    #[inline(always)]
268    fn pack(chunk: &[Self]) -> Self::Packed {
269        assert!(
270            chunk.len() >= PACKED_WIDTH_256,
271            "PackableField::pack: input slice too short",
272        );
273
274        let mut arr = [Self::ZERO; PACKED_WIDTH_256];
275        arr.copy_from_slice(&chunk[..PACKED_WIDTH_256]);
276
277        PackedBlock256(arr)
278    }
279
280    #[inline(always)]
281    fn unpack(packed: Self::Packed, output: &mut [Self]) {
282        assert!(
283            output.len() >= PACKED_WIDTH_256,
284            "PackableField::unpack: output slice too short",
285        );
286
287        output[..PACKED_WIDTH_256].copy_from_slice(&packed.0);
288    }
289}
290
291impl Add for PackedBlock256 {
292    type Output = Self;
293
294    #[inline(always)]
295    fn add(self, rhs: Self) -> Self {
296        let mut res = [Block256::ZERO; PACKED_WIDTH_256];
297        for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
298            *out = *l + *r;
299        }
300
301        Self(res)
302    }
303}
304
305impl AddAssign for PackedBlock256 {
306    #[inline(always)]
307    fn add_assign(&mut self, rhs: Self) {
308        for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
309            *l += *r;
310        }
311    }
312}
313
314impl Sub for PackedBlock256 {
315    type Output = Self;
316
317    #[inline(always)]
318    fn sub(self, rhs: Self) -> Self {
319        self.add(rhs)
320    }
321}
322
323impl SubAssign for PackedBlock256 {
324    #[inline(always)]
325    fn sub_assign(&mut self, rhs: Self) {
326        self.add_assign(rhs);
327    }
328}
329
330impl Mul for PackedBlock256 {
331    type Output = Self;
332
333    #[inline(always)]
334    fn mul(self, rhs: Self) -> Self {
335        let mut res = [Block256::ZERO; PACKED_WIDTH_256];
336        for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
337            *out = *l * *r;
338        }
339
340        Self(res)
341    }
342}
343
344impl MulAssign for PackedBlock256 {
345    #[inline(always)]
346    fn mul_assign(&mut self, rhs: Self) {
347        for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
348            *l *= *r;
349        }
350    }
351}
352
353impl Mul<Block256> for PackedBlock256 {
354    type Output = Self;
355
356    #[inline(always)]
357    fn mul(self, rhs: Block256) -> Self {
358        let mut res = [Block256::ZERO; PACKED_WIDTH_256];
359        for (out, v) in res.iter_mut().zip(self.0.iter()) {
360            *out = *v * rhs;
361        }
362
363        Self(res)
364    }
365}
366
367impl MulAssign<Block256> for PackedBlock256 {
368    #[inline(always)]
369    fn mul_assign(&mut self, rhs: Block256) {
370        for v in self.0.iter_mut() {
371            *v *= rhs;
372        }
373    }
374}
375
376impl HardwareField for Block256 {
377    #[inline(always)]
378    fn to_hardware(self) -> Flat<Self> {
379        let (lo, hi) = self.split();
380        let flat_lo = lo.to_hardware().into_raw().0;
381        let flat_hi = hi.to_hardware().into_raw().0;
382
383        Flat::from_raw(Block256([flat_lo, flat_hi]))
384    }
385
386    #[inline(always)]
387    fn from_hardware(value: Flat<Self>) -> Self {
388        let raw = value.into_raw();
389        let lo = Block128::from_hardware(Flat::from_raw(Block128(raw.0[0])));
390        let hi = Block128::from_hardware(Flat::from_raw(Block128(raw.0[1])));
391
392        Self::new(lo, hi)
393    }
394
395    #[inline(always)]
396    fn add_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
397        let l = lhs.into_raw();
398        let r = rhs.into_raw();
399
400        Flat::from_raw(Block256([l.0[0] ^ r.0[0], l.0[1] ^ r.0[1]]))
401    }
402
403    #[inline(always)]
404    fn add_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
405        PackedFlat::from_raw(lhs.into_raw() + rhs.into_raw())
406    }
407
408    #[inline(always)]
409    fn mul_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
410        let a_lo = Flat::from_raw(Block128(lhs.into_raw().0[0]));
411        let a_hi = Flat::from_raw(Block128(lhs.into_raw().0[1]));
412        let b_lo = Flat::from_raw(Block128(rhs.into_raw().0[0]));
413        let b_hi = Flat::from_raw(Block128(rhs.into_raw().0[1]));
414
415        let tau = Flat::from_raw(Block128(TAU_FLAT));
416
417        let v0 = Block128::mul_hardware(a_lo, b_lo);
418        let v1 = Block128::mul_hardware(a_hi, b_hi);
419
420        let a_sum = Block128::add_hardware(a_lo, a_hi);
421        let b_sum = Block128::add_hardware(b_lo, b_hi);
422        let v_sum = Block128::mul_hardware(a_sum, b_sum);
423
424        let c_hi = Block128::add_hardware(v0, v_sum);
425
426        let v1_tau = Block128::mul_hardware(v1, tau);
427        let c_lo = Block128::add_hardware(v0, v1_tau);
428
429        Flat::from_raw(Block256([c_lo.into_raw().0, c_hi.into_raw().0]))
430    }
431
432    #[inline(always)]
433    fn mul_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
434        let lhs = lhs.into_raw().0;
435        let rhs = rhs.into_raw().0;
436
437        let mut res = [Block256::ZERO; PACKED_WIDTH_256];
438        for i in 0..PACKED_WIDTH_256 {
439            res[i] = Self::mul_hardware(Flat::from_raw(lhs[i]), Flat::from_raw(rhs[i])).into_raw();
440        }
441
442        PackedFlat::from_raw(PackedBlock256(res))
443    }
444
445    #[inline(always)]
446    fn mul_hardware_scalar_packed(lhs: PackedFlat<Self>, rhs: Flat<Self>) -> PackedFlat<Self> {
447        let broadcasted = PackedBlock256::broadcast(rhs.into_raw());
448        Self::mul_hardware_packed(lhs, PackedFlat::from_raw(broadcasted))
449    }
450
451    #[inline(always)]
452    fn tower_bit_from_hardware(value: Flat<Self>, bit_idx: usize) -> u8 {
453        if bit_idx < 128 {
454            Block128::tower_bit_from_hardware(
455                Flat::from_raw(Block128(value.into_raw().0[0])),
456                bit_idx,
457            )
458        } else {
459            Block128::tower_bit_from_hardware(
460                Flat::from_raw(Block128(value.into_raw().0[1])),
461                bit_idx - 128,
462            )
463        }
464    }
465}
466
467const PROMOTE_CHUNK: usize = 64;
468
469impl FlatPromote<Block8> for Block256 {
470    #[inline(always)]
471    fn promote_flat(val: Flat<Block8>) -> Flat<Self> {
472        let promoted = Block128::promote_flat(val);
473        Flat::from_raw(Block256([promoted.into_raw().0, 0]))
474    }
475
476    fn promote_flat_batch(input: &[Flat<Block8>], output: &mut [Flat<Self>]) {
477        promote_chunked(input, output);
478    }
479}
480
481impl FlatPromote<Block16> for Block256 {
482    #[inline(always)]
483    fn promote_flat(val: Flat<Block16>) -> Flat<Self> {
484        let promoted = Block128::promote_flat(val);
485        Flat::from_raw(Block256([promoted.into_raw().0, 0]))
486    }
487
488    fn promote_flat_batch(input: &[Flat<Block16>], output: &mut [Flat<Self>]) {
489        promote_chunked(input, output);
490    }
491}
492
493impl FlatPromote<Block32> for Block256 {
494    #[inline(always)]
495    fn promote_flat(val: Flat<Block32>) -> Flat<Self> {
496        let promoted = Block128::promote_flat(val);
497        Flat::from_raw(Block256([promoted.into_raw().0, 0]))
498    }
499
500    fn promote_flat_batch(input: &[Flat<Block32>], output: &mut [Flat<Self>]) {
501        promote_chunked(input, output);
502    }
503}
504
505impl FlatPromote<Block64> for Block256 {
506    #[inline(always)]
507    fn promote_flat(val: Flat<Block64>) -> Flat<Self> {
508        let promoted = Block128::promote_flat(val);
509        Flat::from_raw(Block256([promoted.into_raw().0, 0]))
510    }
511
512    fn promote_flat_batch(input: &[Flat<Block64>], output: &mut [Flat<Self>]) {
513        promote_chunked(input, output);
514    }
515}
516
517impl FlatPromote<Block128> for Block256 {
518    #[inline(always)]
519    fn promote_flat(val: Flat<Block128>) -> Flat<Self> {
520        Flat::from_raw(Block256([val.into_raw().0, 0]))
521    }
522
523    fn promote_flat_batch(input: &[Flat<Block128>], output: &mut [Flat<Self>]) {
524        let n = input.len().min(output.len());
525        for i in 0..n {
526            output[i] = Flat::from_raw(Block256([input[i].into_raw().0, 0]));
527        }
528    }
529}
530
531#[inline(always)]
532fn promote_chunked<FromF>(input: &[Flat<FromF>], output: &mut [Flat<Block256>])
533where
534    FromF: HardwareField,
535    Block128: FlatPromote<FromF>,
536{
537    let n = input.len().min(output.len());
538
539    let mut scratch = [Flat::from_raw(Block128::ZERO); PROMOTE_CHUNK];
540    let mut i = 0;
541
542    while i < n {
543        let len = (n - i).min(PROMOTE_CHUNK);
544        Block128::promote_flat_batch(&input[i..i + len], &mut scratch[..len]);
545
546        for j in 0..len {
547            output[i + j] = Flat::from_raw(Block256([scratch[j].into_raw().0, 0]));
548        }
549
550        i += len;
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use rand::{RngExt, rng};
558
559    #[test]
560    fn tau_flat_matches_derived() {
561        let derived = Block128::EXTENSION_TAU.to_hardware().into_raw().0;
562        assert_eq!(
563            TAU_FLAT, derived,
564            "TAU_FLAT drifted from Block128::EXTENSION_TAU.to_hardware()",
565        );
566    }
567
568    // ==================================
569    // BASIC
570    // ==================================
571
572    #[test]
573    fn tower_constants() {
574        // Check that tau is propagated correctly
575        // For Block256, tau must be (0, EXTENSION_TAU) from Block128.
576        let tau256 = Block256::EXTENSION_TAU;
577        let (lo256, hi256) = tau256.split();
578        assert_eq!(lo256, Block128::ZERO);
579        assert_eq!(hi256, Block128::EXTENSION_TAU);
580    }
581
582    #[test]
583    fn add_truth() {
584        let zero = Block256::ZERO;
585        let one = Block256::ONE;
586
587        assert_eq!(zero + zero, zero);
588        assert_eq!(zero + one, one);
589        assert_eq!(one + zero, one);
590        assert_eq!(one + one, zero);
591    }
592
593    #[test]
594    fn mul_truth() {
595        let zero = Block256::ZERO;
596        let one = Block256::ONE;
597
598        assert_eq!(zero * zero, zero);
599        assert_eq!(zero * one, zero);
600        assert_eq!(one * one, one);
601    }
602
603    #[test]
604    fn add() {
605        // 5 ^ 3 = 6
606        // 101 ^ 011 = 110
607        assert_eq!(Block256([5, 0]) + Block256([3, 0]), Block256([6, 0]));
608    }
609
610    #[test]
611    fn mul_simple() {
612        // x^1 * x^1 = x^2 (2 * 2 = 4) inside the Block8 subfield
613        assert_eq!(
614            Block256::from(2u32) * Block256::from(2u32),
615            Block256::from(4u32)
616        );
617    }
618
619    #[test]
620    fn mul_overflow() {
621        // AES reduction: 0x57 * 0x83 = 0xC1 inside the Block8 subfield
622        assert_eq!(
623            Block256::from(0x57u32) * Block256::from(0x83u32),
624            Block256::from(0xC1u32)
625        );
626    }
627
628    #[test]
629    fn karatsuba_correctness() {
630        // Y = (hi=ONE, lo=ZERO). Y^2 = Y + tau_256.
631        // So the result must be:
632        // hi = Block128::ONE (the Y component),
633        // lo = Block128::EXTENSION_TAU (the tau component).
634        let y = Block256::new(Block128::ZERO, Block128::ONE);
635        let squared = y * y;
636
637        let (res_lo, res_hi) = squared.split();
638
639        assert_eq!(res_hi, Block128::ONE, "Y^2 should contain Y component");
640        assert_eq!(
641            res_lo,
642            Block128::EXTENSION_TAU,
643            "Y^2 should contain tau_256 component"
644        );
645    }
646
647    #[test]
648    fn security_zeroize() {
649        let mut secret_val = Block256([0xDEAD_BEEF_CAFE_BABE_u128, 0xFEED_FACE_BAAD_F00D_u128]);
650        assert_ne!(secret_val, Block256::ZERO);
651
652        secret_val.zeroize();
653
654        assert_eq!(secret_val, Block256::ZERO, "Memory was not wiped!");
655        assert_eq!(
656            secret_val.0,
657            [0u128, 0u128],
658            "Underlying memory leak detected"
659        );
660    }
661
662    #[test]
663    fn invert_zero() {
664        assert_eq!(
665            Block256::ZERO.invert(),
666            Block256::ZERO,
667            "invert(0) must return 0"
668        );
669    }
670
671    #[test]
672    fn inversion_random() {
673        let mut rng = rng();
674        for _ in 0..1000 {
675            let val = Block256([rng.random(), rng.random()]);
676
677            if val != Block256::ZERO {
678                let inv = val.invert();
679                let identity = val * inv;
680
681                assert_eq!(
682                    identity,
683                    Block256::ONE,
684                    "Inversion identity failed: a * a^-1 != 1"
685                );
686            }
687        }
688    }
689
690    #[test]
691    fn tower_embedding() {
692        let mut rng = rng();
693        for _ in 0..100 {
694            let a = Block128(rng.random());
695            let b = Block128(rng.random());
696
697            // 1. Structure:
698            // Block128 -> Block256
699            let a_lifted: Block256 = a.into();
700            let (lo, hi) = a_lifted.split();
701
702            assert_eq!(lo, a, "Embedding structure failed: low part mismatch");
703            assert_eq!(
704                hi,
705                Block128::ZERO,
706                "Embedding structure failed: high part must be zero"
707            );
708
709            // 2. Addition Homomorphism
710            let sum_sub = a + b;
711            let sum_lifted: Block256 = sum_sub.into();
712            let sum_in_super = Block256::from(a) + Block256::from(b);
713
714            assert_eq!(sum_lifted, sum_in_super, "Homomorphism failed: add");
715
716            // 3. Multiplication Homomorphism
717            let prod_sub = a * b;
718            let prod_lifted: Block256 = prod_sub.into();
719            let prod_in_super = Block256::from(a) * Block256::from(b);
720
721            assert_eq!(prod_lifted, prod_in_super, "Homomorphism failed: mul");
722        }
723    }
724
725    // ==================================
726    // HARDWARE
727    // ==================================
728
729    #[test]
730    fn isomorphism_roundtrip() {
731        let mut rng = rng();
732        for _ in 0..1000 {
733            let val = Block256([rng.random::<u128>(), rng.random::<u128>()]);
734            assert_eq!(val.to_hardware().to_tower(), val);
735        }
736    }
737
738    #[test]
739    fn flat_mul_homomorphism() {
740        let mut rng = rng();
741        for _ in 0..1000 {
742            let a = Block256([rng.random(), rng.random()]);
743            let b = Block256([rng.random(), rng.random()]);
744
745            let expected_flat = (a * b).to_hardware();
746            let actual_flat = a.to_hardware() * b.to_hardware();
747
748            assert_eq!(
749                actual_flat, expected_flat,
750                "Block256 flat multiplication mismatch: (a*b)^H != a^H * b^H"
751            );
752        }
753    }
754
755    #[test]
756    fn packed_consistency() {
757        let mut rng = rng();
758        for _ in 0..100 {
759            let mut a_vals = [Block256::ZERO; PACKED_WIDTH_256];
760            let mut b_vals = [Block256::ZERO; PACKED_WIDTH_256];
761
762            for i in 0..PACKED_WIDTH_256 {
763                a_vals[i] = Block256([rng.random::<u128>(), rng.random::<u128>()]);
764                b_vals[i] = Block256([rng.random::<u128>(), rng.random::<u128>()]);
765            }
766
767            let a_flat_vals = a_vals.map(|x| x.to_hardware());
768            let b_flat_vals = b_vals.map(|x| x.to_hardware());
769            let a_packed = Flat::<Block256>::pack(&a_flat_vals);
770            let b_packed = Flat::<Block256>::pack(&b_flat_vals);
771
772            let add_res = Block256::add_hardware_packed(a_packed, b_packed);
773
774            let mut add_out = [Block256::ZERO.to_hardware(); PACKED_WIDTH_256];
775            Flat::<Block256>::unpack(add_res, &mut add_out);
776
777            for i in 0..PACKED_WIDTH_256 {
778                assert_eq!(
779                    add_out[i],
780                    (a_vals[i] + b_vals[i]).to_hardware(),
781                    "Block256 SIMD add mismatch at index {}",
782                    i
783                );
784            }
785
786            let mul_res = Block256::mul_hardware_packed(a_packed, b_packed);
787
788            let mut mul_out = [Block256::ZERO.to_hardware(); PACKED_WIDTH_256];
789            Flat::<Block256>::unpack(mul_res, &mut mul_out);
790
791            for i in 0..PACKED_WIDTH_256 {
792                let expected_flat = (a_vals[i] * b_vals[i]).to_hardware();
793                assert_eq!(
794                    mul_out[i], expected_flat,
795                    "Block256 SIMD mul mismatch at index {}",
796                    i
797                );
798            }
799        }
800    }
801
802    #[test]
803    fn tower_bit_from_hardware_matches_tower() {
804        let mut rng = rng();
805        for _ in 0..64 {
806            let val = Block256([rng.random::<u128>(), rng.random::<u128>()]);
807            let flat = val.to_hardware();
808
809            for bit in 0..Block256::BITS {
810                let expected = if bit < 128 {
811                    ((val.0[0] >> bit) & 1) as u8
812                } else {
813                    ((val.0[1] >> (bit - 128)) & 1) as u8
814                };
815
816                assert_eq!(
817                    Block256::tower_bit_from_hardware(flat, bit),
818                    expected,
819                    "tower_bit mismatch at bit {}",
820                    bit
821                );
822            }
823        }
824    }
825
826    // ==================================
827    // PROMOTE
828    // ==================================
829
830    #[test]
831    fn promote_flat_batch_matches_scalar_block8() {
832        let mut rng = rng();
833        let input: Vec<Flat<Block8>> = (0..200)
834            .map(|_| Block8(rng.random::<u8>()).to_hardware())
835            .collect();
836
837        let mut batch_out = vec![Flat::from_raw(Block256::ZERO); input.len()];
838        <Block256 as FlatPromote<Block8>>::promote_flat_batch(&input, &mut batch_out);
839
840        for i in 0..input.len() {
841            let scalar = <Block256 as FlatPromote<Block8>>::promote_flat(input[i]);
842            assert_eq!(
843                batch_out[i], scalar,
844                "Block8 batch/scalar mismatch at {}",
845                i
846            );
847        }
848    }
849
850    #[test]
851    fn promote_flat_batch_matches_scalar_block16() {
852        let mut rng = rng();
853        let input: Vec<Flat<Block16>> = (0..200)
854            .map(|_| Block16(rng.random::<u16>()).to_hardware())
855            .collect();
856
857        let mut batch_out = vec![Flat::from_raw(Block256::ZERO); input.len()];
858        <Block256 as FlatPromote<Block16>>::promote_flat_batch(&input, &mut batch_out);
859
860        for i in 0..input.len() {
861            let scalar = <Block256 as FlatPromote<Block16>>::promote_flat(input[i]);
862            assert_eq!(
863                batch_out[i], scalar,
864                "Block16 batch/scalar mismatch at {}",
865                i
866            );
867        }
868    }
869
870    #[test]
871    fn promote_flat_batch_matches_scalar_block32() {
872        let mut rng = rng();
873        let input: Vec<Flat<Block32>> = (0..200)
874            .map(|_| Block32(rng.random::<u32>()).to_hardware())
875            .collect();
876
877        let mut batch_out = vec![Flat::from_raw(Block256::ZERO); input.len()];
878        <Block256 as FlatPromote<Block32>>::promote_flat_batch(&input, &mut batch_out);
879
880        for i in 0..input.len() {
881            let scalar = <Block256 as FlatPromote<Block32>>::promote_flat(input[i]);
882            assert_eq!(
883                batch_out[i], scalar,
884                "Block32 batch/scalar mismatch at {}",
885                i
886            );
887        }
888    }
889
890    #[test]
891    fn promote_flat_batch_matches_scalar_block64() {
892        let mut rng = rng();
893        let input: Vec<Flat<Block64>> = (0..200)
894            .map(|_| Block64(rng.random::<u64>()).to_hardware())
895            .collect();
896
897        let mut batch_out = vec![Flat::from_raw(Block256::ZERO); input.len()];
898        <Block256 as FlatPromote<Block64>>::promote_flat_batch(&input, &mut batch_out);
899
900        for i in 0..input.len() {
901            let scalar = <Block256 as FlatPromote<Block64>>::promote_flat(input[i]);
902            assert_eq!(
903                batch_out[i], scalar,
904                "Block64 batch/scalar mismatch at {}",
905                i
906            );
907        }
908    }
909
910    #[test]
911    fn promote_flat_batch_matches_scalar_block128() {
912        let mut rng = rng();
913        let input: Vec<Flat<Block128>> = (0..200)
914            .map(|_| Block128(rng.random::<u128>()).to_hardware())
915            .collect();
916
917        let mut batch_out = vec![Flat::from_raw(Block256::ZERO); input.len()];
918        <Block256 as FlatPromote<Block128>>::promote_flat_batch(&input, &mut batch_out);
919
920        for i in 0..input.len() {
921            let scalar = <Block256 as FlatPromote<Block128>>::promote_flat(input[i]);
922            assert_eq!(
923                batch_out[i], scalar,
924                "Block128 batch/scalar mismatch at {}",
925                i
926            );
927        }
928    }
929
930    #[test]
931    fn promote_flat_batch_partial_slice() {
932        let mut rng = rng();
933        let input: Vec<Flat<Block8>> = (0..10)
934            .map(|_| Block8(rng.random::<u8>()).to_hardware())
935            .collect();
936
937        let mut out_short = vec![Flat::from_raw(Block256::ZERO); 5];
938        <Block256 as FlatPromote<Block8>>::promote_flat_batch(&input, &mut out_short);
939
940        for i in 0..5 {
941            let scalar = <Block256 as FlatPromote<Block8>>::promote_flat(input[i]);
942            assert_eq!(out_short[i], scalar);
943        }
944
945        let short_input = &input[..3];
946
947        let mut out_long = vec![Flat::from_raw(Block256::ZERO); 10];
948        <Block256 as FlatPromote<Block8>>::promote_flat_batch(short_input, &mut out_long);
949
950        for i in 0..3 {
951            let scalar = <Block256 as FlatPromote<Block8>>::promote_flat(short_input[i]);
952            assert_eq!(out_long[i], scalar);
953        }
954
955        for val in out_long.iter().skip(3) {
956            assert_eq!(*val, Flat::from_raw(Block256::ZERO));
957        }
958    }
959
960    #[test]
961    fn promote_flat_batch_across_chunk_boundary() {
962        let mut rng = rng();
963        // Exercise lengths straddling PROMOTE_CHUNK.
964        for &n in &[
965            PROMOTE_CHUNK - 1,
966            PROMOTE_CHUNK,
967            PROMOTE_CHUNK + 1,
968            PROMOTE_CHUNK * 2 + 3,
969        ] {
970            let input: Vec<Flat<Block8>> = (0..n)
971                .map(|_| Block8(rng.random::<u8>()).to_hardware())
972                .collect();
973
974            let mut batch_out = vec![Flat::from_raw(Block256::ZERO); n];
975            <Block256 as FlatPromote<Block8>>::promote_flat_batch(&input, &mut batch_out);
976
977            for i in 0..n {
978                let scalar = <Block256 as FlatPromote<Block8>>::promote_flat(input[i]);
979                assert_eq!(batch_out[i], scalar, "n={}, idx={}", n, i);
980            }
981        }
982    }
983}