use anyhow::{Context, Result};
use crossbeam_channel::{Receiver, Sender, bounded};
use noodles::bam::{self, Record};
use noodles::bgzf;
use noodles::sam::Header;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::thread::{self, JoinHandle};
use super::MERGE_BUFFER_SIZE;
use crate::bam_io::create_bam_writer;
const PREFETCH_BUFFER_SIZE: usize = 128;
pub struct MergeEntry<K> {
pub key: K,
pub record: Record,
pub chunk_idx: usize,
}
impl<K: PartialEq> PartialEq for MergeEntry<K> {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
impl<K: Eq> Eq for MergeEntry<K> {}
impl<K: PartialOrd> PartialOrd for MergeEntry<K> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.key.partial_cmp(&other.key)
}
}
impl<K: Ord> Ord for MergeEntry<K> {
fn cmp(&self, other: &Self) -> Ordering {
self.key.cmp(&other.key)
}
}
pub struct ParallelMergeConfig {
pub reader_threads: usize,
pub writer_threads: usize,
pub compression_level: u32,
}
impl Default for ParallelMergeConfig {
fn default() -> Self {
Self { reader_threads: 4, writer_threads: 4, compression_level: 6 }
}
}
struct PrefetchingChunkReader {
record_rx: Receiver<Option<Record>>,
_handle: JoinHandle<()>,
idx: usize,
}
impl PrefetchingChunkReader {
#[allow(clippy::unnecessary_wraps)]
fn new(path: PathBuf, idx: usize) -> Result<Self> {
let (record_tx, record_rx) = bounded(PREFETCH_BUFFER_SIZE);
let handle = thread::spawn(move || {
if let Err(e) = Self::reader_thread(path, record_tx) {
log::error!("Chunk reader thread failed: {e}");
}
});
Ok(Self { record_rx, _handle: handle, idx })
}
#[allow(clippy::needless_pass_by_value)]
fn reader_thread(path: PathBuf, tx: Sender<Option<Record>>) -> Result<()> {
let file = File::open(&path).context("Failed to open chunk file")?;
let buf_reader = BufReader::with_capacity(MERGE_BUFFER_SIZE, file);
let bgzf_reader = bgzf::io::Reader::new(buf_reader);
let mut reader = bam::io::Reader::from(bgzf_reader);
reader.read_header()?;
let mut record = Record::default();
loop {
match reader.read_record(&mut record) {
Ok(0) => {
let _ = tx.send(None);
break;
}
Ok(_) => {
let owned_record = std::mem::take(&mut record);
if tx.send(Some(owned_record)).is_err() {
break;
}
}
Err(e) => {
log::error!("Error reading chunk: {e}");
let _ = tx.send(None);
break;
}
}
}
Ok(())
}
fn next(&self) -> Option<Record> {
match self.record_rx.recv() {
Ok(Some(record)) => Some(record),
Ok(None) | Err(_) => None,
}
}
}
#[allow(clippy::needless_pass_by_value)]
pub fn parallel_merge<K, F>(
chunk_files: &[PathBuf],
_header: &Header,
output_header: &Header,
output: &Path,
extract_key: F,
config: ParallelMergeConfig,
) -> Result<u64>
where
K: Clone + Send + Sync + Ord,
F: Fn(&Record) -> K + Send + Sync,
{
log::info!(
"Starting parallel merge of {} chunks with {} reader threads",
chunk_files.len(),
config.reader_threads.min(chunk_files.len())
);
let chunk_readers: Vec<PrefetchingChunkReader> = chunk_files
.iter()
.enumerate()
.map(|(idx, path)| PrefetchingChunkReader::new(path.clone(), idx))
.collect::<Result<Vec<_>>>()?;
let mut heap: BinaryHeap<std::cmp::Reverse<MergeEntry<K>>> =
BinaryHeap::with_capacity(chunk_files.len());
for reader in &chunk_readers {
if let Some(record) = reader.next() {
let key = extract_key(&record);
heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: reader.idx }));
}
}
let mut writer =
create_bam_writer(output, output_header, config.writer_threads, config.compression_level)?;
let mut records_merged = 0u64;
while let Some(std::cmp::Reverse(entry)) = heap.pop() {
writer.write_record(output_header, &entry.record)?;
records_merged += 1;
let reader = &chunk_readers[entry.chunk_idx];
if let Some(record) = reader.next() {
let key = extract_key(&record);
heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: entry.chunk_idx }));
}
}
log::info!("Parallel merge complete: {records_merged} records merged");
Ok(records_merged)
}
#[allow(clippy::needless_pass_by_value)]
pub fn parallel_merge_buffered<K, F>(
chunk_files: &[PathBuf],
_header: &Header,
output_header: &Header,
output: &Path,
extract_key: F,
config: ParallelMergeConfig,
) -> Result<u64>
where
K: Clone + Send + Sync + Ord,
F: Fn(&Record) -> K + Send + Sync,
{
const OUTPUT_BUFFER_SIZE: usize = 1024;
log::info!(
"Starting buffered parallel merge of {} chunks with {} reader threads",
chunk_files.len(),
config.reader_threads.min(chunk_files.len())
);
let chunk_readers: Vec<PrefetchingChunkReader> = chunk_files
.iter()
.enumerate()
.map(|(idx, path)| PrefetchingChunkReader::new(path.clone(), idx))
.collect::<Result<Vec<_>>>()?;
let mut heap: BinaryHeap<std::cmp::Reverse<MergeEntry<K>>> =
BinaryHeap::with_capacity(chunk_files.len());
for reader in &chunk_readers {
if let Some(record) = reader.next() {
let key = extract_key(&record);
heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: reader.idx }));
}
}
let mut writer =
create_bam_writer(output, output_header, config.writer_threads, config.compression_level)?;
let mut records_merged = 0u64;
let mut output_buffer: Vec<Record> = Vec::with_capacity(OUTPUT_BUFFER_SIZE);
while let Some(std::cmp::Reverse(entry)) = heap.pop() {
output_buffer.push(entry.record);
records_merged += 1;
if output_buffer.len() >= OUTPUT_BUFFER_SIZE {
for record in output_buffer.drain(..) {
writer.write_record(output_header, &record)?;
}
}
let reader = &chunk_readers[entry.chunk_idx];
if let Some(record) = reader.next() {
let key = extract_key(&record);
heap.push(std::cmp::Reverse(MergeEntry { key, record, chunk_idx: entry.chunk_idx }));
}
}
for record in output_buffer {
writer.write_record(output_header, &record)?;
}
log::info!("Buffered parallel merge complete: {records_merged} records merged");
Ok(records_merged)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merge_entry_ordering() {
let entry1 = MergeEntry { key: 1, record: Record::default(), chunk_idx: 0 };
let entry2 = MergeEntry { key: 2, record: Record::default(), chunk_idx: 1 };
assert!(entry1 < entry2);
}
#[test]
fn test_config_default() {
let config = ParallelMergeConfig::default();
assert_eq!(config.reader_threads, 4);
assert_eq!(config.writer_threads, 4);
assert_eq!(config.compression_level, 6);
}
#[test]
fn test_merge_entry_equal_keys() {
let entry1 = MergeEntry { key: 5, record: Record::default(), chunk_idx: 0 };
let entry2 = MergeEntry { key: 5, record: Record::default(), chunk_idx: 1 };
assert_eq!(entry1.cmp(&entry2), Ordering::Equal);
}
#[test]
fn test_merge_entry_greater_than() {
let entry1 = MergeEntry { key: 2, record: Record::default(), chunk_idx: 0 };
let entry2 = MergeEntry { key: 1, record: Record::default(), chunk_idx: 1 };
assert!(entry1 > entry2);
}
#[test]
fn test_merge_entry_ordering_ignores_chunk_idx() {
let entry1 = MergeEntry { key: 42, record: Record::default(), chunk_idx: 0 };
let entry2 = MergeEntry { key: 42, record: Record::default(), chunk_idx: 99 };
assert_eq!(entry1.cmp(&entry2), Ordering::Equal);
}
#[test]
fn test_merge_entry_partial_eq() {
let entry1 = MergeEntry { key: 10, record: Record::default(), chunk_idx: 0 };
let entry2 = MergeEntry { key: 10, record: Record::default(), chunk_idx: 3 };
assert!(entry1 == entry2);
}
#[test]
fn test_merge_entry_partial_eq_different() {
let entry1 = MergeEntry { key: 10, record: Record::default(), chunk_idx: 0 };
let entry2 = MergeEntry { key: 20, record: Record::default(), chunk_idx: 0 };
assert!(entry1 != entry2);
}
#[test]
fn test_merge_entry_string_keys() {
let entry_a =
MergeEntry { key: "apple".to_string(), record: Record::default(), chunk_idx: 0 };
let entry_b =
MergeEntry { key: "banana".to_string(), record: Record::default(), chunk_idx: 1 };
let entry_c =
MergeEntry { key: "cherry".to_string(), record: Record::default(), chunk_idx: 2 };
assert!(entry_a < entry_b);
assert!(entry_b < entry_c);
assert!(entry_a < entry_c);
}
#[test]
fn test_merge_entry_in_binary_heap() {
use std::cmp::Reverse;
let mut heap = BinaryHeap::new();
heap.push(Reverse(MergeEntry { key: 3, record: Record::default(), chunk_idx: 0 }));
heap.push(Reverse(MergeEntry { key: 1, record: Record::default(), chunk_idx: 1 }));
heap.push(Reverse(MergeEntry { key: 2, record: Record::default(), chunk_idx: 2 }));
assert_eq!(heap.pop().expect("heap should have elements").0.key, 1);
assert_eq!(heap.pop().expect("heap should have elements").0.key, 2);
assert_eq!(heap.pop().expect("heap should have elements").0.key, 3);
assert!(heap.is_empty());
}
#[test]
fn test_config_custom_values() {
let config =
ParallelMergeConfig { reader_threads: 8, writer_threads: 16, compression_level: 9 };
assert_eq!(config.reader_threads, 8);
assert_eq!(config.writer_threads, 16);
assert_eq!(config.compression_level, 9);
}
#[test]
fn test_config_single_thread() {
let config =
ParallelMergeConfig { reader_threads: 1, writer_threads: 1, compression_level: 1 };
assert_eq!(config.reader_threads, 1);
assert_eq!(config.writer_threads, 1);
assert_eq!(config.compression_level, 1);
}
#[test]
fn test_merge_buffer_size() {
assert_eq!(MERGE_BUFFER_SIZE, 65536);
}
#[test]
fn test_prefetch_buffer_size() {
assert_eq!(PREFETCH_BUFFER_SIZE, 128);
}
}