1use std::thread;
8
9use tracing::{debug, error};
10
11enum WorkerMsg {
15 Job(HashJob),
16 Shutdown,
17}
18
19pub enum HashJob {
21 Data {
23 piece: u32,
25 expected: irontide_core::Id20,
27 generation: u64,
29 data: Vec<u8>,
31 result_tx: tokio::sync::mpsc::Sender<HashResult>,
33 },
34 Streaming {
36 piece: u32,
38 expected: irontide_core::Id20,
40 generation: u64,
42 info_hash: irontide_core::Id20,
44 backend: std::sync::Arc<dyn crate::disk_backend::DiskIoBackend>,
46 result_tx: tokio::sync::mpsc::Sender<HashResult>,
48 },
49}
50
51impl std::fmt::Debug for HashJob {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 Self::Data {
55 piece, generation, ..
56 } => f
57 .debug_struct("HashJob::Data")
58 .field("piece", piece)
59 .field("generation", generation)
60 .finish_non_exhaustive(),
61 Self::Streaming {
62 piece, generation, ..
63 } => f
64 .debug_struct("HashJob::Streaming")
65 .field("piece", piece)
66 .field("generation", generation)
67 .finish_non_exhaustive(),
68 }
69 }
70}
71
72#[derive(Debug)]
74pub struct HashResult {
75 pub piece: u32,
77 pub passed: bool,
79 pub generation: u64,
81}
82
83pub struct HashPool {
89 job_tx: tokio::sync::mpsc::Sender<HashJob>,
91 worker_tx: std::sync::mpsc::SyncSender<WorkerMsg>,
93 workers: Vec<thread::JoinHandle<()>>,
95}
96
97#[allow(
98 clippy::missing_fields_in_debug,
99 reason = "intentionally omit internal channel fields from Debug output"
100)]
101impl std::fmt::Debug for HashPool {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("HashPool")
104 .field("workers", &self.workers.len())
105 .finish()
106 }
107}
108
109impl HashPool {
110 #[must_use]
116 pub fn new(num_workers: usize, job_capacity: usize) -> Self {
117 let (job_tx, mut job_async_rx) = tokio::sync::mpsc::channel::<HashJob>(job_capacity);
118 let (worker_tx, worker_rx) =
120 std::sync::mpsc::sync_channel::<WorkerMsg>(job_capacity + num_workers);
121
122 let bridge_tx = worker_tx.clone();
124 tokio::spawn(async move {
125 while let Some(job) = job_async_rx.recv().await {
126 if bridge_tx.send(WorkerMsg::Job(job)).is_err() {
127 break;
128 }
129 }
130 });
131
132 let worker_rx = std::sync::Arc::new(parking_lot::Mutex::new(worker_rx));
134 let mut workers = Vec::with_capacity(num_workers);
135 for id in 0..num_workers {
136 let rx = worker_rx.clone();
137 let handle = thread::Builder::new()
138 .name(format!("hash-worker-{id}"))
139 .spawn(move || {
140 Self::worker_loop(id, rx);
141 })
142 .expect("failed to spawn hash worker thread");
143 workers.push(handle);
144 }
145
146 Self {
147 job_tx,
148 worker_tx,
149 workers,
150 }
151 }
152
153 pub async fn submit(&self, job: HashJob) -> Result<(), HashJob> {
159 self.job_tx.send(job).await.map_err(|e| e.0)
160 }
161
162 #[allow(
163 clippy::needless_pass_by_value,
164 reason = "Arc is moved into worker thread — pass-by-value is the ownership transfer idiom"
165 )]
166 fn worker_loop(
167 id: usize,
168 rx: std::sync::Arc<parking_lot::Mutex<std::sync::mpsc::Receiver<WorkerMsg>>>,
169 ) {
170 loop {
171 let msg = {
172 let rx = rx.lock();
173 match rx.recv() {
174 Ok(msg) => msg,
175 Err(_) => break, }
177 };
178
179 let job = match msg {
180 WorkerMsg::Shutdown => break,
181 WorkerMsg::Job(job) => job,
182 };
183
184 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| match &job {
185 HashJob::Data { data, expected, .. } => irontide_core::sha1(data) == *expected,
186 HashJob::Streaming {
187 info_hash,
188 piece,
189 expected,
190 backend,
191 ..
192 } => backend
193 .hash_piece(*info_hash, *piece, expected)
194 .unwrap_or(false),
195 }));
196
197 let (piece, generation, result_tx) = match job {
198 HashJob::Data {
199 piece,
200 generation,
201 result_tx,
202 ..
203 }
204 | HashJob::Streaming {
205 piece,
206 generation,
207 result_tx,
208 ..
209 } => (piece, generation, result_tx),
210 };
211
212 let passed = match result {
213 Ok(passed) => passed,
214 Err(panic) => {
215 error!(
216 worker = id,
217 piece,
218 "hash worker panicked: {:?}",
219 panic.downcast_ref::<String>()
220 );
221 false
222 }
223 };
224
225 let _ = result_tx.blocking_send(HashResult {
227 piece,
228 passed,
229 generation,
230 });
231 }
232 debug!(worker = id, "hash worker exiting");
233 }
234}
235
236impl Drop for HashPool {
237 fn drop(&mut self) {
238 for _ in 0..self.workers.len() {
242 let _ = self.worker_tx.try_send(WorkerMsg::Shutdown);
243 }
244 for handle in self.workers.drain(..) {
245 let _ = handle.join();
246 }
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[tokio::test]
255 async fn hash_pool_parallel_correctness() {
256 let pool = HashPool::new(2, 16);
257 let (tx, mut rx) = tokio::sync::mpsc::channel(16);
258
259 for i in 0u32..10 {
261 let data = format!("piece-data-{i}").into_bytes();
262 let expected = if i < 5 {
263 irontide_core::sha1(&data)
264 } else {
265 irontide_core::Id20([0xff; 20])
266 };
267 pool.submit(HashJob::Data {
268 piece: i,
269 expected,
270 generation: 0,
271 data,
272 result_tx: tx.clone(),
273 })
274 .await
275 .unwrap();
276 }
277
278 let mut results = Vec::new();
279 for _ in 0..10 {
280 results.push(rx.recv().await.unwrap());
281 }
282
283 results.sort_by_key(|r| r.piece);
284 assert_eq!(results.len(), 10);
285 for r in &results[..5] {
286 assert!(r.passed, "piece {} should pass", r.piece);
287 }
288 for r in &results[5..] {
289 assert!(!r.passed, "piece {} should fail", r.piece);
290 }
291 }
292
293 #[tokio::test]
294 async fn hash_pool_stale_generation_discard() {
295 let pool = HashPool::new(1, 8);
296 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
297
298 let data = b"piece five data".to_vec();
299 let expected = irontide_core::sha1(&data);
300 pool.submit(HashJob::Data {
301 piece: 5,
302 expected,
303 generation: 1,
304 data,
305 result_tx: tx,
306 })
307 .await
308 .unwrap();
309
310 let r = rx.recv().await.unwrap();
311 assert_eq!(r.piece, 5);
312 assert!(r.passed);
313 assert_eq!(r.generation, 1);
314
315 let current_gen = 2u64;
316 assert!(
317 r.generation != current_gen,
318 "generation 1 result should be stale when current is 2"
319 );
320 }
321
322 #[tokio::test]
323 async fn hash_pool_concurrent_cancel_resubmit() {
324 let pool = HashPool::new(2, 16);
325 let (tx, mut rx) = tokio::sync::mpsc::channel(16);
326
327 let data1 = b"piece42-attempt1".to_vec();
328 let expected1 = irontide_core::sha1(&data1);
329 pool.submit(HashJob::Data {
330 piece: 42,
331 expected: expected1,
332 generation: 1,
333 data: data1,
334 result_tx: tx.clone(),
335 })
336 .await
337 .unwrap();
338
339 let data2 = b"piece42-attempt2".to_vec();
340 let expected2 = irontide_core::sha1(&data2);
341 pool.submit(HashJob::Data {
342 piece: 42,
343 expected: expected2,
344 generation: 2,
345 data: data2,
346 result_tx: tx,
347 })
348 .await
349 .unwrap();
350
351 let r1 = rx.recv().await.unwrap();
352 let r2 = rx.recv().await.unwrap();
353 let mut results = [r1, r2];
354 results.sort_by_key(|r| r.generation);
355
356 assert_eq!(results[0].generation, 1);
357 assert!(results[0].passed);
358 assert_eq!(results[1].generation, 2);
359 assert!(results[1].passed);
360
361 let current_gen = 2u64;
362 assert!(results[0].generation != current_gen);
363 assert!(results[1].generation == current_gen);
364 }
365
366 #[tokio::test]
367 async fn hash_pool_shutdown() {
368 let pool = HashPool::new(2, 8);
369 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
370
371 for i in 0u32..3 {
373 let data = format!("data-{i}").into_bytes();
374 let expected = irontide_core::sha1(&data);
375 pool.submit(HashJob::Data {
376 piece: i,
377 expected,
378 generation: 0,
379 data,
380 result_tx: tx.clone(),
381 })
382 .await
383 .unwrap();
384 }
385
386 for _ in 0..3 {
388 let r = rx.recv().await.unwrap();
389 assert!(r.passed);
390 }
391
392 drop(pool);
394 drop(tx);
396
397 assert!(rx.recv().await.is_none());
399 }
400
401 #[tokio::test]
402 async fn hash_pool_failure_recovery() {
403 let pool = HashPool::new(1, 8);
404 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
405
406 let data = b"corrupt data".to_vec();
408 let expected = irontide_core::Id20([0xAA; 20]); pool.submit(HashJob::Data {
410 piece: 0,
411 expected,
412 generation: 0,
413 data,
414 result_tx: tx.clone(),
415 })
416 .await
417 .unwrap();
418
419 let r = rx.recv().await.unwrap();
420 assert!(!r.passed, "hash mismatch should report failure");
421 assert_eq!(r.piece, 0);
422
423 let data2 = b"good data".to_vec();
425 let expected2 = irontide_core::sha1(&data2);
426 pool.submit(HashJob::Data {
427 piece: 1,
428 expected: expected2,
429 generation: 0,
430 data: data2,
431 result_tx: tx,
432 })
433 .await
434 .unwrap();
435
436 let r2 = rx.recv().await.unwrap();
437 assert!(r2.passed, "correct hash should pass");
438 assert_eq!(r2.piece, 1);
439 }
440
441 #[tokio::test]
442 async fn hash_pool_streaming_variant() {
443 use irontide_core::Lengths;
444 use irontide_storage::MemoryStorage;
445 use std::sync::Arc;
446
447 let pool = HashPool::new(1, 8);
448 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
449
450 let data = vec![0xCDu8; 16384];
452 let expected = irontide_core::sha1(&data);
453 let info_hash = irontide_core::Id20([0x01; 20]);
454 let lengths = Lengths::new(16384, 16384, 16384);
455 let storage: Arc<dyn irontide_storage::TorrentStorage> =
456 Arc::new(MemoryStorage::new(lengths));
457 storage.write_chunk(0, 0, &data).unwrap();
458
459 let config = crate::disk::DiskConfig::default();
460 let backend: Arc<dyn crate::disk_backend::DiskIoBackend> =
461 Arc::new(crate::disk_backend::PosixDiskIo::new(&config));
462 backend.register(info_hash, storage);
463
464 pool.submit(HashJob::Streaming {
465 piece: 0,
466 expected,
467 generation: 0,
468 info_hash,
469 backend,
470 result_tx: tx,
471 })
472 .await
473 .unwrap();
474
475 let r = rx.recv().await.unwrap();
476 assert!(r.passed, "streaming hash should pass");
477 assert_eq!(r.piece, 0);
478 assert_eq!(r.generation, 0);
479 }
480}