Skip to main content

polytune/
block.rs

1//! A 128-bit [`Block`] type.
2//!
3//! Operations on [`Block`]s will use SIMD instructions where possible.
4use std::{
5    fmt,
6    ops::{Add, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use aes::cipher::{self, array::sizes};
10use bytemuck::{Pod, Zeroable};
11use rand::{Rng, distr::StandardUniform, prelude::Distribution};
12use serde::{Deserialize, Serialize};
13use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
14use thiserror::Error;
15use wide::{u8x16, u64x2};
16
17// TODO remove this once OT implementations are refactored and we know
18// what parts we need and which not
19#[allow(dead_code)]
20mod gf128;
21
22/// A 128-bit block. Uses SIMD operations where available.
23///
24/// This type is publicly re-exported when the private `__bench` feature
25/// is enabled at [`crate::bench_reexports::Block`].
26#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, Pod, Zeroable)]
27#[repr(transparent)]
28pub struct Block(u8x16);
29
30impl Block {
31    /// All bits set to 0.
32    pub const ZERO: Self = Self(u8x16::ZERO);
33    /// All bits set to 1.
34    pub const ONES: Self = Self(u8x16::MAX);
35    /// Lsb set to 1, all others zero.
36    pub const ONE: Self = Self::new(1_u128.to_ne_bytes());
37    /// Mask to mask off the LSB of a Block.
38    /// ```rust,ignore
39    /// let b = Block::ONES;
40    /// let masked = b & Block::MASK_LSB;
41    /// assert_eq!(masked, Block::ONES << 1)
42    /// ```
43    pub const MASK_LSB: Self = Self::pack(u64::MAX << 1, u64::MAX);
44
45    /// 16 bytes in a Block.
46    pub const BYTES: usize = 16;
47    /// 128 bits in a block.
48    pub const BITS: usize = 128;
49
50    /// Create a new block from bytes.
51    #[inline]
52    pub const fn new(bytes: [u8; 16]) -> Self {
53        Self(u8x16::new(bytes))
54    }
55
56    /// Create a block with all bytes set to `byte`.
57    #[inline]
58    pub const fn splat(byte: u8) -> Self {
59        Self::new([byte; 16])
60    }
61
62    /// Pack two `u64` into a Block. Usable in const context.
63    ///
64    /// In non-const contexts, using `Block::from([low, high])` is likely
65    /// faster.
66    #[inline]
67    pub const fn pack(low: u64, high: u64) -> Self {
68        let mut bytes = [0; 16];
69        let low = low.to_ne_bytes();
70        let mut i = 0;
71        while i < low.len() {
72            bytes[i] = low[i];
73            i += 1;
74        }
75
76        let high = high.to_ne_bytes();
77        let mut i = 0;
78        while i < high.len() {
79            bytes[i + 8] = high[i];
80            i += 1;
81        }
82
83        Self::new(bytes)
84    }
85
86    /// Bytes of the block.
87    #[inline]
88    pub fn as_bytes(&self) -> &[u8; 16] {
89        self.0.as_array_ref()
90    }
91
92    /// Mutable bytes of the block.
93    #[inline]
94    pub fn as_mut_bytes(&mut self) -> &mut [u8; 16] {
95        self.0.as_array_mut()
96    }
97
98    /// Hash the block with a [`random_oracle`].
99    #[inline]
100    pub fn ro_hash(&self) -> blake3::Hash {
101        blake3::hash(self.as_bytes())
102    }
103
104    ///  Create a block from 128 [`Choice`]s.
105    ///
106    /// # Panics
107    /// If choices.len() != 128
108    #[inline]
109    pub fn from_choices(choices: &[Choice]) -> Self {
110        assert_eq!(128, choices.len(), "choices.len() must be 128");
111        let mut bytes = [0_u8; 16];
112        for (chunk, byte) in choices.chunks_exact(8).zip(&mut bytes) {
113            for (i, choice) in chunk.iter().enumerate() {
114                *byte ^= choice.unwrap_u8() << i;
115            }
116        }
117        Self::new(bytes)
118    }
119
120    /// Low 64 bits of the block.
121    #[inline]
122    pub fn low(&self) -> u64 {
123        let inner: &u64x2 = bytemuck::must_cast_ref(&self.0);
124        inner.as_array_ref()[0]
125    }
126
127    /// High 64 bits of the block.
128    #[inline]
129    pub fn high(&self) -> u64 {
130        let inner: &u64x2 = bytemuck::must_cast_ref(&self.0);
131        inner.as_array_ref()[1]
132    }
133
134    /// Least significant bit of the block
135    #[inline]
136    pub fn lsb(&self) -> bool {
137        *self & Block::ONE == Block::ONE
138    }
139
140    /// Iterator over bits of the Block.
141    #[inline]
142    pub fn bits(&self) -> impl Iterator<Item = bool> + use<> {
143        struct BitIter {
144            blk: Block,
145            idx: usize,
146        }
147        impl Iterator for BitIter {
148            type Item = bool;
149
150            #[inline]
151            fn next(&mut self) -> Option<Self::Item> {
152                if self.idx < Block::BITS {
153                    self.idx += 1;
154                    let bit = (self.blk >> (self.idx - 1)) & Block::ONE != Block::ZERO;
155                    Some(bit)
156                } else {
157                    None
158                }
159            }
160        }
161        BitIter { blk: *self, idx: 0 }
162    }
163}
164
165// Implement standard operators for more ergonomic usage
166impl BitAnd for Block {
167    type Output = Self;
168
169    #[inline]
170    fn bitand(self, rhs: Self) -> Self {
171        Self(self.0 & rhs.0)
172    }
173}
174
175impl BitAndAssign for Block {
176    #[inline]
177    fn bitand_assign(&mut self, rhs: Self) {
178        *self = *self & rhs;
179    }
180}
181
182impl BitOr for Block {
183    type Output = Self;
184
185    #[inline]
186    fn bitor(self, rhs: Self) -> Self {
187        Self(self.0 | rhs.0)
188    }
189}
190
191impl BitOrAssign for Block {
192    #[inline]
193    fn bitor_assign(&mut self, rhs: Self) {
194        *self = *self | rhs;
195    }
196}
197
198impl BitXor for Block {
199    type Output = Self;
200
201    #[inline]
202    fn bitxor(self, rhs: Self) -> Self {
203        Self(self.0 ^ rhs.0)
204    }
205}
206
207impl BitXorAssign for Block {
208    #[inline]
209    fn bitxor_assign(&mut self, rhs: Self) {
210        *self = *self ^ rhs;
211    }
212}
213
214impl<Rhs> Shl<Rhs> for Block
215where
216    u128: Shl<Rhs, Output = u128>,
217{
218    type Output = Block;
219
220    #[inline]
221    fn shl(self, rhs: Rhs) -> Self::Output {
222        Self::from(u128::from(self) << rhs)
223    }
224}
225
226impl<Rhs> Shr<Rhs> for Block
227where
228    u128: Shr<Rhs, Output = u128>,
229{
230    type Output = Block;
231
232    #[inline]
233    fn shr(self, rhs: Rhs) -> Self::Output {
234        Self::from(u128::from(self) >> rhs)
235    }
236}
237
238impl Not for Block {
239    type Output = Self;
240
241    #[inline]
242    fn not(self) -> Self {
243        Self(!self.0)
244    }
245}
246
247impl PartialEq for Block {
248    fn eq(&self, other: &Self) -> bool {
249        let a: u128 = (*self).into();
250        let b: u128 = (*other).into();
251        a.ct_eq(&b).into()
252    }
253}
254
255impl Eq for Block {}
256
257impl Distribution<Block> for StandardUniform {
258    #[inline]
259    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Block {
260        let mut bytes = [0; 16];
261        rng.fill_bytes(&mut bytes);
262        Block::new(bytes)
263    }
264}
265
266impl AsRef<[u8]> for Block {
267    fn as_ref(&self) -> &[u8] {
268        self.as_bytes()
269    }
270}
271
272impl AsMut<[u8]> for Block {
273    #[inline]
274    fn as_mut(&mut self) -> &mut [u8] {
275        self.as_mut_bytes()
276    }
277}
278
279impl From<Block> for cipher::Array<u8, sizes::U16> {
280    #[inline]
281    fn from(value: Block) -> Self {
282        Self(*value.as_bytes())
283    }
284}
285
286impl From<cipher::Array<u8, sizes::U16>> for Block {
287    #[inline]
288    fn from(value: cipher::Array<u8, sizes::U16>) -> Self {
289        Self::new(value.0)
290    }
291}
292
293impl From<[u8; 16]> for Block {
294    #[inline]
295    fn from(value: [u8; 16]) -> Self {
296        Self::new(value)
297    }
298}
299
300impl From<Block> for [u8; 16] {
301    fn from(value: Block) -> Self {
302        *value.as_bytes()
303    }
304}
305
306impl From<[i64; 2]> for Block {
307    #[inline]
308    fn from(value: [i64; 2]) -> Self {
309        bytemuck::must_cast(value)
310    }
311}
312
313impl From<Block> for [i64; 2] {
314    #[inline]
315    fn from(value: Block) -> Self {
316        bytemuck::must_cast(value)
317    }
318}
319
320impl From<[u64; 2]> for Block {
321    #[inline]
322    fn from(value: [u64; 2]) -> Self {
323        bytemuck::must_cast(value)
324    }
325}
326
327impl From<Block> for [u64; 2] {
328    #[inline]
329    fn from(value: Block) -> Self {
330        bytemuck::must_cast(value)
331    }
332}
333
334impl From<Block> for u128 {
335    #[inline]
336    fn from(value: Block) -> Self {
337        u128::from_ne_bytes(*value.as_bytes())
338    }
339}
340
341impl From<&Block> for u128 {
342    #[inline]
343    fn from(value: &Block) -> Self {
344        u128::from_ne_bytes(*value.as_bytes())
345    }
346}
347
348impl From<usize> for Block {
349    fn from(value: usize) -> Self {
350        (value as u128).into()
351    }
352}
353
354impl From<u128> for Block {
355    #[inline]
356    fn from(value: u128) -> Self {
357        Self::new(value.to_ne_bytes())
358    }
359}
360
361impl From<&u128> for Block {
362    #[inline]
363    fn from(value: &u128) -> Self {
364        Self::new(value.to_ne_bytes())
365    }
366}
367
368#[derive(Debug, Error)]
369#[error("slice must have length of 16")]
370pub struct WrongLength;
371
372impl TryFrom<&[u8]> for Block {
373    type Error = WrongLength;
374
375    #[inline]
376    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
377        let arr = value.try_into().map_err(|_| WrongLength)?;
378        Ok(Self::new(arr))
379    }
380}
381
382#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
383mod from_arch_impls {
384    #[cfg(target_arch = "x86")]
385    use std::arch::x86::*;
386    #[cfg(target_arch = "x86_64")]
387    use std::arch::x86_64::*;
388
389    use super::Block;
390
391    impl From<__m128i> for Block {
392        #[inline]
393        fn from(value: __m128i) -> Self {
394            bytemuck::must_cast(value)
395        }
396    }
397
398    impl From<&__m128i> for Block {
399        #[inline]
400        fn from(value: &__m128i) -> Self {
401            bytemuck::must_cast(*value)
402        }
403    }
404
405    impl From<Block> for __m128i {
406        #[inline]
407        fn from(value: Block) -> Self {
408            bytemuck::must_cast(value)
409        }
410    }
411
412    impl From<&Block> for __m128i {
413        #[inline]
414        fn from(value: &Block) -> Self {
415            bytemuck::must_cast(*value)
416        }
417    }
418}
419
420impl ConditionallySelectable for Block {
421    #[inline]
422    // adapted from https://github.com/dalek-cryptography/subtle/blob/369e7463e85921377a5f2df80aabcbbc6d57a930/src/lib.rs#L510-L517
423    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
424        // if choice = 0, mask = (-0) = 0000...0000
425        // if choice = 1, mask = (-1) = 1111...1111
426        let mask = Block::new((-(choice.unwrap_u8() as i128)).to_le_bytes());
427        *a ^ (mask & (*a ^ *b))
428    }
429}
430
431impl Add for Block {
432    type Output = Block;
433
434    #[inline]
435    fn add(self, rhs: Self) -> Self::Output {
436        // todo is this a sensible implementation?
437        let a: u128 = self.into();
438        let b: u128 = rhs.into();
439        Self::from(a.wrapping_add(b))
440    }
441}
442
443impl fmt::Binary for Block {
444    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
445        fmt::Binary::fmt(&u128::from(*self), f)
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use subtle::{Choice, ConditionallySelectable};
452
453    use super::Block;
454
455    #[test]
456    fn test_block_cond_select() {
457        let choice = Choice::from(0);
458        assert_eq!(
459            Block::ZERO,
460            Block::conditional_select(&Block::ZERO, &Block::ONES, choice)
461        );
462        let choice = Choice::from(1);
463        assert_eq!(
464            Block::ONES,
465            Block::conditional_select(&Block::ZERO, &Block::ONES, choice)
466        );
467    }
468
469    #[test]
470    fn test_block_low_high() {
471        let b = Block::from(1_u128);
472        assert_eq!(1, b.low());
473        assert_eq!(0, b.high());
474    }
475
476    #[test]
477    fn test_from_into_u64_arr() {
478        let b = Block::from([42_u64, 65]);
479        assert_eq!(42, b.low());
480        assert_eq!(65, b.high());
481        assert_eq!([42, 65], <[u64; 2]>::from(b));
482    }
483
484    #[test]
485    fn test_pack() {
486        let b = Block::pack(42, 123);
487        assert_eq!(42, b.low());
488        assert_eq!(123, b.high());
489    }
490
491    #[test]
492    fn test_mask_lsb() {
493        assert_eq!(Block::ONES ^ Block::ONE, Block::MASK_LSB);
494    }
495
496    #[test]
497    fn test_bits() {
498        let b: Block = 0b101_u128.into();
499        let mut iter = b.bits();
500        assert_eq!(Some(true), iter.next());
501        assert_eq!(Some(false), iter.next());
502        assert_eq!(Some(true), iter.next());
503        for rest in iter {
504            assert!(!rest);
505        }
506    }
507
508    #[test]
509    fn test_from_choices() {
510        let mut choices = vec![Choice::from(0); 128];
511        choices[2] = Choice::from(1);
512        choices[16] = Choice::from(1);
513        let blk = Block::from_choices(&choices);
514        assert_eq!(Block::from(1_u128 << 2 | 1_u128 << 16), blk);
515    }
516}