1#![forbid(unsafe_code)]
2
3use std::io::{self, Read};
4
5use crate::BlockDecodeWorkspace;
6use crate::literals::decode_literals_ws;
7use crate::sequences::{SequenceDecodeTables, parse_sequence_count, parse_sequence_tables_ws};
8use zrip_core::block::{BlockType, parse_block_header};
9use zrip_core::dict::Dictionary;
10use zrip_core::error::DecompressError;
11use zrip_core::frame::MAX_BLOCK_SIZE;
12use zrip_core::frame::header::parse_frame_header;
13use zrip_core::fse::{promote_ll_table, promote_ml_table, promote_of_table};
14use zrip_core::xxhash::Xxh64State;
15
16#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
17use zrip_core::simd::CpuTier;
18
19enum State {
20 FrameHeader,
21 BlockHeader,
22 BlockData {
23 block_type: BlockType,
24 block_size: usize,
25 last: bool,
26 },
27 Checksum,
28 Done,
29}
30
31pub struct FrameDecoder<R: Read> {
48 inner: R,
49 state: State,
50 read_buf: Vec<u8>,
51 output_buf: Vec<u8>,
52 output_pos: usize,
53 ws: Box<BlockDecodeWorkspace>,
54 seq_tables: SequenceDecodeTables,
55 rep_offsets: [u32; 3],
56 hasher: Option<Xxh64State>,
57 content_checksum: bool,
58 max_output: usize,
59 bytes_output: usize,
60 dict: Option<Dictionary>,
61}
62
63impl<R: Read> FrameDecoder<R> {
64 pub fn new(reader: R) -> Self {
66 Self::with_limit(reader, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
67 }
68
69 pub fn with_limit(reader: R, max_output: usize) -> Self {
71 Self {
72 inner: reader,
73 state: State::FrameHeader,
74 read_buf: Vec::new(),
75 output_buf: Vec::new(),
76 output_pos: 0,
77 ws: Box::new(BlockDecodeWorkspace::new()),
78 seq_tables: SequenceDecodeTables::new_default(),
79 rep_offsets: [1, 4, 8],
80 hasher: None,
81 content_checksum: false,
82 max_output,
83 bytes_output: 0,
84 dict: None,
85 }
86 }
87
88 pub fn with_dict(reader: R, dict: Dictionary) -> Self {
90 Self::with_dict_and_limit(reader, dict, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
91 }
92
93 pub fn with_dict_and_limit(reader: R, dict: Dictionary, max_output: usize) -> Self {
95 Self {
96 inner: reader,
97 state: State::FrameHeader,
98 read_buf: Vec::new(),
99 output_buf: Vec::new(),
100 output_pos: 0,
101 ws: Box::new(BlockDecodeWorkspace::new()),
102 seq_tables: SequenceDecodeTables::new_default(),
103 rep_offsets: [1, 4, 8],
104 hasher: None,
105 content_checksum: false,
106 max_output,
107 bytes_output: 0,
108 dict: Some(dict),
109 }
110 }
111
112 pub fn into_inner(self) -> R {
114 self.inner
115 }
116
117 pub fn reset(&mut self, new_reader: R) -> R {
120 let old = core::mem::replace(&mut self.inner, new_reader);
121 self.state = State::FrameHeader;
122 self.output_buf.clear();
123 self.output_pos = 0;
124 self.rep_offsets = [1, 4, 8];
125 self.seq_tables = SequenceDecodeTables::new_default();
126 self.ws.huf_valid = false;
127 self.hasher = None;
128 self.content_checksum = false;
129 self.bytes_output = 0;
130 old
131 }
132
133 fn fill_output(&mut self) -> io::Result<()> {
134 loop {
135 match self.state {
136 State::Done => return Ok(()),
137 State::FrameHeader => self.read_frame_header()?,
138 State::BlockHeader => self.read_block_header()?,
139 State::BlockData {
140 block_type,
141 block_size,
142 last,
143 } => {
144 self.read_block_data(block_type, block_size, last)?;
145 if self.output_pos < self.output_buf.len() {
146 return Ok(());
147 }
148 }
149 State::Checksum => self.read_checksum()?,
150 }
151 }
152 }
153
154 fn read_frame_header(&mut self) -> io::Result<()> {
155 self.read_buf.resize(18, 0);
156 self.inner.read_exact(&mut self.read_buf[..5])?;
157
158 let magic = u32::from_le_bytes([
159 self.read_buf[0],
160 self.read_buf[1],
161 self.read_buf[2],
162 self.read_buf[3],
163 ]);
164
165 if (magic & 0xFFFFFFF0) == 0x184D2A50 {
166 self.inner.read_exact(&mut self.read_buf[5..9])?;
167 let skip_size = u32::from_le_bytes([
168 self.read_buf[5],
169 self.read_buf[6],
170 self.read_buf[7],
171 self.read_buf[8],
172 ]) as usize;
173 io::copy(
174 &mut self.inner.by_ref().take(skip_size as u64),
175 &mut io::sink(),
176 )?;
177 return Ok(());
178 }
179
180 let descriptor = self.read_buf[4];
181 let single_segment = (descriptor & 0x20) != 0;
182 let dict_id_flag = descriptor & 0x03;
183 let fcs_flag = (descriptor >> 6) & 0x03;
184
185 let mut hdr_len = 5usize;
186 if !single_segment {
187 hdr_len += 1;
188 }
189 hdr_len += match dict_id_flag {
190 0 => 0,
191 1 => 1,
192 2 => 2,
193 3 => 4,
194 _ => unreachable!(),
195 };
196 hdr_len += match fcs_flag {
197 0 if single_segment => 1,
198 0 => 0,
199 1 => 2,
200 2 => 4,
201 3 => 8,
202 _ => unreachable!(),
203 };
204
205 if hdr_len > 5 {
206 self.inner.read_exact(&mut self.read_buf[5..hdr_len])?;
207 }
208
209 let header = parse_frame_header(&self.read_buf[..hdr_len])
210 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
211
212 if let Some(frame_dict_id) = header.dict_id {
213 match &self.dict {
214 Some(d) if d.id() == frame_dict_id => {}
215 Some(d) => {
216 return Err(io::Error::new(
217 io::ErrorKind::InvalidData,
218 DecompressError::DictMismatch {
219 expected: frame_dict_id,
220 got: d.id(),
221 },
222 ));
223 }
224 None => {
225 return Err(io::Error::new(
226 io::ErrorKind::InvalidData,
227 DecompressError::DictRequired,
228 ));
229 }
230 }
231 }
232
233 if let Some(fcs) = header.frame_content_size {
234 if fcs as usize > self.max_output {
235 return Err(io::Error::new(
236 io::ErrorKind::InvalidData,
237 DecompressError::OutputTooSmall,
238 ));
239 }
240 }
241
242 self.content_checksum = header.content_checksum;
243 self.hasher = if header.content_checksum {
244 Some(Xxh64State::new(0))
245 } else {
246 None
247 };
248
249 if let Some(ref d) = self.dict {
250 self.rep_offsets = *d.rep_offsets();
251 let mut st = SequenceDecodeTables::new_default();
252 if let Some((t, l)) = d.of_table() {
253 st.of_table = promote_of_table(t);
254 st.of_accuracy = l;
255 }
256 if let Some((t, l)) = d.ml_table() {
257 st.ml_table = promote_ml_table(t);
258 st.ml_accuracy = l;
259 }
260 if let Some((t, l)) = d.ll_table() {
261 st.ll_table = promote_ll_table(t);
262 st.ll_accuracy = l;
263 }
264 self.seq_tables = st;
265 self.ws.huf_valid = false;
266 if let Some((t, l)) = d.huf_table() {
267 self.ws.huf_table.clear();
268 self.ws.huf_table.extend_from_slice(t);
269 self.ws.huf_table_log = l;
270 self.ws.huf_valid = true;
271 }
272 } else {
273 self.rep_offsets = [1, 4, 8];
274 self.seq_tables = SequenceDecodeTables::new_default();
275 self.ws.huf_valid = false;
276 }
277
278 self.state = State::BlockHeader;
279 Ok(())
280 }
281
282 fn read_block_header(&mut self) -> io::Result<()> {
283 let mut hdr = [0u8; 3];
284 self.inner.read_exact(&mut hdr)?;
285 let block_header =
286 parse_block_header(&hdr).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
287
288 let block_size = block_header.block_size as usize;
289
290 match block_header.block_type {
291 BlockType::Raw | BlockType::Rle if block_size > MAX_BLOCK_SIZE => {
292 return Err(io::Error::new(
293 io::ErrorKind::InvalidData,
294 DecompressError::CorruptSequences,
295 ));
296 }
297 _ => {}
298 }
299
300 self.state = State::BlockData {
301 block_type: block_header.block_type,
302 block_size,
303 last: block_header.last_block,
304 };
305 Ok(())
306 }
307
308 fn read_block_data(
309 &mut self,
310 block_type: BlockType,
311 block_size: usize,
312 last: bool,
313 ) -> io::Result<()> {
314 self.output_buf.clear();
315 self.output_pos = 0;
316
317 match block_type {
318 BlockType::Raw => {
319 self.output_buf.resize(block_size, 0);
320 self.inner.read_exact(&mut self.output_buf)?;
321 }
322 BlockType::Rle => {
323 let mut byte = [0u8; 1];
324 self.inner.read_exact(&mut byte)?;
325 self.output_buf.resize(block_size, byte[0]);
326 }
327 BlockType::Compressed => {
328 self.read_buf.resize(block_size, 0);
329 self.inner.read_exact(&mut self.read_buf[..block_size])?;
330 self.decode_compressed_block(block_size)?;
331 }
332 }
333
334 if let Some(ref mut hasher) = self.hasher {
335 hasher.update(&self.output_buf);
336 }
337 self.bytes_output += self.output_buf.len();
338 if self.bytes_output > self.max_output {
339 return Err(io::Error::new(
340 io::ErrorKind::InvalidData,
341 DecompressError::OutputTooSmall,
342 ));
343 }
344
345 self.state = if last {
346 if self.content_checksum {
347 State::Checksum
348 } else {
349 State::FrameHeader
350 }
351 } else {
352 State::BlockHeader
353 };
354
355 Ok(())
356 }
357
358 fn decode_compressed_block(&mut self, block_size: usize) -> io::Result<()> {
359 let dict_history: &[u8] = match &self.dict {
360 Some(d) => d.content(),
361 None => &[],
362 };
363 let block_data = &self.read_buf[..block_size];
364
365 let lit_consumed = decode_literals_ws(block_data, &mut self.ws)
366 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
367
368 let remaining = &block_data[lit_consumed..];
369
370 if remaining.is_empty() {
371 self.output_buf.extend_from_slice(&self.ws.literal_buf);
372 return Ok(());
373 }
374
375 let (num_sequences, seq_count_size) = parse_sequence_count(remaining)
376 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
377
378 if num_sequences == 0 {
379 self.output_buf.extend_from_slice(&self.ws.literal_buf);
380 return Ok(());
381 }
382
383 let table_data = &remaining[seq_count_size..];
384 let tables_consumed =
385 parse_sequence_tables_ws(table_data, &mut self.seq_tables, &mut self.ws)
386 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
387
388 let seq_data = &table_data[tables_consumed..];
389
390 #[cfg(target_arch = "x86_64")]
391 {
392 if zrip_core::simd::cpu_tier() >= CpuTier::Avx2 {
393 let before = self.output_buf.len();
394 crate::simd_decode::x86_64::decode::decode_execute_avx2_safe(
395 seq_data,
396 num_sequences,
397 &self.seq_tables,
398 &mut self.rep_offsets,
399 &self.ws.literal_buf,
400 &mut self.output_buf,
401 dict_history,
402 )
403 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
404 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
405 return Err(io::Error::new(
406 io::ErrorKind::InvalidData,
407 DecompressError::CorruptSequences,
408 ));
409 }
410 return Ok(());
411 }
412 }
413
414 #[cfg(target_arch = "aarch64")]
415 {
416 if zrip_core::simd::cpu_tier() >= CpuTier::Neon {
417 let before = self.output_buf.len();
418 crate::simd_decode::aarch64::decode::decode_execute_neon_safe(
419 seq_data,
420 num_sequences,
421 &self.seq_tables,
422 &mut self.rep_offsets,
423 &self.ws.literal_buf,
424 &mut self.output_buf,
425 dict_history,
426 )
427 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
428 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
429 return Err(io::Error::new(
430 io::ErrorKind::InvalidData,
431 DecompressError::CorruptSequences,
432 ));
433 }
434 return Ok(());
435 }
436 }
437
438 let before = self.output_buf.len();
439 crate::exec::decode_execute_sequences(
440 seq_data,
441 num_sequences,
442 &self.seq_tables,
443 &mut self.rep_offsets,
444 &self.ws.literal_buf,
445 &mut self.output_buf,
446 dict_history,
447 )
448 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
449 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
450 return Err(io::Error::new(
451 io::ErrorKind::InvalidData,
452 DecompressError::CorruptSequences,
453 ));
454 }
455 Ok(())
456 }
457
458 fn read_checksum(&mut self) -> io::Result<()> {
459 let mut buf = [0u8; 4];
460 self.inner.read_exact(&mut buf)?;
461 let stored = u32::from_le_bytes(buf);
462
463 if let Some(ref hasher) = self.hasher {
464 let hash = hasher.finish();
465 let expected = (hash & 0xFFFFFFFF) as u32;
466 if expected != stored {
467 return Err(io::Error::new(
468 io::ErrorKind::InvalidData,
469 DecompressError::ChecksumMismatch {
470 expected: stored,
471 got: expected,
472 },
473 ));
474 }
475 }
476
477 self.state = State::FrameHeader;
478 Ok(())
479 }
480}
481
482impl<R: Read> Read for FrameDecoder<R> {
483 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
484 if self.output_pos >= self.output_buf.len() {
485 if let State::Done = &self.state {
486 return Ok(0);
487 }
488
489 self.output_buf.clear();
490 self.output_pos = 0;
491
492 match self.fill_output() {
493 Ok(()) => {}
494 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => match &self.state {
495 State::FrameHeader => {
496 self.state = State::Done;
497 return Ok(0);
498 }
499 _ => return Err(e),
500 },
501 Err(e) => return Err(e),
502 }
503 }
504
505 let available = &self.output_buf[self.output_pos..];
506 let n = buf.len().min(available.len());
507 buf[..n].copy_from_slice(&available[..n]);
508 self.output_pos += n;
509 Ok(n)
510 }
511}