Skip to main content

uhash_prover/cpu/
parallel.rs

1//! Multi-threaded CPU solver.
2
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5
6use anyhow::Result;
7use uhash_core::{meets_difficulty, UniversalHash};
8
9use crate::solver::{ProofResult, Solver};
10
11/// Multi-threaded CPU solver. Spawns OS threads with interleaved nonce
12/// assignment to avoid contention.
13pub struct ParallelCpuSolver {
14    threads: usize,
15}
16
17impl ParallelCpuSolver {
18    /// Create a new parallel solver with the given thread count.
19    /// If `threads` is 0, it will be resolved in `recommended_lanes`.
20    pub fn new(threads: usize) -> Self {
21        Self { threads }
22    }
23
24    fn effective_threads(&self) -> usize {
25        if self.threads == 0 {
26            std::thread::available_parallelism()
27                .map(|n| n.get())
28                .unwrap_or(1)
29        } else {
30            self.threads
31        }
32    }
33}
34
35impl Solver for ParallelCpuSolver {
36    fn backend_name(&self) -> &'static str {
37        "cpu"
38    }
39
40    fn recommended_lanes(&mut self, requested: usize) -> usize {
41        if requested == 0 {
42            self.effective_threads() * 1024
43        } else {
44            requested
45        }
46    }
47
48    fn find_proof_batch(
49        &mut self,
50        header_without_nonce: &[u8],
51        start_nonce: u64,
52        lanes: usize,
53        difficulty: u32,
54    ) -> Result<ProofResult> {
55        let threads = self.effective_threads();
56        let found = Arc::new(AtomicBool::new(false));
57        let winner = Arc::new(Mutex::new(None::<(u64, [u8; 32])>));
58        let challenge = Arc::new(header_without_nonce.to_vec());
59        let mut handles = Vec::with_capacity(threads);
60
61        for tid in 0..threads {
62            let found = Arc::clone(&found);
63            let winner = Arc::clone(&winner);
64            let challenge = Arc::clone(&challenge);
65            handles.push(std::thread::spawn(move || {
66                let mut hasher = UniversalHash::new();
67                let mut input = Vec::with_capacity(challenge.len() + 8);
68                let mut lane = tid;
69                while lane < lanes && !found.load(Ordering::Relaxed) {
70                    let nonce = start_nonce.saturating_add(lane as u64);
71                    input.clear();
72                    input.extend_from_slice(&challenge);
73                    input.extend_from_slice(&nonce.to_le_bytes());
74                    let hash = hasher.hash(&input);
75                    if meets_difficulty(&hash, difficulty) {
76                        if !found.swap(true, Ordering::Relaxed) {
77                            let mut guard = winner.lock().expect("winner mutex poisoned");
78                            *guard = Some((nonce, hash));
79                        }
80                        break;
81                    }
82                    lane += threads;
83                }
84            }));
85        }
86
87        for h in handles {
88            h.join()
89                .map_err(|_| anyhow::anyhow!("cpu mining worker panicked"))?;
90        }
91        let result = *winner.lock().expect("winner mutex poisoned");
92        Ok(result)
93    }
94
95    fn benchmark_hashes(
96        &mut self,
97        header_without_nonce: &[u8],
98        start_nonce: u64,
99        lanes: usize,
100    ) -> Result<usize> {
101        let threads = self.effective_threads();
102        let challenge = Arc::new(header_without_nonce.to_vec());
103        let mut handles = Vec::with_capacity(threads);
104
105        for tid in 0..threads {
106            let challenge = Arc::clone(&challenge);
107            handles.push(std::thread::spawn(move || {
108                let mut hasher = UniversalHash::new();
109                let mut input = Vec::with_capacity(challenge.len() + 8);
110                let mut lane = tid;
111                while lane < lanes {
112                    let nonce = start_nonce.saturating_add(lane as u64);
113                    input.clear();
114                    input.extend_from_slice(&challenge);
115                    input.extend_from_slice(&nonce.to_le_bytes());
116                    let _ = hasher.hash(&input);
117                    lane += threads;
118                }
119            }));
120        }
121
122        for h in handles {
123            h.join()
124                .map_err(|_| anyhow::anyhow!("cpu benchmark worker panicked"))?;
125        }
126        Ok(lanes)
127    }
128}