1use super::constants::*;
7
8pub fn compress(data: &[u8], level: i32) -> Vec<u8> {
19 let mut out = Vec::with_capacity(data.len() + 64);
20 write_frame_header(&mut out, data.len() as u64);
21
22 if data.is_empty() {
23 write_raw_block(&mut out, &[], true);
24 return out;
25 }
26
27 if level <= 0 {
28 let blocks: Vec<&[u8]> = data.chunks(ZSTD_BLOCKSIZE_MAX).collect();
30 let n_blocks = blocks.len();
31 for (i, block) in blocks.iter().enumerate() {
32 if is_rle_block(block) {
33 write_rle_block(&mut out, block[0], block.len(), i == n_blocks - 1);
34 } else {
35 write_raw_block(&mut out, block, i == n_blocks - 1);
36 }
37 }
38 return out;
39 }
40
41 let params = MatchParams::from_level(level);
46 let mut all_sequences = find_matches(data, ¶ms);
47
48 if params.lazy_depth == 0 {
50 let lazy_params = MatchParams {
51 lazy_depth: 1,
52 hash_log: 17,
53 hash_bytes: 5,
54 search_depth: 8,
55 };
56
57 #[cfg(feature = "parallel")]
58 let lazy_seqs = {
59 let lazy_params_clone = lazy_params;
61 rayon::spawn(|| {}); let (_, lazy) = rayon::join(
63 || (), || find_matches_lazy(data, &lazy_params_clone),
65 );
66 lazy
67 };
68 #[cfg(not(feature = "parallel"))]
69 let lazy_seqs = find_matches_lazy(data, &lazy_params);
70
71 let fast_enc = resolve_repeat_offsets(&all_sequences);
72 let lazy_enc = resolve_repeat_offsets(&lazy_seqs);
73 let fast_cost = estimate_seq_cost(&all_sequences, &fast_enc);
74 let lazy_cost = estimate_seq_cost(&lazy_seqs, &lazy_enc);
75 if lazy_cost < fast_cost {
76 all_sequences = lazy_seqs;
77 }
78 }
79
80 let all_sequences = split_long_raw_sequences(all_sequences);
82 let all_encoded = resolve_repeat_offsets(&all_sequences);
83
84 let mut block_ranges: Vec<(usize, usize, usize)> = Vec::new(); let mut seq_start = 0usize;
89 let mut data_pos = 0usize;
90 let mut block_output = 0usize;
91
92 for (i, raw) in all_sequences.iter().enumerate() {
93 let seq_output = raw.ll as usize + raw.ml as usize;
94 if block_output + seq_output > ZSTD_BLOCKSIZE_MAX && i > seq_start {
95 block_ranges.push((seq_start, i, data_pos));
96 seq_start = i;
97 block_output = 0;
98 }
99 block_output += seq_output;
100 data_pos += seq_output;
101 }
102
103 let trailing_lits = data.len() - data_pos;
106 if block_output + trailing_lits > ZSTD_BLOCKSIZE_MAX && block_output > 0 {
107 block_ranges.push((seq_start, all_sequences.len(), data_pos));
109 block_ranges.push((all_sequences.len(), all_sequences.len(), data.len()));
111 } else {
112 block_ranges.push((seq_start, all_sequences.len(), data.len()));
113 }
114
115 let n_blocks = block_ranges.len();
116
117 let mut d_starts = Vec::with_capacity(n_blocks);
119 for &(s_start, _, _) in &block_ranges {
120 if s_start == 0 {
121 d_starts.push(0);
122 } else {
123 let mut p = 0usize;
124 for s in &all_sequences[..s_start] {
125 p += s.ll as usize + s.ml as usize;
126 }
127 d_starts.push(p);
128 }
129 }
130
131 #[cfg(feature = "parallel")]
133 let encode_block_content = |bi: usize| -> (Vec<u8>, Vec<u8>) {
134 let (s_start, s_end, d_end) = block_ranges[bi];
135 let block_seqs = &all_encoded[s_start..s_end];
136 let raw_seqs = &all_sequences[s_start..s_end];
137 let d_start = d_starts[bi];
138 let block_data = &data[d_start..d_end];
139
140 if block_seqs.is_empty() {
141 return (vec![], block_data.to_vec()); }
143
144 let mut literals = Vec::with_capacity(block_data.len());
145 let mut pos = 0usize;
146 for seq in raw_seqs {
147 literals.extend_from_slice(&block_data[pos..pos + seq.ll as usize]);
148 pos += seq.ll as usize + seq.ml as usize;
149 }
150 literals.extend_from_slice(&block_data[pos..]);
151
152 let mut block = Vec::with_capacity(block_data.len());
153 let mut used_huf = false;
154 if !literals.is_empty() && literals.iter().all(|&b| b == literals[0]) {
155 encode_literals_rle(&mut block, literals[0], literals.len());
156 used_huf = true;
157 } else if literals.len() >= 64 {
158 if let Some(huf) = encode_literals_huffman(&literals) {
159 if huf.len() < literals.len() {
160 block.extend_from_slice(&huf);
161 used_huf = true;
162 }
163 }
164 }
165 if !used_huf {
166 encode_literals_raw(&mut block, &literals);
167 }
168 encode_sequences_section(&mut block, block_seqs);
169 (block, block_data.to_vec())
170 };
171
172 #[cfg(feature = "parallel")]
176 let encoded_blocks: Vec<(Vec<u8>, Vec<u8>)> = {
177 use rayon::prelude::*;
178 (0..n_blocks)
179 .into_par_iter()
180 .map(&encode_block_content)
181 .collect()
182 };
183
184 #[cfg(not(feature = "parallel"))]
185 let encoded_blocks: Vec<(Vec<u8>, Vec<u8>)> = {
186 let mut prev_huf: Option<([(u32, u8); 256], u8)> = None;
187 let mut results = Vec::with_capacity(n_blocks);
188 for bi in 0..n_blocks {
189 let (s_start, s_end, d_end) = block_ranges[bi];
190 let block_seqs = &all_encoded[s_start..s_end];
191 let raw_seqs = &all_sequences[s_start..s_end];
192 let d_start = d_starts[bi];
193 let block_data = &data[d_start..d_end];
194
195 if block_seqs.is_empty() {
196 results.push((vec![], block_data.to_vec()));
197 continue;
198 }
199
200 let mut literals = Vec::with_capacity(block_data.len());
201 let mut pos = 0usize;
202 for seq in raw_seqs {
203 literals.extend_from_slice(&block_data[pos..pos + seq.ll as usize]);
204 pos += seq.ll as usize + seq.ml as usize;
205 }
206 literals.extend_from_slice(&block_data[pos..]);
207
208 let mut block = Vec::with_capacity(block_data.len());
209 let mut used_huf = false;
210
211 if !literals.is_empty() && literals.iter().all(|&b| b == literals[0]) {
212 encode_literals_rle(&mut block, literals[0], literals.len());
213 used_huf = true;
214 } else if literals.len() >= 64 {
215 let new_result = encode_literals_huffman(&literals);
216
217 let treeless_result = if let Some((prev_codes, _)) = &prev_huf {
219 if literals.iter().all(|&b| prev_codes[b as usize].1 > 0) {
220 encode_literals_treeless(&literals, prev_codes)
221 } else {
222 None
223 }
224 } else {
225 None
226 };
227
228 let mut used_treeless = false;
229 match (&new_result, &treeless_result) {
230 (Some(_), Some(te))
231 if te.len() <= new_result.as_ref().unwrap().len()
232 && te.len() < literals.len() =>
233 {
234 block.extend_from_slice(te);
235 used_huf = true;
236 used_treeless = true;
237 }
238 (Some(ne), _) if ne.len() < literals.len() => {
239 block.extend_from_slice(ne);
240 used_huf = true;
241 }
242 (None, Some(te)) if te.len() < literals.len() => {
243 block.extend_from_slice(te);
244 used_huf = true;
245 used_treeless = true;
246 }
247 _ => {}
248 }
249
250 if used_huf && !used_treeless {
253 let mut counts = [0u32; 256];
254 let mut ms = 0u8;
255 for &b in &literals {
256 counts[b as usize] += 1;
257 if b > ms {
258 ms = b;
259 }
260 }
261 if let Some((codes, mb)) = build_huffman_codes(&counts, ms as usize) {
262 prev_huf = Some((codes, mb));
263 }
264 }
265 }
266
267 if !used_huf {
268 encode_literals_raw(&mut block, &literals);
269 }
270 encode_sequences_section(&mut block, block_seqs);
271 results.push((block, block_data.to_vec()));
272 }
273 results
274 };
275
276 for (bi, (block, block_data)) in encoded_blocks.iter().enumerate() {
278 let is_last = bi == n_blocks - 1;
279
280 if block.is_empty() {
281 let mut rem = &block_data[..];
283 while !rem.is_empty() {
284 let sz = std::cmp::min(rem.len(), ZSTD_BLOCKSIZE_MAX);
285 let chunk = &rem[..sz];
286 let last = is_last && sz == rem.len();
287 if is_rle_block(chunk) {
288 write_rle_block(&mut out, chunk[0], chunk.len(), last);
289 } else {
290 write_raw_block(&mut out, chunk, last);
291 }
292 rem = &rem[sz..];
293 }
294 continue;
295 }
296
297 if block_data.len() <= ZSTD_BLOCKSIZE_MAX {
298 if is_rle_block(block_data) {
299 write_rle_block(&mut out, block_data[0], block_data.len(), is_last);
300 } else if block.len() < block_data.len() && block.len() <= ZSTD_BLOCKSIZE_MAX {
301 write_compressed_block(&mut out, block, is_last);
302 } else {
303 write_raw_block(&mut out, block_data, is_last);
304 }
305 } else if block.len() < block_data.len() && block.len() <= ZSTD_BLOCKSIZE_MAX {
306 write_compressed_block(&mut out, block, is_last);
307 } else {
308 let mut remaining = &block_data[..];
309 while !remaining.is_empty() {
310 let sz = std::cmp::min(remaining.len(), ZSTD_BLOCKSIZE_MAX);
311 let chunk = &remaining[..sz];
312 let last = is_last && sz == remaining.len();
313 if is_rle_block(chunk) {
314 write_rle_block(&mut out, chunk[0], chunk.len(), last);
315 } else {
316 write_raw_block(&mut out, chunk, last);
317 }
318 remaining = &remaining[sz..];
319 }
320 }
321 }
322
323 out
324}
325
326pub fn compress_to_vec(data: &[u8]) -> Vec<u8> {
328 compress(data, 1)
329}
330
331fn write_frame_header(out: &mut Vec<u8>, content_size: u64) {
336 out.extend_from_slice(&ZSTD_MAGIC.to_le_bytes());
338
339 let (fcs_flag, fcs_bytes) = if content_size <= 255 {
348 (0u8, 1) } else if content_size <= 65535 + 256 {
350 (1u8, 2) } else if content_size <= u32::MAX as u64 {
352 (2u8, 4)
353 } else {
354 (3u8, 8)
355 };
356
357 let single_segment = 1u8; let descriptor = (fcs_flag << 6) | (single_segment << 5);
360 out.push(descriptor);
361
362 match fcs_bytes {
366 1 => out.push(content_size as u8),
367 2 => out.extend_from_slice(&((content_size - 256) as u16).to_le_bytes()),
368 4 => out.extend_from_slice(&(content_size as u32).to_le_bytes()),
369 8 => out.extend_from_slice(&content_size.to_le_bytes()),
370 _ => {}
371 }
372}
373
374fn write_raw_block(out: &mut Vec<u8>, data: &[u8], is_last: bool) {
379 let header = (is_last as u32) | ((BLOCK_TYPE_RAW as u32) << 1) | ((data.len() as u32) << 3);
380 out.extend_from_slice(&header.to_le_bytes()[..3]);
381 out.extend_from_slice(data);
382}
383
384fn write_rle_block(out: &mut Vec<u8>, byte: u8, repeat_count: usize, is_last: bool) {
385 let header = (is_last as u32) | ((BLOCK_TYPE_RLE as u32) << 1) | ((repeat_count as u32) << 3);
386 out.extend_from_slice(&header.to_le_bytes()[..3]);
387 out.push(byte);
388}
389
390fn is_rle_block(data: &[u8]) -> bool {
392 if data.is_empty() {
393 return false;
394 }
395 let first = data[0];
396 data.iter().all(|&b| b == first)
397}
398
399fn write_compressed_block(out: &mut Vec<u8>, compressed: &[u8], is_last: bool) {
400 let header =
401 (is_last as u32) | ((BLOCK_TYPE_COMPRESSED as u32) << 1) | ((compressed.len() as u32) << 3);
402 out.extend_from_slice(&header.to_le_bytes()[..3]);
403 out.extend_from_slice(compressed);
404}
405
406struct Sequence {
414 ll: u32,
415 off: u32, ml: u32, }
418
419struct EncodedSequence {
422 ll: u32,
423 of_value: u32, ml: u32,
425}
426
427struct MatchParams {
429 hash_log: u32,
430 hash_bytes: usize, lazy_depth: u32, search_depth: u32, }
434
435impl MatchParams {
436 fn from_level(level: i32) -> Self {
444 match level {
445 0..=2 => Self {
446 hash_log: 14, hash_bytes: 7, lazy_depth: 0,
449 search_depth: 4,
450 },
451 3..=5 => Self {
452 hash_log: 18,
453 hash_bytes: 5,
454 lazy_depth: 1,
455 search_depth: 16,
456 },
457 6..=8 => Self {
458 hash_log: 19,
459 hash_bytes: 5,
460 lazy_depth: 1,
461 search_depth: 64,
462 },
463 _ => Self {
464 hash_log: 20,
465 hash_bytes: 5,
466 lazy_depth: 1,
467 search_depth: 256,
468 },
469 }
470 }
471}
472
473fn resolve_repeat_offsets(sequences: &[Sequence]) -> Vec<EncodedSequence> {
485 let mut rep = [1u32, 4, 8]; let mut out = Vec::with_capacity(sequences.len());
487
488 for seq in sequences {
489 let raw_off = seq.off;
490 let of_value;
491
492 if seq.ll > 0 {
493 if raw_off == rep[0] {
495 of_value = 1;
496 } else if raw_off == rep[1] {
498 of_value = 2;
499 rep = [rep[1], rep[0], rep[2]];
501 } else if raw_off == rep[2] {
502 of_value = 3;
503 rep = [rep[2], rep[0], rep[1]];
505 } else {
506 of_value = raw_off + 3;
507 rep = [raw_off, rep[0], rep[1]];
509 }
510 } else {
511 if raw_off == rep[1] {
514 of_value = 1;
515 rep = [rep[1], rep[0], rep[2]];
516 } else if raw_off == rep[2] {
517 of_value = 2;
518 rep = [rep[2], rep[0], rep[1]];
519 } else if raw_off == rep[0].wrapping_sub(1) && rep[0] > 1 {
520 of_value = 3;
521 rep = [rep[0] - 1, rep[0], rep[1]];
522 } else {
523 of_value = raw_off + 3;
524 rep = [raw_off, rep[0], rep[1]];
525 }
526 }
527
528 out.push(EncodedSequence {
529 ll: seq.ll,
530 of_value,
531 ml: seq.ml,
532 });
533 }
534
535 out
536}
537
538fn split_long_raw_sequences(sequences: Vec<Sequence>) -> Vec<Sequence> {
540 let needs_split = sequences
541 .iter()
542 .any(|s| s.ll as usize + s.ml as usize > ZSTD_BLOCKSIZE_MAX);
543 if !needs_split {
544 return sequences;
545 }
546
547 let mut out = Vec::with_capacity(sequences.len() + 16);
548 for seq in sequences {
549 let total = seq.ll as usize + seq.ml as usize;
550 if total <= ZSTD_BLOCKSIZE_MAX {
551 out.push(seq);
552 } else {
553 let max_ml = ZSTD_BLOCKSIZE_MAX.saturating_sub(seq.ll as usize);
555 let ml_first = std::cmp::max(ZSTD_MINMATCH, std::cmp::min(seq.ml as usize, max_ml));
556 out.push(Sequence {
557 ll: seq.ll,
558 off: seq.off,
559 ml: ml_first as u32,
560 });
561 let mut remaining = seq.ml as usize - ml_first;
562 while remaining > ZSTD_MINMATCH {
566 let ml = std::cmp::min(remaining - 1, ZSTD_BLOCKSIZE_MAX - 1);
567 out.push(Sequence {
568 ll: 1,
569 off: seq.off,
570 ml: ml as u32,
571 });
572 remaining -= ml + 1; }
574 }
575 }
576 out
577}
578
579fn find_matches(data: &[u8], params: &MatchParams) -> Vec<Sequence> {
594 if params.lazy_depth == 0 {
595 find_matches_fast(data, params)
596 } else {
597 find_matches_lazy(data, params)
598 }
599}
600
601fn estimate_seq_cost(raw_seqs: &[Sequence], _enc_seqs: &[EncodedSequence]) -> u64 {
604 let mut literal_bytes = 0u64;
605 for seq in raw_seqs {
606 literal_bytes += seq.ll as u64;
607 }
608 raw_seqs.len() as u64 * 20 + literal_bytes * 5
609}
610
611fn find_matches_fast(data: &[u8], params: &MatchParams) -> Vec<Sequence> {
614 const MIN_INPUT: usize = 8;
615 if data.len() < MIN_INPUT {
616 return vec![];
617 }
618
619 let hlog = params.hash_log;
620 let hash_size = 1usize << hlog;
621 let mls = params.hash_bytes; let mut ht = vec![0u32; hash_size]; let mut sequences = Vec::new();
624
625 let ilimit = data.len() - MIN_INPUT;
626 let mut anchor = 0usize;
627 let mut ip0 = 0usize;
628
629 let mut rep1 = 0u32; let mut rep2 = 0u32; const K_SEARCH_STRENGTH: u32 = 8;
633 const K_STEP_INCR: usize = 1 << (K_SEARCH_STRENGTH - 1); 'outer: loop {
636 let mut step: usize = 2; let mut next_step = ip0 + K_STEP_INCR;
638
639 let mut ip1 = ip0 + 1;
640
641 if ip1 > ilimit {
642 break;
643 }
644
645 let mut h0 = hash_n(&data[ip0..], (hash_size - 1) as u32, mls);
647
648 loop {
649 let ip2 = ip0 + step;
650 let ip3 = ip1 + step;
651 if ip3 > ilimit {
652 break 'outer;
653 }
654
655 let match_idx0 = ht[h0] as usize;
657
658 let h1 = hash_n(&data[ip1..], (hash_size - 1) as u32, mls);
660 ht[h0] = ip0 as u32;
661
662 if rep1 > 0 && ip2 >= rep1 as usize {
664 let rep_cand = ip2 - rep1 as usize;
665 if rep_cand + 4 <= data.len()
666 && ip2 + 4 <= data.len()
667 && read32(data, ip2) == read32(data, rep_cand)
668 {
669 ht[h1] = ip1 as u32;
671
672 let mut mlen = 4 + count_match(data, ip2 + 4, rep_cand + 4);
673 let mut start = ip2;
675 let mut mstart = rep_cand;
676 while start > anchor
677 && mstart > 0
678 && mlen < MAX_MATCH_LEN
679 && data[start - 1] == data[mstart - 1]
680 {
681 start -= 1;
682 mstart -= 1;
683 mlen += 1;
684 }
685
686 let ll = (start - anchor) as u32;
687 sequences.push(Sequence {
688 ll,
689 off: rep1,
690 ml: mlen as u32,
691 });
692 ip0 = start + mlen;
693 anchor = ip0;
694
695 rep_chain(
697 data,
698 &mut ip0,
699 &mut anchor,
700 &mut sequences,
701 &mut rep1,
702 &mut rep2,
703 &mut ht,
704 hlog,
705 mls,
706 ilimit,
707 );
708 continue 'outer;
709 }
710 }
711
712 if match_idx0 < ip0
714 && ip0 - match_idx0 <= (1 << 24)
715 && match_idx0 + 4 <= data.len()
716 && read32(data, ip0) == read32(data, match_idx0)
717 {
718 ht[h1] = ip1 as u32;
719 let match0 = match_idx0;
720
721 let mut mlen = 4 + count_match(data, ip0 + 4, match0 + 4);
722 let mut start = ip0;
723 let mut mstart = match0;
724 while start > anchor && mstart > 0 && data[start - 1] == data[mstart - 1] {
725 start -= 1;
726 mstart -= 1;
727 mlen += 1;
728 }
729
730 let offset = (start - mstart) as u32;
731 rep2 = rep1;
732 rep1 = offset;
733
734 let ll = (start - anchor) as u32;
735 sequences.push(Sequence {
736 ll,
737 off: offset,
738 ml: mlen as u32,
739 });
740 ip0 = start + mlen;
741 anchor = ip0;
742
743 if ip0 > 2 && ip0 + MIN_INPUT <= data.len() {
745 ht[hash_n(&data[ip0 - 2..], (hash_size - 1) as u32, mls)] = (ip0 - 2) as u32;
746 }
747
748 rep_chain(
749 data,
750 &mut ip0,
751 &mut anchor,
752 &mut sequences,
753 &mut rep1,
754 &mut rep2,
755 &mut ht,
756 hlog,
757 mls,
758 ilimit,
759 );
760 continue 'outer;
761 }
762
763 let match_idx1 = ht[h1] as usize;
765 h0 = hash_n(&data[ip2..], (hash_size - 1) as u32, mls);
766 ht[h1] = ip1 as u32;
767
768 if match_idx1 < ip1
769 && ip1 - match_idx1 <= (1 << 24)
770 && match_idx1 + 4 <= data.len()
771 && read32(data, ip1) == read32(data, match_idx1)
772 {
773 let match0 = match_idx1;
774 let mut mlen = 4 + count_match(data, ip1 + 4, match0 + 4);
775 let mut start = ip1;
776 let mut mstart = match0;
777 while start > anchor && mstart > 0 && data[start - 1] == data[mstart - 1] {
778 start -= 1;
779 mstart -= 1;
780 mlen += 1;
781 }
782
783 let offset = (start - mstart) as u32;
784 rep2 = rep1;
785 rep1 = offset;
786
787 let ll = (start - anchor) as u32;
788 sequences.push(Sequence {
789 ll,
790 off: offset,
791 ml: mlen as u32,
792 });
793 ip0 = start + mlen;
794 anchor = ip0;
795
796 if ip0 > 2 && ip0 + MIN_INPUT <= data.len() {
797 ht[hash_n(&data[ip0 - 2..], (hash_size - 1) as u32, mls)] = (ip0 - 2) as u32;
798 }
799
800 rep_chain(
801 data,
802 &mut ip0,
803 &mut anchor,
804 &mut sequences,
805 &mut rep1,
806 &mut rep2,
807 &mut ht,
808 hlog,
809 mls,
810 ilimit,
811 );
812 continue 'outer;
813 }
814
815 ip0 = ip2;
817 ip1 = ip3;
818
819 if ip0 >= next_step {
821 step += 1;
822 next_step += K_STEP_INCR;
823 }
824 }
825 }
826
827 sequences
828}
829
830#[allow(clippy::too_many_arguments)]
832#[inline]
833fn rep_chain(
834 data: &[u8],
835 ip: &mut usize,
836 anchor: &mut usize,
837 sequences: &mut Vec<Sequence>,
838 rep1: &mut u32,
839 rep2: &mut u32,
840 ht: &mut [u32],
841 hlog: u32,
842 mls: usize,
843 _ilimit: usize,
844) {
845 let hash_size = 1usize << hlog;
846 while *rep2 > 0 && *ip + 4 <= data.len() && *ip >= *rep2 as usize {
847 let cand = *ip - *rep2 as usize;
848 if cand + 4 > data.len() || read32(data, *ip) != read32(data, cand) {
849 break;
850 }
851 let mlen = 4 + count_match(data, *ip + 4, cand + 4);
852 std::mem::swap(rep1, rep2);
854
855 if *ip + 8 <= data.len() {
857 ht[hash_n(&data[*ip..], (hash_size - 1) as u32, mls)] = *ip as u32;
858 }
859
860 sequences.push(Sequence {
861 ll: 0,
862 off: *rep1,
863 ml: mlen as u32,
864 });
865 *ip += mlen;
866 *anchor = *ip;
867 }
868}
869
870#[inline]
871fn read32(data: &[u8], pos: usize) -> u32 {
872 u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]])
873}
874
875const MAX_MATCH_LEN: usize = ZSTD_BLOCKSIZE_MAX;
878
879#[inline]
880fn count_match(data: &[u8], mut a: usize, mut b: usize) -> usize {
881 let start = a;
882 let max_extend = MAX_MATCH_LEN - 4;
883 let limit = std::cmp::min(data.len(), start + max_extend);
884 while a + 8 <= limit && b + 8 <= data.len() {
885 let va = u64::from_le_bytes(data[a..a + 8].try_into().unwrap());
886 let vb = u64::from_le_bytes(data[b..b + 8].try_into().unwrap());
887 if va != vb {
888 return a - start + (va ^ vb).trailing_zeros() as usize / 8;
889 }
890 a += 8;
891 b += 8;
892 }
893 while a < limit && b < data.len() && data[a] == data[b] {
894 a += 1;
895 b += 1;
896 }
897 a - start
898}
899
900fn find_matches_lazy(data: &[u8], params: &MatchParams) -> Vec<Sequence> {
907 if data.len() < 8 {
908 return vec![];
909 }
910
911 let hash_size = 1usize << params.hash_log;
912 let hash_mask = (hash_size - 1) as u32;
913 let long_hash_size = 1usize << std::cmp::min(params.hash_log, 17);
914 let long_hash_mask = (long_hash_size - 1) as u32;
915 let mut ht_short = vec![0u32; hash_size]; let mut ht_long = vec![0u32; long_hash_size]; let mut chain = vec![0u32; data.len()];
918 let mut sequences = Vec::new();
919 let mut anchor = 0usize;
920 let mut ip = 0usize;
921 let mut rep1 = 0u32;
922 let mut rep2 = 0u32;
923 let lazy = params.lazy_depth >= 1;
924
925 while ip + 8 <= data.len() {
926 let rep_match = if rep1 > 0 && ip >= rep1 as usize && ip + 4 <= data.len() {
931 let cand = ip - rep1 as usize;
932 if cand + 4 <= data.len() && read32(data, ip) == read32(data, cand) {
933 let ml = 4 + count_match(data, ip + 4, cand + 4);
934 if ml * 5 > 40 {
943 Some((rep1 as usize, ml))
944 } else {
945 None
946 }
947 } else {
948 None
949 }
950 } else {
951 None
952 };
953
954 let short_match = find_best_at_n(
956 data,
957 ip,
958 &ht_short,
959 &chain,
960 hash_mask,
961 params.search_depth,
962 std::cmp::min(params.hash_bytes, 4),
963 );
964
965 let long_match = if ip + 7 <= data.len() {
967 let lh = hash7(&data[ip..], long_hash_mask);
968 let lidx = ht_long[lh] as usize;
969 ht_long[lh] = ip as u32;
970 if lidx < ip
971 && ip - lidx <= (1 << 24)
972 && lidx + 4 <= data.len()
973 && read32(data, ip) == read32(data, lidx)
974 {
975 let ml = 4 + count_match(data, ip + 4, lidx + 4);
976 Some((ip - lidx, ml))
977 } else {
978 None
979 }
980 } else {
981 None
982 };
983
984 let short_match = short_match.filter(|&(off, ml)| is_match_profitable(ml, off));
987 let long_match = long_match.filter(|&(off, ml)| is_match_profitable(ml, off));
988
989 let best_hash = match (short_match, long_match) {
990 (Some((so, sl)), Some((lo, ll))) => {
991 if ll >= sl + 2 {
992 Some((lo, ll))
993 } else {
994 Some((so, sl))
995 }
996 }
997 (Some(s), None) => Some(s),
998 (None, Some(l)) => Some(l),
999 (None, None) => None,
1000 };
1001
1002 let chosen = match (rep_match, best_hash) {
1003 (Some((roff, rml)), Some((hoff, hml))) => {
1004 let off_bits = 32u32.saturating_sub((hoff as u32).leading_zeros());
1005 if rml + (off_bits as usize / 4) >= hml {
1006 Some((roff, rml))
1007 } else {
1008 Some((hoff, hml))
1009 }
1010 }
1011 (Some(r), None) => Some(r),
1012 (None, Some(h)) => Some(h),
1013 (None, None) => None,
1014 };
1015
1016 if let Some((offset, match_len)) = chosen {
1017 let mut final_off = offset;
1018 let mut final_len = match_len;
1019 let mut final_ip = ip;
1020
1021 if lazy && ip + 1 + 8 <= data.len() {
1023 insert_hash_n(&mut ht_short, &mut chain, data, ip, hash_mask, 4);
1024 let mut next_best = None;
1025 if rep1 > 0 && ip + 1 >= rep1 as usize && ip + 5 <= data.len() {
1027 let c = ip + 1 - rep1 as usize;
1028 if c + 4 <= data.len() && read32(data, ip + 1) == read32(data, c) {
1029 let rl = 4 + count_match(data, ip + 5, c + 4);
1030 if rl > final_len {
1031 next_best = Some((rep1 as usize, rl));
1032 }
1033 }
1034 }
1035 if let Some((off2, len2)) = find_best_at_n(
1037 data,
1038 ip + 1,
1039 &ht_short,
1040 &chain,
1041 hash_mask,
1042 params.search_depth,
1043 4,
1044 ) {
1045 if len2 > final_len + 1 && (next_best.is_none() || len2 > next_best.unwrap().1)
1046 {
1047 next_best = Some((off2, len2));
1048 }
1049 }
1050 if let Some((off2, len2)) = next_best {
1051 if len2 > final_len + 1 {
1052 final_off = off2;
1053 final_len = len2;
1054 final_ip = ip + 1;
1055 }
1056 }
1057 }
1058
1059 let ll = (final_ip - anchor) as u32;
1060 sequences.push(Sequence {
1061 ll,
1062 off: final_off as u32,
1063 ml: final_len as u32,
1064 });
1065
1066 if final_off as u32 != rep1 {
1067 rep2 = rep1;
1068 rep1 = final_off as u32;
1069 }
1070
1071 let end = std::cmp::min(final_ip + final_len, data.len().saturating_sub(4));
1072 for p in ip..end {
1073 insert_hash_n(&mut ht_short, &mut chain, data, p, hash_mask, 4);
1074 if p + 7 <= data.len() {
1075 ht_long[hash7(&data[p..], long_hash_mask)] = p as u32;
1076 }
1077 }
1078
1079 ip = final_ip + final_len;
1080 anchor = ip;
1081
1082 while rep2 > 0 && ip + 4 <= data.len() && ip >= rep2 as usize {
1084 let cand = ip - rep2 as usize;
1085 if cand + 4 > data.len() || read32(data, ip) != read32(data, cand) {
1086 break;
1087 }
1088 let mlen = 4 + count_match(data, ip + 4, cand + 4);
1089 if mlen < ZSTD_MINMATCH {
1090 break;
1091 }
1092
1093 std::mem::swap(&mut rep2, &mut rep1);
1094 sequences.push(Sequence {
1095 ll: 0,
1096 off: rep1,
1097 ml: mlen as u32,
1098 });
1099
1100 let end2 = std::cmp::min(ip + mlen, data.len().saturating_sub(4));
1101 for p in ip..end2 {
1102 insert_hash_n(&mut ht_short, &mut chain, data, p, hash_mask, 4);
1103 if p + 7 <= data.len() {
1104 ht_long[hash7(&data[p..], long_hash_mask)] = p as u32;
1105 }
1106 }
1107 ip += mlen;
1108 anchor = ip;
1109 }
1110 } else {
1111 insert_hash_n(&mut ht_short, &mut chain, data, ip, hash_mask, 4);
1112 if ip + 7 <= data.len() {
1113 ht_long[hash7(&data[ip..], long_hash_mask)] = ip as u32;
1114 }
1115 ip += 1;
1116 }
1117 }
1118
1119 sequences
1120}
1121
1122#[inline]
1133fn is_match_profitable(match_len: usize, offset: usize) -> bool {
1134 let off_code = if offset > 1 {
1135 32 - (offset as u32).leading_zeros()
1136 } else {
1137 1
1138 };
1139 let overhead_bits = 17 + off_code;
1143 let savings_bits = match_len as u32 * 5;
1144 savings_bits > overhead_bits + 8 }
1146
1147fn insert_hash_n(
1151 hash_table: &mut [u32],
1152 chain: &mut [u32],
1153 data: &[u8],
1154 pos: usize,
1155 mask: u32,
1156 hash_bytes: usize,
1157) {
1158 if pos + hash_bytes > data.len() {
1159 return;
1160 }
1161 let h = hash_n(&data[pos..], mask, hash_bytes);
1162 chain[pos] = hash_table[h];
1163 hash_table[h] = pos as u32;
1164}
1165
1166#[inline]
1168fn hash_n(data: &[u8], mask: u32, n: usize) -> usize {
1169 match n {
1170 5 => hash5(data, mask),
1171 6 => hash6(data, mask),
1172 7 => hash7(data, mask),
1173 _ => hash4(data, mask),
1174 }
1175}
1176
1177#[inline]
1179fn hash6(data: &[u8], mask: u32) -> usize {
1180 let v = u64::from_le_bytes([data[0], data[1], data[2], data[3], data[4], data[5], 0, 0]);
1181 ((v.wrapping_mul(227718039650203u64)) >> 24) as usize & mask as usize
1182}
1183
1184#[inline]
1187fn hash7(data: &[u8], mask: u32) -> usize {
1188 let v = u64::from_le_bytes([
1189 data[0], data[1], data[2], data[3], data[4], data[5], data[6], 0,
1190 ]);
1191 ((v.wrapping_mul(58295818150454627u64)) >> 24) as usize & mask as usize
1192}
1193
1194fn find_best_at_n(
1196 data: &[u8],
1197 pos: usize,
1198 hash_table: &[u32],
1199 chain: &[u32],
1200 mask: u32,
1201 max_depth: u32,
1202 hash_bytes: usize,
1203) -> Option<(usize, usize)> {
1204 if pos + hash_bytes > data.len() {
1205 return None;
1206 }
1207 let h = hash_n(&data[pos..], mask, hash_bytes);
1208 let mut candidate = hash_table[h] as usize;
1209 let mut best_len = ZSTD_MINMATCH - 1;
1210 let mut best_off = 0;
1211
1212 for _ in 0..max_depth {
1213 if candidate >= pos || pos - candidate > (1 << 24) {
1214 break;
1215 }
1216 if candidate + ZSTD_MINMATCH > data.len() {
1217 break;
1218 }
1219
1220 if data[candidate..candidate + 4] == data[pos..pos + 4] {
1222 let max_ml = std::cmp::min(ZSTD_BLOCKSIZE_MAX, data.len() - pos);
1224 let cand_max = std::cmp::min(max_ml, data.len() - candidate);
1225 let ml = common_prefix_len(
1226 &data[candidate..candidate + cand_max],
1227 &data[pos..pos + cand_max],
1228 );
1229 if ml > best_len {
1230 best_len = ml;
1231 best_off = pos - candidate;
1232 }
1233 }
1234
1235 let next = chain[candidate] as usize;
1236 if next >= candidate {
1237 break;
1238 }
1239 candidate = next;
1240 }
1241
1242 if best_len >= ZSTD_MINMATCH {
1243 Some((best_off, best_len))
1244 } else {
1245 None
1246 }
1247}
1248
1249#[inline]
1251fn common_prefix_len(a: &[u8], b: &[u8]) -> usize {
1252 let max = std::cmp::min(a.len(), b.len());
1253 let mut i = 0;
1254 while i + 8 <= max {
1255 let va = u64::from_le_bytes(a[i..i + 8].try_into().unwrap());
1256 let vb = u64::from_le_bytes(b[i..i + 8].try_into().unwrap());
1257 if va != vb {
1258 return i + ((va ^ vb).trailing_zeros() / 8) as usize;
1259 }
1260 i += 8;
1261 }
1262 while i < max && a[i] == b[i] {
1263 i += 1;
1264 }
1265 i
1266}
1267
1268#[inline]
1271fn hash5(data: &[u8], mask: u32) -> usize {
1272 let v = u64::from_le_bytes([data[0], data[1], data[2], data[3], data[4], 0, 0, 0]);
1273 ((v.wrapping_mul(889523592379u64)) >> 24) as usize & mask as usize
1274}
1275
1276#[inline]
1278fn hash4(data: &[u8], mask: u32) -> usize {
1279 let v = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
1280 (v.wrapping_mul(0x9E3779B1) as usize) & (mask as usize)
1281}
1282
1283#[cfg(not(feature = "parallel"))]
1291fn encode_literals_treeless(literals: &[u8], prev_codes: &[(u32, u8); 256]) -> Option<Vec<u8>> {
1292 let use_4 = literals.len() >= 1024;
1293 let streams = if use_4 {
1294 encode_huf_4streams(literals, prev_codes)
1295 } else {
1296 encode_huf_1stream(literals, prev_codes)
1297 };
1298 let regen = literals.len();
1299 let comp = streams.len();
1300 if comp >= regen {
1301 return None;
1302 }
1303 let lh_size = 3 + (regen >= 1024) as usize + (regen >= 16384) as usize;
1304 let mut out = Vec::with_capacity(lh_size + comp);
1305 let htype = LIT_TYPE_TREELESS as u32;
1306 match lh_size {
1307 3 => {
1308 let sf = if use_4 { 1u32 } else { 0 };
1309 out.extend_from_slice(
1310 &(htype | (sf << 2) | ((regen as u32) << 4) | ((comp as u32) << 14)).to_le_bytes()
1311 [..3],
1312 );
1313 }
1314 4 => out.extend_from_slice(
1315 &(htype | (2u32 << 2) | ((regen as u32) << 4) | ((comp as u32) << 18)).to_le_bytes()
1316 [..4],
1317 ),
1318 _ => {
1319 let v = htype | (3u32 << 2) | ((regen as u32) << 4) | ((comp as u32) << 22);
1320 out.extend_from_slice(&v.to_le_bytes()[..4]);
1321 out.push((comp >> 10) as u8);
1322 }
1323 }
1324 out.extend_from_slice(&streams);
1325 Some(out)
1326}
1327
1328fn encode_literals_huffman(literals: &[u8]) -> Option<Vec<u8>> {
1329 let mut counts = [0u32; 256];
1331 let mut max_sym = 0u8;
1332 for &b in literals {
1333 counts[b as usize] += 1;
1334 if b > max_sym {
1335 max_sym = b;
1336 }
1337 }
1338 let n_used = counts.iter().filter(|&&c| c > 0).count();
1339 if n_used < 2 {
1340 return None;
1341 }
1342
1343 let (codes, max_bits) = build_huffman_codes(&counts, max_sym as usize)?;
1345
1346 let tree_desc = encode_huffman_tree(&codes, max_bits, max_sym as usize);
1348 if tree_desc.is_empty() {
1349 return None;
1350 }
1351
1352 let use_4 = literals.len() >= 1024;
1354 let streams = if use_4 {
1355 encode_huf_4streams(literals, &codes)
1356 } else {
1357 encode_huf_1stream(literals, &codes)
1358 };
1359
1360 let regen = literals.len();
1361 let comp = tree_desc.len() + streams.len();
1362 let lh_size = 3 + (regen >= 1024) as usize + (regen >= 16384) as usize;
1363
1364 let mut out = Vec::with_capacity(lh_size + comp);
1365 let htype = LIT_TYPE_COMPRESSED as u32;
1366
1367 match lh_size {
1368 3 => {
1369 let sf = if use_4 { 1u32 } else { 0u32 };
1371 let lhc = htype | (sf << 2) | ((regen as u32) << 4) | ((comp as u32) << 14);
1372 out.extend_from_slice(&lhc.to_le_bytes()[..3]);
1373 }
1374 4 => {
1375 let lhc = htype | (2u32 << 2) | ((regen as u32) << 4) | ((comp as u32) << 18);
1376 out.extend_from_slice(&lhc.to_le_bytes()[..4]);
1377 }
1378 _ => {
1379 let lhc = htype | (3u32 << 2) | ((regen as u32) << 4) | ((comp as u32) << 22);
1380 out.extend_from_slice(&lhc.to_le_bytes()[..4]);
1381 out.push((comp >> 10) as u8);
1382 }
1383 }
1384
1385 out.extend_from_slice(&tree_desc);
1386 out.extend_from_slice(&streams);
1387 Some(out)
1388}
1389
1390fn build_huffman_codes(counts: &[u32; 256], max_sym: usize) -> Option<([(u32, u8); 256], u8)> {
1396 const MAX_BITS: u8 = 11;
1397
1398 let mut syms: Vec<(u32, u8)> = (0..=max_sym)
1400 .filter(|&s| counts[s] > 0)
1401 .map(|s| (counts[s], s as u8))
1402 .collect();
1403 syms.sort_by(|a, b| b.0.cmp(&a.0));
1404 let n = syms.len();
1405 if n < 2 {
1406 return None;
1407 }
1408
1409 let mut node_count = vec![0u64; 2 * n];
1411 let mut node_parent = vec![0u32; 2 * n];
1412 let mut node_nbits = vec![0u8; 2 * n];
1413 for i in 0..n {
1414 node_count[i] = syms[i].0 as u64;
1415 }
1416 for i in n..2 * n {
1417 node_count[i] = u64::MAX / 2;
1418 }
1419
1420 let mut low_s = n as i32 - 1;
1421 let mut low_n = n;
1422 let mut next_node = n;
1423
1424 let pick_smallest =
1426 |node_count: &[u64], low_s: &mut i32, low_n: &mut usize, next_node: usize| -> usize {
1427 if *low_s >= 0
1428 && (*low_n >= next_node || node_count[*low_s as usize] < node_count[*low_n])
1429 {
1430 let r = *low_s as usize;
1431 *low_s -= 1;
1432 r
1433 } else if *low_n < next_node {
1434 let r = *low_n;
1435 *low_n += 1;
1436 r
1437 } else {
1438 usize::MAX }
1440 };
1441
1442 while next_node < 2 * n - 1 {
1443 let n1 = pick_smallest(&node_count, &mut low_s, &mut low_n, next_node);
1444 let n2 = pick_smallest(&node_count, &mut low_s, &mut low_n, next_node);
1445 if n1 == usize::MAX || n2 == usize::MAX {
1446 break;
1447 }
1448 node_count[next_node] = node_count[n1] + node_count[n2];
1449 node_parent[n1] = next_node as u32;
1450 node_parent[n2] = next_node as u32;
1451 next_node += 1;
1452 }
1453 let root = next_node - 1;
1454
1455 node_nbits[root] = 0;
1457 for i in (n..=root).rev() {
1458 if i < root {
1459 node_nbits[i] = node_nbits[node_parent[i] as usize] + 1;
1460 }
1461 }
1462 for i in 0..n {
1463 node_nbits[i] = node_nbits[node_parent[i] as usize] + 1;
1464 }
1465
1466 let largest_bits = *node_nbits[..n].iter().max().unwrap_or(&0);
1471 if largest_bits > MAX_BITS {
1472 let target = MAX_BITS;
1473
1474 let base_cost = 1i32 << (largest_bits - target);
1476 let mut total_cost = 0i32;
1477 let mut last_non_null = n - 1;
1479 while node_nbits[last_non_null] > target {
1480 total_cost += base_cost - (1i32 << (largest_bits - node_nbits[last_non_null]));
1481 node_nbits[last_non_null] = target;
1482 if last_non_null == 0 {
1483 break;
1484 }
1485 last_non_null -= 1;
1486 }
1487 total_cost >>= largest_bits - target;
1488
1489 const NO_SYMBOL: u32 = 0xF0F0F0F0;
1492 let mut rank_last = [NO_SYMBOL; 16];
1493 {
1494 let mut current_bits = target;
1495 for pos in (0..=last_non_null).rev() {
1496 if node_nbits[pos] >= current_bits {
1497 continue;
1498 }
1499 current_bits = node_nbits[pos];
1500 rank_last[(target - current_bits) as usize] = pos as u32;
1501 }
1502 }
1503
1504 while total_cost > 0 {
1506 let mut n_bits_to_decrease = 32 - (total_cost as u32).leading_zeros();
1508 if n_bits_to_decrease > largest_bits as u32 - target as u32 + 1 {
1510 n_bits_to_decrease = largest_bits as u32 - target as u32 + 1;
1511 }
1512
1513 while n_bits_to_decrease > 1 {
1515 let high_pos = rank_last[n_bits_to_decrease as usize];
1516 let low_pos = rank_last[n_bits_to_decrease as usize - 1];
1517 if high_pos == NO_SYMBOL {
1518 n_bits_to_decrease -= 1;
1519 continue;
1520 }
1521 if low_pos == NO_SYMBOL {
1522 break;
1523 }
1524 let high_total = syms[high_pos as usize].0;
1525 let low_total = 2 * syms[low_pos as usize].0;
1526 if high_total <= low_total {
1527 break;
1528 }
1529 n_bits_to_decrease -= 1;
1530 }
1531
1532 while n_bits_to_decrease as usize <= 14
1534 && rank_last[n_bits_to_decrease as usize] == NO_SYMBOL
1535 {
1536 n_bits_to_decrease += 1;
1537 }
1538 if n_bits_to_decrease as usize > 14
1539 || rank_last[n_bits_to_decrease as usize] == NO_SYMBOL
1540 {
1541 break; }
1543
1544 total_cost -= 1i32 << (n_bits_to_decrease - 1);
1546 let pos = rank_last[n_bits_to_decrease as usize] as usize;
1547 node_nbits[pos] += 1;
1548
1549 if rank_last[n_bits_to_decrease as usize - 1] == NO_SYMBOL {
1551 rank_last[n_bits_to_decrease as usize - 1] = rank_last[n_bits_to_decrease as usize];
1552 }
1553
1554 if rank_last[n_bits_to_decrease as usize] == 0 {
1556 rank_last[n_bits_to_decrease as usize] = NO_SYMBOL;
1557 } else {
1558 let prev = rank_last[n_bits_to_decrease as usize] - 1;
1559 rank_last[n_bits_to_decrease as usize] = prev;
1560 if node_nbits[prev as usize] != target - n_bits_to_decrease as u8 {
1561 rank_last[n_bits_to_decrease as usize] = NO_SYMBOL;
1562 }
1563 }
1564 }
1565
1566 while total_cost < 0 {
1569 if rank_last[1] == NO_SYMBOL {
1570 let mut p = last_non_null;
1572 while p > 0 && node_nbits[p] == target {
1573 p -= 1;
1574 }
1575 if p + 1 < n && node_nbits[p + 1] == target {
1577 node_nbits[p + 1] -= 1; rank_last[1] = (p + 1) as u32;
1579 total_cost += 1;
1580 } else {
1581 break; }
1583 } else {
1584 let next = rank_last[1] as usize + 1;
1586 if next < n && node_nbits[next] == target {
1587 node_nbits[next] -= 1;
1588 rank_last[1] += 1;
1589 total_cost += 1;
1590 } else {
1591 rank_last[1] = NO_SYMBOL;
1593 }
1595 }
1596 }
1597
1598 if total_cost != 0 {
1600 for i in 0..n {
1601 node_nbits[i] = 0;
1602 } }
1604 }
1605
1606 let mut lengths = [0u8; 256];
1608 for i in 0..n {
1609 lengths[syms[i].1 as usize] = node_nbits[i];
1610 }
1611
1612 let max_bits = *lengths.iter().max().unwrap_or(&0);
1613 if max_bits == 0 {
1614 return None;
1615 }
1616
1617 let kraft: u64 = (0..=max_sym)
1619 .filter(|&s| lengths[s] > 0)
1620 .map(|s| 1u64 << (max_bits - lengths[s]))
1621 .sum();
1622 if kraft != (1u64 << max_bits) {
1623 return None;
1624 }
1625
1626 let mut nb_per_rank = [0u32; 16];
1633 for &l in &lengths {
1634 if l > 0 {
1635 nb_per_rank[l as usize] += 1;
1636 }
1637 }
1638
1639 let mut rank_indexes = [0u32; 16];
1644 rank_indexes[max_bits as usize] = 0;
1645 for bits in (1..=max_bits as usize).rev() {
1646 rank_indexes[bits - 1] =
1647 rank_indexes[bits] + nb_per_rank[bits] * (1u32 << (max_bits as usize - bits));
1648 }
1649
1650 let mut next_code = [0u32; 16];
1652 for bits in 1..=max_bits as usize {
1653 next_code[bits] = rank_indexes[bits] >> (max_bits as usize - bits);
1654 }
1655
1656 let mut codes = [(0u32, 0u8); 256];
1657 for s in 0..=max_sym {
1658 if lengths[s] > 0 {
1659 codes[s] = (next_code[lengths[s] as usize], lengths[s]);
1660 next_code[lengths[s] as usize] += 1;
1661 }
1662 }
1663
1664 Some((codes, max_bits))
1665}
1666
1667fn encode_huffman_tree(codes: &[(u32, u8); 256], max_bits: u8, max_sym: usize) -> Vec<u8> {
1668 if max_bits == 0 {
1669 return vec![];
1670 }
1671 let mut weights: Vec<u8> = (0..=max_sym)
1672 .map(|s| {
1673 if codes[s].1 > 0 {
1674 max_bits + 1 - codes[s].1
1675 } else {
1676 0
1677 }
1678 })
1679 .collect();
1680 while weights.last() == Some(&0) && weights.len() > 1 {
1681 weights.pop();
1682 }
1683 if !weights.is_empty() {
1684 weights.pop();
1685 } if weights.is_empty() || weights.len() > 255 {
1687 return vec![];
1688 }
1689
1690 if weights.iter().any(|&w| w > 12) {
1692 return vec![];
1693 }
1694
1695 let num = weights.len();
1696
1697 if num <= 128 {
1698 let mut desc = Vec::with_capacity(1 + num.div_ceil(2));
1700 desc.push((num as u8) + 127);
1701 for pair in weights.chunks(2) {
1702 let w0 = pair[0];
1703 let w1 = if pair.len() > 1 { pair[1] } else { 0 };
1704 desc.push((w0 << 4) | (w1 & 0x0F));
1705 }
1706 desc
1707 } else {
1708 let fse_result = encode_weights_fse(&weights);
1710 match fse_result {
1711 Some(compressed) if compressed.len() < 127 => {
1712 let header_byte = compressed.len() as u8;
1713 let verify = crate::decode::decode_huf_weights_from_fse(&compressed, header_byte);
1715 if let Ok(ref dw) = verify {
1716 if *dw == weights {
1717 let mut desc = Vec::with_capacity(1 + compressed.len());
1718 desc.push(header_byte);
1719 desc.extend_from_slice(&compressed);
1720 return desc;
1721 }
1722 }
1723 if num <= 128 {
1725 let mut desc = Vec::with_capacity(1 + num.div_ceil(2));
1726 desc.push((num as u8) + 127);
1727 for pair in weights.chunks(2) {
1728 let w0 = pair[0];
1729 let w1 = if pair.len() > 1 { pair[1] } else { 0 };
1730 desc.push((w0 << 4) | (w1 & 0x0F));
1731 }
1732 desc
1733 } else {
1734 vec![]
1735 }
1736 }
1737 _ => {
1738 if num <= 128 {
1740 let mut desc = Vec::with_capacity(1 + num.div_ceil(2));
1741 desc.push((num as u8) + 127);
1742 for pair in weights.chunks(2) {
1743 let w0 = pair[0];
1744 let w1 = if pair.len() > 1 { pair[1] } else { 0 };
1745 desc.push((w0 << 4) | (w1 & 0x0F));
1746 }
1747 desc
1748 } else {
1749 vec![] }
1751 }
1752 }
1753 }
1754}
1755
1756fn encode_weights_fse(weights: &[u8]) -> Option<Vec<u8>> {
1761 let mut counts = [0u32; 13];
1762 let mut max_w = 0u8;
1763 for &w in weights {
1764 counts[w as usize] += 1;
1765 if w > max_w {
1766 max_w = w;
1767 }
1768 }
1769 if max_w == 0 {
1770 return None;
1771 }
1772
1773 let table_log = 6u32;
1774 let table_size = 1u32 << table_log;
1775 let total = weights.len() as u32;
1776
1777 let mut norm = [0i16; 13];
1779 let mut dist = 0u32;
1780 for s in 0..=max_w as usize {
1781 if counts[s] == 0 {
1782 continue;
1783 }
1784 norm[s] = std::cmp::max(
1785 1,
1786 (counts[s] as u64 * table_size as u64 / total as u64) as i16,
1787 );
1788 dist += norm[s] as u32;
1789 }
1790 while dist > table_size {
1791 for s in 0..=max_w as usize {
1792 if norm[s] > 1 {
1793 norm[s] -= 1;
1794 dist -= 1;
1795 break;
1796 }
1797 }
1798 }
1799 while dist < table_size {
1800 let best = (0..=max_w as usize).max_by_key(|&s| counts[s]).unwrap_or(0);
1801 norm[best] += 1;
1802 dist += 1;
1803 }
1804
1805 let fse = super::fse::FseCTable::build(&norm, max_w as usize, table_log);
1806
1807 let mut hdr = Vec::with_capacity(16);
1814 let mut bb: u64 = (table_log - 5) as u64;
1815 let mut bp = 4u32;
1816 let prob_sum = table_size;
1817 let mut counter = 0u32;
1818
1819 let mut s = 0usize;
1820 while s <= max_w as usize && counter < prob_sum {
1821 let prob = norm[s] as i32;
1822 let value = (prob + 1) as u32;
1823
1824 let max_remaining = prob_sum - counter + 1;
1825 let bits_to_read = 32 - max_remaining.leading_zeros();
1826 let low_threshold = ((1u32 << bits_to_read) - 1) - max_remaining;
1827 let mask = (1u32 << (bits_to_read - 1)) - 1;
1828
1829 if value < low_threshold {
1830 bb |= (value as u64) << bp;
1831 bp += bits_to_read - 1;
1832 } else if value <= mask {
1833 bb |= (value as u64) << bp;
1834 bp += bits_to_read;
1835 } else {
1836 bb |= ((value + low_threshold) as u64) << bp;
1837 bp += bits_to_read;
1838 }
1839 while bp >= 8 {
1840 hdr.push(bb as u8);
1841 bb >>= 8;
1842 bp -= 8;
1843 }
1844
1845 if prob > 0 {
1846 counter += prob as u32;
1847 } else if prob == -1 {
1848 counter += 1;
1849 }
1850
1851 if prob == 0 {
1852 let mut repeat = 0u32;
1853 while s + 1 + repeat as usize <= max_w as usize
1854 && norm[s + 1 + repeat as usize] == 0
1855 && repeat < 3
1856 {
1857 repeat += 1;
1858 }
1859 bb |= (repeat as u64) << bp;
1860 bp += 2;
1861 while bp >= 8 {
1862 hdr.push(bb as u8);
1863 bb >>= 8;
1864 bp -= 8;
1865 }
1866 s += repeat as usize;
1867 while repeat == 3 {
1868 repeat = 0;
1869 while s + 1 + repeat as usize <= max_w as usize
1870 && norm[s + 1 + repeat as usize] == 0
1871 && repeat < 3
1872 {
1873 repeat += 1;
1874 }
1875 bb |= (repeat as u64) << bp;
1876 bp += 2;
1877 while bp >= 8 {
1878 hdr.push(bb as u8);
1879 bb >>= 8;
1880 bp -= 8;
1881 }
1882 s += repeat as usize;
1883 }
1884 }
1885 s += 1;
1886 }
1887 if bp > 0 {
1888 hdr.push(bb as u8);
1889 }
1890
1891 let ts = table_size as usize;
1894
1895 let mut dec_symbol = vec![0u8; ts];
1897 let mut dec_baseline = vec![0u32; ts];
1898 let mut dec_numbits = vec![0u8; ts];
1899
1900 let mut neg_idx = ts;
1902 for s in 0..=max_w as usize {
1903 if norm[s] == -1 {
1904 neg_idx -= 1;
1905 dec_symbol[neg_idx] = s as u8;
1906 dec_baseline[neg_idx] = 0;
1907 dec_numbits[neg_idx] = table_log as u8;
1908 }
1909 }
1910
1911 let mut pos = 0usize;
1913 for s in 0..=max_w as usize {
1914 if norm[s] <= 0 {
1915 continue;
1916 }
1917 for _ in 0..norm[s] {
1918 dec_symbol[pos] = s as u8;
1919 pos += (ts >> 1) + (ts >> 3) + 3;
1920 pos &= ts - 1;
1921 while pos >= neg_idx {
1922 pos += (ts >> 1) + (ts >> 3) + 3;
1923 pos &= ts - 1;
1924 }
1925 }
1926 }
1927
1928 let stream1: Vec<u8> = weights.iter().step_by(2).copied().collect();
1932 let stream2: Vec<u8> = weights.iter().skip(1).step_by(2).copied().collect();
1933 let len1 = stream1.len();
1934 let len2 = stream2.len();
1935
1936 let mut st1 = fse.init_state(*stream1.last().unwrap() as usize);
1938 let mut st2 = if len2 > 0 {
1939 fse.init_state(*stream2.last().unwrap() as usize)
1940 } else {
1941 0
1942 };
1943
1944 let mut bw = super::bitstream::BackwardBitWriter::new();
1945
1946 let _max_idx = std::cmp::max(len1, len2);
1950 let max_encode = std::cmp::max(len1.saturating_sub(1), len2.saturating_sub(1));
1954 for i in (0..max_encode).rev() {
1955 if i < len2.saturating_sub(1) {
1956 let (bits, nb, ns) = fse.encode_symbol(st2, stream2[i] as usize);
1957 bw.add_bits(bits as u64, nb);
1958 bw.flush_bits();
1959 st2 = ns;
1960 }
1961 if i < len1.saturating_sub(1) {
1962 let (bits, nb, ns) = fse.encode_symbol(st1, stream1[i] as usize);
1963 bw.add_bits(bits as u64, nb);
1964 bw.flush_bits();
1965 st1 = ns;
1966 }
1967 }
1968
1969 bw.add_bits((st2 - table_size) as u64, table_log);
1971 bw.flush_bits();
1972 bw.add_bits((st1 - table_size) as u64, table_log);
1973 bw.flush_bits();
1974
1975 let bitstream = bw.finish();
1976 let mut out = hdr;
1977 out.extend_from_slice(&bitstream);
1978 Some(out)
1979}
1980
1981fn encode_huf_1stream(data: &[u8], codes: &[(u32, u8); 256]) -> Vec<u8> {
1983 let mut bw = super::bitstream::BackwardBitWriter::new();
1984 for &sym in data.iter().rev() {
1986 let (code, nb) = codes[sym as usize];
1987 if nb == 0 {
1988 continue;
1989 }
1990 bw.add_bits(code as u64, nb as u32);
1991 bw.flush_bits();
1992 }
1993 bw.finish() }
1995
1996fn encode_huf_4streams(data: &[u8], codes: &[(u32, u8); 256]) -> Vec<u8> {
1997 let q = data.len().div_ceil(4);
1998 let ends = [
1999 q,
2000 std::cmp::min(q * 2, data.len()),
2001 std::cmp::min(q * 3, data.len()),
2002 data.len(),
2003 ];
2004 let starts = [0, q, ends[1], ends[2]];
2005
2006 let c: Vec<Vec<u8>> = (0..4)
2007 .map(|i| encode_huf_1stream(&data[starts[i]..ends[i]], codes))
2008 .collect();
2009
2010 let mut out = Vec::with_capacity(6 + c.iter().map(|v| v.len()).sum::<usize>());
2011 for i in 0..3 {
2013 out.extend_from_slice(&(c[i].len() as u16).to_le_bytes());
2014 }
2015 for stream in &c {
2016 out.extend_from_slice(stream);
2017 }
2018 out
2019}
2020
2021fn encode_literals_rle(out: &mut Vec<u8>, byte: u8, size: usize) {
2026 if size <= 31 {
2027 out.push(LIT_TYPE_RLE | ((size as u8) << 3));
2028 } else if size <= 4095 {
2029 let h = (LIT_TYPE_RLE as u16) | (1 << 2) | ((size as u16) << 4);
2030 out.extend_from_slice(&h.to_le_bytes());
2031 } else {
2032 let h = (LIT_TYPE_RLE as u32) | (3 << 2) | ((size as u32) << 4);
2033 out.extend_from_slice(&h.to_le_bytes()[..3]);
2034 }
2035 out.push(byte);
2036}
2037
2038fn encode_literals_raw(out: &mut Vec<u8>, literals: &[u8]) {
2039 let size = literals.len();
2040
2041 if size <= 31 {
2042 out.push(LIT_TYPE_RAW | ((size as u8) << 3));
2044 } else if size <= 4095 {
2045 let h = (LIT_TYPE_RAW as u16) | (1 << 2) | ((size as u16) << 4);
2047 out.extend_from_slice(&h.to_le_bytes());
2048 } else {
2049 let h = (LIT_TYPE_RAW as u32) | (3 << 2) | ((size as u32) << 4);
2051 out.extend_from_slice(&h.to_le_bytes()[..3]);
2052 }
2053
2054 out.extend_from_slice(literals);
2055}
2056
2057fn encode_sequences_section(out: &mut Vec<u8>, sequences: &[EncodedSequence]) {
2065 let nb_seq = sequences.len();
2066
2067 if nb_seq < 128 {
2069 out.push(nb_seq as u8);
2070 } else if nb_seq < 0x7F00 {
2071 out.push(((nb_seq >> 8) as u8) + 128);
2072 out.push(nb_seq as u8);
2073 } else {
2074 out.push(255);
2075 out.extend_from_slice(&((nb_seq - 0x7F00) as u16).to_le_bytes());
2076 }
2077
2078 if nb_seq == 0 {
2079 return;
2080 }
2081
2082 let mut ll_codes_v = Vec::with_capacity(nb_seq);
2084 let mut ml_codes_v = Vec::with_capacity(nb_seq);
2085 let mut off_codes_v = Vec::with_capacity(nb_seq);
2086 let mut ll_values = Vec::with_capacity(nb_seq);
2087 let mut ml_values = Vec::with_capacity(nb_seq);
2088 let mut off_values = Vec::with_capacity(nb_seq);
2089
2090 for seq in sequences {
2091 let llc = ll_code(seq.ll);
2092 let ml_base = seq.ml - ZSTD_MINMATCH as u32;
2093 let mlc = ml_code(ml_base);
2094 let ofc = off_code(seq.of_value);
2095
2096 ll_codes_v.push(llc);
2097 ml_codes_v.push(mlc);
2098 off_codes_v.push(ofc);
2099 ll_values.push(seq.ll - LL_BASE[llc as usize]);
2100 ml_values.push(seq.ml - ML_BASE[mlc as usize]);
2101 off_values.push(if ofc > 0 {
2102 seq.of_value - (1u32 << ofc)
2103 } else {
2104 0
2105 });
2106 }
2107
2108 let ll_mode = choose_seq_mode(
2110 &ll_codes_v,
2111 MAX_LL,
2112 LL_DEFAULT_NORM_LOG,
2113 &LL_DEFAULT_NORM,
2114 LL_FSE_LOG,
2115 );
2116 let of_mode = choose_seq_mode(
2117 &off_codes_v,
2118 OF_DEFAULT_NORM.len() - 1,
2119 OF_DEFAULT_NORM_LOG,
2120 &OF_DEFAULT_NORM,
2121 OFF_FSE_LOG,
2122 );
2123 let ml_mode = choose_seq_mode(
2124 &ml_codes_v,
2125 MAX_ML,
2126 ML_DEFAULT_NORM_LOG,
2127 &ML_DEFAULT_NORM,
2128 ML_FSE_LOG,
2129 );
2130
2131 let mode_byte = (ll_mode.tag() << 6) | (of_mode.tag() << 4) | (ml_mode.tag() << 2);
2133 out.push(mode_byte);
2134
2135 let ll_table =
2137 write_seq_table_and_build(out, &ll_mode, &LL_DEFAULT_NORM, MAX_LL, LL_DEFAULT_NORM_LOG);
2138 let of_table = write_seq_table_and_build(
2139 out,
2140 &of_mode,
2141 &OF_DEFAULT_NORM,
2142 OF_DEFAULT_NORM.len() - 1,
2143 OF_DEFAULT_NORM_LOG,
2144 );
2145 let ml_table =
2146 write_seq_table_and_build(out, &ml_mode, &ML_DEFAULT_NORM, MAX_ML, ML_DEFAULT_NORM_LOG);
2147
2148 let bitstream = super::fse::encode_sequences(
2150 &ll_table,
2151 &of_table,
2152 &ml_table,
2153 &ll_codes_v,
2154 &off_codes_v,
2155 &ml_codes_v,
2156 &ll_values,
2157 &ml_values,
2158 &off_values,
2159 );
2160 out.extend_from_slice(&bitstream);
2161}
2162
2163enum SeqTableMode {
2169 Predefined,
2170 Rle(u8),
2171 Fse {
2172 norm: Vec<i16>,
2173 max_symbol: usize,
2174 table_log: u32,
2175 header_bytes: Vec<u8>,
2176 },
2177}
2178
2179impl SeqTableMode {
2180 fn tag(&self) -> u8 {
2181 match self {
2182 SeqTableMode::Predefined => SEQ_MODE_PREDEFINED,
2183 SeqTableMode::Rle(_) => SEQ_MODE_RLE,
2184 SeqTableMode::Fse { .. } => SEQ_MODE_FSE,
2185 }
2186 }
2187}
2188
2189fn normalize_counts(counts: &[u32], max_symbol: usize, table_log: u32) -> Vec<i16> {
2192 let table_size = 1u32 << table_log;
2193 let total: u64 = counts[..=max_symbol].iter().map(|&c| c as u64).sum();
2194 if total == 0 {
2195 return vec![0i16; max_symbol + 1];
2196 }
2197
2198 let mut norm = vec![0i16; max_symbol + 1];
2199
2200 let scale: u32 = 62 - table_log;
2202 let step: u64 = (1u64 << 62) / total;
2203 let v_step: u64 = 1u64 << (scale - 20);
2204 let low_threshold: u64 = total >> table_log;
2205
2206 static RTB_TABLE: [u32; 8] = [0, 473195, 504333, 520860, 550000, 700000, 750000, 830000];
2208
2209 let use_low_prob_count = total >= 2048;
2211 let low_prob_count: i16 = if use_low_prob_count { -1 } else { 1 };
2212
2213 let mut still_to_distribute = table_size as i32;
2214 let mut largest_sym = 0usize;
2215 let mut largest_prob = 0i16;
2216
2217 for s in 0..=max_symbol {
2218 if counts[s] as u64 == total {
2219 norm[s] = table_size as i16;
2221 return norm;
2222 }
2223 if counts[s] == 0 {
2224 continue;
2225 }
2226
2227 if (counts[s] as u64) <= low_threshold {
2228 norm[s] = low_prob_count;
2229 still_to_distribute -= 1;
2230 } else {
2231 let mut proba = ((counts[s] as u64 * step) >> scale) as i16;
2232 if proba < 8 {
2233 let rest_to_beat = v_step as u128 * RTB_TABLE[proba as usize] as u128;
2235 let actual = (counts[s] as u128 * step as u128) - ((proba as u128) << scale);
2236 if actual > rest_to_beat {
2237 proba += 1;
2238 }
2239 }
2240 if proba > (table_size >> 1) as i16 {
2241 proba = (table_size >> 1) as i16; }
2243 norm[s] = std::cmp::max(1, proba);
2244 still_to_distribute -= norm[s] as i32;
2245 }
2246
2247 if norm[s] > largest_prob {
2248 largest_prob = norm[s];
2249 largest_sym = s;
2250 }
2251 }
2252
2253 if -still_to_distribute >= (norm[largest_sym] >> 1) as i32 {
2255 normalize_counts_m2(&mut norm, counts, max_symbol, table_log, total);
2257 } else {
2258 norm[largest_sym] += still_to_distribute as i16;
2259 }
2260
2261 norm
2262}
2263
2264fn normalize_counts_m2(
2266 norm: &mut [i16],
2267 counts: &[u32],
2268 max_symbol: usize,
2269 table_log: u32,
2270 total: u64,
2271) {
2272 let table_size = 1u32 << table_log;
2273
2274 let mut to_distribute = table_size as i32;
2276
2277 let low_one = (total * 3) / ((to_distribute as u64) * 2);
2279 for s in 0..=max_symbol {
2280 if counts[s] == 0 {
2281 norm[s] = 0;
2282 } else if (counts[s] as u64) <= low_one {
2283 norm[s] = -1;
2284 to_distribute -= 1;
2285 } else {
2286 norm[s] = 0; }
2288 }
2289
2290 let remaining_total: u64 = counts[..=max_symbol]
2292 .iter()
2293 .enumerate()
2294 .filter(|&(s, _)| norm[s] == 0 && counts[s] > 0)
2295 .map(|(_, &c)| c as u64)
2296 .sum();
2297
2298 if remaining_total == 0 || to_distribute <= 0 {
2299 return;
2300 }
2301
2302 let v_step_log = 62u32.saturating_sub(table_log);
2303 let r_step = ((1u128 << v_step_log) * to_distribute as u128 + remaining_total as u128 / 2)
2304 / remaining_total as u128;
2305
2306 let mut tmp_total = 0u128;
2307 for s in 0..=max_symbol {
2308 if norm[s] == 0 && counts[s] > 0 {
2309 let end = tmp_total + counts[s] as u128 * r_step;
2310 let s_start = (tmp_total >> v_step_log) as i16;
2311 let s_end = (end >> v_step_log) as i16;
2312 let proba = s_end - s_start;
2313 norm[s] = std::cmp::max(1, proba);
2314 tmp_total = end;
2315 }
2316 }
2317}
2318
2319fn encode_fse_header(norm: &[i16], max_symbol: usize, table_log: u32) -> Vec<u8> {
2322 let table_size = 1u32 << table_log;
2323 let mut bb: u64 = (table_log - 5) as u64; let mut bp = 4u32;
2325 let mut out = Vec::with_capacity(32);
2326 let mut counter = 0u32;
2327
2328 let mut s = 0usize;
2329 while s <= max_symbol && counter < table_size {
2330 let prob = norm[s] as i32;
2331 let value = (prob + 1) as u32;
2332
2333 let max_remaining = table_size - counter + 1;
2334 let bits_to_read = 32 - max_remaining.leading_zeros();
2335 let low_threshold = ((1u32 << bits_to_read) - 1) - max_remaining;
2336 let mask = (1u32 << (bits_to_read - 1)) - 1;
2337
2338 if value < low_threshold {
2339 bb |= (value as u64) << bp;
2340 bp += bits_to_read - 1;
2341 } else if value <= mask {
2342 bb |= (value as u64) << bp;
2343 bp += bits_to_read;
2344 } else {
2345 let encoded = value + low_threshold;
2346 bb |= (encoded as u64) << bp;
2347 bp += bits_to_read;
2348 }
2349
2350 while bp >= 8 {
2351 out.push(bb as u8);
2352 bb >>= 8;
2353 bp -= 8;
2354 }
2355
2356 if prob > 0 {
2357 counter += prob as u32;
2358 } else if prob == -1 {
2359 counter += 1;
2360 }
2361
2362 if prob == 0 {
2364 let mut repeat = 0u32;
2366 while s + 1 + repeat as usize <= max_symbol
2367 && norm[s + 1 + repeat as usize] == 0
2368 && repeat < 3
2369 {
2370 repeat += 1;
2371 }
2372 bb |= (repeat as u64) << bp;
2373 bp += 2;
2374 while bp >= 8 {
2375 out.push(bb as u8);
2376 bb >>= 8;
2377 bp -= 8;
2378 }
2379 s += repeat as usize; while repeat == 3 {
2383 repeat = 0;
2384 while s + 1 + repeat as usize <= max_symbol
2385 && norm[s + 1 + repeat as usize] == 0
2386 && repeat < 3
2387 {
2388 repeat += 1;
2389 }
2390 bb |= (repeat as u64) << bp;
2391 bp += 2;
2392 while bp >= 8 {
2393 out.push(bb as u8);
2394 bb >>= 8;
2395 bp -= 8;
2396 }
2397 s += repeat as usize;
2398 }
2399 }
2400
2401 s += 1;
2402 }
2403
2404 if bp > 0 {
2405 out.push(bb as u8);
2406 }
2407
2408 out
2409}
2410
2411fn cross_entropy_cost(norm: &[i16], table_log: u32, counts: &[u32; 256], max_sym: usize) -> u64 {
2415 let mut cost = 0u64;
2416 for s in 0..=max_sym {
2417 if counts[s] == 0 {
2418 continue;
2419 }
2420 if s >= norm.len() || norm[s] == 0 {
2421 return u64::MAX;
2422 }
2423 let prob = if norm[s] == -1 { 1u64 } else { norm[s] as u64 };
2424 let log2_prob = 63 - prob.leading_zeros() as u64;
2426 cost += counts[s] as u64 * (table_log as u64 - log2_prob);
2427 }
2428 cost + table_log as u64 }
2430
2431fn choose_seq_mode(
2433 codes: &[u8],
2434 max_symbol_default: usize,
2435 default_log: u32,
2436 default_norm: &[i16],
2437 max_log: u32,
2438) -> SeqTableMode {
2439 if codes.is_empty() {
2440 return SeqTableMode::Predefined;
2441 }
2442
2443 let mut counts = [0u32; 256];
2445 let mut max_sym = 0usize;
2446 for &c in codes {
2447 counts[c as usize] += 1;
2448 if c as usize > max_sym {
2449 max_sym = c as usize;
2450 }
2451 }
2452
2453 let n_used = counts[..=max_sym].iter().filter(|&&c| c > 0).count();
2454
2455 if n_used == 1 {
2457 let sym = codes[0];
2458 return SeqTableMode::Rle(sym);
2459 }
2460
2461 let predefined_ok = max_sym <= max_symbol_default
2463 && codes.iter().all(|&c| {
2464 let s = c as usize;
2465 s < default_norm.len() && default_norm[s] != 0
2466 });
2467
2468 let table_log = {
2471 let min_log = 5u32;
2472 let symbol_log = if n_used <= 2 {
2473 min_log
2474 } else {
2475 std::cmp::min(max_log, (32 - (n_used as u32).leading_zeros()).max(min_log))
2476 };
2477 std::cmp::min(max_log, std::cmp::max(min_log, symbol_log))
2478 };
2479
2480 let custom_norm = normalize_counts(&counts, max_sym, table_log);
2481
2482 let all_covered = codes.iter().all(|&c| {
2484 let s = c as usize;
2485 s <= max_sym && custom_norm[s] != 0
2486 });
2487
2488 if !all_covered {
2489 if predefined_ok {
2490 return SeqTableMode::Predefined;
2491 }
2492 return SeqTableMode::Predefined;
2494 }
2495
2496 let header_bytes = encode_fse_header(&custom_norm, max_sym, table_log);
2497
2498 let _nb_seq = codes.len();
2500
2501 let predefined_cost = if predefined_ok {
2503 cross_entropy_cost(default_norm, default_log, &counts, max_sym)
2504 } else {
2505 u64::MAX
2506 };
2507
2508 let _custom_table_size = 1u64 << table_log;
2510 let mut custom_stream_cost = 0u64;
2511 for s in 0..=max_sym {
2512 if counts[s] > 0 {
2513 let prob = if custom_norm[s] == -1 {
2514 1u64
2515 } else {
2516 custom_norm[s] as u64
2517 };
2518 if prob == 0 {
2519 custom_stream_cost = u64::MAX;
2520 break;
2521 }
2522 let log2_prob = 63 - prob.leading_zeros() as u64;
2525 custom_stream_cost += counts[s] as u64 * (table_log as u64 - log2_prob);
2526 }
2527 }
2528 let custom_header_cost = header_bytes.len() as u64 * 8;
2529 let custom_total_cost = custom_header_cost + custom_stream_cost + table_log as u64;
2530
2531 if predefined_ok && predefined_cost <= custom_total_cost {
2532 SeqTableMode::Predefined
2533 } else {
2534 SeqTableMode::Fse {
2535 norm: custom_norm,
2536 max_symbol: max_sym,
2537 table_log,
2538 header_bytes,
2539 }
2540 }
2541}
2542
2543fn write_seq_table_and_build(
2545 out: &mut Vec<u8>,
2546 mode: &SeqTableMode,
2547 default_norm: &[i16],
2548 default_max_symbol: usize,
2549 default_log: u32,
2550) -> super::fse::FseCTable {
2551 match mode {
2552 SeqTableMode::Predefined => {
2553 super::fse::FseCTable::build(default_norm, default_max_symbol, default_log)
2554 }
2555 SeqTableMode::Rle(sym) => {
2556 out.push(*sym);
2557 super::fse::FseCTable::build_rle(*sym)
2558 }
2559 SeqTableMode::Fse {
2560 norm,
2561 max_symbol,
2562 table_log,
2563 header_bytes,
2564 } => {
2565 out.extend_from_slice(header_bytes);
2566 super::fse::FseCTable::build(norm, *max_symbol, *table_log)
2567 }
2568 }
2569}
2570
2571#[cfg(test)]
2572mod tests {
2573 use super::*;
2574
2575 #[test]
2576 fn compress_empty() {
2577 let compressed = compress(&[], 1);
2578 assert!(compressed.len() >= 5); assert_eq!(&compressed[..4], &ZSTD_MAGIC.to_le_bytes());
2580 }
2581
2582 #[test]
2583 fn compress_small() {
2584 let data = b"hello world";
2585 let compressed = compress(data, 1);
2586 assert_eq!(&compressed[..4], &ZSTD_MAGIC.to_le_bytes());
2587 assert!(compressed.len() > 5);
2588 }
2589
2590 #[test]
2591 fn compress_repetitive() {
2592 let data = vec![42u8; 4096];
2593 let compressed = compress(&data, 1);
2594 assert_eq!(&compressed[..4], &ZSTD_MAGIC.to_le_bytes());
2596 }
2597
2598 #[test]
2599 fn compress_real_data() {
2600 let data: Vec<u8> = (0..1024u32)
2601 .flat_map(|i| (i as f32).to_le_bytes())
2602 .collect();
2603 let compressed = compress(&data, 1);
2604 assert_eq!(&compressed[..4], &ZSTD_MAGIC.to_le_bytes());
2605 }
2606
2607 #[test]
2609 fn roundtrip_self_contained() {
2610 let test_cases: Vec<(&str, Vec<u8>)> = vec![
2611 ("zeros", vec![0u8; 4096]),
2612 (
2613 "sequential",
2614 (0..4096u32).flat_map(|i| i.to_le_bytes()).collect(),
2615 ),
2616 (
2617 "f32_data",
2618 (0..256u32)
2619 .flat_map(|i| (i as f32 * 1.5).to_le_bytes())
2620 .collect(),
2621 ),
2622 ("repetitive", b"hello world! ".repeat(100)),
2623 ("small", b"abc".to_vec()),
2624 ];
2625
2626 for (name, data) in &test_cases {
2627 let compressed = compress(data, 1);
2628 let decompressed = crate::decompress(&compressed)
2629 .unwrap_or_else(|e| panic!("{}: decompress failed: {}", name, e));
2630
2631 assert_eq!(decompressed.len(), data.len(), "{}: length mismatch", name);
2632 assert_eq!(&decompressed, data, "{}: data mismatch", name);
2633 }
2634 }
2635}