1use std::{
5 fmt,
6 ops::{Add, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use aes::cipher::{self, array::sizes};
10use bytemuck::{Pod, Zeroable};
11use rand::{Rng, distr::StandardUniform, prelude::Distribution};
12use serde::{Deserialize, Serialize};
13use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
14use thiserror::Error;
15use wide::{u8x16, u64x2};
16
17#[allow(dead_code)]
20mod gf128;
21
22#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, Pod, Zeroable)]
27#[repr(transparent)]
28pub struct Block(u8x16);
29
30impl Block {
31 pub const ZERO: Self = Self(u8x16::ZERO);
33 pub const ONES: Self = Self(u8x16::MAX);
35 pub const ONE: Self = Self::new(1_u128.to_ne_bytes());
37 pub const MASK_LSB: Self = Self::pack(u64::MAX << 1, u64::MAX);
44
45 pub const BYTES: usize = 16;
47 pub const BITS: usize = 128;
49
50 #[inline]
52 pub const fn new(bytes: [u8; 16]) -> Self {
53 Self(u8x16::new(bytes))
54 }
55
56 #[inline]
58 pub const fn splat(byte: u8) -> Self {
59 Self::new([byte; 16])
60 }
61
62 #[inline]
67 pub const fn pack(low: u64, high: u64) -> Self {
68 let mut bytes = [0; 16];
69 let low = low.to_ne_bytes();
70 let mut i = 0;
71 while i < low.len() {
72 bytes[i] = low[i];
73 i += 1;
74 }
75
76 let high = high.to_ne_bytes();
77 let mut i = 0;
78 while i < high.len() {
79 bytes[i + 8] = high[i];
80 i += 1;
81 }
82
83 Self::new(bytes)
84 }
85
86 #[inline]
88 pub fn as_bytes(&self) -> &[u8; 16] {
89 self.0.as_array_ref()
90 }
91
92 #[inline]
94 pub fn as_mut_bytes(&mut self) -> &mut [u8; 16] {
95 self.0.as_array_mut()
96 }
97
98 #[inline]
100 pub fn ro_hash(&self) -> blake3::Hash {
101 blake3::hash(self.as_bytes())
102 }
103
104 #[inline]
109 pub fn from_choices(choices: &[Choice]) -> Self {
110 assert_eq!(128, choices.len(), "choices.len() must be 128");
111 let mut bytes = [0_u8; 16];
112 for (chunk, byte) in choices.chunks_exact(8).zip(&mut bytes) {
113 for (i, choice) in chunk.iter().enumerate() {
114 *byte ^= choice.unwrap_u8() << i;
115 }
116 }
117 Self::new(bytes)
118 }
119
120 #[inline]
122 pub fn low(&self) -> u64 {
123 let inner: &u64x2 = bytemuck::must_cast_ref(&self.0);
124 inner.as_array_ref()[0]
125 }
126
127 #[inline]
129 pub fn high(&self) -> u64 {
130 let inner: &u64x2 = bytemuck::must_cast_ref(&self.0);
131 inner.as_array_ref()[1]
132 }
133
134 #[inline]
136 pub fn lsb(&self) -> bool {
137 *self & Block::ONE == Block::ONE
138 }
139
140 #[inline]
142 pub fn bits(&self) -> impl Iterator<Item = bool> + use<> {
143 struct BitIter {
144 blk: Block,
145 idx: usize,
146 }
147 impl Iterator for BitIter {
148 type Item = bool;
149
150 #[inline]
151 fn next(&mut self) -> Option<Self::Item> {
152 if self.idx < Block::BITS {
153 self.idx += 1;
154 let bit = (self.blk >> (self.idx - 1)) & Block::ONE != Block::ZERO;
155 Some(bit)
156 } else {
157 None
158 }
159 }
160 }
161 BitIter { blk: *self, idx: 0 }
162 }
163}
164
165impl BitAnd for Block {
167 type Output = Self;
168
169 #[inline]
170 fn bitand(self, rhs: Self) -> Self {
171 Self(self.0 & rhs.0)
172 }
173}
174
175impl BitAndAssign for Block {
176 #[inline]
177 fn bitand_assign(&mut self, rhs: Self) {
178 *self = *self & rhs;
179 }
180}
181
182impl BitOr for Block {
183 type Output = Self;
184
185 #[inline]
186 fn bitor(self, rhs: Self) -> Self {
187 Self(self.0 | rhs.0)
188 }
189}
190
191impl BitOrAssign for Block {
192 #[inline]
193 fn bitor_assign(&mut self, rhs: Self) {
194 *self = *self | rhs;
195 }
196}
197
198impl BitXor for Block {
199 type Output = Self;
200
201 #[inline]
202 fn bitxor(self, rhs: Self) -> Self {
203 Self(self.0 ^ rhs.0)
204 }
205}
206
207impl BitXorAssign for Block {
208 #[inline]
209 fn bitxor_assign(&mut self, rhs: Self) {
210 *self = *self ^ rhs;
211 }
212}
213
214impl<Rhs> Shl<Rhs> for Block
215where
216 u128: Shl<Rhs, Output = u128>,
217{
218 type Output = Block;
219
220 #[inline]
221 fn shl(self, rhs: Rhs) -> Self::Output {
222 Self::from(u128::from(self) << rhs)
223 }
224}
225
226impl<Rhs> Shr<Rhs> for Block
227where
228 u128: Shr<Rhs, Output = u128>,
229{
230 type Output = Block;
231
232 #[inline]
233 fn shr(self, rhs: Rhs) -> Self::Output {
234 Self::from(u128::from(self) >> rhs)
235 }
236}
237
238impl Not for Block {
239 type Output = Self;
240
241 #[inline]
242 fn not(self) -> Self {
243 Self(!self.0)
244 }
245}
246
247impl PartialEq for Block {
248 fn eq(&self, other: &Self) -> bool {
249 let a: u128 = (*self).into();
250 let b: u128 = (*other).into();
251 a.ct_eq(&b).into()
252 }
253}
254
255impl Eq for Block {}
256
257impl Distribution<Block> for StandardUniform {
258 #[inline]
259 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Block {
260 let mut bytes = [0; 16];
261 rng.fill_bytes(&mut bytes);
262 Block::new(bytes)
263 }
264}
265
266impl AsRef<[u8]> for Block {
267 fn as_ref(&self) -> &[u8] {
268 self.as_bytes()
269 }
270}
271
272impl AsMut<[u8]> for Block {
273 #[inline]
274 fn as_mut(&mut self) -> &mut [u8] {
275 self.as_mut_bytes()
276 }
277}
278
279impl From<Block> for cipher::Array<u8, sizes::U16> {
280 #[inline]
281 fn from(value: Block) -> Self {
282 Self(*value.as_bytes())
283 }
284}
285
286impl From<cipher::Array<u8, sizes::U16>> for Block {
287 #[inline]
288 fn from(value: cipher::Array<u8, sizes::U16>) -> Self {
289 Self::new(value.0)
290 }
291}
292
293impl From<[u8; 16]> for Block {
294 #[inline]
295 fn from(value: [u8; 16]) -> Self {
296 Self::new(value)
297 }
298}
299
300impl From<Block> for [u8; 16] {
301 fn from(value: Block) -> Self {
302 *value.as_bytes()
303 }
304}
305
306impl From<[i64; 2]> for Block {
307 #[inline]
308 fn from(value: [i64; 2]) -> Self {
309 bytemuck::must_cast(value)
310 }
311}
312
313impl From<Block> for [i64; 2] {
314 #[inline]
315 fn from(value: Block) -> Self {
316 bytemuck::must_cast(value)
317 }
318}
319
320impl From<[u64; 2]> for Block {
321 #[inline]
322 fn from(value: [u64; 2]) -> Self {
323 bytemuck::must_cast(value)
324 }
325}
326
327impl From<Block> for [u64; 2] {
328 #[inline]
329 fn from(value: Block) -> Self {
330 bytemuck::must_cast(value)
331 }
332}
333
334impl From<Block> for u128 {
335 #[inline]
336 fn from(value: Block) -> Self {
337 u128::from_ne_bytes(*value.as_bytes())
338 }
339}
340
341impl From<&Block> for u128 {
342 #[inline]
343 fn from(value: &Block) -> Self {
344 u128::from_ne_bytes(*value.as_bytes())
345 }
346}
347
348impl From<usize> for Block {
349 fn from(value: usize) -> Self {
350 (value as u128).into()
351 }
352}
353
354impl From<u128> for Block {
355 #[inline]
356 fn from(value: u128) -> Self {
357 Self::new(value.to_ne_bytes())
358 }
359}
360
361impl From<&u128> for Block {
362 #[inline]
363 fn from(value: &u128) -> Self {
364 Self::new(value.to_ne_bytes())
365 }
366}
367
368#[derive(Debug, Error)]
369#[error("slice must have length of 16")]
370pub struct WrongLength;
371
372impl TryFrom<&[u8]> for Block {
373 type Error = WrongLength;
374
375 #[inline]
376 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
377 let arr = value.try_into().map_err(|_| WrongLength)?;
378 Ok(Self::new(arr))
379 }
380}
381
382#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
383mod from_arch_impls {
384 #[cfg(target_arch = "x86")]
385 use std::arch::x86::*;
386 #[cfg(target_arch = "x86_64")]
387 use std::arch::x86_64::*;
388
389 use super::Block;
390
391 impl From<__m128i> for Block {
392 #[inline]
393 fn from(value: __m128i) -> Self {
394 bytemuck::must_cast(value)
395 }
396 }
397
398 impl From<&__m128i> for Block {
399 #[inline]
400 fn from(value: &__m128i) -> Self {
401 bytemuck::must_cast(*value)
402 }
403 }
404
405 impl From<Block> for __m128i {
406 #[inline]
407 fn from(value: Block) -> Self {
408 bytemuck::must_cast(value)
409 }
410 }
411
412 impl From<&Block> for __m128i {
413 #[inline]
414 fn from(value: &Block) -> Self {
415 bytemuck::must_cast(*value)
416 }
417 }
418}
419
420impl ConditionallySelectable for Block {
421 #[inline]
422 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
424 let mask = Block::new((-(choice.unwrap_u8() as i128)).to_le_bytes());
427 *a ^ (mask & (*a ^ *b))
428 }
429}
430
431impl Add for Block {
432 type Output = Block;
433
434 #[inline]
435 fn add(self, rhs: Self) -> Self::Output {
436 let a: u128 = self.into();
438 let b: u128 = rhs.into();
439 Self::from(a.wrapping_add(b))
440 }
441}
442
443impl fmt::Binary for Block {
444 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
445 fmt::Binary::fmt(&u128::from(*self), f)
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use subtle::{Choice, ConditionallySelectable};
452
453 use super::Block;
454
455 #[test]
456 fn test_block_cond_select() {
457 let choice = Choice::from(0);
458 assert_eq!(
459 Block::ZERO,
460 Block::conditional_select(&Block::ZERO, &Block::ONES, choice)
461 );
462 let choice = Choice::from(1);
463 assert_eq!(
464 Block::ONES,
465 Block::conditional_select(&Block::ZERO, &Block::ONES, choice)
466 );
467 }
468
469 #[test]
470 fn test_block_low_high() {
471 let b = Block::from(1_u128);
472 assert_eq!(1, b.low());
473 assert_eq!(0, b.high());
474 }
475
476 #[test]
477 fn test_from_into_u64_arr() {
478 let b = Block::from([42_u64, 65]);
479 assert_eq!(42, b.low());
480 assert_eq!(65, b.high());
481 assert_eq!([42, 65], <[u64; 2]>::from(b));
482 }
483
484 #[test]
485 fn test_pack() {
486 let b = Block::pack(42, 123);
487 assert_eq!(42, b.low());
488 assert_eq!(123, b.high());
489 }
490
491 #[test]
492 fn test_mask_lsb() {
493 assert_eq!(Block::ONES ^ Block::ONE, Block::MASK_LSB);
494 }
495
496 #[test]
497 fn test_bits() {
498 let b: Block = 0b101_u128.into();
499 let mut iter = b.bits();
500 assert_eq!(Some(true), iter.next());
501 assert_eq!(Some(false), iter.next());
502 assert_eq!(Some(true), iter.next());
503 for rest in iter {
504 assert!(!rest);
505 }
506 }
507
508 #[test]
509 fn test_from_choices() {
510 let mut choices = vec![Choice::from(0); 128];
511 choices[2] = Choice::from(1);
512 choices[16] = Choice::from(1);
513 let blk = Block::from_choices(&choices);
514 assert_eq!(Block::from(1_u128 << 2 | 1_u128 << 16), blk);
515 }
516}