Skip to main content

zrip_decode/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![cfg_attr(feature = "nightly", feature(optimize_attribute))]
3#![cfg_attr(feature = "paranoid", forbid(unsafe_code))]
4
5#[cfg(feature = "alloc")]
6extern crate alloc;
7
8pub(crate) mod block_decoder;
9#[cfg(feature = "std")]
10pub mod context;
11pub(crate) mod exec;
12pub(crate) mod literals;
13pub(crate) mod ring_buffer;
14pub(crate) mod sequences;
15#[cfg(feature = "std")]
16pub mod streaming;
17
18#[cfg(all(
19    any(target_arch = "x86_64", target_arch = "aarch64"),
20    not(feature = "paranoid")
21))]
22pub(crate) mod simd_decode;
23
24#[cfg(feature = "alloc")]
25use alloc::boxed::Box;
26#[cfg(feature = "alloc")]
27use alloc::vec::Vec;
28
29use crate::exec::decode_execute_sequences;
30use crate::literals::decode_literals_ws;
31use crate::sequences::{SequenceDecodeTables, parse_sequence_count, parse_sequence_tables_ws};
32use zrip_core::block::{BlockType, parse_block_header};
33use zrip_core::error::DecompressError;
34use zrip_core::frame::MAX_WINDOW_SIZE;
35use zrip_core::frame::header::parse_frame_header;
36use zrip_core::huffman::HuffmanDecodeEntry;
37#[cfg(all(
38    any(target_arch = "x86_64", target_arch = "aarch64"),
39    not(feature = "paranoid")
40))]
41use zrip_core::simd::CpuTier;
42use zrip_core::xxhash::Xxh64State;
43
44pub(crate) struct BlockDecodeWorkspace {
45    pub literal_buf: Vec<u8>,
46    pub huf_table: Vec<HuffmanDecodeEntry>,
47    pub huf_table_log: u8,
48    pub huf_valid: bool,
49    pub huf_all_weights: Vec<u8>,
50    pub huf_rank_count: Vec<u32>,
51    pub huf_rank_start: Vec<u32>,
52    pub fse_dist: Vec<i16>,
53    pub fse_symbol_next: Vec<u16>,
54    pub fse_build_buf: Vec<zrip_core::fse::FseDecodeEntry>,
55}
56
57impl BlockDecodeWorkspace {
58    pub(crate) fn new() -> Self {
59        Self {
60            literal_buf: Vec::new(),
61            huf_table: Vec::new(),
62            huf_table_log: 0,
63            huf_valid: false,
64            huf_all_weights: Vec::new(),
65            huf_rank_count: Vec::new(),
66            huf_rank_start: Vec::new(),
67            fse_dist: Vec::new(),
68            fse_symbol_next: Vec::new(),
69            fse_build_buf: Vec::new(),
70        }
71    }
72}
73
74pub(crate) fn skip_skippable_frame(data: &[u8]) -> Option<usize> {
75    if data.len() < 8 {
76        return None;
77    }
78    let magic = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
79    if (magic & 0xFFFF_FFF0) != 0x184D_2A50 {
80        return None;
81    }
82    let frame_size = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
83    let total = 8 + frame_size;
84    if total > data.len() {
85        return None;
86    }
87    Some(total)
88}
89
90pub fn decompress(input: &[u8]) -> Result<Vec<u8>, DecompressError> {
91    decompress_with_dict(input, None)
92}
93
94/// Decompress with an explicit output size limit.
95///
96/// Returns [`DecompressError::OutputTooSmall`] if the decompressed output would
97/// exceed `max_output_size` bytes. Use [`SAFE_DECOMPRESS_LIMIT`](zrip_core::SAFE_DECOMPRESS_LIMIT)
98/// when processing untrusted input to prevent memory exhaustion attacks.
99pub fn decompress_with_limit(
100    input: &[u8],
101    max_output_size: usize,
102) -> Result<Vec<u8>, DecompressError> {
103    let mut output = Vec::new();
104    let mut ws = Box::new(BlockDecodeWorkspace::new());
105    let mut offset = 0;
106    while offset < input.len() {
107        let remaining = &input[offset..];
108        if let Some(skip_len) = skip_skippable_frame(remaining) {
109            offset += skip_len;
110            continue;
111        }
112        let consumed = decompress_frame(remaining, &mut output, max_output_size, None, &mut ws)?;
113        offset += consumed;
114    }
115    Ok(output)
116}
117
118pub fn decompress_into(input: &[u8], output: &mut Vec<u8>) -> Result<usize, DecompressError> {
119    let max_output = zrip_core::DEFAULT_DECOMPRESS_LIMIT;
120    let mut ws = Box::new(BlockDecodeWorkspace::new());
121    let start = output.len();
122    let mut offset = 0;
123    while offset < input.len() {
124        let remaining = &input[offset..];
125        if let Some(skip_len) = skip_skippable_frame(remaining) {
126            offset += skip_len;
127            continue;
128        }
129        let consumed = decompress_frame(remaining, output, max_output, None, &mut ws)?;
130        offset += consumed;
131    }
132    Ok(output.len() - start)
133}
134
135pub fn decompress_with_dict(
136    input: &[u8],
137    dict: Option<&zrip_core::dict::Dictionary>,
138) -> Result<Vec<u8>, DecompressError> {
139    let max_output = zrip_core::DEFAULT_DECOMPRESS_LIMIT;
140    let mut output = Vec::new();
141    let mut ws = Box::new(BlockDecodeWorkspace::new());
142    let mut offset = 0;
143
144    while offset < input.len() {
145        let remaining = &input[offset..];
146        if let Some(skip_len) = skip_skippable_frame(remaining) {
147            offset += skip_len;
148            continue;
149        }
150        let consumed = decompress_frame(remaining, &mut output, max_output, dict, &mut ws)?;
151        offset += consumed;
152    }
153
154    Ok(output)
155}
156
157pub(crate) fn decompress_frame(
158    input: &[u8],
159    output: &mut Vec<u8>,
160    max_output: usize,
161    dict: Option<&zrip_core::dict::Dictionary>,
162    ws: &mut BlockDecodeWorkspace,
163) -> Result<usize, DecompressError> {
164    let header = parse_frame_header(input)?;
165
166    if header.window_size > MAX_WINDOW_SIZE {
167        return Err(DecompressError::WindowTooLarge {
168            requested: header.window_size,
169            max: MAX_WINDOW_SIZE,
170        });
171    }
172
173    if let Some(frame_dict_id) = header.dict_id {
174        match dict {
175            Some(d) if d.id() == frame_dict_id => {}
176            Some(d) => {
177                return Err(DecompressError::DictMismatch {
178                    expected: frame_dict_id,
179                    got: d.id(),
180                });
181            }
182            None => return Err(DecompressError::DictRequired),
183        }
184    }
185
186    if let Some(fcs) = header.frame_content_size {
187        if max_output < usize::MAX && fcs as usize > max_output {
188            return Err(DecompressError::OutputTooSmall);
189        }
190        let hint = (fcs as usize).min(MAX_WINDOW_SIZE as usize);
191        output.reserve(hint + 32);
192    }
193
194    let mut offset = header.header_size;
195    let output_start = output.len();
196
197    let dict_history: &[u8] = if let Some(d) = dict { d.content() } else { &[] };
198
199    let mut seq_tables = if let Some(d) = dict {
200        let mut st = SequenceDecodeTables::new_default();
201        if let Some((t, l)) = d.of_table() {
202            st.of_table = zrip_core::fse::promote_of_table(t);
203            st.of_accuracy = l;
204        }
205        if let Some((t, l)) = d.ml_table() {
206            st.ml_table = zrip_core::fse::promote_ml_table(t);
207            st.ml_accuracy = l;
208        }
209        if let Some((t, l)) = d.ll_table() {
210            st.ll_table = zrip_core::fse::promote_ll_table(t);
211            st.ll_accuracy = l;
212        }
213        st
214    } else {
215        SequenceDecodeTables::new_default()
216    };
217    let mut rep_offsets: [u32; 3] = if let Some(d) = dict {
218        *d.rep_offsets()
219    } else {
220        [1, 4, 8]
221    };
222    ws.huf_valid = false;
223    if let Some(d) = dict {
224        if let Some((t, l)) = d.huf_table() {
225            ws.huf_table.clear();
226            ws.huf_table.extend_from_slice(t);
227            ws.huf_table_log = l;
228            ws.huf_valid = true;
229        }
230    }
231
232    let mut hasher = if header.content_checksum {
233        Some(Xxh64State::new(0))
234    } else {
235        None
236    };
237
238    loop {
239        if offset + 3 > input.len() {
240            return Err(DecompressError::InputExhausted);
241        }
242        let block_header = parse_block_header(&input[offset..])?;
243        offset += 3;
244
245        let block_size = block_header.block_size as usize;
246
247        if block_size > zrip_core::frame::MAX_BLOCK_SIZE {
248            match block_header.block_type {
249                BlockType::Raw | BlockType::Rle => {
250                    return Err(DecompressError::BlockTooLarge);
251                }
252                BlockType::Compressed => {}
253            }
254        }
255
256        match block_header.block_type {
257            BlockType::Raw => {
258                if offset + block_size > input.len() {
259                    return Err(DecompressError::InputExhausted);
260                }
261                if output.len() - output_start + block_size > max_output {
262                    return Err(DecompressError::OutputTooSmall);
263                }
264                output.extend_from_slice(&input[offset..offset + block_size]);
265                offset += block_size;
266            }
267            BlockType::Rle => {
268                if offset >= input.len() {
269                    return Err(DecompressError::InputExhausted);
270                }
271                if output.len() - output_start + block_size > max_output {
272                    return Err(DecompressError::OutputTooSmall);
273                }
274                let byte = input[offset];
275                output.resize(output.len() + block_size, byte);
276                offset += 1;
277            }
278            BlockType::Compressed => {
279                if offset + block_size > input.len() {
280                    return Err(DecompressError::InputExhausted);
281                }
282                let block_data = &input[offset..offset + block_size];
283                decode_compressed_block(
284                    block_data,
285                    output,
286                    output_start,
287                    max_output,
288                    &mut seq_tables,
289                    &mut rep_offsets,
290                    ws,
291                    dict_history,
292                )?;
293                offset += block_size;
294            }
295        }
296
297        if block_header.last_block {
298            break;
299        }
300    }
301
302    if let Some(ref mut hasher) = hasher {
303        hasher.update(&output[output_start..]);
304        let hash = hasher.finish();
305        let expected_checksum = (hash & 0xFFFF_FFFF) as u32;
306
307        if offset + 4 > input.len() {
308            return Err(DecompressError::InputExhausted);
309        }
310        let stored_checksum = u32::from_le_bytes([
311            input[offset],
312            input[offset + 1],
313            input[offset + 2],
314            input[offset + 3],
315        ]);
316        offset += 4;
317
318        if expected_checksum != stored_checksum {
319            return Err(DecompressError::ChecksumMismatch {
320                expected: stored_checksum,
321                got: expected_checksum,
322            });
323        }
324    }
325
326    if let Some(fcs) = header.frame_content_size {
327        if (output.len() - output_start) as u64 != fcs {
328            return Err(DecompressError::FrameSizeMismatch);
329        }
330    }
331
332    Ok(offset)
333}
334
335#[allow(clippy::too_many_arguments)]
336fn decode_compressed_block(
337    data: &[u8],
338    output: &mut Vec<u8>,
339    output_start: usize,
340    max_output: usize,
341    seq_tables: &mut SequenceDecodeTables,
342    rep_offsets: &mut [u32; 3],
343    ws: &mut BlockDecodeWorkspace,
344    dict_history: &[u8],
345) -> Result<(), DecompressError> {
346    let lit_consumed = decode_literals_ws(data, ws)?;
347
348    let remaining = &data[lit_consumed..];
349
350    if remaining.is_empty() {
351        if output.len() - output_start + ws.literal_buf.len() > max_output {
352            return Err(DecompressError::OutputTooSmall);
353        }
354        output.extend_from_slice(&ws.literal_buf);
355        return Ok(());
356    }
357
358    let (num_sequences, seq_count_size) = parse_sequence_count(remaining)?;
359
360    if num_sequences == 0 {
361        if output.len() - output_start + ws.literal_buf.len() > max_output {
362            return Err(DecompressError::OutputTooSmall);
363        }
364        output.extend_from_slice(&ws.literal_buf);
365        return Ok(());
366    }
367
368    let table_data = &remaining[seq_count_size..];
369    let tables_consumed = parse_sequence_tables_ws(table_data, seq_tables, ws)?;
370
371    let seq_data = &table_data[tables_consumed..];
372
373    let before = output.len();
374
375    #[cfg(all(target_arch = "x86_64", not(feature = "paranoid")))]
376    {
377        if zrip_core::simd::cpu_tier() >= CpuTier::Avx2 {
378            decode_execute_block_avx2(
379                seq_data,
380                num_sequences,
381                seq_tables,
382                rep_offsets,
383                &ws.literal_buf,
384                output,
385                dict_history,
386            )?;
387            if output.len() - before > zrip_core::frame::MAX_BLOCK_SIZE {
388                return Err(DecompressError::BlockTooLarge);
389            }
390            return Ok(());
391        }
392    }
393    #[cfg(all(target_arch = "aarch64", not(feature = "paranoid")))]
394    {
395        if zrip_core::simd::cpu_tier() >= CpuTier::Neon {
396            decode_execute_block_neon(
397                seq_data,
398                num_sequences,
399                seq_tables,
400                rep_offsets,
401                &ws.literal_buf,
402                output,
403                dict_history,
404            )?;
405            if output.len() - before > zrip_core::frame::MAX_BLOCK_SIZE {
406                return Err(DecompressError::BlockTooLarge);
407            }
408            return Ok(());
409        }
410    }
411
412    decode_execute_sequences(
413        seq_data,
414        num_sequences,
415        seq_tables,
416        rep_offsets,
417        &ws.literal_buf,
418        output,
419        dict_history,
420    )?;
421    if output.len() - before > zrip_core::frame::MAX_BLOCK_SIZE {
422        return Err(DecompressError::BlockTooLarge);
423    }
424
425    Ok(())
426}
427
428#[cfg(all(target_arch = "x86_64", not(feature = "paranoid")))]
429fn decode_execute_block_avx2(
430    seq_data: &[u8],
431    num_sequences: u32,
432    tables: &mut SequenceDecodeTables,
433    rep_offsets: &mut [u32; 3],
434    literals: &[u8],
435    output: &mut Vec<u8>,
436    history: &[u8],
437) -> Result<(), DecompressError> {
438    crate::simd_decode::x86_64::decode::decode_execute_avx2_safe(
439        seq_data,
440        num_sequences,
441        tables,
442        rep_offsets,
443        literals,
444        output,
445        history,
446    )
447}
448
449#[cfg(all(target_arch = "aarch64", not(feature = "paranoid")))]
450fn decode_execute_block_neon(
451    seq_data: &[u8],
452    num_sequences: u32,
453    tables: &mut SequenceDecodeTables,
454    rep_offsets: &mut [u32; 3],
455    literals: &[u8],
456    output: &mut Vec<u8>,
457    history: &[u8],
458) -> Result<(), DecompressError> {
459    crate::simd_decode::aarch64::decode::decode_execute_neon_safe(
460        seq_data,
461        num_sequences,
462        tables,
463        rep_offsets,
464        literals,
465        output,
466        history,
467    )
468}