#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "zstd_decompress.h"
#define MAX_LITERALS_SIZE ((size_t)128 * 1024)
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define ERROR(s) \
do { \
fprintf(stderr, "Error: %s\n", s); \
exit(1); \
} while (0)
#define INP_SIZE() \
ERROR("Input buffer smaller than it should be or input is " \
"corrupted")
#define OUT_SIZE() ERROR("Output buffer too small for output")
#define CORRUPTION() ERROR("Corruption detected while decompressing")
#define BAD_ALLOC() ERROR("Memory allocation error")
#define IMPOSSIBLE() ERROR("An impossibility has occurred")
typedef uint8_t u8;
typedef uint16_t u16;
typedef uint32_t u32;
typedef uint64_t u64;
typedef int8_t i8;
typedef int16_t i16;
typedef int32_t i32;
typedef int64_t i64;
typedef struct {
u8 *ptr;
size_t len;
} ostream_t;
typedef struct {
const u8 *ptr;
size_t len;
int bit_offset;
} istream_t;
static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
static inline void IO_align_stream(istream_t *const in);
static inline void IO_write_byte(ostream_t *const out, u8 symb);
static inline size_t IO_istream_len(const istream_t *const in);
static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
static inline void IO_advance_input(istream_t *const in, size_t len);
static inline ostream_t IO_make_ostream(u8 *out, size_t len);
static inline istream_t IO_make_istream(const u8 *in, size_t len);
static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
static inline u64 read_bits_LE(const u8 *src, const int num_bits,
const size_t offset);
static inline u64 STREAM_read_bits(const u8 *src, const int bits,
i64 *const offset);
static inline int highest_set_bit(const u64 num);
#define HUF_MAX_BITS (16)
#define HUF_MAX_SYMBS (256)
typedef struct {
u8 *symbols;
u8 *num_bits;
int max_bits;
} HUF_dtable;
static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
u16 *const state, const u8 *const src,
i64 *const offset);
static inline void HUF_init_state(const HUF_dtable *const dtable,
u16 *const state, const u8 *const src,
i64 *const offset);
static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
ostream_t *const out, istream_t *const in);
static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
ostream_t *const out, istream_t *const in);
static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
const int num_symbs);
static void HUF_init_dtable_usingweights(HUF_dtable *const table,
const u8 *const weights,
const int num_symbs);
static void HUF_free_dtable(HUF_dtable *const dtable);
static void HUF_copy_dtable(HUF_dtable *const dst, const HUF_dtable *const src);
#define FSE_MAX_ACCURACY_LOG (15)
#define FSE_MAX_SYMBS (256)
typedef struct {
u8 *symbols;
u8 *num_bits;
u16 *new_state_base;
int accuracy_log;
} FSE_dtable;
static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
const u16 state);
static inline void FSE_update_state(const FSE_dtable *const dtable,
u16 *const state, const u8 *const src,
i64 *const offset);
static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
u16 *const state, const u8 *const src,
i64 *const offset);
static inline void FSE_init_state(const FSE_dtable *const dtable,
u16 *const state, const u8 *const src,
i64 *const offset);
static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
ostream_t *const out,
istream_t *const in);
static void FSE_init_dtable(FSE_dtable *const dtable,
const i16 *const norm_freqs, const int num_symbs,
const int accuracy_log);
static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
const int max_accuracy_log);
static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
static void FSE_free_dtable(FSE_dtable *const dtable);
static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src);
typedef struct {
size_t window_size;
size_t frame_content_size;
u32 dictionary_id;
int content_checksum_flag;
int single_segment_flag;
} frame_header_t;
typedef struct {
frame_header_t header;
size_t current_total_output;
const u8 *dict_content;
size_t dict_content_len;
HUF_dtable literals_dtable;
FSE_dtable ll_dtable;
FSE_dtable ml_dtable;
FSE_dtable of_dtable;
u64 previous_offsets[3];
} frame_context_t;
struct dictionary_s {
HUF_dtable literals_dtable;
FSE_dtable ll_dtable;
FSE_dtable ml_dtable;
FSE_dtable of_dtable;
u8 *content;
size_t content_size;
u64 previous_offsets[3];
u32 dictionary_id;
};
typedef struct {
u32 literal_length;
u32 match_length;
u32 offset;
} sequence_command_t;
static void decode_frame(ostream_t *const out, istream_t *const in,
const dictionary_t *const dict);
static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
istream_t *const in);
static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
u8 **const literals);
static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
sequence_command_t **const sequences);
static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
const u8 *const literals,
const size_t literals_len,
const sequence_command_t *const sequences,
const size_t num_sequences);
static u32 copy_literals(const size_t seq, istream_t *litstream,
ostream_t *const out);
static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
static void execute_match_copy(frame_context_t *const ctx, size_t offset,
size_t match_length, size_t total_output,
ostream_t *const out);
size_t ZSTD_decompress(void *const dst, const size_t dst_len,
const void *const src, const size_t src_len) {
dictionary_t* uninit_dict = create_dictionary();
size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
src_len, uninit_dict);
free_dictionary(uninit_dict);
return decomp_size;
}
size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
const void *const src, const size_t src_len,
dictionary_t* parsed_dict) {
istream_t in = IO_make_istream(src, src_len);
ostream_t out = IO_make_ostream(dst, dst_len);
decode_frame(&out, &in, parsed_dict);
return out.ptr - (u8 *)dst;
}
static void decode_data_frame(ostream_t *const out, istream_t *const in,
const dictionary_t *const dict);
static void init_frame_context(frame_context_t *const context,
istream_t *const in,
const dictionary_t *const dict);
static void free_frame_context(frame_context_t *const context);
static void parse_frame_header(frame_header_t *const header,
istream_t *const in);
static void frame_context_apply_dict(frame_context_t *const ctx,
const dictionary_t *const dict);
static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
istream_t *const in);
static void decode_frame(ostream_t *const out, istream_t *const in,
const dictionary_t *const dict) {
const u32 magic_number = IO_read_bits(in, 32);
if (magic_number == 0xFD2FB528U) {
decode_data_frame(out, in, dict);
return;
}
ERROR("Tried to decode non-ZSTD frame");
}
static void decode_data_frame(ostream_t *const out, istream_t *const in,
const dictionary_t *const dict) {
frame_context_t ctx;
init_frame_context(&ctx, in, dict);
if (ctx.header.frame_content_size != 0 &&
ctx.header.frame_content_size > out->len) {
OUT_SIZE();
}
decompress_data(&ctx, out, in);
free_frame_context(&ctx);
}
static void init_frame_context(frame_context_t *const context,
istream_t *const in,
const dictionary_t *const dict) {
memset(context, 0, sizeof(frame_context_t));
parse_frame_header(&context->header, in);
context->previous_offsets[0] = 1;
context->previous_offsets[1] = 4;
context->previous_offsets[2] = 8;
frame_context_apply_dict(context, dict);
}
static void free_frame_context(frame_context_t *const context) {
HUF_free_dtable(&context->literals_dtable);
FSE_free_dtable(&context->ll_dtable);
FSE_free_dtable(&context->ml_dtable);
FSE_free_dtable(&context->of_dtable);
memset(context, 0, sizeof(frame_context_t));
}
static void parse_frame_header(frame_header_t *const header,
istream_t *const in) {
const u8 descriptor = IO_read_bits(in, 8);
const u8 frame_content_size_flag = descriptor >> 6;
const u8 single_segment_flag = (descriptor >> 5) & 1;
const u8 reserved_bit = (descriptor >> 3) & 1;
const u8 content_checksum_flag = (descriptor >> 2) & 1;
const u8 dictionary_id_flag = descriptor & 3;
if (reserved_bit != 0) {
CORRUPTION();
}
header->single_segment_flag = single_segment_flag;
header->content_checksum_flag = content_checksum_flag;
if (!single_segment_flag) {
u8 window_descriptor = IO_read_bits(in, 8);
u8 exponent = window_descriptor >> 3;
u8 mantissa = window_descriptor & 7;
size_t window_base = (size_t)1 << (10 + exponent);
size_t window_add = (window_base / 8) * mantissa;
header->window_size = window_base + window_add;
}
if (dictionary_id_flag) {
const int bytes_array[] = {0, 1, 2, 4};
const int bytes = bytes_array[dictionary_id_flag];
header->dictionary_id = IO_read_bits(in, bytes * 8);
} else {
header->dictionary_id = 0;
}
if (single_segment_flag || frame_content_size_flag) {
const int bytes_array[] = {1, 2, 4, 8};
const int bytes = bytes_array[frame_content_size_flag];
header->frame_content_size = IO_read_bits(in, bytes * 8);
if (bytes == 2) {
header->frame_content_size += 256;
}
} else {
header->frame_content_size = 0;
}
if (single_segment_flag) {
header->window_size = header->frame_content_size;
}
}
static void frame_context_apply_dict(frame_context_t *const ctx,
const dictionary_t *const dict) {
if (!dict || !dict->content)
return;
if (ctx->header.dictionary_id != 0 &&
ctx->header.dictionary_id != dict->dictionary_id) {
ERROR("Wrong dictionary provided");
}
ctx->dict_content = dict->content;
ctx->dict_content_len = dict->content_size;
if (dict->dictionary_id != 0) {
HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
memcpy(ctx->previous_offsets, dict->previous_offsets,
sizeof(ctx->previous_offsets));
}
}
static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
istream_t *const in) {
int last_block = 0;
do {
last_block = IO_read_bits(in, 1);
const int block_type = IO_read_bits(in, 2);
const size_t block_len = IO_read_bits(in, 21);
switch (block_type) {
case 0: {
const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
u8 *const write_ptr = IO_get_write_ptr(out, block_len);
memcpy(write_ptr, read_ptr, block_len);
ctx->current_total_output += block_len;
break;
}
case 1: {
const u8 *const read_ptr = IO_get_read_ptr(in, 1);
u8 *const write_ptr = IO_get_write_ptr(out, block_len);
memset(write_ptr, read_ptr[0], block_len);
ctx->current_total_output += block_len;
break;
}
case 2: {
istream_t block_stream = IO_make_sub_istream(in, block_len);
decompress_block(ctx, out, &block_stream);
break;
}
case 3:
CORRUPTION();
break;
default:
IMPOSSIBLE();
}
} while (!last_block);
if (ctx->header.content_checksum_flag) {
IO_advance_input(in, 4);
}
}
static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
istream_t *const in) {
u8 *literals = NULL;
const size_t literals_size = decode_literals(ctx, in, &literals);
sequence_command_t *sequences = NULL;
const size_t num_sequences =
decode_sequences(ctx, in, &sequences);
execute_sequences(ctx, out, literals, literals_size, sequences,
num_sequences);
free(literals);
free(sequences);
}
static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
const int block_type,
const int size_format);
static size_t decode_literals_compressed(frame_context_t *const ctx,
istream_t *const in,
u8 **const literals,
const int block_type,
const int size_format);
static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
int *const num_symbs);
static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
u8 **const literals) {
int block_type = IO_read_bits(in, 2);
int size_format = IO_read_bits(in, 2);
if (block_type <= 1) {
return decode_literals_simple(in, literals, block_type,
size_format);
} else {
return decode_literals_compressed(ctx, in, literals, block_type,
size_format);
}
}
static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
const int block_type,
const int size_format) {
size_t size;
switch (size_format) {
case 0:
case 2:
IO_rewind_bits(in, 1);
size = IO_read_bits(in, 5);
break;
case 1:
size = IO_read_bits(in, 12);
break;
case 3:
size = IO_read_bits(in, 20);
break;
default:
IMPOSSIBLE();
}
if (size > MAX_LITERALS_SIZE) {
CORRUPTION();
}
*literals = malloc(size);
if (!*literals) {
BAD_ALLOC();
}
switch (block_type) {
case 0: {
const u8 *const read_ptr = IO_get_read_ptr(in, size);
memcpy(*literals, read_ptr, size);
break;
}
case 1: {
const u8 *const read_ptr = IO_get_read_ptr(in, 1);
memset(*literals, read_ptr[0], size);
break;
}
default:
IMPOSSIBLE();
}
return size;
}
static size_t decode_literals_compressed(frame_context_t *const ctx,
istream_t *const in,
u8 **const literals,
const int block_type,
const int size_format) {
size_t regenerated_size, compressed_size;
int num_streams = 4;
switch (size_format) {
case 0:
num_streams = 1;
case 1:
regenerated_size = IO_read_bits(in, 10);
compressed_size = IO_read_bits(in, 10);
break;
case 2:
regenerated_size = IO_read_bits(in, 14);
compressed_size = IO_read_bits(in, 14);
break;
case 3:
regenerated_size = IO_read_bits(in, 18);
compressed_size = IO_read_bits(in, 18);
break;
default:
IMPOSSIBLE();
}
if (regenerated_size > MAX_LITERALS_SIZE ||
compressed_size >= regenerated_size) {
CORRUPTION();
}
*literals = malloc(regenerated_size);
if (!*literals) {
BAD_ALLOC();
}
ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
if (block_type == 2) {
HUF_free_dtable(&ctx->literals_dtable);
decode_huf_table(&ctx->literals_dtable, &huf_stream);
} else {
if (!ctx->literals_dtable.symbols) {
CORRUPTION();
}
}
size_t symbols_decoded;
if (num_streams == 1) {
symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
} else {
symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
}
if (symbols_decoded != regenerated_size) {
CORRUPTION();
}
return regenerated_size;
}
static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
const u8 header = IO_read_bits(in, 8);
u8 weights[HUF_MAX_SYMBS];
memset(weights, 0, sizeof(weights));
int num_symbs;
if (header >= 128) {
num_symbs = header - 127;
const size_t bytes = (num_symbs + 1) / 2;
const u8 *const weight_src = IO_get_read_ptr(in, bytes);
for (int i = 0; i < num_symbs; i++) {
if (i % 2 == 0) {
weights[i] = weight_src[i / 2] >> 4;
} else {
weights[i] = weight_src[i / 2] & 0xf;
}
}
} else {
istream_t fse_stream = IO_make_sub_istream(in, header);
ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
}
HUF_init_dtable_usingweights(dtable, weights, num_symbs);
}
static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
int *const num_symbs) {
const int MAX_ACCURACY_LOG = 7;
FSE_dtable dtable;
FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
*num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
FSE_free_dtable(&dtable);
}
typedef struct {
FSE_dtable ll_table;
FSE_dtable of_table;
FSE_dtable ml_table;
u16 ll_state;
u16 of_state;
u16 ml_state;
} sequence_states_t;
typedef enum {
seq_literal_length = 0,
seq_offset = 1,
seq_match_length = 2,
} seq_part_t;
typedef enum {
seq_predefined = 0,
seq_rle = 1,
seq_fse = 2,
seq_repeat = 3,
} seq_mode_t;
static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2,
2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40,
48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65538};
static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83,
99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
static const u8 SEQ_MAX_CODES[3] = {35, -1, 52};
static void decompress_sequences(frame_context_t *const ctx,
istream_t *const in,
sequence_command_t *const sequences,
const size_t num_sequences);
static sequence_command_t decode_sequence(sequence_states_t *const state,
const u8 *const src,
i64 *const offset);
static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
const seq_part_t type, const seq_mode_t mode);
static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
sequence_command_t **const sequences) {
size_t num_sequences;
u8 header = IO_read_bits(in, 8);
if (header == 0) {
*sequences = NULL;
return 0;
} else if (header < 128) {
num_sequences = header;
} else if (header < 255) {
num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
} else {
num_sequences = IO_read_bits(in, 16) + 0x7F00;
}
*sequences = malloc(num_sequences * sizeof(sequence_command_t));
if (!*sequences) {
BAD_ALLOC();
}
decompress_sequences(ctx, in, *sequences, num_sequences);
return num_sequences;
}
static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
sequence_command_t *const sequences,
const size_t num_sequences) {
u8 compression_modes = IO_read_bits(in, 8);
if ((compression_modes & 3) != 0) {
CORRUPTION();
}
decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
(compression_modes >> 6) & 3);
decode_seq_table(&ctx->of_dtable, in, seq_offset,
(compression_modes >> 4) & 3);
decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
(compression_modes >> 2) & 3);
sequence_states_t states;
{
states.ll_table = ctx->ll_dtable;
states.of_table = ctx->of_dtable;
states.ml_table = ctx->ml_dtable;
}
const size_t len = IO_istream_len(in);
const u8 *const src = IO_get_read_ptr(in, len);
const int padding = 8 - highest_set_bit(src[len - 1]);
i64 bit_offset = len * 8 - padding;
FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
for (size_t i = 0; i < num_sequences; i++) {
sequences[i] = decode_sequence(&states, src, &bit_offset);
}
if (bit_offset != 0) {
CORRUPTION();
}
}
static sequence_command_t decode_sequence(sequence_states_t *const states,
const u8 *const src,
i64 *const offset) {
const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
ml_code > SEQ_MAX_CODES[seq_match_length]) {
CORRUPTION();
}
sequence_command_t seq;
seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
seq.match_length =
SEQ_MATCH_LENGTH_BASELINES[ml_code] +
STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
seq.literal_length =
SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
if (*offset != 0) {
FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
FSE_update_state(&states->of_table, &states->of_state, src, offset);
}
return seq;
}
static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
const seq_part_t type, const seq_mode_t mode) {
const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
SEQ_OFFSET_DEFAULT_DIST,
SEQ_MATCH_LENGTH_DEFAULT_DIST};
const size_t default_distribution_lengths[] = {36, 29, 53};
const size_t default_distribution_accuracies[] = {6, 5, 6};
const size_t max_accuracies[] = {9, 8, 9};
if (mode != seq_repeat) {
FSE_free_dtable(table);
}
switch (mode) {
case seq_predefined: {
const i16 *distribution = default_distributions[type];
const size_t symbs = default_distribution_lengths[type];
const size_t accuracy_log = default_distribution_accuracies[type];
FSE_init_dtable(table, distribution, symbs, accuracy_log);
break;
}
case seq_rle: {
const u8 symb = IO_get_read_ptr(in, 1)[0];
FSE_init_dtable_rle(table, symb);
break;
}
case seq_fse: {
FSE_decode_header(table, in, max_accuracies[type]);
break;
}
case seq_repeat:
if (!table->symbols) {
CORRUPTION();
}
break;
default:
IMPOSSIBLE();
break;
}
}
static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
const u8 *const literals,
const size_t literals_len,
const sequence_command_t *const sequences,
const size_t num_sequences) {
istream_t litstream = IO_make_istream(literals, literals_len);
u64 *const offset_hist = ctx->previous_offsets;
size_t total_output = ctx->current_total_output;
for (size_t i = 0; i < num_sequences; i++) {
const sequence_command_t seq = sequences[i];
{
const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
total_output += literals_size;
}
size_t const offset = compute_offset(seq, offset_hist);
size_t const match_length = seq.match_length;
execute_match_copy(ctx, offset, match_length, total_output, out);
total_output += match_length;
}
{
size_t len = IO_istream_len(&litstream);
copy_literals(len, &litstream, out);
total_output += len;
}
ctx->current_total_output = total_output;
}
static u32 copy_literals(const size_t literal_length, istream_t *litstream,
ostream_t *const out) {
if (literal_length > IO_istream_len(litstream)) {
CORRUPTION();
}
u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
const u8 *const read_ptr =
IO_get_read_ptr(litstream, literal_length);
memcpy(write_ptr, read_ptr, literal_length);
return literal_length;
}
static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
size_t offset;
if (seq.offset <= 3) {
u32 idx = seq.offset - 1;
if (seq.literal_length == 0) {
idx++;
}
if (idx == 0) {
offset = offset_hist[0];
} else {
offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
if (idx > 1) {
offset_hist[2] = offset_hist[1];
}
offset_hist[1] = offset_hist[0];
offset_hist[0] = offset;
}
} else {
offset = seq.offset - 3;
offset_hist[2] = offset_hist[1];
offset_hist[1] = offset_hist[0];
offset_hist[0] = offset;
}
return offset;
}
static void execute_match_copy(frame_context_t *const ctx, size_t offset,
size_t match_length, size_t total_output,
ostream_t *const out) {
u8 *write_ptr = IO_get_write_ptr(out, match_length);
if (total_output <= ctx->header.window_size) {
if (offset > total_output + ctx->dict_content_len) {
CORRUPTION();
}
if (offset > total_output) {
const size_t dict_copy =
MIN(offset - total_output, match_length);
const size_t dict_offset =
ctx->dict_content_len - (offset - total_output);
memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
write_ptr += dict_copy;
match_length -= dict_copy;
}
} else if (offset > ctx->header.window_size) {
CORRUPTION();
}
for (size_t j = 0; j < match_length; j++) {
*write_ptr = *(write_ptr - offset);
write_ptr++;
}
}
size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
istream_t in = IO_make_istream(src, src_len);
{
const u32 magic_number = IO_read_bits(&in, 32);
if (magic_number == 0xFD2FB528U) {
frame_header_t header;
parse_frame_header(&header, &in);
if (header.frame_content_size == 0 && !header.single_segment_flag) {
return -1;
}
return header.frame_content_size;
} else {
ERROR("ZSTD frame magic number did not match");
}
}
}
#define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
#define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
dictionary_t* create_dictionary() {
dictionary_t* dict = calloc(1, sizeof(dictionary_t));
if (!dict) {
BAD_ALLOC();
}
return dict;
}
static void init_dictionary_content(dictionary_t *const dict,
istream_t *const in);
void parse_dictionary(dictionary_t *const dict, const void *src,
size_t src_len) {
const u8 *byte_src = (const u8 *)src;
memset(dict, 0, sizeof(dictionary_t));
if (src == NULL) {
NULL_SRC();
}
if (src_len < 8) {
DICT_SIZE_ERROR();
}
istream_t in = IO_make_istream(byte_src, src_len);
const u32 magic_number = IO_read_bits(&in, 32);
if (magic_number != 0xEC30A437) {
IO_rewind_bits(&in, 32);
init_dictionary_content(dict, &in);
return;
}
dict->dictionary_id = IO_read_bits(&in, 32);
decode_huf_table(&dict->literals_dtable, &in);
decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
dict->previous_offsets[0] = IO_read_bits(&in, 32);
dict->previous_offsets[1] = IO_read_bits(&in, 32);
dict->previous_offsets[2] = IO_read_bits(&in, 32);
for (int i = 0; i < 3; i++) {
if (dict->previous_offsets[i] > src_len) {
ERROR("Dictionary corrupted");
}
}
init_dictionary_content(dict, &in);
}
static void init_dictionary_content(dictionary_t *const dict,
istream_t *const in) {
dict->content_size = IO_istream_len(in);
dict->content = malloc(dict->content_size);
if (!dict->content) {
BAD_ALLOC();
}
const u8 *const content = IO_get_read_ptr(in, dict->content_size);
memcpy(dict->content, content, dict->content_size);
}
void free_dictionary(dictionary_t *const dict) {
HUF_free_dtable(&dict->literals_dtable);
FSE_free_dtable(&dict->ll_dtable);
FSE_free_dtable(&dict->of_dtable);
FSE_free_dtable(&dict->ml_dtable);
free(dict->content);
memset(dict, 0, sizeof(dictionary_t));
free(dict);
}
#define UNALIGNED() ERROR("Attempting to operate on a non-byte aligned stream")
static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
if (num_bits > 64 || num_bits <= 0) {
ERROR("Attempt to read an invalid number of bits");
}
const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
const size_t full_bytes = (num_bits + in->bit_offset) / 8;
if (bytes > in->len) {
INP_SIZE();
}
const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
in->bit_offset = (num_bits + in->bit_offset) % 8;
in->ptr += full_bytes;
in->len -= full_bytes;
return result;
}
static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
if (num_bits < 0) {
ERROR("Attempting to rewind stream by a negative number of bits");
}
const int new_offset = in->bit_offset - num_bits;
const i64 bytes = -(new_offset - 7) / 8;
in->ptr -= bytes;
in->len += bytes;
in->bit_offset = ((new_offset % 8) + 8) % 8;
}
static inline void IO_align_stream(istream_t *const in) {
if (in->bit_offset != 0) {
if (in->len == 0) {
INP_SIZE();
}
in->ptr++;
in->len--;
in->bit_offset = 0;
}
}
static inline void IO_write_byte(ostream_t *const out, u8 symb) {
if (out->len == 0) {
OUT_SIZE();
}
out->ptr[0] = symb;
out->ptr++;
out->len--;
}
static inline size_t IO_istream_len(const istream_t *const in) {
return in->len;
}
static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
if (len > in->len) {
INP_SIZE();
}
if (in->bit_offset != 0) {
UNALIGNED();
}
const u8 *const ptr = in->ptr;
in->ptr += len;
in->len -= len;
return ptr;
}
static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
if (len > out->len) {
OUT_SIZE();
}
u8 *const ptr = out->ptr;
out->ptr += len;
out->len -= len;
return ptr;
}
static inline void IO_advance_input(istream_t *const in, size_t len) {
if (len > in->len) {
INP_SIZE();
}
if (in->bit_offset != 0) {
UNALIGNED();
}
in->ptr += len;
in->len -= len;
}
static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
return (ostream_t) { out, len };
}
static inline istream_t IO_make_istream(const u8 *in, size_t len) {
return (istream_t) { in, len, 0 };
}
static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
const u8 *const ptr = IO_get_read_ptr(in, len);
return IO_make_istream(ptr, len);
}
static inline u64 read_bits_LE(const u8 *src, const int num_bits,
const size_t offset) {
if (num_bits > 64) {
ERROR("Attempt to read an invalid number of bits");
}
src += offset / 8;
size_t bit_offset = offset % 8;
u64 res = 0;
int shift = 0;
int left = num_bits;
while (left > 0) {
u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
res += (((u64)*src++ >> bit_offset) & mask) << shift;
shift += 8 - bit_offset;
left -= 8 - bit_offset;
bit_offset = 0;
}
return res;
}
static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
i64 *const offset) {
*offset = *offset - bits;
size_t actual_off = *offset;
size_t actual_bits = bits;
if (*offset < 0) {
actual_bits += *offset;
actual_off = 0;
}
u64 res = read_bits_LE(src, actual_bits, actual_off);
if (*offset < 0) {
res = -*offset >= 64 ? 0 : (res << -*offset);
}
return res;
}
static inline int highest_set_bit(const u64 num) {
for (int i = 63; i >= 0; i--) {
if (((u64)1 << i) <= num) {
return i;
}
}
return -1;
}
static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
u16 *const state, const u8 *const src,
i64 *const offset) {
const u8 symb = dtable->symbols[*state];
const u8 bits = dtable->num_bits[*state];
const u16 rest = STREAM_read_bits(src, bits, offset);
*state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
return symb;
}
static inline void HUF_init_state(const HUF_dtable *const dtable,
u16 *const state, const u8 *const src,
i64 *const offset) {
const u8 bits = dtable->max_bits;
*state = STREAM_read_bits(src, bits, offset);
}
static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
ostream_t *const out,
istream_t *const in) {
const size_t len = IO_istream_len(in);
if (len == 0) {
INP_SIZE();
}
const u8 *const src = IO_get_read_ptr(in, len);
const int padding = 8 - highest_set_bit(src[len - 1]);
i64 bit_offset = len * 8 - padding;
u16 state;
HUF_init_state(dtable, &state, src, &bit_offset);
size_t symbols_written = 0;
while (bit_offset > -dtable->max_bits) {
IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
symbols_written++;
}
if (bit_offset != -dtable->max_bits) {
CORRUPTION();
}
return symbols_written;
}
static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
ostream_t *const out, istream_t *const in) {
const size_t csize1 = IO_read_bits(in, 16);
const size_t csize2 = IO_read_bits(in, 16);
const size_t csize3 = IO_read_bits(in, 16);
istream_t in1 = IO_make_sub_istream(in, csize1);
istream_t in2 = IO_make_sub_istream(in, csize2);
istream_t in3 = IO_make_sub_istream(in, csize3);
istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
size_t total_output = 0;
total_output += HUF_decompress_1stream(dtable, out, &in1);
total_output += HUF_decompress_1stream(dtable, out, &in2);
total_output += HUF_decompress_1stream(dtable, out, &in3);
total_output += HUF_decompress_1stream(dtable, out, &in4);
return total_output;
}
static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
const int num_symbs) {
memset(table, 0, sizeof(HUF_dtable));
if (num_symbs > HUF_MAX_SYMBS) {
ERROR("Too many symbols for Huffman");
}
u8 max_bits = 0;
u16 rank_count[HUF_MAX_BITS + 1];
memset(rank_count, 0, sizeof(rank_count));
for (int i = 0; i < num_symbs; i++) {
if (bits[i] > HUF_MAX_BITS) {
ERROR("Huffman table depth too large");
}
max_bits = MAX(max_bits, bits[i]);
rank_count[bits[i]]++;
}
const size_t table_size = 1 << max_bits;
table->max_bits = max_bits;
table->symbols = malloc(table_size);
table->num_bits = malloc(table_size);
if (!table->symbols || !table->num_bits) {
free(table->symbols);
free(table->num_bits);
BAD_ALLOC();
}
u32 rank_idx[HUF_MAX_BITS + 1];
rank_idx[max_bits] = 0;
for (int i = max_bits; i >= 1; i--) {
rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
}
if (rank_idx[0] != table_size) {
CORRUPTION();
}
for (int i = 0; i < num_symbs; i++) {
if (bits[i] != 0) {
const u16 code = rank_idx[bits[i]];
const u16 len = 1 << (max_bits - bits[i]);
memset(&table->symbols[code], i, len);
rank_idx[bits[i]] += len;
}
}
}
static void HUF_init_dtable_usingweights(HUF_dtable *const table,
const u8 *const weights,
const int num_symbs) {
if (num_symbs + 1 > HUF_MAX_SYMBS) {
ERROR("Too many symbols for Huffman");
}
u8 bits[HUF_MAX_SYMBS];
u64 weight_sum = 0;
for (int i = 0; i < num_symbs; i++) {
if (weights[i] > HUF_MAX_BITS) {
CORRUPTION();
}
weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
}
const int max_bits = highest_set_bit(weight_sum) + 1;
const u64 left_over = ((u64)1 << max_bits) - weight_sum;
if (left_over & (left_over - 1)) {
CORRUPTION();
}
const int last_weight = highest_set_bit(left_over) + 1;
for (int i = 0; i < num_symbs; i++) {
bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
}
bits[num_symbs] =
max_bits + 1 - last_weight;
HUF_init_dtable(table, bits, num_symbs + 1);
}
static void HUF_free_dtable(HUF_dtable *const dtable) {
free(dtable->symbols);
free(dtable->num_bits);
memset(dtable, 0, sizeof(HUF_dtable));
}
static void HUF_copy_dtable(HUF_dtable *const dst,
const HUF_dtable *const src) {
if (src->max_bits == 0) {
memset(dst, 0, sizeof(HUF_dtable));
return;
}
const size_t size = (size_t)1 << src->max_bits;
dst->max_bits = src->max_bits;
dst->symbols = malloc(size);
dst->num_bits = malloc(size);
if (!dst->symbols || !dst->num_bits) {
BAD_ALLOC();
}
memcpy(dst->symbols, src->symbols, size);
memcpy(dst->num_bits, src->num_bits, size);
}
static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
const u16 state) {
return dtable->symbols[state];
}
static inline void FSE_update_state(const FSE_dtable *const dtable,
u16 *const state, const u8 *const src,
i64 *const offset) {
const u8 bits = dtable->num_bits[*state];
const u16 rest = STREAM_read_bits(src, bits, offset);
*state = dtable->new_state_base[*state] + rest;
}
static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
u16 *const state, const u8 *const src,
i64 *const offset) {
const u8 symb = FSE_peek_symbol(dtable, *state);
FSE_update_state(dtable, state, src, offset);
return symb;
}
static inline void FSE_init_state(const FSE_dtable *const dtable,
u16 *const state, const u8 *const src,
i64 *const offset) {
const u8 bits = dtable->accuracy_log;
*state = STREAM_read_bits(src, bits, offset);
}
static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
ostream_t *const out,
istream_t *const in) {
const size_t len = IO_istream_len(in);
if (len == 0) {
INP_SIZE();
}
const u8 *const src = IO_get_read_ptr(in, len);
const int padding = 8 - highest_set_bit(src[len - 1]);
i64 offset = len * 8 - padding;
u16 state1, state2;
FSE_init_state(dtable, &state1, src, &offset);
FSE_init_state(dtable, &state2, src, &offset);
size_t symbols_written = 0;
while (1) {
IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
symbols_written++;
if (offset < 0) {
IO_write_byte(out, FSE_peek_symbol(dtable, state2));
symbols_written++;
break;
}
IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
symbols_written++;
if (offset < 0) {
IO_write_byte(out, FSE_peek_symbol(dtable, state1));
symbols_written++;
break;
}
}
return symbols_written;
}
static void FSE_init_dtable(FSE_dtable *const dtable,
const i16 *const norm_freqs, const int num_symbs,
const int accuracy_log) {
if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
ERROR("FSE accuracy too large");
}
if (num_symbs > FSE_MAX_SYMBS) {
ERROR("Too many symbols for FSE");
}
dtable->accuracy_log = accuracy_log;
const size_t size = (size_t)1 << accuracy_log;
dtable->symbols = malloc(size * sizeof(u8));
dtable->num_bits = malloc(size * sizeof(u8));
dtable->new_state_base = malloc(size * sizeof(u16));
if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
BAD_ALLOC();
}
u16 state_desc[FSE_MAX_SYMBS];
int high_threshold = size;
for (int s = 0; s < num_symbs; s++) {
if (norm_freqs[s] == -1) {
dtable->symbols[--high_threshold] = s;
state_desc[s] = 1;
}
}
const u16 step = (size >> 1) + (size >> 3) + 3;
const u16 mask = size - 1;
u16 pos = 0;
for (int s = 0; s < num_symbs; s++) {
if (norm_freqs[s] <= 0) {
continue;
}
state_desc[s] = norm_freqs[s];
for (int i = 0; i < norm_freqs[s]; i++) {
dtable->symbols[pos] = s;
do {
pos = (pos + step) & mask;
} while (pos >=
high_threshold);
}
}
if (pos != 0) {
CORRUPTION();
}
for (size_t i = 0; i < size; i++) {
u8 symbol = dtable->symbols[i];
u16 next_state_desc = state_desc[symbol]++;
dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
dtable->new_state_base[i] =
((u16)next_state_desc << dtable->num_bits[i]) - size;
}
}
static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
const int max_accuracy_log) {
if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
ERROR("FSE accuracy too large");
}
const int accuracy_log = 5 + IO_read_bits(in, 4);
if (accuracy_log > max_accuracy_log) {
ERROR("FSE accuracy too large");
}
i32 remaining = 1 << accuracy_log;
i16 frequencies[FSE_MAX_SYMBS];
int symb = 0;
while (remaining > 0 && symb < FSE_MAX_SYMBS) {
int bits = highest_set_bit(remaining + 1) + 1;
u16 val = IO_read_bits(in, bits);
const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
if ((val & lower_mask) < threshold) {
IO_rewind_bits(in, 1);
val = val & lower_mask;
} else if (val > lower_mask) {
val = val - threshold;
}
const i16 proba = (i16)val - 1;
remaining -= proba < 0 ? -proba : proba;
frequencies[symb] = proba;
symb++;
if (proba == 0) {
int repeat = IO_read_bits(in, 2);
while (1) {
for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
frequencies[symb++] = 0;
}
if (repeat == 3) {
repeat = IO_read_bits(in, 2);
} else {
break;
}
}
}
}
IO_align_stream(in);
if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
CORRUPTION();
}
FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
}
static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
dtable->symbols = malloc(sizeof(u8));
dtable->num_bits = malloc(sizeof(u8));
dtable->new_state_base = malloc(sizeof(u16));
if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
BAD_ALLOC();
}
dtable->symbols[0] = symb;
dtable->num_bits[0] = 0;
dtable->new_state_base[0] = 0;
dtable->accuracy_log = 0;
}
static void FSE_free_dtable(FSE_dtable *const dtable) {
free(dtable->symbols);
free(dtable->num_bits);
free(dtable->new_state_base);
memset(dtable, 0, sizeof(FSE_dtable));
}
static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
if (src->accuracy_log == 0) {
memset(dst, 0, sizeof(FSE_dtable));
return;
}
size_t size = (size_t)1 << src->accuracy_log;
dst->accuracy_log = src->accuracy_log;
dst->symbols = malloc(size);
dst->num_bits = malloc(size);
dst->new_state_base = malloc(size * sizeof(u16));
if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
BAD_ALLOC();
}
memcpy(dst->symbols, src->symbols, size);
memcpy(dst->num_bits, src->num_bits, size);
memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
}