use crate::hashers::Hashers;
use crate::helpers::{base_2b, to_byte};
use crate::types::{Adrs, WotsPk, WotsSig, WOTS_PK, WOTS_PRF};
pub(crate) fn chain<const K: usize, const LEN: usize, const M: usize, const N: usize>(
hashers: &Hashers<K, LEN, M, N>, cap_x: [u8; N], i: u32, s: u32, pk_seed: &[u8], adrs: &Adrs,
) -> [u8; N] {
debug_assert!(i + s < u32::MAX);
let mut adrs = adrs.clone();
let mut tmp = cap_x;
for j in i..(i + s) {
adrs.set_hash_address(j);
tmp = (hashers.f)(pk_seed, &adrs, &tmp);
}
tmp
}
#[allow(clippy::similar_names)] pub(crate) fn wots_pkgen<const K: usize, const LEN: usize, const M: usize, const N: usize>(
hashers: &Hashers<K, LEN, M, N>, sk_seed: &[u8], pk_seed: &[u8], adrs: &Adrs,
) -> WotsPk<N> {
let len32 = u32::try_from(LEN).unwrap();
let mut adrs = adrs.clone();
let mut tmp = [[0u8; N]; LEN];
let mut sk_adrs = adrs.clone();
sk_adrs.set_type_and_clear(WOTS_PRF);
sk_adrs.set_key_pair_address(adrs.get_key_pair_address());
for i in 0..len32 {
sk_adrs.set_chain_address(i);
let sk = (hashers.prf)(pk_seed, sk_seed, &sk_adrs);
adrs.set_chain_address(i);
tmp[i as usize] = chain(hashers, sk, 0, crate::W - 1, pk_seed, &adrs);
}
let mut wotspk_adrs = adrs.clone();
wotspk_adrs.set_type_and_clear(WOTS_PK);
wotspk_adrs.set_key_pair_address(adrs.get_key_pair_address());
let pk = (hashers.t_l)(pk_seed, &wotspk_adrs, &tmp);
WotsPk(pk)
}
#[allow(clippy::similar_names)] pub(crate) fn wots_sign<const K: usize, const LEN: usize, const M: usize, const N: usize>(
hashers: &Hashers<K, LEN, M, N>, m: &[u8], sk_seed: &[u8], pk_seed: &[u8], adrs: &Adrs,
) -> WotsSig<LEN, N> {
let n32 = u32::try_from(N).unwrap();
let mut adrs = adrs.clone();
let mut sig: WotsSig<LEN, N> = WotsSig { data: [[0u8; N]; LEN] };
let mut csum = 0_u32;
let mut msg = [0u32; LEN]; base_2b(m, crate::LGW, 2 * n32, &mut msg[0..(2 * N)]);
for item in msg.iter().take(2 * N) {
csum += crate::W - 1 - *item;
}
csum <<= (8 - ((crate::LEN2 * crate::LGW) & 0x07)) & 0x07;
base_2b(
&to_byte(csum, (crate::LEN2 * crate::LGW + 7) / 8),
crate::LGW,
crate::LEN2,
&mut msg[(2 * N)..],
);
let mut sk_addrs = adrs.clone();
sk_addrs.set_type_and_clear(WOTS_PRF);
sk_addrs.set_key_pair_address(adrs.get_key_pair_address());
for (item, i) in msg.iter().zip(0u32..) {
sk_addrs.set_chain_address(i);
let sk = (hashers.prf)(pk_seed, sk_seed, &sk_addrs);
adrs.set_chain_address(i);
sig.data[i as usize] = chain(hashers, sk, 0, *item, pk_seed, &adrs);
}
sig
}
pub(crate) fn wots_pk_from_sig<const K: usize, const LEN: usize, const M: usize, const N: usize>(
hashers: &Hashers<K, LEN, M, N>, sig: &WotsSig<LEN, N>, m: &[u8], pk_seed: &[u8], adrs: &Adrs,
) -> WotsPk<N> {
let n32 = u32::try_from(N).unwrap();
let mut adrs = adrs.clone();
let mut tmp = [[0u8; N]; LEN];
let mut csum = 0_u32;
let mut msg = [0u32; LEN];
base_2b(m, crate::LGW, 2 * n32, &mut msg[0..(2 * N)]);
for item in msg.iter().take(2 * N) {
csum += crate::W - 1 - item;
}
csum <<= (8 - ((crate::LEN2 * crate::LGW) & 0x07)) & 0x07;
base_2b(
&to_byte(csum, (crate::LEN2 * crate::LGW + 7) / 8),
crate::LGW,
crate::LEN2,
&mut msg[(2 * N)..],
);
#[allow(clippy::cast_possible_truncation)] for i in 0..LEN {
adrs.set_chain_address(i as u32);
tmp[i] = chain::<K, LEN, M, N>(
hashers,
sig.data[i],
msg[i],
crate::W - 1 - msg[i],
pk_seed,
&adrs,
);
}
let mut wotspk_adrs = adrs.clone();
wotspk_adrs.set_type_and_clear(WOTS_PK);
wotspk_adrs.set_key_pair_address(adrs.get_key_pair_address());
let pk = (hashers.t_l)(pk_seed, &wotspk_adrs, &tmp);
WotsPk(pk)
}