#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
use core::{fmt::Debug, mem};
#[cfg(feature = "std")]
use std::io;
#[cfg(feature = "sign")]
mod sign;
#[cfg(feature = "verify")]
pub mod verify;
pub const SHA256_LEN: usize = 32;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum Winternitz {
Reserved = 0x0,
W1,
W2,
W4,
W8,
}
impl From<Winternitz> for u32 {
fn from(value: Winternitz) -> Self {
match value {
Winternitz::Reserved => 0,
Winternitz::W1 => 1,
Winternitz::W2 => 2,
Winternitz::W4 => 4,
Winternitz::W8 => 8,
}
}
}
impl From<u32> for Winternitz {
fn from(value: u32) -> Self {
match value {
0 => Winternitz::Reserved,
1 => Winternitz::W1,
2 => Winternitz::W2,
4 => Winternitz::W4,
8 => Winternitz::W8,
_ => panic!(),
}
}
}
impl From<Winternitz> for usize {
fn from(value: Winternitz) -> Self {
match value {
Winternitz::Reserved => 0,
Winternitz::W1 => 1,
Winternitz::W2 => 2,
Winternitz::W4 => 4,
Winternitz::W8 => 8,
}
}
}
impl From<u8> for Winternitz {
fn from(value: u8) -> Self {
match value {
0 => Winternitz::Reserved,
1 => Winternitz::W1,
2 => Winternitz::W2,
4 => Winternitz::W4,
8 => Winternitz::W8,
_ => panic!(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct LdwmParams {
pub w: Winternitz,
pub m: usize,
pub k: usize,
pub h: usize,
}
impl LdwmParams {
#[cfg(feature = "sign")]
fn serialize<F>(&self, fd: &mut F) -> io::Result<()>
where
F: io::Write,
{
fd.write_all(&(self.h as u32).to_be_bytes())?;
fd.write_all(&(self.k as u32).to_be_bytes())?;
fd.write_all(&(self.m as u32).to_be_bytes())?;
fd.write_all(&(u32::from(self.w)).to_be_bytes())
}
#[cfg(feature = "sign")]
fn deserialize<F>(fd: &mut F) -> io::Result<Self>
where
F: io::Read,
{
let mut buf = 0u32.to_be_bytes();
fd.read_exact(&mut buf)?;
let h = u32::from_be_bytes(buf) as usize;
fd.read_exact(&mut buf)?;
let k = u32::from_be_bytes(buf) as usize;
fd.read_exact(&mut buf)?;
let m = u32::from_be_bytes(buf) as usize;
fd.read_exact(&mut buf)?;
let w = u32::from_be_bytes(buf);
let w = Winternitz::from(w);
Ok(Self { h, k, m, w })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Signature<'a> {
pub auth_path: &'a [u8],
pub node_num: usize,
pub ots: &'a [u8],
}
#[cfg(feature = "sign")]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GeneratedSignature {
pub auth_path: Vec<u8>,
pub node_num: usize,
pub ots: Vec<u8>,
}
#[cfg(feature = "sign")]
impl GeneratedSignature {
pub fn as_borrowed(&self) -> Signature<'_> {
Signature {
auth_path: &self.auth_path,
node_num: self.node_num,
ots: &self.ots,
}
}
}
#[cfg(feature = "sign")]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct AuthTreeNode {
value: Vec<u8>,
children: Vec<AuthTreeNode>,
}
#[cfg(feature = "sign")]
impl AuthTreeNode {
fn serialize<F>(&self, fd: &mut F) -> io::Result<()>
where
F: io::Write,
{
fd.write_all(&self.value)?;
for c in &self.children {
c.serialize(fd)?;
}
Ok(())
}
fn deserialize<F>(lvl: usize, k: usize, fd: &mut F) -> io::Result<Self>
where
F: io::Read,
{
let mut node = AuthTreeNode {
value: vec![0u8; SHA256_LEN],
children: Vec::with_capacity(k),
};
fd.read_exact(&mut node.value)?;
if lvl > 0 {
for _ in 0..k {
node.children
.push(AuthTreeNode::deserialize(lvl - 1, k, fd)?);
}
}
Ok(node)
}
}
#[cfg(feature = "sign")]
pub struct LdwmPrivateKey {
params: LdwmParams,
ots_keys: Vec<Vec<u8>>,
node_num: usize,
tree: AuthTreeNode,
}
#[cfg(feature = "sign")]
impl LdwmPrivateKey {
pub fn serialize<F>(self, fd: &mut F) -> io::Result<()>
where
F: io::Write,
{
fd.write_all(&(self.node_num as u32).to_be_bytes())?;
self.params.serialize(fd)?;
let ots_keys: Vec<_> = self.ots_keys.into_iter().flatten().collect();
fd.write_all(&ots_keys)?;
self.tree.serialize(fd)
}
pub fn deserialize<F>(fd: &mut F) -> io::Result<Self>
where
F: io::Read,
{
let mut buf = 0u32.to_be_bytes();
fd.read_exact(&mut buf)?;
let node_num = u32::from_be_bytes(buf) as usize;
let params = LdwmParams::deserialize(fd)?;
let w: u32 = params.w.into();
let u = 8 * (SHA256_LEN as u32) / w;
let v = u.ilog2().div_ceil(w);
let p = (u + v) as usize;
let num_keys = params.k.pow(params.h as u32);
let mut ots_keys = Vec::with_capacity(num_keys);
for _ in 0..num_keys {
let mut ots_key = vec![0u8; SHA256_LEN * p];
fd.read_exact(&mut ots_key)?;
ots_keys.push(ots_key);
}
let tree = AuthTreeNode::deserialize(params.h, params.k, fd)?;
Ok(Self {
node_num,
params,
ots_keys,
tree,
})
}
}
#[cfg(feature = "sign")]
impl Debug for LdwmPrivateKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"LDWM Private Key {{ num_keys: {} }}",
self.ots_keys.len()
)
}
}
fn coef<const N: usize>(i: usize, w: Winternitz, s: &[u8; N], ck: &[u8; 2]) -> u8 {
let w: usize = w.into();
let max: usize = Winternitz::W8.into();
let mask: u8 = (1 << w) - 1;
let shift = max - (w * (i % (max / w)) + w);
let idx = (i * w) / max;
if idx >= N {
mask & (ck[idx - N] >> shift)
} else {
mask & (s[(i * w) / max] >> shift)
}
}
fn checksum<const N: usize>(w: Winternitz, s: &[u8; N]) -> [u8; mem::size_of::<u16>()] {
let u: usize = N * 8 / usize::from(w);
let mask = (1 << usize::from(w)) - 1;
let ck = (0..u)
.map(|i| mask - coef(i, w, s, &[0u8; 2]) as u16)
.sum::<u16>();
let shift = ck.leading_zeros() - (ck.leading_zeros() % 4);
(ck << shift).to_be_bytes()
}
#[cfg(test)]
mod tests {
use core::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use crate::{LdwmParams, LdwmPrivateKey};
#[test]
fn test_serialization_roundtrip() {
let p1 = LdwmParams {
h: 2,
k: 4,
m: 20,
w: crate::Winternitz::W4,
};
fn hash_key(key: &LdwmPrivateKey) -> u64 {
let mut hasher = DefaultHasher::new();
key.params.hash(&mut hasher);
key.node_num.hash(&mut hasher);
key.ots_keys.hash(&mut hasher);
key.tree.hash(&mut hasher);
hasher.finish()
}
let k1 = LdwmPrivateKey::new(&p1);
let k1_hash = hash_key(&k1);
let mut buffer = Vec::new();
k1.serialize(&mut buffer).unwrap();
let k2 = LdwmPrivateKey::deserialize(&mut buffer.as_slice()).unwrap();
let k2_hash = hash_key(&k2);
assert_eq!(k1_hash, k2_hash);
}
}