1use std::thread;
8
9use tracing::{debug, error};
10
11enum WorkerMsg {
15 Job(HashJob),
16 Shutdown,
17}
18
19pub enum HashJob {
21 Data {
23 piece: u32,
25 expected: irontide_core::Id20,
27 generation: u64,
29 data: Vec<u8>,
31 result_tx: tokio::sync::mpsc::Sender<HashResult>,
33 },
34 Streaming {
36 piece: u32,
38 expected: irontide_core::Id20,
40 generation: u64,
42 info_hash: irontide_core::Id20,
44 backend: std::sync::Arc<dyn crate::disk_backend::DiskIoBackend>,
46 result_tx: tokio::sync::mpsc::Sender<HashResult>,
48 },
49}
50
51impl std::fmt::Debug for HashJob {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 HashJob::Data {
55 piece, generation, ..
56 } => f
57 .debug_struct("HashJob::Data")
58 .field("piece", piece)
59 .field("generation", generation)
60 .finish_non_exhaustive(),
61 HashJob::Streaming {
62 piece, generation, ..
63 } => f
64 .debug_struct("HashJob::Streaming")
65 .field("piece", piece)
66 .field("generation", generation)
67 .finish_non_exhaustive(),
68 }
69 }
70}
71
72#[derive(Debug)]
74pub struct HashResult {
75 pub piece: u32,
77 pub passed: bool,
79 pub generation: u64,
81}
82
83pub struct HashPool {
89 job_tx: tokio::sync::mpsc::Sender<HashJob>,
91 worker_tx: std::sync::mpsc::SyncSender<WorkerMsg>,
93 workers: Vec<thread::JoinHandle<()>>,
95}
96
97impl std::fmt::Debug for HashPool {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("HashPool")
100 .field("workers", &self.workers.len())
101 .finish()
102 }
103}
104
105impl HashPool {
106 pub fn new(num_workers: usize, job_capacity: usize) -> Self {
112 let (job_tx, mut job_async_rx) = tokio::sync::mpsc::channel::<HashJob>(job_capacity);
113 let (worker_tx, worker_rx) =
115 std::sync::mpsc::sync_channel::<WorkerMsg>(job_capacity + num_workers);
116
117 let bridge_tx = worker_tx.clone();
119 tokio::spawn(async move {
120 while let Some(job) = job_async_rx.recv().await {
121 if bridge_tx.send(WorkerMsg::Job(job)).is_err() {
122 break;
123 }
124 }
125 });
126
127 let worker_rx = std::sync::Arc::new(parking_lot::Mutex::new(worker_rx));
129 let mut workers = Vec::with_capacity(num_workers);
130 for id in 0..num_workers {
131 let rx = worker_rx.clone();
132 let handle = thread::Builder::new()
133 .name(format!("hash-worker-{id}"))
134 .spawn(move || {
135 Self::worker_loop(id, rx);
136 })
137 .expect("failed to spawn hash worker thread");
138 workers.push(handle);
139 }
140
141 HashPool {
142 job_tx,
143 worker_tx,
144 workers,
145 }
146 }
147
148 pub async fn submit(&self, job: HashJob) -> Result<(), HashJob> {
150 self.job_tx.send(job).await.map_err(|e| e.0)
151 }
152
153 fn worker_loop(
154 id: usize,
155 rx: std::sync::Arc<parking_lot::Mutex<std::sync::mpsc::Receiver<WorkerMsg>>>,
156 ) {
157 loop {
158 let msg = {
159 let rx = rx.lock();
160 match rx.recv() {
161 Ok(msg) => msg,
162 Err(_) => break, }
164 };
165
166 let job = match msg {
167 WorkerMsg::Shutdown => break,
168 WorkerMsg::Job(job) => job,
169 };
170
171 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| match &job {
172 HashJob::Data { data, expected, .. } => irontide_core::sha1(data) == *expected,
173 HashJob::Streaming {
174 info_hash,
175 piece,
176 expected,
177 backend,
178 ..
179 } => backend
180 .hash_piece(*info_hash, *piece, expected)
181 .unwrap_or(false),
182 }));
183
184 let (piece, generation, result_tx) = match job {
185 HashJob::Data {
186 piece,
187 generation,
188 result_tx,
189 ..
190 } => (piece, generation, result_tx),
191 HashJob::Streaming {
192 piece,
193 generation,
194 result_tx,
195 ..
196 } => (piece, generation, result_tx),
197 };
198
199 let passed = match result {
200 Ok(passed) => passed,
201 Err(panic) => {
202 error!(
203 worker = id,
204 piece,
205 "hash worker panicked: {:?}",
206 panic.downcast_ref::<String>()
207 );
208 false
209 }
210 };
211
212 if result_tx
213 .blocking_send(HashResult {
214 piece,
215 passed,
216 generation,
217 })
218 .is_err()
219 {
220 continue;
222 }
223 }
224 debug!(worker = id, "hash worker exiting");
225 }
226}
227
228impl Drop for HashPool {
229 fn drop(&mut self) {
230 for _ in 0..self.workers.len() {
234 let _ = self.worker_tx.try_send(WorkerMsg::Shutdown);
235 }
236 for handle in self.workers.drain(..) {
237 let _ = handle.join();
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[tokio::test]
247 async fn hash_pool_parallel_correctness() {
248 let pool = HashPool::new(2, 16);
249 let (tx, mut rx) = tokio::sync::mpsc::channel(16);
250
251 for i in 0u32..10 {
253 let data = format!("piece-data-{i}").into_bytes();
254 let expected = if i < 5 {
255 irontide_core::sha1(&data)
256 } else {
257 irontide_core::Id20([0xff; 20])
258 };
259 pool.submit(HashJob::Data {
260 piece: i,
261 expected,
262 generation: 0,
263 data,
264 result_tx: tx.clone(),
265 })
266 .await
267 .unwrap();
268 }
269
270 let mut results = Vec::new();
271 for _ in 0..10 {
272 results.push(rx.recv().await.unwrap());
273 }
274
275 results.sort_by_key(|r| r.piece);
276 assert_eq!(results.len(), 10);
277 for r in &results[..5] {
278 assert!(r.passed, "piece {} should pass", r.piece);
279 }
280 for r in &results[5..] {
281 assert!(!r.passed, "piece {} should fail", r.piece);
282 }
283 }
284
285 #[tokio::test]
286 async fn hash_pool_stale_generation_discard() {
287 let pool = HashPool::new(1, 8);
288 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
289
290 let data = b"piece five data".to_vec();
291 let expected = irontide_core::sha1(&data);
292 pool.submit(HashJob::Data {
293 piece: 5,
294 expected,
295 generation: 1,
296 data,
297 result_tx: tx,
298 })
299 .await
300 .unwrap();
301
302 let r = rx.recv().await.unwrap();
303 assert_eq!(r.piece, 5);
304 assert!(r.passed);
305 assert_eq!(r.generation, 1);
306
307 let current_gen = 2u64;
308 assert!(
309 r.generation != current_gen,
310 "generation 1 result should be stale when current is 2"
311 );
312 }
313
314 #[tokio::test]
315 async fn hash_pool_concurrent_cancel_resubmit() {
316 let pool = HashPool::new(2, 16);
317 let (tx, mut rx) = tokio::sync::mpsc::channel(16);
318
319 let data1 = b"piece42-attempt1".to_vec();
320 let expected1 = irontide_core::sha1(&data1);
321 pool.submit(HashJob::Data {
322 piece: 42,
323 expected: expected1,
324 generation: 1,
325 data: data1,
326 result_tx: tx.clone(),
327 })
328 .await
329 .unwrap();
330
331 let data2 = b"piece42-attempt2".to_vec();
332 let expected2 = irontide_core::sha1(&data2);
333 pool.submit(HashJob::Data {
334 piece: 42,
335 expected: expected2,
336 generation: 2,
337 data: data2,
338 result_tx: tx,
339 })
340 .await
341 .unwrap();
342
343 let r1 = rx.recv().await.unwrap();
344 let r2 = rx.recv().await.unwrap();
345 let mut results = vec![r1, r2];
346 results.sort_by_key(|r| r.generation);
347
348 assert_eq!(results[0].generation, 1);
349 assert!(results[0].passed);
350 assert_eq!(results[1].generation, 2);
351 assert!(results[1].passed);
352
353 let current_gen = 2u64;
354 assert!(results[0].generation != current_gen);
355 assert!(results[1].generation == current_gen);
356 }
357
358 #[tokio::test]
359 async fn hash_pool_shutdown() {
360 let pool = HashPool::new(2, 8);
361 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
362
363 for i in 0u32..3 {
365 let data = format!("data-{i}").into_bytes();
366 let expected = irontide_core::sha1(&data);
367 pool.submit(HashJob::Data {
368 piece: i,
369 expected,
370 generation: 0,
371 data,
372 result_tx: tx.clone(),
373 })
374 .await
375 .unwrap();
376 }
377
378 for _ in 0..3 {
380 let r = rx.recv().await.unwrap();
381 assert!(r.passed);
382 }
383
384 drop(pool);
386 drop(tx);
388
389 assert!(rx.recv().await.is_none());
391 }
392
393 #[tokio::test]
394 async fn hash_pool_failure_recovery() {
395 let pool = HashPool::new(1, 8);
396 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
397
398 let data = b"corrupt data".to_vec();
400 let expected = irontide_core::Id20([0xAA; 20]); pool.submit(HashJob::Data {
402 piece: 0,
403 expected,
404 generation: 0,
405 data,
406 result_tx: tx.clone(),
407 })
408 .await
409 .unwrap();
410
411 let r = rx.recv().await.unwrap();
412 assert!(!r.passed, "hash mismatch should report failure");
413 assert_eq!(r.piece, 0);
414
415 let data2 = b"good data".to_vec();
417 let expected2 = irontide_core::sha1(&data2);
418 pool.submit(HashJob::Data {
419 piece: 1,
420 expected: expected2,
421 generation: 0,
422 data: data2,
423 result_tx: tx,
424 })
425 .await
426 .unwrap();
427
428 let r2 = rx.recv().await.unwrap();
429 assert!(r2.passed, "correct hash should pass");
430 assert_eq!(r2.piece, 1);
431 }
432
433 #[tokio::test]
434 async fn hash_pool_streaming_variant() {
435 use irontide_core::Lengths;
436 use irontide_storage::MemoryStorage;
437 use std::sync::Arc;
438
439 let pool = HashPool::new(1, 8);
440 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
441
442 let data = vec![0xCDu8; 16384];
444 let expected = irontide_core::sha1(&data);
445 let info_hash = irontide_core::Id20([0x01; 20]);
446 let lengths = Lengths::new(16384, 16384, 16384);
447 let storage: Arc<dyn irontide_storage::TorrentStorage> =
448 Arc::new(MemoryStorage::new(lengths));
449 storage.write_chunk(0, 0, &data).unwrap();
450
451 let config = crate::disk::DiskConfig::default();
452 let backend: Arc<dyn crate::disk_backend::DiskIoBackend> =
453 Arc::new(crate::disk_backend::PosixDiskIo::new(&config));
454 backend.register(info_hash, storage);
455
456 pool.submit(HashJob::Streaming {
457 piece: 0,
458 expected,
459 generation: 0,
460 info_hash,
461 backend,
462 result_tx: tx,
463 })
464 .await
465 .unwrap();
466
467 let r = rx.recv().await.unwrap();
468 assert!(r.passed, "streaming hash should pass");
469 assert_eq!(r.piece, 0);
470 assert_eq!(r.generation, 0);
471 }
472}