use crossbeam_channel::{Receiver, Sender, bounded};
use std::thread::{self, JoinHandle};
use crate::bam_io::RawBamReaderAuto;
use fgumi_raw_bam::{RawBamReader, RawRecord};
use std::io::Read as IoRead;
const BATCH_SIZE: usize = 256;
const CHANNEL_BUFFER_SIZE: usize = 16;
pub struct RawReadAheadReader {
receiver: Option<Receiver<Vec<RawRecord>>>,
handle: Option<JoinHandle<()>>,
current_batch: Vec<RawRecord>,
batch_index: usize,
error_slot: std::sync::Arc<std::sync::Mutex<Option<std::io::Error>>>,
}
impl RawReadAheadReader {
#[must_use]
pub fn new(reader: RawBamReaderAuto) -> Self {
Self::from_reader(reader)
}
#[must_use]
pub(crate) fn from_reader<R: IoRead + Send + 'static>(reader: RawBamReader<R>) -> Self {
Self::from_reader_with_batch_size(reader, BATCH_SIZE, CHANNEL_BUFFER_SIZE)
}
#[must_use]
pub(crate) fn from_reader_with_batch_size<R: IoRead + Send + 'static>(
mut reader: RawBamReader<R>,
batch_size: usize,
channel_buffer: usize,
) -> Self {
let error_slot = std::sync::Arc::new(std::sync::Mutex::new(None));
let error_slot_thread = std::sync::Arc::clone(&error_slot);
let (tx, rx) = bounded(channel_buffer);
let handle = thread::spawn(move || {
Self::reader_thread_generic(&mut reader, tx, batch_size, error_slot_thread);
});
Self {
receiver: Some(rx),
handle: Some(handle),
current_batch: Vec::new(),
batch_index: 0,
error_slot,
}
}
#[allow(clippy::needless_pass_by_value)]
fn reader_thread_generic<R: IoRead>(
reader: &mut RawBamReader<R>,
tx: Sender<Vec<RawRecord>>,
batch_size: usize,
error_slot: std::sync::Arc<std::sync::Mutex<Option<std::io::Error>>>,
) {
let mut record = RawRecord::new();
let mut batch = Vec::with_capacity(batch_size);
loop {
match reader.read_record(&mut record) {
Ok(0) => {
if !batch.is_empty() {
let _ = tx.send(batch);
}
let _ = tx.send(Vec::new());
break;
}
Ok(_) => {
batch.push(std::mem::take(&mut record));
if batch.len() >= batch_size {
if tx.send(batch).is_err() {
break;
}
batch = Vec::with_capacity(batch_size);
}
}
Err(e) => {
log::error!("Error reading raw BAM record: {e}");
if let Ok(mut slot) = error_slot.lock() {
*slot = Some(e);
}
if !batch.is_empty() {
let _ = tx.send(batch);
}
let _ = tx.send(Vec::new());
break;
}
}
}
}
#[must_use]
pub fn take_error(&self) -> Option<std::io::Error> {
self.error_slot.lock().ok()?.take()
}
#[inline]
#[must_use]
pub fn next_record(&mut self) -> Option<RawRecord> {
if self.batch_index < self.current_batch.len() {
let record = std::mem::take(&mut self.current_batch[self.batch_index]);
self.batch_index += 1;
return Some(record);
}
let receiver = self.receiver.as_ref()?;
match receiver.recv() {
Ok(batch) if !batch.is_empty() => {
self.current_batch = batch;
self.batch_index = 1; Some(std::mem::take(&mut self.current_batch[0]))
}
Ok(_) | Err(_) => None, }
}
}
impl Drop for RawReadAheadReader {
fn drop(&mut self) {
drop(self.receiver.take());
if let Some(handle) = self.handle.take() {
let _ = handle.join();
}
}
}
impl Iterator for RawReadAheadReader {
type Item = RawRecord;
fn next(&mut self) -> Option<Self::Item> {
self.next_record()
}
}
pub struct PooledInputStream {
decompressed_input: std::sync::Arc<crossbeam_queue::ArrayQueue<(u64, Vec<u8>)>>,
decompressed_input_done: std::sync::Arc<std::sync::atomic::AtomicBool>,
input_read_error: std::sync::Arc<std::sync::atomic::AtomicBool>,
decompression_error: std::sync::Arc<std::sync::atomic::AtomicBool>,
reorder: crate::reorder_buffer::ReorderBuffer<Vec<u8>>,
current_buf: Vec<u8>,
current_pos: usize,
}
impl PooledInputStream {
#[must_use]
pub fn new(
decompressed_input: std::sync::Arc<crossbeam_queue::ArrayQueue<(u64, Vec<u8>)>>,
decompressed_input_done: std::sync::Arc<std::sync::atomic::AtomicBool>,
input_read_error: std::sync::Arc<std::sync::atomic::AtomicBool>,
decompression_error: std::sync::Arc<std::sync::atomic::AtomicBool>,
) -> Self {
Self {
decompressed_input,
decompressed_input_done,
input_read_error,
decompression_error,
reorder: crate::reorder_buffer::ReorderBuffer::new(),
current_buf: Vec::new(),
current_pos: 0,
}
}
fn is_eof(&self) -> bool {
self.decompressed_input_done.load(std::sync::atomic::Ordering::Acquire)
&& self.decompressed_input.is_empty()
}
fn drain_queue(&mut self) {
let reorder_cap = self.decompressed_input.capacity() * 2;
loop {
if self.reorder.can_pop() && self.reorder.buffer_len() >= reorder_cap {
break;
}
match self.decompressed_input.pop() {
Some((serial, data)) => self.reorder.insert(serial, data),
None => break,
}
}
}
fn next_block(&mut self) -> Option<Vec<u8>> {
loop {
self.drain_queue();
if let Some(data) = self.reorder.try_pop_next() {
return Some(data);
}
if self.is_eof() {
self.drain_queue();
if let Some(data) = self.reorder.try_pop_next() {
return Some(data);
}
return None;
}
std::thread::park();
if self.input_read_error.load(std::sync::atomic::Ordering::Acquire)
|| self.decompression_error.load(std::sync::atomic::Ordering::Acquire)
{
return None;
}
}
}
}
impl IoRead for PooledInputStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.current_pos < self.current_buf.len() {
let available = &self.current_buf[self.current_pos..];
let n = available.len().min(buf.len());
buf[..n].copy_from_slice(&available[..n]);
self.current_pos += n;
return Ok(n);
}
if self.input_read_error.load(std::sync::atomic::Ordering::Acquire) {
return Err(std::io::Error::other(
"I/O error reading input BAM blocks (see log for details)",
));
}
if self.decompression_error.load(std::sync::atomic::Ordering::Acquire) {
return Err(std::io::Error::other(
"BGZF decompression error on input blocks (see log for details)",
));
}
if let Some(data) = self.next_block() {
let n = data.len().min(buf.len());
buf[..n].copy_from_slice(&data[..n]);
if n < data.len() {
self.current_buf = data;
self.current_pos = n;
} else {
self.current_buf.clear();
self.current_pos = 0;
}
Ok(n)
} else {
if self.input_read_error.load(std::sync::atomic::Ordering::Acquire) {
return Err(std::io::Error::other(
"I/O error reading input BAM blocks (see log for details)",
));
}
if self.decompression_error.load(std::sync::atomic::Ordering::Acquire) {
return Err(std::io::Error::other(
"BGZF decompression error on input blocks (see log for details)",
));
}
Ok(0) }
}
}
pub enum RecordSource {
ReadAhead(RawReadAheadReader),
Direct(RawBamReader<PooledInputStream>, Option<std::io::Error>),
}
impl RecordSource {
#[must_use]
pub fn direct(reader: RawBamReader<PooledInputStream>) -> Self {
Self::Direct(reader, None)
}
pub fn take_error(&mut self) -> Option<std::io::Error> {
match self {
Self::Direct(_, err) => err.take(),
Self::ReadAhead(r) => r.take_error(),
}
}
}
impl Iterator for RecordSource {
type Item = RawRecord;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::ReadAhead(r) => r.next(),
Self::Direct(reader, error_slot) => {
let mut record = RawRecord::default();
match reader.read_record(&mut record) {
Ok(0) => None, Ok(_) => Some(record),
Err(e) => {
log::error!("Error reading raw BAM record: {e}");
if error_slot.is_none() {
*error_slot = Some(e);
}
None
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_batch_config() {
assert_eq!(BATCH_SIZE, 256);
assert_eq!(CHANNEL_BUFFER_SIZE, 16);
assert_eq!(BATCH_SIZE * CHANNEL_BUFFER_SIZE, 4096);
}
#[test]
fn test_prefetch_total_records() {
let total = BATCH_SIZE * CHANNEL_BUFFER_SIZE;
assert_eq!(total, 4096, "Total prefetch should be BATCH_SIZE * CHANNEL_BUFFER_SIZE = 4096");
}
use noodles::sam::Header;
use noodles::sam::alignment::RecordBuf;
use noodles::sam::alignment::io::Write as AlignmentWrite;
use noodles::sam::alignment::record::Flags;
use noodles::sam::header::record::value::Map;
use noodles::sam::header::record::value::map::ReferenceSequence;
use std::num::NonZeroUsize;
use tempfile::NamedTempFile;
use crate::bam_io::create_raw_bam_reader;
fn create_test_bam_file(num_records: usize) -> (NamedTempFile, Header) {
let ref_seq =
Map::<ReferenceSequence>::new(NonZeroUsize::new(1000).expect("1000 is non-zero"));
let header = Header::builder().add_reference_sequence("chr1", ref_seq).build();
let tmp = NamedTempFile::new().expect("creating temp file/dir should succeed");
let path = tmp.path().to_path_buf();
{
let file = std::fs::File::create(&path).expect("creating file should succeed");
let mut writer = noodles::bam::io::Writer::new(file);
writer.write_header(&header).expect("writing header should succeed");
for i in 0..num_records {
let name = format!("read{i}");
let record =
RecordBuf::builder().set_name(&*name).set_flags(Flags::UNMAPPED).build();
writer
.write_alignment_record(&header, &record)
.expect("writing alignment record should succeed");
}
}
(tmp, header)
}
#[test]
fn test_raw_read_ahead_empty() {
let (tmp, _header) = create_test_bam_file(0);
let (reader, _header) =
create_raw_bam_reader(tmp.path(), 1).expect("creating BAM reader should succeed");
let mut ra = RawReadAheadReader::new(reader);
assert!(ra.next_record().is_none(), "Empty BAM should yield no raw records");
}
#[test]
fn test_raw_read_ahead_multiple() {
let num = 10;
let (tmp, _header) = create_test_bam_file(num);
let (reader, _header) =
create_raw_bam_reader(tmp.path(), 1).expect("creating BAM reader should succeed");
let ra = RawReadAheadReader::new(reader);
let records: Vec<RawRecord> = ra.collect();
assert_eq!(records.len(), num, "Raw read-ahead should yield exactly {num} records");
}
}