nlzss11/
lib.rs

1use byteorder::{ByteOrder, LE};
2
3#[derive(thiserror::Error, Debug)]
4#[non_exhaustive]
5pub enum DecompressError {
6    #[error("invalid magic")]
7    InvalidMagic,
8    #[error("invalid index: {0}")]
9    InvalidIndex(usize),
10    // TODO make better
11    #[error("other error: {0}")]
12    LibraryError(&'static str),
13}
14
15struct LzssCode {
16    distance: u32,
17    length: u32,
18}
19
20impl LzssCode {
21    fn read(buf: &[u8]) -> Option<(LzssCode, usize)> {
22        let pair = u16::from_be_bytes(buf[..2].try_into().ok()?) as u32;
23        Some(match pair & 0xF000 {
24            0 => {
25                // 0000LLLL LLLLDDDD DDDDDDDD
26                // L + 0x11, D + 1
27                // 255 + 17 >= length >= 17
28                let length = (pair >> 4) + 0x11;
29                let distance = (((pair & 0xF) << 8) as u32 | *buf.get(2)? as u32) + 1;
30                (LzssCode { distance, length }, 3)
31            }
32            0x1000 => {
33                // 0001LLLL LLLLLLLL LLLLDDDD DDDDDDDD
34                // L + 0x111, D + 1
35                // 2^16 + 255 + 17 >= length >= 256 + 17
36                let ext_pair = u16::from_be_bytes(buf[2..4].try_into().ok()?) as u32;
37                let length = ((pair & 0xFFF) << 4 | ext_pair >> 12) + 0x111;
38                let distance = (ext_pair & 0xFFF) + 1;
39                (LzssCode { distance, length }, 4)
40            }
41            _ => {
42                // LLLLDDDD DDDDDDDD
43                // L + 1, D + 1
44                // 15 + 1 >= length >= 3
45                let length = (pair >> 12) + 1;
46                let distance = (pair & 0xFFF) + 1;
47                (LzssCode { distance, length }, 2)
48            }
49        })
50    }
51
52    pub fn write(&self, out_buf: &mut Vec<u8>) {
53        let adj_dist = self.distance - 1;
54        if self.length >= 0x111 {
55            let adj_len = self.length - 0x111;
56            out_buf.push(((1 << 4) + (adj_len >> 12)) as u8);
57            out_buf.push((adj_len >> 4) as u8);
58            out_buf.push((((adj_len & 0xF) << 4) + (adj_dist >> 8)) as u8);
59            out_buf.push((adj_dist & 0xFF) as u8);
60        } else if self.length >= 0x11 {
61            let adj_len = self.length - 0x11;
62            out_buf.push((adj_len >> 4) as u8);
63            out_buf.push((((adj_len & 0xF) << 4) + (adj_dist >> 8)) as u8);
64            out_buf.push((adj_dist & 0xFF) as u8);
65        } else {
66            let adj_len = self.length - 1;
67            out_buf.push(((adj_len << 4) + (adj_dist >> 8)) as u8);
68            out_buf.push((adj_dist & 0xFF) as u8);
69        }
70    }
71}
72
73#[inline(always)]
74fn get_or_oob_err(data: &[u8], pos: usize) -> Result<u8, DecompressError> {
75    data.get(pos)
76        .copied()
77        .ok_or(DecompressError::InvalidIndex(pos))
78}
79
80pub fn decompress(data: &[u8]) -> Result<Vec<u8>, DecompressError> {
81    if data.len() < 4 {
82        return Err(DecompressError::LibraryError("Too short"));
83    }
84    if data[0] != 0x11 {
85        return Err(DecompressError::InvalidMagic);
86    }
87    let mut pos = 4;
88    let mut out_size: usize = LE::read_u24(&data[1..]) as usize;
89    if out_size == 0 {
90        if data.len() < 8 {
91            return Err(DecompressError::LibraryError("Too short"));
92        }
93        out_size = LE::read_u32(&data[4..]) as usize;
94    }
95    let mut out_buf = Vec::with_capacity(out_size);
96
97    let mut group_header = 0;
98    let mut remaining_chunks = 0;
99    while out_buf.len() < out_buf.capacity() {
100        // one byte indicates if the next 8 blocks are literals or backreferences
101        if remaining_chunks == 0 {
102            group_header = get_or_oob_err(data, pos)?;
103            pos += 1;
104            remaining_chunks = 8;
105        }
106        if (group_header & 0x80) == 0 {
107            out_buf.push(get_or_oob_err(data, pos)?);
108            pos += 1;
109        } else {
110            let (LzssCode { distance, length }, advance) =
111                LzssCode::read(&data[pos..]).ok_or(DecompressError::InvalidIndex(data.len()))?;
112
113            pos += advance;
114
115            let cpy_start = out_buf
116                .len()
117                .checked_sub(distance as usize)
118                .ok_or(DecompressError::InvalidIndex(0))?;
119            if distance > length {
120                // region to copy doesn't overlap the region it's copied to
121                out_buf.extend_from_within(cpy_start..cpy_start + length as usize);
122            } else {
123                for cpy_pos in cpy_start..cpy_start + length as usize {
124                    // it shouldn't be possible to end up in the default of unwrap_or
125                    out_buf.push(out_buf.get(cpy_pos).copied().unwrap_or(0));
126                }
127            }
128        }
129
130        group_header <<= 1;
131        remaining_chunks -= 1;
132    }
133    Ok(out_buf)
134}
135
136// https://github.com/PSeitz/lz4_flex/blob/c17d3b110325211f9e63c897add5fad09ddd8ef1/src/block/hashtable.rs#L16
137#[inline]
138fn make_hash(sequence: [u8; 4]) -> u32 {
139    (u32::from_ne_bytes(sequence).wrapping_mul(2654435761_u32)) >> 16
140}
141
142const HASH_COUNT: usize = 4096 * 16; // has to be power of 2
143
144struct MatchSearcher {
145    search_dict: [u32; HASH_COUNT],
146}
147
148impl MatchSearcher {
149    pub fn new() -> Self {
150        MatchSearcher {
151            search_dict: [u32::MAX; HASH_COUNT],
152        }
153    }
154    pub fn submit_val(&mut self, data: &[u8], cur_pos: u32) {
155        let rest = &data[cur_pos as usize..];
156        if rest.len() < 4 {
157            return;
158        }
159        let hash = make_hash(rest[..4].try_into().unwrap()) % HASH_COUNT as u32;
160        self.search_dict[hash as usize] = cur_pos;
161    }
162
163    pub fn get_lz_code(&self, data: &[u8], cur_pos: u32) -> Option<(u32, u32)> {
164        let rest = &data[cur_pos as usize..];
165        if rest.len() < 4 {
166            return None;
167        }
168        let hash = make_hash(rest[..4].try_into().unwrap()) % HASH_COUNT as u32;
169        let prev = self.search_dict[hash as usize];
170        if prev == u32::MAX {
171            return None;
172        }
173        let match_backref = cur_pos.wrapping_sub(prev);
174        if match_backref > TOTAL_BACKREF_POS {
175            return None;
176        }
177        let match_len = data[cur_pos as usize..]
178            .iter()
179            .zip(data[prev as usize..].iter())
180            .take_while(|&(a, b)| a == b)
181            .count();
182        if match_len < 4 {
183            return None;
184        }
185        Some((match_backref, (match_len as u32).min(TOTAL_BACKREF_LEN)))
186        // None
187    }
188}
189
190const TOTAL_BACKREF_LEN: u32 = 0x10110;
191const TOTAL_BACKREF_POS: u32 = 0xFFF;
192
193#[cfg(feature = "zlib")]
194pub use nlzss11_zlib::compress_with_zlib_into;
195#[cfg(feature = "zlib")]
196pub fn compress(data: &[u8]) -> Vec<u8> {
197    let mut out = Vec::with_capacity(data.len());
198    compress_with_zlib_into(data, &mut out, 7);
199    out
200}
201
202#[cfg(not(feature = "zlib"))]
203pub fn compress(data: &[u8]) -> Vec<u8> {
204    let mut searcher = MatchSearcher::new();
205
206    let mut out_buf: Vec<u8> = Vec::with_capacity(data.len());
207    // write magic
208    out_buf.push(0x11);
209    // very big archives
210    // little endian data length
211    if data.len() < 0xFFFFFF {
212        let mut len_buf = [0; 3];
213        LE::write_u24(&mut len_buf, data.len() as u32);
214        out_buf.extend_from_slice(&len_buf);
215    } else if data.len() < 0xFFFFFFFF {
216        out_buf.extend([0, 0, 0]);
217        out_buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
218    }
219
220    let mut group_header_pos = out_buf.len();
221    out_buf.push(0);
222    let mut group_header = 0;
223    let mut group_header_count = 0;
224
225    // go through the input in 3 byte chunks
226    let mut pos: usize = 0;
227
228    while pos < data.len() {
229        if group_header_count == 8 {
230            out_buf[group_header_pos] = group_header;
231            group_header_pos = out_buf.len();
232            out_buf.push(0);
233            group_header = 0;
234            group_header_count = 0;
235        }
236        if let Some((backref_dist, backref_len)) = searcher.get_lz_code(data, pos as u32) {
237            group_header <<= 1;
238            group_header += 1;
239            group_header_count += 1;
240            LzssCode {
241                length: backref_len,
242                distance: backref_dist,
243            }
244            .write(&mut out_buf);
245            for p in pos..(pos + backref_len as usize) {
246                searcher.submit_val(data, p as u32);
247            }
248            pos += backref_len as usize;
249            // TODO: submit vals?
250        } else {
251            group_header <<= 1;
252            group_header_count += 1;
253            out_buf.push(data[pos]);
254            searcher.submit_val(data, pos as u32);
255            pos += 1;
256        }
257    }
258    if group_header_count != 0 {
259        group_header <<= 8 - group_header_count;
260        out_buf[group_header_pos] = group_header;
261    }
262    out_buf
263}
264
265#[cfg(test)]
266mod test {
267    use super::LzssCode;
268
269    #[test]
270    pub fn test_roundtrip() {
271        let mut buf = Vec::new();
272        for len in 3..0x1011 {
273            for dist in 1..0xFFF {
274                buf.clear();
275                LzssCode {
276                    distance: dist,
277                    length: len,
278                }
279                .write(&mut buf);
280                let (LzssCode { distance, length }, read_bytes) = LzssCode::read(&buf)
281                    .ok_or_else(|| format!("dist: {}, len: {}", dist, len))
282                    .unwrap();
283                // assert_eq!(len, length);
284                // assert_eq!(dist, distance);
285                // assert_eq!(read_bytes, buf.len());
286                if !(len == length && dist == distance && read_bytes == buf.len()) {
287                    panic!("err len: {}, dist: {}", len, dist);
288                }
289            }
290        }
291    }
292}