use std::collections::BTreeMap;
use std::io::{BufReader, BufWriter, Read, Write};
use crossbeam::channel::{bounded, Receiver, Sender};
use super::boundary::BoundaryResolver;
use super::encoding::{
buffer_and_write_block, encoding_worker, send_job_and_drain, write_single_block, EncodedBlock,
EncodingJob,
};
use super::splitter::{BlockSplitter, DefaultSplitter, FastqSplitter};
use crate::bgzf::{GziEntry, BGZF_EOF};
use crate::deflate::{DeflateParser, LZ77Token};
use crate::error::{Error, Result};
use crate::gzip::GzipHeader;
use crate::{FormatProfile, TranscodeConfig, TranscodeStats, Transcoder};
pub struct ParallelTranscoder {
config: TranscodeConfig,
}
impl ParallelTranscoder {
pub fn new(config: TranscodeConfig) -> Self {
Self { config }
}
}
impl Transcoder for ParallelTranscoder {
fn transcode<R: Read, W: Write>(&mut self, input: R, output: W) -> Result<TranscodeStats> {
let num_threads = self.config.effective_threads();
if num_threads == 1 {
let mut single = super::single::SingleThreadedTranscoder::new(self.config.clone());
return single.transcode(input, output);
}
self.transcode_parallel(input, output, num_threads)
}
}
impl ParallelTranscoder {
fn transcode_parallel<R: Read, W: Write>(
&mut self,
input: R,
mut output: W,
num_threads: usize,
) -> Result<TranscodeStats> {
let channel_capacity = num_threads * 4;
let (job_tx, job_rx): (Sender<EncodingJob>, Receiver<EncodingJob>) =
bounded(channel_capacity);
let (result_tx, result_rx): (Sender<Result<EncodedBlock>>, Receiver<Result<EncodedBlock>>) =
bounded(channel_capacity);
let use_fixed_huffman = self.config.use_fixed_huffman();
let result = crossbeam::scope(|scope| {
for _ in 0..num_threads {
let job_rx = job_rx.clone();
let result_tx = result_tx.clone();
scope.spawn(move |_| {
encoding_worker(job_rx, result_tx, use_fixed_huffman);
});
}
drop(job_rx);
drop(result_tx);
self.parse_dispatch_and_write(input, &mut output, job_tx, result_rx)
});
result.map_err(|_| Error::Internal("Thread panicked".to_string()))?
}
fn parse_dispatch_and_write<R: Read, W: Write>(
&self,
input: R,
output: &mut W,
job_tx: Sender<EncodingJob>,
result_rx: Receiver<Result<EncodedBlock>>,
) -> Result<TranscodeStats> {
let mut reader = BufReader::with_capacity(self.config.buffer_size, input);
let mut writer = BufWriter::with_capacity(self.config.buffer_size, output);
let _gzip_header = GzipHeader::parse(&mut reader)?;
let mut parser = DeflateParser::new(&mut reader);
let mut resolver = BoundaryResolver::new();
let use_smart = self.config.use_smart_boundaries();
let mut splitter: Box<dyn BlockSplitter> =
if use_smart && self.config.format == FormatProfile::Fastq {
Box::new(FastqSplitter::new())
} else {
Box::new(DefaultSplitter)
};
let max_block_size = if use_smart {
(self.config.block_size as f64 * 1.1) as usize
} else {
self.config.block_size
};
let mut pending_tokens: Vec<LZ77Token> = Vec::with_capacity(8192);
let mut pending_uncompressed_size: usize = 0;
let mut block_start_position: u64 = 0;
let mut next_block_id: u64 = 0;
let mut blocks_written: u64 = 0;
let mut output_bytes: u64 = 0;
let build_index = self.config.build_index;
let mut index_entries: Vec<GziEntry> = Vec::new();
let mut current_compressed_offset: u64 = 0;
let mut current_uncompressed_offset: u64 = 0;
let mut pending_blocks: BTreeMap<u64, EncodedBlock> = BTreeMap::new();
let mut next_write_id: u64 = 0;
loop {
while let Some(deflate_block) = parser.parse_block()? {
for token in deflate_block.tokens {
if matches!(token, LZ77Token::EndOfBlock) {
continue;
}
let token_size = token.uncompressed_size();
splitter.process_token(&token);
let should_emit = if use_smart {
let near_target =
pending_uncompressed_size + token_size >= self.config.block_size;
let at_good_split = splitter.is_good_split_point();
let exceeds_max = pending_uncompressed_size + token_size > max_block_size;
!pending_tokens.is_empty()
&& ((near_target && at_good_split) || exceeds_max)
} else {
pending_uncompressed_size + token_size > self.config.block_size
&& !pending_tokens.is_empty()
};
if should_emit {
let (resolved, crc, uncompressed_size) =
resolver.resolve_block(block_start_position, &pending_tokens);
let job = EncodingJob {
block_id: next_block_id,
tokens: resolved,
uncompressed_size,
crc,
};
next_block_id += 1;
send_job_and_drain(
&job_tx,
&result_rx,
job,
&mut writer,
&mut pending_blocks,
&mut next_write_id,
&mut blocks_written,
&mut output_bytes,
build_index,
&mut index_entries,
&mut current_compressed_offset,
&mut current_uncompressed_offset,
)?;
block_start_position = resolver.position();
pending_tokens.clear();
pending_uncompressed_size = 0;
splitter.reset();
}
pending_tokens.push(token);
pending_uncompressed_size += token_size;
}
}
if !parser.read_trailer_and_check_next()? {
break; }
}
if !pending_tokens.is_empty() {
let (resolved, crc, uncompressed_size) =
resolver.resolve_block(block_start_position, &pending_tokens);
let job =
EncodingJob { block_id: next_block_id, tokens: resolved, uncompressed_size, crc };
next_block_id += 1;
send_job_and_drain(
&job_tx,
&result_rx,
job,
&mut writer,
&mut pending_blocks,
&mut next_write_id,
&mut blocks_written,
&mut output_bytes,
build_index,
&mut index_entries,
&mut current_compressed_offset,
&mut current_uncompressed_offset,
)?;
}
drop(job_tx);
while blocks_written + (pending_blocks.len() as u64) < next_block_id {
match result_rx.recv() {
Ok(result) => {
let block = result?;
buffer_and_write_block(
&mut writer,
block,
&mut pending_blocks,
&mut next_write_id,
&mut blocks_written,
&mut output_bytes,
build_index,
&mut index_entries,
&mut current_compressed_offset,
&mut current_uncompressed_offset,
)?;
}
Err(_) => break,
}
}
while let Some(block) = pending_blocks.remove(&next_write_id) {
write_single_block(
&mut writer,
&block.data,
block.uncompressed_size,
&mut output_bytes,
build_index,
&mut index_entries,
&mut current_compressed_offset,
&mut current_uncompressed_offset,
)?;
blocks_written += 1;
next_write_id += 1;
}
writer.write_all(&BGZF_EOF)?;
output_bytes += 28;
writer.flush()?;
let (refs_resolved, _refs_preserved) = resolver.stats();
Ok(TranscodeStats {
input_bytes: parser.bytes_read(),
output_bytes,
blocks_written,
boundary_refs_resolved: refs_resolved,
copied_directly: false,
index_entries: if build_index { Some(index_entries) } else { None },
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_parallel_transcode() {
use std::io::Write as IoWrite;
let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
encoder
.write_all(b"Hello, World! This is some test data for parallel transcoding.")
.unwrap();
let gzip_data = encoder.finish().unwrap();
let config = TranscodeConfig { num_threads: 2, ..Default::default() };
let mut transcoder = ParallelTranscoder::new(config);
let mut output = Vec::new();
let stats = transcoder.transcode(Cursor::new(&gzip_data), &mut output).unwrap();
assert!(stats.blocks_written >= 1);
assert!(!output.is_empty());
assert_eq!(output[0], 0x1f);
assert_eq!(output[1], 0x8b);
assert_eq!(output[3] & 0x04, 0x04);
assert_eq!(output[12], b'B');
assert_eq!(output[13], b'C');
}
#[test]
fn test_effective_threads() {
let config = TranscodeConfig { num_threads: 0, ..Default::default() };
let threads = config.effective_threads();
assert!(threads >= 1);
assert!(threads <= 32);
let config2 = TranscodeConfig { num_threads: 100, ..Default::default() };
assert_eq!(config2.effective_threads(), 32); }
}