#![allow(clippy::too_many_arguments)]
mod adrs;
mod hash;
mod params;
use adrs::{Adrs, AdrsType};
use alloc::vec;
use alloc::vec::Vec;
use params::{MAX_N, MAX_WOTS_LEN, Params};
pub use params::{XmssMtParamSet, XmssParamSet};
use crate::ct::ConstantTimeEq;
use crate::rng::{CryptoRng, RngCore};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum Error {
InvalidKey,
KeyExhausted,
}
fn wots_expand_seed(p: &Params, sk_seed: &[u8], pub_seed: &[u8], addr: &mut Adrs, out: &mut [u8]) {
let n = p.n;
addr.set_hash(0);
addr.set_key_and_mask(0);
let mut buf = [0u8; MAX_N + 32];
buf[..n].copy_from_slice(&pub_seed[..n]);
for i in 0..p.wots_len {
addr.set_chain(i as u32);
buf[n..n + 32].copy_from_slice(&addr.to_bytes());
hash::prf_keygen(p, sk_seed, &buf[..n + 32], &mut out[i * n..i * n + n]);
}
}
fn wots_chain(
p: &Params,
pub_seed: &[u8],
inout: &mut [u8],
start: u32,
steps: u32,
addr: &mut Adrs,
) {
let n = p.n;
let end = (start + steps).min(p.wots_w);
let base = hash::prf_base(p, pub_seed);
for i in start..end {
addr.set_hash(i);
let mut key = [0u8; MAX_N];
let mut bm = [0u8; MAX_N];
addr.set_key_and_mask(0);
hash::prf_with(p, &base, pub_seed, &addr.to_bytes(), &mut key);
addr.set_key_and_mask(1);
hash::prf_with(p, &base, pub_seed, &addr.to_bytes(), &mut bm);
let mut masked = [0u8; MAX_N];
for j in 0..n {
masked[j] = inout[j] ^ bm[j];
}
hash::f(p, &key, &masked, inout);
}
}
fn wots_pkgen(p: &Params, sk_seed: &[u8], pub_seed: &[u8], addr: &mut Adrs, pk: &mut [u8]) {
let n = p.n;
wots_expand_seed(p, sk_seed, pub_seed, addr, pk);
for i in 0..p.wots_len {
addr.set_chain(i as u32);
wots_chain(
p,
pub_seed,
&mut pk[i * n..i * n + n],
0,
p.wots_w - 1,
addr,
);
}
}
fn base_w(p: &Params, input: &[u8], out: &mut [u32]) {
let mut total = 0u8;
let mut bits = 0i32;
let mut in_idx = 0usize;
for o in out.iter_mut() {
if bits == 0 {
total = input[in_idx];
in_idx += 1;
bits = 8;
}
bits -= p.wots_log_w as i32;
*o = ((total >> bits) as u32) & (p.wots_w - 1);
}
}
fn chain_lengths(p: &Params, msg: &[u8]) -> [u32; MAX_WOTS_LEN] {
let mut lengths = [0u32; MAX_WOTS_LEN];
base_w(p, msg, &mut lengths[..p.wots_len1]);
let mut csum: u32 = 0;
for &l in &lengths[..p.wots_len1] {
csum += p.wots_w - 1 - l;
}
let shift = (8 - ((p.wots_len2 * p.wots_log_w as usize) % 8)) % 8;
csum <<= shift;
let csum_bytes_len = (p.wots_len2 * p.wots_log_w as usize).div_ceil(8);
let mut csum_bytes = [0u8; 4];
for (i, b) in csum_bytes.iter_mut().enumerate().take(csum_bytes_len) {
*b = (csum >> (8 * (csum_bytes_len - 1 - i))) as u8;
}
base_w(
p,
&csum_bytes[..csum_bytes_len],
&mut lengths[p.wots_len1..p.wots_len1 + p.wots_len2],
);
lengths
}
fn wots_sign(
p: &Params,
msg: &[u8],
sk_seed: &[u8],
pub_seed: &[u8],
addr: &mut Adrs,
sig: &mut [u8],
) {
let n = p.n;
let lengths = chain_lengths(p, msg);
wots_expand_seed(p, sk_seed, pub_seed, addr, sig);
for i in 0..p.wots_len {
addr.set_chain(i as u32);
wots_chain(p, pub_seed, &mut sig[i * n..i * n + n], 0, lengths[i], addr);
}
}
fn wots_pk_from_sig(
p: &Params,
sig: &[u8],
msg: &[u8],
pub_seed: &[u8],
addr: &mut Adrs,
pk: &mut [u8],
) {
let n = p.n;
let lengths = chain_lengths(p, msg);
for i in 0..p.wots_len {
addr.set_chain(i as u32);
pk[i * n..i * n + n].copy_from_slice(&sig[i * n..i * n + n]);
wots_chain(
p,
pub_seed,
&mut pk[i * n..i * n + n],
lengths[i],
p.wots_w - 1 - lengths[i],
addr,
);
}
}
fn rand_hash(
p: &Params,
left: &[u8],
right: &[u8],
pub_seed: &[u8],
addr: &mut Adrs,
out: &mut [u8],
) {
let n = p.n;
let mut key = [0u8; MAX_N];
let mut bm = [0u8; 2 * MAX_N];
let base = hash::prf_base(p, pub_seed);
addr.set_key_and_mask(0);
hash::prf_with(p, &base, pub_seed, &addr.to_bytes(), &mut key);
addr.set_key_and_mask(1);
hash::prf_with(p, &base, pub_seed, &addr.to_bytes(), &mut bm[..n]);
addr.set_key_and_mask(2);
hash::prf_with(p, &base, pub_seed, &addr.to_bytes(), &mut bm[n..2 * n]);
let mut masked = [0u8; 2 * MAX_N];
for i in 0..n {
masked[i] = left[i] ^ bm[i];
masked[n + i] = right[i] ^ bm[n + i];
}
hash::h(p, &key, &masked, out);
}
fn l_tree(p: &Params, wots_pk: &mut [u8], pub_seed: &[u8], addr: &mut Adrs, leaf: &mut [u8]) {
let n = p.n;
let mut l = p.wots_len;
let mut height = 0u32;
addr.set_tree_height(0);
while l > 1 {
let parents = l / 2;
for i in 0..parents {
addr.set_tree_index(i as u32);
let mut node = [0u8; MAX_N];
let mut left = [0u8; MAX_N];
let mut right = [0u8; MAX_N];
left[..n].copy_from_slice(&wots_pk[2 * i * n..2 * i * n + n]);
right[..n].copy_from_slice(&wots_pk[(2 * i + 1) * n..(2 * i + 1) * n + n]);
rand_hash(p, &left[..n], &right[..n], pub_seed, addr, &mut node);
wots_pk[i * n..i * n + n].copy_from_slice(&node[..n]);
}
if l & 1 == 1 {
let (lo, hi) = ((l / 2) * n, (l - 1) * n);
wots_pk.copy_within(hi..hi + n, lo);
l = l / 2 + 1;
} else {
l /= 2;
}
height += 1;
addr.set_tree_height(height);
}
leaf[..n].copy_from_slice(&wots_pk[..n]);
}
fn gen_leaf(
p: &Params,
sk_seed: &[u8],
pub_seed: &[u8],
ltree_addr: &mut Adrs,
ots_addr: &mut Adrs,
leaf: &mut [u8],
) {
let mut pk = vec![0u8; p.wots_sig_bytes()];
wots_pkgen(p, sk_seed, pub_seed, ots_addr, &mut pk);
l_tree(p, &mut pk, pub_seed, ltree_addr, leaf);
}
type SubtreeNodes = Vec<Vec<u8>>;
fn build_subtree(p: &Params, sk_seed: &[u8], pub_seed: &[u8], subtree_addr: &Adrs) -> SubtreeNodes {
let n = p.n;
let th = p.tree_height as usize;
let mut ots_addr = Adrs::new();
let mut ltree_addr = Adrs::new();
let mut node_addr = Adrs::new();
ots_addr.copy_subtree(subtree_addr);
ltree_addr.copy_subtree(subtree_addr);
node_addr.copy_subtree(subtree_addr);
ots_addr.set_type(AdrsType::Ots);
ltree_addr.set_type(AdrsType::Ltree);
node_addr.set_type(AdrsType::HashTree);
let leaf_count = 1usize << th;
let mut leaves = vec![0u8; leaf_count * n];
for idx in 0..leaf_count {
ltree_addr.set_ltree(idx as u32);
ots_addr.set_ots(idx as u32);
gen_leaf(
p,
sk_seed,
pub_seed,
&mut ltree_addr,
&mut ots_addr,
&mut leaves[idx * n..idx * n + n],
);
}
let mut levels: SubtreeNodes = Vec::with_capacity(th + 1);
levels.push(leaves);
for level in 1..=th {
let count = 1usize << (th - level);
let mut nodes = vec![0u8; count * n];
let child = &levels[level - 1];
for i in 0..count {
node_addr.set_tree_height((level - 1) as u32);
node_addr.set_tree_index(i as u32);
let mut parent = [0u8; MAX_N];
rand_hash(
p,
&child[2 * i * n..2 * i * n + n],
&child[(2 * i + 1) * n..(2 * i + 1) * n + n],
pub_seed,
&mut node_addr,
&mut parent,
);
nodes[i * n..i * n + n].copy_from_slice(&parent[..n]);
}
levels.push(nodes);
}
levels
}
fn auth_path_from_subtree(p: &Params, levels: &[Vec<u8>], idx_leaf: u32, auth_path: &mut [u8]) {
let n = p.n;
for j in 0..p.tree_height as usize {
let sib = ((idx_leaf >> j) ^ 1) as usize;
auth_path[j * n..j * n + n].copy_from_slice(&levels[j][sib * n..sib * n + n]);
}
}
#[derive(Default)]
struct SubtreeCache {
entries: Vec<(u32, u64, SubtreeNodes)>,
}
impl SubtreeCache {
fn seeded(layer: u32, tree: u64, nodes: SubtreeNodes) -> Self {
SubtreeCache {
entries: alloc::vec![(layer, tree, nodes)],
}
}
fn get_or_build(
&mut self,
p: &Params,
sk_seed: &[u8],
pub_seed: &[u8],
layer: u32,
tree: u64,
) -> &[Vec<u8>] {
let pos = match self
.entries
.iter()
.position(|(l, t, _)| *l == layer && *t == tree)
{
Some(pos) => pos,
None => {
self.entries.retain(|(l, _, _)| *l != layer);
let mut subtree_addr = Adrs::new();
subtree_addr.set_layer(layer);
subtree_addr.set_tree(tree);
let nodes = build_subtree(p, sk_seed, pub_seed, &subtree_addr);
self.entries.push((layer, tree, nodes));
self.entries.len() - 1
}
};
&self.entries[pos].2
}
}
fn root_from_sig(
p: &Params,
mut leaf_idx: u32,
leaf: &[u8],
auth_path: &[u8],
pub_seed: &[u8],
node_addr: &mut Adrs,
root: &mut [u8],
) {
let n = p.n;
let th = p.tree_height;
let mut buffer = [0u8; 2 * MAX_N];
if leaf_idx & 1 == 1 {
buffer[..n].copy_from_slice(&auth_path[..n]);
buffer[n..2 * n].copy_from_slice(&leaf[..n]);
} else {
buffer[..n].copy_from_slice(&leaf[..n]);
buffer[n..2 * n].copy_from_slice(&auth_path[..n]);
}
let mut ap = &auth_path[n..];
for i in 0..th - 1 {
node_addr.set_tree_height(i);
leaf_idx >>= 1;
node_addr.set_tree_index(leaf_idx);
let mut out = [0u8; MAX_N];
let mut left = [0u8; MAX_N];
let mut right = [0u8; MAX_N];
left[..n].copy_from_slice(&buffer[..n]);
right[..n].copy_from_slice(&buffer[n..2 * n]);
rand_hash(p, &left[..n], &right[..n], pub_seed, node_addr, &mut out);
if leaf_idx & 1 == 1 {
buffer[n..2 * n].copy_from_slice(&out[..n]);
buffer[..n].copy_from_slice(&ap[..n]);
} else {
buffer[..n].copy_from_slice(&out[..n]);
buffer[n..2 * n].copy_from_slice(&ap[..n]);
}
ap = &ap[n..];
}
node_addr.set_tree_height(th - 1);
leaf_idx >>= 1;
node_addr.set_tree_index(leaf_idx);
let mut left = [0u8; MAX_N];
let mut right = [0u8; MAX_N];
left[..n].copy_from_slice(&buffer[..n]);
right[..n].copy_from_slice(&buffer[n..2 * n]);
rand_hash(p, &left[..n], &right[..n], pub_seed, node_addr, root);
}
struct SkView<'a> {
p: &'a Params,
bytes: &'a [u8],
}
impl SkView<'_> {
fn sk_seed(&self) -> &[u8] {
&self.bytes[self.p.index_bytes..self.p.index_bytes + self.p.n]
}
fn sk_prf(&self) -> &[u8] {
let o = self.p.index_bytes + self.p.n;
&self.bytes[o..o + self.p.n]
}
fn root(&self) -> &[u8] {
let o = self.p.index_bytes + 2 * self.p.n;
&self.bytes[o..o + self.p.n]
}
fn pub_seed(&self) -> &[u8] {
let o = self.p.index_bytes + 3 * self.p.n;
&self.bytes[o..o + self.p.n]
}
}
fn bytes_to_idx(b: &[u8]) -> u64 {
b.iter().fold(0u64, |acc, &v| (acc << 8) | v as u64)
}
fn idx_to_bytes(idx: u64, out: &mut [u8]) {
let len = out.len();
let mut v = idx;
for i in (0..len).rev() {
out[i] = (v & 0xff) as u8;
v >>= 8;
}
}
fn core_sign(p: &Params, sk: &SkView, idx: u64, msg: &[u8], cache: &mut SubtreeCache) -> Vec<u8> {
let n = p.n;
let mut sig = vec![0u8; p.sig_bytes()];
idx_to_bytes(idx, &mut sig[..p.index_bytes]);
let mut idx32 = [0u8; 32];
idx32[24..32].copy_from_slice(&idx.to_be_bytes());
hash::prf(
p,
sk.sk_prf(),
&idx32,
&mut sig[p.index_bytes..p.index_bytes + n],
);
let mut mhash = [0u8; MAX_N];
{
let r = sig[p.index_bytes..p.index_bytes + n].to_vec();
hash::h_msg(p, &r, sk.root(), idx, msg, &mut mhash);
}
let mut off = p.index_bytes + n;
let leaf_mask = (1u64 << p.tree_height) - 1;
let mut cur_idx = idx;
let mut root = [0u8; MAX_N];
root[..n].copy_from_slice(&mhash[..n]);
let mut ots_addr = Adrs::new();
ots_addr.set_type(AdrsType::Ots);
for layer in 0..p.d {
let idx_leaf = (cur_idx & leaf_mask) as u32;
let tree = cur_idx >> p.tree_height;
ots_addr.set_layer(layer);
ots_addr.set_tree(tree);
ots_addr.set_ots(idx_leaf);
wots_sign(
p,
&root[..n],
sk.sk_seed(),
sk.pub_seed(),
&mut ots_addr,
&mut sig[off..off + p.wots_sig_bytes()],
);
off += p.wots_sig_bytes();
let th = p.tree_height as usize;
let nodes = cache.get_or_build(p, sk.sk_seed(), sk.pub_seed(), layer, tree);
auth_path_from_subtree(p, nodes, idx_leaf, &mut sig[off..off + th * n]);
root[..n].copy_from_slice(&nodes[th][..n]);
off += th * n;
cur_idx = tree;
}
sig
}
fn core_verify(p: &Params, pub_root: &[u8], pub_seed: &[u8], sig: &[u8], msg: &[u8]) -> bool {
let n = p.n;
if sig.len() != p.sig_bytes() {
return false;
}
let idx = bytes_to_idx(&sig[..p.index_bytes]);
let r = &sig[p.index_bytes..p.index_bytes + n];
let mut mhash = [0u8; MAX_N];
hash::h_msg(p, r, pub_root, idx, msg, &mut mhash);
let mut off = p.index_bytes + n;
let leaf_mask = (1u64 << p.tree_height) - 1;
let mut cur_idx = idx;
let mut root = [0u8; MAX_N];
root[..n].copy_from_slice(&mhash[..n]);
let mut ots_addr = Adrs::new();
let mut ltree_addr = Adrs::new();
let mut node_addr = Adrs::new();
ots_addr.set_type(AdrsType::Ots);
ltree_addr.set_type(AdrsType::Ltree);
node_addr.set_type(AdrsType::HashTree);
for layer in 0..p.d {
let idx_leaf = (cur_idx & leaf_mask) as u32;
let tree = cur_idx >> p.tree_height;
ots_addr.set_layer(layer);
ltree_addr.set_layer(layer);
node_addr.set_layer(layer);
ots_addr.set_tree(tree);
ltree_addr.set_tree(tree);
node_addr.set_tree(tree);
ots_addr.set_ots(idx_leaf);
let mut wots_pk = vec![0u8; p.wots_sig_bytes()];
wots_pk_from_sig(
p,
&sig[off..off + p.wots_sig_bytes()],
&root[..n],
pub_seed,
&mut ots_addr,
&mut wots_pk,
);
off += p.wots_sig_bytes();
ltree_addr.set_ltree(idx_leaf);
let mut leaf = [0u8; MAX_N];
l_tree(p, &mut wots_pk, pub_seed, &mut ltree_addr, &mut leaf);
let auth_path = &sig[off..off + p.tree_height as usize * n];
off += p.tree_height as usize * n;
let mut new_root = [0u8; MAX_N];
root_from_sig(
p,
idx_leaf,
&leaf[..n],
auth_path,
pub_seed,
&mut node_addr,
&mut new_root,
);
root[..n].copy_from_slice(&new_root[..n]);
cur_idx = tree;
}
bool::from(root[..n].ct_eq(&pub_root[..n]))
}
fn core_keygen(p: &Params, seed: &[u8]) -> (Vec<u8>, Vec<u8>, SubtreeNodes) {
let n = p.n;
let mut sk = vec![0u8; p.sk_bytes()];
sk[p.index_bytes..p.index_bytes + 2 * n].copy_from_slice(&seed[..2 * n]); sk[p.index_bytes + 3 * n..p.index_bytes + 4 * n].copy_from_slice(&seed[2 * n..3 * n]);
let sk_seed = seed[..n].to_vec();
let pub_seed = seed[2 * n..3 * n].to_vec();
let mut top_addr = Adrs::new();
top_addr.set_layer(p.d - 1);
let levels = build_subtree(p, &sk_seed, &pub_seed, &top_addr);
let th = p.tree_height as usize;
sk[p.index_bytes + 2 * n..p.index_bytes + 3 * n].copy_from_slice(&levels[th][..n]);
let mut pk = vec![0u8; p.pk_bytes()];
pk[..n].copy_from_slice(&levels[th][..n]);
pk[n..2 * n].copy_from_slice(&pub_seed);
(sk, pk, levels)
}
const SK_MAGIC: &[u8; 4] = b"XMSk";
const MTSK_MAGIC: &[u8; 4] = b"XMTk";
fn wipe(v: &mut [u8]) {
for b in v.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&v);
}
pub struct XmssPrivateKey {
set: XmssParamSet,
bytes: Vec<u8>,
cache: SubtreeCache,
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct XmssPublicKey {
set: XmssParamSet,
bytes: Vec<u8>,
}
impl XmssPrivateKey {
pub fn parameter_set(&self) -> XmssParamSet {
self.set
}
pub fn from_seed(set: XmssParamSet, seed: &[u8]) -> Self {
let p = set.params();
assert!(
seed.len() >= 3 * p.n,
"XMSS from_seed: seed must be 3n bytes"
);
let (bytes, _pk, top) = core_keygen(&p, &seed[..3 * p.n]);
XmssPrivateKey {
set,
bytes,
cache: SubtreeCache::seeded(p.d - 1, 0, top),
}
}
pub fn generate<R: RngCore + CryptoRng>(set: XmssParamSet, rng: &mut R) -> Self {
let p = set.params();
let mut seed = vec![0u8; 3 * p.n];
rng.fill_bytes(&mut seed);
let sk = Self::from_seed(set, &seed);
wipe(&mut seed);
sk
}
pub fn public_key(&self) -> XmssPublicKey {
let p = self.set.params();
let n = p.n;
let mut bytes = vec![0u8; 2 * n];
bytes[..n].copy_from_slice(&self.bytes[p.index_bytes + 2 * n..p.index_bytes + 3 * n]);
bytes[n..2 * n].copy_from_slice(&self.bytes[p.index_bytes + 3 * n..p.index_bytes + 4 * n]);
XmssPublicKey {
set: self.set,
bytes,
}
}
pub fn index(&self) -> u64 {
let p = self.set.params();
bytes_to_idx(&self.bytes[..p.index_bytes])
}
pub fn remaining(&self) -> u64 {
let p = self.set.params();
let total = 1u64 << p.full_height;
total.saturating_sub(self.index())
}
pub fn sign(&mut self, msg: &[u8]) -> Result<Vec<u8>, Error> {
let p = self.set.params();
let idx = self.index();
if idx >= (1u64 << p.full_height) {
return Err(Error::KeyExhausted);
}
let sig = {
let view = SkView {
p: &p,
bytes: &self.bytes,
};
core_sign(&p, &view, idx, msg, &mut self.cache)
};
idx_to_bytes(idx + 1, &mut self.bytes[..p.index_bytes]);
Ok(sig)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(8 + self.bytes.len());
out.extend_from_slice(SK_MAGIC);
out.extend_from_slice(&self.set.oid().to_be_bytes());
out.extend_from_slice(&self.bytes);
out
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
if bytes.len() < 8 || &bytes[..4] != SK_MAGIC {
return Err(Error::InvalidKey);
}
let oid = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
let set = XmssParamSet::from_oid(oid).ok_or(Error::InvalidKey)?;
let p = set.params();
let raw = &bytes[8..];
if raw.len() != p.sk_bytes() {
return Err(Error::InvalidKey);
}
Ok(XmssPrivateKey {
set,
bytes: raw.to_vec(),
cache: SubtreeCache::default(),
})
}
}
impl Drop for XmssPrivateKey {
fn drop(&mut self) {
wipe(&mut self.bytes);
}
}
impl XmssPublicKey {
pub fn parameter_set(&self) -> XmssParamSet {
self.set
}
pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool {
let p = self.set.params();
let n = p.n;
core_verify(&p, &self.bytes[..n], &self.bytes[n..2 * n], sig, msg)
}
pub fn to_bytes(&self) -> &[u8] {
&self.bytes
}
pub fn from_bytes(set: XmssParamSet, bytes: &[u8]) -> Result<Self, Error> {
if bytes.len() != set.params().pk_bytes() {
return Err(Error::InvalidKey);
}
Ok(XmssPublicKey {
set,
bytes: bytes.to_vec(),
})
}
}
pub struct XmssMtPrivateKey {
set: XmssMtParamSet,
bytes: Vec<u8>,
cache: SubtreeCache,
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct XmssMtPublicKey {
set: XmssMtParamSet,
bytes: Vec<u8>,
}
impl XmssMtPrivateKey {
pub fn parameter_set(&self) -> XmssMtParamSet {
self.set
}
pub fn from_seed(set: XmssMtParamSet, seed: &[u8]) -> Self {
let p = set.params();
assert!(
seed.len() >= 3 * p.n,
"XMSS^MT from_seed: seed must be 3n bytes"
);
let (bytes, _pk, top) = core_keygen(&p, &seed[..3 * p.n]);
XmssMtPrivateKey {
set,
bytes,
cache: SubtreeCache::seeded(p.d - 1, 0, top),
}
}
pub fn generate<R: RngCore + CryptoRng>(set: XmssMtParamSet, rng: &mut R) -> Self {
let p = set.params();
let mut seed = vec![0u8; 3 * p.n];
rng.fill_bytes(&mut seed);
let sk = Self::from_seed(set, &seed);
wipe(&mut seed);
sk
}
pub fn public_key(&self) -> XmssMtPublicKey {
let p = self.set.params();
let n = p.n;
let mut bytes = vec![0u8; 2 * n];
bytes[..n].copy_from_slice(&self.bytes[p.index_bytes + 2 * n..p.index_bytes + 3 * n]);
bytes[n..2 * n].copy_from_slice(&self.bytes[p.index_bytes + 3 * n..p.index_bytes + 4 * n]);
XmssMtPublicKey {
set: self.set,
bytes,
}
}
pub fn index(&self) -> u64 {
let p = self.set.params();
bytes_to_idx(&self.bytes[..p.index_bytes])
}
pub fn remaining(&self) -> u64 {
let p = self.set.params();
let total = if p.full_height >= 64 {
u64::MAX
} else {
1u64 << p.full_height
};
total.saturating_sub(self.index())
}
pub fn sign(&mut self, msg: &[u8]) -> Result<Vec<u8>, Error> {
let p = self.set.params();
let idx = self.index();
let exhausted = if p.full_height >= 64 {
false
} else {
idx >= (1u64 << p.full_height)
};
if exhausted {
return Err(Error::KeyExhausted);
}
let sig = {
let view = SkView {
p: &p,
bytes: &self.bytes,
};
core_sign(&p, &view, idx, msg, &mut self.cache)
};
idx_to_bytes(idx + 1, &mut self.bytes[..p.index_bytes]);
Ok(sig)
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(8 + self.bytes.len());
out.extend_from_slice(MTSK_MAGIC);
out.extend_from_slice(&self.set.oid().to_be_bytes());
out.extend_from_slice(&self.bytes);
out
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
if bytes.len() < 8 || &bytes[..4] != MTSK_MAGIC {
return Err(Error::InvalidKey);
}
let oid = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
let set = XmssMtParamSet::from_oid(oid).ok_or(Error::InvalidKey)?;
let p = set.params();
let raw = &bytes[8..];
if raw.len() != p.sk_bytes() {
return Err(Error::InvalidKey);
}
Ok(XmssMtPrivateKey {
set,
bytes: raw.to_vec(),
cache: SubtreeCache::default(),
})
}
}
impl Drop for XmssMtPrivateKey {
fn drop(&mut self) {
wipe(&mut self.bytes);
}
}
impl XmssMtPublicKey {
pub fn parameter_set(&self) -> XmssMtParamSet {
self.set
}
pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool {
let p = self.set.params();
let n = p.n;
core_verify(&p, &self.bytes[..n], &self.bytes[n..2 * n], sig, msg)
}
pub fn to_bytes(&self) -> &[u8] {
&self.bytes
}
pub fn from_bytes(set: XmssMtParamSet, bytes: &[u8]) -> Result<Self, Error> {
if bytes.len() != set.params().pk_bytes() {
return Err(Error::InvalidKey);
}
Ok(XmssMtPublicKey {
set,
bytes: bytes.to_vec(),
})
}
}
#[cfg(test)]
mod tests;