use crate::ots::util::coefs;
use crate::types::Typecode;
use digest::{Digest, Output};
use generic_array::{ArrayLength, GenericArray};
use sha2::Sha256;
use static_assertions::const_assert_eq;
use std::marker::PhantomData;
use typenum::consts::{U133, U265, U34, U67};
use typenum::Unsigned;
pub trait LmsOtsMode: Typecode {
type Hasher: Digest;
type NLen: ArrayLength<u8>;
type PLen: ArrayLength<Output<Self::Hasher>> + ArrayLength<u8>;
const N: usize;
const W: usize;
const U: usize; const V: usize; const P: usize;
const LS: usize;
const SIG_LEN: usize;
fn expand(message: &Output<Self::Hasher>) -> GenericArray<u8, Self::PLen> {
let mut arr: GenericArray<u8, <Self as LmsOtsMode>::PLen> = GenericArray::default();
for (i, c) in coefs(message, Self::W).enumerate().take(Self::U) {
arr[i] = c;
}
let cksum = (&arr)
.into_iter()
.take(Self::U)
.map(|&x| ((1u16 << Self::W) - 1 - (x as u16)))
.sum::<u16>()
<< Self::LS;
let cksum_bytes = cksum.to_be_bytes();
let cksum_chunks = coefs(&cksum_bytes, Self::W).take(Self::V);
for (i, c) in cksum_chunks.enumerate() {
arr[Self::U + i] = c;
}
arr
}
}
#[derive(Debug)]
pub struct LmsOtsModeInternal<
Hasher: Digest,
const W: usize,
PP: ArrayLength<GenericArray<u8, Hasher::OutputSize>> + ArrayLength<u8>,
const TC: u32,
> {
_phantomdata: PhantomData<(Hasher, PP)>,
}
impl<
Hasher: Digest,
const W: usize,
PP: ArrayLength<GenericArray<u8, Hasher::OutputSize>> + ArrayLength<u8>,
const TC: u32,
> Typecode for LmsOtsModeInternal<Hasher, W, PP, TC>
{
const TYPECODE: u32 = TC;
}
impl<
Hasher: Digest,
const W: usize,
PP: ArrayLength<GenericArray<u8, Hasher::OutputSize>> + ArrayLength<u8>,
const TC: u32,
> LmsOtsMode for LmsOtsModeInternal<Hasher, W, PP, TC>
{
type Hasher = Hasher;
type NLen = Hasher::OutputSize;
type PLen = PP;
const N: usize = Hasher::OutputSize::USIZE;
const W: usize = W;
const U: usize = (8 * Self::N + W - 1) / W;
const V: usize = ((((1 << W) - 1) * Self::U).ilog2() as usize / W) + 1;
const P: usize = Self::U + Self::V;
const LS: usize = 16 - Self::V * W;
const SIG_LEN: usize = 4 + Self::N * (Self::P + 1);
}
pub type LmsOtsSha256N32W1 = LmsOtsModeInternal<Sha256, 1, U265, 1>;
pub type LmsOtsSha256N32W2 = LmsOtsModeInternal<Sha256, 2, U133, 2>;
pub type LmsOtsSha256N32W4 = LmsOtsModeInternal<Sha256, 4, U67, 3>;
pub type LmsOtsSha256N32W8 = LmsOtsModeInternal<Sha256, 8, U34, 4>;
const_assert_eq!(
<LmsOtsSha256N32W1 as LmsOtsMode>::NLen::USIZE,
LmsOtsSha256N32W1::N
);
const_assert_eq!(
<LmsOtsSha256N32W1 as LmsOtsMode>::PLen::USIZE,
LmsOtsSha256N32W1::P
);
const_assert_eq!(LmsOtsSha256N32W1::N, 32);
const_assert_eq!(LmsOtsSha256N32W1::P, 265);
const_assert_eq!(LmsOtsSha256N32W1::LS, 7);
const_assert_eq!(LmsOtsSha256N32W1::SIG_LEN, 8516);
const_assert_eq!(
<LmsOtsSha256N32W2 as LmsOtsMode>::NLen::USIZE,
LmsOtsSha256N32W2::N
);
const_assert_eq!(
<LmsOtsSha256N32W2 as LmsOtsMode>::PLen::USIZE,
LmsOtsSha256N32W2::P
);
const_assert_eq!(LmsOtsSha256N32W2::N, 32);
const_assert_eq!(LmsOtsSha256N32W2::P, 133);
const_assert_eq!(LmsOtsSha256N32W2::LS, 6);
const_assert_eq!(LmsOtsSha256N32W2::SIG_LEN, 4292);
const_assert_eq!(
<LmsOtsSha256N32W4 as LmsOtsMode>::NLen::USIZE,
LmsOtsSha256N32W4::N
);
const_assert_eq!(
<LmsOtsSha256N32W4 as LmsOtsMode>::PLen::USIZE,
LmsOtsSha256N32W4::P
);
const_assert_eq!(LmsOtsSha256N32W4::N, 32);
const_assert_eq!(LmsOtsSha256N32W4::P, 67);
const_assert_eq!(LmsOtsSha256N32W4::LS, 4);
const_assert_eq!(LmsOtsSha256N32W4::SIG_LEN, 2180);
const_assert_eq!(
<LmsOtsSha256N32W8 as LmsOtsMode>::NLen::USIZE,
LmsOtsSha256N32W8::N
);
const_assert_eq!(
<LmsOtsSha256N32W8 as LmsOtsMode>::PLen::USIZE,
LmsOtsSha256N32W8::P
);
const_assert_eq!(LmsOtsSha256N32W8::N, 32);
const_assert_eq!(LmsOtsSha256N32W8::P, 34);
const_assert_eq!(LmsOtsSha256N32W8::LS, 0);
const_assert_eq!(LmsOtsSha256N32W8::SIG_LEN, 1124);
#[cfg(test)]
mod test {
use generic_array::GenericArray;
use super::LmsOtsMode;
#[test]
fn test_checksum_zero_w1() {
let arr = [0u8; super::LmsOtsSha256N32W1::N];
let cksm = super::LmsOtsSha256N32W1::expand(GenericArray::from_slice(&arr));
assert_eq!(
&cksm[super::LmsOtsSha256N32W1::U..],
&[1, 0, 0, 0, 0, 0, 0, 0, 0]
);
}
#[test]
fn test_checksum_ones_w1() {
let arr = [255u8; super::LmsOtsSha256N32W1::N];
let cksm = super::LmsOtsSha256N32W1::expand(GenericArray::from_slice(&arr));
assert_eq!(
&cksm[super::LmsOtsSha256N32W1::U..],
&[0, 0, 0, 0, 0, 0, 0, 0, 0]
);
}
#[test]
fn test_checksum_ten_w4() {
let arr = [0xaa; super::LmsOtsSha256N32W4::N];
let cksm = super::LmsOtsSha256N32W4::expand(GenericArray::from_slice(&arr));
assert_eq!(&cksm[super::LmsOtsSha256N32W4::U..], &[0x01, 0x04, 0x00]);
}
#[test]
fn test_expand_zero_w8() {
let arr = [0u8; super::LmsOtsSha256N32W8::N];
let expanded = super::LmsOtsSha256N32W8::expand(GenericArray::from_slice(&arr));
let mut expected = [0u8; super::LmsOtsSha256N32W8::P];
expected[super::LmsOtsSha256N32W8::U] = 0x1f;
expected[super::LmsOtsSha256N32W8::U + 1] = 0xe0;
assert_eq!(&expanded.as_slice(), &expected);
}
}