#![allow(dead_code)]
use super::address::{Address, AddressType};
use super::hash::SlhDsaHash;
use super::params::SlhDsaParams;
use super::wots::{Wots, WotsSignature};
use core::marker::PhantomData;
use std::vec::Vec;
#[derive(Clone)]
pub struct XmssSignature<P: SlhDsaParams> {
wots_sig: Vec<u8>,
auth_path: Vec<Vec<u8>>,
_params: PhantomData<P>,
}
impl<P: SlhDsaParams> XmssSignature<P> {
pub fn new(wots_sig: Vec<u8>, auth_path: Vec<Vec<u8>>) -> Self {
Self {
wots_sig,
auth_path,
_params: PhantomData,
}
}
pub fn wots_sig(&self) -> &[u8] {
&self.wots_sig
}
pub fn auth_path(&self) -> &[Vec<u8>] {
&self.auth_path
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(P::WOTS_LEN * P::N + P::H_PRIME * P::N);
bytes.extend_from_slice(&self.wots_sig);
for node in &self.auth_path {
bytes.extend_from_slice(node);
}
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
let wots_size = P::WOTS_LEN * P::N;
let auth_size = P::H_PRIME * P::N;
if bytes.len() != wots_size + auth_size {
return None;
}
let wots_sig = bytes[0..wots_size].to_vec();
let mut auth_path = Vec::with_capacity(P::H_PRIME);
for i in 0..P::H_PRIME {
let start = wots_size + i * P::N;
auth_path.push(bytes[start..start + P::N].to_vec());
}
Some(Self {
wots_sig,
auth_path,
_params: PhantomData,
})
}
pub fn size() -> usize {
P::WOTS_LEN * P::N + P::H_PRIME * P::N
}
}
pub struct Xmss<P: SlhDsaParams, H: SlhDsaHash<P>> {
_params: PhantomData<P>,
_hash: PhantomData<H>,
}
impl<P: SlhDsaParams, H: SlhDsaHash<P>> Xmss<P, H> {
pub fn xmss_node(
sk_seed: &[u8],
pk_seed: &[u8],
i: u32,
height: u32,
adrs: &Address,
) -> Vec<u8> {
if height == 0 {
let mut wots_adrs = *adrs;
wots_adrs.set_type(AddressType::WotsHash);
wots_adrs.set_keypair_address(i);
Wots::<P, H>::keygen(sk_seed, pk_seed, &wots_adrs)
} else {
let left = Self::xmss_node(sk_seed, pk_seed, 2 * i, height - 1, adrs);
let right = Self::xmss_node(sk_seed, pk_seed, 2 * i + 1, height - 1, adrs);
let mut tree_adrs = *adrs;
tree_adrs.set_type(AddressType::Tree);
tree_adrs.set_tree_height(height);
tree_adrs.set_tree_index(i);
H::h(pk_seed, &tree_adrs, &left, &right)
}
}
pub fn xmss_sign(
msg: &[u8],
sk_seed: &[u8],
pk_seed: &[u8],
idx: u32,
adrs: &Address,
) -> XmssSignature<P> {
let mut wots_adrs = *adrs;
wots_adrs.set_type(AddressType::WotsHash);
wots_adrs.set_keypair_address(idx);
let wots_sig = Wots::<P, H>::sign(msg, sk_seed, pk_seed, &wots_adrs);
let mut auth_path = Vec::with_capacity(P::H_PRIME);
let mut k = idx;
for j in 0..P::H_PRIME as u32 {
let sibling_idx = k ^ 1;
let node = Self::xmss_node(sk_seed, pk_seed, sibling_idx, j, adrs);
auth_path.push(node);
k /= 2;
}
XmssSignature::new(wots_sig.as_bytes().to_vec(), auth_path)
}
pub fn xmss_pk_from_sig(
idx: u32,
sig: &XmssSignature<P>,
msg: &[u8],
pk_seed: &[u8],
adrs: &Address,
) -> Vec<u8> {
let mut wots_adrs = *adrs;
wots_adrs.set_type(AddressType::WotsHash);
wots_adrs.set_keypair_address(idx);
let wots_sig =
WotsSignature::<P>::from_bytes(sig.wots_sig()).expect("Invalid WOTS+ signature length");
let mut node = Wots::<P, H>::pk_from_sig(&wots_sig, msg, pk_seed, &wots_adrs);
let mut k = idx;
for j in 0..P::H_PRIME {
let mut tree_adrs = *adrs;
tree_adrs.set_type(AddressType::Tree);
tree_adrs.set_tree_height(j as u32 + 1);
tree_adrs.set_tree_index(k / 2);
let sibling = &sig.auth_path()[j];
if k % 2 == 0 {
node = H::h(pk_seed, &tree_adrs, &node, sibling);
} else {
node = H::h(pk_seed, &tree_adrs, sibling, &node);
}
k /= 2;
}
node
}
pub fn xmss_root(sk_seed: &[u8], pk_seed: &[u8], adrs: &Address) -> Vec<u8> {
Self::xmss_node(sk_seed, pk_seed, 0, P::H_PRIME as u32, adrs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::slh_dsa::hash::Sha2Hash;
use crate::slh_dsa::params::Sha2_128f;
type TestXmss = Xmss<Sha2_128f, Sha2Hash<Sha2_128f>>;
#[test]
fn test_xmss_leaf_is_wots_pk() {
let sk_seed = [1u8; 16];
let pk_seed = [2u8; 16];
let adrs = Address::tree(0, 0, 0, 0);
let leaf = TestXmss::xmss_node(&sk_seed, &pk_seed, 0, 0, &adrs);
assert_eq!(leaf.len(), Sha2_128f::N);
let mut wots_adrs = adrs;
wots_adrs.set_type(AddressType::WotsHash);
wots_adrs.set_keypair_address(0);
let wots_pk =
Wots::<Sha2_128f, Sha2Hash<Sha2_128f>>::keygen(&sk_seed, &pk_seed, &wots_adrs);
assert_eq!(leaf, wots_pk);
}
#[test]
fn test_xmss_node_deterministic() {
let sk_seed = [1u8; 16];
let pk_seed = [2u8; 16];
let adrs = Address::tree(0, 0, 0, 0);
let node1 = TestXmss::xmss_node(&sk_seed, &pk_seed, 0, 2, &adrs);
let node2 = TestXmss::xmss_node(&sk_seed, &pk_seed, 0, 2, &adrs);
assert_eq!(node1, node2);
}
#[test]
fn test_xmss_different_indices_different_leaves() {
let sk_seed = [1u8; 16];
let pk_seed = [2u8; 16];
let adrs = Address::tree(0, 0, 0, 0);
let leaf0 = TestXmss::xmss_node(&sk_seed, &pk_seed, 0, 0, &adrs);
let leaf1 = TestXmss::xmss_node(&sk_seed, &pk_seed, 1, 0, &adrs);
assert_ne!(leaf0, leaf1);
}
#[test]
fn test_xmss_sign_verify_roundtrip() {
let sk_seed = [1u8; 16];
let pk_seed = [2u8; 16];
let adrs = Address::tree(0, 0, 0, 0);
let msg = [42u8; 16];
let idx = 0u32;
let root = TestXmss::xmss_root(&sk_seed, &pk_seed, &adrs);
let sig = TestXmss::xmss_sign(&msg, &sk_seed, &pk_seed, idx, &adrs);
let computed_root = TestXmss::xmss_pk_from_sig(idx, &sig, &msg, &pk_seed, &adrs);
assert_eq!(root, computed_root);
}
#[test]
fn test_xmss_sign_verify_different_indices() {
let sk_seed = [1u8; 16];
let pk_seed = [2u8; 16];
let adrs = Address::tree(0, 0, 0, 0);
let msg = [42u8; 16];
let root = TestXmss::xmss_root(&sk_seed, &pk_seed, &adrs);
for idx in [0, 1, 2, 3] {
let sig = TestXmss::xmss_sign(&msg, &sk_seed, &pk_seed, idx, &adrs);
let computed_root = TestXmss::xmss_pk_from_sig(idx, &sig, &msg, &pk_seed, &adrs);
assert_eq!(root, computed_root, "Failed for idx={}", idx);
}
}
#[test]
fn test_xmss_wrong_message_fails() {
let sk_seed = [1u8; 16];
let pk_seed = [2u8; 16];
let adrs = Address::tree(0, 0, 0, 0);
let msg1 = [1u8; 16];
let msg2 = [2u8; 16];
let idx = 0u32;
let root = TestXmss::xmss_root(&sk_seed, &pk_seed, &adrs);
let sig = TestXmss::xmss_sign(&msg1, &sk_seed, &pk_seed, idx, &adrs);
let computed_root = TestXmss::xmss_pk_from_sig(idx, &sig, &msg2, &pk_seed, &adrs);
assert_ne!(root, computed_root);
}
#[test]
fn test_xmss_wrong_index_fails() {
let sk_seed = [1u8; 16];
let pk_seed = [2u8; 16];
let adrs = Address::tree(0, 0, 0, 0);
let msg = [42u8; 16];
let root = TestXmss::xmss_root(&sk_seed, &pk_seed, &adrs);
let sig = TestXmss::xmss_sign(&msg, &sk_seed, &pk_seed, 0, &adrs);
let computed_root = TestXmss::xmss_pk_from_sig(1, &sig, &msg, &pk_seed, &adrs);
assert_ne!(root, computed_root);
}
#[test]
fn test_xmss_signature_serialization() {
let sk_seed = [1u8; 16];
let pk_seed = [2u8; 16];
let adrs = Address::tree(0, 0, 0, 0);
let msg = [42u8; 16];
let sig = TestXmss::xmss_sign(&msg, &sk_seed, &pk_seed, 0, &adrs);
let bytes = sig.to_bytes();
assert_eq!(bytes.len(), XmssSignature::<Sha2_128f>::size());
let restored = XmssSignature::<Sha2_128f>::from_bytes(&bytes).unwrap();
assert_eq!(sig.wots_sig(), restored.wots_sig());
assert_eq!(sig.auth_path().len(), restored.auth_path().len());
}
#[test]
fn test_xmss_root_deterministic() {
let sk_seed = [1u8; 16];
let pk_seed = [2u8; 16];
let adrs = Address::tree(0, 0, 0, 0);
let root1 = TestXmss::xmss_root(&sk_seed, &pk_seed, &adrs);
let root2 = TestXmss::xmss_root(&sk_seed, &pk_seed, &adrs);
assert_eq!(root1, root2);
}
}