pub(crate) fn cmov(r: &mut [u8], x: &[u8], b: u8) {
debug_assert_eq!(r.len(), x.len());
debug_assert!(b == 0 || b == 1);
let mask = (!b).wrapping_add(1);
for (ri, &xi) in r.iter_mut().zip(x.iter()) {
*ri ^= mask & (xi ^ *ri);
}
}
#[inline(always)]
fn int32_minmax(a: &mut i32, b: &mut i32) {
let ab = (*b) ^ (*a);
let mut c = ((*b) as i64).wrapping_sub((*a) as i64) as i32;
c ^= ab & (c ^ (*b));
c >>= 31;
c &= ab;
*a ^= c;
*b ^= c;
}
pub(crate) fn crypto_sort_int32(array: &mut [i32]) {
let n = array.len();
if n < 2 {
return;
}
let mut top: usize = 1;
while top < n - top {
top += top;
}
let mut p = top;
while p >= 1 {
let mut i = 0usize;
while i + 2 * p <= n {
for j in i..i + p {
let (lo, hi) = array.split_at_mut(j + p);
int32_minmax(&mut lo[j], &mut hi[0]);
}
i += 2 * p;
}
for j in i..n.saturating_sub(p) {
let (lo, hi) = array.split_at_mut(j + p);
int32_minmax(&mut lo[j], &mut hi[0]);
}
let mut i = 0usize;
let mut j = 0usize;
let mut q = top;
while q > p {
'outer: loop {
if j != i {
loop {
if j == n - q {
break 'outer;
}
let mut a = array[j + p];
let mut r = q;
while r > p {
int32_minmax(&mut a, &mut array[j + r]);
r >>= 1;
}
array[j + p] = a;
j += 1;
if j == i + p {
i += 2 * p;
break;
}
}
}
while i + p <= n - q {
for k in i..i + p {
let mut a = array[k + p];
let mut r = q;
while r > p {
int32_minmax(&mut a, &mut array[k + r]);
r >>= 1;
}
array[k + p] = a;
}
i += 2 * p;
}
let mut k = i;
while k < n.saturating_sub(q) {
let mut a = array[k + p];
let mut r = q;
while r > p {
int32_minmax(&mut a, &mut array[k + r]);
r >>= 1;
}
array[k + p] = a;
k += 1;
}
break;
}
q >>= 1;
}
p >>= 1;
}
}
#[inline(always)]
pub(crate) fn both_negative_mask_i16(x: i16, y: i16) -> i16 {
(x & y) >> 15
}
#[inline]
pub(crate) fn mod3(a: u16) -> u16 {
let mut r = (a >> 8) + (a & 0xff);
r = (r >> 4) + (r & 0xf);
r = (r >> 2) + (r & 0x3);
r = (r >> 2) + (r & 0x3);
let t = (r as i16) - 3;
let c = t >> 15;
(((c as u16) & r) | ((!c as u16) & (t as u16))) & 0xffff
}
#[inline]
pub(crate) fn mod3_u8(a: u8) -> u8 {
debug_assert!(a <= 14, "mod3_u8 input out of range: {a}");
let a = (a >> 2) + (a & 3);
let t = (a as i16) - 3;
let c = t >> 5;
(t ^ (c & ((a as i16) ^ t))) as u8
}
pub(crate) trait DigestChain: crate::hash::Digest + Sized {
fn chain(self, data: &[u8]) -> Self {
let mut me = self;
me.update(data);
me
}
}
impl<D: crate::hash::Digest> DigestChain for D {}
pub(crate) fn poly_r2_inv<const N: usize>(r: &mut [u16; N], a: &[u16; N]) {
let mut f = [0u16; N];
let mut g = [0u16; N];
let mut v = [0u16; N];
let mut w = [0u16; N];
w[0] = 1;
for fi in f.iter_mut() {
*fi = 1;
}
for i in 0..N - 1 {
g[N - 2 - i] = (a[i] ^ a[N - 1]) & 1;
}
g[N - 1] = 0;
let mut delta: i16 = 1;
for _ in 0..(2 * (N - 1) - 1) {
for i in (1..N).rev() {
v[i] = v[i - 1];
}
v[0] = 0;
let sign = (g[0] & f[0]) as i16;
let swap = both_negative_mask_i16(-delta, -(g[0] as i16));
delta ^= swap & (delta ^ -delta);
delta += 1;
for i in 0..N {
let t = (swap as u16) & (f[i] ^ g[i]);
f[i] ^= t;
g[i] ^= t;
let t = (swap as u16) & (v[i] ^ w[i]);
v[i] ^= t;
w[i] ^= t;
}
for i in 0..N {
g[i] ^= (sign as u16) & f[i];
}
for i in 0..N {
w[i] ^= (sign as u16) & v[i];
}
for i in 0..N - 1 {
g[i] = g[i + 1];
}
g[N - 1] = 0;
}
for i in 0..N - 1 {
r[i] = v[N - 2 - i];
}
r[N - 1] = 0;
}
pub(crate) fn poly_s3_inv<const N: usize>(r: &mut [u16; N], a: &[u16; N]) {
let mut f = [0u16; N];
let mut g = [0u16; N];
let mut v = [0u16; N];
let mut w = [0u16; N];
w[0] = 1;
for fi in f.iter_mut() {
*fi = 1;
}
for i in 0..N - 1 {
g[N - 2 - i] = mod3_u8(((a[i] & 3) + 2 * (a[N - 1] & 3)) as u8) as u16;
}
g[N - 1] = 0;
let mut delta: i16 = 1;
for _ in 0..(2 * (N - 1) - 1) {
for i in (1..N).rev() {
v[i] = v[i - 1];
}
v[0] = 0;
let sign = mod3_u8((2 * g[0] * f[0]) as u8) as u16;
let swap = both_negative_mask_i16(-delta, -(g[0] as i16));
delta ^= swap & (delta ^ -delta);
delta += 1;
for i in 0..N {
let t = (swap as u16) & (f[i] ^ g[i]);
f[i] ^= t;
g[i] ^= t;
let t = (swap as u16) & (v[i] ^ w[i]);
v[i] ^= t;
w[i] ^= t;
}
for i in 0..N {
g[i] = mod3_u8((g[i] + sign * f[i]) as u8) as u16;
}
for i in 0..N {
w[i] = mod3_u8((w[i] + sign * v[i]) as u8) as u16;
}
for i in 0..N - 1 {
g[i] = g[i + 1];
}
g[N - 1] = 0;
}
let sign = f[0] as u16;
for i in 0..N - 1 {
r[i] = mod3_u8((sign * v[N - 2 - i]) as u8) as u16;
}
r[N - 1] = 0;
}
pub(crate) fn poly_r2_inv_to_rq_inv<const N: usize>(
r: &mut [u16; N],
ai: &[u16; N],
a: &[u16; N],
) {
let mut b = [0u16; N];
for i in 0..N {
b[i] = 0u16.wrapping_sub(a[i]);
}
r.copy_from_slice(ai);
let mut c = [0u16; N];
let mut s = [0u16; N];
use crate::public_key::ntru_poly_mul::poly_mul_cyclic as mul;
mul(&mut c, r, &b);
c[0] = c[0].wrapping_add(2);
mul(&mut s, &c, r);
mul(&mut c, &s, &b);
c[0] = c[0].wrapping_add(2);
mul(r, &c, &s);
mul(&mut c, r, &b);
c[0] = c[0].wrapping_add(2);
mul(&mut s, &c, r);
mul(&mut c, &s, &b);
c[0] = c[0].wrapping_add(2);
mul(r, &c, &s);
}
macro_rules! define_pqc_kem {
(
namespace = $type_name:ident,
public_key = $pk_ty:ident,
private_key = $sk_ty:ident,
ciphertext = $ct_ty:ident,
shared_secret = $ss_ty:ident,
variant = $variant:ident,
kat_path = $kat_path:literal $(,)?
) => {
#[derive(Clone, Eq, PartialEq)]
pub struct $pk_ty {
bytes: [u8; PUBLIC_KEY_BYTES],
}
#[derive(Clone, Eq, PartialEq)]
pub struct $sk_ty {
bytes: [u8; PRIVATE_KEY_BYTES],
}
#[derive(Clone, Eq, PartialEq)]
pub struct $ct_ty {
bytes: [u8; CIPHERTEXT_BYTES],
}
#[derive(Clone, Eq, PartialEq)]
pub struct $ss_ty {
bytes: [u8; SHARED_SECRET_BYTES],
}
impl $pk_ty {
#[must_use]
pub fn from_wire_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() != PUBLIC_KEY_BYTES { return None; }
let mut out = [0u8; PUBLIC_KEY_BYTES];
out.copy_from_slice(bytes);
Some(Self { bytes: out })
}
#[must_use]
pub fn to_wire_bytes(&self) -> [u8; PUBLIC_KEY_BYTES] { self.bytes }
#[must_use]
pub fn as_bytes(&self) -> &[u8; PUBLIC_KEY_BYTES] { &self.bytes }
}
impl $sk_ty {
#[must_use]
pub fn from_wire_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() != PRIVATE_KEY_BYTES { return None; }
let mut out = [0u8; PRIVATE_KEY_BYTES];
out.copy_from_slice(bytes);
Some(Self { bytes: out })
}
#[must_use]
pub fn to_wire_bytes(&self) -> [u8; PRIVATE_KEY_BYTES] { self.bytes }
#[must_use]
pub fn as_bytes(&self) -> &[u8; PRIVATE_KEY_BYTES] { &self.bytes }
}
impl $ct_ty {
#[must_use]
pub fn from_wire_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() != CIPHERTEXT_BYTES { return None; }
let mut out = [0u8; CIPHERTEXT_BYTES];
out.copy_from_slice(bytes);
Some(Self { bytes: out })
}
#[must_use]
pub fn to_wire_bytes(&self) -> [u8; CIPHERTEXT_BYTES] { self.bytes }
#[must_use]
pub fn as_bytes(&self) -> &[u8; CIPHERTEXT_BYTES] { &self.bytes }
}
impl $ss_ty {
#[must_use]
pub fn to_wire_bytes(&self) -> [u8; SHARED_SECRET_BYTES] { self.bytes }
#[must_use]
pub fn as_bytes(&self) -> &[u8; SHARED_SECRET_BYTES] { &self.bytes }
}
impl ::core::fmt::Debug for $pk_ty {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
f.debug_struct(stringify!($pk_ty)).finish()
}
}
impl ::core::fmt::Debug for $ct_ty {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
f.debug_struct(stringify!($ct_ty)).finish()
}
}
impl ::core::fmt::Debug for $sk_ty {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
f.write_str(concat!(stringify!($sk_ty), "(<redacted>)"))
}
}
impl ::core::fmt::Debug for $ss_ty {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
f.write_str(concat!(stringify!($ss_ty), "(<redacted>)"))
}
}
pub struct $type_name;
impl $type_name {
pub const PUBLIC_KEY_BYTES: usize = PUBLIC_KEY_BYTES;
pub const PRIVATE_KEY_BYTES: usize = PRIVATE_KEY_BYTES;
pub const CIPHERTEXT_BYTES: usize = CIPHERTEXT_BYTES;
pub const SHARED_SECRET_BYTES: usize = SHARED_SECRET_BYTES;
pub fn keygen<R: $crate::Csprng>(rng: &mut R) -> ($pk_ty, $sk_ty) {
let mut pk = [0u8; PUBLIC_KEY_BYTES];
let mut sk = [0u8; PRIVATE_KEY_BYTES];
let mut seed_scratch = [0u8; SAMPLE_FG_BYTES];
$crate::public_key::ntru_pqc_shared::kem_keypair_seeded::<$variant, R, N, LOGQ>(
&mut pk,
&mut sk,
rng,
&mut seed_scratch,
);
($pk_ty { bytes: pk }, $sk_ty { bytes: sk })
}
pub fn encaps<R: $crate::Csprng>(
pk: &$pk_ty,
rng: &mut R,
) -> ($ct_ty, $ss_ty) {
let mut ct = [0u8; CIPHERTEXT_BYTES];
let mut ss = [0u8; SHARED_SECRET_BYTES];
let mut rm_seed_scratch = [0u8; SAMPLE_RM_BYTES];
let mut rm_scratch = [0u8; OWCPA_MSGBYTES];
$crate::public_key::ntru_pqc_shared::kem_enc_seeded::<$variant, R, N, LOGQ>(
&mut ct,
&mut ss,
&pk.bytes,
rng,
&mut rm_seed_scratch,
&mut rm_scratch,
);
($ct_ty { bytes: ct }, $ss_ty { bytes: ss })
}
pub fn decaps(sk: &$sk_ty, ct: &$ct_ty) -> $ss_ty {
let mut ss = [0u8; SHARED_SECRET_BYTES];
let mut rm_scratch = [0u8; OWCPA_MSGBYTES];
$crate::public_key::ntru_pqc_shared::kem_dec::<$variant, N, LOGQ>(
&mut ss,
&ct.bytes,
&sk.bytes,
&mut rm_scratch,
);
$ss_ty { bytes: ss }
}
}
#[cfg(test)]
mod tests {
use super::*;
use $crate::CtrDrbgAes256;
#[test]
fn parameter_byte_lengths() {
assert!(PUBLIC_KEY_BYTES > 0);
assert!(PRIVATE_KEY_BYTES > 0);
assert!(CIPHERTEXT_BYTES > 0);
assert_eq!(SHARED_SECRET_BYTES, 32);
}
#[test]
fn roundtrip_random() {
let mut drbg = CtrDrbgAes256::new(&[0x42u8; 48]);
let (pk, sk) = $type_name::keygen(&mut drbg);
let (ct, ss_a) = $type_name::encaps(&pk, &mut drbg);
let ss_b = $type_name::decaps(&sk, &ct);
assert_eq!(ss_a.as_bytes(), ss_b.as_bytes());
}
#[test]
fn roundtrip_multiple_seeds() {
for seed in [0x00u8, 0x55, 0xaa, 0xff] {
let mut drbg = CtrDrbgAes256::new(&[seed; 48]);
let (pk, sk) = $type_name::keygen(&mut drbg);
let (ct, ss_a) = $type_name::encaps(&pk, &mut drbg);
let ss_b = $type_name::decaps(&sk, &ct);
assert_eq!(
ss_a.as_bytes(),
ss_b.as_bytes(),
"seed byte 0x{seed:02x}"
);
}
}
#[test]
fn implicit_rejection_on_corrupted_ciphertext() {
let mut drbg = CtrDrbgAes256::new(&[0x99u8; 48]);
let (pk, sk) = $type_name::keygen(&mut drbg);
let (ct, ss_a) = $type_name::encaps(&pk, &mut drbg);
let mut bad = ct.to_wire_bytes();
bad[0] ^= 0x01;
let bad_ct = $ct_ty::from_wire_bytes(&bad).unwrap();
let ss_bad = $type_name::decaps(&sk, &bad_ct);
assert_ne!(ss_bad.as_bytes(), ss_a.as_bytes());
let ss_bad2 = $type_name::decaps(&sk, &bad_ct);
assert_eq!(ss_bad.as_bytes(), ss_bad2.as_bytes());
}
#[test]
fn wire_format_roundtrip() {
let mut drbg = CtrDrbgAes256::new(&[0x21u8; 48]);
let (pk, sk) = $type_name::keygen(&mut drbg);
let (ct, _) = $type_name::encaps(&pk, &mut drbg);
let pk_bytes = pk.to_wire_bytes();
let sk_bytes = sk.to_wire_bytes();
let ct_bytes = ct.to_wire_bytes();
assert_eq!(pk_bytes.len(), PUBLIC_KEY_BYTES);
assert_eq!(sk_bytes.len(), PRIVATE_KEY_BYTES);
assert_eq!(ct_bytes.len(), CIPHERTEXT_BYTES);
let pk2 = $pk_ty::from_wire_bytes(&pk_bytes).unwrap();
let sk2 = $sk_ty::from_wire_bytes(&sk_bytes).unwrap();
let ct2 = $ct_ty::from_wire_bytes(&ct_bytes).unwrap();
assert_eq!(pk, pk2);
assert_eq!(sk, sk2);
assert_eq!(ct, ct2);
}
#[test]
fn nist_kat_sampled_counts() {
let rsp = include_str!($kat_path);
for &count in $crate::public_key::ntru_pqc_shared::KAT_SAMPLED_COUNTS {
run_kat_count(rsp, count);
}
}
#[test]
#[ignore]
fn nist_kat_full() {
let rsp = include_str!($kat_path);
for count in 0..100 {
run_kat_count(rsp, count);
}
}
fn run_kat_count(rsp: &str, count: usize) {
let entry = $crate::public_key::ntru_pqc_shared::parse_kat_entry(rsp, count)
.unwrap_or_else(|| panic!("KAT count={count} missing"));
assert_eq!(entry.seed.len(), 48, "seed length");
let mut seed = [0u8; 48];
seed.copy_from_slice(&entry.seed);
let mut drbg = CtrDrbgAes256::new(&seed);
let (pk, sk) = $type_name::keygen(&mut drbg);
assert_eq!(pk.to_wire_bytes().as_slice(), entry.pk.as_slice(), "pk @ count={count}");
assert_eq!(sk.to_wire_bytes().as_slice(), entry.sk.as_slice(), "sk @ count={count}");
let (ct, ss) = $type_name::encaps(&pk, &mut drbg);
assert_eq!(ct.to_wire_bytes().as_slice(), entry.ct.as_slice(), "ct @ count={count}");
assert_eq!(ss.to_wire_bytes().as_slice(), entry.ss.as_slice(), "ss @ count={count}");
let ss2 = $type_name::decaps(&sk, &ct);
assert_eq!(ss.as_bytes(), ss2.as_bytes(), "decaps @ count={count}");
}
}
};
}
pub(crate) use define_pqc_kem;
pub(crate) fn poly_mod_3_phi_n<const N: usize>(r: &mut [u16; N]) {
let last = r[N - 1];
for c in r.iter_mut() {
*c = mod3(*c + 2 * last);
}
}
pub(crate) fn poly_mod_q_phi_n<const N: usize>(r: &mut [u16; N]) {
let last = r[N - 1];
for c in r.iter_mut() {
*c = c.wrapping_sub(last);
}
}
pub(crate) fn poly_z3_to_zq<const N: usize>(r: &mut [u16; N], q_mask: u16) {
for c in r.iter_mut() {
*c |= (0u16.wrapping_sub(*c >> 1)) & q_mask;
}
}
pub(crate) fn poly_trinary_zq_to_z3<const N: usize, const LOGQ: usize>(r: &mut [u16; N]) {
let q_mask = ((1u32 << LOGQ) - 1) as u16;
for c in r.iter_mut() {
*c = *c & q_mask;
*c = 3 & (*c ^ (*c >> (LOGQ - 1)));
}
}
pub(crate) fn poly_rq_to_s3<const N: usize, const LOGQ: usize>(
r: &mut [u16; N],
a: &[u16; N],
) {
let q_mask = ((1u32 << LOGQ) - 1) as u16;
for i in 0..N {
let mut c = a[i] & q_mask;
let flag = c >> (LOGQ - 1);
c = c.wrapping_add(flag << (1 - (LOGQ & 1)));
r[i] = c;
}
poly_mod_3_phi_n::<N>(r);
}
pub(crate) fn poly_rq_inv<const N: usize>(r: &mut [u16; N], a: &[u16; N]) {
let mut ai2 = [0u16; N];
poly_r2_inv(&mut ai2, a);
poly_r2_inv_to_rq_inv(r, &ai2, a);
}
pub(crate) trait NtruVariant<const N: usize, const LOGQ: usize> {
const Q_MASK: u16;
const SAMPLE_FG_BYTES: usize;
const SAMPLE_RM_BYTES: usize;
const PACK_TRINARY_BYTES: usize;
const OWCPA_PUBLICKEYBYTES: usize;
const OWCPA_SECRETKEYBYTES: usize;
const OWCPA_BYTES: usize;
const OWCPA_MSGBYTES: usize;
const WEIGHT: usize;
fn sample_fg(f: &mut [u16; N], g: &mut [u16; N], seed: &[u8]) {
debug_assert_eq!(seed.len(), Self::SAMPLE_FG_BYTES);
let iid_bytes = N - 1;
sample_iid::<N>(f, &seed[..iid_bytes]);
let mut scratch = [0i32; N];
sample_fixed_type::<N>(g, &seed[iid_bytes..], Self::WEIGHT, &mut scratch);
}
fn sample_rm(r: &mut [u16; N], m: &mut [u16; N], seed: &[u8]) {
debug_assert_eq!(seed.len(), Self::SAMPLE_RM_BYTES);
let iid_bytes = N - 1;
sample_iid::<N>(r, &seed[..iid_bytes]);
let mut scratch = [0i32; N];
sample_fixed_type::<N>(m, &seed[iid_bytes..], Self::WEIGHT, &mut scratch);
}
fn update_g_after_z3_to_zq(g: &mut [u16; N]) {
for gi in g.iter_mut() {
*gi = gi.wrapping_mul(3);
}
}
fn poly_lift(r: &mut [u16; N], a: &[u16; N]) {
poly_lift_hps::<N>(r, a, Self::Q_MASK);
}
fn check_m(m: &[u16; N]) -> i32 {
owcpa_check_m::<N>(m, Self::WEIGHT)
}
fn poly_sq_tobytes(r: &mut [u8], a: &[u16; N]);
fn poly_sq_frombytes(r: &mut [u16; N], a: &[u8]);
}
pub(crate) fn owcpa_keypair<V, const N: usize, const LOGQ: usize>(
pk: &mut [u8],
sk: &mut [u8],
seed: &[u8],
) where
V: NtruVariant<N, LOGQ>,
{
debug_assert_eq!(pk.len(), V::OWCPA_PUBLICKEYBYTES);
debug_assert_eq!(sk.len(), V::OWCPA_SECRETKEYBYTES);
debug_assert_eq!(seed.len(), V::SAMPLE_FG_BYTES);
let mut f = [0u16; N];
let mut g = [0u16; N];
V::sample_fg(&mut f, &mut g, seed);
let mut invf_mod3 = [0u16; N];
poly_s3_inv::<N>(&mut invf_mod3, &f);
poly_s3_tobytes::<N>(&mut sk[..V::PACK_TRINARY_BYTES], &f);
poly_s3_tobytes::<N>(
&mut sk[V::PACK_TRINARY_BYTES..2 * V::PACK_TRINARY_BYTES],
&invf_mod3,
);
poly_z3_to_zq::<N>(&mut f, V::Q_MASK);
poly_z3_to_zq::<N>(&mut g, V::Q_MASK);
V::update_g_after_z3_to_zq(&mut g);
let mut gf = [0u16; N];
poly_rq_mul::<N>(&mut gf, &g, &f);
let mut invgf = [0u16; N];
poly_rq_inv::<N>(&mut invgf, &gf);
let mut tmp = [0u16; N];
let mut invh = [0u16; N];
poly_rq_mul::<N>(&mut tmp, &invgf, &f);
poly_sq_mul::<N>(&mut invh, &tmp, &f);
V::poly_sq_tobytes(&mut sk[2 * V::PACK_TRINARY_BYTES..], &invh);
let mut h = [0u16; N];
poly_rq_mul::<N>(&mut tmp, &invgf, &g);
poly_rq_mul::<N>(&mut h, &tmp, &g);
V::poly_sq_tobytes(pk, &h);
}
pub(crate) fn owcpa_enc<V, const N: usize, const LOGQ: usize>(
c: &mut [u8],
r: &[u16; N],
m: &[u16; N],
pk: &[u8],
) where
V: NtruVariant<N, LOGQ>,
{
debug_assert_eq!(c.len(), V::OWCPA_BYTES);
debug_assert_eq!(pk.len(), V::OWCPA_PUBLICKEYBYTES);
let mut h = [0u16; N];
V::poly_sq_frombytes(&mut h, pk);
poly_rq_sum_zero_adjust::<N>(&mut h);
let mut ct = [0u16; N];
poly_rq_mul::<N>(&mut ct, r, &h);
let mut liftm = [0u16; N];
V::poly_lift(&mut liftm, m);
for i in 0..N {
ct[i] = ct[i].wrapping_add(liftm[i]);
}
V::poly_sq_tobytes(c, &ct);
}
pub(crate) fn owcpa_dec<V, const N: usize, const LOGQ: usize>(
rm: &mut [u8],
ciphertext: &[u8],
secretkey: &[u8],
) -> i32
where
V: NtruVariant<N, LOGQ>,
{
debug_assert_eq!(rm.len(), V::OWCPA_MSGBYTES);
debug_assert_eq!(ciphertext.len(), V::OWCPA_BYTES);
debug_assert_eq!(secretkey.len(), V::OWCPA_SECRETKEYBYTES);
let mut c = [0u16; N];
V::poly_sq_frombytes(&mut c, ciphertext);
poly_rq_sum_zero_adjust::<N>(&mut c);
let mut f = [0u16; N];
poly_s3_frombytes::<N>(&mut f, &secretkey[..V::PACK_TRINARY_BYTES]);
poly_z3_to_zq::<N>(&mut f, V::Q_MASK);
let mut cf = [0u16; N];
poly_rq_mul::<N>(&mut cf, &c, &f);
let mut mf = [0u16; N];
poly_rq_to_s3::<N, LOGQ>(&mut mf, &cf);
let mut finv3 = [0u16; N];
poly_s3_frombytes::<N>(
&mut finv3,
&secretkey[V::PACK_TRINARY_BYTES..2 * V::PACK_TRINARY_BYTES],
);
let mut m = [0u16; N];
poly_s3_mul::<N>(&mut m, &mf, &finv3);
poly_s3_tobytes::<N>(&mut rm[V::PACK_TRINARY_BYTES..], &m);
let mut fail = 0i32;
fail |= owcpa_check_ciphertext::<N, LOGQ>(ciphertext);
fail |= V::check_m(&m);
let mut liftm = [0u16; N];
V::poly_lift(&mut liftm, &m);
let mut b = [0u16; N];
for i in 0..N {
b[i] = c[i].wrapping_sub(liftm[i]);
}
let mut invh = [0u16; N];
V::poly_sq_frombytes(&mut invh, &secretkey[2 * V::PACK_TRINARY_BYTES..]);
let mut r = [0u16; N];
poly_sq_mul::<N>(&mut r, &b, &invh);
fail |= owcpa_check_r::<N, LOGQ>(&r);
poly_trinary_zq_to_z3::<N, LOGQ>(&mut r);
poly_s3_tobytes::<N>(&mut rm[..V::PACK_TRINARY_BYTES], &r);
fail
}
pub(crate) fn owcpa_check_ciphertext<const N: usize, const LOGQ: usize>(
ciphertext: &[u8],
) -> i32 {
let pack_deg = N - 1;
let bits_used = (LOGQ * pack_deg) & 7;
let mask: u8 = if bits_used == 0 { 0 } else { 0xffu8 << bits_used };
let last = *ciphertext.last().expect("non-empty ciphertext");
let t = (last & mask) as u16;
(1 & ((!t).wrapping_add(1) >> 15)) as i32
}
pub(crate) fn owcpa_check_r<const N: usize, const LOGQ: usize>(r: &[u16; N]) -> i32 {
let q16: u16 = if LOGQ < 16 { 1u16 << LOGQ } else { 0 };
let mut t: u32 = 0;
for i in 0..N - 1 {
let c = r[i];
t |= ((c.wrapping_add(1)) & q16.wrapping_sub(4)) as u32;
t |= (c.wrapping_add(2) & 4) as u32;
}
t |= r[N - 1] as u32;
(1 & ((!t).wrapping_add(1) >> 31)) as i32
}
pub(crate) fn owcpa_check_m<const N: usize>(m: &[u16; N], weight: usize) -> i32 {
let mut ps: u16 = 0;
let mut ms: u16 = 0;
for i in 0..N {
ps = ps.wrapping_add(m[i] & 1);
ms = ms.wrapping_add(m[i] & 2);
}
let mut t: u32 = 0;
t |= (ps ^ (ms >> 1)) as u32;
t |= (ms ^ (weight as u16)) as u32;
(1 & ((!t).wrapping_add(1) >> 31)) as i32
}
pub(crate) fn poly_rq_sum_zero_adjust<const N: usize>(r: &mut [u16; N]) {
r[N - 1] = 0;
let mut acc: u16 = 0;
for i in 0..(N - 1) {
acc = acc.wrapping_sub(r[i]);
}
r[N - 1] = acc;
}
pub(crate) fn kem_keypair_seeded<V, R, const N: usize, const LOGQ: usize>(
pk: &mut [u8],
sk: &mut [u8],
rng: &mut R,
seed_scratch: &mut [u8],
) where
V: NtruVariant<N, LOGQ>,
R: crate::Csprng,
{
debug_assert_eq!(seed_scratch.len(), V::SAMPLE_FG_BYTES);
rng.fill_bytes(seed_scratch);
owcpa_keypair::<V, N, LOGQ>(pk, &mut sk[..V::OWCPA_SECRETKEYBYTES], seed_scratch);
rng.fill_bytes(&mut sk[V::OWCPA_SECRETKEYBYTES..]);
}
pub(crate) fn kem_enc_seeded<V, R, const N: usize, const LOGQ: usize>(
c: &mut [u8],
k: &mut [u8],
pk: &[u8],
rng: &mut R,
rm_seed_scratch: &mut [u8],
rm_scratch: &mut [u8],
) where
V: NtruVariant<N, LOGQ>,
R: crate::Csprng,
{
use crate::hash::sha3::Sha3_256;
debug_assert_eq!(k.len(), 32);
debug_assert_eq!(rm_seed_scratch.len(), V::SAMPLE_RM_BYTES);
debug_assert_eq!(rm_scratch.len(), V::OWCPA_MSGBYTES);
rng.fill_bytes(rm_seed_scratch);
let mut r = [0u16; N];
let mut m = [0u16; N];
V::sample_rm(&mut r, &mut m, rm_seed_scratch);
poly_s3_tobytes::<N>(&mut rm_scratch[..V::PACK_TRINARY_BYTES], &r);
poly_s3_tobytes::<N>(&mut rm_scratch[V::PACK_TRINARY_BYTES..], &m);
let digest = Sha3_256::new().chain(rm_scratch).finalize();
k.copy_from_slice(&digest);
poly_z3_to_zq::<N>(&mut r, V::Q_MASK);
owcpa_enc::<V, N, LOGQ>(c, &r, &m, pk);
}
pub(crate) fn kem_dec<V, const N: usize, const LOGQ: usize>(
k: &mut [u8],
c: &[u8],
sk: &[u8],
rm_scratch: &mut [u8],
) where
V: NtruVariant<N, LOGQ>,
{
use crate::hash::sha3::Sha3_256;
debug_assert_eq!(k.len(), 32);
debug_assert_eq!(rm_scratch.len(), V::OWCPA_MSGBYTES);
let fail = owcpa_dec::<V, N, LOGQ>(rm_scratch, c, &sk[..V::OWCPA_SECRETKEYBYTES]);
let digest = Sha3_256::new().chain(rm_scratch).finalize();
k.copy_from_slice(&digest);
let reject = Sha3_256::new()
.chain(&sk[V::OWCPA_SECRETKEYBYTES..])
.chain(c)
.finalize();
cmov(k, &reject, fail as u8);
}
pub(crate) fn sample_iid<const N: usize>(r: &mut [u16; N], uniform_bytes: &[u8]) {
debug_assert_eq!(uniform_bytes.len(), N - 1);
for i in 0..N - 1 {
r[i] = mod3(uniform_bytes[i] as u16);
}
r[N - 1] = 0;
}
pub(crate) fn sample_fixed_type<const N: usize>(
r: &mut [u16; N],
u: &[u8],
weight: usize,
scratch: &mut [i32; N],
) {
debug_assert_eq!(u.len(), (30 * (N - 1)).div_ceil(8));
debug_assert_eq!((N - 1) % 4, 0, "sample_fixed_type assumes (N - 1) % 4 == 0");
let s = &mut scratch[..N - 1];
for slot in s.iter_mut() {
*slot = 0;
}
let blocks = (N - 1) / 4;
for i in 0..blocks {
let base = 15 * i;
s[4 * i] = ((u[base] as i32) << 2)
| ((u[base + 1] as i32) << 10)
| ((u[base + 2] as i32) << 18)
| ((u[base + 3] as u32 as i32) << 26);
s[4 * i + 1] = (((u[base + 3] as i32) & 0xc0) >> 4)
| ((u[base + 4] as i32) << 4)
| ((u[base + 5] as i32) << 12)
| ((u[base + 6] as i32) << 20)
| ((u[base + 7] as u32 as i32) << 28);
s[4 * i + 2] = (((u[base + 7] as i32) & 0xf0) >> 2)
| ((u[base + 8] as i32) << 6)
| ((u[base + 9] as i32) << 14)
| ((u[base + 10] as i32) << 22)
| ((u[base + 11] as u32 as i32) << 30);
s[4 * i + 3] = ((u[base + 11] as i32) & 0xfc)
| ((u[base + 12] as i32) << 8)
| ((u[base + 13] as i32) << 16)
| ((u[base + 14] as u32 as i32) << 24);
}
for i in 0..weight / 2 {
s[i] |= 1;
}
for i in weight / 2..weight {
s[i] |= 2;
}
crypto_sort_int32(s);
for i in 0..N - 1 {
r[i] = (s[i] & 3) as u16;
}
r[N - 1] = 0;
}
pub(crate) fn poly_rq_mul<const N: usize>(
r: &mut [u16; N],
a: &[u16; N],
b: &[u16; N],
) {
crate::public_key::ntru_poly_mul::poly_mul_cyclic(r, a, b);
}
pub(crate) fn poly_sq_mul<const N: usize>(
r: &mut [u16; N],
a: &[u16; N],
b: &[u16; N],
) {
poly_rq_mul::<N>(r, a, b);
poly_mod_q_phi_n::<N>(r);
}
pub(crate) fn poly_s3_mul<const N: usize>(
r: &mut [u16; N],
a: &[u16; N],
b: &[u16; N],
) {
poly_rq_mul::<N>(r, a, b);
poly_mod_3_phi_n::<N>(r);
}
pub(crate) fn poly_lift_hps<const N: usize>(r: &mut [u16; N], a: &[u16; N], q_mask: u16) {
*r = *a;
poly_z3_to_zq::<N>(r, q_mask);
}
pub(crate) fn poly_sq_tobytes_logq11<const N: usize>(r: &mut [u8], a: &[u16; N]) {
const Q_MASK_11: u16 = (1u16 << 11) - 1;
let pack_deg = N - 1;
debug_assert_eq!(r.len(), (pack_deg * 11).div_ceil(8));
let mut t = [0u16; 8];
let full = pack_deg / 8;
for i in 0..full {
for j in 0..8 {
t[j] = a[8 * i + j] & Q_MASK_11;
}
r[11 * i] = (t[0] & 0xff) as u8;
r[11 * i + 1] = ((t[0] >> 8) | ((t[1] & 0x1f) << 3)) as u8;
r[11 * i + 2] = ((t[1] >> 5) | ((t[2] & 0x03) << 6)) as u8;
r[11 * i + 3] = ((t[2] >> 2) & 0xff) as u8;
r[11 * i + 4] = ((t[2] >> 10) | ((t[3] & 0x7f) << 1)) as u8;
r[11 * i + 5] = ((t[3] >> 7) | ((t[4] & 0x0f) << 4)) as u8;
r[11 * i + 6] = ((t[4] >> 4) | ((t[5] & 0x01) << 7)) as u8;
r[11 * i + 7] = ((t[5] >> 1) & 0xff) as u8;
r[11 * i + 8] = ((t[5] >> 9) | ((t[6] & 0x3f) << 2)) as u8;
r[11 * i + 9] = ((t[6] >> 6) | ((t[7] & 0x07) << 5)) as u8;
r[11 * i + 10] = (t[7] >> 3) as u8;
}
let i = full;
let tail = pack_deg - 8 * i;
for j in 0..tail {
t[j] = a[8 * i + j] & Q_MASK_11;
}
for j in tail..8 {
t[j] = 0;
}
match pack_deg & 0x07 {
4 => {
r[11 * i] = (t[0] & 0xff) as u8;
r[11 * i + 1] = ((t[0] >> 8) | ((t[1] & 0x1f) << 3)) as u8;
r[11 * i + 2] = ((t[1] >> 5) | ((t[2] & 0x03) << 6)) as u8;
r[11 * i + 3] = ((t[2] >> 2) & 0xff) as u8;
r[11 * i + 4] = ((t[2] >> 10) | ((t[3] & 0x7f) << 1)) as u8;
r[11 * i + 5] = ((t[3] >> 7) | ((t[4] & 0x0f) << 4)) as u8;
}
2 => {
r[11 * i] = (t[0] & 0xff) as u8;
r[11 * i + 1] = ((t[0] >> 8) | ((t[1] & 0x1f) << 3)) as u8;
r[11 * i + 2] = ((t[1] >> 5) | ((t[2] & 0x03) << 6)) as u8;
}
0 => {}
_ => unreachable!(),
}
}
pub(crate) fn poly_sq_frombytes_logq11<const N: usize>(r: &mut [u16; N], a: &[u8]) {
let pack_deg = N - 1;
debug_assert!(a.len() >= (pack_deg * 11).div_ceil(8));
let full = pack_deg / 8;
for i in 0..full {
r[8 * i] = (a[11 * i] as u16) | (((a[11 * i + 1] as u16) & 0x07) << 8);
r[8 * i + 1] =
((a[11 * i + 1] as u16) >> 3) | (((a[11 * i + 2] as u16) & 0x3f) << 5);
r[8 * i + 2] = ((a[11 * i + 2] as u16) >> 6)
| (((a[11 * i + 3] as u16) & 0xff) << 2)
| (((a[11 * i + 4] as u16) & 0x01) << 10);
r[8 * i + 3] =
((a[11 * i + 4] as u16) >> 1) | (((a[11 * i + 5] as u16) & 0x0f) << 7);
r[8 * i + 4] =
((a[11 * i + 5] as u16) >> 4) | (((a[11 * i + 6] as u16) & 0x7f) << 4);
r[8 * i + 5] = ((a[11 * i + 6] as u16) >> 7)
| (((a[11 * i + 7] as u16) & 0xff) << 1)
| (((a[11 * i + 8] as u16) & 0x03) << 9);
r[8 * i + 6] =
((a[11 * i + 8] as u16) >> 2) | (((a[11 * i + 9] as u16) & 0x1f) << 6);
r[8 * i + 7] =
((a[11 * i + 9] as u16) >> 5) | (((a[11 * i + 10] as u16) & 0xff) << 3);
}
let i = full;
match pack_deg & 0x07 {
4 => {
r[8 * i] = (a[11 * i] as u16) | (((a[11 * i + 1] as u16) & 0x07) << 8);
r[8 * i + 1] =
((a[11 * i + 1] as u16) >> 3) | (((a[11 * i + 2] as u16) & 0x3f) << 5);
r[8 * i + 2] = ((a[11 * i + 2] as u16) >> 6)
| (((a[11 * i + 3] as u16) & 0xff) << 2)
| (((a[11 * i + 4] as u16) & 0x01) << 10);
r[8 * i + 3] =
((a[11 * i + 4] as u16) >> 1) | (((a[11 * i + 5] as u16) & 0x0f) << 7);
}
2 => {
r[8 * i] = (a[11 * i] as u16) | (((a[11 * i + 1] as u16) & 0x07) << 8);
r[8 * i + 1] =
((a[11 * i + 1] as u16) >> 3) | (((a[11 * i + 2] as u16) & 0x3f) << 5);
}
0 => {}
_ => unreachable!(),
}
r[N - 1] = 0;
}
pub(crate) fn poly_sq_tobytes_logq12<const N: usize>(r: &mut [u8], a: &[u16; N]) {
const Q_MASK_12: u16 = (1u16 << 12) - 1;
let pack_deg = N - 1;
debug_assert_eq!(r.len(), (pack_deg * 12).div_ceil(8));
for i in 0..pack_deg / 2 {
let c0 = a[2 * i] & Q_MASK_12;
let c1 = a[2 * i + 1] & Q_MASK_12;
r[3 * i] = (c0 & 0xff) as u8;
r[3 * i + 1] = ((c0 >> 8) | ((c1 & 0x0f) << 4)) as u8;
r[3 * i + 2] = (c1 >> 4) as u8;
}
}
pub(crate) fn poly_sq_frombytes_logq12<const N: usize>(r: &mut [u16; N], a: &[u8]) {
let pack_deg = N - 1;
debug_assert!(a.len() >= (pack_deg * 12).div_ceil(8));
for i in 0..pack_deg / 2 {
r[2 * i] = (a[3 * i] as u16) | (((a[3 * i + 1] as u16) & 0x0f) << 8);
r[2 * i + 1] =
((a[3 * i + 1] as u16) >> 4) | (((a[3 * i + 2] as u16) & 0xff) << 4);
}
r[N - 1] = 0;
}
pub(crate) fn poly_sq_tobytes_logq13<const N: usize>(r: &mut [u8], a: &[u16; N]) {
const Q_MASK_13: u16 = (1u16 << 13) - 1;
let pack_deg = N - 1;
debug_assert_eq!(r.len(), (pack_deg * 13).div_ceil(8));
let mut t = [0u16; 8];
let full = pack_deg / 8;
for i in 0..full {
for j in 0..8 {
t[j] = a[8 * i + j] & Q_MASK_13;
}
r[13 * i] = (t[0] & 0xff) as u8;
r[13 * i + 1] = ((t[0] >> 8) | ((t[1] & 0x07) << 5)) as u8;
r[13 * i + 2] = ((t[1] >> 3) & 0xff) as u8;
r[13 * i + 3] = ((t[1] >> 11) | ((t[2] & 0x3f) << 2)) as u8;
r[13 * i + 4] = ((t[2] >> 6) | ((t[3] & 0x01) << 7)) as u8;
r[13 * i + 5] = ((t[3] >> 1) & 0xff) as u8;
r[13 * i + 6] = ((t[3] >> 9) | ((t[4] & 0x0f) << 4)) as u8;
r[13 * i + 7] = ((t[4] >> 4) & 0xff) as u8;
r[13 * i + 8] = ((t[4] >> 12) | ((t[5] & 0x7f) << 1)) as u8;
r[13 * i + 9] = ((t[5] >> 7) | ((t[6] & 0x03) << 6)) as u8;
r[13 * i + 10] = ((t[6] >> 2) & 0xff) as u8;
r[13 * i + 11] = ((t[6] >> 10) | ((t[7] & 0x1f) << 3)) as u8;
r[13 * i + 12] = (t[7] >> 5) as u8;
}
let i = full;
let tail = pack_deg - 8 * i;
for j in 0..tail {
t[j] = a[8 * i + j] & Q_MASK_13;
}
for j in tail..8 {
t[j] = 0;
}
match pack_deg & 0x07 {
4 => {
r[13 * i] = (t[0] & 0xff) as u8;
r[13 * i + 1] = ((t[0] >> 8) | ((t[1] & 0x07) << 5)) as u8;
r[13 * i + 2] = ((t[1] >> 3) & 0xff) as u8;
r[13 * i + 3] = ((t[1] >> 11) | ((t[2] & 0x3f) << 2)) as u8;
r[13 * i + 4] = ((t[2] >> 6) | ((t[3] & 0x01) << 7)) as u8;
r[13 * i + 5] = ((t[3] >> 1) & 0xff) as u8;
r[13 * i + 6] = ((t[3] >> 9) | ((t[4] & 0x0f) << 4)) as u8;
}
2 => {
r[13 * i] = (t[0] & 0xff) as u8;
r[13 * i + 1] = ((t[0] >> 8) | ((t[1] & 0x07) << 5)) as u8;
r[13 * i + 2] = ((t[1] >> 3) & 0xff) as u8;
r[13 * i + 3] = ((t[1] >> 11) | ((t[2] & 0x3f) << 2)) as u8;
}
0 => {}
_ => unreachable!(),
}
}
pub(crate) fn poly_sq_frombytes_logq13<const N: usize>(r: &mut [u16; N], a: &[u8]) {
let pack_deg = N - 1;
debug_assert!(a.len() >= (pack_deg * 13).div_ceil(8));
let full = pack_deg / 8;
for i in 0..full {
r[8 * i] = (a[13 * i] as u16) | (((a[13 * i + 1] as u16) & 0x1f) << 8);
r[8 * i + 1] = ((a[13 * i + 1] as u16) >> 5)
| ((a[13 * i + 2] as u16) << 3)
| (((a[13 * i + 3] as u16) & 0x03) << 11);
r[8 * i + 2] =
((a[13 * i + 3] as u16) >> 2) | (((a[13 * i + 4] as u16) & 0x7f) << 6);
r[8 * i + 3] = ((a[13 * i + 4] as u16) >> 7)
| ((a[13 * i + 5] as u16) << 1)
| (((a[13 * i + 6] as u16) & 0x0f) << 9);
r[8 * i + 4] = ((a[13 * i + 6] as u16) >> 4)
| ((a[13 * i + 7] as u16) << 4)
| (((a[13 * i + 8] as u16) & 0x01) << 12);
r[8 * i + 5] =
((a[13 * i + 8] as u16) >> 1) | (((a[13 * i + 9] as u16) & 0x3f) << 7);
r[8 * i + 6] = ((a[13 * i + 9] as u16) >> 6)
| ((a[13 * i + 10] as u16) << 2)
| (((a[13 * i + 11] as u16) & 0x07) << 10);
r[8 * i + 7] =
((a[13 * i + 11] as u16) >> 3) | ((a[13 * i + 12] as u16) << 5);
}
let i = full;
match pack_deg & 0x07 {
4 => {
r[8 * i] = (a[13 * i] as u16) | (((a[13 * i + 1] as u16) & 0x1f) << 8);
r[8 * i + 1] = ((a[13 * i + 1] as u16) >> 5)
| ((a[13 * i + 2] as u16) << 3)
| (((a[13 * i + 3] as u16) & 0x03) << 11);
r[8 * i + 2] =
((a[13 * i + 3] as u16) >> 2) | (((a[13 * i + 4] as u16) & 0x7f) << 6);
r[8 * i + 3] = ((a[13 * i + 4] as u16) >> 7)
| ((a[13 * i + 5] as u16) << 1)
| (((a[13 * i + 6] as u16) & 0x0f) << 9);
}
2 => {
r[8 * i] = (a[13 * i] as u16) | (((a[13 * i + 1] as u16) & 0x1f) << 8);
r[8 * i + 1] = ((a[13 * i + 1] as u16) >> 5)
| ((a[13 * i + 2] as u16) << 3)
| (((a[13 * i + 3] as u16) & 0x03) << 11);
}
0 => {}
_ => unreachable!(),
}
r[N - 1] = 0;
}
pub(crate) fn poly_s3_tobytes<const N: usize>(msg: &mut [u8], a: &[u16; N]) {
let pack_deg = N - 1;
debug_assert_eq!(msg.len(), pack_deg.div_ceil(5));
let full = pack_deg / 5;
for i in 0..full {
let mut c = (a[5 * i + 4] & 0xff) as u8;
c = (3u8.wrapping_mul(c)).wrapping_add(a[5 * i + 3] as u8);
c = (3u8.wrapping_mul(c)).wrapping_add(a[5 * i + 2] as u8);
c = (3u8.wrapping_mul(c)).wrapping_add(a[5 * i + 1] as u8);
c = (3u8.wrapping_mul(c)).wrapping_add(a[5 * i] as u8);
msg[i] = c;
}
if pack_deg > full * 5 {
let mut c: u8 = 0;
let start = 5 * full;
let mut j = (pack_deg - start) as isize - 1;
while j >= 0 {
c = (3u8.wrapping_mul(c)).wrapping_add(a[start + j as usize] as u8);
j -= 1;
}
msg[full] = c;
}
}
pub(crate) fn poly_s3_frombytes<const N: usize>(r: &mut [u16; N], msg: &[u8]) {
let pack_deg = N - 1;
debug_assert_eq!(msg.len(), pack_deg.div_ceil(5));
let full = pack_deg / 5;
for i in 0..full {
let c = msg[i] as u32;
r[5 * i] = c as u16;
r[5 * i + 1] = ((c * 171) >> 9) as u16;
r[5 * i + 2] = ((c * 57) >> 9) as u16;
r[5 * i + 3] = ((c * 19) >> 9) as u16;
r[5 * i + 4] = ((c * 203) >> 14) as u16;
}
if pack_deg > full * 5 {
let mut c = msg[full] as u32;
let mut j = 0;
while 5 * full + j < pack_deg {
r[5 * full + j] = c as u16;
c = (c * 171) >> 9;
j += 1;
}
}
r[N - 1] = 0;
poly_mod_3_phi_n::<N>(r);
}
#[cfg(test)]
#[derive(Debug)]
pub(crate) struct KatEntry {
pub seed: Vec<u8>,
pub pk: Vec<u8>,
pub sk: Vec<u8>,
pub ct: Vec<u8>,
pub ss: Vec<u8>,
}
#[cfg(test)]
pub(crate) fn hex_to_bytes(s: &str) -> Vec<u8> {
let cleaned: String = s.chars().filter(|c| !c.is_whitespace()).collect();
assert!(cleaned.len() % 2 == 0, "hex length must be even");
(0..cleaned.len())
.step_by(2)
.map(|i| u8::from_str_radix(&cleaned[i..i + 2], 16).expect("valid hex"))
.collect()
}
#[cfg(test)]
pub(crate) fn parse_kat_entry(rsp: &str, count: usize) -> Option<KatEntry> {
let target = format!("count = {count}");
let mut lines = rsp.lines();
while let Some(line) = lines.next() {
if line.trim() == target {
let mut seed = None;
let mut pk = None;
let mut sk = None;
let mut ct = None;
let mut ss = None;
for line in lines.by_ref() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with("count = ") {
break;
}
let Some((key, value)) = trimmed.split_once(" = ") else {
continue;
};
let bytes = hex_to_bytes(value.trim());
match key.trim() {
"seed" => seed = Some(bytes),
"pk" => pk = Some(bytes),
"sk" => sk = Some(bytes),
"ct" => ct = Some(bytes),
"ss" => ss = Some(bytes),
_ => {}
}
}
return Some(KatEntry {
seed: seed?,
pk: pk?,
sk: sk?,
ct: ct?,
ss: ss?,
});
}
}
None
}
#[cfg(test)]
pub(crate) const KAT_SAMPLED_COUNTS: &[usize] = &[0, 1, 7, 23, 42, 67, 83, 99];
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cmov_copies_when_b_is_one_else_no_change() {
let mut r = [1u8, 2, 3, 4];
let x = [9u8, 8, 7, 6];
cmov(&mut r, &x, 0);
assert_eq!(r, [1, 2, 3, 4]);
cmov(&mut r, &x, 1);
assert_eq!(r, [9, 8, 7, 6]);
}
#[test]
fn crypto_sort_int32_matches_std_sort() {
let inputs: &[&[i32]] = &[
&[],
&[0],
&[3, 1, 2],
&[i32::MAX, i32::MIN, 0, -1, 1],
&[7, 7, 7, 7, 7],
&[5, -3, 8, 0, -7, 2, 6, -1, 9, 4, -2, 1, -5, 3, -6, 7, -8, -4],
];
for &case in inputs {
let mut a = case.to_vec();
let mut b = case.to_vec();
crypto_sort_int32(&mut a);
b.sort();
assert_eq!(a, b, "sort mismatch on {case:?}");
}
}
#[test]
fn mod3_matches_naive_reduction() {
for a in 0u16..=u16::MAX {
assert_eq!(mod3(a), a % 3);
}
}
#[test]
fn digest_chain_matches_concat_then_update() {
use crate::hash::sha2::Sha256;
use crate::hash::sha3::Sha3_256;
let parts: [&[u8]; 3] = [b"abc", b"defghij", b""];
let concat: Vec<u8> = parts.iter().flat_map(|p| p.iter().copied()).collect();
for &(a, b, c) in &[(parts[0], parts[1], parts[2])] {
let chained = Sha3_256::new().chain(a).chain(b).chain(c).finalize();
let oneshot = {
let mut h = Sha3_256::new();
h.update(&concat);
h.finalize()
};
assert_eq!(chained.as_slice(), oneshot.as_slice(), "Sha3_256 chain");
let chained = Sha256::new().chain(a).chain(b).chain(c).finalize();
let oneshot = {
let mut h = Sha256::new();
h.update(&concat);
h.finalize()
};
assert_eq!(chained.as_slice(), oneshot.as_slice(), "Sha256 chain");
}
}
}