Skip to main content

fgumi_lib/sort/
pipeline.rs

1//! Pipelined merge with parallel I/O for BAM sorting.
2//!
3//! This module provides a multi-threaded merge implementation that overlaps
4//! I/O with computation to maximize throughput during the merge phase of
5//! external sort.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────┐    ┌─────────────┐    ┌─────────────┐
11//! │ Reader Pool │───>│ Merge Heap  │───>│   Writer    │
12//! │ (N threads) │    │ (1 thread)  │    │ (M threads) │
13//! └─────────────┘    └─────────────┘    └─────────────┘
14//!      │                   │                   │
15//!      ▼                   ▼                   ▼
16//!   Decompress          K-way              Compress
17//!   in parallel         merge              in parallel
18//! ```
19//!
20//! # Performance Benefits
21//!
22//! - **Parallel decompression**: Each chunk file is read by its own thread
23//! - **Overlapped I/O**: Reading and decompression happen in background
24//! - **Buffered prefetch**: Records are prefetched ahead of merge consumption
25//! - **Multi-threaded output**: Writing uses parallel BGZF compression
26
27use anyhow::{Context, Result};
28use crossbeam_channel::{Receiver, Sender, bounded};
29use noodles::bam::{self, Record};
30use noodles::bgzf;
31use noodles::sam::Header;
32use std::cmp::Ordering;
33use std::collections::BinaryHeap;
34use std::fs::File;
35use std::io::BufReader;
36use std::path::{Path, PathBuf};
37use std::thread::{self, JoinHandle};
38
39use super::MERGE_BUFFER_SIZE;
40use crate::bam_io::create_bam_writer;
41
42/// Number of records to prefetch per chunk reader.
43const PREFETCH_BUFFER_SIZE: usize = 128;
44
45/// A record with its sort key and source chunk index.
46pub struct MergeEntry<K> {
47    pub key: K,
48    pub record: Record,
49    pub chunk_idx: usize,
50}
51
52impl<K: PartialEq> PartialEq for MergeEntry<K> {
53    fn eq(&self, other: &Self) -> bool {
54        self.key == other.key
55    }
56}
57
58impl<K: Eq> Eq for MergeEntry<K> {}
59
60impl<K: PartialOrd> PartialOrd for MergeEntry<K> {
61    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
62        self.key.partial_cmp(&other.key)
63    }
64}
65
66impl<K: Ord> Ord for MergeEntry<K> {
67    fn cmp(&self, other: &Self) -> Ordering {
68        self.key.cmp(&other.key)
69    }
70}
71
72/// Configuration for parallel merge.
73pub struct ParallelMergeConfig {
74    /// Number of reader threads (typically `min(num_chunks, available_threads)`).
75    pub reader_threads: usize,
76    /// Number of writer threads for output compression.
77    pub writer_threads: usize,
78    /// Compression level for output.
79    pub compression_level: u32,
80}
81
82impl Default for ParallelMergeConfig {
83    fn default() -> Self {
84        Self { reader_threads: 4, writer_threads: 4, compression_level: 6 }
85    }
86}
87
88/// Prefetching chunk reader that runs in a background thread.
89///
90/// This reader maintains a buffer of prefetched records, allowing the merge
91/// thread to consume records without blocking on I/O.
92struct PrefetchingChunkReader {
93    /// Receiver for prefetched records.
94    record_rx: Receiver<Option<Record>>,
95    /// Handle to the reader thread.
96    _handle: JoinHandle<()>,
97    /// Chunk index for heap management.
98    idx: usize,
99}
100
101impl PrefetchingChunkReader {
102    /// Create a new prefetching reader for a chunk file.
103    #[allow(clippy::unnecessary_wraps)]
104    fn new(path: PathBuf, idx: usize) -> Result<Self> {
105        // Channel for prefetched records
106        let (record_tx, record_rx) = bounded(PREFETCH_BUFFER_SIZE);
107
108        // Spawn reader thread
109        let handle = thread::spawn(move || {
110            if let Err(e) = Self::reader_thread(path, record_tx) {
111                log::error!("Chunk reader thread failed: {e}");
112            }
113        });
114
115        Ok(Self { record_rx, _handle: handle, idx })
116    }
117
118    /// Reader thread function.
119    #[allow(clippy::needless_pass_by_value)]
120    fn reader_thread(path: PathBuf, tx: Sender<Option<Record>>) -> Result<()> {
121        let file = File::open(&path).context("Failed to open chunk file")?;
122        let buf_reader = BufReader::with_capacity(MERGE_BUFFER_SIZE, file);
123        let bgzf_reader = bgzf::io::Reader::new(buf_reader);
124        let mut reader = bam::io::Reader::from(bgzf_reader);
125
126        // Read and discard header
127        reader.read_header()?;
128
129        // Read records and send to channel
130        let mut record = Record::default();
131        loop {
132            match reader.read_record(&mut record) {
133                Ok(0) => {
134                    // EOF - send None to signal end
135                    let _ = tx.send(None);
136                    break;
137                }
138                Ok(_) => {
139                    // Clone and send record (take ownership of current, reset for next)
140                    let owned_record = std::mem::take(&mut record);
141                    if tx.send(Some(owned_record)).is_err() {
142                        // Receiver dropped, exit
143                        break;
144                    }
145                }
146                Err(e) => {
147                    log::error!("Error reading chunk: {e}");
148                    let _ = tx.send(None);
149                    break;
150                }
151            }
152        }
153
154        Ok(())
155    }
156
157    /// Get the next record from the prefetch buffer.
158    fn next(&self) -> Option<Record> {
159        match self.record_rx.recv() {
160            Ok(Some(record)) => Some(record),
161            Ok(None) | Err(_) => None,
162        }
163    }
164}
165
166/// Parallel merge implementation using prefetching readers.
167///
168/// # Errors
169///
170/// Returns an error if reading chunks, writing output, or merging fails.
171#[allow(clippy::needless_pass_by_value)]
172pub fn parallel_merge<K, F>(
173    chunk_files: &[PathBuf],
174    _header: &Header,
175    output_header: &Header,
176    output: &Path,
177    extract_key: F,
178    config: ParallelMergeConfig,
179) -> Result<u64>
180where
181    K: Clone + Send + Sync + Ord,
182    F: Fn(&Record) -> K + Send + Sync,
183{
184    log::info!(
185        "Starting parallel merge of {} chunks with {} reader threads",
186        chunk_files.len(),
187        config.reader_threads.min(chunk_files.len())
188    );
189
190    // Create prefetching readers for each chunk
191    let chunk_readers: Vec<PrefetchingChunkReader> = chunk_files
192        .iter()
193        .enumerate()
194        .map(|(idx, path)| PrefetchingChunkReader::new(path.clone(), idx))
195        .collect::<Result<Vec<_>>>()?;
196
197    // Initialize heap with first record from each chunk
198    let mut heap: BinaryHeap<std::cmp::Reverse<MergeEntry<K>>> =
199        BinaryHeap::with_capacity(chunk_files.len());
200
201    for reader in &chunk_readers {
202        if let Some(record) = reader.next() {
203            let key = extract_key(&record);
204            heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: reader.idx }));
205        }
206    }
207
208    // Create output writer with multi-threaded compression
209    let mut writer =
210        create_bam_writer(output, output_header, config.writer_threads, config.compression_level)?;
211
212    let mut records_merged = 0u64;
213
214    // Merge loop
215    while let Some(std::cmp::Reverse(entry)) = heap.pop() {
216        // Write record to output
217        writer.write_record(output_header, &entry.record)?;
218        records_merged += 1;
219
220        // Get next record from the same chunk (non-blocking due to prefetch buffer)
221        let reader = &chunk_readers[entry.chunk_idx];
222        if let Some(record) = reader.next() {
223            let key = extract_key(&record);
224            heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: entry.chunk_idx }));
225        }
226    }
227
228    log::info!("Parallel merge complete: {records_merged} records merged");
229
230    Ok(records_merged)
231}
232
233/// Parallel merge with output buffering for even higher throughput.
234///
235/// This version adds an output buffer that accumulates records before
236/// writing them to the output file, reducing the number of write calls.
237///
238/// # Errors
239///
240/// Returns an error if reading chunks, writing output, or merging fails.
241#[allow(clippy::needless_pass_by_value)]
242pub fn parallel_merge_buffered<K, F>(
243    chunk_files: &[PathBuf],
244    _header: &Header,
245    output_header: &Header,
246    output: &Path,
247    extract_key: F,
248    config: ParallelMergeConfig,
249) -> Result<u64>
250where
251    K: Clone + Send + Sync + Ord,
252    F: Fn(&Record) -> K + Send + Sync,
253{
254    const OUTPUT_BUFFER_SIZE: usize = 1024;
255
256    log::info!(
257        "Starting buffered parallel merge of {} chunks with {} reader threads",
258        chunk_files.len(),
259        config.reader_threads.min(chunk_files.len())
260    );
261
262    // Create prefetching readers for each chunk
263    let chunk_readers: Vec<PrefetchingChunkReader> = chunk_files
264        .iter()
265        .enumerate()
266        .map(|(idx, path)| PrefetchingChunkReader::new(path.clone(), idx))
267        .collect::<Result<Vec<_>>>()?;
268
269    // Initialize heap with first record from each chunk
270    let mut heap: BinaryHeap<std::cmp::Reverse<MergeEntry<K>>> =
271        BinaryHeap::with_capacity(chunk_files.len());
272
273    for reader in &chunk_readers {
274        if let Some(record) = reader.next() {
275            let key = extract_key(&record);
276            heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: reader.idx }));
277        }
278    }
279
280    // Create output writer with multi-threaded compression
281    let mut writer =
282        create_bam_writer(output, output_header, config.writer_threads, config.compression_level)?;
283
284    let mut records_merged = 0u64;
285    let mut output_buffer: Vec<Record> = Vec::with_capacity(OUTPUT_BUFFER_SIZE);
286
287    // Merge loop with output buffering
288    while let Some(std::cmp::Reverse(entry)) = heap.pop() {
289        output_buffer.push(entry.record);
290        records_merged += 1;
291
292        // Flush buffer if full
293        if output_buffer.len() >= OUTPUT_BUFFER_SIZE {
294            for record in output_buffer.drain(..) {
295                writer.write_record(output_header, &record)?;
296            }
297        }
298
299        // Get next record from the same chunk
300        let reader = &chunk_readers[entry.chunk_idx];
301        if let Some(record) = reader.next() {
302            let key = extract_key(&record);
303            heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: entry.chunk_idx }));
304        }
305    }
306
307    // Flush remaining buffered records
308    for record in output_buffer {
309        writer.write_record(output_header, &record)?;
310    }
311
312    log::info!("Buffered parallel merge complete: {records_merged} records merged");
313
314    Ok(records_merged)
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_merge_entry_ordering() {
323        let entry1 = MergeEntry { key: 1, record: Record::default(), chunk_idx: 0 };
324        let entry2 = MergeEntry { key: 2, record: Record::default(), chunk_idx: 1 };
325
326        assert!(entry1 < entry2);
327    }
328
329    #[test]
330    fn test_config_default() {
331        let config = ParallelMergeConfig::default();
332        assert_eq!(config.reader_threads, 4);
333        assert_eq!(config.writer_threads, 4);
334        assert_eq!(config.compression_level, 6);
335    }
336
337    #[test]
338    fn test_merge_entry_equal_keys() {
339        let entry1 = MergeEntry { key: 5, record: Record::default(), chunk_idx: 0 };
340        let entry2 = MergeEntry { key: 5, record: Record::default(), chunk_idx: 1 };
341
342        assert_eq!(entry1.cmp(&entry2), Ordering::Equal);
343    }
344
345    #[test]
346    fn test_merge_entry_greater_than() {
347        let entry1 = MergeEntry { key: 2, record: Record::default(), chunk_idx: 0 };
348        let entry2 = MergeEntry { key: 1, record: Record::default(), chunk_idx: 1 };
349
350        assert!(entry1 > entry2);
351    }
352
353    #[test]
354    fn test_merge_entry_ordering_ignores_chunk_idx() {
355        let entry1 = MergeEntry { key: 42, record: Record::default(), chunk_idx: 0 };
356        let entry2 = MergeEntry { key: 42, record: Record::default(), chunk_idx: 99 };
357
358        assert_eq!(entry1.cmp(&entry2), Ordering::Equal);
359    }
360
361    #[test]
362    fn test_merge_entry_partial_eq() {
363        let entry1 = MergeEntry { key: 10, record: Record::default(), chunk_idx: 0 };
364        let entry2 = MergeEntry { key: 10, record: Record::default(), chunk_idx: 3 };
365
366        assert!(entry1 == entry2);
367    }
368
369    #[test]
370    fn test_merge_entry_partial_eq_different() {
371        let entry1 = MergeEntry { key: 10, record: Record::default(), chunk_idx: 0 };
372        let entry2 = MergeEntry { key: 20, record: Record::default(), chunk_idx: 0 };
373
374        assert!(entry1 != entry2);
375    }
376
377    #[test]
378    fn test_merge_entry_string_keys() {
379        let entry_a =
380            MergeEntry { key: "apple".to_string(), record: Record::default(), chunk_idx: 0 };
381        let entry_b =
382            MergeEntry { key: "banana".to_string(), record: Record::default(), chunk_idx: 1 };
383        let entry_c =
384            MergeEntry { key: "cherry".to_string(), record: Record::default(), chunk_idx: 2 };
385
386        assert!(entry_a < entry_b);
387        assert!(entry_b < entry_c);
388        assert!(entry_a < entry_c);
389    }
390
391    #[test]
392    fn test_merge_entry_in_binary_heap() {
393        use std::cmp::Reverse;
394
395        let mut heap = BinaryHeap::new();
396        heap.push(Reverse(MergeEntry { key: 3, record: Record::default(), chunk_idx: 0 }));
397        heap.push(Reverse(MergeEntry { key: 1, record: Record::default(), chunk_idx: 1 }));
398        heap.push(Reverse(MergeEntry { key: 2, record: Record::default(), chunk_idx: 2 }));
399
400        // Should come out in ascending order: 1, 2, 3
401        assert_eq!(heap.pop().unwrap().0.key, 1);
402        assert_eq!(heap.pop().unwrap().0.key, 2);
403        assert_eq!(heap.pop().unwrap().0.key, 3);
404        assert!(heap.is_empty());
405    }
406
407    #[test]
408    fn test_config_custom_values() {
409        let config =
410            ParallelMergeConfig { reader_threads: 8, writer_threads: 16, compression_level: 9 };
411
412        assert_eq!(config.reader_threads, 8);
413        assert_eq!(config.writer_threads, 16);
414        assert_eq!(config.compression_level, 9);
415    }
416
417    #[test]
418    fn test_config_single_thread() {
419        let config =
420            ParallelMergeConfig { reader_threads: 1, writer_threads: 1, compression_level: 1 };
421
422        assert_eq!(config.reader_threads, 1);
423        assert_eq!(config.writer_threads, 1);
424        assert_eq!(config.compression_level, 1);
425    }
426
427    #[test]
428    fn test_merge_buffer_size() {
429        assert_eq!(MERGE_BUFFER_SIZE, 65536);
430    }
431
432    #[test]
433    fn test_prefetch_buffer_size() {
434        assert_eq!(PREFETCH_BUFFER_SIZE, 128);
435    }
436}