use std::{
fs::File,
io::{BufReader, Read},
path::Path,
};
use crate::{Header, IbuError, Record, HEADER_SIZE, RECORD_SIZE};
const DEFAULT_BUFFER_SIZE: usize = 48 * 1024 * RECORD_SIZE;
type BoxedReader = Box<dyn Read + Send>;
#[derive(Clone)]
pub struct Reader<R: Read> {
inner: R,
buffer: Vec<u8>,
header: Header,
pos: usize,
cap: usize,
bytes_read: usize,
eof: bool,
}
impl<R: Read> Reader<R> {
pub fn new(mut inner: R) -> crate::Result<Self> {
let header = {
let mut header_bytes = [0u8; HEADER_SIZE];
inner.read_exact(&mut header_bytes)?;
let header: Header = bytemuck::pod_read_unaligned(&header_bytes);
header.validate()?;
header
};
let buffer = Vec::with_capacity(DEFAULT_BUFFER_SIZE);
Ok(Self {
inner,
buffer,
header,
pos: 0,
cap: 0,
bytes_read: HEADER_SIZE,
eof: false,
})
}
pub fn read_batch(&mut self) -> crate::Result<bool> {
if self.buffer.len() != self.buffer.capacity() {
self.buffer.resize(self.buffer.capacity(), 0);
}
let mut read = 0;
while read < self.buffer.len() {
match self.inner.read(&mut self.buffer[read..]) {
Ok(0) => break,
Ok(n) => read += n,
Err(e) => return Err(e.into()),
}
}
if read % RECORD_SIZE != 0 {
let non_rem = read - read % RECORD_SIZE;
return Err(IbuError::TruncatedRecord {
pos: self.bytes_read + non_rem,
});
}
self.pos = 0;
self.cap = read / RECORD_SIZE;
self.bytes_read += read;
Ok(read > 0)
}
pub fn header(&self) -> Header {
self.header
}
}
impl<R: Read> Iterator for Reader<R> {
type Item = Result<Record, IbuError>;
fn next(&mut self) -> Option<Self::Item> {
if self.eof {
return None;
}
if self.pos >= self.cap {
match self.read_batch() {
Ok(true) => {}
Ok(false) => {
self.eof = true;
}
Err(e) => return Some(Err(e)),
}
}
if self.eof {
None
} else {
let lpos = RECORD_SIZE * self.pos;
let rpos = lpos + RECORD_SIZE;
let record: &[Record] = bytemuck::cast_slice(&self.buffer[lpos..rpos]);
self.pos += 1;
Some(Ok(record[0]))
}
}
}
impl Reader<BoxedReader> {
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self, IbuError> {
let rdr = File::open(path).map(BufReader::new)?;
#[cfg(feature = "niffler")]
{
let (pt, _format) = niffler::send::get_reader(Box::new(rdr))?;
Self::new(pt)
}
#[cfg(not(feature = "niffler"))]
{
Self::new(Box::new(rdr))
}
}
pub fn from_stdin() -> Result<Self, IbuError> {
let rdr = Box::new(std::io::stdin());
#[cfg(feature = "niffler")]
{
let (pt, _format) = niffler::send::get_reader(rdr)?;
Self::new(pt)
}
#[cfg(not(feature = "niffler"))]
{
Self::new(rdr)
}
}
pub fn from_optional_path<P: AsRef<Path>>(path: Option<P>) -> Result<Self, IbuError> {
match path {
Some(path) => Self::from_path(path),
None => Self::from_stdin(),
}
}
}
pub fn load_to_vec<P: AsRef<Path>>(path: P) -> crate::Result<(Header, Vec<Record>)> {
let mut file = File::open(path)?;
let mut header_bytes = [0u8; HEADER_SIZE];
file.read_exact(&mut header_bytes)?;
let header = crate::Header::from_bytes(&header_bytes);
header.validate()?;
let metadata = file.metadata()?;
let data_size = metadata.len() as usize - HEADER_SIZE;
if !data_size.is_multiple_of(RECORD_SIZE) {
return Err(IbuError::InvalidMapSize);
}
let num_records = data_size / crate::RECORD_SIZE;
let mut records = vec![Record::default(); num_records];
let buffer: &mut [u8] = bytemuck::cast_slice_mut(&mut records);
file.read_exact(buffer)?;
Ok((header, records))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Header, Record, Writer};
use std::io::Cursor;
fn create_test_data(records: &[Record]) -> Vec<u8> {
let header = Header::new(16, 12);
let buffer = Vec::new();
let mut writer = Writer::new(buffer, header).unwrap();
writer.write_batch(records).unwrap();
writer.finish().unwrap();
writer.into_inner()
}
#[test]
fn test_reader_creation() {
let records = vec![Record::new(1, 2, 3), Record::new(4, 5, 6)];
let buffer = create_test_data(&records);
let cursor = Cursor::new(buffer);
let reader = Reader::new(cursor).unwrap();
let header = reader.header();
assert_eq!(header.bc_len, 16);
assert_eq!(header.umi_len, 12);
assert_eq!(header.magic, crate::MAGIC);
assert_eq!(header.version, crate::VERSION);
}
#[test]
fn test_reader_invalid_header() {
let invalid_data = vec![0u8; 32];
let cursor = Cursor::new(invalid_data);
let result = Reader::new(cursor);
assert!(matches!(result, Err(IbuError::InvalidMagicNumber { .. })));
}
#[test]
fn test_reader_iterator() {
let records = vec![
Record::new(1, 2, 3),
Record::new(4, 5, 6),
Record::new(7, 8, 9),
];
let buffer = create_test_data(&records);
let cursor = Cursor::new(buffer);
let reader = Reader::new(cursor).unwrap();
let read_records: Result<Vec<_>, _> = reader.collect();
let read_records = read_records.unwrap();
assert_eq!(records, read_records);
}
#[test]
fn test_reader_empty_file() {
let records: Vec<Record> = vec![];
let buffer = create_test_data(&records);
let cursor = Cursor::new(buffer);
let reader = Reader::new(cursor).unwrap();
let read_records: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(read_records.len(), 0);
}
#[test]
fn test_reader_large_batch() {
let records: Vec<Record> = (0..100_000).map(|i| Record::new(i, i * 2, i * 3)).collect();
let buffer = create_test_data(&records);
let cursor = Cursor::new(buffer);
let reader = Reader::new(cursor).unwrap();
let read_records: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(records, read_records);
}
#[test]
fn test_reader_truncated_data() {
let records = vec![Record::new(1, 2, 3)];
let mut buffer = create_test_data(&records);
buffer.truncate(buffer.len() - 5);
let cursor = Cursor::new(buffer);
let mut reader = Reader::new(cursor).unwrap();
let result = reader.next();
assert!(result.is_some());
assert!(matches!(
result.unwrap(),
Err(IbuError::TruncatedRecord { .. })
));
}
#[test]
fn test_reader_manual_batch_reading() {
let records = vec![Record::new(1, 2, 3)];
let buffer = create_test_data(&records);
let cursor = Cursor::new(buffer);
let mut reader = Reader::new(cursor).unwrap();
let has_data = reader.read_batch().unwrap();
assert!(has_data);
let has_data = reader.read_batch().unwrap();
assert!(!has_data);
}
#[test]
fn test_reader_clone() {
let records = vec![Record::new(1, 2, 3)];
let buffer = create_test_data(&records);
let cursor = Cursor::new(buffer.clone());
let reader = Reader::new(cursor).unwrap();
let reader_clone = reader.clone();
assert_eq!(reader.header(), reader_clone.header());
}
#[test]
fn test_load_to_vec_basic() {
use std::fs;
use std::io::Write;
let records = vec![
Record::new(1, 2, 3),
Record::new(4, 5, 6),
Record::new(7, 8, 9),
];
let temp_path = "test_load_to_vec.ibu";
let buffer = create_test_data(&records);
{
let mut file = fs::File::create(temp_path).unwrap();
file.write_all(&buffer).unwrap();
}
let (header, loaded_records) = load_to_vec(temp_path).unwrap();
assert_eq!(header.bc_len, 16);
assert_eq!(header.umi_len, 12);
assert_eq!(loaded_records, records);
fs::remove_file(temp_path).unwrap();
}
#[test]
fn test_load_to_vec_empty_file() {
use std::fs;
use std::io::Write;
let records: Vec<Record> = vec![];
let temp_path = "test_load_empty.ibu";
let buffer = create_test_data(&records);
{
let mut file = fs::File::create(temp_path).unwrap();
file.write_all(&buffer).unwrap();
}
let (header, loaded_records) = load_to_vec(temp_path).unwrap();
assert_eq!(header.bc_len, 16);
assert_eq!(header.umi_len, 12);
assert_eq!(loaded_records.len(), 0);
fs::remove_file(temp_path).unwrap();
}
#[test]
fn test_load_to_vec_invalid_size() {
use std::fs;
use std::io::Write;
let mut buffer = create_test_data(&[Record::new(1, 2, 3)]);
buffer.truncate(buffer.len() - 5);
let temp_path = "test_invalid_size.ibu";
{
let mut file = fs::File::create(temp_path).unwrap();
file.write_all(&buffer).unwrap();
}
let result = load_to_vec(temp_path);
assert!(matches!(result, Err(IbuError::InvalidMapSize)));
fs::remove_file(temp_path).unwrap();
}
#[test]
fn test_reader_bytes_read_tracking() {
let records = vec![Record::new(1, 2, 3), Record::new(4, 5, 6)];
let buffer = create_test_data(&records);
let cursor = Cursor::new(buffer);
let mut reader = Reader::new(cursor).unwrap();
assert_eq!(reader.bytes_read, HEADER_SIZE);
let _ = reader.next().unwrap().unwrap();
assert!(reader.bytes_read > HEADER_SIZE);
let _: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
}
}