use std::{
fmt,
hash::Hasher,
io::{self, Write},
};
use twox_hash::XxHash32;
use crate::{
block::{
compress::compress_internal,
hashtable::{HashTable, HashTable4K},
},
sink::vec_sink_for_compression,
};
use super::Error;
use super::{
header::{BlockInfo, BlockMode, FrameInfo, BLOCK_INFO_SIZE, MAX_FRAME_INFO_SIZE},
BlockSize,
};
use crate::block::WINDOW_SIZE;
pub struct FrameEncoder<W: io::Write> {
src: Vec<u8>,
src_start: usize,
src_end: usize,
ext_dict_offset: usize,
ext_dict_len: usize,
src_stream_offset: usize,
compression_table: HashTable4K,
w: W,
content_hasher: XxHash32,
content_len: u64,
dst: Vec<u8>,
is_frame_open: bool,
data_to_frame_written: bool,
frame_info: FrameInfo,
}
impl<W: io::Write> FrameEncoder<W> {
fn init(&mut self) {
let max_block_size = self.frame_info.block_size.get_size();
let src_size = if self.frame_info.block_mode == BlockMode::Linked {
max_block_size * 2 + WINDOW_SIZE
} else {
max_block_size
};
self.src
.reserve(src_size.saturating_sub(self.src.capacity()));
self.dst.reserve(
crate::block::compress::get_maximum_output_size(max_block_size)
.saturating_sub(self.dst.capacity()),
);
}
pub fn auto_finish(self) -> AutoFinishEncoder<W> {
AutoFinishEncoder {
encoder: Some(self),
}
}
pub fn with_frame_info(frame_info: FrameInfo, wtr: W) -> Self {
FrameEncoder {
src: Vec::new(),
w: wtr,
compression_table: HashTable4K::new(),
content_hasher: XxHash32::with_seed(0),
content_len: 0,
dst: Vec::new(),
is_frame_open: false,
data_to_frame_written: false,
frame_info,
src_start: 0,
src_end: 0,
ext_dict_offset: 0,
ext_dict_len: 0,
src_stream_offset: 0,
}
}
pub fn new(wtr: W) -> Self {
Self::with_frame_info(Default::default(), wtr)
}
pub fn frame_info(&mut self) -> &FrameInfo {
&self.frame_info
}
pub fn finish(mut self) -> Result<W, Error> {
self.try_finish()?;
Ok(self.w)
}
pub fn try_finish(&mut self) -> Result<(), Error> {
match self.flush() {
Ok(()) => {
if !self.is_frame_open && !self.data_to_frame_written {
self.begin_frame(0)?;
}
self.end_frame()?;
self.data_to_frame_written = true;
Ok(())
}
Err(err) => Err(err.into()),
}
}
pub fn into_inner(self) -> W {
self.w
}
pub fn get_ref(&self) -> &W {
&self.w
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.w
}
fn end_frame(&mut self) -> Result<(), Error> {
debug_assert!(self.is_frame_open);
self.is_frame_open = false;
if let Some(expected) = self.frame_info.content_size {
if expected != self.content_len {
return Err(Error::ContentLengthError {
expected,
actual: self.content_len,
});
}
}
let mut block_info_buffer = [0u8; BLOCK_INFO_SIZE];
BlockInfo::EndMark.write(&mut block_info_buffer[..])?;
self.w.write_all(&block_info_buffer[..])?;
if self.frame_info.content_checksum {
let content_checksum = self.content_hasher.finish() as u32;
self.w.write_all(&content_checksum.to_le_bytes())?;
}
Ok(())
}
fn begin_frame(&mut self, buf_len: usize) -> io::Result<()> {
self.is_frame_open = true;
if self.frame_info.block_size == BlockSize::Auto {
self.frame_info.block_size = BlockSize::from_buf_length(buf_len);
}
self.init();
let mut frame_info_buffer = [0u8; MAX_FRAME_INFO_SIZE];
let size = self.frame_info.write(&mut frame_info_buffer)?;
self.w.write_all(&frame_info_buffer[..size])?;
if self.content_len != 0 {
self.content_len = 0;
self.src_stream_offset = 0;
self.src.clear();
self.src_start = 0;
self.src_end = 0;
self.ext_dict_len = 0;
self.content_hasher = XxHash32::with_seed(0);
self.compression_table.clear();
}
Ok(())
}
fn write_block(&mut self) -> io::Result<()> {
debug_assert!(self.is_frame_open);
let max_block_size = self.frame_info.block_size.get_size();
debug_assert!(self.src_end - self.src_start <= max_block_size);
if self.src_stream_offset + max_block_size + WINDOW_SIZE >= u32::MAX as usize / 2 {
self.compression_table
.reposition((self.src_stream_offset - self.ext_dict_len) as _);
self.src_stream_offset = self.ext_dict_len;
}
let input = &self.src[..self.src_end];
let src = &input[self.src_start..];
let dst_required_size = crate::block::compress::get_maximum_output_size(src.len());
let compress_result = if self.ext_dict_len != 0 {
debug_assert_eq!(self.frame_info.block_mode, BlockMode::Linked);
compress_internal::<_, true, _>(
input,
self.src_start,
&mut vec_sink_for_compression(&mut self.dst, 0, 0, dst_required_size),
&mut self.compression_table,
&self.src[self.ext_dict_offset..self.ext_dict_offset + self.ext_dict_len],
self.src_stream_offset,
)
} else {
compress_internal::<_, false, _>(
input,
self.src_start,
&mut vec_sink_for_compression(&mut self.dst, 0, 0, dst_required_size),
&mut self.compression_table,
b"",
self.src_stream_offset,
)
};
let (block_info, block_data) = match compress_result.map_err(Error::CompressionError)? {
comp_len if comp_len < src.len() => {
(BlockInfo::Compressed(comp_len as _), &self.dst[..comp_len])
}
_ => (BlockInfo::Uncompressed(src.len() as _), src),
};
let mut block_info_buffer = [0u8; BLOCK_INFO_SIZE];
block_info.write(&mut block_info_buffer[..])?;
self.w.write_all(&block_info_buffer[..])?;
self.w.write_all(block_data)?;
if self.frame_info.block_checksums {
let block_checksum = XxHash32::oneshot(0, block_data);
self.w.write_all(&block_checksum.to_le_bytes())?;
}
if self.frame_info.content_checksum {
self.content_hasher.write(src);
}
self.content_len += src.len() as u64;
self.src_start += src.len();
debug_assert_eq!(self.src_start, self.src_end);
if self.frame_info.block_mode == BlockMode::Linked {
debug_assert_eq!(self.src.capacity(), max_block_size * 2 + WINDOW_SIZE);
if self.src_start >= max_block_size + WINDOW_SIZE {
self.ext_dict_offset = self.src_end - WINDOW_SIZE;
self.ext_dict_len = WINDOW_SIZE;
self.src_stream_offset += self.src_end;
self.src_start = 0;
self.src_end = 0;
} else if self.src_start + self.ext_dict_len > WINDOW_SIZE {
let delta = self
.ext_dict_len
.min(self.src_start + self.ext_dict_len - WINDOW_SIZE);
self.ext_dict_offset += delta;
self.ext_dict_len -= delta;
debug_assert!(self.src_start + self.ext_dict_len >= WINDOW_SIZE)
}
debug_assert!(
self.ext_dict_len == 0 || self.src_start + max_block_size <= self.ext_dict_offset
);
} else {
debug_assert_eq!(self.ext_dict_len, 0);
debug_assert_eq!(self.src.capacity(), max_block_size);
self.src_start = 0;
self.src_end = 0;
self.src_stream_offset += src.len();
}
debug_assert!(self.src_start <= self.src_end);
debug_assert!(self.src_start + max_block_size <= self.src.capacity());
Ok(())
}
}
impl<W: io::Write> io::Write for FrameEncoder<W> {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
if !self.is_frame_open && !buf.is_empty() {
self.begin_frame(buf.len())?;
}
let buf_len = buf.len();
while !buf.is_empty() {
let src_filled = self.src_end - self.src_start;
let max_fill_len = self.frame_info.block_size.get_size() - src_filled;
if max_fill_len == 0 {
self.write_block()?;
debug_assert_eq!(self.src_end, self.src_start);
continue;
}
let fill_len = max_fill_len.min(buf.len());
vec_copy_overwriting(&mut self.src, self.src_end, &buf[..fill_len]);
buf = &buf[fill_len..];
self.src_end += fill_len;
}
Ok(buf_len)
}
fn flush(&mut self) -> io::Result<()> {
if self.src_start != self.src_end {
self.write_block()?;
}
Ok(())
}
}
pub struct AutoFinishEncoder<W: Write> {
encoder: Option<FrameEncoder<W>>,
}
impl<W: io::Write> Drop for AutoFinishEncoder<W> {
fn drop(&mut self) {
if let Some(mut encoder) = self.encoder.take() {
let _ = encoder.try_finish();
}
}
}
impl<W: Write> Write for AutoFinishEncoder<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.encoder.as_mut().unwrap().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.encoder.as_mut().unwrap().flush()
}
}
impl<W: fmt::Debug + io::Write> fmt::Debug for FrameEncoder<W> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("FrameEncoder")
.field("w", &self.w)
.field("frame_info", &self.frame_info)
.field("is_frame_open", &self.is_frame_open)
.field("content_hasher", &self.content_hasher)
.field("content_len", &self.content_len)
.field("compression_table", &"{ ... }")
.field("data_to_frame_written", &self.data_to_frame_written)
.field("dst", &"[...]")
.field("src", &"[...]")
.field("src_start", &self.src_start)
.field("src_end", &self.src_end)
.field("ext_dict_offset", &self.ext_dict_offset)
.field("ext_dict_len", &self.ext_dict_len)
.field("src_stream_offset", &self.src_stream_offset)
.finish()
}
}
#[inline]
fn vec_copy_overwriting(target: &mut Vec<u8>, target_start: usize, src: &[u8]) {
debug_assert!(target_start + src.len() <= target.capacity());
let overwrite_len = (target.len() - target_start).min(src.len());
target[target_start..target_start + overwrite_len].copy_from_slice(&src[..overwrite_len]);
target.extend_from_slice(&src[overwrite_len..]);
}