#![warn(missing_docs)]
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use thiserror::Error;
pub const HASH_SIZE: usize = 32;
pub const MASTER_SEED_ROUNDS: usize = 2048;
pub const ROOT_GENERATION_ROUNDS: usize = 2048;
pub const DEFAULT_LADDER_HEIGHT: usize = 100_000;
const DOMAIN_MASTER: &[u8] = b"YEDAD_MASTER";
const DOMAIN_ROOT: &[u8] = b"YEDAD_ROOT";
const DOMAIN_STEP: &[u8] = b"YEDAD_STEP";
pub type Hash = [u8; HASH_SIZE];
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum YedadError {
#[error("invalid input size")]
InvalidInputSize,
#[error("invalid proof")]
InvalidProof,
#[error("ladder exhausted")]
LadderExhausted,
#[error("invalid ladder index (must be > 0)")]
InvalidLadderIndex,
#[error("step out of range")]
StepOutOfRange,
#[error("hash not found in ladder")]
HashNotFound,
}
#[inline]
fn sha256(data: &[u8]) -> Hash {
let mut hasher = Sha256::new();
hasher.update(data);
let result = hasher.finalize();
let mut out = [0u8; HASH_SIZE];
out.copy_from_slice(&result);
out
}
#[inline]
fn ct_eq(a: &Hash, b: &Hash) -> bool {
a.ct_eq(b).into()
}
#[inline]
fn double_hash_domain(domain: &[u8], data: &[u8]) -> Hash {
let mut hasher = Sha256::new();
hasher.update(domain);
hasher.update(data);
let first = hasher.finalize();
sha256(&first)
}
#[inline]
pub fn double_hash(data: &[u8]) -> Hash {
double_hash_domain(DOMAIN_STEP, data)
}
fn iterative_double_hash(initial: &[u8], rounds: usize, domain: &[u8]) -> Hash {
let mut current = double_hash_domain(domain, initial);
for _ in 1..rounds {
current = double_hash_domain(domain, ¤t);
}
current
}
fn find_position(root: Hash, target: &Hash, height: usize) -> Option<usize> {
let mut current = root;
for idx in 0..=height {
if ct_eq(¤t, target) {
return Some(idx);
}
current = double_hash(¤t);
}
None
}
#[derive(Debug, Clone)]
pub struct Yedad {
master_seed: Hash,
}
impl Yedad {
pub fn from_private_key<T: AsRef<[u8]>>(private_key: T) -> Self {
let master_seed =
iterative_double_hash(private_key.as_ref(), MASTER_SEED_ROUNDS, DOMAIN_MASTER);
Self { master_seed }
}
pub fn root(&self, ladder_id: u32) -> Result<Hash, YedadError> {
if ladder_id == 0 {
return Err(YedadError::InvalidLadderIndex);
}
let mut buf = [0u8; HASH_SIZE + 4];
buf[..HASH_SIZE].copy_from_slice(&self.master_seed);
buf[HASH_SIZE..].copy_from_slice(&ladder_id.to_be_bytes());
Ok(iterative_double_hash(
&buf,
ROOT_GENERATION_ROUNDS,
DOMAIN_ROOT,
))
}
pub fn step(&self, ladder_id: u32, step_index: usize) -> Result<Hash, YedadError> {
let mut current = self.root(ladder_id)?;
for _ in 0..step_index {
current = double_hash(¤t);
}
Ok(current)
}
pub fn last_step(&self, ladder_id: u32, ladder_height: usize) -> Result<Hash, YedadError> {
self.step(ladder_id, ladder_height)
}
pub fn get_proof(
&self,
ladder_id: u32,
position: usize,
ladder_height: usize,
) -> Result<Hash, YedadError> {
if position >= ladder_height {
return Err(YedadError::LadderExhausted);
}
let step_index = ladder_height - 1 - position;
self.step(ladder_id, step_index)
}
pub fn find_proof_by_current_hash(
&self,
ladder_id: u32,
current_hash: &Hash,
ladder_height: usize,
) -> Result<Hash, YedadError> {
let root = self.root(ladder_id)?;
if let Some(idx) = find_position(root, current_hash, ladder_height) {
if idx == 0 {
return Err(YedadError::StepOutOfRange);
}
let mut prev = root;
let mut cur = double_hash(&prev);
for _ in 1..idx {
prev = cur;
cur = double_hash(&cur);
}
Ok(prev)
} else {
Err(YedadError::HashNotFound)
}
}
pub fn master_seed(&self) -> &Hash {
&self.master_seed
}
}
#[derive(Debug, Clone, Copy)]
pub struct Slot {
current_hash: Hash,
ladder_id: u32,
position: usize,
}
impl Slot {
pub fn new(yedad: &Yedad, ladder_id: u32, ladder_height: usize) -> Result<Self, YedadError> {
let current_hash = yedad.last_step(ladder_id, ladder_height)?;
Ok(Self {
current_hash,
ladder_id,
position: 0,
})
}
pub fn from_state(current_hash: Hash, ladder_id: u32, position: usize) -> Self {
Self {
current_hash,
ladder_id,
position,
}
}
pub fn verify_and_advance(
&mut self,
proof: &Hash,
ladder_height: usize,
) -> Result<(), YedadError> {
if self.position >= ladder_height {
return Err(YedadError::LadderExhausted);
}
let check = double_hash(proof);
if !ct_eq(&check, &self.current_hash) {
return Err(YedadError::InvalidProof);
}
self.current_hash = *proof;
self.position += 1;
Ok(())
}
pub fn is_exhausted(&self, ladder_height: usize) -> bool {
self.position >= ladder_height
}
pub fn remaining(&self, ladder_height: usize) -> usize {
ladder_height.saturating_sub(self.position)
}
pub fn state(&self) -> (Hash, u32, usize) {
(self.current_hash, self.ladder_id, self.position)
}
pub fn current_hash(&self) -> &Hash {
&self.current_hash
}
pub fn ladder_id(&self) -> u32 {
self.ladder_id
}
pub fn position(&self) -> usize {
self.position
}
}
#[derive(Debug, Clone)]
pub struct Account {
yedad: Yedad,
ladder_id: u32,
position: usize,
current_hash: Hash,
}
impl Account {
pub fn new_private(
private_key: &[u8],
ladder_height: usize,
) -> Result<(Self, Hash), YedadError> {
let yedad = Yedad::from_private_key(private_key);
let ladder_id = 1;
let current_hash = yedad.last_step(ladder_id, ladder_height)?;
let account = Self {
yedad,
ladder_id,
position: 0,
current_hash,
};
Ok((account, current_hash))
}
pub fn load(
private_key: &[u8],
ladder_id: u32,
current_hash: &Hash,
ladder_height: usize,
) -> Result<Self, YedadError> {
let yedad = Yedad::from_private_key(private_key);
let root = yedad.root(ladder_id)?;
if let Some(idx) = find_position(root, current_hash, ladder_height) {
let position = ladder_height - idx;
Ok(Self {
yedad,
ladder_id,
position,
current_hash: *current_hash,
})
} else {
Err(YedadError::HashNotFound)
}
}
pub fn next_proof(&self, ladder_height: usize) -> Result<Hash, YedadError> {
if self.position >= ladder_height {
return Err(YedadError::LadderExhausted);
}
let proof_index = ladder_height - 1 - self.position;
self.yedad.step(self.ladder_id, proof_index)
}
pub fn advance(&mut self, proof: &Hash, ladder_height: usize) -> Result<(), YedadError> {
if self.position >= ladder_height {
return Err(YedadError::LadderExhausted);
}
let check = double_hash(proof);
if !ct_eq(&check, &self.current_hash) {
return Err(YedadError::InvalidProof);
}
self.current_hash = *proof;
self.position += 1;
Ok(())
}
pub fn state(&self) -> (u32, Hash) {
(self.ladder_id, self.current_hash)
}
pub fn remaining(&self, ladder_height: usize) -> usize {
ladder_height.saturating_sub(self.position)
}
pub fn yedad(&self) -> &Yedad {
&self.yedad
}
pub fn ladder_id(&self) -> u32 {
self.ladder_id
}
pub fn position(&self) -> usize {
self.position
}
pub fn current_hash(&self) -> &Hash {
&self.current_hash
}
}