use crate::codec::{decode, encode};
use crate::error::{Error, Result};
use crate::gf16::{mat_add, mat_mul, mul_f};
use crate::keygen::expand_p1_p2;
use crate::keypair::derive_cpk_from_csk;
use crate::matrix_ops::{compute_m_and_vpv, p1p1t_times_o};
use crate::params::{F_TAIL_LEN, MAX_M_VEC_LIMBS, MayoParameter};
use crate::sample::sample_solution;
use crate::verify::mayo_verify;
use rand::CryptoRng;
use sha3::Shake256;
use sha3::digest::{ExtendableOutput, Update, XofReader};
use zeroize::Zeroizing;
fn expand_sk<P: MayoParameter>(csk: &[u8]) -> (Zeroizing<Vec<u64>>, Zeroizing<Vec<u8>>) {
let param_o = P::O;
let param_v = P::V;
let param_o_bytes = P::O_BYTES;
let param_pk_seed_bytes = P::PK_SEED_BYTES;
let param_sk_seed_bytes = P::SK_SEED_BYTES;
let seed_sk = &csk[..param_sk_seed_bytes];
let mut s = Zeroizing::new(vec![0u8; param_pk_seed_bytes + param_o_bytes]);
let mut hasher = Shake256::default();
hasher.update(seed_sk);
let mut reader = hasher.finalize_xof();
reader.read(&mut s);
let seed_pk = s[..param_pk_seed_bytes].to_vec();
let mut o = Zeroizing::new(vec![0u8; param_v * param_o]);
decode(&s[param_pk_seed_bytes..], &mut o, param_v * param_o);
let mut p = Zeroizing::new(expand_p1_p2::<P>(&seed_pk));
{
let (p1, l) = p.split_at_mut(P::P1_LIMBS);
p1p1t_times_o::<P>(p1, &o, l);
}
(p, o)
}
fn transpose_16x16_nibbles(m: &mut [u64]) {
let even_nibbles: u64 = 0x0f0f0f0f0f0f0f0f;
let even_bytes: u64 = 0x00ff00ff00ff00ff;
let even_2bytes: u64 = 0x0000ffff0000ffff;
let even_half: u64 = 0x00000000ffffffff;
let mut i = 0;
while i < 16 {
let t = ((m[i] >> 4) ^ m[i + 1]) & even_nibbles;
m[i] ^= t << 4;
m[i + 1] ^= t;
i += 2;
}
i = 0;
while i < 16 {
let t0 = ((m[i] >> 8) ^ m[i + 2]) & even_bytes;
let t1 = ((m[i + 1] >> 8) ^ m[i + 3]) & even_bytes;
m[i] ^= t0 << 8;
m[i + 1] ^= t1 << 8;
m[i + 2] ^= t0;
m[i + 3] ^= t1;
i += 4;
}
for i in 0..4 {
let t0 = ((m[i] >> 16) ^ m[i + 4]) & even_2bytes;
let t1 = ((m[i + 8] >> 16) ^ m[i + 12]) & even_2bytes;
m[i] ^= t0 << 16;
m[i + 8] ^= t1 << 16;
m[i + 4] ^= t0;
m[i + 12] ^= t1;
}
for i in 0..8 {
let t = ((m[i] >> 32) ^ m[i + 8]) & even_half;
m[i] ^= t << 32;
m[i + 8] ^= t;
}
}
pub(crate) fn compute_rhs<P: MayoParameter>(vpv: &mut [u64], t: &[u8], y: &mut [u8]) {
let m_vec_limbs = P::M_VEC_LIMBS;
let param_m = P::M;
let param_k = P::K;
let f_tail = P::F_TAIL;
let top_pos = ((param_m - 1) % 16) * 4;
if param_m % 16 != 0 {
let mut mask: u64 = 1;
mask <<= (param_m % 16) * 4;
mask -= 1;
for i in 0..(param_k * param_k) {
vpv[i * m_vec_limbs + m_vec_limbs - 1] &= mask;
}
}
let mut temp = [0u64; MAX_M_VEC_LIMBS];
for i in (0..param_k).rev() {
for j in i..param_k {
let top = ((temp[m_vec_limbs - 1] >> top_pos) % 16) as u8;
temp[m_vec_limbs - 1] <<= 4;
for k in (0..m_vec_limbs - 1).rev() {
temp[k + 1] ^= temp[k] >> 60;
temp[k] <<= 4;
}
for (jj, &f_coeff) in f_tail.iter().enumerate().take(F_TAIL_LEN) {
let product = mul_f(top, f_coeff);
if jj % 2 == 0 {
let limb_idx = (jj / 2) / 8;
let byte_idx = (jj / 2) % 8;
temp[limb_idx] ^= u64::from(product) << (byte_idx * 8);
} else {
let limb_idx = (jj / 2) / 8;
let byte_idx = (jj / 2) % 8;
temp[limb_idx] ^= u64::from(product) << (byte_idx * 8 + 4);
}
}
let idx_ij = (i * param_k + j) * m_vec_limbs;
let idx_ji = (j * param_k + i) * m_vec_limbs;
for k in 0..m_vec_limbs {
let sym = if i != j { vpv[idx_ji + k] } else { 0 };
temp[k] ^= vpv[idx_ij + k] ^ sym;
}
}
}
for i in (0..param_m).step_by(2) {
let limb_idx = (i / 2) / 8;
let byte_idx = (i / 2) % 8;
let byte_val = ((temp[limb_idx] >> (byte_idx * 8)) & 0xFF) as u8;
y[i] = t[i] ^ (byte_val & 0xF);
if i + 1 < param_m {
y[i + 1] = t[i + 1] ^ (byte_val >> 4);
}
}
}
fn compute_a<P: MayoParameter>(vtl: &mut [u64], a_out: &mut [u8]) {
let m_vec_limbs = P::M_VEC_LIMBS;
let param_m = P::M;
let param_o = P::O;
let param_k = P::K;
let a_cols = P::A_COLS;
let f_tail = P::F_TAIL;
let m_over_8 = param_m.div_ceil(8);
let a_width = (param_o * param_k).div_ceil(16) * 16;
let mut bits_to_shift: usize = 0;
let mut words_to_shift: usize = 0;
let a_total = a_width * m_over_8;
let mut a = vec![0u64; a_total];
if param_m % 16 != 0 {
let mut mask: u64 = 1;
mask <<= (param_m % 16) * 4;
mask -= 1;
for i in 0..(param_o * param_k) {
vtl[i * m_vec_limbs + m_vec_limbs - 1] &= mask;
}
}
for i in 0..param_k {
for j in (i..param_k).rev() {
let mj_base = j * m_vec_limbs * param_o;
for c in 0..param_o {
for k in 0..m_vec_limbs {
let src = vtl[mj_base + k + c * m_vec_limbs];
let dst_idx = param_o * i + c + (k + words_to_shift) * a_width;
debug_assert!(dst_idx < a_total);
a[dst_idx] ^= src << bits_to_shift;
if bits_to_shift > 0 {
let dst_idx2 = param_o * i + c + (k + words_to_shift + 1) * a_width;
if dst_idx2 < a_total {
a[dst_idx2] ^= src >> (64 - bits_to_shift);
}
}
}
}
if i != j {
let mi_base = i * m_vec_limbs * param_o;
for c in 0..param_o {
for k in 0..m_vec_limbs {
let src = vtl[mi_base + k + c * m_vec_limbs];
let dst_idx = param_o * j + c + (k + words_to_shift) * a_width;
debug_assert!(dst_idx < a_total);
a[dst_idx] ^= src << bits_to_shift;
if bits_to_shift > 0 {
let dst_idx2 = param_o * j + c + (k + words_to_shift + 1) * a_width;
if dst_idx2 < a_total {
a[dst_idx2] ^= src >> (64 - bits_to_shift);
}
}
}
}
}
bits_to_shift += 4;
if bits_to_shift == 64 {
words_to_shift += 1;
bits_to_shift = 0;
}
}
}
let total_transpose = a_width * (param_m + (param_k + 1) * param_k / 2).div_ceil(16);
let mut c = 0;
while c < total_transpose {
debug_assert!(c + 16 <= a.len());
transpose_16x16_nibbles(&mut a[c..c + 16]);
c += 16;
}
let mut tab = [0u8; F_TAIL_LEN * 4];
for i in 0..F_TAIL_LEN {
tab[4 * i] = mul_f(f_tail[i], 1);
tab[4 * i + 1] = mul_f(f_tail[i], 2);
tab[4 * i + 2] = mul_f(f_tail[i], 4);
tab[4 * i + 3] = mul_f(f_tail[i], 8);
}
let low_bit_in_nibble: u64 = 0x1111111111111111;
let mut c = 0;
while c < a_width {
for r in param_m..(param_m + (param_k + 1) * param_k / 2) {
let pos = (r / 16) * a_width + c + (r % 16);
debug_assert!(pos < a.len());
let val = a[pos];
let t0 = val & low_bit_in_nibble;
let t1 = (val >> 1) & low_bit_in_nibble;
let t2 = (val >> 2) & low_bit_in_nibble;
let t3 = (val >> 3) & low_bit_in_nibble;
for t in 0..F_TAIL_LEN {
let target_r = r + t - param_m;
let target_pos = (target_r / 16) * a_width + c + (target_r % 16);
debug_assert!(target_pos < a.len());
a[target_pos] ^= t0.wrapping_mul(u64::from(tab[4 * t]))
^ t1.wrapping_mul(u64::from(tab[4 * t + 1]))
^ t2.wrapping_mul(u64::from(tab[4 * t + 2]))
^ t3.wrapping_mul(u64::from(tab[4 * t + 3]));
}
}
c += 16;
}
for r in (0..param_m).step_by(16) {
let mut c = 0;
while c < a_cols - 1 {
for i in 0..16 {
if r + i >= param_m {
break;
}
let src_pos = r * a_width / 16 + c + i;
let decode_len = 16.min(a_cols - 1 - c);
debug_assert!(src_pos < a.len());
let src_bytes = a[src_pos].to_le_bytes();
decode_packed_nibbles(&src_bytes, &mut a_out[(r + i) * a_cols + c..], decode_len);
}
c += 16;
}
}
}
fn decode_packed_nibbles(input: &[u8], output: &mut [u8], len: usize) {
let mut out_idx = 0;
let mut i = 0;
while out_idx < len && i < input.len() {
output[out_idx] = input[i] & 0xf;
out_idx += 1;
if out_idx < len {
output[out_idx] = input[i] >> 4;
out_idx += 1;
}
i += 1;
}
}
pub(crate) fn mayo_sign_signature<P: MayoParameter>(
sig: &mut [u8],
msg: &[u8],
csk: &[u8],
rng: &mut impl CryptoRng,
) -> Result<usize> {
let param_m = P::M;
let param_n = P::N;
let param_o = P::O;
let param_k = P::K;
let param_v = P::V;
let param_m_bytes = P::M_BYTES;
let param_v_bytes = P::V_BYTES;
let param_r_bytes = P::R_BYTES;
let param_sig_bytes = P::SIG_BYTES;
let param_a_cols = P::A_COLS;
let param_digest_bytes = P::DIGEST_BYTES;
let param_sk_seed_bytes = P::SK_SEED_BYTES;
let param_salt_bytes = P::SALT_BYTES;
let (p, o_mat) = expand_sk::<P>(csk);
let seed_sk = &csk[..param_sk_seed_bytes];
let p1 = &p[..P::P1_LIMBS];
let l = &p[P::P1_LIMBS..];
let mut tmp = Zeroizing::new(vec![0u8; param_digest_bytes + param_salt_bytes]);
{
let mut hasher = Shake256::default();
hasher.update(msg);
let mut reader = hasher.finalize_xof();
reader.read(&mut tmp[..param_digest_bytes]);
}
rng.fill_bytes(&mut tmp[param_digest_bytes..param_digest_bytes + param_salt_bytes]);
let mut salt = Zeroizing::new(vec![0u8; param_salt_bytes]);
{
let mut hasher = Shake256::default();
hasher.update(&tmp[..param_digest_bytes + param_salt_bytes]);
hasher.update(seed_sk);
let mut reader = hasher.finalize_xof();
reader.read(&mut salt);
}
let mut tenc = vec![0u8; param_m_bytes];
let mut t = vec![0u8; param_m];
tmp[param_digest_bytes..param_digest_bytes + param_salt_bytes].copy_from_slice(&salt);
{
let mut hasher = Shake256::default();
hasher.update(&tmp[..param_digest_bytes + param_salt_bytes]);
let mut reader = hasher.finalize_xof();
reader.read(&mut tenc);
}
decode(&tenc, &mut t, param_m);
let mut x = Zeroizing::new(vec![0u8; param_k * param_n]);
let mut s = vec![0u8; param_k * param_n];
let mut vdec = Zeroizing::new(vec![0u8; param_v * param_k]);
let m_vec_limbs = P::M_VEC_LIMBS;
let mut v_and_r = Zeroizing::new(vec![0u8; param_k * param_v_bytes + param_r_bytes]);
let mut mtmp = Zeroizing::new(vec![0u64; param_k * param_o * m_vec_limbs]);
let mut vpv = Zeroizing::new(vec![0u64; param_k * param_k * m_vec_limbs]);
let mut pv = Zeroizing::new(vec![0u64; param_v * param_k * m_vec_limbs]);
let mut y = vec![0u8; param_m];
let a_row_size = param_m.div_ceil(8) * 8;
let mut a_matrix = Zeroizing::new(vec![0u8; a_row_size * param_a_cols]);
let mut r = Zeroizing::new(vec![0u8; param_k * param_o + 1]);
for ctr in 0..=255u8 {
{
let mut hasher = Shake256::default();
hasher.update(&tmp[..param_digest_bytes + param_salt_bytes]);
hasher.update(seed_sk);
hasher.update(&[ctr]);
let mut reader = hasher.finalize_xof();
reader.read(&mut v_and_r);
}
for i in 0..param_k {
decode(
&v_and_r[i * param_v_bytes..],
&mut vdec[i * param_v..],
param_v,
);
}
mtmp.fill(0);
vpv.fill(0);
compute_m_and_vpv::<P>(&vdec, l, p1, &mut mtmp, &mut vpv, &mut pv);
y.fill(0);
compute_rhs::<P>(&mut vpv, &t, &mut y);
a_matrix.fill(0);
compute_a::<P>(&mut mtmp, &mut a_matrix);
for i in 0..param_m {
a_matrix[(1 + i) * param_a_cols - 1] = 0;
}
r.fill(0);
decode(
&v_and_r[param_k * param_v_bytes..],
&mut r,
param_k * param_o,
);
if sample_solution(
&mut a_matrix,
&y,
&r,
&mut x,
param_k,
param_o,
param_m,
param_a_cols,
) {
break;
}
}
let mut ox = Zeroizing::new(vec![0u8; param_v]);
for i in 0..param_k {
let vi = &vdec[i * param_v..(i + 1) * param_v];
let xi = &x[i * param_o..(i + 1) * param_o];
ox.fill(0);
mat_mul(&o_mat, xi, &mut ox, param_o, param_v, 1);
mat_add(vi, &ox, &mut s[i * param_n..], param_v, 1);
s[i * param_n + param_v..i * param_n + param_n]
.copy_from_slice(&x[i * param_o..(i + 1) * param_o]);
}
encode(&s, sig, param_n * param_k);
sig[param_sig_bytes - param_salt_bytes..param_sig_bytes].copy_from_slice(&salt);
let mut derived_cpk = vec![0u8; P::CPK_BYTES];
derive_cpk_from_csk::<P>(csk, &mut derived_cpk);
if mayo_verify::<P>(msg, sig, &derived_cpk).is_err() {
return Err(Error::Signing);
}
Ok(param_sig_bytes)
}