use std::fmt;
use std::ops::{Deref, DerefMut};
use std::ptr;
use crate::bytes;
use crate::error::{Error, Result};
use crate::{MAX_BLOCK_SIZE, MAX_INPUT_SIZE};
const MAX_TABLE_SIZE: usize = 1 << 14;
const SMALL_TABLE_SIZE: usize = 1 << 10;
const INPUT_MARGIN: usize = 16 - 1;
const MIN_NON_LITERAL_BLOCK_SIZE: usize = 1 + 1 + INPUT_MARGIN;
enum Tag {
Literal = 0b00,
Copy1 = 0b01,
Copy2 = 0b10,
#[allow(dead_code)]
Copy4 = 0b11,
}
pub fn max_compress_len(input_len: usize) -> usize {
let input_len = input_len as u64;
if input_len > MAX_INPUT_SIZE {
return 0;
}
let max = 32 + input_len + (input_len / 6);
if max > MAX_INPUT_SIZE {
0
} else {
max as usize
}
}
pub struct Encoder {
small: [u16; SMALL_TABLE_SIZE],
big: Vec<u16>,
}
impl fmt::Debug for Encoder {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Encoder(...)")
}
}
impl Encoder {
pub fn new() -> Encoder {
Encoder { small: [0; SMALL_TABLE_SIZE], big: vec![] }
}
pub fn compress(
&mut self,
mut input: &[u8],
output: &mut [u8],
) -> Result<usize> {
match max_compress_len(input.len()) {
0 => {
return Err(Error::TooBig {
given: input.len() as u64,
max: MAX_INPUT_SIZE,
});
}
min if output.len() < min => {
return Err(Error::BufferTooSmall {
given: output.len() as u64,
min: min as u64,
});
}
_ => {}
}
if input.is_empty() {
output[0] = 0;
return Ok(1);
}
let mut d = bytes::write_varu64(output, input.len() as u64);
while !input.is_empty() {
let mut src = input;
if src.len() > MAX_BLOCK_SIZE {
src = &src[..MAX_BLOCK_SIZE as usize];
}
input = &input[src.len()..];
let mut block = Block::new(src, output, d);
if block.src.len() < MIN_NON_LITERAL_BLOCK_SIZE {
let lit_end = block.src.len();
unsafe {
block.emit_literal(lit_end);
}
} else {
let table = self.block_table(block.src.len());
block.compress(table);
}
d = block.d;
}
Ok(d)
}
pub fn compress_vec(&mut self, input: &[u8]) -> Result<Vec<u8>> {
let mut buf = vec![0; max_compress_len(input.len())];
let n = self.compress(input, &mut buf)?;
buf.truncate(n);
Ok(buf)
}
}
struct Block<'s, 'd> {
src: &'s [u8],
s: usize,
s_limit: usize,
dst: &'d mut [u8],
d: usize,
next_emit: usize,
}
impl<'s, 'd> Block<'s, 'd> {
#[inline(always)]
fn new(src: &'s [u8], dst: &'d mut [u8], d: usize) -> Block<'s, 'd> {
Block {
src: src,
s: 0,
s_limit: src.len(),
dst: dst,
d: d,
next_emit: 0,
}
}
#[inline(always)]
fn compress(&mut self, mut table: BlockTable<'_>) {
debug_assert!(!table.is_empty());
debug_assert!(self.src.len() >= MIN_NON_LITERAL_BLOCK_SIZE);
self.s += 1;
self.s_limit -= INPUT_MARGIN;
let mut next_hash =
table.hash(bytes::read_u32_le(&self.src[self.s..]));
loop {
let mut skip = 32;
let mut candidate;
let mut s_next = self.s;
loop {
self.s = s_next;
let bytes_between_hash_lookups = skip >> 5;
s_next = self.s + bytes_between_hash_lookups;
skip += bytes_between_hash_lookups;
if s_next > self.s_limit {
return self.done();
}
unsafe {
candidate = *table.get_unchecked(next_hash) as usize;
*table.get_unchecked_mut(next_hash) = self.s as u16;
let srcp = self.src.as_ptr();
let x = bytes::loadu_u32_le(srcp.add(s_next));
next_hash = table.hash(x);
let cur = bytes::loadu_u32_ne(srcp.add(self.s));
let cand = bytes::loadu_u32_ne(srcp.add(candidate));
if cur == cand {
break;
}
}
}
let lit_end = self.s;
unsafe {
self.emit_literal(lit_end);
}
loop {
let base = self.s;
self.s += 4;
unsafe {
self.extend_match(candidate + 4);
}
let (offset, len) = (base - candidate, self.s - base);
self.emit_copy(offset, len);
self.next_emit = self.s;
if self.s >= self.s_limit {
return self.done();
}
unsafe {
let srcp = self.src.as_ptr();
let x = bytes::loadu_u64_le(srcp.add(self.s - 1));
let prev_hash = table.hash(x as u32);
*table.get_unchecked_mut(prev_hash) = (self.s - 1) as u16;
let cur_hash = table.hash((x >> 8) as u32);
candidate = *table.get_unchecked(cur_hash) as usize;
*table.get_unchecked_mut(cur_hash) = self.s as u16;
let y = bytes::loadu_u32_le(srcp.add(candidate));
if (x >> 8) as u32 != y {
next_hash = table.hash((x >> 16) as u32);
self.s += 1;
break;
}
}
}
}
}
#[inline(always)]
fn emit_copy(&mut self, offset: usize, mut len: usize) {
debug_assert!(1 <= offset && offset <= 65535);
debug_assert!(4 <= len && len <= 65535);
while len >= 68 {
self.emit_copy2(offset, 64);
len -= 64;
}
if len > 64 {
self.emit_copy2(offset, 60);
len -= 60;
}
if len <= 11 && offset <= 2047 {
self.dst[self.d] = (((offset >> 8) as u8) << 5)
| (((len - 4) as u8) << 2)
| (Tag::Copy1 as u8);
self.dst[self.d + 1] = offset as u8;
self.d += 2;
} else {
self.emit_copy2(offset, len);
}
}
#[inline(always)]
fn emit_copy2(&mut self, offset: usize, len: usize) {
debug_assert!(1 <= offset && offset <= 65535);
debug_assert!(1 <= len && len <= 64);
self.dst[self.d] = (((len - 1) as u8) << 2) | (Tag::Copy2 as u8);
bytes::write_u16_le(offset as u16, &mut self.dst[self.d + 1..]);
self.d += 3;
}
#[inline(always)]
unsafe fn extend_match(&mut self, mut cand: usize) {
debug_assert!(cand < self.s);
while self.s + 8 <= self.src.len() {
let srcp = self.src.as_ptr();
let x = bytes::loadu_u64_ne(srcp.add(self.s));
let y = bytes::loadu_u64_ne(srcp.add(cand));
if x == y {
self.s += 8;
cand += 8;
} else {
let z = x.to_le() ^ y.to_le();
self.s += z.trailing_zeros() as usize / 8;
return;
}
}
while self.s < self.src.len() && self.src[self.s] == self.src[cand] {
self.s += 1;
cand += 1;
}
}
#[inline(always)]
fn done(&mut self) {
if self.next_emit < self.src.len() {
let lit_end = self.src.len();
unsafe {
self.emit_literal(lit_end);
}
}
}
#[inline(always)]
unsafe fn emit_literal(&mut self, lit_end: usize) {
let lit_start = self.next_emit;
let len = lit_end - lit_start;
let n = len.checked_sub(1).unwrap();
if n <= 59 {
self.dst[self.d] = ((n as u8) << 2) | (Tag::Literal as u8);
self.d += 1;
if len <= 16 && lit_start + 16 <= self.src.len() {
let srcp = self.src.as_ptr().add(lit_start);
let dstp = self.dst.as_mut_ptr().add(self.d);
ptr::copy_nonoverlapping(srcp, dstp, 16);
self.d += len;
return;
}
} else if n < 256 {
self.dst[self.d] = (60 << 2) | (Tag::Literal as u8);
self.dst[self.d + 1] = n as u8;
self.d += 2;
} else {
self.dst[self.d] = (61 << 2) | (Tag::Literal as u8);
bytes::write_u16_le(n as u16, &mut self.dst[self.d + 1..]);
self.d += 3;
}
let srcp = self.src.as_ptr().add(lit_start);
let dstp = self.dst.as_mut_ptr().add(self.d);
ptr::copy_nonoverlapping(srcp, dstp, len);
self.d += len;
}
}
struct BlockTable<'a> {
table: &'a mut [u16],
shift: u32,
}
impl Encoder {
fn block_table(&mut self, block_size: usize) -> BlockTable<'_> {
let mut shift: u32 = 32 - 8;
let mut table_size = 256;
while table_size < MAX_TABLE_SIZE && table_size < block_size {
shift -= 1;
table_size *= 2;
}
let table: &mut [u16] = if table_size <= SMALL_TABLE_SIZE {
&mut self.small[0..table_size]
} else {
if self.big.is_empty() {
self.big = vec![0; MAX_TABLE_SIZE];
}
&mut self.big[0..table_size]
};
for x in &mut *table {
*x = 0;
}
BlockTable { table: table, shift: shift }
}
}
impl<'a> BlockTable<'a> {
#[inline(always)]
fn hash(&self, x: u32) -> usize {
(x.wrapping_mul(0x1E35A7BD) >> self.shift) as usize
}
}
impl<'a> Deref for BlockTable<'a> {
type Target = [u16];
fn deref(&self) -> &[u16] {
self.table
}
}
impl<'a> DerefMut for BlockTable<'a> {
fn deref_mut(&mut self) -> &mut [u16] {
self.table
}
}