use crate::bam_io::create_raw_bam_reader;
use crate::progress::ProgressTracker;
#[cfg(test)]
use crate::sam::SamTag;
use crate::sort::inline_buffer::{
ProbeableBuffer, RecordBuffer, TemplateKey, TemplateRecordBuffer,
};
use crate::sort::keys::{QuerynameComparator, RawSortKey, SortOrder};
use crate::sort::memory_probe::{
BufferProbeStats, ConsumerProbeStats, MergeProbe, SpillProbe, force_mi_collect, log_snapshot,
};
use crate::sort::pooled_chunk_writer::PooledChunkWriter;
use crate::sort::read_ahead::{RawReadAheadReader, RecordSource};
use crate::sort::tmp_dir_alloc::TmpDirAllocator;
use crate::sort::worker_pool::SortWorkerPool;
use anyhow::Result;
use crossbeam_channel::{Receiver, Sender, bounded};
use log::{debug, info};
use noodles::sam::Header;
use noodles::sam::header::record::value::map::read_group::tag as rg_tag;
use noodles_bgzf::io::{
MultithreadedWriter, Reader as BgzfReader, Writer as BgzfWriter, multithreaded_writer,
writer::CompressionLevel,
};
use std::collections::HashMap;
use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use std::num::NonZero;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use tempfile::TempDir;
#[derive(Debug, Default)]
struct SortPhaseTimer {
read_secs: f64,
sort_secs: f64,
spill_write_secs: f64,
consolidate_secs: f64,
merge_secs: f64,
write_output_secs: f64,
spill_count: usize,
consolidate_count: usize,
total_spill_bytes: u64,
overall_start: Option<Instant>,
read_span_start: Option<Instant>,
}
impl SortPhaseTimer {
fn new() -> Self {
Self {
overall_start: Some(Instant::now()),
read_span_start: Some(Instant::now()),
..Default::default()
}
}
fn end_read_span(&mut self) -> Duration {
if let Some(start) = self.read_span_start.take() {
let elapsed = start.elapsed();
self.read_secs += elapsed.as_secs_f64();
elapsed
} else {
Duration::ZERO
}
}
fn begin_read_span(&mut self) {
self.read_span_start = Some(Instant::now());
}
fn time<T>(field: &mut f64, f: impl FnOnce() -> T) -> T {
let start = Instant::now();
let result = f();
*field += start.elapsed().as_secs_f64();
result
}
fn time_sort<T>(&mut self, f: impl FnOnce() -> T) -> T {
Self::time(&mut self.sort_secs, f)
}
fn time_spill_write<T>(&mut self, f: impl FnOnce() -> Result<T>) -> Result<T> {
let result = Self::time(&mut self.spill_write_secs, f);
self.spill_count += 1;
result
}
fn record_spill_size(&mut self, path: &Path) {
if let Ok(meta) = std::fs::metadata(path) {
self.total_spill_bytes += meta.len();
}
}
fn time_consolidate(&mut self, f: impl FnOnce() -> Result<()>) -> Result<()> {
let start = Instant::now();
let result = f();
let elapsed = start.elapsed().as_secs_f64();
if elapsed > 0.001 {
self.consolidate_secs += elapsed;
self.consolidate_count += 1;
}
result
}
fn time_merge<T>(&mut self, f: impl FnOnce() -> Result<T>) -> Result<T> {
Self::time(&mut self.merge_secs, f)
}
fn time_write_output(&mut self, f: impl FnOnce() -> Result<()>) -> Result<()> {
Self::time(&mut self.write_output_secs, f)
}
#[allow(clippy::cast_precision_loss)]
fn log_summary(&self, threads: usize) {
let overall = self.overall_start.map_or(0.0, |s| s.elapsed().as_secs_f64());
let overall_nonzero = if overall > 0.0 { overall } else { f64::EPSILON };
let read_pct = 100.0 * self.read_secs / overall_nonzero;
let sort_pct = 100.0 * self.sort_secs / overall_nonzero;
let spill_pct = 100.0 * self.spill_write_secs / overall_nonzero;
let spill_count = self.spill_count;
let read_secs = self.read_secs;
let sort_secs = self.sort_secs;
let spill_secs = self.spill_write_secs;
info!("=== Sort Phase Timing ===");
info!(" Read + decompress: {read_secs:.1}s ({read_pct:.0}%)");
info!(" In-memory sort: {sort_secs:.1}s ({sort_pct:.0}%) [{spill_count} spills]");
let spill_mb = self.total_spill_bytes as f64 / (1024.0 * 1024.0);
info!(
" Spill write: {spill_secs:.1}s ({spill_pct:.0}%) [{spill_count} writes, {spill_mb:.1} MB total]"
);
if self.consolidate_count > 0 {
let cons_secs = self.consolidate_secs;
let cons_pct = 100.0 * cons_secs / overall_nonzero;
let cons_count = self.consolidate_count;
info!(" Consolidation: {cons_secs:.1}s ({cons_pct:.0}%) [{cons_count} merges]");
}
if self.merge_secs > 0.0 {
let merge_secs = self.merge_secs;
let merge_pct = 100.0 * merge_secs / overall_nonzero;
info!(" K-way merge: {merge_secs:.1}s ({merge_pct:.0}%)");
}
if self.write_output_secs > 0.0 {
let write_secs = self.write_output_secs;
let write_pct = 100.0 * write_secs / overall_nonzero;
info!(" Write output: {write_secs:.1}s ({write_pct:.0}%)");
}
info!(" Total wall clock: {overall:.1}s");
info!(" Threads: {threads}");
info!("=========================");
}
}
#[must_use]
pub fn cb_hasher() -> ahash::RandomState {
ahash::RandomState::with_seeds(
0xa1b2_c3d4_e5f6_0718,
0x9182_7364_5546_3728,
0xfede_dcba_0987_6543,
0x0011_2233_4455_6677,
)
}
pub struct LibraryLookup {
rg_to_ordinal: HashMap<Vec<u8>, u32>,
hasher: ahash::RandomState,
}
impl LibraryLookup {
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn from_header(header: &Header) -> Self {
let mut libraries: Vec<String> = header
.read_groups()
.iter()
.filter_map(|(_, rg)| {
rg.other_fields().get(&rg_tag::LIBRARY).map(std::string::ToString::to_string)
})
.collect();
libraries.sort();
libraries.dedup();
let mut lib_to_ordinal: HashMap<String, u32> = HashMap::new();
lib_to_ordinal.insert(String::new(), 0);
for (i, lib) in libraries.iter().enumerate() {
lib_to_ordinal.insert(lib.clone(), (i + 1) as u32);
}
let rg_to_ordinal: HashMap<Vec<u8>, u32> = header
.read_groups()
.iter()
.map(|(id, rg)| {
let lib = rg
.other_fields()
.get(&rg_tag::LIBRARY)
.map(std::string::ToString::to_string)
.unwrap_or_default();
let ordinal = *lib_to_ordinal.get(&lib).unwrap_or(&0);
(id.to_vec(), ordinal)
})
.collect();
let hasher = ahash::RandomState::with_seeds(
0x517c_c1b7_2722_0a95,
0x1234_5678_90ab_cdef,
0xfedc_ba98_7654_3210,
0x0123_4567_89ab_cdef,
);
Self { rg_to_ordinal, hasher }
}
#[inline]
#[must_use]
pub fn hash_name(&self, name: &[u8]) -> u64 {
self.hasher.hash_one(name)
}
#[cfg(test)]
#[must_use]
pub fn get_ordinal(&self, bam: &[u8]) -> u32 {
fgumi_raw_bam::RawRecordView::new(bam)
.tags()
.find_string(&SamTag::RG)
.and_then(|rg| self.rg_to_ordinal.get(rg))
.copied()
.unwrap_or(0)
}
#[inline]
#[must_use]
pub fn ordinal_from_rg(&self, rg: Option<&[u8]>) -> u32 {
rg.and_then(|rg| self.rg_to_ordinal.get(rg)).copied().unwrap_or(0)
}
}
const MERGE_PREFETCH_SIZE: usize = 1024;
const DEFAULT_MAX_TEMP_FILES: usize = 64;
pub(crate) type ChunkReaderSemaphore = (Sender<()>, Receiver<()>);
pub(crate) fn make_reader_semaphore(threads: usize) -> Arc<ChunkReaderSemaphore> {
let limit = threads.max(1);
let (tx, rx) = bounded(limit);
for _ in 0..limit {
tx.send(()).expect("semaphore channel must not be disconnected during initialization");
}
Arc::new((tx, rx))
}
use std::marker::PhantomData;
enum ChunkWriterInner {
Raw(BufWriter<std::fs::File>),
SingleThreaded(BgzfWriter<BufWriter<std::fs::File>>),
MultiThreaded(MultithreadedWriter<BufWriter<std::fs::File>>),
}
impl Write for ChunkWriterInner {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
ChunkWriterInner::Raw(w) => w.write(buf),
ChunkWriterInner::SingleThreaded(w) => w.write(buf),
ChunkWriterInner::MultiThreaded(w) => w.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
ChunkWriterInner::Raw(w) => w.flush(),
ChunkWriterInner::SingleThreaded(w) => w.flush(),
ChunkWriterInner::MultiThreaded(w) => w.flush(),
}
}
}
impl ChunkWriterInner {
fn finish(self) -> Result<()> {
match self {
ChunkWriterInner::Raw(mut w) => {
w.flush()?;
Ok(())
}
ChunkWriterInner::SingleThreaded(w) => {
w.finish()?;
Ok(())
}
ChunkWriterInner::MultiThreaded(mut w) => {
w.finish()?;
Ok(())
}
}
}
}
pub struct GenericKeyedChunkWriter<K: RawSortKey> {
writer: ChunkWriterInner,
_marker: PhantomData<K>,
}
impl<K: RawSortKey> GenericKeyedChunkWriter<K> {
pub fn create(path: &Path, compression_level: u32, threads: usize) -> Result<Self> {
let file = std::fs::File::create(path)?;
let buf = BufWriter::with_capacity(256 * 1024, file);
let writer = if compression_level == 0 {
ChunkWriterInner::Raw(buf)
} else if threads > 1 {
let worker_count = NonZero::new(threads).expect("threads > 1");
let mut builder =
multithreaded_writer::Builder::default().set_worker_count(worker_count);
#[allow(clippy::cast_possible_truncation)]
if let Some(level) = CompressionLevel::new(compression_level as u8) {
builder = builder.set_compression_level(level);
}
ChunkWriterInner::MultiThreaded(builder.build_from_writer(buf))
} else {
#[allow(clippy::cast_possible_truncation)]
let level = CompressionLevel::new(compression_level as u8).unwrap_or_else(|| {
CompressionLevel::new(6).expect("compression level 6 is always valid")
});
let writer = noodles_bgzf::io::writer::Builder::default()
.set_compression_level(level)
.build_from_writer(buf);
ChunkWriterInner::SingleThreaded(writer)
};
Ok(Self { writer, _marker: PhantomData })
}
#[inline]
#[allow(clippy::cast_possible_truncation)]
pub fn write_record(&mut self, key: &K, record: &[u8]) -> Result<()> {
if !K::EMBEDDED_IN_RECORD {
key.write_to(&mut self.writer)?;
}
self.writer.write_all(&(record.len() as u32).to_le_bytes())?;
self.writer.write_all(record)?;
Ok(())
}
pub fn finish(self) -> Result<()> {
self.writer.finish()
}
}
type ChunkReadResult<K> = Result<Option<(K, Vec<u8>)>>;
fn read_exact_or_eof<R: Read>(reader: &mut R, buf: &mut [u8]) -> std::io::Result<bool> {
let mut offset = 0;
while offset < buf.len() {
match reader.read(&mut buf[offset..]) {
Ok(0) => {
return if offset == 0 {
Ok(false) } else {
Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!("truncated chunk: read {} of {} bytes", offset, buf.len()),
))
};
}
Ok(n) => offset += n,
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(true)
}
pub struct GenericKeyedChunkReader<K: RawSortKey + 'static> {
receiver: Receiver<ChunkReadResult<K>>,
buf_return: Sender<Vec<u8>>,
_handle: JoinHandle<()>,
}
impl<K: RawSortKey + 'static> GenericKeyedChunkReader<K> {
pub fn open(path: &Path, concurrency_limit: Option<Arc<ChunkReaderSemaphore>>) -> Result<Self> {
let (tx, rx) = bounded(MERGE_PREFETCH_SIZE);
let (buf_tx, buf_rx) = bounded::<Vec<u8>>(MERGE_PREFETCH_SIZE);
let path = path.to_path_buf();
let handle = thread::spawn(move || {
let file = match std::fs::File::open(&path) {
Ok(f) => f,
Err(e) => {
let _ = tx.send(Err(anyhow::anyhow!(
"Failed to open keyed chunk {}: {e}",
path.display()
)));
return;
}
};
let mut buf_reader = BufReader::with_capacity(2 * 1024 * 1024, file);
let mut magic = [0u8; 2];
let is_compressed = if buf_reader.read_exact(&mut magic).is_ok() {
magic == [0x1f, 0x8b]
} else {
false
};
if buf_reader.seek(SeekFrom::Start(0)).is_err() {
let _ = tx
.send(Err(anyhow::anyhow!("Failed to seek in keyed chunk {}", path.display())));
return;
}
if is_compressed {
let bgzf_reader = BgzfReader::new(buf_reader);
Self::read_records(bgzf_reader, tx, buf_rx, concurrency_limit);
} else {
Self::read_records(buf_reader, tx, buf_rx, concurrency_limit);
}
});
Ok(Self { receiver: rx, buf_return: buf_tx, _handle: handle })
}
#[allow(clippy::needless_pass_by_value)]
fn read_records<R: Read>(
mut reader: R,
tx: crossbeam_channel::Sender<ChunkReadResult<K>>,
buf_pool: crossbeam_channel::Receiver<Vec<u8>>,
semaphore: Option<Arc<ChunkReaderSemaphore>>,
) {
const BATCH_SIZE: usize = 64;
loop {
if let Some(ref sem) = semaphore {
let _ = sem.1.recv();
}
let mut batch: Vec<(K, Vec<u8>)> = Vec::with_capacity(BATCH_SIZE);
let mut eof = false;
let mut read_error: Option<String> = None;
for _ in 0..BATCH_SIZE {
if K::EMBEDDED_IN_RECORD {
let mut len_buf = [0u8; 4];
match read_exact_or_eof(&mut reader, &mut len_buf) {
Ok(true) => {}
Ok(false) => {
eof = true;
break;
}
Err(e) => {
read_error = Some(format!("Error reading chunk record length: {e}"));
break;
}
}
let len = u32::from_le_bytes(len_buf) as usize;
let mut record = buf_pool.try_recv().unwrap_or_default();
record.clear();
record.resize(len, 0);
if let Err(e) = reader.read_exact(&mut record) {
read_error = Some(format!("Error reading chunk record: {e}"));
break;
}
let key = K::extract_from_record(&record);
batch.push((key, record));
} else {
let key = match K::read_from(&mut reader) {
Ok(k) => k,
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
eof = true;
break;
}
Err(e) => {
read_error = Some(format!("Error reading keyed chunk key: {e}"));
break;
}
};
let mut len_buf = [0u8; 4];
match reader.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) => {
read_error = Some(format!("Error reading keyed chunk length: {e}"));
break;
}
}
let len = u32::from_le_bytes(len_buf) as usize;
let mut record = buf_pool.try_recv().unwrap_or_default();
record.clear();
record.resize(len, 0);
if let Err(e) = reader.read_exact(&mut record) {
read_error = Some(format!("Error reading keyed chunk record: {e}"));
break;
}
batch.push((key, record));
}
}
if let Some(ref sem) = semaphore {
let _ = sem.0.send(());
}
for record in batch {
if tx.send(Ok(Some(record))).is_err() {
return; }
}
if let Some(msg) = read_error {
let _ = tx.send(Err(anyhow::anyhow!("{msg}")));
break;
}
if eof {
let _ = tx.send(Ok(None));
break;
}
}
}
pub fn next_record(&mut self, buf: &mut Vec<u8>) -> Result<Option<K>> {
match self.receiver.recv() {
Ok(Ok(Some((key, mut data)))) => {
std::mem::swap(buf, &mut data);
let _ = self.buf_return.try_send(data);
Ok(Some(key))
}
Ok(Ok(None)) => Ok(None),
Ok(Err(e)) => Err(e),
Err(_) => Err(anyhow::anyhow!("chunk reader thread terminated unexpectedly")),
}
}
pub fn try_next_record(&mut self) -> Option<ChunkReadResult<K>> {
match self.receiver.try_recv() {
Ok(result) => Some(result),
Err(crossbeam_channel::TryRecvError::Disconnected) => {
Some(Err(anyhow::anyhow!("chunk reader thread terminated unexpectedly")))
}
Err(crossbeam_channel::TryRecvError::Empty) => None,
}
}
}
enum ChunkSource<K: RawSortKey + Default + 'static> {
Disk(GenericKeyedChunkReader<K>),
Memory { records: Vec<(K, fgumi_raw_bam::RawRecord)>, idx: usize },
PoolDisk { source_id: usize },
}
impl<K: RawSortKey + Default + 'static> ChunkSource<K> {
fn next_record(
&mut self,
buf: &mut Vec<u8>,
consumer: Option<&mut MainThreadChunkConsumer<K>>,
) -> Result<Option<K>> {
match self {
ChunkSource::Disk(reader) => reader.next_record(buf),
ChunkSource::Memory { records, idx } => {
if *idx < records.len() {
let (ref mut key, ref mut data) = records[*idx];
std::mem::swap(buf, data.as_mut_vec());
let key = std::mem::take(key);
*idx += 1;
Ok(Some(key))
} else {
Ok(None)
}
}
ChunkSource::PoolDisk { source_id } => consumer
.ok_or_else(|| {
anyhow::anyhow!(
"PoolDisk source (id {source_id}) requires a MainThreadChunkConsumer \
but none was provided — this is a bug in the sort pipeline"
)
})?
.next_record(*source_id, buf),
}
}
}
struct SourceParserState {
current_buf: Vec<u8>,
current_pos: usize,
}
impl SourceParserState {
fn new() -> Self {
Self { current_buf: Vec::new(), current_pos: 0 }
}
fn remaining(&self) -> usize {
self.current_buf.len() - self.current_pos
}
}
pub(crate) struct MainThreadChunkConsumer<K: RawSortKey + 'static> {
files: Arc<Vec<crate::sort::worker_pool::Phase2FileState>>,
parser_state: Vec<SourceParserState>,
decompression_error: std::sync::Arc<std::sync::atomic::AtomicBool>,
chunk_read_error: std::sync::Arc<std::sync::atomic::AtomicBool>,
worker_panicked: std::sync::Arc<std::sync::atomic::AtomicBool>,
_phantom: std::marker::PhantomData<K>,
}
impl<K: RawSortKey + 'static> MainThreadChunkConsumer<K> {
#[must_use]
pub(crate) fn new(
files: Arc<Vec<crate::sort::worker_pool::Phase2FileState>>,
decompression_error: std::sync::Arc<std::sync::atomic::AtomicBool>,
chunk_read_error: std::sync::Arc<std::sync::atomic::AtomicBool>,
worker_panicked: std::sync::Arc<std::sync::atomic::AtomicBool>,
) -> Self {
let parser_state = (0..files.len()).map(|_| SourceParserState::new()).collect();
Self {
files,
parser_state,
decompression_error,
chunk_read_error,
worker_panicked,
_phantom: std::marker::PhantomData,
}
}
pub fn next_record(&mut self, source_id: usize, buf: &mut Vec<u8>) -> Result<Option<K>> {
if self.parser_state[source_id].remaining() == 0
&& !self.advance_to_next_block(source_id)?
{
return Ok(None);
}
self.parse_next_record(source_id, buf)
}
fn advance_to_next_block(&mut self, source_id: usize) -> Result<bool> {
let file = &self.files[source_id];
loop {
{
let mut guard =
file.decompressed.lock().expect("phase2 decompressed mutex poisoned");
if let Some(data) = guard.try_pop_next() {
drop(guard);
let st = &mut self.parser_state[source_id];
st.current_buf = data;
st.current_pos = 0;
return Ok(true);
}
}
if self.decompression_error.load(std::sync::atomic::Ordering::Acquire) {
return Err(anyhow::anyhow!(
"BGZF decompression error on chunk blocks (see log for details)"
));
}
if self.chunk_read_error.load(std::sync::atomic::Ordering::Acquire) {
return Err(anyhow::anyhow!("I/O error reading chunk file (see log for details)"));
}
if self.worker_panicked.load(std::sync::atomic::Ordering::Acquire) {
return Err(anyhow::anyhow!(
"a sort worker thread panicked unexpectedly (see log for details)"
));
}
if file.is_drained() {
return Ok(false);
}
std::thread::park();
}
}
fn parse_next_record(&mut self, source_id: usize, buf: &mut Vec<u8>) -> Result<Option<K>> {
let mut len_buf = [0u8; 4];
if K::EMBEDDED_IN_RECORD {
if !self.read_exact_from_source(source_id, &mut len_buf)? {
return Ok(None);
}
let len = u32::from_le_bytes(len_buf) as usize;
buf.clear();
buf.resize(len, 0);
if !self.read_exact_from_source(source_id, buf)? {
return Err(anyhow::anyhow!("truncated record in chunk source {source_id}"));
}
let key = K::extract_from_record(buf);
Ok(Some(key))
} else {
let Some(key) = self.read_key_from_source::<K>(source_id)? else {
return Ok(None);
};
if !self.read_exact_from_source(source_id, &mut len_buf)? {
return Err(anyhow::anyhow!("truncated record length in chunk source {source_id}"));
}
let len = u32::from_le_bytes(len_buf) as usize;
buf.clear();
buf.resize(len, 0);
if !self.read_exact_from_source(source_id, buf)? {
return Err(anyhow::anyhow!("truncated record in chunk source {source_id}"));
}
Ok(Some(key))
}
}
fn read_exact_from_source(&mut self, source_id: usize, out: &mut [u8]) -> Result<bool> {
let n = out.len();
let mut filled = 0;
while filled < n {
if self.parser_state[source_id].remaining() == 0
&& !self.advance_to_next_block(source_id)?
{
if filled == 0 {
return Ok(false);
}
return Err(anyhow::anyhow!(
"truncated data in chunk source {source_id}: got {filled} of {n} bytes",
));
}
let st = &mut self.parser_state[source_id];
let take = (n - filled).min(st.remaining());
out[filled..filled + take]
.copy_from_slice(&st.current_buf[st.current_pos..st.current_pos + take]);
st.current_pos += take;
filled += take;
}
Ok(true)
}
fn read_key_from_source<KK: RawSortKey>(&mut self, source_id: usize) -> Result<Option<KK>> {
let mut adapter = SourceReadAdapter { consumer: self, source_id, bytes_read: 0 };
match KK::read_from(&mut adapter) {
Ok(key) => Ok(Some(key)),
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
if adapter.bytes_read == 0 {
Ok(None)
} else {
Err(anyhow::anyhow!(
"truncated key in chunk source {source_id}: \
got {n} bytes then EOF",
n = adapter.bytes_read
))
}
}
Err(e) => Err(anyhow::anyhow!("error reading key from source {source_id}: {e}")),
}
}
fn probe_consumer_stats(&self) -> ConsumerProbeStats {
let mut pending_blocks: u64 = 0;
let mut pending_bytes: u64 = 0;
let mut active_sources: u64 = 0;
for file in self.files.iter() {
let (blocks, bytes, active) = file.probe_stats();
pending_blocks += blocks;
pending_bytes += bytes;
if active {
active_sources += 1;
}
}
ConsumerProbeStats {
current_bytes: 0,
current_capacity: 0,
pending_blocks,
pending_bytes,
active_sources,
}
}
}
struct SourceReadAdapter<'a, K: RawSortKey + 'static> {
consumer: &'a mut MainThreadChunkConsumer<K>,
source_id: usize,
bytes_read: usize,
}
impl<K: RawSortKey + 'static> Read for SourceReadAdapter<'_, K> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.consumer.parser_state[self.source_id].remaining() == 0 {
match self.consumer.advance_to_next_block(self.source_id) {
Ok(true) => {}
Ok(false) => return Ok(0), Err(e) => return Err(std::io::Error::other(e.to_string())),
}
}
let st = &mut self.consumer.parser_state[self.source_id];
let take = buf.len().min(st.remaining());
buf[..take].copy_from_slice(&st.current_buf[st.current_pos..st.current_pos + take]);
st.current_pos += take;
self.bytes_read += take;
Ok(take)
}
}
struct ChunkNamer<'a> {
alloc: &'a mut TmpDirAllocator,
chunk_count: usize,
merge_count: usize,
}
impl<'a> ChunkNamer<'a> {
fn new(alloc: &'a mut TmpDirAllocator) -> Self {
Self { alloc, chunk_count: 0, merge_count: 0 }
}
fn next_chunk_path(&mut self) -> Result<PathBuf> {
let base = self.alloc.next()?;
let path = base.join(format!("chunk_{:04}.keyed", self.chunk_count));
self.chunk_count += 1;
Ok(path)
}
fn next_merged_path(&mut self) -> Result<PathBuf> {
let base = self.alloc.next()?;
let path = base.join(format!("merged_{:04}.keyed", self.merge_count));
self.merge_count += 1;
Ok(path)
}
}
struct PendingSpill {
handle: crate::sort::pooled_chunk_writer::SpillWriteHandle,
chunk_path: PathBuf,
}
#[allow(clippy::cast_possible_truncation)]
fn probe_stats(buf: &impl ProbeableBuffer) -> BufferProbeStats {
BufferProbeStats {
usage: buf.memory_usage() as u64,
capacity: buf.allocated_capacity() as u64,
records: buf.len() as u64,
segments: buf.num_segments() as u64,
}
}
pub struct RawExternalSorter {
sort_order: SortOrder,
memory_limit: usize,
temp_dirs: Vec<PathBuf>,
threads: usize,
output_compression: u32,
temp_compression: u32,
write_index: bool,
pg_info: Option<(String, String)>,
max_temp_files: usize,
cell_tag: Option<[u8; 2]>,
initial_capacity: Option<usize>,
async_reader: bool,
}
struct Phase2Guard<'a, K: RawSortKey + 'static> {
pool: &'a Arc<SortWorkerPool>,
consumer: Option<MainThreadChunkConsumer<K>>,
active: bool,
}
impl<K: RawSortKey + 'static> Phase2Guard<'_, K> {
fn consumer_mut(&mut self) -> Option<&mut MainThreadChunkConsumer<K>> {
self.consumer.as_mut()
}
fn deactivate(&mut self) {
if self.active {
drop(self.consumer.take());
self.pool.set_phase(crate::sort::worker_pool::phase::LEGACY);
self.pool.clear_phase2_files();
self.active = false;
}
}
}
impl<K: RawSortKey + 'static> Drop for Phase2Guard<'_, K> {
fn drop(&mut self) {
self.deactivate();
}
}
impl RawExternalSorter {
#[must_use]
pub fn new(sort_order: SortOrder) -> Self {
Self {
sort_order,
memory_limit: 512 * 1024 * 1024, temp_dirs: Vec::new(),
threads: 1,
output_compression: 6,
temp_compression: 1, write_index: false,
pg_info: None,
max_temp_files: DEFAULT_MAX_TEMP_FILES,
cell_tag: None,
initial_capacity: None,
async_reader: false,
}
}
#[must_use]
pub fn memory_limit(mut self, limit: usize) -> Self {
self.memory_limit = limit;
self
}
#[must_use]
pub fn temp_dir(mut self, path: PathBuf) -> Self {
self.temp_dirs = vec![path];
self
}
#[must_use]
pub fn temp_dirs(mut self, paths: Vec<PathBuf>) -> Self {
self.temp_dirs = paths;
self
}
#[must_use]
pub fn threads(mut self, threads: usize) -> Self {
self.threads = threads;
self
}
#[must_use]
pub fn output_compression(mut self, level: u32) -> Self {
self.output_compression = level;
self
}
#[must_use]
pub fn temp_compression(mut self, level: u32) -> Self {
self.temp_compression = level;
self
}
#[must_use]
pub fn write_index(mut self, enabled: bool) -> Self {
self.write_index = enabled;
self
}
#[must_use]
pub fn pg_info(mut self, version: String, command_line: String) -> Self {
self.pg_info = Some((version, command_line));
self
}
#[must_use]
pub fn max_temp_files(mut self, max: usize) -> Self {
self.max_temp_files = max;
self
}
#[must_use]
pub fn cell_tag(mut self, tag: [u8; 2]) -> Self {
self.cell_tag = Some(tag);
self
}
#[must_use]
pub fn initial_capacity(mut self, bytes: usize) -> Self {
self.initial_capacity = Some(bytes);
self
}
#[must_use]
pub fn async_reader(mut self, enabled: bool) -> Self {
self.async_reader = enabled;
self
}
fn effective_initial_capacity(&self) -> usize {
self.initial_capacity.unwrap_or(self.memory_limit).min(self.memory_limit)
}
fn build_sort_rayon_pool(&self) -> Result<rayon::ThreadPool> {
rayon::ThreadPoolBuilder::new()
.num_threads(self.threads.max(1))
.thread_name(|i| format!("fgumi-sort-rayon-{i}"))
.build()
.map_err(|e| anyhow::anyhow!("failed to build rayon sort pool: {e}"))
}
fn drain_pending_spill<K: RawSortKey + Default + 'static>(
&self,
pending: &mut Option<PendingSpill>,
chunk_files: &mut Vec<PathBuf>,
stats: &mut RawSortStats,
timer: &mut SortPhaseTimer,
namer: &mut ChunkNamer<'_>,
pool: &std::sync::Arc<crate::sort::worker_pool::SortWorkerPool>,
) -> Result<()> {
if let Some(prev) = pending.take() {
prev.handle.wait()?;
timer.record_spill_size(&prev.chunk_path);
chunk_files.push(prev.chunk_path);
stats.chunks_written += 1;
timer.time_consolidate(|| {
self.maybe_consolidate_temp_files::<K>(chunk_files, namer, pool)
})?;
}
Ok(())
}
fn maybe_consolidate_temp_files<K: RawSortKey + Default + 'static>(
&self,
chunk_files: &mut Vec<PathBuf>,
namer: &mut ChunkNamer<'_>,
pool: &Arc<SortWorkerPool>,
) -> Result<()> {
use crate::sort::loser_tree::LoserTree;
if self.max_temp_files == 0 || chunk_files.len() < self.max_temp_files {
return Ok(());
}
if self.max_temp_files < 2 {
return Ok(());
}
let n_to_merge = (self.max_temp_files / 2).max(2).min(chunk_files.len());
let files_to_merge: Vec<PathBuf> = chunk_files.drain(..n_to_merge).collect();
info!(
"Consolidating {} temp files into 1 (total was {})...",
n_to_merge,
n_to_merge + chunk_files.len()
);
let merged_path = namer.next_merged_path()?;
let sem = make_reader_semaphore(self.threads);
let mut readers: Vec<GenericKeyedChunkReader<K>> = files_to_merge
.iter()
.map(|p| GenericKeyedChunkReader::<K>::open(p, Some(Arc::clone(&sem))))
.collect::<Result<Vec<_>>>()?;
let mut writer = PooledChunkWriter::<K>::new(Arc::clone(pool), &merged_path)?;
let mut initial_keys: Vec<K> = Vec::with_capacity(readers.len());
let mut records: Vec<Vec<u8>> = Vec::with_capacity(readers.len());
let mut source_map: Vec<usize> = Vec::with_capacity(readers.len());
for (reader_idx, reader) in readers.iter_mut().enumerate() {
let mut record = Vec::new();
if let Some(key) = reader.next_record(&mut record)? {
initial_keys.push(key);
records.push(record);
source_map.push(reader_idx);
}
}
if initial_keys.is_empty() {
writer.finish()?;
chunk_files.insert(0, merged_path);
for path in &files_to_merge {
let _ = std::fs::remove_file(path);
}
return Ok(());
}
let mut tree = LoserTree::new(initial_keys);
while tree.winner_is_active() {
let winner = tree.winner();
let reader_idx = source_map[winner];
writer.write_record(tree.winner_key(), &records[winner])?;
if let Some(next_key) = readers[reader_idx].next_record(&mut records[winner])? {
tree.replace_winner(next_key);
} else {
tree.remove_winner();
}
}
writer.finish()?;
chunk_files.insert(0, merged_path);
for path in &files_to_merge {
let _ = std::fs::remove_file(path);
}
info!("Consolidation complete, {} temp files remain", chunk_files.len());
Ok(())
}
pub fn sort(&self, input: &Path, output: &Path) -> Result<RawSortStats> {
info!("Starting raw-bytes sort with order: {:?}", self.sort_order);
info!("Memory limit: {} MB", self.memory_limit / (1024 * 1024));
info!("Threads: {}", self.threads);
let pool = Arc::new(SortWorkerPool::new(
self.threads.max(1),
self.temp_compression,
self.output_compression,
));
info!("Phase 1: Pool-integrated input reading ({} workers, N+2 model)", pool.num_workers());
let (record_source, header) = {
let (reader, header) = crate::bam_io::create_raw_bam_reader_pool_integrated(
input,
&pool,
self.async_reader,
)?;
(RecordSource::direct(reader), header)
};
let header = if let Some((ref version, ref command_line)) = self.pg_info {
crate::header::add_pg_record(header, version, command_line)?
} else {
header
};
let (_temp_dirs, mut alloc) = self.create_temp_dirs()?;
match self.sort_order {
SortOrder::Coordinate => {
self.sort_coordinate(record_source, pool, &header, output, &mut alloc)
}
SortOrder::Queryname(comparator) => {
self.sort_queryname(record_source, pool, &header, output, &mut alloc, comparator)
}
SortOrder::TemplateCoordinate => {
self.sort_template_coordinate(record_source, pool, &header, output, &mut alloc)
}
}
}
pub fn merge_bams(&self, inputs: &[PathBuf], header: &Header, output: &Path) -> Result<u64> {
use crate::sort::inline_buffer::extract_coordinate_key_inline;
use crate::sort::keys::{
QuerynameComparator, RawCoordinateKey, RawQuerynameKey, RawQuerynameLexKey, RawSortKey,
SortContext,
};
info!("Starting k-way merge of {} BAM files", inputs.len());
let mut readers = Self::open_bam_prefetch_readers(inputs)?;
let output_header = self.create_output_header(header);
match self.sort_order {
SortOrder::TemplateCoordinate => {
let lib_lookup = LibraryLookup::from_header(header);
let cell_tag = self.cell_tag;
let hasher = cb_hasher();
self.run_merge_loop(&mut readers, &output_header, output, |bam| {
extract_template_key_inline(bam, &lib_lookup, cell_tag.as_ref(), &hasher)
})
}
SortOrder::Coordinate => {
#[allow(clippy::cast_possible_truncation)]
let nref = header.reference_sequences().len() as u32;
self.run_merge_loop(&mut readers, &output_header, output, |bam| RawCoordinateKey {
sort_key: extract_coordinate_key_inline(bam, nref),
})
}
SortOrder::Queryname(QuerynameComparator::Lexicographic) => {
let ctx = SortContext::from_header(header);
self.run_merge_loop(&mut readers, &output_header, output, |bam| {
RawQuerynameLexKey::extract(bam, &ctx)
})
}
SortOrder::Queryname(QuerynameComparator::Natural) => {
let ctx = SortContext::from_header(header);
self.run_merge_loop(&mut readers, &output_header, output, |bam| {
RawQuerynameKey::extract(bam, &ctx)
})
}
}
}
fn open_bam_prefetch_readers(inputs: &[PathBuf]) -> Result<Vec<RawReadAheadReader>> {
inputs
.iter()
.map(|path| {
let (reader, _header) = create_raw_bam_reader(path, 1)?;
Ok(RawReadAheadReader::new(reader))
})
.collect()
}
fn run_merge_loop<K: Ord>(
&self,
readers: &mut [RawReadAheadReader],
output_header: &Header,
output: &Path,
extract_key: impl Fn(&[u8]) -> K,
) -> Result<u64> {
use crate::sort::loser_tree::LoserTree;
let mut initial_keys: Vec<K> = Vec::with_capacity(readers.len());
let mut records: Vec<Vec<u8>> = Vec::with_capacity(readers.len());
let mut source_map: Vec<usize> = Vec::with_capacity(readers.len());
for (idx, reader) in readers.iter_mut().enumerate() {
if let Some(raw_record) = reader.next() {
let mut buf = Vec::with_capacity(raw_record.as_ref().len());
buf.extend_from_slice(raw_record.as_ref());
initial_keys.push(extract_key(&buf));
records.push(buf);
source_map.push(idx);
}
}
if initial_keys.is_empty() {
info!("Merge complete: 0 records merged");
let writer = crate::bam_io::create_raw_bam_writer(
output,
output_header,
self.threads,
self.output_compression,
)?;
writer.finish()?;
return Ok(0);
}
let mut tree = LoserTree::new(initial_keys);
let mut writer = crate::bam_io::create_raw_bam_writer(
output,
output_header,
self.threads,
self.output_compression,
)?;
let mut records_merged = 0u64;
let merge_progress = ProgressTracker::new("Merged records").with_interval(1_000_000);
while tree.winner_is_active() {
let winner = tree.winner();
writer.write_raw_record(&records[winner])?;
records_merged += 1;
merge_progress.log_if_needed(1);
let reader_idx = source_map[winner];
if let Some(raw_record) = readers[reader_idx].next() {
let buf = &mut records[winner];
buf.clear();
buf.extend_from_slice(raw_record.as_ref());
let new_key = extract_key(buf);
tree.replace_winner(new_key);
} else {
tree.remove_winner();
}
}
writer.finish()?;
merge_progress.log_final();
Ok(records_merged)
}
fn sort_coordinate(
&self,
record_source: RecordSource,
pool: Arc<SortWorkerPool>,
header: &Header,
output: &Path,
alloc: &mut TmpDirAllocator,
) -> Result<RawSortStats> {
if self.write_index {
self.sort_coordinate_with_index(record_source, pool, header, output, alloc)
} else {
self.sort_coordinate_optimized(record_source, pool, header, output, alloc)
}
}
#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
fn sort_coordinate_optimized(
&self,
mut record_source: RecordSource,
pool: Arc<SortWorkerPool>,
header: &Header,
output: &Path,
alloc: &mut TmpDirAllocator,
) -> Result<RawSortStats> {
use crate::sort::keys::RawCoordinateKey;
let mut stats = RawSortStats::default();
let mut timer = SortPhaseTimer::new();
let nref = header.reference_sequences().len() as u32;
let init_cap = self.effective_initial_capacity();
let estimated_records = init_cap / 240;
let estimated_data_bytes = init_cap.saturating_sub(estimated_records * 24);
let mut chunk_files: Vec<PathBuf> = Vec::new();
let mut buffer = RecordBuffer::with_capacity(estimated_records, estimated_data_bytes, nref);
let mut namer = ChunkNamer::new(alloc);
let mut pending_spill: Option<PendingSpill> = None;
let rayon_pool = self.build_sort_rayon_pool()?;
let progress = ProgressTracker::new("Read records").with_interval(1_000_000);
info!("Phase 1: Reading and sorting chunks (inline buffer, keyed output)...");
let mut probe = SpillProbe::new("phase1");
for record in record_source.by_ref() {
stats.total_records += 1;
progress.log_if_needed(1);
buffer.push_coordinate(record.as_ref())?;
if probe.should_sample_read(stats.total_records) {
probe.log_mid_read(probe_stats(&buffer), Some(pool.phase1_queue_depths()));
}
if buffer.memory_usage() >= self.memory_limit {
timer.end_read_span();
let bstats = probe_stats(&buffer);
let depths = Some(pool.phase1_queue_depths());
probe.pre_spill(bstats, depths);
self.drain_pending_spill::<RawCoordinateKey>(
&mut pending_spill,
&mut chunk_files,
&mut stats,
&mut timer,
&mut namer,
&pool,
)?;
probe.post_drain(probe_stats(&buffer), Some(pool.phase1_queue_depths()));
let chunk_path = namer.next_chunk_path()?;
timer.time_sort(|| {
rayon_pool.install(|| buffer.par_sort());
});
let handle = timer.time_spill_write(|| {
let mut writer =
PooledChunkWriter::<RawCoordinateKey>::new(Arc::clone(&pool), &chunk_path)?;
for r in buffer.refs() {
let key = RawCoordinateKey { sort_key: r.sort_key };
let record_bytes = buffer.get_record(r);
writer.write_record(&key, record_bytes)?;
}
writer.start_finish()
})?;
pending_spill = Some(PendingSpill { handle, chunk_path });
buffer.clear();
force_mi_collect();
probe.post_spill(Some(pool.phase1_queue_depths()));
timer.begin_read_span();
}
}
timer.end_read_span();
progress.log_final();
if let Some(err) = record_source.take_error() {
return Err(anyhow::Error::from(err));
}
self.drain_pending_spill::<RawCoordinateKey>(
&mut pending_spill,
&mut chunk_files,
&mut stats,
&mut timer,
&mut namer,
&pool,
)?;
probe.phase1_end(buffer.memory_usage() as u64);
if chunk_files.is_empty() {
info!("All records fit in memory, performing in-memory sort");
timer.time_sort(|| {
rayon_pool.install(|| buffer.par_sort());
});
timer.time_write_output(|| {
use crate::sort::pooled_bam_writer::PooledBamWriter;
let output_header = self.create_output_header(header);
let mut writer = PooledBamWriter::new(Arc::clone(&pool), output, &output_header)?;
for record_bytes in buffer.iter_sorted() {
writer.write_raw_record(record_bytes)?;
}
writer.finish()?;
Ok(())
})?;
} else {
let memory_chunks: Vec<Vec<(RawCoordinateKey, fgumi_raw_bam::RawRecord)>> = if buffer
.is_empty()
{
Vec::new()
} else if self.threads > 1 {
timer.time_sort(|| rayon_pool.install(|| buffer.par_sort_into_chunks(self.threads)))
} else {
timer.time_sort(|| {
rayon_pool.install(|| buffer.par_sort());
});
let chunk = buffer
.refs()
.iter()
.map(|r| {
let key = RawCoordinateKey { sort_key: r.sort_key };
(key, fgumi_raw_bam::RawRecord::from(buffer.get_record(r).to_vec()))
})
.collect();
vec![chunk]
};
let n_memory = memory_chunks.iter().filter(|c| !c.is_empty()).count();
info!(
"Phase 2: Merging {} chunks (keyed O(1) comparisons)...",
chunk_files.len() + n_memory
);
timer.time_merge(|| {
self.merge_chunks_generic::<RawCoordinateKey>(
&chunk_files,
memory_chunks,
header,
output,
stats.total_records,
&pool,
)
})?;
}
stats.output_records = stats.total_records;
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.shutdown();
}
timer.log_summary(self.threads);
info!("Sort complete: {} records processed", stats.total_records);
Ok(stats)
}
#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
fn sort_coordinate_with_index(
&self,
mut record_source: RecordSource,
pool: Arc<SortWorkerPool>,
header: &Header,
output: &Path,
alloc: &mut TmpDirAllocator,
) -> Result<RawSortStats> {
use crate::bam_io::{create_indexing_bam_writer, write_bai_index};
use crate::sort::keys::RawCoordinateKey;
info!("Indexing enabled: will write BAM index alongside output");
let mut stats = RawSortStats::default();
let mut timer = SortPhaseTimer::new();
let nref = header.reference_sequences().len() as u32;
let init_cap = self.effective_initial_capacity();
let estimated_records = init_cap / 240;
let estimated_data_bytes = init_cap.saturating_sub(estimated_records * 24);
let mut chunk_files: Vec<PathBuf> = Vec::new();
let mut buffer = RecordBuffer::with_capacity(estimated_records, estimated_data_bytes, nref);
let mut namer = ChunkNamer::new(alloc);
let mut pending_spill: Option<PendingSpill> = None;
let rayon_pool = self.build_sort_rayon_pool()?;
info!("Phase 1: Reading and sorting chunks (inline buffer, keyed output)...");
let mut probe = SpillProbe::new("phase1");
for record in record_source.by_ref() {
stats.total_records += 1;
buffer.push_coordinate(record.as_ref())?;
if probe.should_sample_read(stats.total_records) {
probe.log_mid_read(probe_stats(&buffer), Some(pool.phase1_queue_depths()));
}
if buffer.memory_usage() >= self.memory_limit {
timer.end_read_span();
let bstats = probe_stats(&buffer);
let depths = Some(pool.phase1_queue_depths());
probe.pre_spill(bstats, depths);
self.drain_pending_spill::<RawCoordinateKey>(
&mut pending_spill,
&mut chunk_files,
&mut stats,
&mut timer,
&mut namer,
&pool,
)?;
probe.post_drain(probe_stats(&buffer), Some(pool.phase1_queue_depths()));
let chunk_path = namer.next_chunk_path()?;
timer.time_sort(|| {
rayon_pool.install(|| buffer.par_sort());
});
let handle = timer.time_spill_write(|| {
let mut writer =
PooledChunkWriter::<RawCoordinateKey>::new(Arc::clone(&pool), &chunk_path)?;
for r in buffer.refs() {
let key = RawCoordinateKey { sort_key: r.sort_key };
let record_bytes = buffer.get_record(r);
writer.write_record(&key, record_bytes)?;
}
writer.start_finish()
})?;
pending_spill = Some(PendingSpill { handle, chunk_path });
buffer.clear();
force_mi_collect();
probe.post_spill(Some(pool.phase1_queue_depths()));
timer.begin_read_span();
}
}
timer.end_read_span();
info!("Read {} records total", stats.total_records);
if let Some(err) = record_source.take_error() {
return Err(anyhow::Error::from(err));
}
self.drain_pending_spill::<RawCoordinateKey>(
&mut pending_spill,
&mut chunk_files,
&mut stats,
&mut timer,
&mut namer,
&pool,
)?;
probe.phase1_end(buffer.memory_usage() as u64);
let output_header = self.create_output_header(header);
if chunk_files.is_empty() {
info!("All records fit in memory, performing in-memory sort");
timer.time_sort(|| {
rayon_pool.install(|| buffer.par_sort());
});
timer.time_write_output(|| {
let mut writer = create_indexing_bam_writer(
output,
&output_header,
self.output_compression,
self.threads,
)?;
for record_bytes in buffer.iter_sorted() {
writer.write_raw_record(record_bytes)?;
}
let index = writer.finish()?;
let index_path = output.with_extension("bam.bai");
write_bai_index(&index_path, &index)?;
info!("Wrote BAM index: {}", index_path.display());
Ok(())
})?;
} else {
let memory_chunks: Vec<Vec<(RawCoordinateKey, fgumi_raw_bam::RawRecord)>> = if buffer
.is_empty()
{
Vec::new()
} else if self.threads > 1 {
timer.time_sort(|| rayon_pool.install(|| buffer.par_sort_into_chunks(self.threads)))
} else {
timer.time_sort(|| {
rayon_pool.install(|| buffer.par_sort());
});
let chunk = buffer
.refs()
.iter()
.map(|r| {
let key = RawCoordinateKey { sort_key: r.sort_key };
(key, fgumi_raw_bam::RawRecord::from(buffer.get_record(r).to_vec()))
})
.collect();
vec![chunk]
};
let n_memory = memory_chunks.iter().filter(|c| !c.is_empty()).count();
info!(
"Phase 2: Merging {} chunks with index generation...",
chunk_files.len() + n_memory
);
timer.time_merge(|| {
let index = self.merge_chunks_with_index::<RawCoordinateKey>(
&chunk_files,
memory_chunks,
header,
output,
stats.total_records,
&pool,
)?;
let index_path = output.with_extension("bam.bai");
write_bai_index(&index_path, &index)?;
info!("Wrote BAM index: {}", index_path.display());
Ok(())
})?;
}
stats.output_records = stats.total_records;
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.shutdown();
}
timer.log_summary(self.threads);
info!("Sort complete: {} records processed", stats.total_records);
Ok(stats)
}
fn sort_queryname(
&self,
record_source: RecordSource,
pool: Arc<SortWorkerPool>,
header: &Header,
output: &Path,
alloc: &mut TmpDirAllocator,
comparator: QuerynameComparator,
) -> Result<RawSortStats> {
use crate::sort::keys::{RawQuerynameKey, RawQuerynameLexKey};
info!("Using queryname sort with {comparator} comparator");
match comparator {
QuerynameComparator::Lexicographic => self.sort_queryname_keyed::<RawQuerynameLexKey>(
record_source,
pool,
header,
output,
alloc,
),
QuerynameComparator::Natural => self.sort_queryname_keyed::<RawQuerynameKey>(
record_source,
pool,
header,
output,
alloc,
),
}
}
#[allow(clippy::too_many_lines)]
fn sort_queryname_keyed<K: RawSortKey + Default + 'static>(
&self,
mut record_source: RecordSource,
pool: Arc<SortWorkerPool>,
header: &Header,
output: &Path,
alloc: &mut TmpDirAllocator,
) -> Result<RawSortStats> {
use crate::sort::keys::SortContext;
let mut stats = RawSortStats::default();
let mut timer = SortPhaseTimer::new();
let ctx = SortContext::from_header(header);
let init_cap = self.effective_initial_capacity();
let estimated_records = init_cap / 300;
let mut chunk_files: Vec<PathBuf> = Vec::new();
let mut entries: Vec<(K, fgumi_raw_bam::RawRecord)> = Vec::with_capacity(estimated_records);
let mut memory_used = 0usize;
let mut namer = ChunkNamer::new(alloc);
let mut pending_spill: Option<PendingSpill> = None;
let rayon_pool = self.build_sort_rayon_pool()?;
let progress = ProgressTracker::new("Read records").with_interval(1_000_000);
info!("Phase 1: Reading and sorting chunks (keyed output)...");
let mut probe = SpillProbe::new("phase1");
for record in record_source.by_ref() {
stats.total_records += 1;
progress.log_if_needed(1);
let key = K::extract(record.as_ref(), &ctx);
let record_size = record.as_ref().len() + 50; memory_used += record_size;
entries.push((key, record));
if probe.should_sample_read(stats.total_records) {
let bstats = BufferProbeStats::simple(memory_used as u64, entries.len() as u64);
probe.log_mid_read(bstats, Some(pool.phase1_queue_depths()));
}
if memory_used >= self.memory_limit {
timer.end_read_span();
let bstats = BufferProbeStats::simple(memory_used as u64, entries.len() as u64);
let depths = Some(pool.phase1_queue_depths());
probe.pre_spill(bstats, depths);
self.drain_pending_spill::<K>(
&mut pending_spill,
&mut chunk_files,
&mut stats,
&mut timer,
&mut namer,
&pool,
)?;
probe.post_drain(bstats, Some(pool.phase1_queue_depths()));
let chunk_path = namer.next_chunk_path()?;
timer.time_sort(|| {
use rayon::prelude::*;
rayon_pool.install(|| entries.par_sort_unstable_by(|a, b| a.0.cmp(&b.0)));
});
let handle = timer.time_spill_write(|| {
let mut writer = PooledChunkWriter::<K>::new(Arc::clone(&pool), &chunk_path)?;
for (key, record) in entries.drain(..) {
writer.write_record(&key, record.as_ref())?;
}
writer.start_finish()
})?;
pending_spill = Some(PendingSpill { handle, chunk_path });
memory_used = 0;
force_mi_collect();
probe.post_spill(Some(pool.phase1_queue_depths()));
timer.begin_read_span();
}
}
timer.end_read_span();
progress.log_final();
if let Some(err) = record_source.take_error() {
return Err(anyhow::Error::from(err));
}
self.drain_pending_spill::<K>(
&mut pending_spill,
&mut chunk_files,
&mut stats,
&mut timer,
&mut namer,
&pool,
)?;
probe.phase1_end(memory_used as u64);
if chunk_files.is_empty() {
info!("All records fit in memory, performing in-memory sort");
timer.time_sort(|| {
use rayon::prelude::*;
rayon_pool.install(|| entries.par_sort_unstable_by(|a, b| a.0.cmp(&b.0)));
});
timer.time_write_output(|| {
use crate::sort::pooled_bam_writer::PooledBamWriter;
let output_header = self.create_output_header(header);
let mut writer = PooledBamWriter::new(Arc::clone(&pool), output, &output_header)?;
for (_key, record) in entries {
writer.write_raw_record(&record)?;
}
writer.finish()?;
Ok(())
})?;
} else {
let memory_chunks: Vec<Vec<(K, fgumi_raw_bam::RawRecord)>> = if entries.is_empty() {
Vec::new()
} else if self.threads > 1 {
timer.time_sort(|| {
use rayon::prelude::*;
let chunk_size = entries.len().div_ceil(self.threads.max(1));
rayon_pool.install(|| {
entries.par_chunks_mut(chunk_size).for_each(|chunk| {
chunk.sort_unstable_by(|a, b| a.0.cmp(&b.0));
});
});
let mut remaining = std::mem::take(&mut entries);
let num_chunks = remaining.len().div_ceil(chunk_size);
let mut chunks: Vec<Vec<(K, fgumi_raw_bam::RawRecord)>> =
Vec::with_capacity(num_chunks);
let tail_len = remaining.len() % chunk_size;
if tail_len != 0 {
let split_at = remaining.len() - tail_len;
chunks.push(remaining.split_off(split_at));
}
while !remaining.is_empty() {
let split_at = remaining.len().saturating_sub(chunk_size);
chunks.push(remaining.split_off(split_at));
}
chunks.reverse();
chunks
})
} else {
timer.time_sort(|| {
entries.sort_unstable_by(|a, b| a.0.cmp(&b.0));
});
vec![entries]
};
let n_memory = memory_chunks.iter().filter(|c| !c.is_empty()).count();
info!(
"Phase 2: Merging {} chunks (keyed comparisons)...",
chunk_files.len() + n_memory
);
timer.time_merge(|| {
self.merge_chunks_generic::<K>(
&chunk_files,
memory_chunks,
header,
output,
stats.total_records,
&pool,
)
})?;
}
stats.output_records = stats.total_records;
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.shutdown();
}
timer.log_summary(self.threads);
info!("Sort complete: {} records processed", stats.total_records);
Ok(stats)
}
#[allow(clippy::too_many_lines)]
fn sort_template_coordinate(
&self,
mut record_source: RecordSource,
pool: Arc<SortWorkerPool>,
header: &Header,
output: &Path,
alloc: &mut TmpDirAllocator,
) -> Result<RawSortStats> {
let mut stats = RawSortStats::default();
let mut timer = SortPhaseTimer::new();
let lib_lookup = LibraryLookup::from_header(header);
let cb_hasher = cb_hasher();
let init_cap = self.effective_initial_capacity();
let bytes_per_record = 354;
let estimated_records = init_cap / bytes_per_record;
let estimated_data_bytes = init_cap * 86 / 100;
let mut chunk_files: Vec<PathBuf> = Vec::new();
let mut buffer =
TemplateRecordBuffer::with_capacity(estimated_records, estimated_data_bytes);
let mut namer = ChunkNamer::new(alloc);
let mut pending_spill: Option<PendingSpill> = None;
let rayon_pool = self.build_sort_rayon_pool()?;
let progress = ProgressTracker::new("Read records").with_interval(1_000_000);
info!("Phase 1: Reading and sorting chunks (inline buffer)...");
let mut probe = SpillProbe::new("phase1");
for record in record_source.by_ref() {
stats.total_records += 1;
progress.log_if_needed(1);
let bam_bytes = record.as_ref();
let key = extract_template_key_inline(
bam_bytes,
&lib_lookup,
self.cell_tag.as_ref(),
&cb_hasher,
);
buffer.push(bam_bytes, key)?;
if probe.should_sample_read(stats.total_records) {
probe.log_mid_read(probe_stats(&buffer), Some(pool.phase1_queue_depths()));
}
if buffer.memory_usage() >= self.memory_limit {
timer.end_read_span();
let bstats = probe_stats(&buffer);
let depths = Some(pool.phase1_queue_depths());
probe.pre_spill(bstats, depths);
self.drain_pending_spill::<TemplateKey>(
&mut pending_spill,
&mut chunk_files,
&mut stats,
&mut timer,
&mut namer,
&pool,
)?;
probe.post_drain(probe_stats(&buffer), Some(pool.phase1_queue_depths()));
let chunk_path = namer.next_chunk_path()?;
timer.time_sort(|| {
rayon_pool.install(|| buffer.par_sort());
});
let handle = timer.time_spill_write(|| {
let mut writer =
PooledChunkWriter::<TemplateKey>::new(Arc::clone(&pool), &chunk_path)?;
for (key, record) in buffer.iter_sorted_keyed() {
writer.write_record(&key, record)?;
}
writer.start_finish()
})?;
pending_spill = Some(PendingSpill { handle, chunk_path });
buffer.clear();
force_mi_collect();
probe.post_spill(Some(pool.phase1_queue_depths()));
timer.begin_read_span();
}
}
timer.end_read_span();
progress.log_final();
if let Some(err) = record_source.take_error() {
return Err(anyhow::Error::from(err));
}
self.drain_pending_spill::<TemplateKey>(
&mut pending_spill,
&mut chunk_files,
&mut stats,
&mut timer,
&mut namer,
&pool,
)?;
probe.phase1_end(buffer.memory_usage() as u64);
if chunk_files.is_empty() {
info!("All records fit in memory, performing in-memory sort");
timer.time_sort(|| {
rayon_pool.install(|| buffer.par_sort());
});
timer.time_write_output(|| {
use crate::sort::pooled_bam_writer::PooledBamWriter;
let output_header = self.create_output_header(header);
let mut writer = PooledBamWriter::new(Arc::clone(&pool), output, &output_header)?;
for record_bytes in buffer.iter_sorted() {
writer.write_raw_record(record_bytes)?;
}
writer.finish()?;
Ok(())
})?;
} else {
let memory_chunks: Vec<Vec<(TemplateKey, fgumi_raw_bam::RawRecord)>> = if buffer
.is_empty()
{
Vec::new()
} else if self.threads > 1 {
timer.time_sort(|| rayon_pool.install(|| buffer.par_sort_into_chunks(self.threads)))
} else {
timer.time_sort(|| {
rayon_pool.install(|| buffer.par_sort());
});
let chunk = buffer
.iter_sorted_keyed()
.map(|(k, r)| (k, fgumi_raw_bam::RawRecord::from(r.to_vec())))
.collect();
vec![chunk]
};
let n_memory = memory_chunks.iter().filter(|c| !c.is_empty()).count();
info!("Phase 2: Merging {} chunks...", chunk_files.len() + n_memory);
timer.time_merge(|| {
self.merge_chunks_keyed(
&chunk_files,
memory_chunks,
header,
output,
stats.total_records,
&pool,
)
})?;
}
stats.output_records = stats.total_records;
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.shutdown();
}
timer.log_summary(self.threads);
info!("Sort complete: {} records processed", stats.total_records);
Ok(stats)
}
fn build_chunk_sources<K: RawSortKey + Default + 'static>(
chunk_files: &[PathBuf],
memory_chunks: Vec<Vec<(K, fgumi_raw_bam::RawRecord)>>,
reader_concurrency: usize,
pool_decompress: bool,
pool: &Arc<SortWorkerPool>,
) -> Result<Vec<ChunkSource<K>>> {
let num_disk = chunk_files.len();
let num_memory = memory_chunks.iter().filter(|c| !c.is_empty()).count();
let mut sources: Vec<ChunkSource<K>> = Vec::with_capacity(num_disk + num_memory);
if pool_decompress && !chunk_files.is_empty() {
pool.set_phase2_files(chunk_files)?;
for source_id in 0..num_disk {
sources.push(ChunkSource::PoolDisk { source_id });
}
} else {
let sem = make_reader_semaphore(reader_concurrency);
for path in chunk_files {
sources.push(ChunkSource::Disk(GenericKeyedChunkReader::<K>::open(
path,
Some(Arc::clone(&sem)),
)?));
}
}
for chunk in memory_chunks {
if !chunk.is_empty() {
sources.push(ChunkSource::Memory { records: chunk, idx: 0 });
}
}
Ok(sources)
}
fn merge_chunks_keyed(
&self,
chunk_files: &[PathBuf],
memory_chunks: Vec<Vec<(TemplateKey, fgumi_raw_bam::RawRecord)>>,
header: &Header,
output: &Path,
total_records: u64,
pool: &Arc<SortWorkerPool>,
) -> Result<u64> {
self.merge_chunks_generic::<TemplateKey>(
chunk_files,
memory_chunks,
header,
output,
total_records,
pool,
)
}
#[allow(clippy::too_many_lines)]
fn merge_chunks_generic<K: RawSortKey + Default + 'static>(
&self,
chunk_files: &[PathBuf],
memory_chunks: Vec<Vec<(K, fgumi_raw_bam::RawRecord)>>,
header: &Header,
output: &Path,
total_records: u64,
pool: &Arc<SortWorkerPool>,
) -> Result<u64> {
use crate::sort::loser_tree::LoserTree;
use crate::sort::pooled_bam_writer::PooledBamWriter;
use crate::sort::worker_pool::phase;
let reader_concurrency: usize = 1;
let num_disk = chunk_files.len();
if num_disk > 0 {
info!(
"Pool-integrated merge: {} disk sources, {} pool workers (N+2 model)",
num_disk,
pool.num_workers()
);
}
let mut sources = Self::build_chunk_sources::<K>(
chunk_files,
memory_chunks,
reader_concurrency,
true,
pool,
)?;
let num_sources = sources.len();
info!("Merging from {num_sources} sources...");
let mut guard: Phase2Guard<'_, K> = if num_disk > 0 {
let files = pool.phase2_files();
let consumer = MainThreadChunkConsumer::new(
files,
pool.decompress_error_flag(),
pool.chunk_read_error_flag(),
pool.worker_panicked_flag(),
);
pool.set_phase(phase::PHASE2);
Phase2Guard { pool, consumer: Some(consumer), active: true }
} else {
Phase2Guard { pool, consumer: None, active: false }
};
let output_header = self.create_output_header(header);
let mut initial_keys: Vec<K> = Vec::with_capacity(sources.len());
let mut records: Vec<Vec<u8>> = Vec::with_capacity(sources.len());
let mut source_map: Vec<usize> = Vec::with_capacity(sources.len());
for (idx, source) in sources.iter_mut().enumerate() {
let mut record = Vec::new();
if let Some(key) = source.next_record(&mut record, guard.consumer_mut())? {
initial_keys.push(key);
records.push(record);
source_map.push(idx);
}
}
if initial_keys.is_empty() {
info!("Merge complete: 0 records merged");
guard.deactivate();
let writer = PooledBamWriter::new(Arc::clone(pool), output, &output_header)?;
writer.finish()?;
return Ok(0);
}
let mut tree = LoserTree::new(initial_keys);
info!("Merge thread budget: {} pool workers + 1 I/O + 1 main (N+2)", pool.num_workers());
let mut writer = PooledBamWriter::new(Arc::clone(pool), output, &output_header)?;
let mut records_merged = 0u64;
let merge_progress = ProgressTracker::new("Merged records")
.with_interval(1_000_000)
.with_total(total_records);
let mut merge_probe = MergeProbe::new();
let debug_timing = log::log_enabled!(log::Level::Debug);
let merge_sample_interval: u64 = 1024;
let mut merge_write_secs = 0.0f64;
let mut merge_read_secs = 0.0f64;
let mut merge_tree_secs = 0.0f64;
let mut samples_taken: u64 = 0;
let loop_start = Instant::now();
while tree.winner_is_active() {
let winner = tree.winner();
let sample_this = debug_timing && records_merged.is_multiple_of(merge_sample_interval);
if sample_this {
let t0 = Instant::now();
writer.write_raw_record(&records[winner])?;
merge_write_secs += t0.elapsed().as_secs_f64();
} else {
writer.write_raw_record(&records[winner])?;
}
records_merged += 1;
merge_progress.log_if_needed(1);
if merge_probe.should_sample(records_merged) {
let depths = pool.phase1_queue_depths();
let consumer_stats = guard.consumer_mut().map(|c| c.probe_consumer_stats());
merge_probe.log_mid_with_depths(depths, consumer_stats);
}
let src_idx = source_map[winner];
if sample_this {
let t0 = Instant::now();
let next =
sources[src_idx].next_record(&mut records[winner], guard.consumer_mut())?;
merge_read_secs += t0.elapsed().as_secs_f64();
let t0 = Instant::now();
if let Some(key) = next {
tree.replace_winner(key);
} else {
tree.remove_winner();
}
merge_tree_secs += t0.elapsed().as_secs_f64();
samples_taken += 1;
} else {
let next =
sources[src_idx].next_record(&mut records[winner], guard.consumer_mut())?;
if let Some(key) = next {
tree.replace_winner(key);
} else {
tree.remove_winner();
}
}
}
guard.deactivate();
if debug_timing {
let loop_total = loop_start.elapsed().as_secs_f64();
#[allow(clippy::cast_precision_loss)]
let scale =
if samples_taken > 0 { records_merged as f64 / samples_taken as f64 } else { 1.0 };
let est_write = merge_write_secs * scale;
let est_read = merge_read_secs * scale;
let est_tree = merge_tree_secs * scale;
debug!(
"Merge sub-phases (sampled {samples_taken}/{records_merged}, scale={scale:.1}x): write={est_write:.2}s read={est_read:.2}s tree={est_tree:.2}s total={loop_total:.2}s records={records_merged}"
);
}
writer.finish()?;
merge_progress.log_final();
log_snapshot("phase2.end", 0);
Ok(records_merged)
}
#[cfg(test)]
pub fn merge_chunks_for_test<K: RawSortKey + Default + 'static>(
&self,
chunk_files: &[PathBuf],
memory_chunks: Vec<Vec<(K, fgumi_raw_bam::RawRecord)>>,
header: &Header,
output: &Path,
pool: &Arc<SortWorkerPool>,
) -> Result<u64> {
self.merge_chunks_generic::<K>(chunk_files, memory_chunks, header, output, 0, pool)
}
fn merge_chunks_with_index<K: RawSortKey + Default + 'static>(
&self,
chunk_files: &[PathBuf],
memory_chunks: Vec<Vec<(K, fgumi_raw_bam::RawRecord)>>,
header: &Header,
output: &Path,
total_records: u64,
pool: &Arc<SortWorkerPool>,
) -> Result<noodles::bam::bai::Index> {
use crate::bam_io::create_indexing_bam_writer;
use crate::sort::loser_tree::LoserTree;
let writer_threads = self.threads;
let reader_concurrency: usize = 1;
let mut sources = Self::build_chunk_sources::<K>(
chunk_files,
memory_chunks,
reader_concurrency,
false,
pool,
)?;
let num_sources = sources.len();
info!(
"Merge thread budget (indexing): {writer_threads} writer + {reader_concurrency} reader + 1 main = {} total",
writer_threads + reader_concurrency + 1
);
info!("Merging from {num_sources} sources (indexing)...");
let output_header = self.create_output_header(header);
let mut initial_keys: Vec<K> = Vec::with_capacity(sources.len());
let mut records: Vec<Vec<u8>> = Vec::with_capacity(sources.len());
let mut source_map: Vec<usize> = Vec::with_capacity(sources.len());
for (idx, source) in sources.iter_mut().enumerate() {
let mut record = Vec::new();
if let Some(key) = source.next_record(&mut record, None)? {
initial_keys.push(key);
records.push(record);
source_map.push(idx);
}
}
if initial_keys.is_empty() {
let writer = create_indexing_bam_writer(
output,
&output_header,
self.output_compression,
writer_threads,
)?;
let index = writer.finish()?;
info!("Merge complete: 0 records merged");
return Ok(index);
}
let mut tree = LoserTree::new(initial_keys);
let mut writer = create_indexing_bam_writer(
output,
&output_header,
self.output_compression,
writer_threads,
)?;
let merge_progress = ProgressTracker::new("Merged records")
.with_interval(1_000_000)
.with_total(total_records);
while tree.winner_is_active() {
let winner = tree.winner();
writer.write_raw_record(&records[winner])?;
merge_progress.log_if_needed(1);
let src_idx = source_map[winner];
if let Some(key) = sources[src_idx].next_record(&mut records[winner], None)? {
tree.replace_winner(key);
} else {
tree.remove_winner();
}
}
let index = writer.finish()?;
merge_progress.log_final();
Ok(index)
}
fn create_output_header(&self, header: &Header) -> Header {
super::create_output_header(self.sort_order, header)
}
fn create_temp_dirs(&self) -> Result<(Vec<TempDir>, TmpDirAllocator)> {
use super::create_temp_dir;
if self.temp_dirs.is_empty() {
let td = create_temp_dir(None)?;
let base = td.path().to_path_buf();
let alloc = TmpDirAllocator::new(vec![base])?;
return Ok((vec![td], alloc));
}
let mut handles = Vec::with_capacity(self.temp_dirs.len());
let mut subdirs = Vec::with_capacity(self.temp_dirs.len());
for base in &self.temp_dirs {
let td = create_temp_dir(Some(base))?;
subdirs.push(td.path().to_path_buf());
handles.push(td);
}
let alloc = TmpDirAllocator::new(subdirs)?;
Ok((handles, alloc))
}
}
#[must_use]
pub fn extract_template_key_inline(
bam_bytes: &[u8],
lib_lookup: &LibraryLookup,
cell_tag: Option<&[u8; 2]>,
cb_hasher: &ahash::RandomState,
) -> TemplateKey {
use crate::sort::bam_fields;
use bam_fields::{flags, mate_unclipped_5prime, unclipped_5prime_raw};
let aux = bam_fields::extract_template_aux_tags(bam_bytes, cell_tag);
let mi = aux.mi;
let library = lib_lookup.ordinal_from_rg(aux.rg);
let cb_hash = aux.cell.map_or(0u64, |cb_bytes| cb_hasher.hash_one(cb_bytes));
let v = bam_fields::RawRecordView::new(bam_bytes);
let tid = v.ref_id();
let pos = v.pos();
let l_read_name = v.l_read_name() as usize;
let flag = v.flags();
let mate_tid = v.mate_ref_id();
let mate_pos = v.mate_pos();
let is_unmapped = (flag & flags::UNMAPPED) != 0;
let mate_unmapped = (flag & flags::MATE_UNMAPPED) != 0;
let is_reverse = (flag & flags::REVERSE) != 0;
let mate_reverse = (flag & flags::MATE_REVERSE) != 0;
let is_paired = (flag & flags::PAIRED) != 0;
let name_len = l_read_name.saturating_sub(1);
let name = if name_len > 0 && 32 + name_len <= bam_bytes.len() {
&bam_bytes[32..32 + name_len]
} else {
&[]
};
let name_hash = lib_lookup.hash_name(name);
if is_unmapped {
if is_paired && !mate_unmapped {
let mate_unclipped =
aux.mc.map_or(mate_pos, |mc| mate_unclipped_5prime(mate_pos, mate_reverse, mc));
return TemplateKey::new(
mate_tid,
mate_unclipped,
mate_reverse,
i32::MAX,
i32::MAX,
false,
cb_hash,
library,
mi,
name_hash,
true, );
}
let is_read2 = (flag & 0x80) != 0; return TemplateKey::unmapped(name_hash, cb_hash, is_read2);
}
let this_pos = unclipped_5prime_raw(bam_bytes, pos, is_reverse);
let mate_unclipped = if is_paired && !mate_unmapped {
aux.mc.map_or(mate_pos, |mc| mate_unclipped_5prime(mate_pos, mate_reverse, mc))
} else {
mate_pos
};
let (tid1, tid2, pos1, pos2, neg1, neg2, is_upper) = if is_paired && !mate_unmapped {
let is_upper = (tid, this_pos) > (mate_tid, mate_unclipped)
|| ((tid, this_pos) == (mate_tid, mate_unclipped) && is_reverse);
if is_upper {
(mate_tid, tid, mate_unclipped, this_pos, mate_reverse, is_reverse, true)
} else {
(tid, mate_tid, this_pos, mate_unclipped, is_reverse, mate_reverse, false)
}
} else {
(tid, i32::MAX, this_pos, i32::MAX, is_reverse, false, false)
};
TemplateKey::new(tid1, pos1, neg1, tid2, pos2, neg2, cb_hash, library, mi, name_hash, is_upper)
}
pub use super::SortStats as RawSortStats;
#[cfg(test)]
mod tests {
use super::*;
use crate::sam::builder::SamBuilder;
use bstr::BString;
use noodles::sam::header::record::value::Map;
use noodles::sam::header::record::value::map::ReadGroup;
#[test]
fn test_library_lookup_empty_header() {
let header = Header::builder().build();
let lookup = LibraryLookup::from_header(&header);
assert!(lookup.rg_to_ordinal.is_empty());
}
#[test]
fn test_library_lookup_single_rg() {
let rg = Map::<ReadGroup>::builder()
.insert(rg_tag::LIBRARY, String::from("LibA"))
.build()
.expect("valid");
let header = Header::builder().add_read_group(BString::from("rg1"), rg).build();
let lookup = LibraryLookup::from_header(&header);
assert_eq!(lookup.rg_to_ordinal.len(), 1);
assert_eq!(
*lookup.rg_to_ordinal.get(b"rg1".as_slice()).expect("rg1 should be in ordinal map"),
1
);
}
#[test]
fn test_library_lookup_multiple_libraries() {
let rg_a = Map::<ReadGroup>::builder()
.insert(rg_tag::LIBRARY, String::from("LibC"))
.build()
.expect("valid");
let rg_b = Map::<ReadGroup>::builder()
.insert(rg_tag::LIBRARY, String::from("LibA"))
.build()
.expect("valid");
let rg_c = Map::<ReadGroup>::builder()
.insert(rg_tag::LIBRARY, String::from("LibB"))
.build()
.expect("valid");
let header = Header::builder()
.add_read_group(BString::from("rg1"), rg_a)
.add_read_group(BString::from("rg2"), rg_b)
.add_read_group(BString::from("rg3"), rg_c)
.build();
let lookup = LibraryLookup::from_header(&header);
assert_eq!(lookup.rg_to_ordinal.len(), 3);
let rg2 = *lookup.rg_to_ordinal.get(b"rg2".as_slice()).expect("rg2");
let rg3 = *lookup.rg_to_ordinal.get(b"rg3".as_slice()).expect("rg3");
let rg1 = *lookup.rg_to_ordinal.get(b"rg1".as_slice()).expect("rg1");
assert_eq!(rg2, 1); assert_eq!(rg3, 2); assert_eq!(rg1, 3); }
#[test]
fn test_library_lookup_unknown_rg_returns_zero() {
let rg = Map::<ReadGroup>::builder()
.insert(rg_tag::LIBRARY, String::from("LibA"))
.build()
.expect("valid");
let header = Header::builder().add_read_group(BString::from("rg1"), rg).build();
let lookup = LibraryLookup::from_header(&header);
let mut bam = vec![0u8; 36];
bam[8] = 4; bam[32..36].copy_from_slice(b"rea\0");
assert_eq!(lookup.get_ordinal(&bam), 0);
}
#[test]
fn test_raw_sorter_defaults() {
let sorter = RawExternalSorter::new(SortOrder::Coordinate);
assert_eq!(sorter.memory_limit, 512 * 1024 * 1024);
assert!(sorter.temp_dirs.is_empty());
assert_eq!(sorter.threads, 1);
assert_eq!(sorter.output_compression, 6);
assert_eq!(sorter.temp_compression, 1);
assert!(!sorter.write_index);
assert!(sorter.pg_info.is_none());
assert_eq!(sorter.max_temp_files, DEFAULT_MAX_TEMP_FILES);
}
#[test]
fn test_raw_sorter_builder_chain() {
let sorter = RawExternalSorter::new(SortOrder::Queryname(QuerynameComparator::default()))
.memory_limit(1024)
.temp_dir(PathBuf::from("/tmp/test"))
.threads(8)
.output_compression(9)
.temp_compression(3)
.write_index(true)
.pg_info("1.0".to_string(), "fgumi sort".to_string())
.max_temp_files(128);
assert_eq!(sorter.memory_limit, 1024);
assert_eq!(sorter.temp_dirs, vec![PathBuf::from("/tmp/test")]);
assert_eq!(sorter.threads, 8);
assert_eq!(sorter.output_compression, 9);
assert_eq!(sorter.temp_compression, 3);
assert!(sorter.write_index);
assert_eq!(sorter.pg_info, Some(("1.0".to_string(), "fgumi sort".to_string())));
assert_eq!(sorter.max_temp_files, 128);
}
#[test]
fn test_raw_sorter_memory_limit() {
let sorter = RawExternalSorter::new(SortOrder::Coordinate).memory_limit(256 * 1024 * 1024);
assert_eq!(sorter.memory_limit, 256 * 1024 * 1024);
}
#[test]
fn test_raw_sorter_temp_compression() {
let sorter = RawExternalSorter::new(SortOrder::Coordinate).temp_compression(0);
assert_eq!(sorter.temp_compression, 0);
}
#[test]
fn test_raw_sorter_max_temp_files() {
let sorter = RawExternalSorter::new(SortOrder::Coordinate).max_temp_files(0);
assert_eq!(sorter.max_temp_files, 0);
}
#[test]
fn test_create_output_header_coordinate() {
let sorter = RawExternalSorter::new(SortOrder::Coordinate);
let header = Header::builder().build();
let output_header = sorter.create_output_header(&header);
let hd = output_header.header().expect("header should have HD record");
let so = hd.other_fields().get(b"SO").expect("should have SO tag");
assert_eq!(<_ as AsRef<[u8]>>::as_ref(so), b"coordinate");
}
#[test]
fn test_create_output_header_queryname() {
let sorter = RawExternalSorter::new(SortOrder::Queryname(QuerynameComparator::default()));
let header = Header::builder().build();
let output_header = sorter.create_output_header(&header);
let hd = output_header.header().expect("header should have HD record");
let so = hd.other_fields().get(b"SO").expect("should have SO tag");
assert_eq!(<_ as AsRef<[u8]>>::as_ref(so), b"queryname");
}
#[test]
fn test_create_output_header_template_coordinate() {
let sorter = RawExternalSorter::new(SortOrder::TemplateCoordinate);
let header = Header::builder().build();
let output_header = sorter.create_output_header(&header);
let hd = output_header.header().expect("header should have HD record");
let fields = hd.other_fields();
let so = fields.get(b"SO").expect("should have SO tag");
assert_eq!(<_ as AsRef<[u8]>>::as_ref(so), b"unsorted");
let go = fields.get(b"GO").expect("should have GO tag");
assert_eq!(<_ as AsRef<[u8]>>::as_ref(go), b"query");
let ss = fields.get(b"SS").expect("should have SS tag");
assert_eq!(<_ as AsRef<[u8]>>::as_ref(ss), b"template-coordinate");
}
fn build_bam_with_aux(aux_data: &[u8]) -> Vec<u8> {
let l_read_name: u8 = 4; let mut bam = vec![0u8; 32];
bam[8] = l_read_name;
bam.extend_from_slice(b"rea\0");
bam.extend_from_slice(aux_data);
bam
}
#[rstest::rstest]
#[case::present(b"RGZgroup1\0".as_slice(), Some(b"group1".as_slice()))]
#[case::absent(b"".as_slice(), None)]
fn test_find_rg_tag(#[case] aux_data: &[u8], #[case] expected: Option<&[u8]>) {
let bam = build_bam_with_aux(aux_data);
assert_eq!(
fgumi_raw_bam::RawRecordView::new(&bam).tags().find_string(&SamTag::RG),
expected
);
}
#[test]
fn test_find_rg_tag_after_other_tags() {
let mut aux = Vec::new();
aux.extend_from_slice(b"XYi");
aux.extend_from_slice(&42i32.to_le_bytes());
aux.extend_from_slice(b"RGZmygroup\0");
let bam = build_bam_with_aux(&aux);
assert_eq!(
fgumi_raw_bam::RawRecordView::new(&bam).tags().find_string(&SamTag::RG),
Some(b"mygroup".as_slice())
);
}
#[test]
fn test_raw_sorter_cell_tag_default_is_none() {
let sorter = RawExternalSorter::new(SortOrder::TemplateCoordinate);
assert!(sorter.cell_tag.is_none());
}
#[test]
fn test_raw_sorter_cell_tag_builder() {
let sorter = RawExternalSorter::new(SortOrder::TemplateCoordinate).cell_tag(*SamTag::CB);
assert_eq!(sorter.cell_tag, Some(*SamTag::CB));
}
fn test_cb_hasher() -> ahash::RandomState {
cb_hasher()
}
#[allow(clippy::cast_possible_truncation)]
fn build_mapped_bam(tid: i32, pos: i32, name: &[u8], aux: &[u8]) -> Vec<u8> {
let l_read_name = (name.len() + 1) as u8; let mut bam = vec![0u8; 32];
bam[0..4].copy_from_slice(&tid.to_le_bytes());
bam[4..8].copy_from_slice(&pos.to_le_bytes());
bam[8] = l_read_name;
bam[14..16].copy_from_slice(&3u16.to_le_bytes());
bam[20..24].copy_from_slice(&tid.to_le_bytes());
bam[24..28].copy_from_slice(&pos.to_le_bytes());
bam.extend_from_slice(name);
bam.push(0); bam.extend_from_slice(aux);
bam
}
fn cb_aux(value: &[u8]) -> Vec<u8> {
let mut aux = Vec::new();
aux.extend_from_slice(b"CBZ");
aux.extend_from_slice(value);
aux.push(0); aux
}
#[test]
fn test_extract_template_key_cb_present_has_nonzero_hash() {
let header = Header::builder().build();
let lib_lookup = LibraryLookup::from_header(&header);
let aux = cb_aux(b"ACGTACGT");
let bam = build_mapped_bam(0, 100, b"read1", &aux);
let key = extract_template_key_inline(&bam, &lib_lookup, Some(b"CB"), &test_cb_hasher());
assert_ne!(key.cb_hash, 0, "CB present should produce non-zero cb_hash");
}
#[test]
fn test_extract_template_key_cb_absent_has_zero_hash() {
let header = Header::builder().build();
let lib_lookup = LibraryLookup::from_header(&header);
let bam = build_mapped_bam(0, 100, b"read1", &[]);
let key = extract_template_key_inline(&bam, &lib_lookup, Some(b"CB"), &test_cb_hasher());
assert_eq!(key.cb_hash, 0, "missing CB tag should produce cb_hash=0");
}
#[test]
fn test_extract_template_key_cell_tag_none_has_zero_hash() {
let header = Header::builder().build();
let lib_lookup = LibraryLookup::from_header(&header);
let aux = cb_aux(b"ACGTACGT");
let bam = build_mapped_bam(0, 100, b"read1", &aux);
let key = extract_template_key_inline(&bam, &lib_lookup, None, &test_cb_hasher());
assert_eq!(key.cb_hash, 0, "cell_tag=None should produce cb_hash=0");
}
#[test]
fn test_extract_template_key_different_cb_values_differ() {
let header = Header::builder().build();
let lib_lookup = LibraryLookup::from_header(&header);
let aux1 = cb_aux(b"ACGTACGT");
let bam1 = build_mapped_bam(0, 100, b"read1", &aux1);
let key1 = extract_template_key_inline(&bam1, &lib_lookup, Some(b"CB"), &test_cb_hasher());
let aux2 = cb_aux(b"TGCATGCA");
let bam2 = build_mapped_bam(0, 100, b"read1", &aux2);
let key2 = extract_template_key_inline(&bam2, &lib_lookup, Some(b"CB"), &test_cb_hasher());
assert_ne!(
key1.cb_hash, key2.cb_hash,
"different CB values should produce different hashes"
);
}
#[test]
fn test_extract_template_key_cb_hash_is_deterministic() {
let header = Header::builder().build();
let lib_lookup = LibraryLookup::from_header(&header);
let aux = cb_aux(b"ACGTACGT");
let bam = build_mapped_bam(0, 100, b"read1", &aux);
let key1 = extract_template_key_inline(&bam, &lib_lookup, Some(b"CB"), &test_cb_hasher());
let key2 = extract_template_key_inline(&bam, &lib_lookup, Some(b"CB"), &test_cb_hasher());
assert_eq!(key1.cb_hash, key2.cb_hash, "same input should produce same cb_hash");
}
#[test]
fn test_extract_template_key_unmapped_with_cb() {
let header = Header::builder().build();
let lib_lookup = LibraryLookup::from_header(&header);
let aux = cb_aux(b"ACGTACGT");
let mut bam = vec![0u8; 32];
bam[8] = 6; bam[14..16].copy_from_slice(&0x000Du16.to_le_bytes()); bam[0..4].copy_from_slice(&(-1i32).to_le_bytes());
bam[4..8].copy_from_slice(&(-1i32).to_le_bytes()); bam[20..24].copy_from_slice(&(-1i32).to_le_bytes()); bam[24..28].copy_from_slice(&(-1i32).to_le_bytes()); bam.extend_from_slice(b"read1\0");
bam.extend_from_slice(&aux);
let key = extract_template_key_inline(&bam, &lib_lookup, Some(b"CB"), &test_cb_hasher());
assert_ne!(key.cb_hash, 0, "unmapped read with CB should have non-zero cb_hash");
assert_eq!(key.primary, u64::MAX, "unmapped both-mates should have MAX primary");
}
fn count_bam_records(path: &std::path::Path) -> u64 {
use crate::sort::read_ahead::RawReadAheadReader;
let (reader, _) = create_raw_bam_reader(path, 1).expect("failed to create raw BAM reader");
RawReadAheadReader::new(reader).count() as u64
}
#[rstest::rstest]
#[case::coordinate(SortOrder::Coordinate, false)]
#[case::coordinate_with_index(SortOrder::Coordinate, true)]
#[case::queryname(SortOrder::Queryname(QuerynameComparator::default()), false)]
#[case::queryname_natural(SortOrder::Queryname(QuerynameComparator::Natural), false)]
#[case::template_coordinate(SortOrder::TemplateCoordinate, false)]
fn test_sort_with_consolidation_preserves_all_records(
#[case] sort_order: SortOrder,
#[case] write_index: bool,
) {
use crate::sam::builder::SamBuilder;
let num_pairs = 30;
let mut builder = SamBuilder::new();
for i in 0..num_pairs {
let _ = builder
.add_pair()
.name(&format!("read{i}"))
.start1(i * 200 + 1)
.start2(i * 200 + 101)
.build();
}
let dir = tempfile::tempdir().expect("failed to create temp directory");
let input = dir.path().join("input.bam");
let output = dir.path().join("output.bam");
builder.write_bam(&input).expect("failed to write BAM");
let stats = RawExternalSorter::new(sort_order)
.memory_limit(1024) .max_temp_files(4)
.temp_compression(0)
.output_compression(0)
.write_index(write_index)
.sort(&input, &output)
.expect("sort should succeed");
assert!(
stats.chunks_written >= 5,
"expected at least 5 chunks to exercise post-consolidation naming, got {}",
stats.chunks_written
);
let expected = (num_pairs * 2) as u64;
let observed = count_bam_records(&output);
assert_eq!(observed, expected, "chunk filename collision likely lost data");
}
#[rstest::rstest]
#[case::coordinate(SortOrder::Coordinate)]
#[case::queryname(SortOrder::Queryname(QuerynameComparator::default()))]
#[case::queryname_natural(SortOrder::Queryname(QuerynameComparator::Natural))]
#[case::template_coordinate(SortOrder::TemplateCoordinate)]
fn test_sort_many_chunks_with_semaphore(#[case] sort_order: SortOrder) {
use crate::sam::builder::SamBuilder;
let num_pairs = 200;
let mut builder = SamBuilder::new();
for i in 0..num_pairs {
let _ = builder
.add_pair()
.name(&format!("read{i}"))
.start1(i * 200 + 1)
.start2(i * 200 + 101)
.build();
}
let dir = tempfile::tempdir().expect("failed to create temp directory");
let input = dir.path().join("input.bam");
let output = dir.path().join("output.bam");
builder.write_bam(&input).expect("failed to write BAM");
let stats = RawExternalSorter::new(sort_order)
.memory_limit(32 * 1024)
.max_temp_files(0) .threads(2) .temp_compression(0)
.output_compression(0)
.sort(&input, &output)
.expect("sort should succeed");
assert!(
stats.chunks_written >= 2,
"expected multiple chunks to exercise merge, got {}",
stats.chunks_written
);
let expected = (num_pairs * 2) as u64;
let observed = count_bam_records(&output);
assert_eq!(observed, expected, "semaphore-capped merge lost data");
}
#[rstest::rstest]
#[case::coordinate(SortOrder::Coordinate)]
#[case::queryname(SortOrder::Queryname(QuerynameComparator::default()))]
#[case::queryname_natural(SortOrder::Queryname(QuerynameComparator::Natural))]
#[case::template_coordinate(SortOrder::TemplateCoordinate)]
fn test_sort_sub_arrays_match_single_thread(#[case] sort_order: SortOrder) {
use crate::sam::builder::SamBuilder;
let num_pairs = 50;
let mut builder = SamBuilder::new();
for i in 0..num_pairs {
let _ = builder
.add_pair()
.name(&format!("read{i}"))
.start1(i * 200 + 1)
.start2(i * 200 + 101)
.build();
}
let dir = tempfile::tempdir().expect("failed to create temp directory");
let input = dir.path().join("input.bam");
let output_st = dir.path().join("output_1t.bam");
let output_mt = dir.path().join("output_2t.bam");
builder.write_bam(&input).expect("failed to write BAM");
RawExternalSorter::new(sort_order)
.memory_limit(16 * 1024) .threads(1)
.temp_compression(0)
.output_compression(0)
.sort(&input, &output_st)
.expect("sort should succeed");
RawExternalSorter::new(sort_order)
.memory_limit(16 * 1024)
.threads(2)
.temp_compression(0)
.output_compression(0)
.sort(&input, &output_mt)
.expect("sort should succeed");
let names_st = collect_read_names(&output_st);
let names_mt = collect_read_names(&output_mt);
let expected = num_pairs * 2;
assert_eq!(names_st.len(), expected, "single-thread record count mismatch");
assert_eq!(names_mt.len(), expected, "multi-thread record count mismatch");
assert_eq!(names_st, names_mt, "multi-thread sort order differs from single-thread");
}
#[rstest::rstest]
#[case::coordinate(SortOrder::Coordinate)]
#[case::queryname(SortOrder::Queryname(QuerynameComparator::default()))]
#[case::queryname_natural(SortOrder::Queryname(QuerynameComparator::Natural))]
#[case::template_coordinate(SortOrder::TemplateCoordinate)]
fn test_sort_sub_arrays_in_memory_only(#[case] sort_order: SortOrder) {
use crate::sam::builder::SamBuilder;
let num_pairs = 20;
let mut builder = SamBuilder::new();
for i in 0..num_pairs {
let _ = builder
.add_pair()
.name(&format!("read{i}"))
.start1(i * 200 + 1)
.start2(i * 200 + 101)
.build();
}
let dir = tempfile::tempdir().expect("failed to create temp directory");
let input = dir.path().join("input.bam");
let output = dir.path().join("output.bam");
builder.write_bam(&input).expect("failed to write BAM");
RawExternalSorter::new(sort_order)
.memory_limit(10 * 1024 * 1024)
.threads(2)
.output_compression(0)
.sort(&input, &output)
.expect("sort should succeed");
let expected = (num_pairs * 2) as u64;
let observed = count_bam_records(&output);
assert_eq!(observed, expected, "sort lost data");
}
fn create_sorted_bam(
dir: &Path,
prefix: &str,
num_pairs: usize,
start_offset: usize,
sort_order: SortOrder,
) -> (PathBuf, Vec<String>) {
let mut builder = SamBuilder::new();
let mut names = Vec::with_capacity(num_pairs);
for i in 0..num_pairs {
let name = format!("{prefix}_read{i:04}");
names.push(name.clone());
let _ = builder
.add_pair()
.name(&name)
.start1((start_offset + i * 200) + 1)
.start2((start_offset + i * 200) + 101)
.build();
}
let unsorted = dir.join(format!("{prefix}_unsorted.bam"));
let sorted = dir.join(format!("{prefix}_sorted.bam"));
builder.write_bam(&unsorted).expect("failed to write BAM");
RawExternalSorter::new(sort_order)
.output_compression(0)
.sort(&unsorted, &sorted)
.expect("sort should succeed");
(sorted, names)
}
fn collect_read_names(path: &Path) -> Vec<String> {
use crate::sort::read_ahead::RawReadAheadReader;
let (reader, _) = create_raw_bam_reader(path, 1).expect("failed to create raw BAM reader");
RawReadAheadReader::new(reader)
.map(|rec| {
let name_bytes = fgumi_raw_bam::RawRecordView::new(rec.as_ref()).read_name();
String::from_utf8(name_bytes.to_vec()).expect("read name should be valid UTF-8")
})
.collect()
}
fn collect_positions(path: &Path) -> Vec<(i32, i32)> {
use crate::sort::read_ahead::RawReadAheadReader;
let (reader, _) = create_raw_bam_reader(path, 1).expect("failed to create raw BAM reader");
RawReadAheadReader::new(reader)
.map(|rec| {
let bytes = rec.as_ref();
{
let v = fgumi_raw_bam::RawRecordView::new(bytes);
(v.ref_id(), v.pos())
}
})
.collect()
}
fn default_merge_header() -> Header {
SamBuilder::new().header.clone()
}
#[test]
fn test_merge_bams_coordinate_sort() {
let dir = tempfile::tempdir().expect("failed to create temp directory");
let (bam_a, _) = create_sorted_bam(dir.path(), "a", 10, 0, SortOrder::Coordinate);
let (bam_b, _) = create_sorted_bam(dir.path(), "b", 10, 10_000, SortOrder::Coordinate);
let merged = dir.path().join("merged.bam");
let header = default_merge_header();
let count = RawExternalSorter::new(SortOrder::Coordinate)
.output_compression(0)
.merge_bams(&[bam_a, bam_b], &header, &merged)
.expect("sort should succeed");
assert_eq!(count, 40);
assert_eq!(count_bam_records(&merged), 40);
let positions = collect_positions(&merged);
for w in positions.windows(2) {
assert!(w[0] <= w[1], "coordinate sort violated: {:?} > {:?}", w[0], w[1]);
}
}
#[test]
fn test_merge_bams_template_coordinate_sort() {
let dir = tempfile::tempdir().expect("failed to create temp directory");
let (bam_a, _) = create_sorted_bam(dir.path(), "a", 10, 0, SortOrder::TemplateCoordinate);
let (bam_b, _) =
create_sorted_bam(dir.path(), "b", 10, 10_000, SortOrder::TemplateCoordinate);
let merged = dir.path().join("merged.bam");
let header = default_merge_header();
let count = RawExternalSorter::new(SortOrder::TemplateCoordinate)
.output_compression(0)
.merge_bams(&[bam_a, bam_b], &header, &merged)
.expect("sort should succeed");
assert_eq!(count, 40);
assert_eq!(count_bam_records(&merged), 40);
}
#[test]
fn test_merge_bams_queryname_sort() {
let dir = tempfile::tempdir().expect("failed to create temp directory");
let (bam_a, _) = create_sorted_bam(
dir.path(),
"a",
10,
0,
SortOrder::Queryname(QuerynameComparator::default()),
);
let (bam_b, _) = create_sorted_bam(
dir.path(),
"b",
10,
10_000,
SortOrder::Queryname(QuerynameComparator::default()),
);
let merged = dir.path().join("merged.bam");
let header = default_merge_header();
let count = RawExternalSorter::new(SortOrder::Queryname(QuerynameComparator::default()))
.output_compression(0)
.merge_bams(&[bam_a, bam_b], &header, &merged)
.expect("sort should succeed");
assert_eq!(count, 40);
assert_eq!(count_bam_records(&merged), 40);
let names = collect_read_names(&merged);
for w in names.windows(2) {
assert!(w[0] <= w[1], "queryname sort violated: {:?} > {:?}", w[0], w[1]);
}
}
#[test]
fn test_merge_bams_single_input() {
let dir = tempfile::tempdir().expect("failed to create temp directory");
let (bam_a, _) = create_sorted_bam(dir.path(), "a", 15, 0, SortOrder::Coordinate);
let merged = dir.path().join("merged.bam");
let header = default_merge_header();
let count = RawExternalSorter::new(SortOrder::Coordinate)
.output_compression(0)
.merge_bams(&[bam_a], &header, &merged)
.expect("sort should succeed");
assert_eq!(count, 30);
assert_eq!(count_bam_records(&merged), 30);
}
#[test]
fn test_merge_bams_preserves_all_records() {
let dir = tempfile::tempdir().expect("failed to create temp directory");
let (bam_a, names_a) = create_sorted_bam(
dir.path(),
"a",
5,
0,
SortOrder::Queryname(QuerynameComparator::default()),
);
let (bam_b, names_b) = create_sorted_bam(
dir.path(),
"b",
5,
10_000,
SortOrder::Queryname(QuerynameComparator::default()),
);
let merged = dir.path().join("merged.bam");
let header = default_merge_header();
RawExternalSorter::new(SortOrder::Queryname(QuerynameComparator::default()))
.output_compression(0)
.merge_bams(&[bam_a, bam_b], &header, &merged)
.expect("sort should succeed");
let merged_names: std::collections::HashSet<String> =
collect_read_names(&merged).into_iter().collect();
for name in names_a.iter().chain(names_b.iter()) {
assert!(merged_names.contains(name), "read name {name:?} missing from merged output");
}
}
#[test]
fn test_merge_bams_many_inputs() {
let dir = tempfile::tempdir().expect("failed to create temp directory");
let k = 8;
let pairs_per_input = 5;
let mut inputs = Vec::with_capacity(k);
for i in 0..k {
let (bam, _) = create_sorted_bam(
dir.path(),
&format!("in{i}"),
pairs_per_input,
i * 50_000,
SortOrder::Coordinate,
);
inputs.push(bam);
}
let merged = dir.path().join("merged.bam");
let header = default_merge_header();
let count = RawExternalSorter::new(SortOrder::Coordinate)
.output_compression(0)
.merge_bams(&inputs, &header, &merged)
.expect("sort should succeed");
let expected = (k * pairs_per_input * 2) as u64; assert_eq!(count, expected);
assert_eq!(count_bam_records(&merged), expected);
let positions = collect_positions(&merged);
for w in positions.windows(2) {
assert!(w[0] <= w[1], "coordinate sort violated with k={k}: {:?} > {:?}", w[0], w[1]);
}
}
#[test]
fn test_merge_bams_queryname_natural_sort() {
let dir = tempfile::tempdir().expect("failed to create temp directory");
let nat = SortOrder::Queryname(QuerynameComparator::Natural);
let (bam_a, _) = create_sorted_bam(dir.path(), "a", 10, 0, nat);
let (bam_b, _) = create_sorted_bam(dir.path(), "b", 10, 10_000, nat);
let merged = dir.path().join("merged.bam");
let header = default_merge_header();
let count = RawExternalSorter::new(nat)
.output_compression(0)
.merge_bams(&[bam_a, bam_b], &header, &merged)
.expect("merge should succeed");
assert_eq!(count, 40);
assert_eq!(count_bam_records(&merged), 40);
let names = collect_read_names(&merged);
for w in names.windows(2) {
assert!(w[0] <= w[1], "natural queryname sort violated: {:?} > {:?}", w[0], w[1]);
}
}
#[test]
fn test_sort_phase_timer_all_methods() {
let mut timer = SortPhaseTimer::new();
assert!(timer.overall_start.is_some());
assert!(timer.read_span_start.is_some());
let elapsed = timer.end_read_span();
assert!(timer.read_secs >= 0.0);
assert!(timer.read_span_start.is_none());
let _ = elapsed;
let elapsed2 = timer.end_read_span();
assert_eq!(elapsed2, std::time::Duration::ZERO);
assert!(timer.read_secs >= 0.0);
timer.begin_read_span();
assert!(timer.read_span_start.is_some());
timer.time_sort(|| {});
assert!(timer.sort_secs >= 0.0);
let result = timer.time_spill_write(|| Ok::<u32, anyhow::Error>(42));
assert_eq!(result.unwrap(), 42);
assert_eq!(timer.spill_count, 1);
assert!(timer.spill_write_secs >= 0.0);
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("spill.bin");
std::fs::write(&path, b"hello world").expect("write");
timer.record_spill_size(&path);
assert_eq!(timer.total_spill_bytes, 11);
timer.record_spill_size(&dir.path().join("nonexistent.bin"));
assert_eq!(timer.total_spill_bytes, 11);
timer.time_consolidate(|| Ok(())).expect("consolidate ok");
assert_eq!(timer.consolidate_count, 0);
timer
.time_consolidate(|| {
std::thread::sleep(std::time::Duration::from_millis(10));
Ok(())
})
.expect("consolidate ok");
assert_eq!(timer.consolidate_count, 1);
assert!(timer.consolidate_secs > 0.0);
timer.time_merge(|| Ok::<(), anyhow::Error>(())).expect("merge ok");
assert!(timer.merge_secs >= 0.0);
timer.time_write_output(|| Ok(())).expect("write ok");
assert!(timer.write_output_secs >= 0.0);
timer.log_summary(4);
}
#[test]
#[allow(clippy::cast_possible_truncation)]
fn test_split_off_aligns_with_par_chunks_mut() {
use rayon::prelude::*;
let pool =
rayon::ThreadPoolBuilder::new().num_threads(4).build().expect("build rayon pool");
for &(n, threads) in &[(10, 3), (11, 4), (17, 3), (100, 7), (1000, 8), (13, 5)] {
let mut entries: Vec<(u64, Vec<u8>)> =
(0..n).rev().map(|i| (i as u64, vec![i as u8])).collect();
let chunk_size = entries.len().div_ceil(threads);
pool.install(|| {
entries.par_chunks_mut(chunk_size).for_each(|chunk| {
chunk.sort_unstable_by(|a, b| a.0.cmp(&b.0));
});
});
let mut remaining = std::mem::take(&mut entries);
let num_chunks = remaining.len().div_ceil(chunk_size);
let mut chunks: Vec<Vec<(u64, Vec<u8>)>> = Vec::with_capacity(num_chunks);
let tail_len = remaining.len() % chunk_size;
if tail_len != 0 {
let split_at = remaining.len() - tail_len;
chunks.push(remaining.split_off(split_at));
}
while !remaining.is_empty() {
let split_at = remaining.len().saturating_sub(chunk_size);
chunks.push(remaining.split_off(split_at));
}
chunks.reverse();
for (ci, chunk) in chunks.iter().enumerate() {
for i in 1..chunk.len() {
assert!(
chunk[i - 1].0 <= chunk[i].0,
"n={n} threads={threads}: chunk {ci} not sorted at index {i} \
({} > {})",
chunk[i - 1].0,
chunk[i].0,
);
}
}
let total: usize = chunks.iter().map(Vec::len).sum();
assert_eq!(total, n, "n={n} threads={threads}: total mismatch");
}
}
#[test]
fn test_sort_with_two_temp_dirs_matches_single_dir() {
use crate::sam::builder::SamBuilder;
let num_pairs = 200;
let mut builder = SamBuilder::new();
for i in 0..num_pairs {
let _ = builder
.add_pair()
.name(&format!("read{i:05}"))
.start1(i * 200 + 1)
.start2(i * 200 + 101)
.build();
}
let workdir = tempfile::tempdir().expect("workdir");
let input = workdir.path().join("input.bam");
let output_multi = workdir.path().join("output_multi.bam");
let output_single = workdir.path().join("output_single.bam");
builder.write_bam(&input).expect("write bam");
let tmp_a = tempfile::tempdir().expect("tmp a");
let tmp_b = tempfile::tempdir().expect("tmp b");
let stats_multi = RawExternalSorter::new(SortOrder::Coordinate)
.memory_limit(8 * 1024)
.threads(1)
.temp_compression(0)
.output_compression(0)
.temp_dirs(vec![tmp_a.path().to_path_buf(), tmp_b.path().to_path_buf()])
.sort(&input, &output_multi)
.expect("multi-dir sort should succeed");
assert!(stats_multi.chunks_written >= 2, "expected multiple spill chunks");
RawExternalSorter::new(SortOrder::Coordinate)
.memory_limit(8 * 1024)
.threads(1)
.temp_compression(0)
.output_compression(0)
.sort(&input, &output_single)
.expect("single-dir sort should succeed");
let names_multi = collect_read_names(&output_multi);
let names_single = collect_read_names(&output_single);
assert_eq!(names_multi.len(), num_pairs * 2, "record count mismatch");
assert_eq!(
names_multi, names_single,
"multi-dir and single-dir sort produced different record orders"
);
}
}