use super::frame;
use crate::decoding;
use crate::decoding::dictionary::Dictionary;
use crate::decoding::scratch::DecoderScratch;
use std::collections::HashMap;
use std::convert::TryInto;
use std::hash::Hasher;
use std::io::Read;
pub struct FrameDecoder {
state: Option<FrameDecoderState>,
dicts: HashMap<u32, Dictionary>,
}
struct FrameDecoderState {
pub frame: frame::Frame,
decoder_scratch: DecoderScratch,
frame_finished: bool,
block_counter: usize,
bytes_read_counter: u64,
check_sum: Option<u32>,
using_dict: Option<u32>,
}
pub enum BlockDecodingStrategy {
All,
UptoBlocks(usize),
UptoBytes(usize),
}
const MAX_WINDOW_SIZE: u64 = 1024 * 1024 * 100;
impl FrameDecoderState {
pub fn new(source: &mut dyn Read) -> Result<FrameDecoderState, String> {
let (frame, header_size) = frame::read_frame_header(source)?;
let window_size = frame.header.window_size()?;
frame.check_valid()?;
Ok(FrameDecoderState {
frame,
frame_finished: false,
block_counter: 0,
decoder_scratch: DecoderScratch::new(window_size as usize),
bytes_read_counter: u64::from(header_size),
check_sum: None,
using_dict: None,
})
}
pub fn reset(&mut self, source: &mut dyn Read) -> Result<(), String> {
let (frame, header_size) = frame::read_frame_header(source)?;
let window_size = frame.header.window_size()?;
frame.check_valid()?;
if window_size > MAX_WINDOW_SIZE {
return Err(format!(
"Dont support window_sizes (requested: {}) over: {}",
window_size, MAX_WINDOW_SIZE
));
}
self.frame = frame;
self.frame_finished = false;
self.block_counter = 0;
self.decoder_scratch.reset(window_size as usize);
self.bytes_read_counter = u64::from(header_size);
self.check_sum = None;
self.using_dict = None;
Ok(())
}
}
impl Default for FrameDecoder {
fn default() -> Self {
Self::new()
}
}
impl FrameDecoder {
pub fn new() -> FrameDecoder {
FrameDecoder {
state: None,
dicts: HashMap::new(),
}
}
pub fn init(&mut self, source: &mut dyn Read) -> Result<(), String> {
self.reset(source)
}
pub fn init_with_dict(&mut self, source: &mut dyn Read, dict: &[u8]) -> Result<(), String> {
self.reset_with_dict(source, dict)
}
pub fn reset(&mut self, source: &mut dyn Read) -> Result<(), String> {
match &mut self.state {
Some(s) => s.reset(source),
None => {
self.state = Some(FrameDecoderState::new(source)?);
Ok(())
}
}
}
pub fn reset_with_dict(&mut self, source: &mut dyn Read, dict: &[u8]) -> Result<(), String> {
self.reset(source)?;
if let Some(state) = &mut self.state {
let id = state.decoder_scratch.load_dict(dict)?;
state.using_dict = Some(id);
};
Ok(())
}
pub fn add_dict(&mut self, raw_dict: &[u8]) -> Result<(), String> {
let dict = Dictionary::decode_dict(raw_dict)?;
self.dicts.insert(dict.id, dict);
Ok(())
}
pub fn content_size(&self) -> Option<u64> {
let state = match &self.state {
None => return Some(0),
Some(s) => s,
};
match state.frame.header.frame_content_size() {
Err(_) => None,
Ok(x) => Some(x),
}
}
pub fn get_checksum_from_data(&self) -> Option<u32> {
let state = match &self.state {
None => return None,
Some(s) => s,
};
state.check_sum
}
pub fn get_calculated_checksum(&self) -> Option<u32> {
let state = match &self.state {
None => return None,
Some(s) => s,
};
let cksum_64bit = state.decoder_scratch.buffer.hash.finish();
Some(cksum_64bit as u32)
}
pub fn bytes_read_from_source(&self) -> u64 {
let state = match &self.state {
None => return 0,
Some(s) => s,
};
state.bytes_read_counter
}
pub fn is_finished(&self) -> bool {
let state = match &self.state {
None => return true,
Some(s) => s,
};
if state.frame.header.descriptor.content_checksum_flag() {
state.frame_finished && state.check_sum.is_some()
} else {
state.frame_finished
}
}
pub fn blocks_decoded(&self) -> usize {
let state = match &self.state {
None => return 0,
Some(s) => s,
};
state.block_counter
}
pub fn decode_blocks(
&mut self,
source: &mut dyn Read,
strat: BlockDecodingStrategy,
) -> Result<bool, crate::errors::FrameDecoderError> {
let state = match &mut self.state {
None => return Err(crate::errors::FrameDecoderError::NotYetInitialized),
Some(s) => s,
};
match state.frame.header.dictiornary_id() {
Ok(Some(id)) => {
match state.using_dict {
Some(using_id) => {
debug_assert!(id == using_id);
}
None => {
let dict = match self.dicts.get(&id) {
Some(dict) => dict,
None => return Err(crate::errors::FrameDecoderError::DictNotProvided),
};
state.decoder_scratch.use_dict(dict);
state.using_dict = Some(id);
}
}
}
Ok(None) => {}
Err(e) => {
return Err(crate::errors::FrameDecoderError::FailedToInitialize(e));
}
}
let mut block_dec = decoding::block_decoder::new();
let buffer_size_before = state.decoder_scratch.buffer.len();
let block_counter_before = state.block_counter;
loop {
if crate::VERBOSE {
println!("################");
println!("Next Block: {}", state.block_counter);
println!("################");
}
let (block_header, block_header_size) = match block_dec.read_block_header(source) {
Ok(h) => h,
Err(m) => return Err(crate::errors::FrameDecoderError::FailedToReadBlockHeader(m)),
};
state.bytes_read_counter += u64::from(block_header_size);
if crate::VERBOSE {
println!();
println!(
"Found {} block with size: {}, which will be of size: {}",
block_header.block_type,
block_header.content_size,
block_header.decompressed_size
);
}
let bytes_read_in_block_body = match block_dec.decode_block_content(
&block_header,
&mut state.decoder_scratch,
source,
) {
Ok(h) => h,
Err(m) => return Err(crate::errors::FrameDecoderError::FailedToReadBlockBody(m)),
};
state.bytes_read_counter += bytes_read_in_block_body;
state.block_counter += 1;
if crate::VERBOSE {
println!("Output: {}", state.decoder_scratch.buffer.len());
}
if block_header.last_block {
state.frame_finished = true;
if state.frame.header.descriptor.content_checksum_flag() {
let mut chksum = [0u8; 4];
match source.read_exact(&mut chksum) {
Err(_) => {
return Err(crate::errors::FrameDecoderError::FailedToReadChecksum)
}
Ok(()) => {
state.bytes_read_counter += 4;
let chksum = u32::from_le_bytes(chksum);
state.check_sum = Some(chksum);
}
};
}
break;
}
match strat {
BlockDecodingStrategy::All => { }
BlockDecodingStrategy::UptoBlocks(n) => {
if state.block_counter - block_counter_before >= n {
break;
}
}
BlockDecodingStrategy::UptoBytes(n) => {
if state.decoder_scratch.buffer.len() - buffer_size_before >= n {
break;
}
}
}
}
Ok(state.frame_finished)
}
pub fn collect(&mut self) -> Option<Vec<u8>> {
let finished = self.is_finished();
let state = match &mut self.state {
None => return None,
Some(s) => s,
};
if finished {
Some(state.decoder_scratch.buffer.drain())
} else {
state.decoder_scratch.buffer.drain_to_window_size()
}
}
pub fn collect_to_writer(
&mut self,
w: &mut dyn std::io::Write,
) -> Result<usize, std::io::Error> {
let finished = self.is_finished();
let state = match &mut self.state {
None => return Ok(0),
Some(s) => s,
};
if finished {
state.decoder_scratch.buffer.drain_to_writer(w)
} else {
state.decoder_scratch.buffer.drain_to_window_size_writer(w)
}
}
pub fn can_collect(&self) -> usize {
let finished = self.is_finished();
let state = match &self.state {
None => return 0,
Some(s) => s,
};
if finished {
state.decoder_scratch.buffer.can_drain()
} else {
state
.decoder_scratch
.buffer
.can_drain_to_window_size()
.unwrap_or(0)
}
}
pub fn decode_from_to(
&mut self,
source: &[u8],
target: &mut [u8],
) -> Result<(usize, usize), crate::errors::FrameDecoderError> {
let bytes_read_at_start = match &mut self.state {
Some(s) => s.bytes_read_counter,
None => 0,
};
if !self.is_finished() || self.state.is_none() {
let mut mt_source = &source[..];
if self.state.is_none() {
match self.init(&mut mt_source) {
Ok(()) => {}
Err(m) => return Err(crate::errors::FrameDecoderError::FailedToInitialize(m)),
}
}
{
let mut state = match &mut self.state {
Some(s) => s,
None => panic!("Bug in library"),
};
let mut block_dec = decoding::block_decoder::new();
if state.frame.header.descriptor.content_checksum_flag()
&& state.frame_finished
&& state.check_sum.is_none()
{
if mt_source.len() >= 4 {
let chksum = mt_source[..4].try_into().expect("optimized away");
state.bytes_read_counter += 4;
let chksum = u32::from_le_bytes(chksum);
state.check_sum = Some(chksum);
}
return Ok((4, 0));
}
match state.frame.header.dictiornary_id() {
Ok(Some(id)) => {
match state.using_dict {
Some(using_id) => {
debug_assert!(id == using_id);
}
None => {
let dict = match self.dicts.get(&id) {
Some(dict) => dict,
None => {
return Err(
crate::errors::FrameDecoderError::DictNotProvided,
)
}
};
state.decoder_scratch.use_dict(dict);
state.using_dict = Some(id);
}
}
}
Ok(None) => {}
Err(e) => {
return Err(crate::errors::FrameDecoderError::FailedToInitialize(e));
}
}
loop {
if mt_source.len() < 3 {
break;
}
let (block_header, block_header_size) =
match block_dec.read_block_header(&mut mt_source) {
Ok(h) => h,
Err(m) => {
return Err(
crate::errors::FrameDecoderError::FailedToReadBlockHeader(m),
)
}
};
if mt_source.len() < block_header.content_size as usize {
break;
}
state.bytes_read_counter += u64::from(block_header_size);
let bytes_read_in_block_body = match block_dec.decode_block_content(
&block_header,
&mut state.decoder_scratch,
&mut mt_source,
) {
Ok(h) => h,
Err(m) => {
return Err(crate::errors::FrameDecoderError::FailedToReadBlockBody(m))
}
};
state.bytes_read_counter += bytes_read_in_block_body;
state.block_counter += 1;
if block_header.last_block {
state.frame_finished = true;
if state.frame.header.descriptor.content_checksum_flag() {
if mt_source.len() >= 4 {
let chksum = mt_source[..4].try_into().expect("optimized away");
state.bytes_read_counter += 4;
let chksum = u32::from_le_bytes(chksum);
state.check_sum = Some(chksum);
}
}
break;
}
}
}
}
let result_len = match self.read(target) {
Ok(x) => x,
Err(_) => return Err(crate::errors::FrameDecoderError::FailedToDrainDecodebuffer),
};
let bytes_read_at_end = match &mut self.state {
Some(s) => s.bytes_read_counter,
None => panic!("Bug in library"),
};
let read_len = bytes_read_at_end - bytes_read_at_start;
Ok((read_len as usize, result_len))
}
}
impl std::io::Read for FrameDecoder {
fn read(&mut self, target: &mut [u8]) -> std::result::Result<usize, std::io::Error> {
let state = match &mut self.state {
None => return Ok(0),
Some(s) => s,
};
if state.frame_finished {
state.decoder_scratch.buffer.read_all(target)
} else {
state.decoder_scratch.buffer.read(target)
}
}
}