use super::params::{D_MESG, D_PBLC, LmotsType, MAX_P, N};
use crate::hash::{Digest, Sha256};
fn coef(s: &[u8], i: usize, w: u32) -> u32 {
let per_byte = 8 / w as usize; let byte = s[i * w as usize / 8];
let shift = 8 - (w * (i % per_byte) as u32 + w);
let mask = (1u32 << w) - 1;
(byte as u32 >> shift) & mask
}
fn cksm(s: &[u8], t: LmotsType) -> u16 {
let w = t.w();
let max = t.max_digit();
let count = N * 8 / w as usize;
let mut sum: u16 = 0;
for i in 0..count {
sum = sum.wrapping_add((max - coef(s, i, w)) as u16);
}
sum << t.ls()
}
fn q_with_checksum(t: LmotsType, q_digest: &[u8; N]) -> [u8; N + 2] {
let mut buf = [0u8; N + 2];
buf[..N].copy_from_slice(q_digest);
let c = cksm(q_digest, t);
buf[N] = (c >> 8) as u8;
buf[N + 1] = c as u8;
buf
}
fn compute_q(i_id: &[u8; 16], q: u32, c: &[u8; N], message: &[u8]) -> [u8; N] {
let mut h = Sha256::new();
h.update(i_id);
h.update(&q.to_be_bytes());
h.update(&D_MESG.to_be_bytes());
h.update(c);
h.update(message);
h.finalize()
}
fn chain_step(i_id: &[u8; 16], q: u32, chain: u16, j: u8, tmp: &mut [u8; N]) {
let mut h = Sha256::new();
h.update(i_id);
h.update(&q.to_be_bytes());
h.update(&chain.to_be_bytes());
h.update(&[j]);
h.update(&*tmp);
*tmp = h.finalize();
}
pub(crate) fn derive_x(i_id: &[u8; 16], seed: &[u8; N], q: u32, chain: u16, out: &mut [u8; N]) {
let mut h = Sha256::new();
h.update(i_id);
h.update(&q.to_be_bytes());
h.update(&chain.to_be_bytes());
h.update(&[0xffu8]);
h.update(seed);
*out = h.finalize();
}
pub(crate) fn derive_c(i_id: &[u8; 16], seed: &[u8; N], q: u32, message: &[u8]) -> [u8; N] {
let mut h = Sha256::new();
h.update(i_id);
h.update(&q.to_be_bytes());
h.update(&0xfffdu16.to_be_bytes());
h.update(&[0xffu8]);
h.update(seed);
h.update(message);
h.finalize()
}
pub(crate) fn public_key(t: LmotsType, i_id: &[u8; 16], seed: &[u8; N], q: u32) -> [u8; N] {
#[cfg(all(feature = "std", target_arch = "x86_64"))]
if crate::hash::sha256_mb::supported() {
return lmots_x8::public_key_x8(t, i_id, seed, q);
}
public_key_scalar(t, i_id, seed, q)
}
pub(crate) fn public_key_scalar(t: LmotsType, i_id: &[u8; 16], seed: &[u8; N], q: u32) -> [u8; N] {
let p = t.p();
let max = t.max_digit();
let mut k_hash = Sha256::new();
k_hash.update(i_id);
k_hash.update(&q.to_be_bytes());
k_hash.update(&D_PBLC.to_be_bytes());
let mut tmp = [0u8; N];
for chain in 0..p {
derive_x(i_id, seed, q, chain as u16, &mut tmp);
for j in 0..max {
chain_step(i_id, q, chain as u16, j as u8, &mut tmp);
}
k_hash.update(&tmp);
}
k_hash.finalize()
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
mod lmots_x8 {
use super::{D_PBLC, LmotsType, N};
use crate::hash::sha256::H256;
use crate::hash::sha256_mb::{LANES, compress8};
use crate::hash::{Digest, Sha256};
#[inline]
fn block55(i_id: &[u8; 16], q: u32, chain: u16, byte22: u8, tail: &[u8; N]) -> [u8; 64] {
let mut b = [0u8; 64];
b[..16].copy_from_slice(i_id);
b[16..20].copy_from_slice(&q.to_be_bytes());
b[20..22].copy_from_slice(&chain.to_be_bytes());
b[22] = byte22;
b[23..55].copy_from_slice(tail);
b[55] = 0x80;
b[56..64].copy_from_slice(&440u64.to_be_bytes());
b
}
#[inline]
fn state_be(s: &[u32; 8]) -> [u8; N] {
let mut o = [0u8; N];
for (i, w) in s.iter().enumerate() {
o[i * 4..i * 4 + 4].copy_from_slice(&w.to_be_bytes());
}
o
}
pub(super) fn public_key_x8(t: LmotsType, i_id: &[u8; 16], seed: &[u8; N], q: u32) -> [u8; N] {
let p = t.p();
let max = t.max_digit();
let mut k_hash = Sha256::new();
k_hash.update(i_id);
k_hash.update(&q.to_be_bytes());
k_hash.update(&D_PBLC.to_be_bytes());
let mut c0 = 0usize;
while c0 < p {
let lanes = (p - c0).min(LANES);
let mut blocks = [[0u8; 64]; LANES];
for (l, blk) in blocks.iter_mut().enumerate() {
let chain = if l < lanes { c0 + l } else { c0 };
*blk = block55(i_id, q, chain as u16, 0xff, seed);
}
let mut states = [H256; LANES];
compress8(&mut states, &blocks);
let mut tmps = [[0u8; N]; LANES];
for (l, tmp) in tmps.iter_mut().enumerate() {
*tmp = state_be(&states[l]);
}
for j in 0..max {
let mut blocks = [[0u8; 64]; LANES];
for (l, blk) in blocks.iter_mut().enumerate() {
let chain = if l < lanes { c0 + l } else { c0 };
*blk = block55(i_id, q, chain as u16, j as u8, &tmps[l]);
}
let mut states = [H256; LANES];
compress8(&mut states, &blocks);
for (l, tmp) in tmps.iter_mut().enumerate() {
*tmp = state_be(&states[l]);
}
}
for tmp in tmps.iter().take(lanes) {
k_hash.update(tmp);
}
c0 += lanes;
}
k_hash.finalize()
}
}
#[cfg(all(test, feature = "std", target_arch = "x86_64"))]
mod lmots_x8_tests {
use super::{LmotsType, N, lmots_x8, public_key_scalar};
#[test]
fn batched_matches_scalar() {
if !crate::hash::sha256_mb::supported() {
return;
}
let mut s = 0x1234_5678_9abc_def0u64;
let mut next = || {
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
s
};
let types = [
LmotsType::Sha256N32W8,
LmotsType::Sha256N32W4,
LmotsType::Sha256N32W2,
LmotsType::Sha256N32W1,
];
for t in types {
for _ in 0..8 {
let mut i_id = [0u8; 16];
for b in i_id.iter_mut() {
*b = (next() >> 24) as u8;
}
let mut seed = [0u8; N];
for b in seed.iter_mut() {
*b = (next() >> 24) as u8;
}
let q = next() as u32;
let want = public_key_scalar(t, &i_id, &seed, q);
let got = lmots_x8::public_key_x8(t, &i_id, &seed, q);
assert_eq!(got, want, "type {t:?}, q={q}");
}
}
}
}
pub(crate) fn sign(
t: LmotsType,
i_id: &[u8; 16],
seed: &[u8; N],
q: u32,
c: &[u8; N],
message: &[u8],
out: &mut [u8],
) {
let p = t.p();
let q_digest = compute_q(i_id, q, c, message);
let qc = q_with_checksum(t, &q_digest);
out[..4].copy_from_slice(&t.typecode().to_be_bytes());
out[4..4 + N].copy_from_slice(c);
let w = t.w();
let mut tmp = [0u8; N];
for chain in 0..p {
let a = coef(&qc, chain, w);
derive_x(i_id, seed, q, chain as u16, &mut tmp);
for j in 0..a {
chain_step(i_id, q, chain as u16, j as u8, &mut tmp);
}
let off = 4 + N + chain * N;
out[off..off + N].copy_from_slice(&tmp);
}
}
pub(crate) fn recover_public_key(
pubtype: LmotsType,
i_id: &[u8; 16],
q: u32,
message: &[u8],
sig: &[u8],
) -> Option<[u8; N]> {
if sig.len() < 4 {
return None;
}
let sigtype = u32::from_be_bytes([sig[0], sig[1], sig[2], sig[3]]);
if sigtype != pubtype.typecode() {
return None;
}
let t = pubtype;
let p = t.p();
if sig.len() != t.sig_len() {
return None;
}
let mut c = [0u8; N];
c.copy_from_slice(&sig[4..4 + N]);
let q_digest = compute_q(i_id, q, &c, message);
let qc = q_with_checksum(t, &q_digest);
let max = t.max_digit();
let w = t.w();
let mut k_hash = Sha256::new();
k_hash.update(i_id);
k_hash.update(&q.to_be_bytes());
k_hash.update(&D_PBLC.to_be_bytes());
let mut tmp = [0u8; N];
debug_assert!(p <= MAX_P);
for chain in 0..p {
let a = coef(&qc, chain, w);
let off = 4 + N + chain * N;
tmp.copy_from_slice(&sig[off..off + N]);
for j in a..max {
chain_step(i_id, q, chain as u16, j as u8, &mut tmp);
}
k_hash.update(&tmp);
}
Some(k_hash.finalize())
}