Skip to main content

nox_crypto/sphinx/
pow.rs

1//! Hashcash-style `PoW` for Sphinx `DoS` prevention. SHA-256 and Blake3 algorithms with parallel solving.
2
3#[cfg(feature = "rayon-pow")]
4use rayon::prelude::*;
5use sha2::{Digest, Sha256};
6#[cfg(feature = "rayon-pow")]
7use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
8#[cfg(feature = "rayon-pow")]
9use std::sync::Arc;
10use thiserror::Error;
11
12pub const DEFAULT_THREADS: usize = 0;
13
14#[cfg(feature = "rayon-pow")]
15const BATCH_SIZE: u64 = 10_000;
16
17/// Capped at 64 to prevent unsolvable puzzles (64 leading zeros ~ 1.8e19 hashes).
18pub const MAX_DIFFICULTY: u32 = 64;
19
20pub const MIN_DIFFICULTY: u32 = 0;
21
22#[derive(Debug, Error)]
23pub enum PowError {
24    #[error("PoW difficulty {difficulty} exceeds maximum allowed {MAX_DIFFICULTY}")]
25    DifficultyTooHigh { difficulty: u32 },
26}
27
28/// Swappable `PoW` hash algorithm.
29pub trait PowAlgorithm: Send + Sync {
30    fn hash(&self, data: &[u8]) -> [u8; 32];
31    fn name(&self) -> &'static str;
32}
33
34/// SHA-256 based `PoW`.
35#[derive(Debug, Clone, Copy, Default)]
36pub struct Sha256Pow;
37
38impl PowAlgorithm for Sha256Pow {
39    fn hash(&self, data: &[u8]) -> [u8; 32] {
40        let mut hasher = Sha256::new();
41        hasher.update(data);
42        hasher.finalize().into()
43    }
44
45    fn name(&self) -> &'static str {
46        "SHA-256"
47    }
48}
49
50/// Blake3 based `PoW` (~3x faster than SHA-256).
51#[derive(Debug, Clone, Copy, Default)]
52pub struct Blake3Pow;
53
54impl PowAlgorithm for Blake3Pow {
55    fn hash(&self, data: &[u8]) -> [u8; 32] {
56        blake3::hash(data).into()
57    }
58
59    fn name(&self) -> &'static str {
60        "BLAKE3"
61    }
62}
63
64/// Counts leading zero bits in a hash using u64-wide operations.
65#[inline]
66#[must_use]
67pub fn count_leading_zeros(hash: &[u8]) -> u32 {
68    let mut zeros = 0u32;
69    let mut chunks = hash.chunks_exact(8);
70    for chunk in chunks.by_ref() {
71        let word = u64::from_be_bytes([
72            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
73        ]);
74        if word == 0 {
75            zeros += 64;
76        } else {
77            zeros += word.leading_zeros();
78            return zeros;
79        }
80    }
81    for byte in chunks.remainder() {
82        if *byte == 0 {
83            zeros += 8;
84        } else {
85            zeros += byte.leading_zeros();
86            return zeros;
87        }
88    }
89    zeros
90}
91
92/// Returns true if the hash has at least `difficulty` leading zero bits.
93#[inline]
94#[must_use]
95pub fn meets_difficulty(hash: &[u8], difficulty: u32) -> bool {
96    count_leading_zeros(hash) >= difficulty
97}
98
99/// Parallel `PoW` solver using rayon work-stealing.
100#[derive(Clone)]
101pub struct PowSolver<A: PowAlgorithm> {
102    algorithm: A,
103    #[cfg_attr(not(feature = "rayon-pow"), allow(dead_code))]
104    num_threads: usize,
105}
106
107impl<A: PowAlgorithm + Clone + 'static> PowSolver<A> {
108    /// Creates a new solver. `num_threads` = 0 means all cores.
109    pub fn new(algorithm: A, num_threads: usize) -> Self {
110        Self {
111            algorithm,
112            num_threads,
113        }
114    }
115
116    /// Finds a nonce producing a hash with `difficulty` leading zero bits.
117    #[cfg(feature = "rayon-pow")]
118    pub fn solve(
119        &self,
120        header_data: &[u8],
121        difficulty: u32,
122        start_nonce: u64,
123    ) -> Result<u64, PowError> {
124        if difficulty > MAX_DIFFICULTY {
125            return Err(PowError::DifficultyTooHigh { difficulty });
126        }
127
128        if difficulty == 0 {
129            return Ok(start_nonce);
130        }
131
132        let threads = if self.num_threads == 0 {
133            rayon::current_num_threads()
134        } else {
135            self.num_threads
136        };
137
138        if difficulty <= 8 || threads == 1 {
139            Ok(self.solve_single_threaded(header_data, difficulty, start_nonce))
140        } else {
141            Ok(self.solve_parallel(header_data, difficulty, start_nonce))
142        }
143    }
144
145    /// Single-threaded solve (WASM-compatible path).
146    #[cfg(not(feature = "rayon-pow"))]
147    pub fn solve(
148        &self,
149        header_data: &[u8],
150        difficulty: u32,
151        start_nonce: u64,
152    ) -> Result<u64, PowError> {
153        if difficulty > MAX_DIFFICULTY {
154            return Err(PowError::DifficultyTooHigh { difficulty });
155        }
156
157        if difficulty == 0 {
158            return Ok(start_nonce);
159        }
160
161        Ok(self.solve_single_threaded(header_data, difficulty, start_nonce))
162    }
163
164    fn solve_single_threaded(&self, header_data: &[u8], difficulty: u32, start_nonce: u64) -> u64 {
165        let mut nonce = start_nonce;
166        let mut buffer = Vec::with_capacity(header_data.len() + 8);
167        buffer.extend_from_slice(header_data);
168        buffer.extend_from_slice(&[0u8; 8]);
169
170        loop {
171            let nonce_pos = buffer.len() - 8;
172            buffer[nonce_pos..].copy_from_slice(&nonce.to_be_bytes());
173
174            let hash = self.algorithm.hash(&buffer);
175            if meets_difficulty(&hash, difficulty) {
176                return nonce;
177            }
178            nonce = nonce.wrapping_add(1);
179        }
180    }
181
182    #[cfg(feature = "rayon-pow")]
183    fn solve_parallel(&self, header_data: &[u8], difficulty: u32, start_nonce: u64) -> u64 {
184        let found = Arc::new(AtomicBool::new(false));
185        let result = Arc::new(AtomicU64::new(0));
186        let header = header_data.to_vec();
187        let algo = self.algorithm.clone();
188
189        (0..u64::MAX / BATCH_SIZE).into_par_iter().find_any(|&i| {
190            if found.load(Ordering::Relaxed) {
191                return true;
192            }
193
194            let batch_start = start_nonce.wrapping_add(i * BATCH_SIZE);
195            let mut buffer = Vec::with_capacity(header.len() + 8);
196            buffer.extend_from_slice(&header);
197            buffer.extend_from_slice(&[0u8; 8]);
198
199            for offset in 0..BATCH_SIZE {
200                if found.load(Ordering::Relaxed) {
201                    return true;
202                }
203
204                let nonce = batch_start.wrapping_add(offset);
205                let nonce_pos = buffer.len() - 8;
206                buffer[nonce_pos..].copy_from_slice(&nonce.to_be_bytes());
207
208                let hash = algo.hash(&buffer);
209                if meets_difficulty(&hash, difficulty) {
210                    result.store(nonce, Ordering::Relaxed);
211                    found.store(true, Ordering::Relaxed);
212                    return true;
213                }
214            }
215            false
216        });
217
218        result.load(Ordering::Relaxed)
219    }
220
221    /// Verifies a nonce. Returns `false` for difficulty > `MAX_DIFFICULTY`.
222    #[inline]
223    pub fn verify(&self, header_data: &[u8], nonce: u64, difficulty: u32) -> bool {
224        if difficulty == 0 {
225            return true;
226        }
227
228        if difficulty > MAX_DIFFICULTY {
229            return false;
230        }
231
232        let mut buffer = Vec::with_capacity(header_data.len() + 8);
233        buffer.extend_from_slice(header_data);
234        buffer.extend_from_slice(&nonce.to_be_bytes());
235
236        let hash = self.algorithm.hash(&buffer);
237        meets_difficulty(&hash, difficulty)
238    }
239}
240
241/// SHA-256 solver using all available cores.
242#[must_use]
243pub fn default_solver() -> PowSolver<Sha256Pow> {
244    PowSolver::new(Sha256Pow, DEFAULT_THREADS)
245}
246
247/// Blake3 solver using all available cores.
248#[must_use]
249pub fn fast_solver() -> PowSolver<Blake3Pow> {
250    PowSolver::new(Blake3Pow, DEFAULT_THREADS)
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use std::time::Instant;
257
258    #[test]
259    fn test_pow_low_difficulty() {
260        let solver = default_solver();
261        let header_data = b"test_header_data";
262
263        let start = Instant::now();
264        let nonce = solver.solve(header_data, 4, 0).expect("solve failed");
265        let elapsed = start.elapsed();
266
267        assert!(solver.verify(header_data, nonce, 4));
268        println!(
269            "Low difficulty (4 bits): nonce={}, time={:?}",
270            nonce, elapsed
271        );
272        assert!(elapsed.as_millis() < 100, "Should resolve in <100ms");
273    }
274
275    #[test]
276    fn test_pow_moderate_difficulty() {
277        let solver = default_solver();
278        let header_data = b"sphinx_header_ephemeral_key_routing_info_mac";
279
280        let start = Instant::now();
281        let nonce = solver.solve(header_data, 16, 0).expect("solve failed");
282        let elapsed = start.elapsed();
283
284        assert!(solver.verify(header_data, nonce, 16));
285        println!(
286            "Moderate difficulty (16 bits): nonce={}, time={:?}",
287            nonce, elapsed
288        );
289    }
290
291    #[test]
292    fn test_pow_tamper_invalidates() {
293        let solver = default_solver();
294        let header_data = b"original_header";
295        let nonce = solver.solve(header_data, 12, 0).expect("solve failed");
296
297        assert!(solver.verify(header_data, nonce, 12));
298
299        let tampered = b"tampered_header";
300        assert!(!solver.verify(tampered, nonce, 12));
301    }
302
303    #[test]
304    fn test_zero_difficulty() {
305        let solver = default_solver();
306        let header_data = b"any_data";
307
308        let nonce = solver.solve(header_data, 0, 42).expect("solve failed");
309        assert_eq!(nonce, 42);
310        assert!(solver.verify(header_data, nonce, 0));
311    }
312
313    #[test]
314    fn test_count_leading_zeros() {
315        assert_eq!(count_leading_zeros(&[0x00, 0x00, 0x00, 0xFF]), 24);
316        assert_eq!(count_leading_zeros(&[0x00, 0x0F, 0x00, 0x00]), 12);
317        assert_eq!(count_leading_zeros(&[0x80, 0x00, 0x00, 0x00]), 0);
318        assert_eq!(count_leading_zeros(&[0x00, 0x00, 0x00, 0x00]), 32);
319    }
320
321    #[test]
322    fn test_difficulty_bounds_enforced() {
323        let solver = default_solver();
324        let header_data = b"test_data";
325
326        let result = solver.solve(header_data, MAX_DIFFICULTY + 1, 0);
327        assert!(
328            matches!(result, Err(PowError::DifficultyTooHigh { difficulty }) if difficulty == MAX_DIFFICULTY + 1),
329            "Difficulty {} should be rejected",
330            MAX_DIFFICULTY + 1
331        );
332
333        assert!(
334            !solver.verify(header_data, 0, MAX_DIFFICULTY + 1),
335            "verify() should return false for excessive difficulty"
336        );
337    }
338}