use super::keystream::Keystream;
use super::u256::U256;
use core::fmt;
use core::mem::size_of;
use core::ops::{BitXorAssign, ShlAssign};
use cipher::generic_array::typenum;
use cipher::generic_array::ArrayLength;
use cipher::generic_array::GenericArray;
use numeric_cast::TruncatingCast;
pub trait MacWord
where
Self: Sized + Copy + Eq,
Self: fmt::Debug + fmt::LowerHex + fmt::UpperHex,
Self: BitXorAssign + ShlAssign<usize>,
{
type KeyPair: MacKeyPair<Word = Self>;
type ByteSize: ArrayLength<u8>;
fn gen_word(zuc: &mut impl Keystream<Word = u32>) -> Self;
fn from_be_slice(chunk: &[u8]) -> Self;
fn to_be_array(self) -> GenericArray<u8, Self::ByteSize>;
fn test_high_bit(&self) -> bool;
}
pub trait MacKeyPair
where
Self: Sized + Copy,
Self: ShlAssign<usize>,
{
type Word: MacWord<KeyPair = Self>;
fn gen_key_pair(zuc: &mut impl Keystream<Word = u32>) -> Self;
fn high(&self) -> Self::Word;
fn set_low(&mut self, low: Self::Word);
}
impl MacWord for u32 {
type KeyPair = u64;
type ByteSize = typenum::U4;
#[inline(always)]
fn gen_word(zuc: &mut impl Keystream<Word = u32>) -> u32 {
zuc.next_key()
}
#[inline(always)]
fn from_be_slice(chunk: &[u8]) -> u32 {
match chunk.try_into() {
Ok(arr) => u32::from_be_bytes(arr),
Err(_) => unreachable!(),
}
}
#[inline(always)]
fn to_be_array(self) -> GenericArray<u8, Self::ByteSize> {
GenericArray::from(self.to_be_bytes())
}
#[inline(always)]
fn test_high_bit(&self) -> bool {
let high_bit: u32 = 1 << (32 - 1);
(*self & high_bit) != 0
}
}
impl MacKeyPair for u64 {
type Word = u32;
#[inline(always)]
fn gen_key_pair(zuc: &mut impl Keystream<Word = u32>) -> u64 {
u64::gen_word(zuc)
}
#[inline(always)]
fn high(&self) -> u32 {
(self >> 32) as u32
}
#[inline(always)]
fn set_low(&mut self, low: Self::Word) {
*self |= Self::from(low);
}
}
impl MacWord for u64 {
type KeyPair = u128;
type ByteSize = typenum::U8;
#[inline(always)]
fn gen_word(zuc: &mut impl Keystream<Word = u32>) -> u64 {
(u64::from(zuc.next_key()) << 32) | u64::from(zuc.next_key())
}
#[inline(always)]
fn from_be_slice(chunk: &[u8]) -> u64 {
match chunk.try_into() {
Ok(arr) => u64::from_be_bytes(arr),
Err(_) => unreachable!(),
}
}
#[inline(always)]
fn to_be_array(self) -> GenericArray<u8, Self::ByteSize> {
GenericArray::from(self.to_be_bytes())
}
#[inline(always)]
fn test_high_bit(&self) -> bool {
let high_bit: u64 = 1 << (64 - 1);
(*self & high_bit) != 0
}
}
impl MacKeyPair for u128 {
type Word = u64;
#[inline(always)]
fn gen_key_pair(zuc: &mut impl Keystream<Word = u32>) -> u128 {
u128::gen_word(zuc)
}
#[inline(always)]
fn high(&self) -> u64 {
(self >> 64) as u64
}
#[inline(always)]
fn set_low(&mut self, low: Self::Word) {
*self |= Self::from(low);
}
}
impl MacWord for u128 {
type KeyPair = U256;
type ByteSize = typenum::U16;
#[inline(always)]
fn gen_word(zuc: &mut impl Keystream<Word = u32>) -> u128 {
let a = (
u128::from(zuc.next_key()) << 96,
u128::from(zuc.next_key()) << 64,
u128::from(zuc.next_key()) << 32,
u128::from(zuc.next_key()),
);
a.0 | a.1 | a.2 | a.3
}
#[inline(always)]
fn from_be_slice(chunk: &[u8]) -> u128 {
match chunk.try_into() {
Ok(arr) => u128::from_be_bytes(arr),
Err(_) => unreachable!(),
}
}
#[inline(always)]
fn to_be_array(self) -> GenericArray<u8, Self::ByteSize> {
GenericArray::from(self.to_be_bytes())
}
#[inline(always)]
fn test_high_bit(&self) -> bool {
let high_bit: u128 = 1 << (128 - 1);
(*self & high_bit) != 0
}
}
impl MacKeyPair for U256 {
type Word = u128;
fn gen_key_pair(zuc: &mut impl Keystream<Word = u32>) -> U256 {
let high = u128::gen_word(zuc);
let low = u128::gen_word(zuc);
U256::new(high, low)
}
fn high(&self) -> u128 {
self.high
}
fn set_low(&mut self, low: Self::Word) {
self.low = low;
}
}
#[inline(always)]
fn copy(dst: &mut [u8], src: &[u8]) {
dst[..src.len()].copy_from_slice(src);
}
pub struct MacCore<S, T>
where
S: Keystream,
T: MacWord,
{
pub zuc: S,
pub key: T::KeyPair,
pub tag: T,
pub rem: GenericArray<u8, T::ByteSize>,
pub cnt: u8,
}
impl<S, T> MacCore<S, T>
where
S: Keystream<Word = u32>,
T: MacWord,
{
#[inline(always)]
fn xor_step(bits: &mut T, tag: &mut T, key: &mut T::KeyPair) {
if bits.test_high_bit() {
*tag ^= key.high();
}
*bits <<= 1;
*key <<= 1;
}
#[inline(always)]
fn feed_word(mut bits: T, tag: &mut T, key: &mut T::KeyPair, zuc: &mut S) {
for _ in 0..size_of::<T>() * 8 {
Self::xor_step(&mut bits, tag, key);
}
key.set_low(T::gen_word(zuc));
}
pub fn update(&mut self, mut msg: &[u8]) {
if msg.is_empty() {
return;
}
let zuc = &mut self.zuc;
let mut key = self.key;
let mut tag = self.tag;
let rem = self.rem.as_mut_slice();
let cnt = self.cnt as usize;
if cnt > 0 {
if cnt + msg.len() < size_of::<T>() {
copy(&mut rem[cnt..], msg);
self.cnt += msg.len().truncating_cast::<u8>();
return;
}
let (head, tail) = msg.split_at(size_of::<T>() - cnt);
copy(&mut rem[cnt..], head);
msg = tail;
let bits = T::from_be_slice(rem);
Self::feed_word(bits, &mut tag, &mut key, zuc);
}
let mut chunks = msg.chunks_exact(size_of::<T>());
for chunk in &mut chunks {
let bits = T::from_be_slice(chunk);
Self::feed_word(bits, &mut tag, &mut key, zuc);
}
{
let rest = chunks.remainder();
copy(rem, rest);
self.cnt = rest.len().truncating_cast::<u8>();
}
self.key = key;
self.tag = tag;
}
#[must_use]
pub fn finish(&mut self, mut tail: &[u8], mut bitlen: usize) -> usize {
assert!(bitlen <= tail.len() * 8);
if bitlen >= 8 {
self.update(&tail[..(bitlen / 8)]);
tail = &tail[(bitlen / 8)..];
bitlen %= 8;
}
let mut key = self.key;
let mut tag = self.tag;
let rem = self.rem.as_mut_slice();
let cnt = self.cnt as usize;
if bitlen != 0 {
rem[cnt] = tail[0];
}
let bitlen = cnt * 8 + bitlen;
if bitlen != 0 {
let mut bits = T::from_be_slice(rem);
for _ in 0..bitlen {
Self::xor_step(&mut bits, &mut tag, &mut key);
}
self.tag = tag;
self.key = key;
}
bitlen
}
}