#![allow(clippy::indexing_slicing)]
use crate::backend::cache::OnceCache;
#[cfg(feature = "aes-gcm-siv")]
pub(crate) const BLOCK_SIZE: usize = 16;
#[cfg(feature = "aes-gcm-siv")]
pub(crate) const KEY_SIZE: usize = 16;
#[cfg(test)]
const POLY: u128 = (1u128 << 127) | (1u128 << 126) | (1u128 << 121) | 1;
#[cfg(target_arch = "x86_64")]
mod pclmul {
use core::arch::x86_64::*;
#[target_feature(enable = "pclmulqdq,sse2")]
pub(super) unsafe fn clmul128_reduce(a: u128, b: u128) -> u128 {
unsafe {
let a_xmm = _mm_loadu_si128((&a as *const u128).cast());
let b_xmm = _mm_loadu_si128((&b as *const u128).cast());
let lo = _mm_clmulepi64_si128(a_xmm, b_xmm, 0x00); let hi = _mm_clmulepi64_si128(a_xmm, b_xmm, 0x11); let m1 = _mm_clmulepi64_si128(a_xmm, b_xmm, 0x10); let m2 = _mm_clmulepi64_si128(a_xmm, b_xmm, 0x01); let mid = _mm_xor_si128(m1, m2);
let lo_128 = _mm_xor_si128(lo, _mm_slli_si128(mid, 8));
let hi_128 = _mm_xor_si128(hi, _mm_srli_si128(mid, 8));
let result = mont_reduce_sse2(lo_128, hi_128);
let mut out = 0u128;
_mm_storeu_si128((&mut out as *mut u128).cast(), result);
out
}
}
#[inline]
pub(super) unsafe fn mont_reduce_sse2(lo: __m128i, hi: __m128i) -> __m128i {
unsafe {
let left = _mm_xor_si128(
_mm_xor_si128(_mm_slli_epi64(lo, 63), _mm_slli_epi64(lo, 62)),
_mm_slli_epi64(lo, 57),
);
let lo_folded = _mm_xor_si128(lo, _mm_slli_si128(left, 8));
let right = _mm_xor_si128(
_mm_xor_si128(lo_folded, _mm_srli_epi64(lo_folded, 1)),
_mm_xor_si128(_mm_srli_epi64(lo_folded, 2), _mm_srli_epi64(lo_folded, 7)),
);
let left2 = _mm_xor_si128(
_mm_xor_si128(_mm_slli_epi64(lo_folded, 63), _mm_slli_epi64(lo_folded, 62)),
_mm_slli_epi64(lo_folded, 57),
);
_mm_xor_si128(_mm_xor_si128(hi, right), _mm_srli_si128(left2, 8))
}
}
}
#[cfg(target_arch = "aarch64")]
mod pmull {
use core::arch::aarch64::*;
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
pub(super) unsafe fn clmul128_reduce_core(a: u128, b: u128) -> u128 {
unsafe {
let a_lo = a as u64;
let a_hi = (a >> 64) as u64;
let b_lo = b as u64;
let b_hi = (b >> 64) as u64;
let ll = vreinterpretq_u64_p128(vmull_p64(a_lo, b_lo));
let hh = vreinterpretq_u64_p128(vmull_p64(a_hi, b_hi));
let lh = vreinterpretq_u64_p128(vmull_p64(a_lo, b_hi));
let hl = vreinterpretq_u64_p128(vmull_p64(a_hi, b_lo));
let mid = veorq_u64(lh, hl);
let zero = vdupq_n_u64(0);
let lo = veorq_u64(ll, vextq_u64(zero, mid, 1)); let hi = veorq_u64(hh, vextq_u64(mid, zero, 1));
let result = mont_reduce_neon(lo, hi);
let r_lo = vgetq_lane_u64(result, 0) as u128;
let r_hi = vgetq_lane_u64(result, 1) as u128;
r_lo | (r_hi << 64)
}
}
#[target_feature(enable = "neon", enable = "aes")]
pub(super) unsafe fn clmul128_reduce(a: u128, b: u128) -> u128 {
unsafe { clmul128_reduce_core(a, b) }
}
#[inline]
unsafe fn mont_reduce_neon(lo: uint64x2_t, hi: uint64x2_t) -> uint64x2_t {
unsafe {
let zero = vdupq_n_u64(0);
let left = veorq_u64(veorq_u64(vshlq_n_u64(lo, 63), vshlq_n_u64(lo, 62)), vshlq_n_u64(lo, 57));
let lo_folded = veorq_u64(lo, vextq_u64(zero, left, 1));
let right = veorq_u64(
veorq_u64(lo_folded, vshrq_n_u64(lo_folded, 1)),
veorq_u64(vshrq_n_u64(lo_folded, 2), vshrq_n_u64(lo_folded, 7)),
);
let left2 = veorq_u64(
veorq_u64(vshlq_n_u64(lo_folded, 63), vshlq_n_u64(lo_folded, 62)),
vshlq_n_u64(lo_folded, 57),
);
veorq_u64(veorq_u64(hi, right), vextq_u64(left2, zero, 1))
}
}
#[cfg(feature = "aes-gcm-siv")]
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
pub(super) unsafe fn aggregate_4blocks(acc: u128, h_powers_rev: &[u128; 4], blocks: &[u128; 4]) -> u128 {
unsafe {
let b0 = acc ^ blocks[0];
let b1 = blocks[1];
let b2 = blocks[2];
let b3 = blocks[3];
let ll0 = vreinterpretq_u64_p128(vmull_p64(b0 as u64, h_powers_rev[0] as u64));
let hh0 = vreinterpretq_u64_p128(vmull_p64((b0 >> 64) as u64, (h_powers_rev[0] >> 64) as u64));
let lh0 = vreinterpretq_u64_p128(vmull_p64(b0 as u64, (h_powers_rev[0] >> 64) as u64));
let hl0 = vreinterpretq_u64_p128(vmull_p64((b0 >> 64) as u64, h_powers_rev[0] as u64));
let ll1 = vreinterpretq_u64_p128(vmull_p64(b1 as u64, h_powers_rev[1] as u64));
let hh1 = vreinterpretq_u64_p128(vmull_p64((b1 >> 64) as u64, (h_powers_rev[1] >> 64) as u64));
let lh1 = vreinterpretq_u64_p128(vmull_p64(b1 as u64, (h_powers_rev[1] >> 64) as u64));
let hl1 = vreinterpretq_u64_p128(vmull_p64((b1 >> 64) as u64, h_powers_rev[1] as u64));
let ll2 = vreinterpretq_u64_p128(vmull_p64(b2 as u64, h_powers_rev[2] as u64));
let hh2 = vreinterpretq_u64_p128(vmull_p64((b2 >> 64) as u64, (h_powers_rev[2] >> 64) as u64));
let lh2 = vreinterpretq_u64_p128(vmull_p64(b2 as u64, (h_powers_rev[2] >> 64) as u64));
let hl2 = vreinterpretq_u64_p128(vmull_p64((b2 >> 64) as u64, h_powers_rev[2] as u64));
let ll3 = vreinterpretq_u64_p128(vmull_p64(b3 as u64, h_powers_rev[3] as u64));
let hh3 = vreinterpretq_u64_p128(vmull_p64((b3 >> 64) as u64, (h_powers_rev[3] >> 64) as u64));
let lh3 = vreinterpretq_u64_p128(vmull_p64(b3 as u64, (h_powers_rev[3] >> 64) as u64));
let hl3 = vreinterpretq_u64_p128(vmull_p64((b3 >> 64) as u64, h_powers_rev[3] as u64));
let ll = veorq_u64(veorq_u64(ll0, ll1), veorq_u64(ll2, ll3));
let hh = veorq_u64(veorq_u64(hh0, hh1), veorq_u64(hh2, hh3));
let lh = veorq_u64(veorq_u64(lh0, lh1), veorq_u64(lh2, lh3));
let hl = veorq_u64(veorq_u64(hl0, hl1), veorq_u64(hl2, hl3));
let mid = veorq_u64(lh, hl);
let zero = vdupq_n_u64(0);
let lo = veorq_u64(ll, vextq_u64(zero, mid, 1));
let hi = veorq_u64(hh, vextq_u64(mid, zero, 1));
let result = mont_reduce_neon(lo, hi);
let r_lo = vgetq_lane_u64(result, 0) as u128;
let r_hi = vgetq_lane_u64(result, 1) as u128;
r_lo | (r_hi << 64)
}
}
}
#[cfg(target_arch = "s390x")]
#[allow(unsafe_code)]
mod s390x_vgfm {
use core::{arch::asm, simd::i64x2};
#[inline]
#[target_feature(enable = "vector")]
unsafe fn mul64(a: u64, b: u64) -> i64x2 {
let va = i64x2::from_array([0, a as i64]);
let vb = i64x2::from_array([0, b as i64]);
unsafe {
let out: i64x2;
asm!(
"vgfm {out}, {a}, {b}, 3",
out = lateout(vreg) out,
a = in(vreg) va,
b = in(vreg) vb,
options(nomem, nostack, pure),
);
out
}
}
#[inline]
#[target_feature(enable = "vector")]
unsafe fn veslg<const N: u32>(a: i64x2) -> i64x2 {
unsafe {
let out: i64x2;
asm!(
"veslg {out}, {a}, {n}",
out = lateout(vreg) out,
a = in(vreg) a,
n = const N,
options(nomem, nostack, pure),
);
out
}
}
#[inline]
#[target_feature(enable = "vector")]
unsafe fn vesrlg<const N: u32>(a: i64x2) -> i64x2 {
unsafe {
let out: i64x2;
asm!(
"vesrlg {out}, {a}, {n}",
out = lateout(vreg) out,
a = in(vreg) a,
n = const N,
options(nomem, nostack, pure),
);
out
}
}
#[inline]
#[target_feature(enable = "vector")]
unsafe fn vsldb<const N: u32>(a: i64x2, b: i64x2) -> i64x2 {
unsafe {
let out: i64x2;
asm!(
"vsldb {out}, {a}, {b}, {n}",
out = lateout(vreg) out,
a = in(vreg) a,
b = in(vreg) b,
n = const N,
options(nomem, nostack, pure),
);
out
}
}
#[inline]
#[target_feature(enable = "vector")]
unsafe fn mont_reduce(lo: i64x2, hi: i64x2) -> i64x2 {
unsafe {
let zero = i64x2::from_array([0, 0]);
let left = veslg::<63>(lo) ^ veslg::<62>(lo) ^ veslg::<57>(lo);
let lo_folded = lo ^ vsldb::<8>(left, zero);
let right = lo_folded ^ vesrlg::<1>(lo_folded) ^ vesrlg::<2>(lo_folded) ^ vesrlg::<7>(lo_folded);
let left2 = veslg::<63>(lo_folded) ^ veslg::<62>(lo_folded) ^ veslg::<57>(lo_folded);
hi ^ right ^ vsldb::<8>(zero, left2)
}
}
#[target_feature(enable = "vector")]
#[inline]
pub(super) unsafe fn clmul128_reduce_core(a: u128, b: u128) -> u128 {
unsafe {
let a_lo = a as u64;
let a_hi = (a >> 64) as u64;
let b_lo = b as u64;
let b_hi = (b >> 64) as u64;
let zero = i64x2::from_array([0, 0]);
let v0 = mul64(a_lo, b_lo);
let v1 = mul64(a_hi, b_hi);
let v2 = mul64(a_lo ^ a_hi, b_lo ^ b_hi);
let mid = v2 ^ v0 ^ v1;
let lo_128 = v0 ^ vsldb::<8>(mid, zero);
let hi_128 = v1 ^ vsldb::<8>(zero, mid);
let result = mont_reduce(lo_128, hi_128);
let arr = result.to_array();
((arr[0] as u64 as u128) << 64) | (arr[1] as u64 as u128)
}
}
#[target_feature(enable = "vector")]
pub(super) unsafe fn clmul128_reduce(a: u128, b: u128) -> u128 {
unsafe { clmul128_reduce_core(a, b) }
}
#[cfg(feature = "aes-gcm-siv")]
#[target_feature(enable = "vector")]
#[inline]
pub(super) unsafe fn aggregate_4blocks(acc: u128, h_powers_rev: &[u128; 4], blocks: &[u128; 4]) -> u128 {
unsafe {
let b0 = acc ^ blocks[0];
let b1 = blocks[1];
let b2 = blocks[2];
let b3 = blocks[3];
let zero = i64x2::from_array([0, 0]);
let v0_0 = mul64(b0 as u64, h_powers_rev[0] as u64);
let v1_0 = mul64((b0 >> 64) as u64, (h_powers_rev[0] >> 64) as u64);
let v2_0 = mul64(
b0 as u64 ^ (b0 >> 64) as u64,
h_powers_rev[0] as u64 ^ (h_powers_rev[0] >> 64) as u64,
);
let v0_1 = mul64(b1 as u64, h_powers_rev[1] as u64);
let v1_1 = mul64((b1 >> 64) as u64, (h_powers_rev[1] >> 64) as u64);
let v2_1 = mul64(
b1 as u64 ^ (b1 >> 64) as u64,
h_powers_rev[1] as u64 ^ (h_powers_rev[1] >> 64) as u64,
);
let v0_2 = mul64(b2 as u64, h_powers_rev[2] as u64);
let v1_2 = mul64((b2 >> 64) as u64, (h_powers_rev[2] >> 64) as u64);
let v2_2 = mul64(
b2 as u64 ^ (b2 >> 64) as u64,
h_powers_rev[2] as u64 ^ (h_powers_rev[2] >> 64) as u64,
);
let v0_3 = mul64(b3 as u64, h_powers_rev[3] as u64);
let v1_3 = mul64((b3 >> 64) as u64, (h_powers_rev[3] >> 64) as u64);
let v2_3 = mul64(
b3 as u64 ^ (b3 >> 64) as u64,
h_powers_rev[3] as u64 ^ (h_powers_rev[3] >> 64) as u64,
);
let v0 = v0_0 ^ v0_1 ^ v0_2 ^ v0_3;
let v1 = v1_0 ^ v1_1 ^ v1_2 ^ v1_3;
let v2 = v2_0 ^ v2_1 ^ v2_2 ^ v2_3;
let mid = v2 ^ v0 ^ v1;
let lo_128 = v0 ^ vsldb::<8>(mid, zero);
let hi_128 = v1 ^ vsldb::<8>(zero, mid);
let result = mont_reduce(lo_128, hi_128);
let arr = result.to_array();
((arr[0] as u64 as u128) << 64) | (arr[1] as u64 as u128)
}
}
}
#[cfg(target_arch = "powerpc64")]
#[allow(unsafe_code)]
mod ppc_vpmsum {
use core::{arch::asm, simd::i64x2};
#[inline]
#[target_feature(enable = "altivec,vsx,power8-vector,power8-crypto")]
unsafe fn mul64(a: u64, b: u64) -> (u64, u64) {
let va = i64x2::from_array([a as i64, 0]);
let vb = i64x2::from_array([b as i64, 0]);
unsafe {
let out: i64x2;
asm!(
"vpmsumd {out}, {a}, {b}",
out = lateout(vreg) out,
a = in(vreg) va,
b = in(vreg) vb,
options(nomem, nostack, pure),
);
let [lo, hi] = out.to_array();
(lo as u64, hi as u64)
}
}
#[target_feature(enable = "altivec,vsx,power8-vector,power8-crypto")]
#[inline]
pub(super) unsafe fn clmul128_reduce_core(a: u128, b: u128) -> u128 {
unsafe {
let a_lo = a as u64;
let a_hi = (a >> 64) as u64;
let b_lo = b as u64;
let b_hi = (b >> 64) as u64;
let (v0_lo, v0_hi) = mul64(a_lo, b_lo);
let (v1_lo, v1_hi) = mul64(a_hi, b_hi);
let (v2_lo, v2_hi) = mul64(a_lo ^ a_hi, b_lo ^ b_hi);
let mid_lo = v2_lo ^ v0_lo ^ v1_lo;
let mid_hi = v2_hi ^ v0_hi ^ v1_hi;
super::mont_reduce([v0_lo, v0_hi ^ mid_lo, v1_lo ^ mid_hi, v1_hi])
}
}
#[target_feature(enable = "altivec,vsx,power8-vector,power8-crypto")]
pub(super) unsafe fn clmul128_reduce(a: u128, b: u128) -> u128 {
unsafe { clmul128_reduce_core(a, b) }
}
#[cfg(feature = "aes-gcm-siv")]
#[target_feature(enable = "altivec,vsx,power8-vector,power8-crypto")]
#[inline]
pub(super) unsafe fn aggregate_4blocks(acc: u128, h_powers_rev: &[u128; 4], blocks: &[u128; 4]) -> u128 {
unsafe {
let b0 = acc ^ blocks[0];
let b1 = blocks[1];
let b2 = blocks[2];
let b3 = blocks[3];
let (z0_0l, z0_0h) = mul64(b0 as u64, h_powers_rev[0] as u64);
let (z1_0l, z1_0h) = mul64((b0 >> 64) as u64, (h_powers_rev[0] >> 64) as u64);
let (z2_0l, z2_0h) = mul64(
b0 as u64 ^ (b0 >> 64) as u64,
h_powers_rev[0] as u64 ^ (h_powers_rev[0] >> 64) as u64,
);
let (z0_1l, z0_1h) = mul64(b1 as u64, h_powers_rev[1] as u64);
let (z1_1l, z1_1h) = mul64((b1 >> 64) as u64, (h_powers_rev[1] >> 64) as u64);
let (z2_1l, z2_1h) = mul64(
b1 as u64 ^ (b1 >> 64) as u64,
h_powers_rev[1] as u64 ^ (h_powers_rev[1] >> 64) as u64,
);
let (z0_2l, z0_2h) = mul64(b2 as u64, h_powers_rev[2] as u64);
let (z1_2l, z1_2h) = mul64((b2 >> 64) as u64, (h_powers_rev[2] >> 64) as u64);
let (z2_2l, z2_2h) = mul64(
b2 as u64 ^ (b2 >> 64) as u64,
h_powers_rev[2] as u64 ^ (h_powers_rev[2] >> 64) as u64,
);
let (z0_3l, z0_3h) = mul64(b3 as u64, h_powers_rev[3] as u64);
let (z1_3l, z1_3h) = mul64((b3 >> 64) as u64, (h_powers_rev[3] >> 64) as u64);
let (z2_3l, z2_3h) = mul64(
b3 as u64 ^ (b3 >> 64) as u64,
h_powers_rev[3] as u64 ^ (h_powers_rev[3] >> 64) as u64,
);
let z0_lo = z0_0l ^ z0_1l ^ z0_2l ^ z0_3l;
let z0_hi = z0_0h ^ z0_1h ^ z0_2h ^ z0_3h;
let z1_lo = z1_0l ^ z1_1l ^ z1_2l ^ z1_3l;
let z1_hi = z1_0h ^ z1_1h ^ z1_2h ^ z1_3h;
let z2_lo = z2_0l ^ z2_1l ^ z2_2l ^ z2_3l;
let z2_hi = z2_0h ^ z2_1h ^ z2_2h ^ z2_3h;
let mid_lo = z2_lo ^ z0_lo ^ z1_lo;
let mid_hi = z2_hi ^ z0_hi ^ z1_hi;
super::mont_reduce([z0_lo, z0_hi ^ mid_lo, z1_lo ^ mid_hi, z1_hi])
}
}
}
#[cfg(target_arch = "riscv64")]
#[allow(unsafe_code)]
mod rv_clmul {
use core::arch::asm;
#[inline]
#[target_feature(enable = "v", enable = "zvbc")]
unsafe fn mul64(a: u64, b: u64) -> (u64, u64) {
unsafe {
let lo: u64;
let hi: u64;
asm!(
"vsetivli zero, 1, e64, m1, ta, ma",
"vmv.v.x v0, {a}",
"vclmul.vx v1, v0, {b}",
"vclmulh.vx v2, v0, {b}",
"vmv.x.s {lo}, v1",
"vmv.x.s {hi}, v2",
a = in(reg) a,
b = in(reg) b,
lo = lateout(reg) lo,
hi = lateout(reg) hi,
out("v0") _,
out("v1") _,
out("v2") _,
options(nostack),
);
(lo, hi)
}
}
#[target_feature(enable = "v", enable = "zvbc")]
pub(super) unsafe fn clmul128_reduce(a: u128, b: u128) -> u128 {
unsafe {
let a_lo = a as u64;
let a_hi = (a >> 64) as u64;
let b_lo = b as u64;
let b_hi = (b >> 64) as u64;
let (v0_lo, v0_hi) = mul64(a_lo, b_lo);
let (v1_lo, v1_hi) = mul64(a_hi, b_hi);
let (v2_lo, v2_hi) = mul64(a_lo ^ a_hi, b_lo ^ b_hi);
let mid_lo = v2_lo ^ v0_lo ^ v1_lo;
let mid_hi = v2_hi ^ v0_hi ^ v1_hi;
let w0 = v0_lo;
let w1 = v0_hi ^ mid_lo;
let w2 = v1_lo ^ mid_hi;
let w3 = v1_hi;
super::mont_reduce([w0, w1, w2, w3])
}
}
}
#[cfg(target_arch = "riscv64")]
#[allow(unsafe_code)]
mod rv_scalar_clmul {
use core::arch::asm;
#[inline]
#[target_feature(enable = "zbc")]
unsafe fn mul64(a: u64, b: u64) -> (u64, u64) {
unsafe {
let lo: u64;
let hi: u64;
asm!(
"clmul {lo}, {a}, {b}",
"clmulh {hi}, {a}, {b}",
a = in(reg) a,
b = in(reg) b,
lo = lateout(reg) lo,
hi = lateout(reg) hi,
options(nomem, nostack, pure),
);
(lo, hi)
}
}
#[target_feature(enable = "zbc")]
pub(super) unsafe fn clmul128_reduce(a: u128, b: u128) -> u128 {
unsafe {
let a_lo = a as u64;
let a_hi = (a >> 64) as u64;
let b_lo = b as u64;
let b_hi = (b >> 64) as u64;
let (v0_lo, v0_hi) = mul64(a_lo, b_lo);
let (v1_lo, v1_hi) = mul64(a_hi, b_hi);
let (v2_lo, v2_hi) = mul64(a_lo ^ a_hi, b_lo ^ b_hi);
let mid_lo = v2_lo ^ v0_lo ^ v1_lo;
let mid_hi = v2_hi ^ v0_hi ^ v1_hi;
super::mont_reduce([v0_lo, v0_hi ^ mid_lo, v1_lo ^ mid_hi, v1_hi])
}
}
}
#[cfg(target_arch = "x86_64")]
mod vpclmul {
use core::arch::x86_64::*;
#[cfg(any(feature = "aes-gcm", feature = "aes-gcm-siv"))]
#[target_feature(enable = "avx512f,avx512vl,avx512bw,avx512dq,vpclmulqdq,pclmulqdq,sse2")]
pub(super) unsafe fn aggregate_4blocks(
acc: u128,
h_powers_rev: &[u128; 4], blocks: &[u128; 4], ) -> u128 {
unsafe {
let mut block_data = *blocks;
block_data[0] ^= acc;
let data = _mm512_loadu_si512(block_data.as_ptr().cast());
let h_vec = _mm512_loadu_si512(h_powers_rev.as_ptr().cast());
let lo = _mm512_clmulepi64_epi128(data, h_vec, 0x00); let hi = _mm512_clmulepi64_epi128(data, h_vec, 0x11); let m1 = _mm512_clmulepi64_epi128(data, h_vec, 0x10); let m2 = _mm512_clmulepi64_epi128(data, h_vec, 0x01); let mid = _mm512_xor_si512(m1, m2);
let lo_128 = _mm512_xor_si512(lo, _mm512_bslli_epi128(mid, 8));
let hi_128 = _mm512_xor_si512(hi, _mm512_bsrli_epi128(mid, 8));
let lo_0 = _mm512_extracti64x2_epi64(lo_128, 0);
let lo_1 = _mm512_extracti64x2_epi64(lo_128, 1);
let lo_2 = _mm512_extracti64x2_epi64(lo_128, 2);
let lo_3 = _mm512_extracti64x2_epi64(lo_128, 3);
let lo_sum = _mm_xor_si128(_mm_xor_si128(lo_0, lo_1), _mm_xor_si128(lo_2, lo_3));
let hi_0 = _mm512_extracti64x2_epi64(hi_128, 0);
let hi_1 = _mm512_extracti64x2_epi64(hi_128, 1);
let hi_2 = _mm512_extracti64x2_epi64(hi_128, 2);
let hi_3 = _mm512_extracti64x2_epi64(hi_128, 3);
let hi_sum = _mm_xor_si128(_mm_xor_si128(hi_0, hi_1), _mm_xor_si128(hi_2, hi_3));
let result = super::pclmul::mont_reduce_sse2(lo_sum, hi_sum);
let mut out = 0u128;
_mm_storeu_si128((&mut out as *mut u128).cast(), result);
out
}
}
}
#[cfg(all(target_arch = "aarch64", feature = "aes-gcm-siv"))]
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
pub(super) unsafe fn aarch64_clmul128_reduce_inline(a: u128, b: u128) -> u128 {
unsafe { pmull::clmul128_reduce_core(a, b) }
}
#[cfg(all(target_arch = "aarch64", feature = "aes-gcm-siv"))]
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
pub(super) unsafe fn aarch64_aggregate_4blocks_inline(acc: u128, h_powers_rev: &[u128; 4], blocks: &[u128; 4]) -> u128 {
unsafe { pmull::aggregate_4blocks(acc, h_powers_rev, blocks) }
}
#[cfg(all(target_arch = "powerpc64", feature = "aes-gcm-siv"))]
#[target_feature(enable = "altivec,vsx,power8-vector,power8-crypto")]
#[inline]
pub(super) unsafe fn ppc_clmul128_reduce_inline(a: u128, b: u128) -> u128 {
unsafe { ppc_vpmsum::clmul128_reduce_core(a, b) }
}
#[cfg(all(target_arch = "powerpc64", feature = "aes-gcm-siv"))]
#[target_feature(enable = "altivec,vsx,power8-vector,power8-crypto")]
#[inline]
pub(super) unsafe fn ppc_aggregate_4blocks_inline(acc: u128, h_powers_rev: &[u128; 4], blocks: &[u128; 4]) -> u128 {
unsafe { ppc_vpmsum::aggregate_4blocks(acc, h_powers_rev, blocks) }
}
#[cfg(all(target_arch = "s390x", feature = "aes-gcm-siv"))]
#[target_feature(enable = "vector")]
#[inline]
pub(super) unsafe fn s390x_clmul128_reduce_inline(a: u128, b: u128) -> u128 {
unsafe { s390x_vgfm::clmul128_reduce_core(a, b) }
}
#[cfg(all(target_arch = "s390x", feature = "aes-gcm-siv"))]
#[target_feature(enable = "vector")]
#[inline]
pub(super) unsafe fn s390x_aggregate_4blocks_inline(acc: u128, h_powers_rev: &[u128; 4], blocks: &[u128; 4]) -> u128 {
unsafe { s390x_vgfm::aggregate_4blocks(acc, h_powers_rev, blocks) }
}
#[cfg(all(target_arch = "riscv64", feature = "aes-gcm-siv"))]
#[inline(always)]
pub(super) fn portable_clmul128_reduce_inline(a: u128, b: u128) -> u128 {
mont_reduce(clmul128(a, b))
}
#[cfg(all(target_arch = "riscv64", feature = "aes-gcm-siv"))]
#[target_feature(enable = "v", enable = "zvbc")]
#[inline]
pub(super) unsafe fn riscv_vector_clmul128_reduce_inline(a: u128, b: u128) -> u128 {
unsafe { rv_clmul::clmul128_reduce(a, b) }
}
#[cfg(all(target_arch = "riscv64", feature = "aes-gcm-siv"))]
#[target_feature(enable = "zbc")]
#[inline]
pub(super) unsafe fn riscv_scalar_clmul128_reduce_inline(a: u128, b: u128) -> u128 {
unsafe { rv_scalar_clmul::clmul128_reduce(a, b) }
}
type Clmul128ReduceFn = fn(u128, u128) -> u128;
static CLMUL128_REDUCE_DISPATCH: OnceCache<Clmul128ReduceFn> = OnceCache::new();
#[inline]
fn current_caps() -> crate::platform::Caps {
#[cfg(feature = "std")]
{
crate::platform::caps()
}
#[cfg(not(feature = "std"))]
{
crate::platform::caps_static()
}
}
#[inline]
fn clmul128_reduce_portable(a: u128, b: u128) -> u128 {
mont_reduce(clmul128(a, b))
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn clmul128_reduce_x86_pclmul(a: u128, b: u128) -> u128 {
unsafe { pclmul::clmul128_reduce(a, b) }
}
#[cfg(target_arch = "aarch64")]
#[inline]
fn clmul128_reduce_aarch64_pmull(a: u128, b: u128) -> u128 {
unsafe { pmull::clmul128_reduce(a, b) }
}
#[cfg(target_arch = "s390x")]
#[inline]
fn clmul128_reduce_s390x_vgfm(a: u128, b: u128) -> u128 {
unsafe { s390x_vgfm::clmul128_reduce(a, b) }
}
#[cfg(target_arch = "powerpc64")]
#[inline]
fn clmul128_reduce_power_vpmsum(a: u128, b: u128) -> u128 {
unsafe { ppc_vpmsum::clmul128_reduce(a, b) }
}
#[cfg(target_arch = "riscv64")]
#[inline]
fn clmul128_reduce_riscv_vector(a: u128, b: u128) -> u128 {
unsafe { rv_clmul::clmul128_reduce(a, b) }
}
#[cfg(target_arch = "riscv64")]
#[inline]
fn clmul128_reduce_riscv_scalar(a: u128, b: u128) -> u128 {
unsafe { rv_scalar_clmul::clmul128_reduce(a, b) }
}
#[inline]
fn resolve_clmul128_reduce() -> Clmul128ReduceFn {
let caps = current_caps();
#[cfg(target_arch = "x86_64")]
if caps.has(crate::platform::caps::x86::PCLMULQDQ) {
return clmul128_reduce_x86_pclmul;
}
#[cfg(target_arch = "aarch64")]
if caps.has(crate::platform::caps::aarch64::PMULL) {
return clmul128_reduce_aarch64_pmull;
}
#[cfg(target_arch = "s390x")]
if caps.has(crate::platform::caps::s390x::VECTOR) {
return clmul128_reduce_s390x_vgfm;
}
#[cfg(target_arch = "powerpc64")]
if caps.has(crate::platform::caps::power::POWER8_CRYPTO) {
return clmul128_reduce_power_vpmsum;
}
#[cfg(target_arch = "riscv64")]
{
use crate::platform::caps::riscv;
if caps.has(riscv::ZVBC) {
return clmul128_reduce_riscv_vector;
}
if caps.has(riscv::ZBC) || caps.has(riscv::ZBKC) {
return clmul128_reduce_riscv_scalar;
}
}
clmul128_reduce_portable
}
pub(super) fn clmul128_reduce(a: u128, b: u128) -> u128 {
let clmul = CLMUL128_REDUCE_DISPATCH.get_or_init(resolve_clmul128_reduce);
clmul(a, b)
}
#[cfg(any(
feature = "aes-gcm",
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "powerpc64",
target_arch = "s390x",
test
))]
pub(super) fn precompute_powers(h: u128) -> [u128; 4] {
let h2 = clmul128_reduce(h, h);
let h3 = clmul128_reduce(h2, h);
let h4 = clmul128_reduce(h3, h);
[h, h2, h3, h4]
}
#[cfg(any(target_arch = "x86_64", test))]
pub(super) fn accumulate_4blocks(
acc: u128,
h: u128,
h_powers_rev: &[u128; 4], blocks: &[u128; 4],
) -> u128 {
#[cfg(target_arch = "x86_64")]
{
if crate::platform::caps().has(crate::platform::caps::x86::VPCLMUL_READY) {
return unsafe { vpclmul::aggregate_4blocks(acc, h_powers_rev, blocks) };
}
}
let _ = h_powers_rev;
let mut a = acc ^ blocks[0];
a = clmul128_reduce(a, h);
a ^= blocks[1];
a = clmul128_reduce(a, h);
a ^= blocks[2];
a = clmul128_reduce(a, h);
a ^= blocks[3];
clmul128_reduce(a, h)
}
#[cfg(feature = "aes-gcm-siv")]
pub(crate) struct Polyval {
h: u128,
acc: u128,
}
#[cfg(feature = "aes-gcm-siv")]
impl Polyval {
#[inline]
pub(crate) fn new(key: &[u8; KEY_SIZE]) -> Self {
Self {
h: u128::from_le_bytes(*key),
acc: 0,
}
}
#[inline]
pub(crate) fn update_block(&mut self, block: &[u8; BLOCK_SIZE]) {
self.acc ^= u128::from_le_bytes(*block);
self.acc = clmul128_reduce(self.acc, self.h);
}
pub(crate) fn update_padded(&mut self, data: &[u8]) {
let (blocks, remainder) = data.as_chunks::<BLOCK_SIZE>();
for block in blocks {
self.update_block(block);
}
if !remainder.is_empty() {
let mut block = [0u8; BLOCK_SIZE];
block[..remainder.len()].copy_from_slice(remainder);
self.update_block(&block);
}
}
#[inline]
pub(crate) fn finalize(self) -> [u8; BLOCK_SIZE] {
self.acc.to_le_bytes()
}
}
#[cfg(feature = "aes-gcm-siv")]
impl Drop for Polyval {
fn drop(&mut self) {
unsafe {
core::ptr::write_volatile(&mut self.acc, 0);
core::ptr::write_volatile(&mut self.h, 0);
}
core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
}
}
#[inline]
fn bmul64(x: u64, y: u64) -> u64 {
let x0 = x & 0x1111_1111_1111_1111;
let x1 = x & 0x2222_2222_2222_2222;
let x2 = x & 0x4444_4444_4444_4444;
let x3 = x & 0x8888_8888_8888_8888;
let y0 = y & 0x1111_1111_1111_1111;
let y1 = y & 0x2222_2222_2222_2222;
let y2 = y & 0x4444_4444_4444_4444;
let y3 = y & 0x8888_8888_8888_8888;
let mut z0 = (x0.wrapping_mul(y0)) ^ (x1.wrapping_mul(y3)) ^ (x2.wrapping_mul(y2)) ^ (x3.wrapping_mul(y1));
let mut z1 = (x0.wrapping_mul(y1)) ^ (x1.wrapping_mul(y0)) ^ (x2.wrapping_mul(y3)) ^ (x3.wrapping_mul(y2));
let mut z2 = (x0.wrapping_mul(y2)) ^ (x1.wrapping_mul(y1)) ^ (x2.wrapping_mul(y0)) ^ (x3.wrapping_mul(y3));
let mut z3 = (x0.wrapping_mul(y3)) ^ (x1.wrapping_mul(y2)) ^ (x2.wrapping_mul(y1)) ^ (x3.wrapping_mul(y0));
z0 &= 0x1111_1111_1111_1111;
z1 &= 0x2222_2222_2222_2222;
z2 &= 0x4444_4444_4444_4444;
z3 &= 0x8888_8888_8888_8888;
z0 | z1 | z2 | z3
}
pub(super) fn clmul128(a: u128, b: u128) -> [u64; 4] {
let a0 = a as u64;
let a1 = (a >> 64) as u64;
let b0 = b as u64;
let b1 = (b >> 64) as u64;
let a2 = a0 ^ a1;
let b2 = b0 ^ b1;
let a0r = a0.reverse_bits();
let a1r = a1.reverse_bits();
let a2r = a2.reverse_bits();
let b0r = b0.reverse_bits();
let b1r = b1.reverse_bits();
let b2r = b2.reverse_bits();
let z0 = bmul64(a0, b0);
let z1 = bmul64(a1, b1);
let mut z2 = bmul64(a2, b2);
let mut z0h = bmul64(a0r, b0r).reverse_bits() >> 1;
let z1h = bmul64(a1r, b1r).reverse_bits() >> 1;
let mut z2h = bmul64(a2r, b2r).reverse_bits() >> 1;
z2 ^= z0 ^ z1;
z2h ^= z0h ^ z1h;
z0h ^= z2;
[z0, z0h, z1 ^ z2h, z1h]
}
#[inline]
pub(super) fn mont_reduce(v: [u64; 4]) -> u128 {
let (v0, mut v1, mut v2, mut v3) = (v[0], v[1], v[2], v[3]);
v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7);
v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57);
v3 ^= v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7);
v2 ^= (v1 << 63) ^ (v1 << 62) ^ (v1 << 57);
(v2 as u128) | ((v3 as u128) << 64)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "aes-gcm-siv")]
#[test]
fn polyval_rfc8452_appendix_a() {
let h = hex_to_16("25629347589242761d31f826ba4b757b");
let x1 = hex_to_16("4f4f95668c83dfb6401762bb2d01a262");
let x2 = hex_to_16("d1a24ddd2721d006bbe45f20d3c9f362");
let expected = hex_to_16("f7a3b47b846119fae5b7866cf5e5b77e");
let mut pv = Polyval::new(&h);
pv.update_block(&x1);
pv.update_block(&x2);
let result = pv.finalize();
assert_eq!(result, expected, "POLYVAL mismatch");
}
#[cfg(feature = "aes-gcm-siv")]
#[test]
fn polyval_empty() {
let h = [0x42u8; 16];
let pv = Polyval::new(&h);
assert_eq!(pv.finalize(), [0u8; 16]);
}
#[cfg(feature = "aes-gcm-siv")]
#[test]
fn polyval_zero_key() {
let h = [0u8; 16];
let x = [0xffu8; 16];
let mut pv = Polyval::new(&h);
pv.update_block(&x);
assert_eq!(pv.finalize(), [0u8; 16]);
}
#[cfg(feature = "aes-gcm-siv")]
#[test]
fn polyval_padded_matches_manual() {
let h = hex_to_16("25629347589242761d31f826ba4b757b");
let data = b"Hello, World! This is test data for POLYVAL padding.";
let mut manual = Polyval::new(&h);
let mut offset = 0;
while offset + 16 <= data.len() {
let block: [u8; 16] = data[offset..offset + 16].try_into().unwrap();
manual.update_block(&block);
offset += 16;
}
if offset < data.len() {
let mut block = [0u8; 16];
block[..data.len() - offset].copy_from_slice(&data[offset..]);
manual.update_block(&block);
}
let manual_result = manual.finalize();
let mut padded = Polyval::new(&h);
padded.update_padded(data);
let padded_result = padded.finalize();
assert_eq!(manual_result, padded_result);
}
#[test]
fn bmul64_basic() {
assert_eq!(bmul64(1, 0x42), 0x42);
assert_eq!(bmul64(0x42, 1), 0x42);
assert_eq!(bmul64(0, 0xDEAD_BEEF), 0);
assert_eq!(bmul64(0xFF, 0xFF), 0x5555);
assert_eq!(bmul64(0x1234, 0x5678), bmul64(0x5678, 0x1234));
}
#[test]
fn clmul128_identity() {
let v = clmul128(1, 1);
assert_eq!(v, [1, 0, 0, 0]);
let v = clmul128(2, 2);
assert_eq!(v, [4, 0, 0, 0]);
let x64 = 1u128 << 64;
let v = clmul128(x64, x64);
assert_eq!(v, [0, 0, 1, 0]); }
#[test]
fn clmul128_by_one() {
let val: u128 = 0x7b75_4bba_26f8_311d_7642_9258_4793_6225;
let v = clmul128(1, val);
assert_eq!(v[0], val as u64, "v0 should be val_lo");
assert_eq!(v[1], (val >> 64) as u64, "v1 should be val_hi");
assert_eq!(v[2], 0, "v2 should be 0");
assert_eq!(v[3], 0, "v3 should be 0");
}
#[test]
fn mont_reduce_of_poly() {
let v = [POLY as u64, (POLY >> 64) as u64, 0u64, 0u64];
let result = mont_reduce(v);
assert_eq!(result, 1, "mont_reduce(POLY) should be 1");
}
#[test]
fn precompute_powers_correct() {
let h = u128::from_le_bytes(hex_to_16("25629347589242761d31f826ba4b757b"));
let powers = precompute_powers(h);
assert_eq!(powers[0], h, "powers[0] should be H");
assert_eq!(powers[1], clmul128_reduce(h, h), "powers[1] should be H^2");
assert_eq!(powers[2], clmul128_reduce(powers[1], h), "powers[2] should be H^3");
assert_eq!(powers[3], clmul128_reduce(powers[2], h), "powers[3] should be H^4");
}
#[test]
fn accumulate_4blocks_matches_sequential() {
let h_bytes = hex_to_16("25629347589242761d31f826ba4b757b");
let h = u128::from_le_bytes(h_bytes);
let powers = precompute_powers(h);
let h_powers_rev = [powers[3], powers[2], powers[1], powers[0]];
let blocks = [
u128::from_le_bytes(hex_to_16("4f4f95668c83dfb6401762bb2d01a262")),
u128::from_le_bytes(hex_to_16("d1a24ddd2721d006bbe45f20d3c9f362")),
u128::from_le_bytes(hex_to_16("0100000000000000000000000000000f")),
u128::from_le_bytes(hex_to_16("abcdef0123456789abcdef0123456789")),
];
let acc = 0x42u128;
let mut seq = acc ^ blocks[0];
seq = clmul128_reduce(seq, h);
seq ^= blocks[1];
seq = clmul128_reduce(seq, h);
seq ^= blocks[2];
seq = clmul128_reduce(seq, h);
seq ^= blocks[3];
seq = clmul128_reduce(seq, h);
let wide = accumulate_4blocks(acc, h, &h_powers_rev, &blocks);
assert_eq!(wide, seq, "4-block aggregate must match sequential processing");
}
#[test]
fn accumulate_4blocks_zeros() {
let h = 0x42u128;
let powers = precompute_powers(h);
let h_powers_rev = [powers[3], powers[2], powers[1], powers[0]];
let blocks = [0u128; 4];
let result = accumulate_4blocks(0, h, &h_powers_rev, &blocks);
assert_eq!(result, 0);
}
fn hex_to_16(hex: &str) -> [u8; 16] {
let mut out = [0u8; 16];
let mut i = 0;
while i < 16 {
out[i] = u8::from_str_radix(&hex[2 * i..2 * i + 2], 16).unwrap();
i = i.strict_add(1);
}
out
}
}