use hybrid_array::{Array, ArraySize};
use typenum::Unsigned;
use crate::wots::WotsSig;
use crate::{PkSeed, SkSeed};
use crate::{address, wots::WotsParams};
use core::fmt::Debug;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct XmssSig<P: XmssParams> {
pub(crate) sig: WotsSig<P>,
pub(crate) auth: Array<Array<u8, P::N>, P::HPrime>,
}
impl<P: XmssParams> XmssSig<P> {
pub const SIZE: usize = WotsSig::<P>::SIZE + P::HPrime::USIZE * P::N::USIZE;
pub fn write_to(&self, buf: &mut [u8]) {
debug_assert!(buf.len() == Self::SIZE, "Xmss serialize length mismatch");
let (wots, auth) = buf.split_at_mut(WotsSig::<P>::SIZE);
self.sig.write_to(wots);
auth.chunks_exact_mut(P::N::USIZE)
.zip(self.auth.iter())
.for_each(|(buf, auth)| buf.copy_from_slice(auth.as_slice()));
}
#[cfg(feature = "alloc")]
#[cfg(test)]
pub fn to_vec(&self) -> Vec<u8> {
let mut buf = vec![0u8; Self::SIZE];
self.write_to(&mut buf);
buf
}
}
impl<P: XmssParams> TryFrom<&[u8]> for XmssSig<P> {
type Error = ();
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if value.len() != Self::SIZE {
return Err(());
}
let sig = WotsSig::<P>::try_from(&value[..WotsSig::<P>::SIZE])?;
let mut auth = Array::<Array<u8, P::N>, P::HPrime>::default();
for i in 0..P::HPrime::USIZE {
auth[i].copy_from_slice(
&value[WotsSig::<P>::SIZE + i * P::N::USIZE
..WotsSig::<P>::SIZE + (i + 1) * P::N::USIZE],
);
}
Ok(XmssSig { sig, auth })
}
}
pub(crate) trait XmssParams: WotsParams + Sized {
type HPrime: ArraySize + Debug + Eq;
fn xmss_node(
sk_seed: &SkSeed<Self::N>,
node: u32,
height: u32,
pk_seed: &PkSeed<Self::N>,
adrs: &address::WotsHash,
) -> Array<u8, Self::N> {
debug_assert!(height <= Self::HPrime::U32);
debug_assert!(node < (1 << (Self::HPrime::U32 - height)));
if height == 0 {
let mut adrs = adrs.clone();
adrs.key_pair_adrs.set(node);
Self::wots_pk_gen(sk_seed, pk_seed, &adrs)
} else {
let lnode = Self::xmss_node(sk_seed, 2 * node, height - 1, pk_seed, adrs);
let rnode = Self::xmss_node(sk_seed, 2 * node + 1, height - 1, pk_seed, adrs);
let mut adrs = adrs.tree_adrs();
adrs.tree_height.set(height);
adrs.tree_index.set(node);
Self::h(pk_seed, &adrs, &lnode, &rnode)
}
}
fn xmss_sign(
m: &Array<u8, Self::N>,
sk_seed: &SkSeed<Self::N>,
pk_seed: &PkSeed<Self::N>,
idx: u32,
adrs: &address::WotsHash,
) -> XmssSig<Self> {
let mut adrs = adrs.clone();
adrs.key_pair_adrs.set(idx);
let sig = Self::wots_sign(m, sk_seed, pk_seed, &adrs);
let mut auth = Array::<Array<u8, Self::N>, Self::HPrime>::default();
let mut idx = idx;
for j in 0..Self::HPrime::U32 {
let node = Self::xmss_node(sk_seed, idx ^ 1, j, pk_seed, &adrs);
idx >>= 1;
auth[j as usize] = node;
}
XmssSig { sig, auth }
}
fn xmss_pk_from_sig(
idx: u32,
sig: &XmssSig<Self>,
m: &Array<u8, Self::N>,
pk_seed: &PkSeed<Self::N>,
adrs: &address::WotsHash,
) -> Array<u8, Self::N>
where {
let mut adrs = adrs.clone();
adrs.key_pair_adrs.set(idx);
let mut node = Self::wots_pk_from_sig(&sig.sig, m, pk_seed, &adrs);
let mut adrs = adrs.tree_adrs();
let mut idx = idx;
let mut rem;
for j in 0..Self::HPrime::U32 {
adrs.tree_height.set(j + 1);
(idx, rem) = (idx >> 1, idx & 1);
adrs.tree_index.set(idx);
if rem == 0 {
node = Self::h(pk_seed, &adrs, &node, &sig.auth[j as usize]);
} else {
node = Self::h(pk_seed, &adrs, &sig.auth[j as usize], &node);
}
}
node
}
}
#[cfg(test)]
mod tests {
use crate::PkSeed;
use crate::SkSeed;
use crate::util::macros::test_parameter_sets;
use hex_literal::hex;
use hybrid_array::Array;
use rand::Rng;
use rand::RngCore;
use typenum::Unsigned;
use crate::{address::WotsHash, hashes::Shake128f, xmss::XmssParams};
#[test]
fn test_xmss_node_shake128f_kat() {
let sk_seed = SkSeed(Array([1; 16]));
let pk_seed = PkSeed(Array([2; 16]));
let adrs = WotsHash::default();
let node = Shake128f::xmss_node(
&sk_seed,
0,
<Shake128f as XmssParams>::HPrime::U32,
&pk_seed,
&adrs,
);
let expected = hex!("94e24679fb2460b97332db131c38bec9");
assert_eq!(node.as_slice(), expected);
}
#[test]
#[cfg(feature = "alloc")]
fn test_sign_shake128f_kat() {
let sk_seed = SkSeed(Array([1; 16]));
let pk_seed = PkSeed(Array([2; 16]));
let adrs = WotsHash::default();
let m = Array([3; 16]);
let idx = 3;
let sig = Shake128f::xmss_sign(&m, &sk_seed, &pk_seed, idx, &adrs);
let expected = hex!(
"
a77a0b07e558b023f653a954d886ac66ded67b313f9db7fd93da00686be66a3f
2e2d3e841292bf5a4060d88509e9a2a51e0bbae6835482bceabce76c5653546d
08c2f5f78e7491f755f35380d965598891131bdd4c57df2397eed8062a1038fb
10c758bb30c6ea3859db4eb6296269d170d86cc67804dc63a61e5f30af709aad
2407624eb81549e87c326c2a646c2b995dfad81cc007286b6f50b56f61352fa2
752a30aa4f63cc367a7a1c57140a086cc43387ce5f530d84538d0c503d051be2
9c0040486c2953d34e3817bfcb6f198e545476ddd93930af48333b4e7e0eba03
3bdbc1badca23875d2f4345699075558a68c8f53865c0b2151208a7a5a4b0c7d
270b71d5688c6d727525e3fd9c75b9656e13394777faee925fe8cda6e2b7c52a
684f218679a48b942127f89ffaa069db21659a09266e9304ce870c16094bf585
6ed93c0748b9479a95d4309c74c2da26b2cf2e5f2090f02601b80c3373b14666
f0bd973d10c7eb649966d1ffd3e87979899812fef1e23f5703a99924001d9ba9
522ea93575ad20143eeeeff77b8d192870932b1583459271f634a65441fe1907
370f71e4d9312b930a66e1b85cba8f4a404c703c7c38ada5c6b95824c2c0ff87
b1e3f258189d949430c516d2c2192ffbb8d687b10228d7ecf47f86c1299825a8
b6ee7c560f4bd1720aabdca41c8a5569e9917f906efca17d5f080e65e5a16386
c9bb4f1ad49404340df212e94d77ff5a25b8649b725e1993dc66f37a89058499
107bb57a4f699688406e89a44776b95bd1af01290496fb4f3abba58eb407eff9
c1dfd1362d169170f8b7364c6aa8e6507f049484e5d9b934e86d61b1d3155b5a"
);
assert_eq!(sig.to_vec(), expected);
}
fn test_sign_verify<Xmss: XmssParams>() {
let mut rng = rand::rngs::OsRng;
let sk_seed = SkSeed::new(&mut rng);
let pk_seed = PkSeed::new(&mut rng);
let mut msg = Array::<u8, _>::default();
rng.fill_bytes(msg.as_mut_slice());
let idx = rng.gen_range(0..(1 << Xmss::HPrime::U32));
let adrs = WotsHash::default();
let pk = Xmss::xmss_node(&sk_seed, 0, Xmss::HPrime::U32, &pk_seed, &adrs);
let sig = Xmss::xmss_sign(&msg, &sk_seed, &pk_seed, idx, &adrs);
let pk_recovered = Xmss::xmss_pk_from_sig(idx, &sig, &msg, &pk_seed, &adrs);
assert_eq!(pk, pk_recovered);
}
test_parameter_sets!(test_sign_verify);
fn test_sign_verify_fail<Xmss: XmssParams>() {
let mut rng = rand::rngs::OsRng;
let sk_seed = SkSeed::new(&mut rng);
let pk_seed = PkSeed::new(&mut rng);
let mut msg = Array::<u8, _>::default();
rng.fill_bytes(msg.as_mut_slice());
let idx = rng.gen_range(0..(1 << Xmss::HPrime::U32));
let adrs = WotsHash::default();
let pk = Xmss::xmss_node(&sk_seed, 0, Xmss::HPrime::U32, &pk_seed, &adrs);
let sig = Xmss::xmss_sign(&msg, &sk_seed, &pk_seed, idx, &adrs);
msg[0] ^= 0xff;
let pk_recovered = Xmss::xmss_pk_from_sig(idx, &sig, &msg, &pk_seed, &adrs);
assert_ne!(pk, pk_recovered);
}
test_parameter_sets!(test_sign_verify_fail);
}