use crate::encodings::{pk_decode, pk_encode, sig_decode, sig_encode, sk_decode, w1_encode};
use crate::hashing::{expand_a, expand_mask, expand_s, h256_xof, sample_in_ball};
use crate::helpers::{
add_vector_ntt, center_mod, full_reduce32, infinity_norm, mat_vec_mul, mont_reduce,
partial_reduce32, to_mont,
};
use crate::high_low::{high_bits, low_bits, make_hint, power2round, use_hint};
use crate::ntt::{inv_ntt, ntt};
use crate::types::{PrivateKey, PublicKey, R, T};
use crate::{D, Q};
use rand_core::CryptoRngCore;
use sha3::digest::XofReader;
pub(crate) fn key_gen<
const CTEST: bool,
const K: usize,
const L: usize,
const PK_LEN: usize,
const SK_LEN: usize,
>(
rng: &mut impl CryptoRngCore, eta: i32,
) -> Result<(PublicKey<K, L>, PrivateKey<K, L>), &'static str> {
let mut xi = [0u8; 32];
rng.try_fill_bytes(&mut xi).map_err(|_| "KeyGen: Random number generator failed")?;
Ok(key_gen_internal::<CTEST, K, L, PK_LEN, SK_LEN>(eta, &xi))
}
pub(crate) fn key_gen_internal<
const CTEST: bool,
const K: usize,
const L: usize,
const PK_LEN: usize,
const SK_LEN: usize,
>(
eta: i32, xi: &[u8; 32],
) -> (PublicKey<K, L>, PrivateKey<K, L>) {
let mut h2 = h256_xof(&[xi, &[K.to_le_bytes()[0]], &[L.to_le_bytes()[0]]]);
let mut rho = [0u8; 32];
h2.read(&mut rho);
let mut rho_prime = [0u8; 64];
h2.read(&mut rho_prime);
let mut cap_k = [0u8; 32];
h2.read(&mut cap_k);
let (s_1, s_2): ([R; L], [R; K]) = expand_s::<CTEST, K, L>(eta, &rho_prime);
let (t_1, t_0): ([R; K], [R; K]) = {
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(&rho);
let s_1_hat: [T; L] = ntt(&s_1);
let as1_hat: [T; K] = mat_vec_mul(&cap_a_hat, &s_1_hat);
let t_not_reduced: [R; K] = add_vector_ntt(&inv_ntt(&as1_hat), &s_2);
let t: [R; K] = core::array::from_fn(|k| {
R(core::array::from_fn(|n| full_reduce32(t_not_reduced[k].0[n])))
});
power2round(&t)
};
let mut tr = [0u8; 64];
let mut h8 = h256_xof(&[&pk_encode::<K, PK_LEN>(&rho, &t_1)]);
h8.read(&mut tr);
let t1_d2_hat_mont: [T; K] = {
let t1_hat_mont: [T; K] = to_mont(&ntt(&t_1));
to_mont(&core::array::from_fn(|k| {
T(core::array::from_fn(|n| mont_reduce(i64::from(t1_hat_mont[k].0[n]) << D)))
}))
};
let pk = PublicKey { rho, tr, t1_d2_hat_mont };
let s_1_hat_mont: [T; L] = to_mont(&ntt(&s_1));
let s_2_hat_mont: [T; K] = to_mont(&ntt(&s_2));
let t_0_hat_mont: [T; K] = to_mont(&ntt(&t_0));
let sk = PrivateKey { rho, cap_k, tr, s_1_hat_mont, s_2_hat_mont, t_0_hat_mont };
(pk, sk)
}
#[allow(
clippy::similar_names,
clippy::many_single_char_names,
clippy::too_many_arguments,
clippy::too_many_lines
)]
pub(crate) fn sign_internal<
const CTEST: bool,
const K: usize,
const L: usize,
const LAMBDA_DIV4: usize,
const SIG_LEN: usize,
const SK_LEN: usize,
const W1_LEN: usize,
>(
beta: i32, gamma1: i32, gamma2: i32, omega: i32, tau: i32, esk: &PrivateKey<K, L>,
message: &[u8], ctx: &[u8], oid: &[u8], phm: &[u8], rnd: [u8; 32], nist: bool,
) -> [u8; SIG_LEN] {
let PrivateKey { rho, cap_k, tr, s_1_hat_mont, s_2_hat_mont, t_0_hat_mont } = esk;
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);
let mut h6 = if nist {
h256_xof(&[tr, message])
} else if oid.is_empty() {
h256_xof(&[tr, &[0u8], &[ctx.len().to_le_bytes()[0]], ctx, message])
} else {
h256_xof(&[tr, &[1u8], &[ctx.len().to_le_bytes()[0]], ctx, oid, phm])
};
let mut mu = [0u8; 64];
h6.read(&mut mu);
let mut h7 = h256_xof(&[cap_k, &rnd, &mu]);
let mut rho_prime = [0u8; 64];
h7.read(&mut rho_prime);
let mut kappa_ctr = 0u16;
let mut z: [R; L];
let mut h: [R; K];
let mut c_tilde = [0u8; LAMBDA_DIV4];
loop {
let y: [R; L] = expand_mask(gamma1, &rho_prime, kappa_ctr);
let w: [R; K] = {
let y_hat: [T; L] = ntt(&y);
let ay_hat: [T; K] = mat_vec_mul(&cap_a_hat, &y_hat);
inv_ntt(&ay_hat)
};
let w_1: [R; K] =
core::array::from_fn(|k| R(core::array::from_fn(|n| high_bits(gamma2, w[k].0[n]))));
let mut w1_tilde = [0u8; W1_LEN];
w1_encode::<K>(gamma2, &w_1, &mut w1_tilde);
let mut h15 = h256_xof(&[&mu, &w1_tilde]);
h15.read(&mut c_tilde);
let c: R = sample_in_ball::<CTEST>(tau, &c_tilde);
let c_hat: &T = &ntt(&[c])[0];
let c_s_1: [R; L] = {
let cs1_hat: [T; L] = core::array::from_fn(|l| {
T(core::array::from_fn(|n| {
mont_reduce(i64::from(c_hat.0[n]) * i64::from(s_1_hat_mont[l].0[n]))
}))
});
inv_ntt(&cs1_hat)
};
let c_s_2: [R; K] = {
let cs2_hat: [T; K] = core::array::from_fn(|k| {
T(core::array::from_fn(|n| {
mont_reduce(i64::from(c_hat.0[n]) * i64::from(s_2_hat_mont[k].0[n]))
}))
});
inv_ntt(&cs2_hat)
};
z = core::array::from_fn(|l| {
R(core::array::from_fn(|n| partial_reduce32(y[l].0[n] + c_s_1[l].0[n])))
});
let r0: [R; K] = core::array::from_fn(|k| {
R(core::array::from_fn(|n| {
low_bits(gamma2, partial_reduce32(w[k].0[n] - c_s_2[k].0[n]))
}))
});
let z_norm = infinity_norm(&z);
let r0_norm = infinity_norm(&r0);
if !CTEST && ((z_norm >= (gamma1 - beta)) || (r0_norm >= (gamma2 - beta))) {
kappa_ctr += u16::try_from(L).expect("cannot fail; L is static parameter");
continue;
}
let c_t_0: [R; K] = {
let ct0_hat: [T; K] = core::array::from_fn(|k| {
T(core::array::from_fn(|n| {
mont_reduce(i64::from(c_hat.0[n]) * i64::from(t_0_hat_mont[k].0[n]))
}))
});
inv_ntt(&ct0_hat)
};
h = core::array::from_fn(|k| {
R(core::array::from_fn(|n| {
i32::from(make_hint(
gamma2,
Q - c_t_0[k].0[n], partial_reduce32(w[k].0[n] - c_s_2[k].0[n] + c_t_0[k].0[n]),
))
}))
});
if !CTEST
&& ((infinity_norm(&c_t_0) >= gamma2)
|| (h.iter().map(|h_i| h_i.0.iter().sum::<i32>()).sum::<i32>() > omega))
{
kappa_ctr += u16::try_from(L).expect("cannot fail; L is static parameter");
continue;
}
break;
}
let zmodq: [R; L] =
core::array::from_fn(|l| R(core::array::from_fn(|n| center_mod(z[l].0[n]))));
sig_encode::<CTEST, K, L, LAMBDA_DIV4, SIG_LEN>(gamma1, omega, &c_tilde, &zmodq, &h)
}
#[allow(clippy::too_many_arguments, clippy::similar_names, clippy::type_complexity)]
pub(crate) fn verify_internal<
const CTEST: bool,
const K: usize,
const L: usize,
const LAMBDA_DIV4: usize,
const PK_LEN: usize,
const SIG_LEN: usize,
const W1_LEN: usize,
>(
beta: i32, gamma1: i32, gamma2: i32, omega: i32, tau: i32, epk: &PublicKey<K, L>, m: &[u8],
sig: &[u8; SIG_LEN], ctx: &[u8], oid: &[u8], phm: &[u8], nist: bool,
) -> bool {
let PublicKey { rho, tr, t1_d2_hat_mont } = epk;
let Ok((c_tilde, z, h)): Result<([u8; LAMBDA_DIV4], [R; L], Option<[R; K]>), &'static str> =
sig_decode(gamma1, omega, sig)
else {
return false;
};
let Some(h) = h else { return false };
debug_assert!(infinity_norm(&z) <= gamma1, "Alg 8: i_norm out of range");
let mut h7 = if nist {
h256_xof(&[tr, m])
} else if oid.is_empty() {
h256_xof(&[tr, &[0u8], &[ctx.len().to_le_bytes()[0]], ctx, m])
} else {
h256_xof(&[tr, &[1u8], &[ctx.len().to_le_bytes()[0]], ctx, oid, phm])
};
let mut mu = [0u8; 64];
h7.read(&mut mu);
let c: R = sample_in_ball::<false>(tau, &c_tilde);
let wp_approx: [R; K] = {
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);
let z_hat: [T; L] = ntt(&z);
let az_hat: [T; K] = mat_vec_mul(&cap_a_hat, &z_hat);
let c_hat: &T = &ntt(&[c])[0];
inv_ntt(&core::array::from_fn(|k| {
T(core::array::from_fn(|n| {
az_hat[k].0[n]
- mont_reduce(i64::from(c_hat.0[n]) * i64::from(t1_d2_hat_mont[k].0[n]))
}))
}))
};
let wp_1: [R; K] = core::array::from_fn(|k| {
R(core::array::from_fn(|n| use_hint(gamma2, h[k].0[n], wp_approx[k].0[n])))
});
let mut tmp = [0u8; W1_LEN];
w1_encode::<K>(gamma2, &wp_1, &mut tmp);
let mut h12 = h256_xof(&[&mu, &tmp]);
let mut c_tilde_p = [0u8; LAMBDA_DIV4];
h12.read(&mut c_tilde_p);
let left = infinity_norm(&z) < (gamma1 - beta);
let right = c_tilde == c_tilde_p; left && right
}
pub(crate) fn expand_private<const K: usize, const L: usize, const SK_LEN: usize>(
eta: i32, sk: &[u8; SK_LEN],
) -> Result<PrivateKey<K, L>, &'static str> {
let (rho, cap_k, tr, s_1, s_2, t_0) = sk_decode(eta, sk)?;
let s_1_hat_mont: [T; L] = to_mont(&ntt(&s_1));
let s_2_hat_mont: [T; K] = to_mont(&ntt(&s_2));
let t_0_hat_mont: [T; K] = to_mont(&ntt(&t_0));
Ok(PrivateKey {
rho: *rho,
cap_k: *cap_k,
tr: *tr,
s_1_hat_mont,
s_2_hat_mont,
t_0_hat_mont,
})
}
pub(crate) fn expand_public<const K: usize, const L: usize, const PK_LEN: usize>(
pk: &[u8; PK_LEN],
) -> Result<PublicKey<K, L>, &'static str> {
let (rho, t_1): (&[u8; 32], [R; K]) = pk_decode(pk)?;
let mut h6 = h256_xof(&[pk]);
let mut tr = [0u8; 64];
h6.read(&mut tr);
let t1_hat_mont: [T; K] = to_mont(&ntt(&t_1));
let t1_d2_hat_mont: [T; K] = to_mont(&core::array::from_fn(|k| {
T(core::array::from_fn(|n| mont_reduce(i64::from(t1_hat_mont[k].0[n]) << D)))
}));
Ok(PublicKey { rho: *rho, tr, t1_d2_hat_mont })
}
pub(crate) fn private_to_public_key<const K: usize, const L: usize>(
sk: &PrivateKey<K, L>,
) -> PublicKey<K, L> {
let PrivateKey { rho, cap_k: _, tr, s_1_hat_mont, s_2_hat_mont, t_0_hat_mont } = sk;
let cap_a_hat: [[T; L]; K] = expand_a::<false, K, L>(rho);
let s_1_hat: [T; L] = core::array::from_fn(|l| {
T(core::array::from_fn(|n| mont_reduce(i64::from(s_1_hat_mont[l].0[n]))))
});
let s_2: [R; K] = inv_ntt(&core::array::from_fn(|k| {
T(core::array::from_fn(|n| mont_reduce(i64::from(s_2_hat_mont[k].0[n]))))
}));
let s_2: [R; K] = core::array::from_fn(|k| {
R(core::array::from_fn(|n| {
if s_2[k].0[n] > (Q / 2) {
s_2[k].0[n] - Q
} else {
s_2[k].0[n]
}
}))
});
let t_0: [R; K] = inv_ntt(&core::array::from_fn(|k| {
T(core::array::from_fn(|n| mont_reduce(i64::from(t_0_hat_mont[k].0[n]))))
}));
let sk_t_0: [R; K] = core::array::from_fn(|k| {
R(core::array::from_fn(|n| {
if t_0[k].0[n] > (Q / 2) {
t_0[k].0[n] - Q
} else {
t_0[k].0[n]
}
}))
});
let t: [R; K] = {
let as1_hat: [T; K] = mat_vec_mul(&cap_a_hat, &s_1_hat);
let t_not_reduced: [R; K] = add_vector_ntt(&inv_ntt(&as1_hat), &s_2);
core::array::from_fn(|k| R(core::array::from_fn(|n| full_reduce32(t_not_reduced[k].0[n]))))
};
let (t_1, pk_t_0): ([R; K], [R; K]) = power2round(&t);
debug_assert_eq!(sk_t_0, pk_t_0);
let t1_hat_mont: [T; K] = to_mont(&ntt(&t_1));
let t1_d2_hat_mont: [T; K] = to_mont(&core::array::from_fn(|k| {
T(core::array::from_fn(|n| mont_reduce(i64::from(t1_hat_mont[k].0[n]) << D)))
}));
PublicKey { rho: *rho, tr: *tr, t1_d2_hat_mont }
}