use std::{fs, io, path::Path, sync::Arc, thread};
use memmap2::Mmap;
use zstd::{stream::copy_decode, zstd_safe};
use crate::{
BinseqRecord, ParallelProcessor, ParallelReader, Result,
cbq::core::{
BlockHeader, BlockRange, ColumnarBlock, FileHeader, Index, IndexFooter, IndexHeader,
},
};
pub struct Reader<R: io::Read> {
inner: R,
pub block: ColumnarBlock,
iheader: Option<IndexHeader>,
}
impl<R: io::Read> Reader<R> {
pub fn new(mut inner: R) -> Result<Self> {
let mut header_buf = [0u8; size_of::<FileHeader>()];
inner.read_exact(&mut header_buf)?;
let header = FileHeader::from_bytes(&header_buf)?;
Ok(Self {
inner,
block: ColumnarBlock::new(header),
iheader: None,
})
}
pub fn set_default_quality_score(&mut self, score: u8) {
self.block.set_default_quality_score(score);
}
pub fn read_block(&mut self) -> Result<Option<BlockHeader>> {
let mut iheader_buf = [0u8; size_of::<IndexHeader>()];
let mut diff_buf = [0u8; size_of::<BlockHeader>() - size_of::<IndexHeader>()];
let mut header_buf = [0u8; size_of::<BlockHeader>()];
match self.inner.read_exact(&mut iheader_buf) {
Ok(()) => {}
Err(e) => {
if e.kind() == io::ErrorKind::UnexpectedEof {
return Ok(None);
}
return Err(e.into());
}
}
if let Ok(iheader) = IndexHeader::from_bytes(&iheader_buf) {
self.iheader = Some(iheader);
return Ok(None);
}
match self.inner.read_exact(&mut diff_buf) {
Ok(()) => {}
Err(e) => {
return Err(e.into());
}
}
header_buf[..iheader_buf.len()].copy_from_slice(&iheader_buf);
header_buf[iheader_buf.len()..].copy_from_slice(&diff_buf);
let header = BlockHeader::from_bytes(&header_buf)?;
self.block.read_from(&mut self.inner, header)?;
Ok(Some(header))
}
pub fn read_index(&mut self) -> Result<Option<Index>> {
let Some(header) = self.iheader else {
return Ok(None);
};
let mut z_index_buf = Vec::new();
let mut index_buf = Vec::new();
let mut footer_buf = [0u8; size_of::<IndexFooter>()];
z_index_buf.resize(header.z_bytes as usize, 0);
self.inner.read_exact(&mut z_index_buf)?;
copy_decode(z_index_buf.as_slice(), &mut index_buf)?;
let index = Index::from_bytes(&index_buf)?;
self.inner.read_exact(&mut footer_buf)?;
let _footer = IndexFooter::from_bytes(&footer_buf)?;
Ok(Some(index))
}
}
pub struct MmapReader {
inner: Arc<Mmap>,
index: Arc<Index>,
block: ColumnarBlock,
dctx: zstd_safe::DCtx<'static>,
}
impl Clone for MmapReader {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
index: self.index.clone(),
block: self.block.clone(),
dctx: zstd_safe::DCtx::create(),
}
}
}
impl MmapReader {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = fs::File::open(path)?;
let inner = unsafe { Mmap::map(&file) }?;
let header = FileHeader::from_bytes(&inner[..size_of::<FileHeader>()])?;
let index = {
let footer_start = inner.len() - size_of::<IndexFooter>();
let mut footer_buf = [0u8; size_of::<IndexFooter>()];
footer_buf.copy_from_slice(&inner[footer_start..]);
let index_footer = IndexFooter::from_bytes(&footer_buf)?;
let z_index_start = footer_start - index_footer.bytes as usize;
let z_index_slice = &inner[z_index_start..footer_start];
let mut index_buf = Vec::default();
copy_decode(z_index_slice, &mut index_buf)?;
Index::from_bytes(&index_buf)
}?;
Ok(Self {
inner: Arc::new(inner),
index: Arc::new(index),
block: ColumnarBlock::new(header),
dctx: zstd_safe::DCtx::create(),
})
}
pub fn set_default_quality_score(&mut self, score: u8) {
self.block.set_default_quality_score(score);
}
#[must_use]
pub fn header(&self) -> FileHeader {
self.block.header
}
#[must_use]
pub fn is_paired(&self) -> bool {
self.block.header.is_paired()
}
#[must_use]
pub fn num_records(&self) -> usize {
self.index.num_records()
}
#[must_use]
pub fn num_blocks(&self) -> usize {
self.index.num_blocks()
}
#[must_use]
pub fn index(&self) -> &Index {
&self.index
}
fn load_block(&mut self, range: BlockRange) -> Result<()> {
let header_start = range.offset as usize;
let header_end = size_of::<BlockHeader>() + header_start;
let block_header = {
let mut block_header_buf = [0u8; size_of::<BlockHeader>()];
block_header_buf.copy_from_slice(&self.inner[header_start..header_end]);
BlockHeader::from_bytes(&block_header_buf)
}?;
let data_end = header_end + block_header.block_len();
let block_data_slice = &self.inner[header_end..data_end];
self.block
.decompress_from_bytes(block_data_slice, block_header, &mut self.dctx)?;
Ok(())
}
pub fn iter_block_headers(&self) -> impl Iterator<Item = Result<BlockHeader>> {
self.index.iter_blocks().map(|range| {
let mut block_header_buf = [0u8; size_of::<BlockHeader>()];
block_header_buf.copy_from_slice(
&self.inner
[range.offset as usize..range.offset as usize + size_of::<BlockHeader>()],
);
BlockHeader::from_bytes(&block_header_buf)
})
}
}
impl ParallelReader for MmapReader {
fn process_parallel<P: ParallelProcessor + Clone + 'static>(
self,
processor: P,
num_threads: usize,
) -> crate::Result<()> {
let num_records = self.num_records();
self.process_parallel_range(processor, num_threads, 0..num_records)
}
fn process_parallel_range<P: ParallelProcessor + Clone + 'static>(
self,
processor: P,
num_threads: usize,
range: std::ops::Range<usize>,
) -> crate::Result<()> {
let num_threads = if num_threads == 0 {
num_cpus::get()
} else {
num_threads.min(num_cpus::get())
};
let total_records = self.num_records();
self.validate_range(total_records, &range)?;
let mut iv_start = 0;
let relevant_blocks = self
.index
.iter_blocks()
.filter(|block| {
let iv_end = block.cumulative_records as usize;
let relevant = iv_start <= range.end && iv_end > range.start;
iv_start = iv_end;
relevant
})
.collect::<Vec<_>>();
let num_blocks = relevant_blocks.len();
if relevant_blocks.is_empty() {
return Ok(()); }
let base_blocks_per_thread = num_blocks / num_threads;
let extra_blocks = num_blocks % num_threads;
let mut handles = Vec::new();
for thread_id in 0..num_threads {
let blocks_for_this_thread = if thread_id < extra_blocks {
base_blocks_per_thread + 1
} else {
base_blocks_per_thread
};
let start_block_idx = if thread_id < extra_blocks {
thread_id * (base_blocks_per_thread + 1)
} else {
extra_blocks * (base_blocks_per_thread + 1)
+ (thread_id - extra_blocks) * base_blocks_per_thread
};
let end_block_idx = start_block_idx + blocks_for_this_thread;
if blocks_for_this_thread == 0 {
continue;
}
let mut t_reader = self.clone();
let mut t_proc = processor.clone();
let t_block_ranges = relevant_blocks
.iter()
.skip(start_block_idx)
.take(end_block_idx - start_block_idx)
.copied()
.collect::<Vec<_>>();
let thread_handle = thread::spawn(move || -> crate::Result<()> {
for b_range in t_block_ranges {
t_reader.load_block(b_range)?;
for record in t_reader.block.iter_records(b_range) {
let global_record_idx = record.index() as usize;
if global_record_idx >= range.start && global_record_idx < range.end {
t_proc.process_record(record)?;
}
}
t_proc.on_batch_complete()?;
}
Ok(())
});
handles.push(thread_handle);
}
for handle in handles {
handle.join().unwrap()?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::BinseqRecord;
const TEST_CBQ_FILE: &str = "./data/subset.cbq";
#[test]
fn test_mmap_reader_new() {
let reader = MmapReader::new(TEST_CBQ_FILE);
assert!(reader.is_ok(), "Failed to create CBQ reader");
}
#[test]
fn test_mmap_reader_num_records() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let num_records = reader.num_records();
assert!(num_records > 0, "Expected non-zero records");
}
#[test]
fn test_mmap_reader_is_paired() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let is_paired = reader.is_paired();
assert!(is_paired || !is_paired);
}
#[test]
fn test_mmap_reader_header_access() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let header = reader.header();
assert!(header.block_size > 0, "Expected non-zero block size");
}
#[test]
fn test_mmap_reader_index_access() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let index = reader.index();
assert!(index.num_records() > 0, "Index should have records");
}
#[test]
fn test_mmap_reader_num_blocks() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let num_blocks = reader.num_blocks();
assert!(num_blocks > 0, "Should have at least one block");
}
#[test]
fn test_set_default_quality_score() {
let mut reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let custom_score = 42u8;
reader.set_default_quality_score(custom_score);
}
#[derive(Clone)]
struct CbqCountingProcessor {
count: Arc<std::sync::Mutex<usize>>,
}
impl ParallelProcessor for CbqCountingProcessor {
fn process_record<R: BinseqRecord>(&mut self, _record: R) -> Result<()> {
let mut count = self.count.lock().unwrap();
*count += 1;
Ok(())
}
}
#[test]
fn test_parallel_processing() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let num_records = reader.num_records();
let count = Arc::new(std::sync::Mutex::new(0));
let processor = CbqCountingProcessor {
count: count.clone(),
};
reader.process_parallel(processor, 2).unwrap();
let final_count = *count.lock().unwrap();
assert_eq!(final_count, num_records, "All records should be processed");
}
#[test]
fn test_parallel_processing_range() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let num_records = reader.num_records();
if num_records >= 100 {
let start = 10;
let end = 50;
let expected_count = end - start;
let count = Arc::new(std::sync::Mutex::new(0));
let processor = CbqCountingProcessor {
count: count.clone(),
};
reader
.process_parallel_range(processor, 2, start..end)
.unwrap();
let final_count = *count.lock().unwrap();
assert_eq!(
final_count, expected_count,
"Should process exactly {} records",
expected_count
);
}
}
#[test]
fn test_parallel_processing_with_record_data() {
#[derive(Clone)]
struct RecordValidator {
valid_count: Arc<std::sync::Mutex<usize>>,
}
impl ParallelProcessor for RecordValidator {
fn process_record<R: BinseqRecord>(&mut self, record: R) -> Result<()> {
assert!(record.slen() > 0, "Record should have non-zero length");
let mut count = self.valid_count.lock().unwrap();
*count += 1;
Ok(())
}
}
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let num_records = reader.num_records();
let count = Arc::new(std::sync::Mutex::new(0));
let processor = RecordValidator {
valid_count: count.clone(),
};
reader.process_parallel(processor, 2).unwrap();
let final_count = *count.lock().unwrap();
assert_eq!(final_count, num_records);
}
#[test]
fn test_index_num_records() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let index_records = reader.index().num_records();
let reader_records = reader.num_records();
assert_eq!(
index_records, reader_records,
"Index and reader should report same number of records"
);
}
#[test]
fn test_index_num_blocks() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let num_blocks = reader.index().num_blocks();
assert!(num_blocks > 0, "Should have at least one block");
}
#[test]
fn test_index_iter_blocks() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let blocks: Vec<_> = reader.index().iter_blocks().collect();
assert!(!blocks.is_empty(), "Should have at least one block");
let num_blocks = reader.num_blocks();
assert_eq!(blocks.len(), num_blocks, "Block count should match");
}
#[test]
fn test_nonexistent_file() {
let result = MmapReader::new("./data/nonexistent.cbq");
assert!(result.is_err(), "Should fail on nonexistent file");
}
#[test]
fn test_invalid_file_format() {
let result = MmapReader::new("./Cargo.toml");
assert!(result.is_err(), "Should fail on invalid file format");
}
#[test]
fn test_iter_block_headers() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let headers: Vec<_> = reader
.iter_block_headers()
.take(5)
.collect::<Result<Vec<_>>>()
.unwrap();
assert!(!headers.is_empty(), "Should have at least one block header");
for header in headers {
assert!(header.num_records > 0, "Block should have records");
}
}
#[test]
fn test_iter_block_headers_count() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let header_count = reader
.iter_block_headers()
.collect::<Result<Vec<_>>>()
.unwrap()
.len();
let num_blocks = reader.num_blocks();
assert_eq!(header_count, num_blocks, "Should iterate all block headers");
}
#[test]
fn test_parallel_processing_empty_range() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let count = Arc::new(std::sync::Mutex::new(0));
let processor = CbqCountingProcessor {
count: count.clone(),
};
reader.process_parallel_range(processor, 2, 0..0).unwrap();
let final_count = *count.lock().unwrap();
assert_eq!(final_count, 0, "Empty range should process no records");
}
#[test]
fn test_parallel_processing_invalid_range() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let num_records = reader.num_records();
let count = Arc::new(std::sync::Mutex::new(0));
let processor = CbqCountingProcessor {
count: count.clone(),
};
let result =
reader.process_parallel_range(processor, 2, num_records + 100..num_records + 200);
assert!(result.is_err(), "Should handle out of bounds as error");
}
#[test]
fn test_parallel_processing_single_thread() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let num_records = reader.num_records();
let count = Arc::new(std::sync::Mutex::new(0));
let processor = CbqCountingProcessor {
count: count.clone(),
};
reader.process_parallel(processor, 1).unwrap();
let final_count = *count.lock().unwrap();
assert_eq!(final_count, num_records);
}
#[test]
fn test_parallel_processing_many_threads() {
let reader = MmapReader::new(TEST_CBQ_FILE).unwrap();
let num_records = reader.num_records();
let count = Arc::new(std::sync::Mutex::new(0));
let processor = CbqCountingProcessor {
count: count.clone(),
};
reader.process_parallel(processor, 8).unwrap();
let final_count = *count.lock().unwrap();
assert_eq!(final_count, num_records);
}
}