Skip to main content

irontide_session/
hash_pool.rs

1//! Dedicated thread pool for CPU-bound piece hash verification (M96).
2//!
3//! Workers receive `HashJob`s via `tokio::sync::mpsc` (bridged to `std::sync::mpsc`),
4//! compute SHA1 hashes with `catch_unwind` panic recovery, and send results back
5//! via per-torrent `tokio::sync::mpsc::Sender` carried in each job.
6
7use std::thread;
8
9use tracing::{debug, error};
10
11/// Internal enum threaded through the std mpsc to workers.
12/// `Shutdown` is sent during `Drop` to unblock workers without requiring the
13/// async bridge task to have flushed first.
14enum WorkerMsg {
15    Job(HashJob),
16    Shutdown,
17}
18
19/// Job submitted to the hash pool.
20pub enum HashJob {
21    /// Pre-read piece data (original path).
22    Data {
23        /// Piece index to verify.
24        piece: u32,
25        /// Expected SHA1 hash.
26        expected: irontide_core::Id20,
27        /// Generation counter for staleness detection.
28        generation: u64,
29        /// Pre-extracted piece data.
30        data: Vec<u8>,
31        /// Per-torrent result sender.
32        result_tx: tokio::sync::mpsc::Sender<HashResult>,
33    },
34    /// Streaming verify via backend (M101 — no full-piece alloc).
35    Streaming {
36        /// Piece index to verify.
37        piece: u32,
38        /// Expected SHA1 hash.
39        expected: irontide_core::Id20,
40        /// Generation counter for staleness detection.
41        generation: u64,
42        /// Info hash for backend lookup.
43        info_hash: irontide_core::Id20,
44        /// Disk I/O backend for streaming verification.
45        backend: std::sync::Arc<dyn crate::disk_backend::DiskIoBackend>,
46        /// Per-torrent result sender.
47        result_tx: tokio::sync::mpsc::Sender<HashResult>,
48    },
49}
50
51impl std::fmt::Debug for HashJob {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            Self::Data {
55                piece, generation, ..
56            } => f
57                .debug_struct("HashJob::Data")
58                .field("piece", piece)
59                .field("generation", generation)
60                .finish_non_exhaustive(),
61            Self::Streaming {
62                piece, generation, ..
63            } => f
64                .debug_struct("HashJob::Streaming")
65                .field("piece", piece)
66                .field("generation", generation)
67                .finish_non_exhaustive(),
68        }
69    }
70}
71
72/// Result of a hash verification.
73#[derive(Debug)]
74pub struct HashResult {
75    /// Piece index that was verified.
76    pub piece: u32,
77    /// Whether the hash matched.
78    pub passed: bool,
79    /// Generation counter (for staleness check by caller).
80    pub generation: u64,
81}
82
83/// A thread pool dedicated to CPU-bound piece hash verification.
84///
85/// Uses `tokio::sync::mpsc` for async job submission and `std::sync::mpsc`
86/// internally for the blocking worker threads. Results are sent back via
87/// per-torrent `tokio::sync::mpsc::Sender` carried in each `HashJob`.
88pub struct HashPool {
89    /// Async sender for submitting jobs.
90    job_tx: tokio::sync::mpsc::Sender<HashJob>,
91    /// Direct std sender for sending `Shutdown` sentinels in `Drop`.
92    worker_tx: std::sync::mpsc::SyncSender<WorkerMsg>,
93    /// Worker thread join handles (for clean shutdown).
94    workers: Vec<thread::JoinHandle<()>>,
95}
96
97#[allow(
98    clippy::missing_fields_in_debug,
99    reason = "intentionally omit internal channel fields from Debug output"
100)]
101impl std::fmt::Debug for HashPool {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("HashPool")
104            .field("workers", &self.workers.len())
105            .finish()
106    }
107}
108
109impl HashPool {
110    /// Create a new hash pool with `num_workers` threads.
111    ///
112    /// `job_capacity` bounds the submission channel (backpressure if full).
113    ///
114    /// **Must be called from within a tokio runtime** — spawns a bridge task.
115    #[must_use]
116    pub fn new(num_workers: usize, job_capacity: usize) -> Self {
117        let (job_tx, mut job_async_rx) = tokio::sync::mpsc::channel::<HashJob>(job_capacity);
118        // +num_workers headroom so Drop can always enqueue shutdown sentinels.
119        let (worker_tx, worker_rx) =
120            std::sync::mpsc::sync_channel::<WorkerMsg>(job_capacity + num_workers);
121
122        // Bridge task: forwards jobs from tokio channel to std channel.
123        let bridge_tx = worker_tx.clone();
124        tokio::spawn(async move {
125            while let Some(job) = job_async_rx.recv().await {
126                if bridge_tx.send(WorkerMsg::Job(job)).is_err() {
127                    break;
128                }
129            }
130        });
131
132        // Spawn worker threads
133        let worker_rx = std::sync::Arc::new(parking_lot::Mutex::new(worker_rx));
134        let mut workers = Vec::with_capacity(num_workers);
135        for id in 0..num_workers {
136            let rx = worker_rx.clone();
137            let handle = thread::Builder::new()
138                .name(format!("hash-worker-{id}"))
139                .spawn(move || {
140                    Self::worker_loop(id, rx);
141                })
142                .expect("failed to spawn hash worker thread");
143            workers.push(handle);
144        }
145
146        Self {
147            job_tx,
148            worker_tx,
149            workers,
150        }
151    }
152
153    /// Submit a hash job. Returns `Err` if the pool has been shut down.
154    ///
155    /// # Errors
156    ///
157    /// Returns an error if the session is shut down.
158    pub async fn submit(&self, job: HashJob) -> Result<(), HashJob> {
159        self.job_tx.send(job).await.map_err(|e| e.0)
160    }
161
162    #[allow(
163        clippy::needless_pass_by_value,
164        reason = "Arc is moved into worker thread — pass-by-value is the ownership transfer idiom"
165    )]
166    fn worker_loop(
167        id: usize,
168        rx: std::sync::Arc<parking_lot::Mutex<std::sync::mpsc::Receiver<WorkerMsg>>>,
169    ) {
170        loop {
171            let msg = {
172                let rx = rx.lock();
173                match rx.recv() {
174                    Ok(msg) => msg,
175                    Err(_) => break, // Channel closed
176                }
177            };
178
179            let job = match msg {
180                WorkerMsg::Shutdown => break,
181                WorkerMsg::Job(job) => job,
182            };
183
184            let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| match &job {
185                HashJob::Data { data, expected, .. } => irontide_core::sha1(data) == *expected,
186                HashJob::Streaming {
187                    info_hash,
188                    piece,
189                    expected,
190                    backend,
191                    ..
192                } => backend
193                    .hash_piece(*info_hash, *piece, expected)
194                    .unwrap_or(false),
195            }));
196
197            let (piece, generation, result_tx) = match job {
198                HashJob::Data {
199                    piece,
200                    generation,
201                    result_tx,
202                    ..
203                }
204                | HashJob::Streaming {
205                    piece,
206                    generation,
207                    result_tx,
208                    ..
209                } => (piece, generation, result_tx),
210            };
211
212            let passed = match result {
213                Ok(passed) => passed,
214                Err(panic) => {
215                    error!(
216                        worker = id,
217                        piece,
218                        "hash worker panicked: {:?}",
219                        panic.downcast_ref::<String>()
220                    );
221                    false
222                }
223            };
224
225            // Torrent removed — result dropped, worker continues
226            let _ = result_tx.blocking_send(HashResult {
227                piece,
228                passed,
229                generation,
230            });
231        }
232        debug!(worker = id, "hash worker exiting");
233    }
234}
235
236impl Drop for HashPool {
237    fn drop(&mut self) {
238        // Send one Shutdown sentinel per worker so each unblocks from recv().
239        // This works even if the async bridge task hasn't flushed yet, because
240        // we hold a direct std sender (`worker_tx`) that bypasses the bridge.
241        for _ in 0..self.workers.len() {
242            let _ = self.worker_tx.try_send(WorkerMsg::Shutdown);
243        }
244        for handle in self.workers.drain(..) {
245            let _ = handle.join();
246        }
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[tokio::test]
255    async fn hash_pool_parallel_correctness() {
256        let pool = HashPool::new(2, 16);
257        let (tx, mut rx) = tokio::sync::mpsc::channel(16);
258
259        // Submit 10 jobs: 5 matching, 5 mismatched
260        for i in 0u32..10 {
261            let data = format!("piece-data-{i}").into_bytes();
262            let expected = if i < 5 {
263                irontide_core::sha1(&data)
264            } else {
265                irontide_core::Id20([0xff; 20])
266            };
267            pool.submit(HashJob::Data {
268                piece: i,
269                expected,
270                generation: 0,
271                data,
272                result_tx: tx.clone(),
273            })
274            .await
275            .unwrap();
276        }
277
278        let mut results = Vec::new();
279        for _ in 0..10 {
280            results.push(rx.recv().await.unwrap());
281        }
282
283        results.sort_by_key(|r| r.piece);
284        assert_eq!(results.len(), 10);
285        for r in &results[..5] {
286            assert!(r.passed, "piece {} should pass", r.piece);
287        }
288        for r in &results[5..] {
289            assert!(!r.passed, "piece {} should fail", r.piece);
290        }
291    }
292
293    #[tokio::test]
294    async fn hash_pool_stale_generation_discard() {
295        let pool = HashPool::new(1, 8);
296        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
297
298        let data = b"piece five data".to_vec();
299        let expected = irontide_core::sha1(&data);
300        pool.submit(HashJob::Data {
301            piece: 5,
302            expected,
303            generation: 1,
304            data,
305            result_tx: tx,
306        })
307        .await
308        .unwrap();
309
310        let r = rx.recv().await.unwrap();
311        assert_eq!(r.piece, 5);
312        assert!(r.passed);
313        assert_eq!(r.generation, 1);
314
315        let current_gen = 2u64;
316        assert!(
317            r.generation != current_gen,
318            "generation 1 result should be stale when current is 2"
319        );
320    }
321
322    #[tokio::test]
323    async fn hash_pool_concurrent_cancel_resubmit() {
324        let pool = HashPool::new(2, 16);
325        let (tx, mut rx) = tokio::sync::mpsc::channel(16);
326
327        let data1 = b"piece42-attempt1".to_vec();
328        let expected1 = irontide_core::sha1(&data1);
329        pool.submit(HashJob::Data {
330            piece: 42,
331            expected: expected1,
332            generation: 1,
333            data: data1,
334            result_tx: tx.clone(),
335        })
336        .await
337        .unwrap();
338
339        let data2 = b"piece42-attempt2".to_vec();
340        let expected2 = irontide_core::sha1(&data2);
341        pool.submit(HashJob::Data {
342            piece: 42,
343            expected: expected2,
344            generation: 2,
345            data: data2,
346            result_tx: tx,
347        })
348        .await
349        .unwrap();
350
351        let r1 = rx.recv().await.unwrap();
352        let r2 = rx.recv().await.unwrap();
353        let mut results = [r1, r2];
354        results.sort_by_key(|r| r.generation);
355
356        assert_eq!(results[0].generation, 1);
357        assert!(results[0].passed);
358        assert_eq!(results[1].generation, 2);
359        assert!(results[1].passed);
360
361        let current_gen = 2u64;
362        assert!(results[0].generation != current_gen);
363        assert!(results[1].generation == current_gen);
364    }
365
366    #[tokio::test]
367    async fn hash_pool_shutdown() {
368        let pool = HashPool::new(2, 8);
369        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
370
371        // Submit a few jobs
372        for i in 0u32..3 {
373            let data = format!("data-{i}").into_bytes();
374            let expected = irontide_core::sha1(&data);
375            pool.submit(HashJob::Data {
376                piece: i,
377                expected,
378                generation: 0,
379                data,
380                result_tx: tx.clone(),
381            })
382            .await
383            .unwrap();
384        }
385
386        // Collect results
387        for _ in 0..3 {
388            let r = rx.recv().await.unwrap();
389            assert!(r.passed);
390        }
391
392        // Drop pool — should join worker threads cleanly
393        drop(pool);
394        // Drop our tx clone so channel can close
395        drop(tx);
396
397        // After drop, recv should return None (channel closed)
398        assert!(rx.recv().await.is_none());
399    }
400
401    #[tokio::test]
402    async fn hash_pool_failure_recovery() {
403        let pool = HashPool::new(1, 8);
404        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
405
406        // Submit a job with mismatched hash
407        let data = b"corrupt data".to_vec();
408        let expected = irontide_core::Id20([0xAA; 20]); // Wrong hash
409        pool.submit(HashJob::Data {
410            piece: 0,
411            expected,
412            generation: 0,
413            data,
414            result_tx: tx.clone(),
415        })
416        .await
417        .unwrap();
418
419        let r = rx.recv().await.unwrap();
420        assert!(!r.passed, "hash mismatch should report failure");
421        assert_eq!(r.piece, 0);
422
423        // Worker should continue — submit a correct job
424        let data2 = b"good data".to_vec();
425        let expected2 = irontide_core::sha1(&data2);
426        pool.submit(HashJob::Data {
427            piece: 1,
428            expected: expected2,
429            generation: 0,
430            data: data2,
431            result_tx: tx,
432        })
433        .await
434        .unwrap();
435
436        let r2 = rx.recv().await.unwrap();
437        assert!(r2.passed, "correct hash should pass");
438        assert_eq!(r2.piece, 1);
439    }
440
441    #[tokio::test]
442    async fn hash_pool_streaming_variant() {
443        use irontide_core::Lengths;
444        use irontide_storage::MemoryStorage;
445        use std::sync::Arc;
446
447        let pool = HashPool::new(1, 8);
448        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
449
450        // Set up a backend with real data
451        let data = vec![0xCDu8; 16384];
452        let expected = irontide_core::sha1(&data);
453        let info_hash = irontide_core::Id20([0x01; 20]);
454        let lengths = Lengths::new(16384, 16384, 16384);
455        let storage: Arc<dyn irontide_storage::TorrentStorage> =
456            Arc::new(MemoryStorage::new(lengths));
457        storage.write_chunk(0, 0, &data).unwrap();
458
459        let config = crate::disk::DiskConfig::default();
460        let backend: Arc<dyn crate::disk_backend::DiskIoBackend> =
461            Arc::new(crate::disk_backend::PosixDiskIo::new(&config));
462        backend.register(info_hash, storage);
463
464        pool.submit(HashJob::Streaming {
465            piece: 0,
466            expected,
467            generation: 0,
468            info_hash,
469            backend,
470            result_tx: tx,
471        })
472        .await
473        .unwrap();
474
475        let r = rx.recv().await.unwrap();
476        assert!(r.passed, "streaming hash should pass");
477        assert_eq!(r.piece, 0);
478        assert_eq!(r.generation, 0);
479    }
480}