use super::ots;
use super::params::{D_INTR, D_LEAF, LmotsType, LmsType, N};
use crate::ct::ConstantTimeEq;
use crate::hash::{Digest, Sha256};
use alloc::vec;
use alloc::vec::Vec;
fn leaf_hash(
lms: LmsType,
ots_type: LmotsType,
i_id: &[u8; 16],
seed: &[u8; N],
q: u32,
) -> [u8; N] {
let k = ots::public_key(ots_type, i_id, seed, q);
let node_num = (1u64 << lms.h()) as u32 + q;
let mut h = Sha256::new();
h.update(i_id);
h.update(&node_num.to_be_bytes());
h.update(&D_LEAF.to_be_bytes());
h.update(&k);
h.finalize()
}
fn interior_hash(i_id: &[u8; 16], node_num: u32, left: &[u8; N], right: &[u8; N]) -> [u8; N] {
let mut h = Sha256::new();
h.update(i_id);
h.update(&node_num.to_be_bytes());
h.update(&D_INTR.to_be_bytes());
h.update(left);
h.update(right);
h.finalize()
}
fn node_value(
lms: LmsType,
ots_type: LmotsType,
i_id: &[u8; 16],
seed: &[u8; N],
node_num: u32,
) -> [u8; N] {
let leaf_base = (1u64 << lms.h()) as u32;
if node_num >= leaf_base {
leaf_hash(lms, ots_type, i_id, seed, node_num - leaf_base)
} else {
let left = node_value(lms, ots_type, i_id, seed, 2 * node_num);
let right = node_value(lms, ots_type, i_id, seed, 2 * node_num + 1);
interior_hash(i_id, node_num, &left, &right)
}
}
pub(crate) fn compute_root(
lms: LmsType,
ots_type: LmotsType,
i_id: &[u8; 16],
seed: &[u8; N],
) -> [u8; N] {
node_value(lms, ots_type, i_id, seed, 1)
}
pub(crate) fn encode_public_key(
lms: LmsType,
ots_type: LmotsType,
i_id: &[u8; 16],
root: &[u8; N],
) -> Vec<u8> {
let mut v = Vec::with_capacity(24 + N);
v.extend_from_slice(&lms.typecode().to_be_bytes());
v.extend_from_slice(&ots_type.typecode().to_be_bytes());
v.extend_from_slice(i_id);
v.extend_from_slice(root);
v
}
pub(crate) fn sign(
lms: LmsType,
ots_type: LmotsType,
i_id: &[u8; 16],
seed: &[u8; N],
q: u32,
c: &[u8; N],
message: &[u8],
) -> Vec<u8> {
let h = lms.h();
let ots_len = ots_type.sig_len();
let mut sig = vec![0u8; 4 + ots_len + 4 + h as usize * N];
sig[..4].copy_from_slice(&q.to_be_bytes());
ots::sign(
ots_type,
i_id,
seed,
q,
c,
message,
&mut sig[4..4 + ots_len],
);
let lms_type_off = 4 + ots_len;
sig[lms_type_off..lms_type_off + 4].copy_from_slice(&lms.typecode().to_be_bytes());
let mut path_off = lms_type_off + 4;
let r = (1u64 << h) as u32 + q;
for i in 0..h {
let node = (r >> i) ^ 1;
let val = node_value(lms, ots_type, i_id, seed, node);
sig[path_off..path_off + N].copy_from_slice(&val);
path_off += N;
}
sig
}
pub(crate) fn recover_root(
pubtype: LmsType,
ots_pubtype: LmotsType,
i_id: &[u8; 16],
message: &[u8],
sig: &[u8],
) -> Option<[u8; N]> {
if sig.len() < 8 {
return None;
}
let q = u32::from_be_bytes([sig[0], sig[1], sig[2], sig[3]]);
let otssigtype = u32::from_be_bytes([sig[4], sig[5], sig[6], sig[7]]);
if otssigtype != ots_pubtype.typecode() {
return None;
}
let ots_len = ots_pubtype.sig_len();
let h = pubtype.h();
let expected = 4 + ots_len + 4 + h as usize * N;
if sig.len() != expected {
return None;
}
let ots_sig = &sig[4..4 + ots_len];
let lms_type_off = 4 + ots_len;
let sigtype = u32::from_be_bytes([
sig[lms_type_off],
sig[lms_type_off + 1],
sig[lms_type_off + 2],
sig[lms_type_off + 3],
]);
if sigtype != pubtype.typecode() {
return None;
}
if q as u64 >= pubtype.leaves() {
return None;
}
let kc = ots::recover_public_key(ots_pubtype, i_id, q, message, ots_sig)?;
let mut node_num = (1u64 << h) as u32 + q;
let mut tmp = {
let mut hh = Sha256::new();
hh.update(i_id);
hh.update(&node_num.to_be_bytes());
hh.update(&D_LEAF.to_be_bytes());
hh.update(&kc);
hh.finalize()
};
let path_base = lms_type_off + 4;
let mut i = 0usize;
while node_num > 1 {
let off = path_base + i * N;
let path_node = &sig[off..off + N];
let parent = node_num / 2;
if node_num & 1 == 1 {
let mut hh = Sha256::new();
hh.update(i_id);
hh.update(&parent.to_be_bytes());
hh.update(&D_INTR.to_be_bytes());
hh.update(path_node);
hh.update(&tmp);
tmp = hh.finalize();
} else {
let mut hh = Sha256::new();
hh.update(i_id);
hh.update(&parent.to_be_bytes());
hh.update(&D_INTR.to_be_bytes());
hh.update(&tmp);
hh.update(path_node);
tmp = hh.finalize();
}
node_num = parent;
i += 1;
}
Some(tmp)
}
pub(crate) fn verify(public_key: &[u8], message: &[u8], sig: &[u8]) -> bool {
if public_key.len() < 8 {
return false;
}
let pubtype = match LmsType::from_u32(u32::from_be_bytes([
public_key[0],
public_key[1],
public_key[2],
public_key[3],
])) {
Some(t) => t,
None => return false,
};
let ots_pubtype = match LmotsType::from_u32(u32::from_be_bytes([
public_key[4],
public_key[5],
public_key[6],
public_key[7],
])) {
Some(t) => t,
None => return false,
};
if public_key.len() != 24 + N {
return false;
}
let mut i_id = [0u8; 16];
i_id.copy_from_slice(&public_key[8..24]);
let t1 = &public_key[24..24 + N];
match recover_root(pubtype, ots_pubtype, &i_id, message, sig) {
Some(tc) => bool::from(tc[..].ct_eq(t1)),
None => false,
}
}