use rayon::prelude::*;
use crate::bitreader::BitReader;
use crate::block::{self, BlockError};
use crate::block_scan;
use crate::BLOCK_MAGIC;
use crate::FINAL_MAGIC;
pub fn decompress_parallel(data: &[u8]) -> Result<Vec<u8>, BlockError> {
if data.len() < 4 {
return Err(BlockError("input too short for bzip2 header"));
}
if &data[..2] != b"BZ" {
return Err(BlockError("bad bzip2 signature"));
}
if data[2] != b'h' {
return Err(BlockError("only huffman bzip2 supported"));
}
let level = data[3];
if !(b'1'..=b'9').contains(&level) {
return Err(BlockError("invalid bzip2 block size level"));
}
let max_blocksize = 100_000 * (level - b'0') as u32;
let boundaries = block_scan::find_all_blocks(data);
if boundaries.is_empty() {
return Ok(Vec::new());
}
let results: Vec<Result<Vec<u8>, BlockError>> = boundaries
.par_iter()
.map(|boundary| {
let bit_after_magic = boundary.bit_offset + 48;
let mut reader = BitReader::from_bit_offset(data, bit_after_magic as usize);
block::decode_block(&mut reader, max_blocksize)
})
.collect();
let mut total_size = 0usize;
for r in &results {
match r {
Ok(v) => total_size += v.len(),
Err(e) => return Err(BlockError(e.0)),
}
}
let mut output = Vec::with_capacity(total_size);
for r in results {
output.extend_from_slice(&r?);
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parallel_hello() {
let compressed = include_bytes!("../test_data/hello.bz2");
let output = decompress_parallel(compressed).unwrap();
assert_eq!(&output, b"Hello, World!\n");
}
#[test]
fn parallel_liechtenstein() {
let compressed = include_bytes!("../test_data/liechtenstein.osm.bz2");
let output = decompress_parallel(compressed).unwrap();
let sequential = crate::stream::decompress(compressed).unwrap();
assert_eq!(output.len(), sequential.len(), "size mismatch");
assert_eq!(output, sequential, "content mismatch");
}
#[test]
fn parallel_matches_sequential() {
let compressed = include_bytes!("../test_data/liechtenstein.osm.bz2");
let par = decompress_parallel(compressed).unwrap();
let seq = crate::stream::decompress(compressed).unwrap();
assert_eq!(par, seq);
}
}