use core::{arch::aarch64::*, mem};
use universal_hash::{
consts::{U1, U16},
crypto_common::{BlockSizeUser, KeySizeUser, ParBlocksSizeUser},
KeyInit, Reset, UhfBackend,
};
use crate::{Block, Key, Tag};
#[derive(Clone)]
pub struct Polyval {
h: uint8x16_t,
y: uint8x16_t,
}
impl KeySizeUser for Polyval {
type KeySize = U16;
}
impl Polyval {
pub fn new_with_init_block(h: &Key, init_block: u128) -> Self {
unsafe {
Self {
h: vld1q_u8(h.as_ptr()),
y: vld1q_u8(init_block.to_be_bytes()[..].as_ptr()),
}
}
}
}
impl KeyInit for Polyval {
fn new(h: &Key) -> Self {
Self::new_with_init_block(h, 0)
}
}
impl BlockSizeUser for Polyval {
type BlockSize = U16;
}
impl ParBlocksSizeUser for Polyval {
type ParBlocksSize = U1;
}
impl UhfBackend for Polyval {
fn proc_block(&mut self, x: &Block) {
unsafe {
self.mul(x);
}
}
}
impl Reset for Polyval {
fn reset(&mut self) {
unsafe {
self.y = vdupq_n_u8(0);
}
}
}
impl Polyval {
pub(crate) fn finalize(self) -> Tag {
unsafe { mem::transmute(self.y) }
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn mul(&mut self, x: &Block) {
let y = veorq_u8(self.y, vld1q_u8(x.as_ptr()));
let (h, m, l) = karatsuba1(self.h, y);
let (h, l) = karatsuba2(h, m, l);
self.y = mont_reduce(h, l);
}
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn karatsuba1(x: uint8x16_t, y: uint8x16_t) -> (uint8x16_t, uint8x16_t, uint8x16_t) {
let m = pmull(
veorq_u8(x, vextq_u8(x, x, 8)), veorq_u8(y, vextq_u8(y, y, 8)), );
let h = pmull2(x, y); let l = pmull(x, y); (h, m, l)
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn karatsuba2(h: uint8x16_t, m: uint8x16_t, l: uint8x16_t) -> (uint8x16_t, uint8x16_t) {
let t = {
let t0 = veorq_u8(m, vextq_u8(l, h, 8));
let t1 = veorq_u8(h, l);
veorq_u8(t0, t1)
};
let x01 = vextq_u8(
vextq_u8(l, l, 8), t,
8,
);
let x23 = vextq_u8(
t,
vextq_u8(h, h, 8), 8,
);
(x23, x01)
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn mont_reduce(x23: uint8x16_t, x01: uint8x16_t) -> uint8x16_t {
let poly = vreinterpretq_u8_p128(1 << 127 | 1 << 126 | 1 << 121 | 1 << 63 | 1 << 62 | 1 << 57);
let a = pmull(x01, poly);
let b = veorq_u8(x01, vextq_u8(a, a, 8));
let c = pmull2(b, poly);
veorq_u8(x23, veorq_u8(c, b))
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn pmull(a: uint8x16_t, b: uint8x16_t) -> uint8x16_t {
mem::transmute(vmull_p64(
vgetq_lane_u64(vreinterpretq_u64_u8(a), 0),
vgetq_lane_u64(vreinterpretq_u64_u8(b), 0),
))
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn pmull2(a: uint8x16_t, b: uint8x16_t) -> uint8x16_t {
mem::transmute(vmull_p64(
vgetq_lane_u64(vreinterpretq_u64_u8(a), 1),
vgetq_lane_u64(vreinterpretq_u64_u8(b), 1),
))
}