#include "lzx.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef __GNUC__
#define memcpy __builtin_memcpy
#endif
typedef unsigned char  UBYTE; 
typedef unsigned short UWORD; 
typedef unsigned int   ULONG; 
typedef   signed int    LONG; 
#define LZX_MIN_MATCH                (2)
#define LZX_MAX_MATCH                (257)
#define LZX_NUM_CHARS                (256)
#define LZX_BLOCKTYPE_INVALID        (0)   
#define LZX_BLOCKTYPE_VERBATIM       (1)
#define LZX_BLOCKTYPE_ALIGNED        (2)
#define LZX_BLOCKTYPE_UNCOMPRESSED   (3)
#define LZX_PRETREE_NUM_ELEMENTS     (20)
#define LZX_ALIGNED_NUM_ELEMENTS     (8)   
#define LZX_NUM_PRIMARY_LENGTHS      (7)   
#define LZX_NUM_SECONDARY_LENGTHS    (249) 
#define LZX_PRETREE_MAXSYMBOLS  (LZX_PRETREE_NUM_ELEMENTS)
#define LZX_PRETREE_TABLEBITS   (6)
#define LZX_MAINTREE_MAXSYMBOLS (LZX_NUM_CHARS + 50*8)
#define LZX_MAINTREE_TABLEBITS  (12)
#define LZX_LENGTH_MAXSYMBOLS   (LZX_NUM_SECONDARY_LENGTHS+1)
#define LZX_LENGTH_TABLEBITS    (12)
#define LZX_ALIGNED_MAXSYMBOLS  (LZX_ALIGNED_NUM_ELEMENTS)
#define LZX_ALIGNED_TABLEBITS   (7)
#define LZX_LENTABLE_SAFETY (64) 
#define LZX_DECLARE_TABLE(tbl) \
  UWORD tbl##_table[(1<<LZX_##tbl##_TABLEBITS) + (LZX_##tbl##_MAXSYMBOLS<<1)];\
  UBYTE tbl##_len  [LZX_##tbl##_MAXSYMBOLS + LZX_LENTABLE_SAFETY]
struct LZXstate
{
    UBYTE *window;         
    ULONG window_size;     
    ULONG actual_size;     
    ULONG window_posn;     
    ULONG R0, R1, R2;      
    UWORD main_elements;   
    int   header_read;     
    UWORD block_type;      
    ULONG block_length;    
    ULONG block_remaining; 
    ULONG frames_read;     
    LONG  intel_filesize;  
    LONG  intel_curpos;    
    int   intel_started;   
    LZX_DECLARE_TABLE(PRETREE);
    LZX_DECLARE_TABLE(MAINTREE);
    LZX_DECLARE_TABLE(LENGTH);
    LZX_DECLARE_TABLE(ALIGNED);
};
static const UBYTE extra_bits[51] = {
     0,  0,  0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,
     7,  7,  8,  8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14,
    15, 15, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
    17, 17, 17
};
static const ULONG position_base[51] = {
          0,       1,       2,      3,      4,      6,      8,     12,     16,     24,     32,       48,      64,      96,     128,     192,
        256,     384,     512,    768,   1024,   1536,   2048,   3072,   4096,   6144,   8192,    12288,   16384,   24576,   32768,   49152,
      65536,   98304,  131072, 196608, 262144, 393216, 524288, 655360, 786432, 917504, 1048576, 1179648, 1310720, 1441792, 1572864, 1703936,
    1835008, 1966080, 2097152
};
struct LZXstate *LZXinit(int window)
{
    struct LZXstate *pState=NULL;
    ULONG wndsize = 1 << window;
    int i, posn_slots;
    
    
    if (window < 15 || window > 21) return NULL;
    
    pState = (struct LZXstate *)malloc(sizeof(struct LZXstate));
    if (!(pState->window = (UBYTE *)malloc(wndsize)))
    {
        free(pState);
        return NULL;
    }
    pState->actual_size = wndsize;
    pState->window_size = wndsize;
    
    if (window == 20) posn_slots = 42;
    else if (window == 21) posn_slots = 50;
    else posn_slots = window << 1;
    
    
    
    pState->R0  =  pState->R1  = pState->R2 = 1;
    pState->main_elements   = LZX_NUM_CHARS + (posn_slots << 3);
    pState->header_read     = 0;
    pState->frames_read     = 0;
    pState->block_remaining = 0;
    pState->block_type      = LZX_BLOCKTYPE_INVALID;
    pState->intel_curpos    = 0;
    pState->intel_started   = 0;
    pState->window_posn     = 0;
    
    for (i = 0; i < LZX_MAINTREE_MAXSYMBOLS; i++) pState->MAINTREE_len[i] = 0;
    for (i = 0; i < LZX_LENGTH_MAXSYMBOLS; i++)   pState->LENGTH_len[i]   = 0;
    return pState;
}
void LZXteardown(struct LZXstate *pState)
{
    if (pState)
    {
        if (pState->window)
            free(pState->window);
        free(pState);
    }
}
int LZXreset(struct LZXstate *pState)
{
    int i;
    pState->R0  =  pState->R1  = pState->R2 = 1;
    pState->header_read     = 0;
    pState->frames_read     = 0;
    pState->block_remaining = 0;
    pState->block_type      = LZX_BLOCKTYPE_INVALID;
    pState->intel_curpos    = 0;
    pState->intel_started   = 0;
    pState->window_posn     = 0;
    for (i = 0; i < LZX_MAINTREE_MAXSYMBOLS + LZX_LENTABLE_SAFETY; i++) pState->MAINTREE_len[i] = 0;
    for (i = 0; i < LZX_LENGTH_MAXSYMBOLS + LZX_LENTABLE_SAFETY; i++)   pState->LENGTH_len[i]   = 0;
    return DECR_OK;
}
#define ULONG_BITS (sizeof(ULONG)<<3)
#define INIT_BITSTREAM do { bitsleft = 0; bitbuf = 0; } while (0)
#define ENSURE_BITS(n)							\
  while (bitsleft < (n)) {						\
    bitbuf |= ((inpos[1]<<8)|inpos[0]) << (ULONG_BITS-16 - bitsleft);	\
    bitsleft += 16; inpos+=2;						\
  }
#define PEEK_BITS(n)   (bitbuf >> (ULONG_BITS - (n)))
#define REMOVE_BITS(n) ((bitbuf <<= (n)), (bitsleft -= (n)))
#define READ_BITS(v,n) do {						\
  ENSURE_BITS(n);							\
  (v) = PEEK_BITS(n);							\
  REMOVE_BITS(n);							\
} while (0)
#define TABLEBITS(tbl)   (LZX_##tbl##_TABLEBITS)
#define MAXSYMBOLS(tbl)  (LZX_##tbl##_MAXSYMBOLS)
#define SYMTABLE(tbl)    (pState->tbl##_table)
#define LENTABLE(tbl)    (pState->tbl##_len)
#define BUILD_TABLE(tbl)						\
  if (make_decode_table(						\
    MAXSYMBOLS(tbl), TABLEBITS(tbl), LENTABLE(tbl), SYMTABLE(tbl)	\
  )) { return DECR_ILLEGALDATA; }
#define READ_HUFFSYM(tbl,var) do {					\
  ENSURE_BITS(16);							\
  hufftbl = SYMTABLE(tbl);						\
  if ((i = hufftbl[PEEK_BITS(TABLEBITS(tbl))]) >= MAXSYMBOLS(tbl)) {	\
    j = 1 << (ULONG_BITS - TABLEBITS(tbl));				\
    do {								\
      j >>= 1; i <<= 1; i |= (bitbuf & j) ? 1 : 0;			\
      if (!j) { return DECR_ILLEGALDATA; }	                        \
    } while ((i = hufftbl[i]) >= MAXSYMBOLS(tbl));			\
  }									\
  j = LENTABLE(tbl)[(var) = i];						\
  REMOVE_BITS(j);							\
} while (0)
#define READ_LENGTHS(tbl,first,last) do { \
  lb.bb = bitbuf; lb.bl = bitsleft; lb.ip = inpos; \
  if (lzx_read_lens(pState, LENTABLE(tbl),(first),(last),&lb)) { \
    return DECR_ILLEGALDATA; \
  } \
  bitbuf = lb.bb; bitsleft = lb.bl; inpos = lb.ip; \
} while (0)
static int make_decode_table(ULONG nsyms, ULONG nbits, UBYTE *length, UWORD *table) {
    register UWORD sym;
    register ULONG leaf;
    register UBYTE bit_num = 1;
    ULONG fill;
    ULONG pos         = 0; 
    ULONG table_mask  = 1 << nbits;
    ULONG bit_mask    = table_mask >> 1; 
    ULONG next_symbol = bit_mask; 
    
    while (bit_num <= nbits) {
        for (sym = 0; sym < nsyms; sym++) {
            if (length[sym] == bit_num) {
                leaf = pos;
                if((pos += bit_mask) > table_mask) return 1; 
                
                fill = bit_mask;
                while (fill-- > 0) table[leaf++] = sym;
            }
        }
        bit_mask >>= 1;
        bit_num++;
    }
    
    if (pos != table_mask) {
        
        for (sym = pos; sym < table_mask; sym++) table[sym] = 0;
        
        pos <<= 16;
        table_mask <<= 16;
        bit_mask = 1 << 15;
        while (bit_num <= 16) {
            for (sym = 0; sym < nsyms; sym++) {
                if (length[sym] == bit_num) {
                    leaf = pos >> 16;
                    for (fill = 0; fill < bit_num - nbits; fill++) {
                        
                        if (table[leaf] == 0) {
                            table[(next_symbol << 1)] = 0;
                            table[(next_symbol << 1) + 1] = 0;
                            table[leaf] = next_symbol++;
                        }
                        
                        leaf = table[leaf] << 1;
                        if ((pos >> (15-fill)) & 1) leaf++;
                    }
                    table[leaf] = sym;
                    if ((pos += bit_mask) > table_mask) return 1; 
                }
            }
            bit_mask >>= 1;
            bit_num++;
        }
    }
    
    if (pos == table_mask) return 0;
    
    for (sym = 0; sym < nsyms; sym++) if (length[sym]) return 1;
    return 0;
}
struct lzx_bits {
  ULONG bb;
  int bl;
  UBYTE *ip;
};
static int lzx_read_lens(struct LZXstate *pState, UBYTE *lens, ULONG first, ULONG last, struct lzx_bits *lb) {
    ULONG i,j, x,y;
    int z;
    register ULONG bitbuf = lb->bb;
    register int bitsleft = lb->bl;
    UBYTE *inpos = lb->ip;
    UWORD *hufftbl;
    for (x = 0; x < 20; x++) {
        READ_BITS(y, 4);
        LENTABLE(PRETREE)[x] = y;
    }
    BUILD_TABLE(PRETREE);
    for (x = first; x < last; ) {
        READ_HUFFSYM(PRETREE, z);
        if (z == 17) {
            READ_BITS(y, 4); y += 4;
            while (y--) lens[x++] = 0;
        }
        else if (z == 18) {
            READ_BITS(y, 5); y += 20;
            while (y--) lens[x++] = 0;
        }
        else if (z == 19) {
            READ_BITS(y, 1); y += 4;
            READ_HUFFSYM(PRETREE, z);
            z = lens[x] - z; if (z < 0) z += 17;
            while (y--) lens[x++] = z;
        }
        else {
            z = lens[x] - z; if (z < 0) z += 17;
            lens[x++] = z;
        }
    }
    lb->bb = bitbuf;
    lb->bl = bitsleft;
    lb->ip = inpos;
    return 0;
}
int LZXdecompress(struct LZXstate *pState, unsigned char *inpos, unsigned char *outpos, int inlen, int outlen) {
    UBYTE *endinp = inpos + inlen;
    UBYTE *window = pState->window;
    UBYTE *runsrc, *rundest;
    UWORD *hufftbl; 
    ULONG window_posn = pState->window_posn;
    ULONG window_size = pState->window_size;
    ULONG R0 = pState->R0;
    ULONG R1 = pState->R1;
    ULONG R2 = pState->R2;
    register ULONG bitbuf;
    register int bitsleft;
    ULONG match_offset, i,j,k; 
    struct lzx_bits lb; 
    int togo = outlen, this_run, main_element, aligned_bits;
    int match_length, length_footer, extra, verbatim_bits;
    INIT_BITSTREAM;
    
    if (!pState->header_read) {
        i = j = 0;
        READ_BITS(k, 1); if (k) { READ_BITS(i,16); READ_BITS(j,16); }
        pState->intel_filesize = (i << 16) | j; 
        pState->header_read = 1;
    }
    
    while (togo > 0) {
        
        if (pState->block_remaining == 0) {
            if (pState->block_type == LZX_BLOCKTYPE_UNCOMPRESSED) {
                if (pState->block_length & 1) inpos++; 
                INIT_BITSTREAM;
            }
            READ_BITS(pState->block_type, 3);
            READ_BITS(i, 16);
            READ_BITS(j, 8);
            pState->block_remaining = pState->block_length = (i << 8) | j;
            switch (pState->block_type) {
                case LZX_BLOCKTYPE_ALIGNED:
                    for (i = 0; i < 8; i++) { READ_BITS(j, 3); LENTABLE(ALIGNED)[i] = j; }
                    BUILD_TABLE(ALIGNED);
                    
                case LZX_BLOCKTYPE_VERBATIM:
                    READ_LENGTHS(MAINTREE, 0, 256);
                    READ_LENGTHS(MAINTREE, 256, pState->main_elements);
                    BUILD_TABLE(MAINTREE);
                    if (LENTABLE(MAINTREE)[0xE8] != 0) pState->intel_started = 1;
                    READ_LENGTHS(LENGTH, 0, LZX_NUM_SECONDARY_LENGTHS);
                    BUILD_TABLE(LENGTH);
                    break;
                case LZX_BLOCKTYPE_UNCOMPRESSED:
                    pState->intel_started = 1; 
                    ENSURE_BITS(16); 
                    if (bitsleft > 16) inpos -= 2; 
                    R0 = inpos[0]|(inpos[1]<<8)|(inpos[2]<<16)|(inpos[3]<<24);inpos+=4;
                    R1 = inpos[0]|(inpos[1]<<8)|(inpos[2]<<16)|(inpos[3]<<24);inpos+=4;
                    R2 = inpos[0]|(inpos[1]<<8)|(inpos[2]<<16)|(inpos[3]<<24);inpos+=4;
                    break;
                default:
                    return DECR_ILLEGALDATA;
            }
        }
        
        if (inpos > endinp) {
            
            if (inpos > (endinp+2) || bitsleft < 16) return DECR_ILLEGALDATA;
        }
        while ((this_run = pState->block_remaining) > 0 && togo > 0) {
            if (this_run > togo) this_run = togo;
            togo -= this_run;
            pState->block_remaining -= this_run;
            
            window_posn &= window_size - 1;
            
            if ((window_posn + this_run) > window_size)
                return DECR_DATAFORMAT;
            switch (pState->block_type) {
                case LZX_BLOCKTYPE_VERBATIM:
                    while (this_run > 0) {
                        READ_HUFFSYM(MAINTREE, main_element);
                        if (main_element < LZX_NUM_CHARS) {
                            
                            window[window_posn++] = main_element;
                            this_run--;
                        }
                        else {
                            
                            main_element -= LZX_NUM_CHARS;
                            match_length = main_element & LZX_NUM_PRIMARY_LENGTHS;
                            if (match_length == LZX_NUM_PRIMARY_LENGTHS) {
                                READ_HUFFSYM(LENGTH, length_footer);
                                match_length += length_footer;
                            }
                            match_length += LZX_MIN_MATCH;
                            match_offset = main_element >> 3;
                            if (match_offset > 2) {
                                
                                if (match_offset != 3) {
                                    extra = extra_bits[match_offset];
                                    READ_BITS(verbatim_bits, extra);
                                    match_offset = position_base[match_offset] - 2 + verbatim_bits;
                                }
                                else {
                                    match_offset = 1;
                                }
                                
                                R2 = R1; R1 = R0; R0 = match_offset;
                            }
                            else if (match_offset == 0) {
                                match_offset = R0;
                            }
                            else if (match_offset == 1) {
                                match_offset = R1;
                                R1 = R0; R0 = match_offset;
                            }
                            else  {
                                match_offset = R2;
                                R2 = R0; R0 = match_offset;
                            }
                            rundest = window + window_posn;
                            runsrc  = rundest - match_offset;
                            window_posn += match_length;
                            if (window_posn > window_size) return DECR_ILLEGALDATA;
                            this_run -= match_length;
                            
                            while ((runsrc < window) && (match_length-- > 0)) {
                                *rundest++ = *(runsrc + window_size); runsrc++;
                            }
                            
                            while (match_length-- > 0) *rundest++ = *runsrc++;
                        }
                    }
                    break;
                case LZX_BLOCKTYPE_ALIGNED:
                    while (this_run > 0) {
                        READ_HUFFSYM(MAINTREE, main_element);
                        if (main_element < LZX_NUM_CHARS) {
                            
                            window[window_posn++] = main_element;
                            this_run--;
                        }
                        else {
                            
                            main_element -= LZX_NUM_CHARS;
                            match_length = main_element & LZX_NUM_PRIMARY_LENGTHS;
                            if (match_length == LZX_NUM_PRIMARY_LENGTHS) {
                                READ_HUFFSYM(LENGTH, length_footer);
                                match_length += length_footer;
                            }
                            match_length += LZX_MIN_MATCH;
                            match_offset = main_element >> 3;
                            if (match_offset > 2) {
                                
                                extra = extra_bits[match_offset];
                                match_offset = position_base[match_offset] - 2;
                                if (extra > 3) {
                                    
                                    extra -= 3;
                                    READ_BITS(verbatim_bits, extra);
                                    match_offset += (verbatim_bits << 3);
                                    READ_HUFFSYM(ALIGNED, aligned_bits);
                                    match_offset += aligned_bits;
                                }
                                else if (extra == 3) {
                                    
                                    READ_HUFFSYM(ALIGNED, aligned_bits);
                                    match_offset += aligned_bits;
                                }
                                else if (extra > 0) { 
                                    
                                    READ_BITS(verbatim_bits, extra);
                                    match_offset += verbatim_bits;
                                }
                                else  {
                                    
                                    match_offset = 1;
                                }
                                
                                R2 = R1; R1 = R0; R0 = match_offset;
                            }
                            else if (match_offset == 0) {
                                match_offset = R0;
                            }
                            else if (match_offset == 1) {
                                match_offset = R1;
                                R1 = R0; R0 = match_offset;
                            }
                            else  {
                                match_offset = R2;
                                R2 = R0; R0 = match_offset;
                            }
                            rundest = window + window_posn;
                            runsrc  = rundest - match_offset;
                            window_posn += match_length;
                            if (window_posn > window_size) return DECR_ILLEGALDATA;
                            this_run -= match_length;
                            
                            while ((runsrc < window) && (match_length-- > 0)) {
                                *rundest++ = *(runsrc + window_size); runsrc++;
                            }
                            
                            while (match_length-- > 0) *rundest++ = *runsrc++;
                        }
                    }
                    break;
                case LZX_BLOCKTYPE_UNCOMPRESSED:
                    if ((inpos + this_run) > endinp) return DECR_ILLEGALDATA;
                    memcpy(window + window_posn, inpos, (size_t) this_run);
                    inpos += this_run; window_posn += this_run;
                    break;
                default:
                    return DECR_ILLEGALDATA; 
            }
        }
    }
    if (togo != 0) return DECR_ILLEGALDATA;
    memcpy(outpos, window + ((!window_posn) ? window_size : window_posn) - outlen, (size_t) outlen);
    pState->window_posn = window_posn;
    pState->R0 = R0;
    pState->R1 = R1;
    pState->R2 = R2;
    
    if ((pState->frames_read++ < 32768) && pState->intel_filesize != 0) {
        if (outlen <= 6 || !pState->intel_started) {
            pState->intel_curpos += outlen;
        }
        else {
            UBYTE *data    = outpos;
            UBYTE *dataend = data + outlen - 10;
            LONG curpos    = pState->intel_curpos;
            LONG filesize  = pState->intel_filesize;
            LONG abs_off, rel_off;
            pState->intel_curpos = curpos + outlen;
            while (data < dataend) {
                if (*data++ != 0xE8) { curpos++; continue; }
                abs_off = data[0] | (data[1]<<8) | (data[2]<<16) | (data[3]<<24);
                if ((abs_off >= -curpos) && (abs_off < filesize)) {
                    rel_off = (abs_off >= 0) ? abs_off - curpos : abs_off + filesize;
                    data[0] = (UBYTE) rel_off;
                    data[1] = (UBYTE) (rel_off >> 8);
                    data[2] = (UBYTE) (rel_off >> 16);
                    data[3] = (UBYTE) (rel_off >> 24);
                }
                data += 4;
                curpos += 5;
            }
        }
    }
    return DECR_OK;
}
#ifdef LZX_CHM_TESTDRIVER
int main(int c, char **v)
{
    FILE *fin, *fout;
    struct LZXstate state;
    UBYTE ibuf[16384];
    UBYTE obuf[32768];
    int ilen, olen;
    int status;
    int i;
    int count=0;
    int w = atoi(v[1]);
    LZXinit(&state, w);
    fout = fopen(v[2], "wb");
    for (i=3; i<c; i++)
    {
        fin = fopen(v[i], "rb");
        ilen = fread(ibuf, 1, 16384, fin);
        status = LZXdecompress(&state, ibuf, obuf, ilen, 32768);
        switch (status)
        {
            case DECR_OK:
                printf("ok\n");
                fwrite(obuf, 1, 32768, fout);
                break;
            case DECR_DATAFORMAT:
                printf("bad format\n");
                break;
            case DECR_ILLEGALDATA:
                printf("illegal data\n");
                break;
            case DECR_NOMEMORY:
                printf("no memory\n");
                break;
            default:
                break;
        }
        fclose(fin);
        if (++count == 2)
        {
            count = 0;
            LZXreset(&state);
        }
    }
    fclose(fout);
}
#endif