use std::fs::File;
use std::io::Read;
use std::ops::Range;
use std::path::Path;
use std::sync::Arc;
use bitnuc::BitSize;
use bytemuck::cast_slice;
use memmap2::Mmap;
use super::header::{FileHeader, SIZE_HEADER};
use crate::{
BinseqRecord, DEFAULT_QUALITY_SCORE, Error, ParallelProcessor, ParallelReader,
error::{ReadError, Result},
};
#[derive(Clone, Copy)]
pub struct RefRecord<'a> {
id: u64,
buffer: &'a [u64],
qbuf: &'a [u8],
config: RecordConfig,
header_buf: [u8; 20],
header_len: usize,
}
impl<'a> RefRecord<'a> {
#[must_use]
pub fn new(id: u64, buffer: &'a [u64], qbuf: &'a [u8], config: RecordConfig) -> Self {
assert_eq!(buffer.len(), config.record_size_u64());
Self {
id,
buffer,
qbuf,
config,
header_buf: [0; 20],
header_len: 0,
}
}
#[must_use]
pub fn config(&self) -> RecordConfig {
self.config
}
pub fn set_id(&mut self, id: &[u8]) {
self.header_len = id.len();
self.header_buf[..self.header_len].copy_from_slice(id);
}
}
impl BinseqRecord for RefRecord<'_> {
fn bitsize(&self) -> BitSize {
self.config.bitsize
}
fn index(&self) -> u64 {
self.id
}
fn sheader(&self) -> &[u8] {
&self.header_buf[..self.header_len]
}
fn xheader(&self) -> &[u8] {
self.sheader()
}
fn flag(&self) -> Option<u64> {
if self.config.flags {
Some(self.buffer[0])
} else {
None
}
}
fn slen(&self) -> u64 {
self.config.slen
}
fn xlen(&self) -> u64 {
self.config.xlen
}
fn sbuf(&self) -> &[u64] {
if self.config.flags {
&self.buffer[1..=(self.config.schunk as usize)]
} else {
&self.buffer[..(self.config.schunk as usize)]
}
}
fn xbuf(&self) -> &[u64] {
if self.config.flags {
&self.buffer[1 + self.config.schunk as usize..]
} else {
&self.buffer[self.config.schunk as usize..]
}
}
fn squal(&self) -> &[u8] {
&self.qbuf[..self.config.slen as usize]
}
fn xqual(&self) -> &[u8] {
&self.qbuf[..self.config.xlen as usize]
}
}
pub struct BatchRecord<'a> {
buffer: &'a [u64],
dbuf: &'a [u8],
id: u64,
config: RecordConfig,
qbuf: &'a [u8],
header_buf: [u8; 20],
header_len: usize,
}
impl BinseqRecord for BatchRecord<'_> {
fn bitsize(&self) -> BitSize {
self.config.bitsize
}
fn index(&self) -> u64 {
self.id
}
fn sheader(&self) -> &[u8] {
&self.header_buf[..self.header_len]
}
fn xheader(&self) -> &[u8] {
self.sheader()
}
fn flag(&self) -> Option<u64> {
if self.config.flags {
Some(self.buffer[0])
} else {
None
}
}
fn slen(&self) -> u64 {
self.config.slen
}
fn xlen(&self) -> u64 {
self.config.xlen
}
fn sbuf(&self) -> &[u64] {
if self.config.flags {
&self.buffer[1..=(self.config.schunk as usize)]
} else {
&self.buffer[..(self.config.schunk as usize)]
}
}
fn xbuf(&self) -> &[u64] {
if self.config.flags {
&self.buffer[1 + self.config.schunk as usize..]
} else {
&self.buffer[self.config.schunk as usize..]
}
}
fn decode_s(&self, dbuf: &mut Vec<u8>) -> Result<()> {
dbuf.extend_from_slice(self.sseq());
Ok(())
}
fn decode_x(&self, dbuf: &mut Vec<u8>) -> Result<()> {
dbuf.extend_from_slice(self.xseq());
Ok(())
}
fn sseq(&self) -> &[u8] {
let scalar = self.config.scalar();
let mut lbound = 0;
let mut rbound = self.config.slen();
if self.config.flags {
lbound += scalar;
rbound += scalar;
}
&self.dbuf[lbound..rbound]
}
fn xseq(&self) -> &[u8] {
let scalar = self.config.scalar();
let mut lbound = scalar * self.config.schunk();
let mut rbound = lbound + self.config.xlen();
if self.config.flags {
lbound += scalar;
rbound += scalar;
}
&self.dbuf[lbound..rbound]
}
fn squal(&self) -> &[u8] {
&self.qbuf[..self.config.slen()]
}
fn xqual(&self) -> &[u8] {
&self.qbuf[..self.config.xlen()]
}
}
#[derive(Clone, Copy)]
pub struct RecordConfig {
slen: u64,
xlen: u64,
schunk: u64,
xchunk: u64,
bitsize: BitSize,
flags: bool,
}
impl RecordConfig {
pub fn new(slen: usize, xlen: usize, bitsize: BitSize, flags: bool) -> Self {
let (schunk, xchunk) = match bitsize {
BitSize::Two => (slen.div_ceil(32), xlen.div_ceil(32)),
BitSize::Four => (slen.div_ceil(16), xlen.div_ceil(16)),
};
Self {
slen: slen as u64,
xlen: xlen as u64,
schunk: schunk as u64,
xchunk: xchunk as u64,
bitsize,
flags,
}
}
pub fn from_header(header: &FileHeader) -> Self {
Self::new(
header.slen as usize,
header.xlen as usize,
header.bits,
header.flags,
)
}
pub fn paired(&self) -> bool {
self.xlen > 0
}
pub fn slen(&self) -> usize {
self.slen as usize
}
pub fn xlen(&self) -> usize {
self.xlen as usize
}
pub fn schunk(&self) -> usize {
self.schunk as usize
}
pub fn xchunk(&self) -> usize {
self.xchunk as usize
}
pub fn record_size_bytes(&self) -> usize {
8 * self.record_size_u64()
}
pub fn record_size_u64(&self) -> usize {
if self.flags {
(self.schunk + self.xchunk + 1) as usize
} else {
(self.schunk + self.xchunk) as usize
}
}
pub fn scalar(&self) -> usize {
match self.bitsize {
BitSize::Two => 32,
BitSize::Four => 16,
}
}
}
pub struct MmapReader {
mmap: Arc<Mmap>,
header: FileHeader,
config: RecordConfig,
qbuf: Vec<u8>,
default_quality_score: u8,
}
impl MmapReader {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path)?;
if !file.metadata()?.is_file() {
return Err(ReadError::IncompatibleFile.into());
}
let mmap = unsafe { Mmap::map(&file)? };
let header = FileHeader::from_buffer(&mmap)?;
let config = RecordConfig::from_header(&header);
if !(mmap.len() - SIZE_HEADER).is_multiple_of(config.record_size_bytes()) {
return Err(ReadError::FileTruncation(mmap.len()).into());
}
let qbuf = vec![DEFAULT_QUALITY_SCORE; header.slen.max(header.xlen) as usize];
Ok(Self {
mmap: Arc::new(mmap),
header,
config,
qbuf,
default_quality_score: DEFAULT_QUALITY_SCORE,
})
}
#[must_use]
pub fn num_records(&self) -> usize {
(self.mmap.len() - SIZE_HEADER) / self.config.record_size_bytes()
}
#[must_use]
pub fn header(&self) -> FileHeader {
self.header
}
#[must_use]
pub fn is_paired(&self) -> bool {
self.header.is_paired()
}
pub fn set_default_quality_score(&mut self, score: u8) {
self.default_quality_score = score;
self.qbuf = self.build_qbuf();
}
#[must_use]
pub fn build_qbuf(&self) -> Vec<u8> {
vec![self.default_quality_score; self.header.slen.max(self.header.xlen) as usize]
}
pub fn get(&self, idx: usize) -> Result<RefRecord<'_>> {
if idx > self.num_records() {
return Err(ReadError::OutOfRange {
requested_index: idx,
max_index: self.num_records(),
}
.into());
}
let rsize = self.config.record_size_bytes();
let lbound = SIZE_HEADER + (idx * rsize);
let rbound = lbound + rsize;
let bytes = &self.mmap[lbound..rbound];
let buffer = cast_slice(bytes);
Ok(RefRecord::new(idx as u64, buffer, &self.qbuf, self.config))
}
pub fn get_buffer_slice(&self, range: Range<usize>) -> Result<&[u64]> {
if range.end > self.num_records() {
return Err(ReadError::OutOfRange {
requested_index: range.end,
max_index: self.num_records(),
}
.into());
}
let rsize = self.config.record_size_bytes();
let total_records = range.end - range.start;
let lbound = SIZE_HEADER + (range.start * rsize);
let rbound = lbound + (total_records * rsize);
let bytes = &self.mmap[lbound..rbound];
let buffer = cast_slice(bytes);
Ok(buffer)
}
}
pub struct StreamReader<R: Read> {
reader: R,
header: Option<FileHeader>,
config: Option<RecordConfig>,
buffer: Vec<u8>,
qbuf: Vec<u8>,
default_quality_score: u8,
buffer_pos: usize,
buffer_len: usize,
}
impl<R: Read> StreamReader<R> {
pub fn new(reader: R) -> Self {
Self::with_capacity(reader, 8192)
}
pub fn with_capacity(reader: R, capacity: usize) -> Self {
Self {
reader,
header: None,
config: None,
buffer: vec![0; capacity],
qbuf: vec![0; capacity],
buffer_pos: 0,
buffer_len: 0,
default_quality_score: DEFAULT_QUALITY_SCORE,
}
}
pub fn set_default_quality_score(&mut self, score: u8) {
if score != self.default_quality_score {
self.qbuf.clear();
}
self.default_quality_score = score;
}
pub fn read_header(&mut self) -> Result<&FileHeader> {
if self.header.is_some() {
return Ok(self
.header
.as_ref()
.expect("Missing header when expected in stream"));
}
while self.buffer_len - self.buffer_pos < SIZE_HEADER {
self.fill_buffer()?;
}
let header_slice = &self.buffer[self.buffer_pos..self.buffer_pos + SIZE_HEADER];
let header = FileHeader::from_buffer(header_slice)?;
self.header = Some(header);
self.config = Some(RecordConfig::from_header(&header));
self.buffer_pos += SIZE_HEADER;
Ok(self.header.as_ref().unwrap())
}
fn fill_buffer(&mut self) -> Result<()> {
if self.buffer_pos > 0 && self.buffer_pos < self.buffer_len {
self.buffer.copy_within(self.buffer_pos..self.buffer_len, 0);
self.buffer_len -= self.buffer_pos;
self.buffer_pos = 0;
} else if self.buffer_pos == self.buffer_len {
self.buffer_len = 0;
self.buffer_pos = 0;
}
let bytes_read = self.reader.read(&mut self.buffer[self.buffer_len..])?;
if bytes_read == 0 {
return Err(ReadError::EndOfStream.into());
}
self.buffer_len += bytes_read;
Ok(())
}
pub fn next_record(&mut self) -> Option<Result<RefRecord<'_>>> {
if self.header.is_none()
&& let Some(e) = self.read_header().err()
{
return Some(Err(e));
}
let config = self
.config
.expect("Missing configuration when expected in stream");
let record_size = config.record_size_bytes();
while self.buffer_len - self.buffer_pos < record_size {
match self.fill_buffer() {
Ok(()) => {}
Err(Error::ReadError(ReadError::EndOfStream)) => {
if self.buffer_len - self.buffer_pos > 0 {
return Some(Err(ReadError::PartialRecord(
self.buffer_len - self.buffer_pos,
)
.into()));
}
return None;
}
Err(e) => return Some(Err(e)),
}
}
let record_start = self.buffer_pos;
self.buffer_pos += record_size;
let record_bytes = &self.buffer[record_start..record_start + record_size];
let record_u64s = cast_slice(record_bytes);
if self.qbuf.is_empty() {
let max_size = config.slen.max(config.xlen) as usize;
self.qbuf.resize(max_size, self.default_quality_score);
}
let id = (record_start - SIZE_HEADER) / record_size;
Some(Ok(RefRecord::new(
id as u64,
record_u64s,
&self.qbuf,
config,
)))
}
pub fn into_inner(self) -> R {
self.reader
}
}
pub const BATCH_SIZE: usize = 1024;
impl ParallelReader for MmapReader {
fn process_parallel<P: ParallelProcessor + Clone + 'static>(
self,
processor: P,
num_threads: usize,
) -> Result<()> {
let num_records = self.num_records();
self.process_parallel_range(processor, num_threads, 0..num_records)
}
fn process_parallel_range<P: ParallelProcessor + Clone + 'static>(
self,
processor: P,
num_threads: usize,
range: Range<usize>,
) -> Result<()> {
let num_threads = if num_threads == 0 {
num_cpus::get()
} else {
num_threads.min(num_cpus::get())
};
let num_records = self.num_records();
self.validate_range(num_records, &range)?;
let range_size = range.end - range.start;
let records_per_thread = range_size.div_ceil(num_threads);
let reader = Arc::new(self);
let mut handles = Vec::new();
for tid in 0..num_threads {
let mut processor = processor.clone();
let reader = reader.clone();
processor.set_tid(tid);
let handle = std::thread::spawn(move || -> Result<()> {
let start_idx = range.start + tid * records_per_thread;
let end_idx = (start_idx + records_per_thread).min(range.end);
if start_idx >= end_idx {
return Ok(()); }
let mut translater = itoa::Buffer::new();
let mut dbuf = Vec::new();
let qbuf = reader.build_qbuf();
let rsize_u64 = reader.config.record_size_bytes() / 8;
let scalar = reader.config.scalar();
let mut dbuf_rsize = { (reader.config.schunk() + reader.config.xchunk()) * scalar };
if reader.config.flags {
dbuf_rsize += scalar;
}
for range_start in (start_idx..end_idx).step_by(BATCH_SIZE) {
let range_end = (range_start + BATCH_SIZE).min(end_idx);
dbuf.clear();
let ebuf = reader.get_buffer_slice(range_start..range_end)?;
reader
.config
.bitsize
.decode(ebuf, ebuf.len() * scalar, &mut dbuf)?;
for (inner_idx, idx) in (range_start..range_end).enumerate() {
let id_str = translater.format(idx);
let mut header_buf = [0; 20];
let header_len = id_str.len();
header_buf[..header_len].copy_from_slice(id_str.as_bytes());
let ebuf_start = inner_idx * rsize_u64;
let dbuf_start = inner_idx * dbuf_rsize;
let record = BatchRecord {
buffer: &ebuf[ebuf_start..(ebuf_start + rsize_u64)],
dbuf: &dbuf[dbuf_start..(dbuf_start + dbuf_rsize)],
qbuf: &qbuf,
id: idx as u64,
config: reader.config,
header_buf,
header_len,
};
processor.process_record(record)?;
}
processor.on_batch_complete()?;
}
Ok(())
});
handles.push(handle);
}
for handle in handles {
handle
.join()
.expect("Error joining handle (1)")
.expect("Error joining handle (2)");
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::BinseqRecord;
use bitnuc::BitSize;
const TEST_BQ_FILE: &str = "./data/subset.bq";
#[test]
fn test_mmap_reader_new() {
let reader = MmapReader::new(TEST_BQ_FILE);
assert!(reader.is_ok(), "Failed to create reader");
}
#[test]
fn test_mmap_reader_num_records() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let num_records = reader.num_records();
assert!(num_records > 0, "Expected non-zero records");
}
#[test]
fn test_mmap_reader_is_paired() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let is_paired = reader.is_paired();
assert!(is_paired || !is_paired); }
#[test]
fn test_mmap_reader_header_access() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let header = reader.header();
assert!(header.slen > 0, "Expected non-zero sequence length");
}
#[test]
fn test_mmap_reader_config_access() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let header = reader.header();
let config = RecordConfig::from_header(&header);
assert!(
config.slen > 0,
"Expected non-zero sequence length in config"
);
}
#[test]
fn test_get_record() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let num_records = reader.num_records();
if num_records > 0 {
let record = reader.get(0);
assert!(record.is_ok(), "Expected to get first record");
let record = record.unwrap();
assert_eq!(record.index(), 0, "Expected record index to be 0");
}
}
#[test]
fn test_get_record_out_of_bounds() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let num_records = reader.num_records();
let record = reader.get(num_records + 100);
assert!(record.is_err(), "Expected error for out of bounds index");
}
#[test]
fn test_record_sequence_data() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
if let Ok(record) = reader.get(0) {
let sbuf = record.sbuf();
assert!(!sbuf.is_empty(), "Expected non-empty sequence buffer");
let slen = record.slen();
assert!(slen > 0, "Expected non-zero sequence length");
}
}
#[test]
fn test_record_quality_data() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
if let Ok(record) = reader.get(0) {
let squal = record.squal();
let slen = record.slen() as usize;
assert_eq!(
squal.len(),
slen,
"Quality length should match sequence length"
);
}
}
#[test]
fn test_set_default_quality_score() {
let mut reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let custom_score = 42u8;
reader.set_default_quality_score(custom_score);
if let Ok(record) = reader.get(0) {
let squal = record.squal();
assert!(
squal.iter().all(|&q| q == custom_score),
"All quality scores should be {}",
custom_score
);
}
}
#[derive(Clone)]
struct CountingProcessor {
count: Arc<std::sync::Mutex<usize>>,
}
impl ParallelProcessor for CountingProcessor {
fn process_record<R: BinseqRecord>(&mut self, _record: R) -> Result<()> {
let mut count = self.count.lock().unwrap();
*count += 1;
Ok(())
}
}
#[test]
fn test_parallel_processing() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let num_records = reader.num_records();
let count = Arc::new(std::sync::Mutex::new(0));
let processor = CountingProcessor {
count: count.clone(),
};
reader.process_parallel(processor, 2).unwrap();
let final_count = *count.lock().unwrap();
assert_eq!(final_count, num_records, "All records should be processed");
}
#[test]
fn test_parallel_processing_range() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let num_records = reader.num_records();
if num_records >= 100 {
let start = 10;
let end = 50;
let expected_count = end - start;
let count = Arc::new(std::sync::Mutex::new(0));
let processor = CountingProcessor {
count: count.clone(),
};
reader
.process_parallel_range(processor, 2, start..end)
.unwrap();
let final_count = *count.lock().unwrap();
assert_eq!(
final_count, expected_count,
"Should process exactly {} records",
expected_count
);
}
}
#[test]
fn test_record_config_from_header() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let header = reader.header();
let config = RecordConfig::from_header(&header);
assert_eq!(config.slen, header.slen as u64, "Sequence length mismatch");
assert_eq!(config.xlen, header.xlen as u64, "Extended length mismatch");
assert_eq!(config.bitsize, header.bits, "Bit size mismatch");
}
#[test]
fn test_record_config_record_size() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let header = reader.header();
let config = RecordConfig::from_header(&header);
let size_u64 = config.record_size_u64();
assert!(size_u64 > 0, "Record size should be non-zero");
let size_bytes = config.record_size_bytes();
assert_eq!(size_bytes, size_u64 * 8, "Byte size should be 8x u64 size");
}
#[test]
fn test_ref_record_bitsize() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
if let Ok(record) = reader.get(0) {
let bitsize = record.bitsize();
assert!(
matches!(bitsize, BitSize::Two | BitSize::Four),
"Bitsize should be Two or Four"
);
}
}
#[test]
fn test_ref_record_flag() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
if let Ok(record) = reader.get(0) {
let flag = record.flag();
assert!(flag.is_some() || flag.is_none()); }
}
#[test]
fn test_ref_record_paired_data() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
if reader.is_paired() {
if let Ok(record) = reader.get(0) {
let xbuf = record.xbuf();
let xlen = record.xlen();
if xlen > 0 {
assert!(
!xbuf.is_empty(),
"Extended buffer should not be empty for paired"
);
}
}
}
}
#[test]
fn test_nonexistent_file() {
let result = MmapReader::new("./data/nonexistent.bq");
assert!(result.is_err(), "Should fail on nonexistent file");
}
#[test]
fn test_invalid_file_format() {
let result = MmapReader::new("./Cargo.toml");
if let Ok(reader) = result {
let num_records = reader.num_records();
let _ = num_records; }
}
#[test]
fn test_sequential_record_access() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let num_records = reader.num_records().min(10);
for i in 0..num_records {
let record = reader.get(i);
assert!(record.is_ok(), "Should get record at index {}", i);
assert_eq!(
record.unwrap().index() as usize,
i,
"Record index mismatch at {}",
i
);
}
}
#[test]
fn test_random_record_access() {
let reader = MmapReader::new(TEST_BQ_FILE).unwrap();
let num_records = reader.num_records();
if num_records > 10 {
let indices = [0, 5, num_records / 2, num_records - 1];
for &idx in &indices {
let record = reader.get(idx);
assert!(record.is_ok(), "Should get record at index {}", idx);
assert_eq!(record.unwrap().index() as usize, idx);
}
}
}
}