use std::io::{self, Read};
use rayon::prelude::*;
use crate::bitreader::BitReader;
use crate::block::{self, BlockError};
use crate::block_scan;
const CHUNK_BYTES: usize = 100 * 1024 * 1024;
pub struct StreamingBz2Read<R: Read> {
source: R,
max_blocksize: u32,
comp_buf: Vec<u8>,
out_buf: Vec<u8>,
out_pos: usize,
source_eof: bool,
}
impl<R: Read> StreamingBz2Read<R> {
pub fn new(mut source: R) -> Result<Self, BlockError> {
let mut header = [0u8; 4];
source.read_exact(&mut header)
.map_err(|_| BlockError("failed to read bzip2 header"))?;
if &header[..2] != b"BZ" {
return Err(BlockError("bad bzip2 signature"));
}
if header[2] != b'h' {
return Err(BlockError("only huffman bzip2 supported"));
}
let level = header[3];
if !(b'1'..=b'9').contains(&level) {
return Err(BlockError("invalid bzip2 block size level"));
}
let max_blocksize = 100_000 * (level - b'0') as u32;
let mut comp_buf = Vec::with_capacity(CHUNK_BYTES + 1024 * 1024);
comp_buf.extend_from_slice(&header);
Ok(Self {
source,
max_blocksize,
comp_buf,
out_buf: Vec::new(),
out_pos: 0,
source_eof: false,
})
}
fn fill(&mut self) -> io::Result<()> {
if self.source_eof && self.comp_buf.len() <= 4 {
self.out_buf.clear();
self.out_pos = 0;
return Ok(());
}
if !self.source_eof {
let old_len = self.comp_buf.len();
let target = old_len + CHUNK_BYTES;
self.comp_buf.resize(target, 0);
let mut filled = 0;
while filled < CHUNK_BYTES {
match self.source.read(&mut self.comp_buf[old_len + filled..target]) {
Ok(0) => { self.source_eof = true; break; }
Ok(n) => filled += n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
self.comp_buf.truncate(old_len + filled);
}
let boundaries = block_scan::find_all_blocks(&self.comp_buf);
if boundaries.is_empty() {
if self.source_eof {
self.comp_buf.clear();
self.out_buf.clear();
self.out_pos = 0;
}
return Ok(());
}
let decode_count = if self.source_eof {
boundaries.len()
} else if boundaries.len() > 1 {
boundaries.len() - 1
} else {
return Ok(());
};
let decode_boundaries = &boundaries[..decode_count];
let comp_data: &[u8] = &self.comp_buf;
let max_bs = self.max_blocksize;
let results: Vec<Result<Vec<u8>, BlockError>> = decode_boundaries
.par_iter()
.map(|boundary| {
let bit_after_magic = boundary.bit_offset + 48;
let mut reader = BitReader::from_bit_offset(comp_data, bit_after_magic as usize);
block::decode_block(&mut reader, max_bs)
})
.collect();
self.out_buf.clear();
self.out_pos = 0;
for r in results {
let decoded = r.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.0))?;
self.out_buf.extend_from_slice(&decoded);
}
if decode_count < boundaries.len() {
let keep_from_byte = boundaries[decode_count].byte_offset();
let remaining = self.comp_buf[keep_from_byte..].to_vec();
self.comp_buf.clear();
self.comp_buf.extend_from_slice(&remaining);
} else {
self.comp_buf.clear();
}
Ok(())
}
}
impl<R: Read> Read for StreamingBz2Read<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.out_pos < self.out_buf.len() {
let available = &self.out_buf[self.out_pos..];
let n = buf.len().min(available.len());
buf[..n].copy_from_slice(&available[..n]);
self.out_pos += n;
return Ok(n);
}
self.fill()?;
if self.out_buf.is_empty() {
return Ok(0); }
let available = &self.out_buf[self.out_pos..];
let n = buf.len().min(available.len());
buf[..n].copy_from_slice(&available[..n]);
self.out_pos += n;
Ok(n)
}
}
pub struct ParallelBz2Read<'a> {
data: &'a [u8],
max_blocksize: u32,
boundaries: Vec<block_scan::BlockBoundary>,
next_block: usize,
buf: Vec<u8>,
buf_pos: usize,
}
const BATCH_SIZE: usize = 64;
impl<'a> ParallelBz2Read<'a> {
pub fn new(data: &'a [u8]) -> Result<Self, BlockError> {
if data.len() < 4 {
return Err(BlockError("input too short for bzip2 header"));
}
if &data[..2] != b"BZ" {
return Err(BlockError("bad bzip2 signature"));
}
if data[2] != b'h' {
return Err(BlockError("only huffman bzip2 supported"));
}
let level = data[3];
if !(b'1'..=b'9').contains(&level) {
return Err(BlockError("invalid bzip2 block size level"));
}
let max_blocksize = 100_000 * (level - b'0') as u32;
let boundaries = block_scan::find_all_blocks(data);
Ok(Self {
data,
max_blocksize,
boundaries,
next_block: 0,
buf: Vec::new(),
buf_pos: 0,
})
}
pub fn block_count(&self) -> usize {
self.boundaries.len()
}
fn fill_batch(&mut self) -> io::Result<()> {
let remaining = self.boundaries.len() - self.next_block;
if remaining == 0 {
self.buf.clear();
self.buf_pos = 0;
return Ok(());
}
let batch_end = self.next_block + remaining.min(BATCH_SIZE);
let batch = &self.boundaries[self.next_block..batch_end];
let results: Vec<Result<Vec<u8>, BlockError>> = batch
.par_iter()
.map(|boundary| {
let bit_after_magic = boundary.bit_offset + 48;
let mut reader = BitReader::from_bit_offset(self.data, bit_after_magic as usize);
block::decode_block(&mut reader, self.max_blocksize)
})
.collect();
let mut total = 0usize;
for r in &results {
match r {
Ok(v) => total += v.len(),
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.0)),
}
}
self.buf.clear();
self.buf.reserve(total);
for r in results {
self.buf.extend_from_slice(&r.map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, e.0)
})?);
}
self.buf_pos = 0;
self.next_block = batch_end;
Ok(())
}
}
impl Read for ParallelBz2Read<'_> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.buf_pos >= self.buf.len() {
self.fill_batch()?;
if self.buf.is_empty() {
return Ok(0);
}
}
let available = &self.buf[self.buf_pos..];
let n = buf.len().min(available.len());
buf[..n].copy_from_slice(&available[..n]);
self.buf_pos += n;
Ok(n)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn streaming_hello() {
let compressed = include_bytes!("../test_data/hello.bz2");
let cursor = io::Cursor::new(compressed.as_slice());
let mut reader = StreamingBz2Read::new(cursor).unwrap();
let mut output = Vec::new();
reader.read_to_end(&mut output).unwrap();
assert_eq!(&output, b"Hello, World!\n");
}
#[test]
fn streaming_liechtenstein() {
let compressed = include_bytes!("../test_data/liechtenstein.osm.bz2");
let cursor = io::Cursor::new(compressed.as_slice());
let mut reader = StreamingBz2Read::new(cursor).unwrap();
let mut output = Vec::new();
reader.read_to_end(&mut output).unwrap();
let sequential = crate::stream::decompress(compressed).unwrap();
assert_eq!(output.len(), sequential.len());
assert_eq!(output, sequential);
}
#[test]
fn streaming_incremental() {
let compressed = include_bytes!("../test_data/liechtenstein.osm.bz2");
let cursor = io::Cursor::new(compressed.as_slice());
let mut reader = StreamingBz2Read::new(cursor).unwrap();
let mut output = Vec::new();
let mut buf = [0u8; 4096];
loop {
let n = reader.read(&mut buf).unwrap();
if n == 0 { break; }
output.extend_from_slice(&buf[..n]);
}
let sequential = crate::stream::decompress(compressed).unwrap();
assert_eq!(output, sequential);
}
#[test]
fn mmap_hello() {
let compressed = include_bytes!("../test_data/hello.bz2");
let mut reader = ParallelBz2Read::new(compressed).unwrap();
let mut output = Vec::new();
reader.read_to_end(&mut output).unwrap();
assert_eq!(&output, b"Hello, World!\n");
}
#[test]
fn mmap_liechtenstein() {
let compressed = include_bytes!("../test_data/liechtenstein.osm.bz2");
let mut reader = ParallelBz2Read::new(compressed).unwrap();
let mut output = Vec::new();
reader.read_to_end(&mut output).unwrap();
let sequential = crate::stream::decompress(compressed).unwrap();
assert_eq!(output, sequential);
}
}