1use crate::{
19 Block8, CanonicalDeserialize, CanonicalSerialize, Flat, FlatPromote, HardwareField,
20 PackableField, PackedFlat, TowerField,
21};
22use core::ops::{Add, AddAssign, BitAnd, BitXor, Mul, MulAssign, Sub, SubAssign};
23use serde::{Deserialize, Serialize};
24use zeroize::Zeroize;
25
26#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
31#[repr(transparent)]
32pub struct Bit(pub u8);
33
34impl Bit {
35 pub const fn new(val: u8) -> Self {
36 Self(val & 1) }
38}
39
40impl TowerField for Bit {
41 const BITS: usize = 1;
42 const ZERO: Self = Bit(0);
43 const ONE: Self = Bit(1);
44
45 const EXTENSION_TAU: Self = Bit(1);
47
48 fn invert(&self) -> Self {
49 *self
55 }
56
57 fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
58 Self(bytes[0] & 1)
60 }
61}
62
63impl Add for Bit {
66 type Output = Self;
67
68 fn add(self, rhs: Self) -> Self::Output {
69 Self(self.0.bitxor(rhs.0))
70 }
71}
72
73impl Sub for Bit {
75 type Output = Self;
76
77 fn sub(self, rhs: Self) -> Self::Output {
78 self.add(rhs)
79 }
80}
81
82impl Mul for Bit {
85 type Output = Self;
86
87 fn mul(self, rhs: Self) -> Self::Output {
88 Self(self.0.bitand(rhs.0))
89 }
90}
91
92impl AddAssign for Bit {
93 fn add_assign(&mut self, rhs: Self) {
94 *self = *self + rhs
95 }
96}
97
98impl SubAssign for Bit {
99 fn sub_assign(&mut self, rhs: Self) {
100 *self = *self - rhs
101 }
102}
103
104impl MulAssign for Bit {
105 fn mul_assign(&mut self, rhs: Self) {
106 *self = *self * rhs;
107 }
108}
109
110impl CanonicalSerialize for Bit {
111 #[inline]
112 fn serialized_size(&self) -> usize {
113 1
114 }
115
116 #[inline]
117 fn serialize(&self, writer: &mut [u8]) -> Result<(), ()> {
118 if writer.is_empty() {
119 return Err(());
120 }
121
122 writer[0] = self.0;
123
124 Ok(())
125 }
126}
127
128impl CanonicalDeserialize for Bit {
129 fn deserialize(bytes: &[u8]) -> Result<Self, ()> {
130 if bytes.is_empty() {
131 return Err(());
132 }
133
134 if bytes[0] > 1 {
135 return Err(());
136 }
137
138 Ok(Self(bytes[0]))
139 }
140}
141
142impl From<u8> for Bit {
143 #[inline]
144 fn from(val: u8) -> Self {
145 Self(val & 1)
146 }
147}
148
149impl From<u32> for Bit {
150 #[inline]
151 fn from(val: u32) -> Self {
152 Self((val & 1) as u8)
153 }
154}
155
156impl From<u64> for Bit {
157 #[inline]
158 fn from(val: u64) -> Self {
159 Self((val & 1) as u8)
160 }
161}
162
163impl From<u128> for Bit {
164 #[inline]
165 fn from(val: u128) -> Self {
166 Self((val & 1) as u8)
167 }
168}
169
170pub const PACKED_WIDTH_BIT: usize = 64;
176
177#[repr(C, align(64))]
178pub struct PackedBit(pub [Bit; PACKED_WIDTH_BIT]);
179
180impl Clone for PackedBit {
181 #[inline(always)]
182 fn clone(&self) -> Self {
183 *self
184 }
185}
186
187impl Copy for PackedBit {}
188
189impl Default for PackedBit {
190 #[inline(always)]
191 fn default() -> Self {
192 Self::zero()
193 }
194}
195
196impl PartialEq for PackedBit {
197 fn eq(&self, other: &Self) -> bool {
198 self.0[..] == other.0[..]
201 }
202}
203
204impl Eq for PackedBit {}
205
206impl core::fmt::Debug for PackedBit {
207 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
208 write!(f, "PackedBit([size={}])", PACKED_WIDTH_BIT)
209 }
210}
211
212impl PackedBit {
213 #[inline(always)]
214 pub fn zero() -> Self {
215 Self([Bit::ZERO; PACKED_WIDTH_BIT])
216 }
217}
218
219impl PackableField for Bit {
220 type Packed = PackedBit;
221
222 const WIDTH: usize = PACKED_WIDTH_BIT;
223
224 #[inline(always)]
225 fn pack(chunk: &[Self]) -> Self::Packed {
226 assert!(
227 chunk.len() >= PACKED_WIDTH_BIT,
228 "PackableField::pack: input slice too short",
229 );
230
231 let mut arr = [Self::ZERO; PACKED_WIDTH_BIT];
232 arr.copy_from_slice(&chunk[..PACKED_WIDTH_BIT]);
233
234 PackedBit(arr)
235 }
236
237 #[inline(always)]
238 fn unpack(packed: Self::Packed, output: &mut [Self]) {
239 assert!(
240 output.len() >= PACKED_WIDTH_BIT,
241 "PackableField::unpack: output slice too short",
242 );
243
244 output[..PACKED_WIDTH_BIT].copy_from_slice(&packed.0);
245 }
246}
247
248impl Add for PackedBit {
249 type Output = Self;
250
251 #[inline(always)]
252 fn add(self, rhs: Self) -> Self {
253 #[cfg(target_arch = "aarch64")]
254 {
255 neon::add_packed_bit(self, rhs)
256 }
257
258 #[cfg(not(target_arch = "aarch64"))]
259 {
260 let mut res = [Bit::ZERO; PACKED_WIDTH_BIT];
261 for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
262 *out = *l + *r;
263 }
264
265 Self(res)
266 }
267 }
268}
269
270impl AddAssign for PackedBit {
271 #[inline(always)]
272 fn add_assign(&mut self, rhs: Self) {
273 for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
274 *l += *r;
275 }
276 }
277}
278
279impl Sub for PackedBit {
280 type Output = Self;
281
282 #[inline(always)]
283 fn sub(self, rhs: Self) -> Self {
284 self.add(rhs)
285 }
286}
287
288impl SubAssign for PackedBit {
289 #[inline(always)]
290 fn sub_assign(&mut self, rhs: Self) {
291 self.add_assign(rhs)
292 }
293}
294
295impl Mul for PackedBit {
296 type Output = Self;
297
298 #[inline(always)]
299 fn mul(self, rhs: Self) -> Self {
300 #[cfg(target_arch = "aarch64")]
301 {
302 neon::mul_packed_bit(self, rhs)
303 }
304
305 #[cfg(not(target_arch = "aarch64"))]
306 {
307 let mut res = [Bit::ZERO; PACKED_WIDTH_BIT];
308 for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
309 *out = *l * *r;
310 }
311
312 Self(res)
313 }
314 }
315}
316
317impl MulAssign for PackedBit {
318 #[inline(always)]
319 fn mul_assign(&mut self, rhs: Self) {
320 *self = *self * rhs;
321 }
322}
323
324impl Mul<Bit> for PackedBit {
325 type Output = Self;
326
327 #[inline(always)]
328 fn mul(self, rhs: Bit) -> Self {
329 let mut res = [Bit::ZERO; PACKED_WIDTH_BIT];
330 for (out, v) in res.iter_mut().zip(self.0.iter()) {
331 *out = *v * rhs;
332 }
333
334 Self(res)
335 }
336}
337
338impl HardwareField for Bit {
343 #[inline(always)]
344 fn to_hardware(self) -> Flat<Self> {
345 Flat::from_raw(self)
346 }
347
348 #[inline(always)]
349 fn from_hardware(value: Flat<Self>) -> Self {
350 value.into_raw()
351 }
352
353 #[inline(always)]
354 fn add_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
355 let lhs = lhs.into_raw();
356 let rhs = rhs.into_raw();
357
358 Flat::from_raw(Self(lhs.0 ^ rhs.0))
360 }
361
362 #[inline(always)]
363 fn add_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
364 PackedFlat::from_raw(lhs.into_raw() + rhs.into_raw())
365 }
366
367 #[inline(always)]
368 fn mul_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
369 let lhs = lhs.into_raw();
370 let rhs = rhs.into_raw();
371
372 Flat::from_raw(Self(lhs.0 & rhs.0))
374 }
375
376 #[inline(always)]
377 fn mul_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
378 PackedFlat::from_raw(lhs.into_raw() * rhs.into_raw())
379 }
380
381 #[inline(always)]
382 fn mul_hardware_scalar_packed(lhs: PackedFlat<Self>, rhs: Flat<Self>) -> PackedFlat<Self> {
383 let broadcasted = PackedBit([rhs.into_raw(); PACKED_WIDTH_BIT]);
384 Self::mul_hardware_packed(lhs, PackedFlat::from_raw(broadcasted))
385 }
386
387 #[inline(always)]
388 fn tower_bit_from_hardware(value: Flat<Self>, bit_idx: usize) -> u8 {
389 assert_eq!(bit_idx, 0, "bit index out of bounds for Bit");
390
391 value.into_raw().0
394 }
395}
396
397impl FlatPromote<Block8> for Bit {
398 #[inline(always)]
399 fn promote_flat(val: Flat<Block8>) -> Flat<Self> {
400 Flat::from_raw(Bit(val.into_raw().0 & 1))
402 }
403}
404
405#[cfg(target_arch = "aarch64")]
410mod neon {
411 use super::*;
412 use core::arch::aarch64::*;
413 use core::mem::transmute;
414
415 #[inline(always)]
418 pub fn add_packed_bit(lhs: PackedBit, rhs: PackedBit) -> PackedBit {
419 unsafe {
420 let l: [uint8x16_t; 4] = transmute::<[Bit; PACKED_WIDTH_BIT], [uint8x16_t; 4]>(lhs.0);
422 let r: [uint8x16_t; 4] = transmute::<[Bit; PACKED_WIDTH_BIT], [uint8x16_t; 4]>(rhs.0);
423
424 let res = [
425 veorq_u8(l[0], r[0]),
426 veorq_u8(l[1], r[1]),
427 veorq_u8(l[2], r[2]),
428 veorq_u8(l[3], r[3]),
429 ];
430
431 PackedBit(transmute::<[uint8x16_t; 4], [Bit; PACKED_WIDTH_BIT]>(res))
432 }
433 }
434
435 #[inline(always)]
438 pub fn mul_packed_bit(lhs: PackedBit, rhs: PackedBit) -> PackedBit {
439 unsafe {
440 let l: [uint8x16_t; 4] = transmute::<[Bit; PACKED_WIDTH_BIT], [uint8x16_t; 4]>(lhs.0);
441 let r: [uint8x16_t; 4] = transmute::<[Bit; PACKED_WIDTH_BIT], [uint8x16_t; 4]>(rhs.0);
442
443 let res = [
444 vandq_u8(l[0], r[0]),
445 vandq_u8(l[1], r[1]),
446 vandq_u8(l[2], r[2]),
447 vandq_u8(l[3], r[3]),
448 ];
449
450 PackedBit(transmute::<[uint8x16_t; 4], [Bit; PACKED_WIDTH_BIT]>(res))
451 }
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use rand::{RngExt, rng};
459
460 #[test]
465 fn add_truth() {
466 let zero = Bit::ZERO;
467 let one = Bit::ONE;
468
469 assert_eq!(zero + zero, zero);
470 assert_eq!(zero + one, one);
471 assert_eq!(one + zero, one);
472 assert_eq!(one + one, zero);
473 }
474
475 #[test]
476 fn mul_truth() {
477 let zero = Bit::ZERO;
478 let one = Bit::ONE;
479
480 assert_eq!(zero * zero, zero);
481 assert_eq!(zero * one, zero);
482 assert_eq!(one * one, one);
483 }
484
485 #[test]
486 fn security_zeroize() {
487 let mut secret_bit = Bit::ONE;
489 assert_eq!(secret_bit.0, 1);
490
491 secret_bit.zeroize();
493
494 assert_eq!(secret_bit, Bit::ZERO);
496 assert_eq!(secret_bit.0, 0, "Bit memory leak detected");
497 }
498
499 #[test]
500 fn invert_truth() {
501 let one = Bit::ONE;
506 let zero = Bit::ZERO;
507
508 assert_eq!(one.invert(), Bit::ONE, "Inversion of 1 must be 1");
509 assert_eq!(zero.invert(), Bit::ZERO, "Inversion of 0 must be 0");
510 }
511
512 #[test]
517 fn isomorphism_roundtrip() {
518 let mut rng = rng();
519 for _ in 0..100 {
520 let val = Bit::new(rng.random::<u8>());
522
523 assert_eq!(
527 val.to_hardware().to_tower(),
528 val,
529 "Bit isomorphism roundtrip failed"
530 );
531 }
532 }
533
534 #[test]
535 fn flat_mul_homomorphism() {
536 let mut rng = rng();
537 for _ in 0..100 {
538 let a = Bit::new(rng.random::<u8>());
539 let b = Bit::new(rng.random::<u8>());
540
541 let expected_flat = (a * b).to_hardware();
542 let actual_flat = a.to_hardware() * b.to_hardware();
543
544 assert_eq!(
546 actual_flat, expected_flat,
547 "Bit flat multiplication mismatch"
548 );
549 }
550 }
551
552 #[test]
553 fn packed_consistency() {
554 let mut rng = rng();
555 for _ in 0..100 {
556 let mut a_vals = [Bit::ZERO; 64];
558 let mut b_vals = [Bit::ZERO; 64];
559
560 for i in 0..64 {
561 a_vals[i] = Bit::new(rng.random::<u8>());
562 b_vals[i] = Bit::new(rng.random::<u8>());
563 }
564
565 let a_flat_vals = a_vals.map(|x| x.to_hardware());
566 let b_flat_vals = b_vals.map(|x| x.to_hardware());
567 let a_packed = Flat::<Bit>::pack(&a_flat_vals);
568 let b_packed = Flat::<Bit>::pack(&b_flat_vals);
569
570 let add_res = Bit::add_hardware_packed(a_packed, b_packed);
572
573 let mut add_out = [Bit::ZERO.to_hardware(); 64];
574 Flat::<Bit>::unpack(add_res, &mut add_out);
575
576 for i in 0..64 {
577 assert_eq!(
578 add_out[i],
579 (a_vals[i] + b_vals[i]).to_hardware(),
580 "Bit packed add mismatch at index {}",
581 i
582 );
583 }
584
585 let mul_res = Bit::mul_hardware_packed(a_packed, b_packed);
587
588 let mut mul_out = [Bit::ZERO.to_hardware(); 64];
589 Flat::<Bit>::unpack(mul_res, &mut mul_out);
590
591 for i in 0..64 {
592 assert_eq!(
593 mul_out[i],
594 (a_vals[i] * b_vals[i]).to_hardware(),
595 "Bit packed mul mismatch at index {}",
596 i
597 );
598 }
599 }
600 }
601
602 #[test]
607 fn pack_unpack_roundtrip() {
608 let mut rng = rng();
609 let mut data = [Bit::ZERO; PACKED_WIDTH_BIT];
611
612 for v in data.iter_mut() {
613 *v = Bit::new(rng.random());
614 }
615
616 let packed = Bit::pack(&data);
617 let mut unpacked = [Bit::ZERO; PACKED_WIDTH_BIT];
618 Bit::unpack(packed, &mut unpacked);
619
620 assert_eq!(data, unpacked, "Bit pack/unpack roundtrip failed");
621 }
622
623 #[test]
624 fn packed_add_consistency() {
625 let mut rng = rng();
626 let mut a_vals = [Bit::ZERO; PACKED_WIDTH_BIT];
627 let mut b_vals = [Bit::ZERO; PACKED_WIDTH_BIT];
628
629 for i in 0..PACKED_WIDTH_BIT {
630 a_vals[i] = Bit::new(rng.random());
631 b_vals[i] = Bit::new(rng.random());
632 }
633
634 let a_packed = Bit::pack(&a_vals);
635 let b_packed = Bit::pack(&b_vals);
636
637 let res_packed = a_packed + b_packed;
639
640 let mut res_unpacked = [Bit::ZERO; PACKED_WIDTH_BIT];
641 Bit::unpack(res_packed, &mut res_unpacked);
642
643 for i in 0..PACKED_WIDTH_BIT {
644 assert_eq!(
645 res_unpacked[i],
646 a_vals[i] + b_vals[i], "Bit packed add mismatch"
648 );
649 }
650 }
651
652 #[test]
653 fn packed_mul_consistency() {
654 let mut rng = rng();
655
656 for _ in 0..100 {
657 let mut a_arr = [Bit::ZERO; PACKED_WIDTH_BIT];
658 let mut b_arr = [Bit::ZERO; PACKED_WIDTH_BIT];
659
660 for i in 0..PACKED_WIDTH_BIT {
661 a_arr[i] = Bit::new(rng.random());
662 b_arr[i] = Bit::new(rng.random());
663 }
664
665 let a_packed = PackedBit(a_arr); let b_packed = PackedBit(b_arr);
667
668 let c_packed = a_packed * b_packed;
670
671 let mut c_expected = [Bit::ZERO; PACKED_WIDTH_BIT];
672 for i in 0..PACKED_WIDTH_BIT {
673 c_expected[i] = a_arr[i] * b_arr[i]; }
675
676 assert_eq!(c_packed.0, c_expected, "Bit packed mul mismatch");
677 }
678 }
679}