Skip to main content

zrip_encode/
context.rs

1#[cfg(feature = "alloc")]
2use alloc::borrow::Cow;
3#[cfg(feature = "alloc")]
4use alloc::vec;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7
8use crate::block_encoder::{self, BlockEncodeWorkspace};
9use crate::strategy::{self, LevelParams, Strategy};
10use crate::{block_looks_incompressible, dfast, fast, write_frame_header};
11use zrip_core::Sequence;
12use zrip_core::dict::Dictionary;
13use zrip_core::error::CompressError;
14use zrip_core::frame::MAX_BLOCK_SIZE;
15use zrip_core::huffman::encode::HuffmanEncodeTable;
16use zrip_core::xxhash::xxh64;
17
18/// Pre-computed dictionary state for hot-loop compression.
19///
20/// Built once from a [`Dictionary`] + [`LevelParams`]. Caches the pre-filled
21/// hash table(s) and a combined buffer with the dict prefix already loaded,
22/// plus encode-side entropy tables built from the dict's decode tables.
23pub(crate) struct PreparedDict {
24    combined: Vec<u8>,
25    hash_snapshot: Vec<u32>,
26    hash_long_snapshot: Vec<u32>,
27    prefix_len: usize,
28    rep_offsets: [u32; 3],
29    dict_id: u32,
30    huf_table: Option<HuffmanEncodeTable>,
31    ll_table: Option<block_encoder::FseEncodeTable>,
32    of_table: Option<block_encoder::FseEncodeTable>,
33    ml_table: Option<block_encoder::FseEncodeTable>,
34}
35
36impl PreparedDict {
37    pub fn new(dict: &Dictionary, params: &LevelParams) -> Self {
38        let prefix = dict.content();
39        let prefix_len = prefix.len();
40
41        let mut combined = Vec::with_capacity(prefix_len + MAX_BLOCK_SIZE);
42        combined.extend_from_slice(prefix);
43
44        let (hash_snapshot, hash_long_snapshot) = match params.strategy {
45            Strategy::Fast => {
46                let hash_size = 1usize << params.hash_log;
47                let mut hash_table = vec![0u32; hash_size];
48                fast::prefill_hash_table(&combined, prefix_len, params.hash_log, &mut hash_table);
49                (hash_table, Vec::new())
50            }
51            Strategy::DFast => {
52                let short_size = 1usize << params.chain_log;
53                let long_size = 1usize << params.hash_log;
54                let mut hash_short = vec![0u32; short_size];
55                let mut hash_long = vec![0u32; long_size];
56                dfast::prefill_hash_tables(
57                    &combined,
58                    prefix_len,
59                    params.hash_log,
60                    params.chain_log,
61                    params.min_match,
62                    &mut hash_short,
63                    &mut hash_long,
64                );
65                (hash_short, hash_long)
66            }
67        };
68
69        let huf_table = dict
70            .huf_table()
71            .and_then(|(dt, tl)| HuffmanEncodeTable::from_decode_table(dt, tl));
72
73        let ll_table = dict
74            .ll_table()
75            .map(|(dt, al)| block_encoder::FseEncodeTable::from_decode_table(dt, al, 35));
76        let of_table = dict
77            .of_table()
78            .map(|(dt, al)| block_encoder::FseEncodeTable::from_decode_table(dt, al, 31));
79        let ml_table = dict
80            .ml_table()
81            .map(|(dt, al)| block_encoder::FseEncodeTable::from_decode_table(dt, al, 52));
82
83        Self {
84            combined,
85            hash_snapshot,
86            hash_long_snapshot,
87            prefix_len,
88            rep_offsets: *dict.rep_offsets(),
89            dict_id: dict.id(),
90            huf_table,
91            ll_table,
92            of_table,
93            ml_table,
94        }
95    }
96}
97
98/// Reusable compression context that amortizes hash table and buffer allocations.
99///
100/// Holds internal state (hash tables, output buffer, block encoder workspace)
101/// across calls. Useful when compressing many small inputs in a loop.
102///
103/// ```
104/// let mut ctx = zrip::CompressContext::new(1).unwrap();
105/// for i in 0..10 {
106///     let data = format!("message {i}").repeat(100);
107///     let compressed = ctx.compress(data.as_bytes()).unwrap();
108///     assert!(compressed.len() < data.len());
109/// }
110/// ```
111pub struct CompressContext {
112    level: i32,
113    prepared: Option<PreparedDict>,
114    hash_table: Vec<u32>,
115    hash_long: Vec<u32>,
116    dict_hash: Vec<u32>,
117    sequences: Vec<Sequence>,
118    output: Vec<u8>,
119    workspace: BlockEncodeWorkspace,
120    combined: Vec<u8>,
121}
122
123impl CompressContext {
124    /// Creates a new context for the given compression level (-7..=4).
125    pub fn new(level: i32) -> Result<Self, CompressError> {
126        let params = strategy::level_params(level).ok_or(CompressError::InvalidLevel(level))?;
127        let max_log = strategy::max_hash_log(level).expect("level validated above");
128        let alloc_size = 1usize << max_log;
129        let (hash_table, hash_long) = match params.strategy {
130            Strategy::Fast => (vec![0u32; alloc_size], Vec::new()),
131            Strategy::DFast => (vec![0u32; alloc_size], vec![0u32; alloc_size]),
132        };
133        Ok(Self {
134            level,
135            prepared: None,
136            hash_table,
137            hash_long,
138            dict_hash: Vec::new(),
139            sequences: Vec::new(),
140            output: Vec::new(),
141            workspace: BlockEncodeWorkspace::new(),
142            combined: Vec::new(),
143        })
144    }
145
146    /// Creates a new context with a pre-loaded dictionary.
147    ///
148    /// The prepared hash table snapshot is built for the T0 (>256 KB)
149    /// parameter tier. Inputs whose tiered params match T0's hash sizes
150    /// use the fast snapshot-restore path; others fall back to per-call
151    /// prefix hashing.
152    ///
153    /// Use [`with_dict_for_size`] to build the snapshot for a specific
154    /// input size tier.
155    pub fn with_dict(level: i32, dict: Dictionary) -> Result<Self, CompressError> {
156        Self::with_dict_for_size(level, dict, usize::MAX)
157    }
158
159    /// Creates a new context with a pre-loaded dictionary, optimized for
160    /// inputs of approximately `expected_size` bytes.
161    ///
162    /// The prepared hash table snapshot is built for the parameter tier
163    /// matching `expected_size`. Inputs in the same tier use the fast
164    /// snapshot-restore path.
165    pub fn with_dict_for_size(
166        level: i32,
167        dict: Dictionary,
168        expected_size: usize,
169    ) -> Result<Self, CompressError> {
170        let total_window = dict.content().len().saturating_add(expected_size);
171        let params = strategy::level_params_for_size(level, total_window)
172            .ok_or(CompressError::InvalidLevel(level))?;
173        let prepared = PreparedDict::new(&dict, &params);
174        let hash_table = vec![0u32; prepared.hash_snapshot.len()];
175        let hash_long = vec![0u32; prepared.hash_long_snapshot.len()];
176        Ok(Self {
177            level,
178            prepared: Some(prepared),
179            hash_table,
180            hash_long,
181            dict_hash: Vec::new(),
182            sequences: Vec::new(),
183            output: Vec::new(),
184            workspace: BlockEncodeWorkspace::new(),
185            combined: Vec::new(),
186        })
187    }
188
189    /// Compresses `input` using the context's level and optional dictionary.
190    pub fn compress(&mut self, input: &[u8]) -> Result<Cow<'_, [u8]>, CompressError> {
191        if self.prepared.is_some() {
192            return self.compress_with_prepared(input);
193        }
194        let params = strategy::level_params_for_size(self.level, input.len())
195            .expect("level validated at construction");
196        compress_core(
197            input,
198            params,
199            None,
200            &[],
201            [1u32, 4, 8],
202            &mut self.hash_table,
203            &mut self.hash_long,
204            &mut self.dict_hash,
205            &mut self.sequences,
206            &mut self.output,
207            &mut self.workspace,
208            &mut self.combined,
209        )?;
210        Ok(self.take_or_borrow_output())
211    }
212
213    /// Compresses `input` using an ad-hoc dictionary (overrides the stored one).
214    pub fn compress_with_dict(
215        &mut self,
216        input: &[u8],
217        dict: &Dictionary,
218    ) -> Result<Cow<'_, [u8]>, CompressError> {
219        let params = strategy::level_params_for_size(self.level, input.len())
220            .expect("level validated at construction");
221        compress_core(
222            input,
223            params,
224            Some(dict.id()),
225            dict.content(),
226            *dict.rep_offsets(),
227            &mut self.hash_table,
228            &mut self.hash_long,
229            &mut self.dict_hash,
230            &mut self.sequences,
231            &mut self.output,
232            &mut self.workspace,
233            &mut self.combined,
234        )?;
235        Ok(self.take_or_borrow_output())
236    }
237
238    fn compress_with_prepared(&mut self, input: &[u8]) -> Result<Cow<'_, [u8]>, CompressError> {
239        let prep = self.prepared.as_ref().unwrap();
240        let total_window = prep.prefix_len + input.len();
241        let params = strategy::level_params_for_size(self.level, total_window)
242            .expect("level validated at construction");
243        let snapshot_matches = match params.strategy {
244            Strategy::Fast => (1usize << params.hash_log) == prep.hash_snapshot.len(),
245            Strategy::DFast => {
246                (1usize << params.chain_log) == prep.hash_snapshot.len()
247                    && (1usize << params.hash_log) == prep.hash_long_snapshot.len()
248            }
249        };
250        let dict_id = prep.dict_id;
251        let prefix_len = prep.prefix_len;
252
253        if !snapshot_matches {
254            return self.compress_with_dict_fallback(input, dict_id, prefix_len);
255        }
256
257        let prep = self.prepared.as_mut().unwrap();
258        self.hash_table.copy_from_slice(&prep.hash_snapshot);
259        if !prep.hash_long_snapshot.is_empty() {
260            self.hash_long.copy_from_slice(&prep.hash_long_snapshot);
261        }
262
263        prep.combined.truncate(prep.prefix_len);
264        prep.combined.extend_from_slice(input);
265
266        if let Some(ref huf) = prep.huf_table {
267            self.workspace.prev_huffman = Some(huf.clone());
268        } else {
269            self.workspace.prev_huffman = None;
270        }
271
272        self.workspace.prev_ll = prep.ll_table.clone();
273        self.workspace.prev_of = prep.of_table.clone();
274        self.workspace.prev_ml = prep.ml_table.clone();
275
276        self.output.clear();
277        self.output.reserve(input.len() + 32);
278        write_frame_header(&mut self.output, input.len(), Some(prep.dict_id));
279
280        if input.is_empty() {
281            block_encoder::encode_raw_block(&[], true, &mut self.output);
282        } else {
283            let prefix_len = prep.prefix_len;
284            let combined = &prep.combined;
285            let mut rep_offsets = prep.rep_offsets;
286
287            if input.len() <= MAX_BLOCK_SIZE {
288                match params.strategy {
289                    Strategy::Fast => {
290                        fast::compress_fast_block(
291                            combined,
292                            prefix_len,
293                            prefix_len + input.len(),
294                            &params,
295                            &rep_offsets,
296                            &mut self.hash_table,
297                            &mut self.sequences,
298                        );
299                    }
300                    Strategy::DFast => {
301                        dfast::compress_dfast_block(
302                            combined,
303                            prefix_len,
304                            prefix_len + input.len(),
305                            &params,
306                            &rep_offsets,
307                            &mut self.hash_table,
308                            &mut self.hash_long,
309                            &mut self.sequences,
310                        );
311                    }
312                }
313                if params.force_raw_literals {
314                    block_encoder::encode_compressed_block_raw(
315                        input,
316                        &self.sequences,
317                        &mut rep_offsets,
318                        true,
319                        &mut self.output,
320                        &mut self.workspace,
321                    );
322                } else {
323                    block_encoder::encode_compressed_block(
324                        input,
325                        &self.sequences,
326                        &mut rep_offsets,
327                        true,
328                        &mut self.output,
329                        &mut self.workspace,
330                    );
331                }
332            } else {
333                let mut offset = 0;
334                while offset < input.len() {
335                    let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
336                    let is_last = offset + chunk_size >= input.len();
337                    match params.strategy {
338                        Strategy::Fast => {
339                            fast::compress_fast_block(
340                                combined,
341                                prefix_len + offset,
342                                prefix_len + offset + chunk_size,
343                                &params,
344                                &rep_offsets,
345                                &mut self.hash_table,
346                                &mut self.sequences,
347                            );
348                        }
349                        Strategy::DFast => {
350                            dfast::compress_dfast_block(
351                                combined,
352                                prefix_len + offset,
353                                prefix_len + offset + chunk_size,
354                                &params,
355                                &rep_offsets,
356                                &mut self.hash_table,
357                                &mut self.hash_long,
358                                &mut self.sequences,
359                            );
360                        }
361                    }
362                    if params.force_raw_literals {
363                        block_encoder::encode_compressed_block_raw(
364                            &input[offset..offset + chunk_size],
365                            &self.sequences,
366                            &mut rep_offsets,
367                            is_last,
368                            &mut self.output,
369                            &mut self.workspace,
370                        );
371                    } else {
372                        block_encoder::encode_compressed_block(
373                            &input[offset..offset + chunk_size],
374                            &self.sequences,
375                            &mut rep_offsets,
376                            is_last,
377                            &mut self.output,
378                            &mut self.workspace,
379                        );
380                    }
381                    offset += chunk_size;
382                }
383            }
384        }
385
386        let hash = xxh64(input, 0);
387        let checksum = (hash & 0xFFFF_FFFF) as u32;
388        self.output.extend_from_slice(&checksum.to_le_bytes());
389
390        Ok(self.take_or_borrow_output())
391    }
392
393    fn compress_with_dict_fallback(
394        &mut self,
395        input: &[u8],
396        dict_id: u32,
397        prefix_len: usize,
398    ) -> Result<Cow<'_, [u8]>, CompressError> {
399        let prep = self.prepared.as_ref().unwrap();
400        let rep_offsets = prep.rep_offsets;
401        let prefix = &prep.combined[..prefix_len];
402
403        let params = strategy::level_params_for_size(self.level, input.len())
404            .expect("level validated at construction");
405        compress_core(
406            input,
407            params,
408            Some(dict_id),
409            prefix,
410            rep_offsets,
411            &mut self.hash_table,
412            &mut self.hash_long,
413            &mut self.dict_hash,
414            &mut self.sequences,
415            &mut self.output,
416            &mut self.workspace,
417            &mut self.combined,
418        )?;
419        Ok(self.take_or_borrow_output())
420    }
421
422    fn take_or_borrow_output(&mut self) -> Cow<'_, [u8]> {
423        if self.output.len() >= zrip_core::LARGE_OUTPUT_THRESHOLD {
424            Cow::Owned(core::mem::take(&mut self.output))
425        } else {
426            Cow::Borrowed(&self.output)
427        }
428    }
429}
430
431#[allow(clippy::too_many_arguments, clippy::unnecessary_wraps)]
432fn compress_core(
433    input: &[u8],
434    params: LevelParams,
435    dict_id: Option<u32>,
436    prefix: &[u8],
437    init_rep_offsets: [u32; 3],
438    hash_table: &mut Vec<u32>,
439    hash_long: &mut Vec<u32>,
440    dict_hash: &mut Vec<u32>,
441    sequences: &mut Vec<Sequence>,
442    output: &mut Vec<u8>,
443    workspace: &mut BlockEncodeWorkspace,
444    combined: &mut Vec<u8>,
445) -> Result<(), CompressError> {
446    let hash_size = match params.strategy {
447        Strategy::Fast => 1usize << params.hash_log,
448        Strategy::DFast => 1usize << params.chain_log,
449    };
450    let long_size = 1usize << params.hash_log;
451
452    workspace.prev_huffman = None;
453
454    output.clear();
455    output.reserve(input.len() + 32);
456    write_frame_header(output, input.len(), dict_id);
457
458    if input.is_empty() {
459        block_encoder::encode_raw_block(&[], true, output);
460    } else {
461        let has_prefix = !prefix.is_empty();
462        let mut rep_offsets = init_rep_offsets;
463        let mut offset = 0;
464
465        if hash_table.len() != hash_size {
466            hash_table.resize(hash_size, 0);
467        }
468
469        match params.strategy {
470            Strategy::Fast => {
471                if has_prefix && input.len() <= MAX_BLOCK_SIZE {
472                    if dict_hash.len() != hash_size {
473                        dict_hash.resize(hash_size, 0);
474                    }
475                    fast::compress_fast_with_prefix_reuse(
476                        input,
477                        &params,
478                        &rep_offsets,
479                        prefix,
480                        dict_hash,
481                        hash_table,
482                        sequences,
483                        combined,
484                    );
485                    if params.force_raw_literals {
486                        block_encoder::encode_compressed_block_raw(
487                            input,
488                            sequences,
489                            &mut rep_offsets,
490                            true,
491                            output,
492                            workspace,
493                        );
494                    } else {
495                        block_encoder::encode_compressed_block(
496                            input,
497                            sequences,
498                            &mut rep_offsets,
499                            true,
500                            output,
501                            workspace,
502                        );
503                    }
504                } else if has_prefix {
505                    combined.clear();
506                    combined.reserve(prefix.len() + input.len());
507                    combined.extend_from_slice(prefix);
508                    combined.extend_from_slice(input);
509                    let plen = prefix.len();
510                    fast::prefill_hash_table(combined, plen, params.hash_log, hash_table);
511
512                    while offset < input.len() {
513                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
514                        let is_last = offset + chunk_size >= input.len();
515                        fast::compress_fast_block(
516                            combined,
517                            plen + offset,
518                            plen + offset + chunk_size,
519                            &params,
520                            &rep_offsets,
521                            hash_table,
522                            sequences,
523                        );
524                        if params.force_raw_literals {
525                            block_encoder::encode_compressed_block_raw(
526                                &input[offset..offset + chunk_size],
527                                sequences,
528                                &mut rep_offsets,
529                                is_last,
530                                output,
531                                workspace,
532                            );
533                        } else {
534                            block_encoder::encode_compressed_block(
535                                &input[offset..offset + chunk_size],
536                                sequences,
537                                &mut rep_offsets,
538                                is_last,
539                                output,
540                                workspace,
541                            );
542                        }
543                        offset += chunk_size;
544                    }
545                } else {
546                    hash_table.fill(0);
547                    while offset < input.len() {
548                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
549                        let block_end = offset + chunk_size;
550                        let is_last = block_end >= input.len();
551
552                        if block_looks_incompressible(&input[offset..block_end]) {
553                            block_encoder::encode_raw_block(
554                                &input[offset..block_end],
555                                is_last,
556                                output,
557                            );
558                        } else {
559                            fast::compress_fast_block(
560                                input,
561                                offset,
562                                block_end,
563                                &params,
564                                &rep_offsets,
565                                hash_table,
566                                sequences,
567                            );
568                            if params.force_raw_literals {
569                                block_encoder::encode_compressed_block_raw(
570                                    &input[offset..block_end],
571                                    sequences,
572                                    &mut rep_offsets,
573                                    is_last,
574                                    output,
575                                    workspace,
576                                );
577                            } else {
578                                block_encoder::encode_compressed_block(
579                                    &input[offset..block_end],
580                                    sequences,
581                                    &mut rep_offsets,
582                                    is_last,
583                                    output,
584                                    workspace,
585                                );
586                            }
587                        }
588                        offset = block_end;
589                    }
590                }
591            }
592            Strategy::DFast => {
593                if hash_long.len() != long_size {
594                    hash_long.resize(long_size, 0);
595                }
596                if has_prefix && input.len() <= MAX_BLOCK_SIZE {
597                    dfast::compress_dfast_with_prefix_reuse(
598                        input,
599                        &params,
600                        &rep_offsets,
601                        prefix,
602                        hash_table,
603                        hash_long,
604                        sequences,
605                        combined,
606                    );
607                    block_encoder::encode_compressed_block(
608                        input,
609                        sequences,
610                        &mut rep_offsets,
611                        true,
612                        output,
613                        workspace,
614                    );
615                } else if has_prefix {
616                    combined.clear();
617                    combined.reserve(prefix.len() + input.len());
618                    combined.extend_from_slice(prefix);
619                    combined.extend_from_slice(input);
620                    let plen = prefix.len();
621                    dfast::prefill_hash_tables(
622                        combined,
623                        plen,
624                        params.hash_log,
625                        params.chain_log,
626                        params.min_match,
627                        hash_table,
628                        hash_long,
629                    );
630
631                    while offset < input.len() {
632                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
633                        let is_last = offset + chunk_size >= input.len();
634                        dfast::compress_dfast_block(
635                            combined,
636                            plen + offset,
637                            plen + offset + chunk_size,
638                            &params,
639                            &rep_offsets,
640                            hash_table,
641                            hash_long,
642                            sequences,
643                        );
644                        block_encoder::encode_compressed_block(
645                            &input[offset..offset + chunk_size],
646                            sequences,
647                            &mut rep_offsets,
648                            is_last,
649                            output,
650                            workspace,
651                        );
652                        offset += chunk_size;
653                    }
654                } else {
655                    hash_table.fill(0);
656                    hash_long.fill(0);
657                    while offset < input.len() {
658                        let chunk_size = (input.len() - offset).min(MAX_BLOCK_SIZE);
659                        let block_end = offset + chunk_size;
660                        let is_last = block_end >= input.len();
661
662                        if block_looks_incompressible(&input[offset..block_end]) {
663                            block_encoder::encode_raw_block(
664                                &input[offset..block_end],
665                                is_last,
666                                output,
667                            );
668                        } else {
669                            dfast::compress_dfast_block(
670                                input,
671                                offset,
672                                block_end,
673                                &params,
674                                &rep_offsets,
675                                hash_table,
676                                hash_long,
677                                sequences,
678                            );
679                            block_encoder::encode_compressed_block(
680                                &input[offset..block_end],
681                                sequences,
682                                &mut rep_offsets,
683                                is_last,
684                                output,
685                                workspace,
686                            );
687                        }
688                        offset = block_end;
689                    }
690                }
691            }
692        }
693    }
694
695    let hash = xxh64(input, 0);
696    let checksum = (hash & 0xFFFF_FFFF) as u32;
697    output.extend_from_slice(&checksum.to_le_bytes());
698
699    Ok(())
700}