1use crate::bits::BitDst;
2use crate::lmd::{LMax, LmdPack, MMax, MatchDistance};
3use crate::ops::WriteShort;
4use crate::types::ShortBuffer;
5
6use super::block::FseBlock;
7use super::constants::*;
8use super::encoder::Encoder;
9use super::literals::Literals;
10use super::lmds::Lmds;
11use super::weights::Weights;
12use super::Fse;
13
14use std::convert::AsRef;
15use std::io;
16
17#[derive(Default)]
18pub struct Buffer {
19 literals: Literals,
20 lmds: Lmds,
21 n_match_bytes: u32,
22 match_distance: u32,
23}
24
25impl Buffer {
26 pub fn pad(&mut self) {
27 self.literals.pad();
28 }
29
30 pub fn init_weights(&self, weights: &mut Weights) -> u8 {
31 weights.load(self.lmds.as_ref(), self.literals.as_ref())
32 }
33
34 pub fn store<O>(&self, dst: &mut O, encoder: &Encoder) -> io::Result<FseBlock>
35 where
36 O: BitDst + WriteShort,
37 {
38 let literal_param = self.literals.store(dst, encoder)?;
39 let lmd_param = self.lmds.store(dst, encoder)?;
40 Ok(FseBlock::new(self.n_raw_bytes(), literal_param, lmd_param).expect("internal error"))
41 }
42
43 #[inline(always)]
45 pub fn push<I>(
46 &mut self,
47 literals: &mut I,
48 match_len: &mut u32,
49 match_distance: MatchDistance<Fse>,
50 ) -> bool
51 where
52 I: ShortBuffer,
53 {
54 let match_distance = match_distance.get();
55 debug_assert!(literals.len() != 0 || *match_len != 0);
56 while literals.len() > Fse::MAX_LITERAL_LEN as usize {
57 if self.lmds.len() == LMDS_PER_BLOCK as usize {
58 return false;
59 }
60 let limit = LITERALS_PER_BLOCK - self.literals.len() as u32;
61 if Fse::MAX_LITERAL_LEN as u32 <= limit {
62 unsafe { self.literals.push_unchecked_max(literals) };
63 unsafe { self.push_l(Fse::MAX_LITERAL_LEN) };
64 } else if limit != 0 {
65 unsafe { self.literals.push_unchecked(literals, limit) };
66 unsafe { self.push_l(limit as u16) };
67 return false;
68 } else {
69 return false;
70 }
71 }
72 if self.lmds.len() == LMDS_PER_BLOCK as usize {
73 return false;
74 }
75 let mut literal_len = literals.len();
76 let limit = LITERALS_PER_BLOCK - self.literals.len() as u32;
77 if literal_len <= limit as usize {
78 unsafe { self.literals.push_unchecked(literals, literal_len as u32) };
79 } else if limit != 0 {
80 unsafe { self.literals.push_unchecked(literals, limit) };
81 unsafe { self.push_l(limit as u16) };
82 return false;
83 } else {
84 return false;
85 }
86 while *match_len > Fse::MAX_MATCH_LEN as u32 {
87 unsafe { self.push_lmd(literal_len as u16, Fse::MAX_MATCH_LEN, match_distance) };
88 *match_len -= Fse::MAX_MATCH_LEN as u32;
89 literal_len = 0;
90 if self.lmds.len() == LMDS_PER_BLOCK as usize {
91 return false;
92 }
93 }
94 unsafe { self.push_lmd(literal_len as u16, *match_len as u16, match_distance) };
95 *match_len = 0;
96 true
97 }
98
99 #[inline(always)]
100 unsafe fn push_l(&mut self, l: u16) {
101 debug_assert!(l <= Fse::MAX_LITERAL_LEN);
102 self.match_distance = 1;
103 self.lmds.push_unchecked(LmdPack::<Fse>::new_unchecked(l, 0, 1));
104 }
105
106 #[inline(always)]
107 unsafe fn push_lmd(&mut self, l: u16, m: u16, mut d: u32) {
108 debug_assert_ne!(d, 0);
109 if self.match_distance == d {
110 self.match_distance = d;
111 d = 0;
112 } else {
113 self.match_distance = d;
114 }
115 self.lmds.push_unchecked(LmdPack::<Fse>::new_unchecked(l, m, d));
116 self.n_match_bytes += m as u32;
117 }
118
119 #[inline(always)]
120 pub fn reset(&mut self) {
121 self.literals.reset();
122 self.lmds.reset();
123 self.n_match_bytes = 0;
124 self.match_distance = 0;
125 }
126
127 #[inline(always)]
128 fn n_raw_bytes(&self) -> u32 {
129 self.literals.len() as u32 + self.n_match_bytes
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use crate::fse::Fse;
136 use crate::lmd::{LmdPack, MatchDistance, MatchDistanceUnpack, MatchLen};
137 use crate::lz::LzWriter;
138 use crate::{fse::constants::*, lmd::DMax};
139
140 use test_kit::{Rng, Seq};
141
142 use super::*;
143
144 macro_rules! test_push {
145 ($name:ident, $lm:expr, $mm:expr) => {
146 #[test]
147 #[ignore = "expensive"]
148 fn $name() -> crate::Result<()> {
149 let bytes =
150 Seq::default().take(LITERALS_PER_BLOCK as usize + 0x1000).collect::<Vec<_>>();
151 let mut buffer = Buffer::default();
152 let mut dst_a = Vec::default();
153 let mut dst_b = Vec::default();
154 for seed in 0..0x0001_0000 {
155 let mut rng = Rng::new(seed);
156 let mut bytes = bytes.as_slice();
157 loop {
158 let l = (rng.gen() as usize % 0x1000 + 1).min(bytes.len());
159 let m = rng.gen() % 0x1000;
160 let d = (rng.gen() % Fse::MAX_MATCH_DISTANCE).min(dst_a.len() as u32) + 1;
161 let match_distance = MatchDistance::new(d);
162 let literals = &bytes[..l];
163 bytes = &bytes[l..];
164 let mut literals_mut = literals;
165 let mut match_len_mut = m;
166 let ok = buffer.push(&mut literals_mut, &mut match_len_mut, match_distance);
167 if ok {
168 assert_eq!(literals_mut.len(), 0);
169 assert_eq!(match_len_mut, 0);
170 }
171 let literals = &literals[..literals.len() - literals_mut.len()];
172 let mut match_len = m - match_len_mut;
173 dst_a.write_bytes_long(literals)?;
174 while match_len != 0 {
175 let match_distance_m =
176 MatchLen::new(match_len.min(Fse::MAX_MATCH_LEN as u32));
177 dst_a.write_match(match_distance_m, match_distance.into())?;
178 match_len -= match_distance_m.get();
179 }
180 if !ok {
181 break;
182 }
183 }
184 let mut match_distance = MatchDistanceUnpack::default();
185 let mut bytes = buffer.literals.as_ref();
186 for &LmdPack(literal_len_pack, match_len_pack, match_distance_pack) in
187 buffer.lmds.as_ref()
188 {
189 let literals = &bytes[..literal_len_pack.get() as usize];
190 match_distance.substitute(match_distance_pack);
191 bytes = &bytes[literal_len_pack.get() as usize..];
192 dst_b.write_bytes_long(literals)?;
193 dst_b.write_match(match_len_pack.into(), match_distance)?;
194 }
195 assert!(dst_a == dst_b);
196 buffer.reset();
197 dst_a.clear();
198 dst_b.clear();
199 }
200 Ok(())
201 }
202 };
203 }
204
205 test_push!(push_0, 0x1000, 0x1000);
206 test_push!(push_1, 0x1000, 0x0010);
207 test_push!(push_2, 0x0010, 0x1000);
208 test_push!(push_3, 0x0010, 0x0010);
209}