use crate::towers::bit::Bit;
use crate::towers::block8::Block8;
use crate::{
CanonicalDeserialize, CanonicalSerialize, Flat, FlatPromote, HardwareField, PackableField,
PackedFlat, TowerField, constants,
};
use core::ops::{Add, AddAssign, BitXor, BitXorAssign, Mul, MulAssign, Sub, SubAssign};
use serde::{Deserialize, Serialize};
use zeroize::Zeroize;
#[cfg(not(feature = "table-math"))]
#[repr(align(64))]
struct CtConvertBasisU16<const N: usize>([u16; N]);
#[cfg(not(feature = "table-math"))]
static TOWER_TO_FLAT_BASIS_16: CtConvertBasisU16<16> =
CtConvertBasisU16(constants::RAW_TOWER_TO_FLAT_16);
#[cfg(not(feature = "table-math"))]
static FLAT_TO_TOWER_BASIS_16: CtConvertBasisU16<16> =
CtConvertBasisU16(constants::RAW_FLAT_TO_TOWER_16);
#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Serialize, Deserialize, Zeroize)]
#[repr(transparent)]
pub struct Block16(pub u16);
impl Block16 {
pub const TAU: Self = Block16(0x2000);
pub fn new(lo: Block8, hi: Block8) -> Self {
Self((hi.0 as u16) << 8 | (lo.0 as u16))
}
#[inline(always)]
pub fn split(self) -> (Block8, Block8) {
(Block8(self.0 as u8), Block8((self.0 >> 8) as u8))
}
}
impl TowerField for Block16 {
const BITS: usize = 16;
const ZERO: Self = Block16(0);
const ONE: Self = Block16(1);
const EXTENSION_TAU: Self = Self::TAU;
fn invert(&self) -> Self {
let (l, h) = self.split();
let h2 = h * h;
let l2 = l * l;
let hl = h * l;
let norm = (h2 * Block8::EXTENSION_TAU) + hl + l2;
let norm_inv = norm.invert();
let res_hi = h * norm_inv;
let res_lo = (h + l) * norm_inv;
Self::new(res_lo, res_hi)
}
fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
let mut buf = [0u8; 2];
buf.copy_from_slice(&bytes[0..2]);
Self(u16::from_le_bytes(buf))
}
}
impl Add for Block16 {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0.bitxor(rhs.0))
}
}
impl Sub for Block16 {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self(self.0.bitxor(rhs.0))
}
}
impl Mul for Block16 {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
let (a0, a1) = self.split();
let (b0, b1) = rhs.split();
let v0 = a0 * b0;
let v1 = a1 * b1;
let v_sum = (a0 + a1) * (b0 + b1);
let c_hi = v0 + v_sum;
let c_lo = v0 + (v1 * Block8::EXTENSION_TAU);
Self::new(c_lo, c_hi)
}
}
impl AddAssign for Block16 {
fn add_assign(&mut self, rhs: Self) {
self.0.bitxor_assign(rhs.0);
}
}
impl SubAssign for Block16 {
fn sub_assign(&mut self, rhs: Self) {
self.0.bitxor_assign(rhs.0);
}
}
impl MulAssign for Block16 {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl CanonicalSerialize for Block16 {
fn serialized_size(&self) -> usize {
2
}
fn serialize(&self, writer: &mut [u8]) -> Result<(), ()> {
if writer.len() < 2 {
return Err(());
}
writer[..2].copy_from_slice(&self.0.to_le_bytes());
Ok(())
}
}
impl CanonicalDeserialize for Block16 {
fn deserialize(bytes: &[u8]) -> Result<Self, ()> {
if bytes.len() < 2 {
return Err(());
}
let mut buf = [0u8; 2];
buf.copy_from_slice(&bytes[0..2]);
Ok(Self(u16::from_le_bytes(buf)))
}
}
impl From<u8> for Block16 {
fn from(val: u8) -> Self {
Self(val as u16)
}
}
impl From<u16> for Block16 {
#[inline]
fn from(val: u16) -> Self {
Self(val)
}
}
impl From<u32> for Block16 {
#[inline]
fn from(val: u32) -> Self {
Self(val as u16)
}
}
impl From<u64> for Block16 {
#[inline]
fn from(val: u64) -> Self {
Self(val as u16)
}
}
impl From<u128> for Block16 {
#[inline]
fn from(val: u128) -> Self {
Self(val as u16)
}
}
impl From<Bit> for Block16 {
#[inline(always)]
fn from(val: Bit) -> Self {
Self(val.0 as u16)
}
}
impl From<Block8> for Block16 {
#[inline(always)]
fn from(val: Block8) -> Self {
Self(val.0 as u16)
}
}
pub const PACKED_WIDTH_16: usize = 8;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
#[repr(C, align(16))]
pub struct PackedBlock16(pub [Block16; PACKED_WIDTH_16]);
impl PackedBlock16 {
#[inline(always)]
pub fn zero() -> Self {
Self([Block16::ZERO; PACKED_WIDTH_16])
}
}
impl PackableField for Block16 {
type Packed = PackedBlock16;
const WIDTH: usize = PACKED_WIDTH_16;
#[inline(always)]
fn pack(chunk: &[Self]) -> Self::Packed {
assert!(
chunk.len() >= PACKED_WIDTH_16,
"PackableField::pack: input slice too short",
);
let mut arr = [Self::ZERO; PACKED_WIDTH_16];
arr.copy_from_slice(&chunk[..PACKED_WIDTH_16]);
PackedBlock16(arr)
}
#[inline(always)]
fn unpack(packed: Self::Packed, output: &mut [Self]) {
assert!(
output.len() >= PACKED_WIDTH_16,
"PackableField::unpack: output slice too short",
);
output[..PACKED_WIDTH_16].copy_from_slice(&packed.0);
}
}
impl Add for PackedBlock16 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self {
let mut res = [Block16::ZERO; PACKED_WIDTH_16];
for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
*out = *l + *r;
}
Self(res)
}
}
impl AddAssign for PackedBlock16 {
#[inline(always)]
fn add_assign(&mut self, rhs: Self) {
for (l, r) in self.0.iter_mut().zip(rhs.0.iter()) {
*l += *r;
}
}
}
impl Sub for PackedBlock16 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self {
self.add(rhs)
}
}
impl SubAssign for PackedBlock16 {
#[inline(always)]
fn sub_assign(&mut self, rhs: Self) {
self.add_assign(rhs);
}
}
impl Mul for PackedBlock16 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self {
#[cfg(target_arch = "aarch64")]
{
let mut res = [Block16::ZERO; PACKED_WIDTH_16];
for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
*out = mul_iso_16(*l, *r);
}
Self(res)
}
#[cfg(not(target_arch = "aarch64"))]
{
let mut res = [Block16::ZERO; PACKED_WIDTH_16];
for ((out, l), r) in res.iter_mut().zip(self.0.iter()).zip(rhs.0.iter()) {
*out = *l * *r;
}
Self(res)
}
}
}
impl MulAssign for PackedBlock16 {
#[inline(always)]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl Mul<Block16> for PackedBlock16 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Block16) -> Self {
let mut res = [Block16::ZERO; PACKED_WIDTH_16];
for (out, v) in res.iter_mut().zip(self.0.iter()) {
*out = *v * rhs;
}
Self(res)
}
}
impl HardwareField for Block16 {
#[inline(always)]
fn to_hardware(self) -> Flat<Self> {
#[cfg(feature = "table-math")]
{
Flat::from_raw(apply_matrix_16(self, &constants::TOWER_TO_FLAT_16))
}
#[cfg(not(feature = "table-math"))]
{
Flat::from_raw(Block16(map_ct_16(self.0, &TOWER_TO_FLAT_BASIS_16.0)))
}
}
#[inline(always)]
fn from_hardware(value: Flat<Self>) -> Self {
let value = value.into_raw();
#[cfg(feature = "table-math")]
{
apply_matrix_16(value, &constants::FLAT_TO_TOWER_16)
}
#[cfg(not(feature = "table-math"))]
{
Block16(map_ct_16(value.0, &FLAT_TO_TOWER_BASIS_16.0))
}
}
#[inline(always)]
fn add_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
Flat::from_raw(lhs.into_raw() + rhs.into_raw())
}
#[inline(always)]
fn add_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
let lhs = lhs.into_raw();
let rhs = rhs.into_raw();
#[cfg(target_arch = "aarch64")]
{
PackedFlat::from_raw(neon::add_packed_16(lhs, rhs))
}
#[cfg(not(target_arch = "aarch64"))]
{
PackedFlat::from_raw(lhs + rhs)
}
}
#[inline(always)]
fn mul_hardware(lhs: Flat<Self>, rhs: Flat<Self>) -> Flat<Self> {
let lhs = lhs.into_raw();
let rhs = rhs.into_raw();
#[cfg(target_arch = "aarch64")]
{
Flat::from_raw(neon::mul_flat_16(lhs, rhs))
}
#[cfg(not(target_arch = "aarch64"))]
{
let a_tower = Self::from_hardware(Flat::from_raw(lhs));
let b_tower = Self::from_hardware(Flat::from_raw(rhs));
(a_tower * b_tower).to_hardware()
}
}
#[inline(always)]
fn mul_hardware_packed(lhs: PackedFlat<Self>, rhs: PackedFlat<Self>) -> PackedFlat<Self> {
let lhs = lhs.into_raw();
let rhs = rhs.into_raw();
#[cfg(target_arch = "aarch64")]
{
PackedFlat::from_raw(neon::mul_flat_packed_16(lhs, rhs))
}
#[cfg(not(target_arch = "aarch64"))]
{
let mut l = [Self::ZERO; <Self as PackableField>::WIDTH];
let mut r = [Self::ZERO; <Self as PackableField>::WIDTH];
let mut res = [Self::ZERO; <Self as PackableField>::WIDTH];
Self::unpack(lhs, &mut l);
Self::unpack(rhs, &mut r);
for i in 0..<Self as PackableField>::WIDTH {
res[i] = Self::mul_hardware(Flat::from_raw(l[i]), Flat::from_raw(r[i])).into_raw();
}
PackedFlat::from_raw(Self::pack(&res))
}
}
#[inline(always)]
fn mul_hardware_scalar_packed(lhs: PackedFlat<Self>, rhs: Flat<Self>) -> PackedFlat<Self> {
let broadcasted = PackedBlock16([rhs.into_raw(); PACKED_WIDTH_16]);
Self::mul_hardware_packed(lhs, PackedFlat::from_raw(broadcasted))
}
#[inline(always)]
fn tower_bit_from_hardware(value: Flat<Self>, bit_idx: usize) -> u8 {
let mask = constants::FLAT_TO_TOWER_BIT_MASKS_16[bit_idx];
let mut v = value.into_raw().0 & mask;
v ^= v >> 8;
v ^= v >> 4;
v ^= v >> 2;
v ^= v >> 1;
(v & 1) as u8
}
}
impl FlatPromote<Block8> for Block16 {
#[inline(always)]
fn promote_flat(val: Flat<Block8>) -> Flat<Self> {
let val = val.into_raw();
#[cfg(not(feature = "table-math"))]
{
let mut acc = 0u16;
for i in 0..8 {
let bit = (val.0 >> i) & 1;
let mask = 0u16.wrapping_sub(bit as u16);
acc ^= constants::LIFT_BASIS_8_TO_16[i] & mask;
}
Flat::from_raw(Block16(acc))
}
#[cfg(feature = "table-math")]
{
Flat::from_raw(Block16(constants::LIFT_TABLE_8_TO_16[val.0 as usize]))
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub fn mul_iso_16(a: Block16, b: Block16) -> Block16 {
let a_f = a.to_hardware();
let b_f = b.to_hardware();
let c_f = Flat::from_raw(neon::mul_flat_16(a_f.into_raw(), b_f.into_raw()));
c_f.to_tower()
}
#[cfg(feature = "table-math")]
#[inline(always)]
pub fn apply_matrix_16(val: Block16, table: &[u16; 512]) -> Block16 {
let v = val.0;
let mut res = 0u16;
for i in 0..2 {
let idx = (i * 256) + ((v >> (i * 8)) & 0xFF) as usize;
res ^= unsafe { *table.get_unchecked(idx) };
}
Block16(res)
}
#[cfg(not(feature = "table-math"))]
#[inline(always)]
fn map_ct_16(x: u16, basis: &[u16; 16]) -> u16 {
let mut acc = 0u16;
let mut i = 0usize;
while i < 16 {
let bit = (x >> i) & 1;
let mask = 0u16.wrapping_sub(bit);
acc ^= basis[i] & mask;
i += 1;
}
acc
}
#[cfg(target_arch = "aarch64")]
mod neon {
use super::*;
use core::arch::aarch64::*;
use core::mem::transmute;
#[inline(always)]
pub fn add_packed_16(lhs: PackedBlock16, rhs: PackedBlock16) -> PackedBlock16 {
unsafe {
let res = veorq_u8(
transmute::<[Block16; 8], uint8x16_t>(lhs.0),
transmute::<[Block16; 8], uint8x16_t>(rhs.0),
);
transmute(res)
}
}
#[inline(always)]
pub fn mul_flat_16(a: Block16, b: Block16) -> Block16 {
unsafe {
let prod = vmull_p64(a.0 as u64, b.0 as u64);
let prod_val = vgetq_lane_u64(transmute::<u128, uint64x2_t>(prod), 0);
let l = (prod_val & 0xFFFF) as u16;
let h = (prod_val >> 16) as u16;
let r_val = constants::POLY_16 as u64;
let h_red = vmull_p64(h as u64, r_val);
let h_red_val = vgetq_lane_u64(transmute::<u128, uint64x2_t>(h_red), 0);
let folded = (h_red_val & 0xFFFF) as u16;
let carry = (h_red_val >> 16) as u16;
let mut res = l ^ folded;
let c_red = vmull_p64(carry as u64, r_val);
let c_val = vgetq_lane_u64(transmute::<u128, uint64x2_t>(c_red), 0);
res ^= c_val as u16;
Block16(res)
}
}
#[inline(always)]
pub fn mul_flat_packed_16(lhs: PackedBlock16, rhs: PackedBlock16) -> PackedBlock16 {
let r0 = mul_flat_16(lhs.0[0], rhs.0[0]);
let r1 = mul_flat_16(lhs.0[1], rhs.0[1]);
let r2 = mul_flat_16(lhs.0[2], rhs.0[2]);
let r3 = mul_flat_16(lhs.0[3], rhs.0[3]);
let r4 = mul_flat_16(lhs.0[4], rhs.0[4]);
let r5 = mul_flat_16(lhs.0[5], rhs.0[5]);
let r6 = mul_flat_16(lhs.0[6], rhs.0[6]);
let r7 = mul_flat_16(lhs.0[7], rhs.0[7]);
PackedBlock16([r0, r1, r2, r3, r4, r5, r6, r7])
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{RngExt, rng};
#[test]
fn tower_constants() {
let tau16 = Block16::EXTENSION_TAU;
let (lo16, hi16) = tau16.split();
assert_eq!(lo16, Block8::ZERO);
assert_eq!(hi16, Block8(0x20));
}
#[test]
fn add_truth() {
let zero = Block16::ZERO;
let one = Block16::ONE;
assert_eq!(zero + zero, zero);
assert_eq!(zero + one, one);
assert_eq!(one + zero, one);
assert_eq!(one + one, zero);
}
#[test]
fn mul_truth() {
let zero = Block16::ZERO;
let one = Block16::ONE;
assert_eq!(zero * zero, zero);
assert_eq!(zero * one, zero);
assert_eq!(one * one, one);
}
#[test]
fn add() {
assert_eq!(Block16(5) + Block16(3), Block16(6));
}
#[test]
fn mul_simple() {
assert_eq!(Block16(2) * Block16(2), Block16(4));
}
#[test]
fn mul_overflow() {
assert_eq!(Block16(0x57) * Block16(0x83), Block16(0xC1));
}
#[test]
fn karatsuba_correctness() {
let x = Block16::new(Block8::ZERO, Block8::ONE);
let squared = x * x;
let (res_lo, res_hi) = squared.split();
assert_eq!(res_hi, Block8::ONE, "X^2 should contain X component");
assert_eq!(
res_lo,
Block8(0x20),
"X^2 should contain tau component (0x20)"
);
}
#[test]
fn security_zeroize() {
let mut secret_val = Block16::from(0xDEAD_u16);
assert_ne!(secret_val, Block16::ZERO);
secret_val.zeroize();
assert_eq!(secret_val, Block16::ZERO);
assert_eq!(secret_val.0, 0, "Block16 memory leak detected");
}
#[test]
fn invert_zero() {
assert_eq!(
Block16::ZERO.invert(),
Block16::ZERO,
"invert(0) must return 0"
);
}
#[test]
fn inversion_random() {
let mut rng = rng();
for _ in 0..1000 {
let val_u16: u16 = rng.random();
let val = Block16(val_u16);
if val != Block16::ZERO {
let inv = val.invert();
let res = val * inv;
assert_eq!(
res,
Block16::ONE,
"Inversion identity failed: a * a^-1 != 1"
);
}
}
}
#[test]
fn tower_embedding() {
let mut rng = rng();
for _ in 0..100 {
let a = Block8(rng.random());
let b = Block8(rng.random());
let a_lifted: Block16 = a.into();
let (lo, hi) = a_lifted.split();
assert_eq!(lo, a, "Embedding structure failed: low part mismatch");
assert_eq!(
hi,
Block8::ZERO,
"Embedding structure failed: high part must be zero"
);
let sum_sub = a + b;
let sum_lifted: Block16 = sum_sub.into();
let sum_manual = Block16::from(a) + Block16::from(b);
assert_eq!(sum_lifted, sum_manual, "Homomorphism failed: add");
let prod_sub = a * b;
let prod_lifted: Block16 = prod_sub.into();
let prod_manual = Block16::from(a) * Block16::from(b);
assert_eq!(prod_lifted, prod_manual, "Homomorphism failed: mul");
}
}
#[test]
fn isomorphism_roundtrip() {
let mut rng = rng();
for _ in 0..1000 {
let val = Block16(rng.random::<u16>());
assert_eq!(
val.to_hardware().to_tower(),
val,
"Block16 isomorphism roundtrip failed"
);
}
}
#[test]
fn flat_mul_homomorphism() {
let mut rng = rng();
for _ in 0..1000 {
let a = Block16(rng.random::<u16>());
let b = Block16(rng.random::<u16>());
let expected_flat = (a * b).to_hardware();
let actual_flat = a.to_hardware() * b.to_hardware();
assert_eq!(
actual_flat, expected_flat,
"Block16 flat multiplication mismatch"
);
}
}
#[test]
fn packed_consistency() {
let mut rng = rng();
for _ in 0..100 {
let mut a_vals = [Block16::ZERO; 8];
let mut b_vals = [Block16::ZERO; 8];
for i in 0..8 {
a_vals[i] = Block16(rng.random::<u16>());
b_vals[i] = Block16(rng.random::<u16>());
}
let a_flat_vals = a_vals.map(|x| x.to_hardware());
let b_flat_vals = b_vals.map(|x| x.to_hardware());
let a_packed = Flat::<Block16>::pack(&a_flat_vals);
let b_packed = Flat::<Block16>::pack(&b_flat_vals);
let add_res = Block16::add_hardware_packed(a_packed, b_packed);
let mut add_out = [Block16::ZERO.to_hardware(); 8];
Flat::<Block16>::unpack(add_res, &mut add_out);
for i in 0..8 {
assert_eq!(
add_out[i],
(a_vals[i] + b_vals[i]).to_hardware(),
"Block16 packed add mismatch"
);
}
let mul_res = Block16::mul_hardware_packed(a_packed, b_packed);
let mut mul_out = [Block16::ZERO.to_hardware(); 8];
Flat::<Block16>::unpack(mul_res, &mut mul_out);
for i in 0..8 {
assert_eq!(
mul_out[i],
(a_vals[i] * b_vals[i]).to_hardware(),
"Block16 packed mul mismatch"
);
}
}
}
#[test]
fn pack_unpack_roundtrip() {
let mut rng = rng();
let mut data = [Block16::ZERO; PACKED_WIDTH_16];
for v in data.iter_mut() {
*v = Block16(rng.random());
}
let packed = Block16::pack(&data);
let mut unpacked = [Block16::ZERO; PACKED_WIDTH_16];
Block16::unpack(packed, &mut unpacked);
assert_eq!(data, unpacked, "Block16 pack/unpack roundtrip failed");
}
#[test]
fn packed_add_consistency() {
let mut rng = rng();
let mut a_vals = [Block16::ZERO; PACKED_WIDTH_16];
let mut b_vals = [Block16::ZERO; PACKED_WIDTH_16];
for i in 0..PACKED_WIDTH_16 {
a_vals[i] = Block16(rng.random());
b_vals[i] = Block16(rng.random());
}
let res_packed = Block16::pack(&a_vals) + Block16::pack(&b_vals);
let mut res_unpacked = [Block16::ZERO; PACKED_WIDTH_16];
Block16::unpack(res_packed, &mut res_unpacked);
for i in 0..PACKED_WIDTH_16 {
assert_eq!(
res_unpacked[i],
a_vals[i] + b_vals[i],
"Block16 packed add mismatch"
);
}
}
#[test]
fn packed_mul_consistency() {
let mut rng = rng();
for _ in 0..1000 {
let mut a_arr = [Block16::ZERO; PACKED_WIDTH_16];
let mut b_arr = [Block16::ZERO; PACKED_WIDTH_16];
for i in 0..PACKED_WIDTH_16 {
let val_a_u16: u16 = rng.random();
let val_b_u16: u16 = rng.random();
a_arr[i] = Block16(val_a_u16);
b_arr[i] = Block16(val_b_u16);
}
let a_packed = PackedBlock16(a_arr);
let b_packed = PackedBlock16(b_arr);
let c_packed = a_packed * b_packed;
let mut c_expected = [Block16::ZERO; PACKED_WIDTH_16];
for i in 0..PACKED_WIDTH_16 {
c_expected[i] = a_arr[i] * b_arr[i];
}
assert_eq!(c_packed.0, c_expected, "SIMD Block16 mismatch!");
}
}
#[test]
fn parity_masks_match_from_hardware() {
for x_flat in 0u16..=u16::MAX {
let tower = Block16::from_hardware(Flat::from_raw(Block16(x_flat))).0;
for k in 0..16 {
let bit = ((tower >> k) & 1) as u8;
let via_api = Flat::from_raw(Block16(x_flat)).tower_bit(k);
assert_eq!(
via_api, bit,
"Block16 tower_bit_from_hardware mismatch at x_flat={x_flat:#06x}, bit_idx={k}"
);
}
}
}
}