use std::ops::{BitAnd, BitOr, BitXor, Index, IndexMut, Mul, Shl, Shr};
#[cfg(feature = "simd")]
use std::simd::prelude::*;
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg(feature = "simd")]
pub struct Block(u8x16);
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg(not(feature = "simd"))]
pub struct Block([u8; 16]);
impl Block {
pub fn null() -> Block {
Block([0; 16].into())
}
pub fn one() -> Block {
Block([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].into())
}
pub fn bytes(&self) -> [u8; 16] {
self.0.into()
}
pub fn write_to(&self, output: &mut [u8; 16]) {
#[cfg(feature = "simd")]
self.0.copy_to_slice(output);
#[cfg(not(feature = "simd"))]
output.copy_from_slice(&self.0);
}
#[cfg(feature = "simd")]
pub(crate) fn simd(&self) -> u8x16 {
self.0
}
#[cfg(feature = "simd")]
pub(crate) fn from_simd(value: u8x16) -> Self {
Block(value)
}
pub fn from_slice(value: &[u8]) -> Self {
let len = value.len().min(16);
let mut array = [0; 16];
array[..len].copy_from_slice(&value[..len]);
Block(array.into())
}
pub fn from_int<I: Into<u128>>(value: I) -> Self {
Block(value.into().to_be_bytes().into())
}
pub fn to_int(&self) -> u128 {
u128::from_be_bytes(self.0.into())
}
pub fn pad(&self, length: usize) -> Block {
assert!(length <= 127);
let mut result = *self;
result[length / 8] |= 1 << (7 - length % 8);
result
}
pub fn clip(&self, length: usize) -> Block {
match length {
0 => Block::default(),
_ => Block::from_int(self.to_int() & (u128::MAX << (128 - length))),
}
}
pub fn exp(&self, exponent: u32) -> Block {
match exponent {
_ if exponent < 32 => *self * (1 << exponent),
_ if exponent % 2 == 0 => self.exp(exponent / 2).exp(exponent / 2),
_ => (*self * 2).exp(exponent - 1),
}
}
}
impl From<[u8; 16]> for Block {
fn from(value: [u8; 16]) -> Block {
Block(value.into())
}
}
impl From<&[u8; 16]> for Block {
fn from(value: &[u8; 16]) -> Block {
Block((*value).into())
}
}
impl From<u128> for Block {
fn from(value: u128) -> Block {
Block(value.to_be_bytes().into())
}
}
impl BitXor<Block> for Block {
type Output = Block;
#[cfg(feature = "simd")]
fn bitxor(self, rhs: Block) -> Block {
Block(self.0 ^ rhs.0)
}
#[cfg(not(feature = "simd"))]
fn bitxor(self, rhs: Block) -> Block {
Block([
self.0[0] ^ rhs.0[0],
self.0[1] ^ rhs.0[1],
self.0[2] ^ rhs.0[2],
self.0[3] ^ rhs.0[3],
self.0[4] ^ rhs.0[4],
self.0[5] ^ rhs.0[5],
self.0[6] ^ rhs.0[6],
self.0[7] ^ rhs.0[7],
self.0[8] ^ rhs.0[8],
self.0[9] ^ rhs.0[9],
self.0[10] ^ rhs.0[10],
self.0[11] ^ rhs.0[11],
self.0[12] ^ rhs.0[12],
self.0[13] ^ rhs.0[13],
self.0[14] ^ rhs.0[14],
self.0[15] ^ rhs.0[15],
])
}
}
impl Shl<u32> for Block {
type Output = Block;
fn shl(self, rhs: u32) -> Block {
#[cfg(feature = "simd")]
{
if rhs == 1 {
return Block((self.0 << 1) | (self.0.shift_elements_left::<1>(0) >> 7));
}
}
Block::from(self.to_int() << rhs)
}
}
impl Shr<u32> for Block {
type Output = Block;
fn shr(self, rhs: u32) -> Block {
Block::from(self.to_int() >> rhs)
}
}
impl BitAnd<Block> for Block {
type Output = Block;
fn bitand(self, rhs: Block) -> Block {
#[cfg(feature = "simd")]
{
Block(self.0 & rhs.0)
}
#[cfg(not(feature = "simd"))]
{
Block::from(self.to_int() & rhs.to_int())
}
}
}
impl BitOr<Block> for Block {
type Output = Block;
fn bitor(self, rhs: Block) -> Block {
#[cfg(feature = "simd")]
{
Block(self.0 | rhs.0)
}
#[cfg(not(feature = "simd"))]
{
Block::from(self.to_int() | rhs.to_int())
}
}
}
impl Index<usize> for Block {
type Output = u8;
fn index(&self, index: usize) -> &u8 {
&self.0[index]
}
}
impl IndexMut<usize> for Block {
fn index_mut(&mut self, index: usize) -> &mut u8 {
&mut self.0[index]
}
}
impl Mul<u32> for Block {
type Output = Block;
fn mul(self, rhs: u32) -> Block {
match rhs {
0 => Block::null(),
1 => self,
2 => {
let mut result = self << 1;
if self[0] & 0x80 != 0 {
result[15] ^= 0x87;
}
result
}
_ if rhs % 2 == 0 => self * 2 * (rhs / 2),
_ => self * (rhs - 1) ^ self,
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_xor() {
assert_eq!(
Block::from([1; 16]) ^ Block::from([2; 16]),
Block::from([3; 16])
);
}
#[test]
fn test_pad() {
assert_eq!(
Block::from([0; 16]).pad(0),
Block::from([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
);
assert_eq!(
Block::from([0; 16]).pad(1),
Block::from([0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
);
assert_eq!(
Block::from([0; 16]).pad(8),
Block::from([0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
);
}
#[test]
fn test_shl() {
assert_eq!(
Block::from([0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 1,
Block::from([0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
);
assert_eq!(
Block::from([0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 4,
Block::from([0x10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
);
assert_eq!(
Block::from([0x0A, 0xB0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 4,
Block::from([0xAB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
);
assert_eq!(
Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) << 8,
Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]),
);
}
#[test]
fn test_times() {
assert_eq!(
Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) * 0,
Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
);
assert_eq!(
Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) * 1,
Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]),
);
assert_eq!(
Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) * 2,
Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]),
);
assert_eq!(
Block::from([128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) * 2,
Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 133]),
);
assert_eq!(
Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 2,
Block::from([2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 133]),
);
assert_eq!(
Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 3,
Block::from([131, 0, 0, 0, 1, 128, 0, 0, 0, 3, 0, 0, 0, 0, 0, 132]),
);
assert_eq!(
Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 4,
Block::from([4, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 1, 10]),
);
}
}