1#[cfg(feature = "rayon-pow")]
4use rayon::prelude::*;
5use sha2::{Digest, Sha256};
6#[cfg(feature = "rayon-pow")]
7use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
8#[cfg(feature = "rayon-pow")]
9use std::sync::Arc;
10use thiserror::Error;
11
12pub const DEFAULT_THREADS: usize = 0;
13
14#[cfg(feature = "rayon-pow")]
15const BATCH_SIZE: u64 = 10_000;
16
17pub const MAX_DIFFICULTY: u32 = 64;
19
20pub const MIN_DIFFICULTY: u32 = 0;
21
22#[derive(Debug, Error)]
23pub enum PowError {
24 #[error("PoW difficulty {difficulty} exceeds maximum allowed {MAX_DIFFICULTY}")]
25 DifficultyTooHigh { difficulty: u32 },
26}
27
28pub trait PowAlgorithm: Send + Sync {
30 fn hash(&self, data: &[u8]) -> [u8; 32];
31 fn name(&self) -> &'static str;
32}
33
34#[derive(Debug, Clone, Copy, Default)]
36pub struct Sha256Pow;
37
38impl PowAlgorithm for Sha256Pow {
39 fn hash(&self, data: &[u8]) -> [u8; 32] {
40 let mut hasher = Sha256::new();
41 hasher.update(data);
42 hasher.finalize().into()
43 }
44
45 fn name(&self) -> &'static str {
46 "SHA-256"
47 }
48}
49
50#[derive(Debug, Clone, Copy, Default)]
52pub struct Blake3Pow;
53
54impl PowAlgorithm for Blake3Pow {
55 fn hash(&self, data: &[u8]) -> [u8; 32] {
56 blake3::hash(data).into()
57 }
58
59 fn name(&self) -> &'static str {
60 "BLAKE3"
61 }
62}
63
64#[inline]
66#[must_use]
67pub fn count_leading_zeros(hash: &[u8]) -> u32 {
68 let mut zeros = 0u32;
69 let mut chunks = hash.chunks_exact(8);
70 for chunk in chunks.by_ref() {
71 let word = u64::from_be_bytes([
72 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
73 ]);
74 if word == 0 {
75 zeros += 64;
76 } else {
77 zeros += word.leading_zeros();
78 return zeros;
79 }
80 }
81 for byte in chunks.remainder() {
82 if *byte == 0 {
83 zeros += 8;
84 } else {
85 zeros += byte.leading_zeros();
86 return zeros;
87 }
88 }
89 zeros
90}
91
92#[inline]
94#[must_use]
95pub fn meets_difficulty(hash: &[u8], difficulty: u32) -> bool {
96 count_leading_zeros(hash) >= difficulty
97}
98
99#[derive(Clone)]
101pub struct PowSolver<A: PowAlgorithm> {
102 algorithm: A,
103 #[cfg_attr(not(feature = "rayon-pow"), allow(dead_code))]
104 num_threads: usize,
105}
106
107impl<A: PowAlgorithm + Clone + 'static> PowSolver<A> {
108 pub fn new(algorithm: A, num_threads: usize) -> Self {
110 Self {
111 algorithm,
112 num_threads,
113 }
114 }
115
116 #[cfg(feature = "rayon-pow")]
118 pub fn solve(
119 &self,
120 header_data: &[u8],
121 difficulty: u32,
122 start_nonce: u64,
123 ) -> Result<u64, PowError> {
124 if difficulty > MAX_DIFFICULTY {
125 return Err(PowError::DifficultyTooHigh { difficulty });
126 }
127
128 if difficulty == 0 {
129 return Ok(start_nonce);
130 }
131
132 let threads = if self.num_threads == 0 {
133 rayon::current_num_threads()
134 } else {
135 self.num_threads
136 };
137
138 if difficulty <= 8 || threads == 1 {
139 Ok(self.solve_single_threaded(header_data, difficulty, start_nonce))
140 } else {
141 Ok(self.solve_parallel(header_data, difficulty, start_nonce))
142 }
143 }
144
145 #[cfg(not(feature = "rayon-pow"))]
147 pub fn solve(
148 &self,
149 header_data: &[u8],
150 difficulty: u32,
151 start_nonce: u64,
152 ) -> Result<u64, PowError> {
153 if difficulty > MAX_DIFFICULTY {
154 return Err(PowError::DifficultyTooHigh { difficulty });
155 }
156
157 if difficulty == 0 {
158 return Ok(start_nonce);
159 }
160
161 Ok(self.solve_single_threaded(header_data, difficulty, start_nonce))
162 }
163
164 fn solve_single_threaded(&self, header_data: &[u8], difficulty: u32, start_nonce: u64) -> u64 {
165 let mut nonce = start_nonce;
166 let mut buffer = Vec::with_capacity(header_data.len() + 8);
167 buffer.extend_from_slice(header_data);
168 buffer.extend_from_slice(&[0u8; 8]);
169
170 loop {
171 let nonce_pos = buffer.len() - 8;
172 buffer[nonce_pos..].copy_from_slice(&nonce.to_be_bytes());
173
174 let hash = self.algorithm.hash(&buffer);
175 if meets_difficulty(&hash, difficulty) {
176 return nonce;
177 }
178 nonce = nonce.wrapping_add(1);
179 }
180 }
181
182 #[cfg(feature = "rayon-pow")]
183 fn solve_parallel(&self, header_data: &[u8], difficulty: u32, start_nonce: u64) -> u64 {
184 let found = Arc::new(AtomicBool::new(false));
185 let result = Arc::new(AtomicU64::new(0));
186 let header = header_data.to_vec();
187 let algo = self.algorithm.clone();
188
189 (0..u64::MAX / BATCH_SIZE).into_par_iter().find_any(|&i| {
190 if found.load(Ordering::Relaxed) {
191 return true;
192 }
193
194 let batch_start = start_nonce.wrapping_add(i * BATCH_SIZE);
195 let mut buffer = Vec::with_capacity(header.len() + 8);
196 buffer.extend_from_slice(&header);
197 buffer.extend_from_slice(&[0u8; 8]);
198
199 for offset in 0..BATCH_SIZE {
200 if found.load(Ordering::Relaxed) {
201 return true;
202 }
203
204 let nonce = batch_start.wrapping_add(offset);
205 let nonce_pos = buffer.len() - 8;
206 buffer[nonce_pos..].copy_from_slice(&nonce.to_be_bytes());
207
208 let hash = algo.hash(&buffer);
209 if meets_difficulty(&hash, difficulty) {
210 result.store(nonce, Ordering::Relaxed);
211 found.store(true, Ordering::Relaxed);
212 return true;
213 }
214 }
215 false
216 });
217
218 result.load(Ordering::Relaxed)
219 }
220
221 #[inline]
223 pub fn verify(&self, header_data: &[u8], nonce: u64, difficulty: u32) -> bool {
224 if difficulty == 0 {
225 return true;
226 }
227
228 if difficulty > MAX_DIFFICULTY {
229 return false;
230 }
231
232 let mut buffer = Vec::with_capacity(header_data.len() + 8);
233 buffer.extend_from_slice(header_data);
234 buffer.extend_from_slice(&nonce.to_be_bytes());
235
236 let hash = self.algorithm.hash(&buffer);
237 meets_difficulty(&hash, difficulty)
238 }
239}
240
241#[must_use]
243pub fn default_solver() -> PowSolver<Sha256Pow> {
244 PowSolver::new(Sha256Pow, DEFAULT_THREADS)
245}
246
247#[must_use]
249pub fn fast_solver() -> PowSolver<Blake3Pow> {
250 PowSolver::new(Blake3Pow, DEFAULT_THREADS)
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use std::time::Instant;
257
258 #[test]
259 fn test_pow_low_difficulty() {
260 let solver = default_solver();
261 let header_data = b"test_header_data";
262
263 let start = Instant::now();
264 let nonce = solver.solve(header_data, 4, 0).expect("solve failed");
265 let elapsed = start.elapsed();
266
267 assert!(solver.verify(header_data, nonce, 4));
268 println!(
269 "Low difficulty (4 bits): nonce={}, time={:?}",
270 nonce, elapsed
271 );
272 assert!(elapsed.as_millis() < 100, "Should resolve in <100ms");
273 }
274
275 #[test]
276 fn test_pow_moderate_difficulty() {
277 let solver = default_solver();
278 let header_data = b"sphinx_header_ephemeral_key_routing_info_mac";
279
280 let start = Instant::now();
281 let nonce = solver.solve(header_data, 16, 0).expect("solve failed");
282 let elapsed = start.elapsed();
283
284 assert!(solver.verify(header_data, nonce, 16));
285 println!(
286 "Moderate difficulty (16 bits): nonce={}, time={:?}",
287 nonce, elapsed
288 );
289 }
290
291 #[test]
292 fn test_pow_tamper_invalidates() {
293 let solver = default_solver();
294 let header_data = b"original_header";
295 let nonce = solver.solve(header_data, 12, 0).expect("solve failed");
296
297 assert!(solver.verify(header_data, nonce, 12));
298
299 let tampered = b"tampered_header";
300 assert!(!solver.verify(tampered, nonce, 12));
301 }
302
303 #[test]
304 fn test_zero_difficulty() {
305 let solver = default_solver();
306 let header_data = b"any_data";
307
308 let nonce = solver.solve(header_data, 0, 42).expect("solve failed");
309 assert_eq!(nonce, 42);
310 assert!(solver.verify(header_data, nonce, 0));
311 }
312
313 #[test]
314 fn test_count_leading_zeros() {
315 assert_eq!(count_leading_zeros(&[0x00, 0x00, 0x00, 0xFF]), 24);
316 assert_eq!(count_leading_zeros(&[0x00, 0x0F, 0x00, 0x00]), 12);
317 assert_eq!(count_leading_zeros(&[0x80, 0x00, 0x00, 0x00]), 0);
318 assert_eq!(count_leading_zeros(&[0x00, 0x00, 0x00, 0x00]), 32);
319 }
320
321 #[test]
322 fn test_difficulty_bounds_enforced() {
323 let solver = default_solver();
324 let header_data = b"test_data";
325
326 let result = solver.solve(header_data, MAX_DIFFICULTY + 1, 0);
327 assert!(
328 matches!(result, Err(PowError::DifficultyTooHigh { difficulty }) if difficulty == MAX_DIFFICULTY + 1),
329 "Difficulty {} should be rejected",
330 MAX_DIFFICULTY + 1
331 );
332
333 assert!(
334 !solver.verify(header_data, 0, MAX_DIFFICULTY + 1),
335 "verify() should return false for excessive difficulty"
336 );
337 }
338}