use hybrid_array::{Array, ArraySize};
use typenum::Unsigned;
use typenum::generic_const_mappings::U;
use crate::hashes::HashSuite;
use crate::util::base_2b;
use crate::{PkSeed, SkSeed, address};
use core::fmt::Debug;
const LOG_W: usize = 4;
const W: u32 = 16;
const CK_LEN: usize = 3;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct WotsSig<P: WotsParams>(Array<Array<u8, P::N>, P::WotsSigLen>);
impl<P: WotsParams> WotsSig<P> {
pub const SIZE: usize = P::N::USIZE * P::WotsSigLen::USIZE;
pub fn write_to(&self, buf: &mut [u8]) {
debug_assert!(buf.len() == Self::SIZE, "WOTS+ serialize length mismatch");
buf.chunks_exact_mut(P::N::USIZE)
.zip(self.0.iter())
.for_each(|(buf, sig)| buf.copy_from_slice(sig.as_slice()));
}
#[cfg(feature = "alloc")]
#[cfg(test)]
pub fn to_vec(&self) -> Vec<u8> {
let mut vec = vec![0u8; Self::SIZE];
self.write_to(&mut vec);
vec
}
}
impl<P: WotsParams> TryFrom<&[u8]> for WotsSig<P> {
type Error = ();
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if value.len() != Self::SIZE {
return Err(());
}
let mut sig = Array::<Array<u8, P::N>, P::WotsSigLen>::default();
for i in 0..P::WotsSigLen::USIZE {
sig[i].copy_from_slice(&value[i * P::N::USIZE..(i + 1) * P::N::USIZE]);
}
Ok(WotsSig(sig))
}
}
pub(crate) trait WotsParams: HashSuite {
type WotsMsgLen: ArraySize; type WotsSigLen: ArraySize + Debug + Eq;
fn wots_chain(
x: &Array<u8, Self::N>,
i: u32,
s: u32,
pk_seed: &PkSeed<Self::N>,
adrs: &address::WotsHash,
) -> Array<u8, Self::N> {
debug_assert!(i + s < 1 << LOG_W, "Invalid wots_chain index");
let mut tmp = x.clone(); let mut adrs = adrs.clone(); for j in i..(i + s) {
adrs.hash_adrs.set(j);
tmp = Self::f(pk_seed, &adrs, &tmp); }
tmp
}
fn wots_pk_gen(
sk_seed: &SkSeed<Self::N>,
pk_seed: &PkSeed<Self::N>,
adrs: &address::WotsHash,
) -> Array<u8, Self::N> {
let mut adrs = adrs.clone();
let mut sk_adrs = adrs.prf_adrs();
let tmp = Array::<Array<u8, Self::N>, Self::WotsSigLen>::from_fn(|i: usize| {
let i: u32 = i.try_into().expect("i is less than 2^32");
sk_adrs.chain_adrs.set(i);
adrs.chain_adrs.set(i);
let sk = Self::prf_sk(pk_seed, sk_seed, &sk_adrs);
Self::wots_chain(&sk, 0, (1 << LOG_W) - 1, pk_seed, &adrs)
});
let pk_adrs = adrs.pk_adrs();
Self::t(pk_seed, &pk_adrs, &tmp)
}
fn wots_sign(
m: &Array<u8, Self::N>,
sk_seed: &SkSeed<Self::N>,
pk_seed: &PkSeed<Self::N>,
adrs: &address::WotsHash,
) -> WotsSig<Self> {
let msg = base_2b::<Self::WotsMsgLen, U<LOG_W>>(m.as_slice());
let csum = msg.iter().map(|&x| (1 << LOG_W) - 1 - x).sum::<u16>() << 4;
let csum_bytes = csum.to_be_bytes();
let csum_chunks = base_2b::<U<CK_LEN>, U<LOG_W>>(&csum_bytes);
let mut msg_csum = msg.iter().chain(csum_chunks.iter());
let mut adrs = adrs.clone();
let mut sk_adrs = adrs.prf_adrs();
let sig = Array::<Array<u8, Self::N>, Self::WotsSigLen>::from_fn(|i: usize| {
let i: u32 = i.try_into().expect("i is less than 2^32");
sk_adrs.chain_adrs.set(i);
adrs.chain_adrs.set(i);
let sk = Self::prf_sk(pk_seed, sk_seed, &sk_adrs);
Self::wots_chain(&sk, 0, u32::from(*msg_csum.next().unwrap()), pk_seed, &adrs)
});
WotsSig(sig)
}
fn wots_pk_from_sig(
sig: &WotsSig<Self>,
m: &Array<u8, Self::N>,
pk_seed: &PkSeed<Self::N>,
adrs: &address::WotsHash,
) -> Array<u8, Self::N> {
let msg = base_2b::<Self::WotsMsgLen, U<LOG_W>>(m.as_slice());
let csum = msg.iter().map(|&x| (1 << LOG_W) - 1 - x).sum::<u16>() << 4; let csum_bytes = csum.to_be_bytes();
let csum_chunks = base_2b::<U<CK_LEN>, U<LOG_W>>(&csum_bytes);
let mut msg_csum = msg.iter().chain(csum_chunks.iter());
let mut adrs = adrs.clone();
let tmp = Array::<Array<u8, Self::N>, Self::WotsSigLen>::from_fn(|i: usize| {
adrs.chain_adrs
.set(i.try_into().expect("i is less than 2^32"));
let msg_i = u32::from(*msg_csum.next().unwrap());
Self::wots_chain(&sig.0[i], msg_i, W - 1 - msg_i, pk_seed, &adrs)
});
Self::t(pk_seed, &adrs.pk_adrs(), &tmp)
}
}
#[cfg(test)]
mod tests {
use crate::{PkSeed, SkSeed, util::macros::test_parameter_sets};
use hex_literal::hex;
use hybrid_array::Array;
use rand::RngCore;
use crate::{address::WotsHash, hashes::Shake128f};
use super::WotsParams;
fn test_sign_verify<Wots: WotsParams>() {
let mut rng = rand::rngs::OsRng;
let sk_seed = SkSeed::new(&mut rng);
let pk_seed = PkSeed::new(&mut rng);
let mut msg = Array::<u8, _>::default();
rng.fill_bytes(msg.as_mut_slice());
let adrs = &WotsHash::default();
let pk = Wots::wots_pk_gen(&sk_seed, &pk_seed, adrs);
let sig = Wots::wots_sign(&msg, &sk_seed, &pk_seed, adrs);
let pk_recovered = Wots::wots_pk_from_sig(&sig, &msg, &pk_seed, adrs);
assert_eq!(pk, pk_recovered);
}
test_parameter_sets!(test_sign_verify);
fn test_sign_verify_fail<Wots: WotsParams>() {
let mut rng = rand::rngs::OsRng;
let sk_seed = SkSeed::new(&mut rng);
let pk_seed = PkSeed::new(&mut rng);
let mut msg = Array::<u8, _>::default();
rng.fill_bytes(msg.as_mut_slice());
let adrs = &WotsHash::default();
let pk = Wots::wots_pk_gen(&sk_seed, &pk_seed, adrs);
let sig = Wots::wots_sign(&msg, &sk_seed, &pk_seed, adrs);
msg[0] ^= 0xff;
let pk_recovered = Wots::wots_pk_from_sig(&sig, &msg, &pk_seed, adrs);
assert_ne!(
pk, pk_recovered,
"Signature verification should fail with a modified message"
);
}
test_parameter_sets!(test_sign_verify_fail);
#[test]
fn test_pk_gen_shake128f_kat() {
let sk_seed = SkSeed(Array([1; 16]));
let pk_seed = PkSeed(Array([2; 16]));
let adrs = WotsHash::default();
let expected = Array(hex!("98b63dd1574484876b1f8a1120421eac"));
let result = Shake128f::wots_pk_gen(&sk_seed, &pk_seed, &adrs);
assert_eq!(result, expected);
}
#[test]
#[cfg(feature = "alloc")]
fn test_sign_shake128f_kat() {
let sk_seed = SkSeed(Array([1; 16]));
let pk_seed = PkSeed(Array([2; 16]));
let adrs = &WotsHash::default();
let msg = Array([3; 16]);
let expected = &hex!(
"f7bcb9575590faae2e6a8ae33149082d2ec777cff4051f43177ef44bcbd2c18d
a94146c50037c914461dd6ed720192b059bd2be6ed8d8cf26e4e9d68fbf9ded1
6c334bed21677c6a3679f17a8425de40431b4317326c5d825d931b4a54a1b81f
e7ad259086ea665109a7eca79f03e3619d99af5d0419fece8300973f29467f28
d2b18639eeaa826488f6c785d492703463e80f8b088e64de9ca3b373cead611f
d356bf6c22f70f98f229174a9ac815342f0439eb289a78f49f47aa8c3f272a15
f5f0f5020b5d71981254daa9e1f01a90248935c1c67ad1cf71d9224184820cf9
ece9b737ec986c86ba0a9431ff8485c274140bebc9d856316d49128eb075f81a
c00d32b9f949940f2dd684a2e615e16b47093eb49e3bc9d77e69c7944d7063c6
f8b4b5aa46fe759999fa2892ce4c7881b80f38d684427a0b77f3ad43377833d2
d94c600b340ea408a0ad7c32c409bdb4ebaade3b1dda4ac8584acba979c845a9
b0ddfc69ea22ffb415745b779b45d7af00ca9fde87e5d59385d7b5cedec6e30f
3346f573f59a00af993a2ec314ed951e3a8c00f69364a82fa34d14933fe3cdb7
bd5e5d511297695bad5cda22daea8d39f61d4ed34412acd1f5399a54953ae04b
09828f90877ad7f01605631ace0a4e7c773cc887e2d0fa0bd3d6db811794df3a
a8721c308482ccb511c9133311653ce8f9c2336e2980c2ab554c41bad436c0c7
1c394d3f7eafcea2806c153113d6291a912c0e73e44197763b9ead341c298585
bc6e16d8458fc1917ff4ac57de461ee1"
);
let result = Shake128f::wots_sign(&msg, &sk_seed, &pk_seed, adrs);
assert_eq!(result.to_vec(), expected.as_slice());
}
}