1#![forbid(unsafe_code)]
2
3use std::io::{self, Read};
4
5use crate::BlockDecodeWorkspace;
6use crate::literals::decode_literals_ws;
7use crate::sequences::{SequenceDecodeTables, parse_sequence_count, parse_sequence_tables_ws};
8
9use crate::exec::decode_execute_sequences;
10use zrip_core::block::{BlockType, parse_block_header};
11use zrip_core::dict::Dictionary;
12use zrip_core::error::DecompressError;
13use zrip_core::frame::MAX_BLOCK_SIZE;
14use zrip_core::frame::header::parse_frame_header;
15use zrip_core::fse::{promote_ll_table, promote_ml_table, promote_of_table};
16use zrip_core::xxhash::Xxh64State;
17
18#[cfg(all(
19 any(target_arch = "x86_64", target_arch = "aarch64"),
20 not(feature = "paranoid")
21))]
22use zrip_core::simd::CpuTier;
23
24enum State {
25 FrameHeader,
26 BlockHeader,
27 BlockData {
28 block_type: BlockType,
29 block_size: usize,
30 last: bool,
31 },
32 Checksum,
33 Done,
34}
35
36pub struct FrameDecoder<R: Read> {
53 inner: R,
54 state: State,
55 read_buf: Vec<u8>,
56 output_buf: Vec<u8>,
57 output_pos: usize,
58 ws: Box<BlockDecodeWorkspace>,
59 seq_tables: SequenceDecodeTables,
60 rep_offsets: [u32; 3],
61 hasher: Option<Xxh64State>,
62 content_checksum: bool,
63 max_output: usize,
64 bytes_output: usize,
65 frame_content_size: Option<u64>,
66 frame_bytes: usize,
67 dict: Option<Dictionary>,
68}
69
70impl<R: Read> FrameDecoder<R> {
71 pub fn new(reader: R) -> Self {
73 Self::with_limit(reader, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
74 }
75
76 pub fn with_limit(reader: R, max_output: usize) -> Self {
78 Self {
79 inner: reader,
80 state: State::FrameHeader,
81 read_buf: Vec::new(),
82 output_buf: Vec::new(),
83 output_pos: 0,
84 ws: Box::new(BlockDecodeWorkspace::new()),
85 seq_tables: SequenceDecodeTables::new_default(),
86 rep_offsets: [1, 4, 8],
87 hasher: None,
88 content_checksum: false,
89 max_output,
90 bytes_output: 0,
91 frame_content_size: None,
92 frame_bytes: 0,
93 dict: None,
94 }
95 }
96
97 pub fn with_dict(reader: R, dict: Dictionary) -> Self {
99 Self::with_dict_and_limit(reader, dict, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
100 }
101
102 pub fn with_dict_and_limit(reader: R, dict: Dictionary, max_output: usize) -> Self {
104 Self {
105 inner: reader,
106 state: State::FrameHeader,
107 read_buf: Vec::new(),
108 output_buf: Vec::new(),
109 output_pos: 0,
110 ws: Box::new(BlockDecodeWorkspace::new()),
111 seq_tables: SequenceDecodeTables::new_default(),
112 rep_offsets: [1, 4, 8],
113 hasher: None,
114 content_checksum: false,
115 max_output,
116 bytes_output: 0,
117 frame_content_size: None,
118 frame_bytes: 0,
119 dict: Some(dict),
120 }
121 }
122
123 pub fn into_inner(self) -> R {
125 self.inner
126 }
127
128 pub fn reset(&mut self, new_reader: R) -> R {
131 let old = core::mem::replace(&mut self.inner, new_reader);
132 self.state = State::FrameHeader;
133 self.output_buf.clear();
134 self.output_pos = 0;
135 self.rep_offsets = [1, 4, 8];
136 self.seq_tables = SequenceDecodeTables::new_default();
137 self.ws.huf_valid = false;
138 self.hasher = None;
139 self.content_checksum = false;
140 self.bytes_output = 0;
141 self.frame_content_size = None;
142 self.frame_bytes = 0;
143 old
144 }
145
146 fn fill_output(&mut self) -> io::Result<()> {
147 loop {
148 match self.state {
149 State::Done => return Ok(()),
150 State::FrameHeader => self.read_frame_header()?,
151 State::BlockHeader => self.read_block_header()?,
152 State::BlockData {
153 block_type,
154 block_size,
155 last,
156 } => {
157 self.read_block_data(block_type, block_size, last)?;
158 if self.output_pos < self.output_buf.len() {
159 return Ok(());
160 }
161 }
162 State::Checksum => self.read_checksum()?,
163 }
164 }
165 }
166
167 fn read_frame_header(&mut self) -> io::Result<()> {
168 self.read_buf.resize(18, 0);
169 self.inner.read_exact(&mut self.read_buf[..5])?;
170
171 let magic = u32::from_le_bytes([
172 self.read_buf[0],
173 self.read_buf[1],
174 self.read_buf[2],
175 self.read_buf[3],
176 ]);
177
178 if (magic & 0xFFFF_FFF0) == 0x184D_2A50 {
179 self.inner.read_exact(&mut self.read_buf[5..9])?;
180 let skip_size = u32::from_le_bytes([
181 self.read_buf[5],
182 self.read_buf[6],
183 self.read_buf[7],
184 self.read_buf[8],
185 ]) as usize;
186 io::copy(
187 &mut self.inner.by_ref().take(skip_size as u64),
188 &mut io::sink(),
189 )?;
190 return Ok(());
191 }
192
193 let descriptor = self.read_buf[4];
194 let single_segment = (descriptor & 0x20) != 0;
195 let dict_id_flag = descriptor & 0x03;
196 let fcs_flag = (descriptor >> 6) & 0x03;
197
198 let mut hdr_len = 5usize;
199 if !single_segment {
200 hdr_len += 1;
201 }
202 hdr_len += match dict_id_flag {
203 0 => 0,
204 1 => 1,
205 2 => 2,
206 3 => 4,
207 _ => unreachable!(),
208 };
209 hdr_len += match fcs_flag {
210 0 if single_segment => 1,
211 0 => 0,
212 1 => 2,
213 2 => 4,
214 3 => 8,
215 _ => unreachable!(),
216 };
217
218 if hdr_len > 5 {
219 self.inner.read_exact(&mut self.read_buf[5..hdr_len])?;
220 }
221
222 let header = parse_frame_header(&self.read_buf[..hdr_len])
223 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
224
225 if let Some(frame_dict_id) = header.dict_id {
226 match &self.dict {
227 Some(d) if d.id() == frame_dict_id => {}
228 Some(d) => {
229 return Err(io::Error::new(
230 io::ErrorKind::InvalidData,
231 DecompressError::DictMismatch {
232 expected: frame_dict_id,
233 got: d.id(),
234 },
235 ));
236 }
237 None => {
238 return Err(io::Error::new(
239 io::ErrorKind::InvalidData,
240 DecompressError::DictRequired,
241 ));
242 }
243 }
244 }
245
246 if let Some(fcs) = header.frame_content_size {
247 if fcs as usize > self.max_output {
248 return Err(io::Error::new(
249 io::ErrorKind::InvalidData,
250 DecompressError::OutputTooSmall,
251 ));
252 }
253 }
254
255 self.frame_content_size = header.frame_content_size;
256 self.frame_bytes = 0;
257 self.content_checksum = header.content_checksum;
258 self.hasher = if header.content_checksum {
259 Some(Xxh64State::new(0))
260 } else {
261 None
262 };
263
264 if let Some(ref d) = self.dict {
265 self.rep_offsets = *d.rep_offsets();
266 let mut st = SequenceDecodeTables::new_default();
267 if let Some((t, l)) = d.of_table() {
268 st.of_table = promote_of_table(t);
269 st.of_accuracy = l;
270 st.of_set = true;
271 }
272 if let Some((t, l)) = d.ml_table() {
273 st.ml_table = promote_ml_table(t);
274 st.ml_accuracy = l;
275 st.ml_set = true;
276 }
277 if let Some((t, l)) = d.ll_table() {
278 st.ll_table = promote_ll_table(t);
279 st.ll_accuracy = l;
280 st.ll_set = true;
281 }
282 self.seq_tables = st;
283 self.ws.huf_valid = false;
284 if let Some((t, l)) = d.huf_table() {
285 self.ws.huf_table.clear();
286 self.ws.huf_table.extend_from_slice(t);
287 self.ws.huf_table_log = l;
288 self.ws.huf_valid = true;
289 }
290 } else {
291 self.rep_offsets = [1, 4, 8];
292 self.seq_tables = SequenceDecodeTables::new_default();
293 self.ws.huf_valid = false;
294 }
295
296 self.state = State::BlockHeader;
297 Ok(())
298 }
299
300 fn read_block_header(&mut self) -> io::Result<()> {
301 let mut hdr = [0u8; 3];
302 self.inner.read_exact(&mut hdr)?;
303 let block_header =
304 parse_block_header(&hdr).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
305
306 let block_size = block_header.block_size as usize;
307
308 match block_header.block_type {
309 BlockType::Raw | BlockType::Rle if block_size > MAX_BLOCK_SIZE => {
310 return Err(io::Error::new(
311 io::ErrorKind::InvalidData,
312 DecompressError::BlockTooLarge,
313 ));
314 }
315 _ => {}
316 }
317
318 self.state = State::BlockData {
319 block_type: block_header.block_type,
320 block_size,
321 last: block_header.last_block,
322 };
323 Ok(())
324 }
325
326 fn read_block_data(
327 &mut self,
328 block_type: BlockType,
329 block_size: usize,
330 last: bool,
331 ) -> io::Result<()> {
332 self.output_buf.clear();
333 self.output_pos = 0;
334
335 match block_type {
336 BlockType::Raw => {
337 self.output_buf.resize(block_size, 0);
338 self.inner.read_exact(&mut self.output_buf)?;
339 }
340 BlockType::Rle => {
341 let mut byte = [0u8; 1];
342 self.inner.read_exact(&mut byte)?;
343 self.output_buf.resize(block_size, byte[0]);
344 }
345 BlockType::Compressed => {
346 self.read_buf.resize(block_size, 0);
347 self.inner.read_exact(&mut self.read_buf[..block_size])?;
348 self.decode_compressed_block(block_size)?;
349 }
350 }
351
352 if let Some(ref mut hasher) = self.hasher {
353 hasher.update(&self.output_buf);
354 }
355 self.bytes_output += self.output_buf.len();
356 self.frame_bytes += self.output_buf.len();
357 if self.bytes_output > self.max_output {
358 return Err(io::Error::new(
359 io::ErrorKind::InvalidData,
360 DecompressError::OutputTooSmall,
361 ));
362 }
363
364 self.state = if last {
365 if let Some(fcs) = self.frame_content_size {
366 if self.frame_bytes as u64 != fcs {
367 return Err(io::Error::new(
368 io::ErrorKind::InvalidData,
369 DecompressError::FrameSizeMismatch,
370 ));
371 }
372 }
373 if self.content_checksum {
374 State::Checksum
375 } else {
376 State::FrameHeader
377 }
378 } else {
379 State::BlockHeader
380 };
381
382 Ok(())
383 }
384
385 fn decode_compressed_block(&mut self, block_size: usize) -> io::Result<()> {
386 let dict_history: &[u8] = match &self.dict {
387 Some(d) => d.content(),
388 None => &[],
389 };
390 let block_data = &self.read_buf[..block_size];
391
392 let lit_consumed = decode_literals_ws(block_data, &mut self.ws)
393 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
394
395 let remaining = &block_data[lit_consumed..];
396
397 if remaining.is_empty() {
398 self.output_buf.extend_from_slice(&self.ws.literal_buf);
399 return Ok(());
400 }
401
402 let (num_sequences, seq_count_size) = parse_sequence_count(remaining)
403 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
404
405 if num_sequences == 0 {
406 self.output_buf.extend_from_slice(&self.ws.literal_buf);
407 return Ok(());
408 }
409
410 let table_data = &remaining[seq_count_size..];
411 let tables_consumed =
412 parse_sequence_tables_ws(table_data, &mut self.seq_tables, &mut self.ws)
413 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
414
415 let seq_data = &table_data[tables_consumed..];
416
417 #[cfg(all(target_arch = "x86_64", not(feature = "paranoid")))]
418 {
419 if zrip_core::simd::cpu_tier() >= CpuTier::Avx2 {
420 let before = self.output_buf.len();
421 crate::simd_decode::x86_64::decode::decode_execute_avx2_safe(
422 seq_data,
423 num_sequences,
424 &self.seq_tables,
425 &mut self.rep_offsets,
426 &self.ws.literal_buf,
427 &mut self.output_buf,
428 dict_history,
429 )
430 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
431 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
432 return Err(io::Error::new(
433 io::ErrorKind::InvalidData,
434 DecompressError::BlockTooLarge,
435 ));
436 }
437 return Ok(());
438 }
439 }
440
441 #[cfg(all(target_arch = "aarch64", not(feature = "paranoid")))]
442 {
443 if zrip_core::simd::cpu_tier() >= CpuTier::Neon {
444 let before = self.output_buf.len();
445 crate::simd_decode::aarch64::decode::decode_execute_neon_safe(
446 seq_data,
447 num_sequences,
448 &self.seq_tables,
449 &mut self.rep_offsets,
450 &self.ws.literal_buf,
451 &mut self.output_buf,
452 dict_history,
453 )
454 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
455 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
456 return Err(io::Error::new(
457 io::ErrorKind::InvalidData,
458 DecompressError::BlockTooLarge,
459 ));
460 }
461 return Ok(());
462 }
463 }
464
465 let before = self.output_buf.len();
466 decode_execute_sequences(
467 seq_data,
468 num_sequences,
469 &self.seq_tables,
470 &mut self.rep_offsets,
471 &self.ws.literal_buf,
472 &mut self.output_buf,
473 dict_history,
474 )
475 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
476 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
477 return Err(io::Error::new(
478 io::ErrorKind::InvalidData,
479 DecompressError::BlockTooLarge,
480 ));
481 }
482 Ok(())
483 }
484
485 fn read_checksum(&mut self) -> io::Result<()> {
486 let mut buf = [0u8; 4];
487 self.inner.read_exact(&mut buf)?;
488 let stored = u32::from_le_bytes(buf);
489
490 if let Some(ref hasher) = self.hasher {
491 let hash = hasher.finish();
492 let expected = (hash & 0xFFFF_FFFF) as u32;
493 if expected != stored {
494 return Err(io::Error::new(
495 io::ErrorKind::InvalidData,
496 DecompressError::ChecksumMismatch {
497 expected: stored,
498 got: expected,
499 },
500 ));
501 }
502 }
503
504 self.state = State::FrameHeader;
505 Ok(())
506 }
507}
508
509impl<R: Read> Read for FrameDecoder<R> {
510 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
511 if self.output_pos >= self.output_buf.len() {
512 if let State::Done = &self.state {
513 return Ok(0);
514 }
515
516 self.output_buf.clear();
517 self.output_pos = 0;
518
519 match self.fill_output() {
520 Ok(()) => {}
521 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => match &self.state {
522 State::FrameHeader => {
523 self.state = State::Done;
524 return Ok(0);
525 }
526 _ => return Err(e),
527 },
528 Err(e) => return Err(e),
529 }
530 }
531
532 let available = &self.output_buf[self.output_pos..];
533 let n = buf.len().min(available.len());
534 buf[..n].copy_from_slice(&available[..n]);
535 self.output_pos += n;
536 Ok(n)
537 }
538}