use crate::block::{DecompressError, MINMATCH};
use crate::fastcpy_unsafe;
use crate::sink::SliceSink;
use crate::sink::{PtrSink, Sink};
#[allow(unused_imports)]
use alloc::vec::Vec;
#[inline]
unsafe fn duplicate(
output_ptr: &mut *mut u8,
output_end: *mut u8,
start: *const u8,
match_length: usize,
) {
if (output_ptr.offset_from(start) as usize) < match_length + 16 - 1
|| (output_end.offset_from(*output_ptr) as usize) < match_length + 16 - 1
{
duplicate_overlapping(output_ptr, start, match_length);
} else {
debug_assert!(
output_ptr.add(match_length / 16 * 16 + ((match_length % 16) != 0) as usize * 16)
<= output_end
);
wild_copy_from_src_16(start, *output_ptr, match_length);
*output_ptr = output_ptr.add(match_length);
}
}
#[inline]
fn wild_copy_from_src_16(mut source: *const u8, mut dst_ptr: *mut u8, num_items: usize) {
unsafe {
let dst_ptr_end = dst_ptr.add(num_items);
loop {
core::ptr::copy_nonoverlapping(source, dst_ptr, 16);
source = source.add(16);
dst_ptr = dst_ptr.add(16);
if dst_ptr >= dst_ptr_end {
break;
}
}
}
}
#[inline]
#[cfg_attr(feature = "nightly", optimize(size))] unsafe fn duplicate_overlapping(
output_ptr: &mut *mut u8,
mut start: *const u8,
match_length: usize,
) {
let dst_ptr_end = output_ptr.add(match_length);
while output_ptr.add(1) < dst_ptr_end {
core::ptr::copy(start, *output_ptr, 1);
start = start.add(1);
*output_ptr = output_ptr.add(1);
core::ptr::copy(start, *output_ptr, 1);
start = start.add(1);
*output_ptr = output_ptr.add(1);
}
if *output_ptr < dst_ptr_end {
core::ptr::copy(start, *output_ptr, 1);
*output_ptr = output_ptr.add(1);
}
}
#[inline]
unsafe fn copy_from_dict(
output_base: *mut u8,
output_ptr: &mut *mut u8,
ext_dict: &[u8],
offset: usize,
match_length: usize,
) -> usize {
debug_assert!(output_ptr.offset_from(output_base) >= 0);
debug_assert!(offset > output_ptr.offset_from(output_base) as usize);
debug_assert!(ext_dict.len() + output_ptr.offset_from(output_base) as usize >= offset);
let dict_offset = ext_dict.len() + output_ptr.offset_from(output_base) as usize - offset;
let dict_match_length = match_length.min(ext_dict.len() - dict_offset);
core::ptr::copy_nonoverlapping(
ext_dict.as_ptr().add(dict_offset),
*output_ptr,
dict_match_length,
);
*output_ptr = output_ptr.add(dict_match_length);
dict_match_length
}
#[inline]
pub(super) fn read_integer_ptr(
input_ptr: &mut *const u8,
_input_ptr_end: *const u8,
) -> Result<usize, DecompressError> {
let mut n: usize = 0;
loop {
{
if *input_ptr >= _input_ptr_end {
return Err(DecompressError::ExpectedAnotherByte);
}
}
let extra = unsafe { input_ptr.read() };
*input_ptr = unsafe { input_ptr.add(1) };
n += extra as usize;
if extra != 0xFF {
break;
}
}
Ok(n)
}
#[inline]
fn read_match_offset(input_ptr: &mut *const u8) -> Result<u16, DecompressError> {
let mut num: u16 = 0;
unsafe {
core::ptr::copy_nonoverlapping(*input_ptr, &mut num as *mut u16 as *mut u8, 2);
*input_ptr = input_ptr.add(2);
}
let offset = u16::from_le(num);
if offset == 0 {
Err(DecompressError::OffsetZero)
} else {
Ok(offset)
}
}
const FIT_TOKEN_MASK_LITERAL: u8 = 0b00001111;
const FIT_TOKEN_MASK_MATCH: u8 = 0b11110000;
#[test]
fn check_token() {
assert!(!does_token_fit(15));
assert!(does_token_fit(14));
assert!(does_token_fit(114));
assert!(!does_token_fit(0b11110000));
assert!(does_token_fit(0b10110000));
}
#[inline]
fn does_token_fit(token: u8) -> bool {
!((token & FIT_TOKEN_MASK_LITERAL) == FIT_TOKEN_MASK_LITERAL
|| (token & FIT_TOKEN_MASK_MATCH) == FIT_TOKEN_MASK_MATCH)
}
#[inline]
pub(crate) fn decompress_internal<const USE_DICT: bool, S: Sink>(
input: &[u8],
output: &mut S,
ext_dict: &[u8],
) -> Result<usize, DecompressError> {
if input.is_empty() {
return Err(DecompressError::ExpectedAnotherByte);
}
let ext_dict = if USE_DICT {
ext_dict
} else {
debug_assert!(ext_dict.is_empty());
&[]
};
let output_base = unsafe { output.base_mut_ptr() };
let output_end = unsafe { output_base.add(output.capacity()) };
let output_start_pos_ptr = unsafe { output.base_mut_ptr().add(output.pos()) as *mut u8 };
let mut output_ptr = output_start_pos_ptr;
let mut input_ptr = input.as_ptr();
let input_ptr_end = unsafe { input.as_ptr().add(input.len()) };
let safe_distance_from_end = (16 + 2 + 1 ).min(input.len()) ;
let input_ptr_safe = unsafe { input_ptr_end.sub(safe_distance_from_end) };
let safe_output_ptr = unsafe {
let mut output_num_safe_bytes = output
.capacity()
.saturating_sub(16 + 18 );
if USE_DICT {
output_num_safe_bytes = output_num_safe_bytes.saturating_sub(17);
};
output_base.add(output_num_safe_bytes)
};
loop {
let token = unsafe { input_ptr.read() };
input_ptr = unsafe { input_ptr.add(1) };
if does_token_fit(token)
&& (input_ptr as usize) <= input_ptr_safe as usize
&& output_ptr < safe_output_ptr
{
let literal_length = (token >> 4) as usize;
let mut match_length = MINMATCH + (token & 0xF) as usize;
debug_assert!(
unsafe { output_ptr.add(literal_length + match_length) } <= output_end,
"{literal_length} + {match_length} {} wont fit ",
literal_length + match_length
);
unsafe {
core::ptr::copy_nonoverlapping(input_ptr, output_ptr, 16);
input_ptr = input_ptr.add(literal_length);
output_ptr = output_ptr.add(literal_length);
}
debug_assert!(input_ptr_end as usize - input_ptr as usize >= 2);
let offset = read_match_offset(&mut input_ptr)? as usize;
let output_len = unsafe { output_ptr.offset_from(output_base) as usize };
if offset > output_len + ext_dict.len() {
return Err(DecompressError::OffsetOutOfBounds);
}
if USE_DICT && offset > output_len {
let copied = unsafe {
copy_from_dict(output_base, &mut output_ptr, ext_dict, offset, match_length)
};
if copied == match_length {
continue;
}
match_length -= copied;
}
let start_ptr = unsafe { output_ptr.sub(offset) };
debug_assert!(start_ptr >= output_base);
debug_assert!(start_ptr < output_end);
debug_assert!(unsafe { output_end.offset_from(start_ptr) as usize } >= match_length);
if offset >= match_length {
unsafe {
core::ptr::copy(start_ptr, output_ptr, 18);
output_ptr = output_ptr.add(match_length);
}
} else {
unsafe {
duplicate_overlapping(&mut output_ptr, start_ptr, match_length);
}
}
continue;
}
let mut literal_length = (token >> 4) as usize;
if literal_length != 0 {
if literal_length == 15 {
literal_length += read_integer_ptr(&mut input_ptr, input_ptr_end)? as usize;
}
{
if literal_length > input_ptr_end as usize - input_ptr as usize {
return Err(DecompressError::LiteralOutOfBounds);
}
if literal_length > unsafe { output_end.offset_from(output_ptr) as usize } {
return Err(DecompressError::OutputTooSmall {
expected: unsafe { output_ptr.offset_from(output_base) as usize }
+ literal_length,
actual: output.capacity(),
});
}
}
unsafe {
fastcpy_unsafe::slice_copy(input_ptr, output_ptr, literal_length);
output_ptr = output_ptr.add(literal_length);
input_ptr = input_ptr.add(literal_length);
}
}
if input_ptr >= input_ptr_end {
break;
}
{
if (input_ptr_end as usize) - (input_ptr as usize) < 2 {
return Err(DecompressError::ExpectedAnotherByte);
}
}
let offset = read_match_offset(&mut input_ptr)? as usize;
let mut match_length = MINMATCH + (token & 0xF) as usize;
if match_length == MINMATCH + 15 {
match_length += read_integer_ptr(&mut input_ptr, input_ptr_end)? as usize;
}
let output_len = unsafe { output_ptr.offset_from(output_base) as usize };
{
if offset > output_len + ext_dict.len() {
return Err(DecompressError::OffsetOutOfBounds);
}
if match_length > unsafe { output_end.offset_from(output_ptr) as usize } {
return Err(DecompressError::OutputTooSmall {
expected: output_len + match_length,
actual: output.capacity(),
});
}
}
if USE_DICT && offset > output_len {
let copied = unsafe {
copy_from_dict(output_base, &mut output_ptr, ext_dict, offset, match_length)
};
if copied == match_length {
{
if input_ptr >= input_ptr_end {
return Err(DecompressError::ExpectedAnotherByte);
}
}
continue;
}
match_length -= copied;
}
let start_ptr = unsafe { output_ptr.sub(offset) };
debug_assert!(start_ptr >= output_base);
debug_assert!(start_ptr < output_end);
debug_assert!(unsafe { output_end.offset_from(start_ptr) as usize } >= match_length);
unsafe {
duplicate(&mut output_ptr, output_end, start_ptr, match_length);
}
{
if input_ptr >= input_ptr_end {
return Err(DecompressError::ExpectedAnotherByte);
}
}
}
unsafe {
output.set_pos(output_ptr.offset_from(output_base) as usize);
Ok(output_ptr.offset_from(output_start_pos_ptr) as usize)
}
}
#[inline]
pub fn decompress_into(input: &[u8], output: &mut [u8]) -> Result<usize, DecompressError> {
decompress_internal::<false, _>(input, &mut SliceSink::new(output, 0), b"")
}
#[inline]
pub fn decompress_into_with_dict(
input: &[u8],
output: &mut [u8],
ext_dict: &[u8],
) -> Result<usize, DecompressError> {
decompress_internal::<true, _>(input, &mut SliceSink::new(output, 0), ext_dict)
}
#[inline]
pub fn decompress_with_dict(
input: &[u8],
min_uncompressed_size: usize,
ext_dict: &[u8],
) -> Result<Vec<u8>, DecompressError> {
let mut vec = Vec::with_capacity(min_uncompressed_size);
let decomp_len =
decompress_internal::<true, _>(input, &mut PtrSink::from_vec(&mut vec, 0), ext_dict)?;
unsafe {
vec.set_len(decomp_len);
}
Ok(vec)
}
#[inline]
pub fn decompress_size_prepended(input: &[u8]) -> Result<Vec<u8>, DecompressError> {
let (uncompressed_size, input) = super::uncompressed_size(input)?;
decompress(input, uncompressed_size)
}
#[inline]
pub fn decompress(input: &[u8], min_uncompressed_size: usize) -> Result<Vec<u8>, DecompressError> {
let mut vec = Vec::with_capacity(min_uncompressed_size);
let decomp_len =
decompress_internal::<true, _>(input, &mut PtrSink::from_vec(&mut vec, 0), b"")?;
unsafe {
vec.set_len(decomp_len);
}
Ok(vec)
}
#[inline]
pub fn decompress_size_prepended_with_dict(
input: &[u8],
ext_dict: &[u8],
) -> Result<Vec<u8>, DecompressError> {
let (uncompressed_size, input) = super::uncompressed_size(input)?;
decompress_with_dict(input, uncompressed_size, ext_dict)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn all_literal() {
assert_eq!(decompress(&[0x30, b'a', b'4', b'9'], 3).unwrap(), b"a49");
}
#[test]
fn incomplete_input() {
assert!(matches!(
decompress(&[], 255),
Err(DecompressError::ExpectedAnotherByte)
));
assert!(matches!(
decompress(&[0xF0], 255),
Err(DecompressError::ExpectedAnotherByte)
));
assert!(matches!(
decompress(&[0x0F, 0], 255),
Err(DecompressError::ExpectedAnotherByte)
));
assert!(matches!(
decompress(&[0x0F, 1, 0], 255),
Err(DecompressError::ExpectedAnotherByte)
));
}
#[test]
fn offset_oob() {
assert!(matches!(
decompress(&[0x40, b'a', 1, 0], 4),
Err(DecompressError::LiteralOutOfBounds)
));
assert!(matches!(
decompress(&[0x20, b'a', b'a', 1, 0], 1),
Err(DecompressError::OutputTooSmall {
expected: 2,
actual: 1
})
));
assert!(matches!(
decompress(&[0x10, b'a', 1, 0], 4),
Err(DecompressError::OutputTooSmall {
expected: 5,
actual: 4
})
));
assert!(matches!(
decompress(
&[0x0E, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
256
),
Err(DecompressError::OffsetOutOfBounds)
));
assert!(matches!(
decompress_with_dict(
&[0x0E, 255, 0, 0x70, 0, 0, 0, 0, 0, 0, 0],
256,
&[0_u8; 250]
),
Err(DecompressError::OffsetOutOfBounds)
));
assert!(matches!(
decompress(&[0x0F, 1, 0, 1, 0x70, 0, 0, 0, 0, 0, 0, 0], 256),
Err(DecompressError::OffsetOutOfBounds)
));
assert!(matches!(
decompress(&[0x40, 0, 0, 0, 0, 255, 0, 0x70, 0, 0, 0, 0, 0, 0, 0], 256),
Err(DecompressError::OffsetOutOfBounds)
));
}
#[test]
fn offset_0() {
assert!(matches!(
decompress(&[0x0E, 0, 0, 0x70, 0, 0, 0, 0, 0, 0, 0], 256),
Err(DecompressError::OffsetZero)
));
}
}