Skip to main content

irontide_session/
hash_pool.rs

1//! Dedicated thread pool for CPU-bound piece hash verification (M96).
2//!
3//! Workers receive `HashJob`s via `tokio::sync::mpsc` (bridged to `std::sync::mpsc`),
4//! compute SHA1 hashes with `catch_unwind` panic recovery, and send results back
5//! via per-torrent `tokio::sync::mpsc::Sender` carried in each job.
6
7use std::thread;
8
9use tracing::{debug, error};
10
11/// Internal enum threaded through the std mpsc to workers.
12/// `Shutdown` is sent during `Drop` to unblock workers without requiring the
13/// async bridge task to have flushed first.
14enum WorkerMsg {
15    Job(HashJob),
16    Shutdown,
17}
18
19/// Job submitted to the hash pool.
20pub enum HashJob {
21    /// Pre-read piece data (original path).
22    Data {
23        /// Piece index to verify.
24        piece: u32,
25        /// Expected SHA1 hash.
26        expected: irontide_core::Id20,
27        /// Generation counter for staleness detection.
28        generation: u64,
29        /// Pre-extracted piece data.
30        data: Vec<u8>,
31        /// Per-torrent result sender.
32        result_tx: tokio::sync::mpsc::Sender<HashResult>,
33    },
34    /// Streaming verify via backend (M101 — no full-piece alloc).
35    Streaming {
36        /// Piece index to verify.
37        piece: u32,
38        /// Expected SHA1 hash.
39        expected: irontide_core::Id20,
40        /// Generation counter for staleness detection.
41        generation: u64,
42        /// Info hash for backend lookup.
43        info_hash: irontide_core::Id20,
44        /// Disk I/O backend for streaming verification.
45        backend: std::sync::Arc<dyn crate::disk_backend::DiskIoBackend>,
46        /// Per-torrent result sender.
47        result_tx: tokio::sync::mpsc::Sender<HashResult>,
48    },
49}
50
51impl std::fmt::Debug for HashJob {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            HashJob::Data {
55                piece, generation, ..
56            } => f
57                .debug_struct("HashJob::Data")
58                .field("piece", piece)
59                .field("generation", generation)
60                .finish_non_exhaustive(),
61            HashJob::Streaming {
62                piece, generation, ..
63            } => f
64                .debug_struct("HashJob::Streaming")
65                .field("piece", piece)
66                .field("generation", generation)
67                .finish_non_exhaustive(),
68        }
69    }
70}
71
72/// Result of a hash verification.
73#[derive(Debug)]
74pub struct HashResult {
75    /// Piece index that was verified.
76    pub piece: u32,
77    /// Whether the hash matched.
78    pub passed: bool,
79    /// Generation counter (for staleness check by caller).
80    pub generation: u64,
81}
82
83/// A thread pool dedicated to CPU-bound piece hash verification.
84///
85/// Uses `tokio::sync::mpsc` for async job submission and `std::sync::mpsc`
86/// internally for the blocking worker threads. Results are sent back via
87/// per-torrent `tokio::sync::mpsc::Sender` carried in each `HashJob`.
88pub struct HashPool {
89    /// Async sender for submitting jobs.
90    job_tx: tokio::sync::mpsc::Sender<HashJob>,
91    /// Direct std sender for sending `Shutdown` sentinels in `Drop`.
92    worker_tx: std::sync::mpsc::SyncSender<WorkerMsg>,
93    /// Worker thread join handles (for clean shutdown).
94    workers: Vec<thread::JoinHandle<()>>,
95}
96
97impl std::fmt::Debug for HashPool {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("HashPool")
100            .field("workers", &self.workers.len())
101            .finish()
102    }
103}
104
105impl HashPool {
106    /// Create a new hash pool with `num_workers` threads.
107    ///
108    /// `job_capacity` bounds the submission channel (backpressure if full).
109    ///
110    /// **Must be called from within a tokio runtime** — spawns a bridge task.
111    pub fn new(num_workers: usize, job_capacity: usize) -> Self {
112        let (job_tx, mut job_async_rx) = tokio::sync::mpsc::channel::<HashJob>(job_capacity);
113        // +num_workers headroom so Drop can always enqueue shutdown sentinels.
114        let (worker_tx, worker_rx) =
115            std::sync::mpsc::sync_channel::<WorkerMsg>(job_capacity + num_workers);
116
117        // Bridge task: forwards jobs from tokio channel to std channel.
118        let bridge_tx = worker_tx.clone();
119        tokio::spawn(async move {
120            while let Some(job) = job_async_rx.recv().await {
121                if bridge_tx.send(WorkerMsg::Job(job)).is_err() {
122                    break;
123                }
124            }
125        });
126
127        // Spawn worker threads
128        let worker_rx = std::sync::Arc::new(parking_lot::Mutex::new(worker_rx));
129        let mut workers = Vec::with_capacity(num_workers);
130        for id in 0..num_workers {
131            let rx = worker_rx.clone();
132            let handle = thread::Builder::new()
133                .name(format!("hash-worker-{id}"))
134                .spawn(move || {
135                    Self::worker_loop(id, rx);
136                })
137                .expect("failed to spawn hash worker thread");
138            workers.push(handle);
139        }
140
141        HashPool {
142            job_tx,
143            worker_tx,
144            workers,
145        }
146    }
147
148    /// Submit a hash job. Returns `Err` if the pool has been shut down.
149    pub async fn submit(&self, job: HashJob) -> Result<(), HashJob> {
150        self.job_tx.send(job).await.map_err(|e| e.0)
151    }
152
153    fn worker_loop(
154        id: usize,
155        rx: std::sync::Arc<parking_lot::Mutex<std::sync::mpsc::Receiver<WorkerMsg>>>,
156    ) {
157        loop {
158            let msg = {
159                let rx = rx.lock();
160                match rx.recv() {
161                    Ok(msg) => msg,
162                    Err(_) => break, // Channel closed
163                }
164            };
165
166            let job = match msg {
167                WorkerMsg::Shutdown => break,
168                WorkerMsg::Job(job) => job,
169            };
170
171            let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| match &job {
172                HashJob::Data { data, expected, .. } => irontide_core::sha1(data) == *expected,
173                HashJob::Streaming {
174                    info_hash,
175                    piece,
176                    expected,
177                    backend,
178                    ..
179                } => backend
180                    .hash_piece(*info_hash, *piece, expected)
181                    .unwrap_or(false),
182            }));
183
184            let (piece, generation, result_tx) = match job {
185                HashJob::Data {
186                    piece,
187                    generation,
188                    result_tx,
189                    ..
190                } => (piece, generation, result_tx),
191                HashJob::Streaming {
192                    piece,
193                    generation,
194                    result_tx,
195                    ..
196                } => (piece, generation, result_tx),
197            };
198
199            let passed = match result {
200                Ok(passed) => passed,
201                Err(panic) => {
202                    error!(
203                        worker = id,
204                        piece,
205                        "hash worker panicked: {:?}",
206                        panic.downcast_ref::<String>()
207                    );
208                    false
209                }
210            };
211
212            if result_tx
213                .blocking_send(HashResult {
214                    piece,
215                    passed,
216                    generation,
217                })
218                .is_err()
219            {
220                // Torrent removed — result dropped, worker continues
221                continue;
222            }
223        }
224        debug!(worker = id, "hash worker exiting");
225    }
226}
227
228impl Drop for HashPool {
229    fn drop(&mut self) {
230        // Send one Shutdown sentinel per worker so each unblocks from recv().
231        // This works even if the async bridge task hasn't flushed yet, because
232        // we hold a direct std sender (`worker_tx`) that bypasses the bridge.
233        for _ in 0..self.workers.len() {
234            let _ = self.worker_tx.try_send(WorkerMsg::Shutdown);
235        }
236        for handle in self.workers.drain(..) {
237            let _ = handle.join();
238        }
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[tokio::test]
247    async fn hash_pool_parallel_correctness() {
248        let pool = HashPool::new(2, 16);
249        let (tx, mut rx) = tokio::sync::mpsc::channel(16);
250
251        // Submit 10 jobs: 5 matching, 5 mismatched
252        for i in 0u32..10 {
253            let data = format!("piece-data-{i}").into_bytes();
254            let expected = if i < 5 {
255                irontide_core::sha1(&data)
256            } else {
257                irontide_core::Id20([0xff; 20])
258            };
259            pool.submit(HashJob::Data {
260                piece: i,
261                expected,
262                generation: 0,
263                data,
264                result_tx: tx.clone(),
265            })
266            .await
267            .unwrap();
268        }
269
270        let mut results = Vec::new();
271        for _ in 0..10 {
272            results.push(rx.recv().await.unwrap());
273        }
274
275        results.sort_by_key(|r| r.piece);
276        assert_eq!(results.len(), 10);
277        for r in &results[..5] {
278            assert!(r.passed, "piece {} should pass", r.piece);
279        }
280        for r in &results[5..] {
281            assert!(!r.passed, "piece {} should fail", r.piece);
282        }
283    }
284
285    #[tokio::test]
286    async fn hash_pool_stale_generation_discard() {
287        let pool = HashPool::new(1, 8);
288        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
289
290        let data = b"piece five data".to_vec();
291        let expected = irontide_core::sha1(&data);
292        pool.submit(HashJob::Data {
293            piece: 5,
294            expected,
295            generation: 1,
296            data,
297            result_tx: tx,
298        })
299        .await
300        .unwrap();
301
302        let r = rx.recv().await.unwrap();
303        assert_eq!(r.piece, 5);
304        assert!(r.passed);
305        assert_eq!(r.generation, 1);
306
307        let current_gen = 2u64;
308        assert!(
309            r.generation != current_gen,
310            "generation 1 result should be stale when current is 2"
311        );
312    }
313
314    #[tokio::test]
315    async fn hash_pool_concurrent_cancel_resubmit() {
316        let pool = HashPool::new(2, 16);
317        let (tx, mut rx) = tokio::sync::mpsc::channel(16);
318
319        let data1 = b"piece42-attempt1".to_vec();
320        let expected1 = irontide_core::sha1(&data1);
321        pool.submit(HashJob::Data {
322            piece: 42,
323            expected: expected1,
324            generation: 1,
325            data: data1,
326            result_tx: tx.clone(),
327        })
328        .await
329        .unwrap();
330
331        let data2 = b"piece42-attempt2".to_vec();
332        let expected2 = irontide_core::sha1(&data2);
333        pool.submit(HashJob::Data {
334            piece: 42,
335            expected: expected2,
336            generation: 2,
337            data: data2,
338            result_tx: tx,
339        })
340        .await
341        .unwrap();
342
343        let r1 = rx.recv().await.unwrap();
344        let r2 = rx.recv().await.unwrap();
345        let mut results = vec![r1, r2];
346        results.sort_by_key(|r| r.generation);
347
348        assert_eq!(results[0].generation, 1);
349        assert!(results[0].passed);
350        assert_eq!(results[1].generation, 2);
351        assert!(results[1].passed);
352
353        let current_gen = 2u64;
354        assert!(results[0].generation != current_gen);
355        assert!(results[1].generation == current_gen);
356    }
357
358    #[tokio::test]
359    async fn hash_pool_shutdown() {
360        let pool = HashPool::new(2, 8);
361        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
362
363        // Submit a few jobs
364        for i in 0u32..3 {
365            let data = format!("data-{i}").into_bytes();
366            let expected = irontide_core::sha1(&data);
367            pool.submit(HashJob::Data {
368                piece: i,
369                expected,
370                generation: 0,
371                data,
372                result_tx: tx.clone(),
373            })
374            .await
375            .unwrap();
376        }
377
378        // Collect results
379        for _ in 0..3 {
380            let r = rx.recv().await.unwrap();
381            assert!(r.passed);
382        }
383
384        // Drop pool — should join worker threads cleanly
385        drop(pool);
386        // Drop our tx clone so channel can close
387        drop(tx);
388
389        // After drop, recv should return None (channel closed)
390        assert!(rx.recv().await.is_none());
391    }
392
393    #[tokio::test]
394    async fn hash_pool_failure_recovery() {
395        let pool = HashPool::new(1, 8);
396        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
397
398        // Submit a job with mismatched hash
399        let data = b"corrupt data".to_vec();
400        let expected = irontide_core::Id20([0xAA; 20]); // Wrong hash
401        pool.submit(HashJob::Data {
402            piece: 0,
403            expected,
404            generation: 0,
405            data,
406            result_tx: tx.clone(),
407        })
408        .await
409        .unwrap();
410
411        let r = rx.recv().await.unwrap();
412        assert!(!r.passed, "hash mismatch should report failure");
413        assert_eq!(r.piece, 0);
414
415        // Worker should continue — submit a correct job
416        let data2 = b"good data".to_vec();
417        let expected2 = irontide_core::sha1(&data2);
418        pool.submit(HashJob::Data {
419            piece: 1,
420            expected: expected2,
421            generation: 0,
422            data: data2,
423            result_tx: tx,
424        })
425        .await
426        .unwrap();
427
428        let r2 = rx.recv().await.unwrap();
429        assert!(r2.passed, "correct hash should pass");
430        assert_eq!(r2.piece, 1);
431    }
432
433    #[tokio::test]
434    async fn hash_pool_streaming_variant() {
435        use irontide_core::Lengths;
436        use irontide_storage::MemoryStorage;
437        use std::sync::Arc;
438
439        let pool = HashPool::new(1, 8);
440        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
441
442        // Set up a backend with real data
443        let data = vec![0xCDu8; 16384];
444        let expected = irontide_core::sha1(&data);
445        let info_hash = irontide_core::Id20([0x01; 20]);
446        let lengths = Lengths::new(16384, 16384, 16384);
447        let storage: Arc<dyn irontide_storage::TorrentStorage> =
448            Arc::new(MemoryStorage::new(lengths));
449        storage.write_chunk(0, 0, &data).unwrap();
450
451        let config = crate::disk::DiskConfig::default();
452        let backend: Arc<dyn crate::disk_backend::DiskIoBackend> =
453            Arc::new(crate::disk_backend::PosixDiskIo::new(&config));
454        backend.register(info_hash, storage);
455
456        pool.submit(HashJob::Streaming {
457            piece: 0,
458            expected,
459            generation: 0,
460            info_hash,
461            backend,
462            result_tx: tx,
463        })
464        .await
465        .unwrap();
466
467        let r = rx.recv().await.unwrap();
468        assert!(r.passed, "streaming hash should pass");
469        assert_eq!(r.piece, 0);
470        assert_eq!(r.generation, 0);
471    }
472}