use blake3::Hasher;
use borsh::{BorshDeserialize, BorshSerialize};
use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
use subtle::ConstantTimeEq;
pub const MAX_CLIENT_POW_DIFFICULTY: u8 = 24;
pub const MAX_SOLVE_ITERATIONS: u64 = 1u64 << (MAX_CLIENT_POW_DIFFICULTY as u32 + 8);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PowError {
DifficultyTooHigh { demanded: u8, cap: u8 },
Exhausted,
}
impl core::fmt::Display for PowError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
PowError::DifficultyTooHigh { demanded, cap } => write!(
f,
"server demanded PoW difficulty {demanded} exceeding client cap {cap}"
),
PowError::Exhausted => write!(f, "PoW solve exhausted the iteration bound"),
}
}
}
#[derive(BorshSerialize, BorshDeserialize, SerdeSerialize, SerdeDeserialize, Debug, Clone)]
pub struct PoWChallenge {
pub nonce: [u8; 32], pub difficulty: u8, }
#[derive(BorshSerialize, BorshDeserialize, SerdeSerialize, SerdeDeserialize, Debug, Clone)]
pub struct PoWSolution {
pub nonce: [u8; 32],
pub solution: u64,
}
impl PoWChallenge {
pub fn new_stateless(difficulty: u8, client_id: &[u8], secret: &[u8; 32]) -> Self {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let mut nonce = [0u8; 32];
nonce[0..8].copy_from_slice(×tamp.to_le_bytes());
let mut hasher = Hasher::new_keyed(secret);
hasher.update(×tamp.to_le_bytes());
hasher.update(client_id);
let mac = hasher.finalize();
nonce[8..32].copy_from_slice(&mac.as_bytes()[0..24]);
Self { nonce, difficulty }
}
pub fn verify(&self, solution: &PoWSolution, client_id: &[u8], secret: &[u8; 32]) -> bool {
if self.nonce != solution.nonce {
return false;
}
let timestamp_bytes: [u8; 8] = self.nonce[0..8].try_into().unwrap_or_default();
let timestamp = u64::from_le_bytes(timestamp_bytes);
let mut hasher = Hasher::new_keyed(secret);
hasher.update(×tamp_bytes);
hasher.update(client_id);
let mac = hasher.finalize();
if !bool::from(self.nonce[8..32].ct_eq(&mac.as_bytes()[0..24])) {
return false;
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if now < timestamp || now > timestamp + 120 {
return false; }
let hash = compute_blake3_hash(&self.nonce, solution.solution);
check_leading_zeros(&hash, self.difficulty)
}
pub fn solve(&self) -> Result<PoWSolution, PowError> {
self.solve_with_bound(MAX_SOLVE_ITERATIONS)
}
pub fn solve_capped(&self, max_difficulty: u8) -> Result<PoWSolution, PowError> {
if self.difficulty > max_difficulty {
return Err(PowError::DifficultyTooHigh {
demanded: self.difficulty,
cap: max_difficulty,
});
}
self.solve()
}
fn solve_with_bound(&self, max_iters: u64) -> Result<PoWSolution, PowError> {
for solution in 0..max_iters {
let hash = compute_blake3_hash(&self.nonce, solution);
if check_leading_zeros(&hash, self.difficulty) {
return Ok(PoWSolution {
nonce: self.nonce,
solution,
});
}
}
Err(PowError::Exhausted)
}
}
fn compute_blake3_hash(nonce: &[u8; 32], solution: u64) -> [u8; 32] {
let mut hasher = Hasher::new();
hasher.update(nonce);
hasher.update(&solution.to_le_bytes());
*hasher.finalize().as_bytes()
}
fn check_leading_zeros(hash: &[u8], difficulty: u8) -> bool {
let mut zeros = 0;
for &byte in hash {
if byte == 0 {
zeros += 8;
} else {
zeros += byte.leading_zeros() as u8;
break;
}
}
zeros >= difficulty
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pow_stateless_verify() {
let secret = [42u8; 32];
let client_id = b"127.0.0.1";
let challenge = PoWChallenge::new_stateless(8, client_id, &secret);
let solution = challenge.solve().expect("difficulty 8 is solvable");
assert!(challenge.verify(&solution, client_id, &secret));
}
#[test]
fn test_pow_invalid_mac() {
let secret = [42u8; 32];
let client_id = b"127.0.0.1";
let mut challenge = PoWChallenge::new_stateless(8, client_id, &secret);
challenge.nonce[10] ^= 0xFF;
let solution = challenge.solve().expect("difficulty 8 is solvable");
assert!(!challenge.verify(&solution, client_id, &secret));
}
#[test]
fn test_pow_invalid_client() {
let secret = [42u8; 32];
let client_id = b"127.0.0.1";
let other_client = b"192.168.1.1";
let challenge = PoWChallenge::new_stateless(8, client_id, &secret);
let solution = challenge.solve().expect("difficulty 8 is solvable");
assert!(!challenge.verify(&solution, other_client, &secret));
}
#[test]
fn solve_capped_rejects_oversized_difficulty() {
let challenge = PoWChallenge {
nonce: [7u8; 32],
difficulty: 255,
};
match challenge.solve_capped(MAX_CLIENT_POW_DIFFICULTY) {
Err(PowError::DifficultyTooHigh { demanded, cap }) => {
assert_eq!(demanded, 255);
assert_eq!(cap, MAX_CLIENT_POW_DIFFICULTY);
}
other => panic!("expected DifficultyTooHigh, got {:?}", other),
}
}
#[test]
fn solve_capped_accepts_within_cap() {
let secret = [9u8; 32];
let client_id = b"127.0.0.1";
let challenge = PoWChallenge::new_stateless(8, client_id, &secret);
assert!(challenge.difficulty <= MAX_CLIENT_POW_DIFFICULTY);
let solution = challenge
.solve_capped(MAX_CLIENT_POW_DIFFICULTY)
.expect("difficulty 8 is solvable");
assert!(challenge.verify(&solution, client_id, &secret));
}
#[test]
fn solve_is_bounded_and_fails_closed() {
let challenge = PoWChallenge {
nonce: [3u8; 32],
difficulty: 250,
};
assert!(matches!(
challenge.solve_with_bound(1_000),
Err(PowError::Exhausted)
));
}
#[test]
fn max_client_pow_difficulty_admits_the_server_max() {
assert!(MAX_CLIENT_POW_DIFFICULTY >= 20);
}
}