use core::slice;
use ark_std::rand::{CryptoRng, RngCore};
use serde::{Deserialize, Serialize};
#[cfg(feature = "tracing")]
use tracing::instrument;
use zerocopy::IntoBytes;
use crate::{
bits::Bits,
engines::EngineId,
hash::{Hash, BLAKE3, ENGINES},
transcript::{
codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverState, VerificationResult,
VerifierMessage, VerifierState,
},
utils::zip_strict,
verify,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Config {
pub hash_id: EngineId,
pub threshold: u64,
}
pub fn threshold(difficulty: Bits) -> u64 {
assert!((0.0..=60.0).contains(&difficulty.into()));
let threshold = (64.0 - f64::from(difficulty)).exp2().ceil();
#[allow(clippy::cast_sign_loss)]
if threshold >= u64::MAX as f64 {
u64::MAX
} else {
threshold as u64
}
}
pub fn difficulty(threshold: u64) -> Bits {
Bits::from(64.0 - (threshold as f64).log2())
}
impl Config {
pub fn from_difficulty(difficulty: Bits) -> Self {
Self {
hash_id: BLAKE3,
threshold: threshold(difficulty),
}
}
pub fn difficulty(&self) -> Bits {
difficulty(self.threshold)
}
#[cfg_attr(feature = "tracing", instrument(skip_all, fields(engine)))]
pub fn prove<H, R>(&self, prover_state: &mut ProverState<H, R>)
where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
[u8; 32]: Decoding<[H::U]>,
U64: Codec<[H::U]>,
{
if self.threshold == u64::MAX {
return;
}
let engine = ENGINES
.retrieve(self.hash_id)
.expect("Hash Engine not found");
#[cfg(feature = "tracing")]
tracing::Span::current().record("engine", engine.name().as_ref());
let batch_size = engine.preferred_batch_size();
let challenge: [u8; 32] = prover_state.verifier_message();
#[cfg(not(feature = "parallel"))]
let nonce = (0_u64..)
.step_by(batch_size)
.find_map({
let mut inputs = vec![[0u8; 64]; batch_size];
for input in &mut inputs {
input[..32].copy_from_slice(&challenge);
}
let mut outputs = vec![Hash::default(); batch_size];
move |nonce| {
let input_len = inputs.len();
for (input, nonce) in zip_strict(inputs.iter_mut(), (nonce..).take(input_len)) {
input[32..40].copy_from_slice(&nonce.to_le_bytes());
}
engine.hash_many(64, inputs.as_bytes(), &mut outputs);
let output_len = outputs.len();
for (output, nonce) in zip_strict(outputs.iter(), (nonce..).take(output_len)) {
let value = u64::from_le_bytes(output.0[..8].try_into().unwrap());
if value <= self.threshold {
return Some(nonce);
}
}
None
}
})
.expect("Proof of Work failed to solve.");
#[cfg(feature = "parallel")]
let nonce = {
use std::sync::atomic::{AtomicU64, Ordering};
let global_min = AtomicU64::new(u64::MAX);
rayon::broadcast(|ctx| {
let thread_nonces =
((batch_size * ctx.index()) as u64..).step_by(batch_size * ctx.num_threads());
let mut inputs = vec![[0u8; 64]; batch_size];
for input in &mut inputs {
input[..32].copy_from_slice(&challenge);
}
let mut outputs = vec![Hash::default(); batch_size];
for batch_start in thread_nonces {
if batch_start >= global_min.load(Ordering::Relaxed) {
break;
}
let input_len = inputs.len();
for (input, nonce) in
zip_strict(inputs.iter_mut(), (batch_start..).take(input_len))
{
input[32..40].copy_from_slice(&nonce.to_le_bytes());
}
engine.hash_many(64, inputs.as_bytes(), &mut outputs);
let output_len = outputs.len();
for (output, nonce) in
zip_strict(outputs.iter(), (batch_start..).take(output_len))
{
let value = u64::from_le_bytes(output.0[..8].try_into().unwrap());
if value <= self.threshold {
global_min.fetch_min(nonce, Ordering::SeqCst);
break;
}
}
}
});
let nonce = global_min.load(Ordering::SeqCst);
assert!(nonce != u64::MAX, "Proof of Work failed to solve.");
nonce
};
prover_state.prover_message(&U64(nonce));
}
pub fn verify<H>(&self, verifier_state: &mut VerifierState<H>) -> VerificationResult<()>
where
H: DuplexSpongeInterface,
[u8; 32]: Decoding<[H::U]>,
U64: Codec<[H::U]>,
{
if self.threshold == u64::MAX {
return Ok(());
}
let engine = ENGINES.retrieve(self.hash_id);
verify!(engine.is_some());
let engine = engine.unwrap();
let challenge: [u8; 32] = verifier_state.verifier_message();
let nonce: U64 = verifier_state.prover_message()?;
let mut input = [0u8; 64];
input[..32].copy_from_slice(&challenge);
input[32..40].copy_from_slice(&nonce.0.to_le_bytes());
let mut output = Hash::default();
engine.hash_many(64, &input, slice::from_mut(&mut output));
let value = u64::from_le_bytes(output.0[..8].try_into().unwrap());
verify!(value <= self.threshold);
Ok(())
}
}
#[cfg(test)]
mod tests {
use proptest::proptest;
use super::*;
#[test]
fn test_threshold_integer() {
assert_eq!(threshold(Bits::new(0.0)), u64::MAX);
assert_eq!(threshold(Bits::new(60.0)), 1 << 4);
proptest!(|(bits in 1_u64..=60)| {
assert_eq!(threshold(Bits::new(bits as f64)), 1 << (64 - bits));
});
}
#[test]
fn test_threshold_fractional() {
proptest!(|(bits in 0.0..=60.0)| {
let t = threshold(Bits::new(bits));
let min = threshold(Bits::new(bits.ceil()));
let max = threshold(Bits::new(bits.floor()));
assert!((min..=max).contains(&t));
});
}
#[test]
fn test_threshold_monotonic() {
proptest!(|(bits in 0.0..=59.0, delta in 0.0..=1.0)| {
let low = threshold(Bits::new(bits + delta));
let high = threshold(Bits::new(bits));
assert!(low <= high);
});
}
#[test]
fn test_difficulty_integer() {
assert_eq!(difficulty(u64::MAX), Bits::new(0.0));
assert_eq!(difficulty(1 << 4), Bits::new(60.0));
proptest!(|(bits in 1_u64..=60)| {
assert_eq!(difficulty(1 << (64 - bits)), Bits::new(bits as f64));
});
}
#[test]
fn test_difficulty_fractional() {
proptest!(|(threshold in 16_u64..)| {
let d = difficulty(threshold);
let min = difficulty(threshold.checked_next_power_of_two().unwrap_or(u64::MAX));
let max = Bits::new(f64::from(min) + 1.0);
assert!((min..=max).contains(&d));
});
}
#[test]
fn test_difficulty_monotonic() {
proptest!(|(threshold in 16_u64.., delta: u64)| {
let high = difficulty(threshold);
let low = difficulty(threshold.saturating_add(delta));
assert!(low <= high);
});
}
}