1use anyhow::{Context, Result};
28use crossbeam_channel::{Receiver, Sender, bounded};
29use noodles::bam::{self, Record};
30use noodles::bgzf;
31use noodles::sam::Header;
32use std::cmp::Ordering;
33use std::collections::BinaryHeap;
34use std::fs::File;
35use std::io::BufReader;
36use std::path::{Path, PathBuf};
37use std::thread::{self, JoinHandle};
38
39use super::MERGE_BUFFER_SIZE;
40use crate::bam_io::create_bam_writer;
41
42const PREFETCH_BUFFER_SIZE: usize = 128;
44
45pub struct MergeEntry<K> {
47 pub key: K,
48 pub record: Record,
49 pub chunk_idx: usize,
50}
51
52impl<K: PartialEq> PartialEq for MergeEntry<K> {
53 fn eq(&self, other: &Self) -> bool {
54 self.key == other.key
55 }
56}
57
58impl<K: Eq> Eq for MergeEntry<K> {}
59
60impl<K: PartialOrd> PartialOrd for MergeEntry<K> {
61 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
62 self.key.partial_cmp(&other.key)
63 }
64}
65
66impl<K: Ord> Ord for MergeEntry<K> {
67 fn cmp(&self, other: &Self) -> Ordering {
68 self.key.cmp(&other.key)
69 }
70}
71
72pub struct ParallelMergeConfig {
74 pub reader_threads: usize,
76 pub writer_threads: usize,
78 pub compression_level: u32,
80}
81
82impl Default for ParallelMergeConfig {
83 fn default() -> Self {
84 Self { reader_threads: 4, writer_threads: 4, compression_level: 6 }
85 }
86}
87
88struct PrefetchingChunkReader {
93 record_rx: Receiver<Option<Record>>,
95 _handle: JoinHandle<()>,
97 idx: usize,
99}
100
101impl PrefetchingChunkReader {
102 #[allow(clippy::unnecessary_wraps)]
104 fn new(path: PathBuf, idx: usize) -> Result<Self> {
105 let (record_tx, record_rx) = bounded(PREFETCH_BUFFER_SIZE);
107
108 let handle = thread::spawn(move || {
110 if let Err(e) = Self::reader_thread(path, record_tx) {
111 log::error!("Chunk reader thread failed: {e}");
112 }
113 });
114
115 Ok(Self { record_rx, _handle: handle, idx })
116 }
117
118 #[allow(clippy::needless_pass_by_value)]
120 fn reader_thread(path: PathBuf, tx: Sender<Option<Record>>) -> Result<()> {
121 let file = File::open(&path).context("Failed to open chunk file")?;
122 let buf_reader = BufReader::with_capacity(MERGE_BUFFER_SIZE, file);
123 let bgzf_reader = bgzf::io::Reader::new(buf_reader);
124 let mut reader = bam::io::Reader::from(bgzf_reader);
125
126 reader.read_header()?;
128
129 let mut record = Record::default();
131 loop {
132 match reader.read_record(&mut record) {
133 Ok(0) => {
134 let _ = tx.send(None);
136 break;
137 }
138 Ok(_) => {
139 let owned_record = std::mem::take(&mut record);
141 if tx.send(Some(owned_record)).is_err() {
142 break;
144 }
145 }
146 Err(e) => {
147 log::error!("Error reading chunk: {e}");
148 let _ = tx.send(None);
149 break;
150 }
151 }
152 }
153
154 Ok(())
155 }
156
157 fn next(&self) -> Option<Record> {
159 match self.record_rx.recv() {
160 Ok(Some(record)) => Some(record),
161 Ok(None) | Err(_) => None,
162 }
163 }
164}
165
166#[allow(clippy::needless_pass_by_value)]
172pub fn parallel_merge<K, F>(
173 chunk_files: &[PathBuf],
174 _header: &Header,
175 output_header: &Header,
176 output: &Path,
177 extract_key: F,
178 config: ParallelMergeConfig,
179) -> Result<u64>
180where
181 K: Clone + Send + Sync + Ord,
182 F: Fn(&Record) -> K + Send + Sync,
183{
184 log::info!(
185 "Starting parallel merge of {} chunks with {} reader threads",
186 chunk_files.len(),
187 config.reader_threads.min(chunk_files.len())
188 );
189
190 let chunk_readers: Vec<PrefetchingChunkReader> = chunk_files
192 .iter()
193 .enumerate()
194 .map(|(idx, path)| PrefetchingChunkReader::new(path.clone(), idx))
195 .collect::<Result<Vec<_>>>()?;
196
197 let mut heap: BinaryHeap<std::cmp::Reverse<MergeEntry<K>>> =
199 BinaryHeap::with_capacity(chunk_files.len());
200
201 for reader in &chunk_readers {
202 if let Some(record) = reader.next() {
203 let key = extract_key(&record);
204 heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: reader.idx }));
205 }
206 }
207
208 let mut writer =
210 create_bam_writer(output, output_header, config.writer_threads, config.compression_level)?;
211
212 let mut records_merged = 0u64;
213
214 while let Some(std::cmp::Reverse(entry)) = heap.pop() {
216 writer.write_record(output_header, &entry.record)?;
218 records_merged += 1;
219
220 let reader = &chunk_readers[entry.chunk_idx];
222 if let Some(record) = reader.next() {
223 let key = extract_key(&record);
224 heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: entry.chunk_idx }));
225 }
226 }
227
228 log::info!("Parallel merge complete: {records_merged} records merged");
229
230 Ok(records_merged)
231}
232
233#[allow(clippy::needless_pass_by_value)]
242pub fn parallel_merge_buffered<K, F>(
243 chunk_files: &[PathBuf],
244 _header: &Header,
245 output_header: &Header,
246 output: &Path,
247 extract_key: F,
248 config: ParallelMergeConfig,
249) -> Result<u64>
250where
251 K: Clone + Send + Sync + Ord,
252 F: Fn(&Record) -> K + Send + Sync,
253{
254 const OUTPUT_BUFFER_SIZE: usize = 1024;
255
256 log::info!(
257 "Starting buffered parallel merge of {} chunks with {} reader threads",
258 chunk_files.len(),
259 config.reader_threads.min(chunk_files.len())
260 );
261
262 let chunk_readers: Vec<PrefetchingChunkReader> = chunk_files
264 .iter()
265 .enumerate()
266 .map(|(idx, path)| PrefetchingChunkReader::new(path.clone(), idx))
267 .collect::<Result<Vec<_>>>()?;
268
269 let mut heap: BinaryHeap<std::cmp::Reverse<MergeEntry<K>>> =
271 BinaryHeap::with_capacity(chunk_files.len());
272
273 for reader in &chunk_readers {
274 if let Some(record) = reader.next() {
275 let key = extract_key(&record);
276 heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: reader.idx }));
277 }
278 }
279
280 let mut writer =
282 create_bam_writer(output, output_header, config.writer_threads, config.compression_level)?;
283
284 let mut records_merged = 0u64;
285 let mut output_buffer: Vec<Record> = Vec::with_capacity(OUTPUT_BUFFER_SIZE);
286
287 while let Some(std::cmp::Reverse(entry)) = heap.pop() {
289 output_buffer.push(entry.record);
290 records_merged += 1;
291
292 if output_buffer.len() >= OUTPUT_BUFFER_SIZE {
294 for record in output_buffer.drain(..) {
295 writer.write_record(output_header, &record)?;
296 }
297 }
298
299 let reader = &chunk_readers[entry.chunk_idx];
301 if let Some(record) = reader.next() {
302 let key = extract_key(&record);
303 heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: entry.chunk_idx }));
304 }
305 }
306
307 for record in output_buffer {
309 writer.write_record(output_header, &record)?;
310 }
311
312 log::info!("Buffered parallel merge complete: {records_merged} records merged");
313
314 Ok(records_merged)
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_merge_entry_ordering() {
323 let entry1 = MergeEntry { key: 1, record: Record::default(), chunk_idx: 0 };
324 let entry2 = MergeEntry { key: 2, record: Record::default(), chunk_idx: 1 };
325
326 assert!(entry1 < entry2);
327 }
328
329 #[test]
330 fn test_config_default() {
331 let config = ParallelMergeConfig::default();
332 assert_eq!(config.reader_threads, 4);
333 assert_eq!(config.writer_threads, 4);
334 assert_eq!(config.compression_level, 6);
335 }
336
337 #[test]
338 fn test_merge_entry_equal_keys() {
339 let entry1 = MergeEntry { key: 5, record: Record::default(), chunk_idx: 0 };
340 let entry2 = MergeEntry { key: 5, record: Record::default(), chunk_idx: 1 };
341
342 assert_eq!(entry1.cmp(&entry2), Ordering::Equal);
343 }
344
345 #[test]
346 fn test_merge_entry_greater_than() {
347 let entry1 = MergeEntry { key: 2, record: Record::default(), chunk_idx: 0 };
348 let entry2 = MergeEntry { key: 1, record: Record::default(), chunk_idx: 1 };
349
350 assert!(entry1 > entry2);
351 }
352
353 #[test]
354 fn test_merge_entry_ordering_ignores_chunk_idx() {
355 let entry1 = MergeEntry { key: 42, record: Record::default(), chunk_idx: 0 };
356 let entry2 = MergeEntry { key: 42, record: Record::default(), chunk_idx: 99 };
357
358 assert_eq!(entry1.cmp(&entry2), Ordering::Equal);
359 }
360
361 #[test]
362 fn test_merge_entry_partial_eq() {
363 let entry1 = MergeEntry { key: 10, record: Record::default(), chunk_idx: 0 };
364 let entry2 = MergeEntry { key: 10, record: Record::default(), chunk_idx: 3 };
365
366 assert!(entry1 == entry2);
367 }
368
369 #[test]
370 fn test_merge_entry_partial_eq_different() {
371 let entry1 = MergeEntry { key: 10, record: Record::default(), chunk_idx: 0 };
372 let entry2 = MergeEntry { key: 20, record: Record::default(), chunk_idx: 0 };
373
374 assert!(entry1 != entry2);
375 }
376
377 #[test]
378 fn test_merge_entry_string_keys() {
379 let entry_a =
380 MergeEntry { key: "apple".to_string(), record: Record::default(), chunk_idx: 0 };
381 let entry_b =
382 MergeEntry { key: "banana".to_string(), record: Record::default(), chunk_idx: 1 };
383 let entry_c =
384 MergeEntry { key: "cherry".to_string(), record: Record::default(), chunk_idx: 2 };
385
386 assert!(entry_a < entry_b);
387 assert!(entry_b < entry_c);
388 assert!(entry_a < entry_c);
389 }
390
391 #[test]
392 fn test_merge_entry_in_binary_heap() {
393 use std::cmp::Reverse;
394
395 let mut heap = BinaryHeap::new();
396 heap.push(Reverse(MergeEntry { key: 3, record: Record::default(), chunk_idx: 0 }));
397 heap.push(Reverse(MergeEntry { key: 1, record: Record::default(), chunk_idx: 1 }));
398 heap.push(Reverse(MergeEntry { key: 2, record: Record::default(), chunk_idx: 2 }));
399
400 assert_eq!(heap.pop().unwrap().0.key, 1);
402 assert_eq!(heap.pop().unwrap().0.key, 2);
403 assert_eq!(heap.pop().unwrap().0.key, 3);
404 assert!(heap.is_empty());
405 }
406
407 #[test]
408 fn test_config_custom_values() {
409 let config =
410 ParallelMergeConfig { reader_threads: 8, writer_threads: 16, compression_level: 9 };
411
412 assert_eq!(config.reader_threads, 8);
413 assert_eq!(config.writer_threads, 16);
414 assert_eq!(config.compression_level, 9);
415 }
416
417 #[test]
418 fn test_config_single_thread() {
419 let config =
420 ParallelMergeConfig { reader_threads: 1, writer_threads: 1, compression_level: 1 };
421
422 assert_eq!(config.reader_threads, 1);
423 assert_eq!(config.writer_threads, 1);
424 assert_eq!(config.compression_level, 1);
425 }
426
427 #[test]
428 fn test_merge_buffer_size() {
429 assert_eq!(MERGE_BUFFER_SIZE, 65536);
430 }
431
432 #[test]
433 fn test_prefetch_buffer_size() {
434 assert_eq!(PREFETCH_BUFFER_SIZE, 128);
435 }
436}