use core::mem::MaybeUninit;
use crate::WinternitzError;
#[repr(transparent)]
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub struct WinternitzRoot([u8; 32]);
#[repr(C)]
pub struct WinternitzPubkey<const N: usize> {
scalars: [[u8; 32]; N],
checksum: [[u8; 32]; 2],
}
impl<'a, const N: usize> TryFrom<&'a [u8]> for &'a WinternitzPubkey<N> {
type Error = WinternitzError;
fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
const { crate::assert_n::<N>() };
if value.len() != (N + 2) * 32 {
return Err(WinternitzError::InvalidLength);
}
Ok(unsafe { &*value.as_ptr().cast::<WinternitzPubkey<N>>() })
}
}
impl WinternitzRoot {
pub const fn new(bytes: [u8; 32]) -> Self {
Self(bytes)
}
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
impl From<[u8; 32]> for WinternitzRoot {
fn from(bytes: [u8; 32]) -> Self {
Self(bytes)
}
}
impl core::fmt::Display for WinternitzRoot {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "0x")?;
for b in &self.0 {
write!(f, "{:02x}", b)?;
}
Ok(())
}
}
impl<const N: usize> WinternitzPubkey<N> {
pub(crate) fn new(scalars: [[u8; 32]; N], checksum: [[u8; 32]; 2]) -> Self {
const { crate::assert_n::<N>() };
Self { scalars, checksum }
}
pub fn as_bytes(&self) -> &[u8] {
unsafe {
core::slice::from_raw_parts(
self as *const Self as *const u8,
core::mem::size_of::<Self>(),
)
}
}
pub fn merklize(&self) -> WinternitzRoot {
const { crate::assert_n::<N>() };
const LEAF_TAG: &[u8] = &[0x00];
const NODE_TAG: &[u8] = &[0x01];
const MAX_DEPTH: usize = 5;
let mut stack: [MaybeUninit<[u8; 32]>; MAX_DEPTH] =
[const { MaybeUninit::uninit() }; MAX_DEPTH];
let mut levels = [0u8; MAX_DEPTH];
let mut len: usize = 0;
for leaf in self.scalars.iter().chain(self.checksum.iter()) {
let mut h: [u8; 32] = solana_sha256_hasher::hashv(&[LEAF_TAG, leaf]).to_bytes();
let mut level: u8 = 0;
while len > 0 && levels[len - 1] == level {
let top = unsafe { stack[len - 1].assume_init_read() };
h = solana_sha256_hasher::hashv(&[NODE_TAG, &top, &h]).to_bytes();
level += 1;
len -= 1;
}
stack[len].write(h);
levels[len] = level;
len += 1;
}
while len > 1 {
let mut top = unsafe { stack[len - 1].assume_init_read() };
let mut top_level = levels[len - 1];
let next_level = levels[len - 2];
while top_level < next_level {
top = solana_sha256_hasher::hashv(&[NODE_TAG, &top, &top]).to_bytes();
top_level += 1;
}
let next = unsafe { stack[len - 2].assume_init_read() };
let combined = solana_sha256_hasher::hashv(&[NODE_TAG, &next, &top]).to_bytes();
stack[len - 2].write(combined);
levels[len - 2] = top_level + 1;
len -= 1;
}
let root = unsafe { stack[0].assume_init_read() };
WinternitzRoot(root)
}
}
impl<const N: usize> core::fmt::Display for WinternitzPubkey<N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "0x")?;
for s in self.scalars.iter().chain(self.checksum.iter()) {
for b in s {
write!(f, "{:02x}", b)?;
}
}
Ok(())
}
}
impl<const N: usize> core::fmt::Debug for WinternitzPubkey<N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
writeln!(f, "WinternitzPubkey {{")?;
for (i, s) in self.scalars.iter().enumerate() {
write!(f, " scalars[{}] = 0x", i)?;
for b in s {
write!(f, "{:02x}", b)?;
}
writeln!(f)?;
}
for (i, s) in self.checksum.iter().enumerate() {
write!(f, " checksum[{}] = 0x", i)?;
for b in s {
write!(f, "{:02x}", b)?;
}
writeln!(f)?;
}
write!(f, "}}")
}
}
impl<const N: usize> From<WinternitzPubkey<N>> for WinternitzRoot {
fn from(pk: WinternitzPubkey<N>) -> Self {
const { crate::assert_n::<N>() };
pk.merklize()
}
}
impl<const N: usize> From<&WinternitzPubkey<N>> for WinternitzRoot {
fn from(pk: &WinternitzPubkey<N>) -> Self {
const { crate::assert_n::<N>() };
pk.merklize()
}
}