Skip to main content

hekate_math/towers/
bit.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate-math project.
3// Copyright (C) 2026 Andrei Kochergin <zeek@tuta.com>
4// Copyright (C) 2026 Oumuamua Labs. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18use crate::{
19    Block8, CanonicalDeserialize, CanonicalSerialize, Flat, FlatPromote, HardwareField,
20    PackableField, PackedFlat, TowerField,
21};
22use core::ops::{Add, AddAssign, BitAnd, BitXor, Mul, MulAssign, Sub, SubAssign};
23use serde::{Deserialize, Serialize};
24use zeroize::Zeroize;
25
26// ==================================
27// BIT (GF(2))
28// ==================================
29
30#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
31#[repr(transparent)]
32pub struct Bit(pub u8);
33
34impl Bit {
35    pub const fn new(val: u8) -> Self {
36        Self(val & 1) // Self(val.bitand(1))
37    }
38}
39
40impl TowerField for Bit {
41    const BITS: usize = 1;
42    const ZERO: Self = Bit(0);
43    const ONE: Self = Bit(1);
44
45    // x^2 + x + 1 = 0 -> Irreducible over GF(2)
46    const EXTENSION_TAU: Self = Bit(1);
47
48    fn invert(&self) -> Self {
49        // In GF(2), the inverse of 1 is 1.
50        // By cryptographic convention, the
51        // inverse of 0 is defined as 0.
52        // Thus, inversion in GF(2) is
53        // just the identity function.
54        *self
55    }
56
57    fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
58        // Take LSB of first byte
59        Self(bytes[0] & 1)
60    }
61}
62
63/// Add (XOR)
64/// 0+0=0, 0+1=1, 1+0=1, 1+1=0
65impl Add for Bit {
66    type Output = Self;
67
68    fn add(self, rhs: Self) -> Self::Output {
69        Self(self.0.bitxor(rhs.0))
70    }
71}
72
73/// Sub is the same as add
74impl Sub for Bit {
75    type Output = Self;
76
77    fn sub(self, rhs: Self) -> Self::Output {
78        self.add(rhs)
79    }
80}
81
82/// Mul (AND)
83/// 0*0=0, 0*1=0, 1*1=1
84impl Mul for Bit {
85    type Output = Self;
86
87    fn mul(self, rhs: Self) -> Self::Output {
88        Self(self.0.bitand(rhs.0))
89    }
90}
91
92impl AddAssign for Bit {
93    fn add_assign(&mut self, rhs: Self) {
94        *self = *self + rhs
95    }
96}
97
98impl SubAssign for Bit {
99    fn sub_assign(&mut self, rhs: Self) {
100        *self = *self - rhs
101    }
102}
103
104impl MulAssign for Bit {
105    fn mul_assign(&mut self, rhs: Self) {
106        *self = *self * rhs;
107    }
108}
109
110impl CanonicalSerialize for Bit {
111    #[inline]
112    fn serialized_size(&self) -> usize {
113        1
114    }
115
116    #[inline]
117    fn serialize(&self, writer: &mut [u8]) -> Result<(), ()> {
118        if writer.is_empty() {
119            return Err(());
120        }
121
122        writer[0] = self.0;
123
124        Ok(())
125    }
126}
127
128impl CanonicalDeserialize for Bit {
129    fn deserialize(bytes: &[u8]) -> Result<Self, ()> {
130        if bytes.is_empty() {
131            return Err(());
132        }
133
134        if bytes[0] > 1 {
135            return Err(());
136        }
137
138        Ok(Self(bytes[0]))
139    }
140}
141
142impl From<u8> for Bit {
143    #[inline]
144    fn from(val: u8) -> Self {
145        Self(val & 1)
146    }
147}
148
149impl From<u32> for Bit {
150    #[inline]
151    fn from(val: u32) -> Self {
152        Self((val & 1) as u8)
153    }
154}
155
156impl From<u64> for Bit {
157    #[inline]
158    fn from(val: u64) -> Self {
159        Self((val & 1) as u8)
160    }
161}
162
163impl From<u128> for Bit {
164    #[inline]
165    fn from(val: u128) -> Self {
166        Self((val & 1) as u8)
167    }
168}
169
170// ===================================
171// PACKED BIT (Width = 64)
172// ===================================
173
174// 64 bytes = 512 bits = 4 SIMD registers (128-bit each)
175pub const PACKED_WIDTH_BIT: usize = 64;
176
177#[repr(C, align(64))]
178pub struct PackedBit(pub [Bit; PACKED_WIDTH_BIT]);
179
180impl Clone for PackedBit {
181    #[inline(always)]
182    fn clone(&self) -> Self {
183        *self
184    }
185}
186
187impl Copy for PackedBit {}
188
189impl Default for PackedBit {
190    #[inline(always)]
191    fn default() -> Self {
192        Self::zero()
193    }
194}
195
196impl PartialEq for PackedBit {
197    fn eq(&self, other: &Self) -> bool {
198        // Bit(u8) is transparent, direct slice
199        // comparison works and is fast.
200        self.0[..] == other.0[..]
201    }
202}
203
204impl Eq for PackedBit {}
205
206impl core::fmt::Debug for PackedBit {
207    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
208        write!(f, "PackedBit([size={}])", PACKED_WIDTH_BIT)
209    }
210}
211
212impl PackedBit {
213    #[inline(always)]
214    pub fn zero() -> Self {
215        Self([Bit::ZERO; PACKED_WIDTH_BIT])
216    }
217}
218
219impl PackableField for Bit {
220    type Packed = PackedBit;
221
222    const WIDTH: usize = PACKED_WIDTH_BIT;
223
224    #[inline(always)]
225    fn pack(chunk: &[Self]) -> Self::Packed {
226        assert!(
227            chunk.len() >= PACKED_WIDTH_BIT,
228            "PackableField::pack: input slice too short",
229        );
230
231        let mut arr = [Self::ZERO; PACKED_WIDTH_BIT];
232        arr.copy_from_slice(&chunk[..PACKED_WIDTH_BIT]);
233
234        PackedBit(arr)
235    }
236
237    #[inline(always)]
238    fn unpack(packed: Self::Packed, output: &mut [Self]) {
239        assert!(
240            output.len() >= PACKED_WIDTH_BIT,
241            "PackableField::unpack: output slice too short",
242        );
243
244        output[..PACKED_WIDTH_BIT].copy_from_slice(&packed.0);
245    }
246}
247
248impl Add for PackedBit {
249    type Output = Self;
250
251    #[inline(always)]
252    fn add(self, rhs: Self) -> Self {
253        #[cfg(target_arch = "aarch64")]
254        {
255            neon::add_packed_bit(self, rhs)
256        }
257
258        #[cfg(not(target_arch = "aarch64"))]
259        {
260            let mut res = [Bit::ZERO; PACKED_WIDTH_BIT];
261            for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
262                *out = *l + *r;
263            }
264
265            Self(res)
266        }
267    }
268}
269
270impl AddAssign for PackedBit {
271    #[inline(always)]
272    fn add_assign(&mut self, rhs: Self) {
273        for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
274            *l += *r;
275        }
276    }
277}
278
279impl Sub for PackedBit {
280    type Output = Self;
281
282    #[inline(always)]
283    fn sub(self, rhs: Self) -> Self {
284        self.add(rhs)
285    }
286}
287
288impl SubAssign for PackedBit {
289    #[inline(always)]
290    fn sub_assign(&mut self, rhs: Self) {
291        self.add_assign(rhs)
292    }
293}
294
295impl Mul for PackedBit {
296    type Output = Self;
297
298    #[inline(always)]
299    fn mul(self, rhs: Self) -> Self {
300        #[cfg(target_arch = "aarch64")]
301        {
302            neon::mul_packed_bit(self, rhs)
303        }
304
305        #[cfg(not(target_arch = "aarch64"))]
306        {
307            let mut res = [Bit::ZERO; PACKED_WIDTH_BIT];
308            for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
309                *out = *l * *r;
310            }
311
312            Self(res)
313        }
314    }
315}
316
317impl MulAssign for PackedBit {
318    #[inline(always)]
319    fn mul_assign(&mut self, rhs: Self) {
320        *self = *self * rhs;
321    }
322}
323
324impl Mul<Bit> for PackedBit {
325    type Output = Self;
326
327    #[inline(always)]
328    fn mul(self, rhs: Bit) -> Self {
329        let mut res = [Bit::ZERO; PACKED_WIDTH_BIT];
330        for (out, v) in res.iter_mut().zip(self.0.iter()) {
331            *out = *v * rhs;
332        }
333
334        Self(res)
335    }
336}
337
338// ===================================
339// Hardware Field
340// ===================================
341
342impl HardwareField for Bit {
343    #[inline(always)]
344    fn to_hardware(self) -> Flat<Self> {
345        Flat::from_raw(self)
346    }
347
348    #[inline(always)]
349    fn from_hardware(value: Flat<Self>) -> Self {
350        value.into_raw()
351    }
352
353    #[inline(always)]
354    fn add_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
355        let lhs = lhs.into_raw();
356        let rhs = rhs.into_raw();
357
358        // Hardware addition for bits is XOR
359        Flat::from_raw(Self(lhs.0 ^ rhs.0))
360    }
361
362    #[inline(always)]
363    fn add_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
364        PackedFlat::from_raw(lhs.into_raw() + rhs.into_raw())
365    }
366
367    #[inline(always)]
368    fn mul_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
369        let lhs = lhs.into_raw();
370        let rhs = rhs.into_raw();
371
372        // Hardware multiplication for bits is AND
373        Flat::from_raw(Self(lhs.0 & rhs.0))
374    }
375
376    #[inline(always)]
377    fn mul_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
378        PackedFlat::from_raw(lhs.into_raw() * rhs.into_raw())
379    }
380
381    #[inline(always)]
382    fn mul_hardware_scalar_packed(lhs: PackedFlat<Self>, rhs: Flat<Self>) -> PackedFlat<Self> {
383        let broadcasted = PackedBit([rhs.into_raw(); PACKED_WIDTH_BIT]);
384        Self::mul_hardware_packed(lhs, PackedFlat::from_raw(broadcasted))
385    }
386
387    #[inline(always)]
388    fn tower_bit_from_hardware(value: Flat<Self>, bit_idx: usize) -> u8 {
389        assert_eq!(bit_idx, 0, "bit index out of bounds for Bit");
390
391        // In GF(2), Tower and Flat
392        // bases are identical.
393        value.into_raw().0
394    }
395}
396
397impl FlatPromote<Block8> for Bit {
398    #[inline(always)]
399    fn promote_flat(val: Flat<Block8>) -> Flat<Self> {
400        // Take LSB
401        Flat::from_raw(Bit(val.into_raw().0 & 1))
402    }
403}
404
405// ===========================================
406// SIMD INSTRUCTIONS
407// ===========================================
408
409#[cfg(target_arch = "aarch64")]
410mod neon {
411    use super::*;
412    use core::arch::aarch64::*;
413    use core::mem::transmute;
414
415    /// XOR for 64 bits (represented as bytes).
416    /// Uses 4 NEON registers.
417    #[inline(always)]
418    pub fn add_packed_bit(lhs: PackedBit, rhs: PackedBit) -> PackedBit {
419        unsafe {
420            // Cast [Bit; 64] -> [uint8x16_t; 4]
421            let l: [uint8x16_t; 4] = transmute::<[Bit; PACKED_WIDTH_BIT], [uint8x16_t; 4]>(lhs.0);
422            let r: [uint8x16_t; 4] = transmute::<[Bit; PACKED_WIDTH_BIT], [uint8x16_t; 4]>(rhs.0);
423
424            let res = [
425                veorq_u8(l[0], r[0]),
426                veorq_u8(l[1], r[1]),
427                veorq_u8(l[2], r[2]),
428                veorq_u8(l[3], r[3]),
429            ];
430
431            PackedBit(transmute::<[uint8x16_t; 4], [Bit; PACKED_WIDTH_BIT]>(res))
432        }
433    }
434
435    /// AND for 64 bits (represented as bytes).
436    /// Uses 4 NEON registers.
437    #[inline(always)]
438    pub fn mul_packed_bit(lhs: PackedBit, rhs: PackedBit) -> PackedBit {
439        unsafe {
440            let l: [uint8x16_t; 4] = transmute::<[Bit; PACKED_WIDTH_BIT], [uint8x16_t; 4]>(lhs.0);
441            let r: [uint8x16_t; 4] = transmute::<[Bit; PACKED_WIDTH_BIT], [uint8x16_t; 4]>(rhs.0);
442
443            let res = [
444                vandq_u8(l[0], r[0]),
445                vandq_u8(l[1], r[1]),
446                vandq_u8(l[2], r[2]),
447                vandq_u8(l[3], r[3]),
448            ];
449
450            PackedBit(transmute::<[uint8x16_t; 4], [Bit; PACKED_WIDTH_BIT]>(res))
451        }
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458    use rand::{RngExt, rng};
459
460    // ==================================
461    // BASIC
462    // ==================================
463
464    #[test]
465    fn add_truth() {
466        let zero = Bit::ZERO;
467        let one = Bit::ONE;
468
469        assert_eq!(zero + zero, zero);
470        assert_eq!(zero + one, one);
471        assert_eq!(one + zero, one);
472        assert_eq!(one + one, zero);
473    }
474
475    #[test]
476    fn mul_truth() {
477        let zero = Bit::ZERO;
478        let one = Bit::ONE;
479
480        assert_eq!(zero * zero, zero);
481        assert_eq!(zero * one, zero);
482        assert_eq!(one * one, one);
483    }
484
485    #[test]
486    fn security_zeroize() {
487        // Setup sensitive bit (1)
488        let mut secret_bit = Bit::ONE;
489        assert_eq!(secret_bit.0, 1);
490
491        // Nuke it
492        secret_bit.zeroize();
493
494        // Verify
495        assert_eq!(secret_bit, Bit::ZERO);
496        assert_eq!(secret_bit.0, 0, "Bit memory leak detected");
497    }
498
499    #[test]
500    fn invert_truth() {
501        // In GF(2):
502        // invert(1) = 1
503        // invert(0) = 0 (by convention)
504
505        let one = Bit::ONE;
506        let zero = Bit::ZERO;
507
508        assert_eq!(one.invert(), Bit::ONE, "Inversion of 1 must be 1");
509        assert_eq!(zero.invert(), Bit::ZERO, "Inversion of 0 must be 0");
510    }
511
512    // ==================================
513    // HARDWARE
514    // ==================================
515
516    #[test]
517    fn isomorphism_roundtrip() {
518        let mut rng = rng();
519        for _ in 0..100 {
520            // Generate random bit (0 or 1)
521            let val = Bit::new(rng.random::<u8>());
522
523            // Roundtrip: Tower -> Hardware -> Tower must be identity.
524            // For Bit, this is trivial (identity),
525            // but we verify the trait contract.
526            assert_eq!(
527                val.to_hardware().to_tower(),
528                val,
529                "Bit isomorphism roundtrip failed"
530            );
531        }
532    }
533
534    #[test]
535    fn flat_mul_homomorphism() {
536        let mut rng = rng();
537        for _ in 0..100 {
538            let a = Bit::new(rng.random::<u8>());
539            let b = Bit::new(rng.random::<u8>());
540
541            let expected_flat = (a * b).to_hardware();
542            let actual_flat = a.to_hardware() * b.to_hardware();
543
544            // Check if multiplication in Flat basis matches Tower
545            assert_eq!(
546                actual_flat, expected_flat,
547                "Bit flat multiplication mismatch"
548            );
549        }
550    }
551
552    #[test]
553    fn packed_consistency() {
554        let mut rng = rng();
555        for _ in 0..100 {
556            // PACKED_WIDTH_BIT = 64
557            let mut a_vals = [Bit::ZERO; 64];
558            let mut b_vals = [Bit::ZERO; 64];
559
560            for i in 0..64 {
561                a_vals[i] = Bit::new(rng.random::<u8>());
562                b_vals[i] = Bit::new(rng.random::<u8>());
563            }
564
565            let a_flat_vals = a_vals.map(|x| x.to_hardware());
566            let b_flat_vals = b_vals.map(|x| x.to_hardware());
567            let a_packed = Flat::<Bit>::pack(&a_flat_vals);
568            let b_packed = Flat::<Bit>::pack(&b_flat_vals);
569
570            // 1. Test SIMD Add (XOR)
571            let add_res = Bit::add_hardware_packed(a_packed, b_packed);
572
573            let mut add_out = [Bit::ZERO.to_hardware(); 64];
574            Flat::<Bit>::unpack(add_res, &mut add_out);
575
576            for i in 0..64 {
577                assert_eq!(
578                    add_out[i],
579                    (a_vals[i] + b_vals[i]).to_hardware(),
580                    "Bit packed add mismatch at index {}",
581                    i
582                );
583            }
584
585            // 2. Test SIMD Mul (AND)
586            let mul_res = Bit::mul_hardware_packed(a_packed, b_packed);
587
588            let mut mul_out = [Bit::ZERO.to_hardware(); 64];
589            Flat::<Bit>::unpack(mul_res, &mut mul_out);
590
591            for i in 0..64 {
592                assert_eq!(
593                    mul_out[i],
594                    (a_vals[i] * b_vals[i]).to_hardware(),
595                    "Bit packed mul mismatch at index {}",
596                    i
597                );
598            }
599        }
600    }
601
602    // ==================================
603    // PACKED
604    // ==================================
605
606    #[test]
607    fn pack_unpack_roundtrip() {
608        let mut rng = rng();
609        // Width is 64
610        let mut data = [Bit::ZERO; PACKED_WIDTH_BIT];
611
612        for v in data.iter_mut() {
613            *v = Bit::new(rng.random());
614        }
615
616        let packed = Bit::pack(&data);
617        let mut unpacked = [Bit::ZERO; PACKED_WIDTH_BIT];
618        Bit::unpack(packed, &mut unpacked);
619
620        assert_eq!(data, unpacked, "Bit pack/unpack roundtrip failed");
621    }
622
623    #[test]
624    fn packed_add_consistency() {
625        let mut rng = rng();
626        let mut a_vals = [Bit::ZERO; PACKED_WIDTH_BIT];
627        let mut b_vals = [Bit::ZERO; PACKED_WIDTH_BIT];
628
629        for i in 0..PACKED_WIDTH_BIT {
630            a_vals[i] = Bit::new(rng.random());
631            b_vals[i] = Bit::new(rng.random());
632        }
633
634        let a_packed = Bit::pack(&a_vals);
635        let b_packed = Bit::pack(&b_vals);
636
637        // Uses the SIMD add impl (which uses aarch64::add_packed_bit)
638        let res_packed = a_packed + b_packed;
639
640        let mut res_unpacked = [Bit::ZERO; PACKED_WIDTH_BIT];
641        Bit::unpack(res_packed, &mut res_unpacked);
642
643        for i in 0..PACKED_WIDTH_BIT {
644            assert_eq!(
645                res_unpacked[i],
646                a_vals[i] + b_vals[i], // Regular Bit add (XOR)
647                "Bit packed add mismatch"
648            );
649        }
650    }
651
652    #[test]
653    fn packed_mul_consistency() {
654        let mut rng = rng();
655
656        for _ in 0..100 {
657            let mut a_arr = [Bit::ZERO; PACKED_WIDTH_BIT];
658            let mut b_arr = [Bit::ZERO; PACKED_WIDTH_BIT];
659
660            for i in 0..PACKED_WIDTH_BIT {
661                a_arr[i] = Bit::new(rng.random());
662                b_arr[i] = Bit::new(rng.random());
663            }
664
665            let a_packed = PackedBit(a_arr); // Using constructor directly or pack
666            let b_packed = PackedBit(b_arr);
667
668            // Uses the SIMD mul impl (which uses aarch64::mul_packed_bit)
669            let c_packed = a_packed * b_packed;
670
671            let mut c_expected = [Bit::ZERO; PACKED_WIDTH_BIT];
672            for i in 0..PACKED_WIDTH_BIT {
673                c_expected[i] = a_arr[i] * b_arr[i]; // Regular Bit mul (AND)
674            }
675
676            assert_eq!(c_packed.0, c_expected, "Bit packed mul mismatch");
677        }
678    }
679}