lzfse_rust/fse/
literals.rs1use crate::bits::{BitDst, BitReader, BitSrc, BitWriter};
2use crate::kit::{CopyTypeIndex, WIDE};
3use crate::lmd::LMax;
4use crate::types::ShortBuffer;
5
6use super::block::LiteralParam;
7use super::constants::*;
8use super::decoder::{self, Decoder};
9use super::encoder::{self, Encoder};
10use super::error_kind::FseErrorKind;
11use super::Fse;
12
13use std::io;
14use std::usize;
15
16const BUF_LEN: usize = LITERALS_PER_BLOCK as usize + MAX_L_VALUE as usize + WIDE;
17
18#[repr(C)]
19pub struct Literals(Box<[u8]>, pub usize);
20
21impl Literals {
22 #[inline(always)]
23 pub unsafe fn push_unchecked_max<I>(&mut self, literals: &mut I)
24 where
25 I: ShortBuffer,
26 {
27 assert!(Fse::MAX_LITERAL_LEN as u32 <= I::SHORT_LIMIT);
28 debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
29 debug_assert!(self.1 + Fse::MAX_LITERAL_LEN as usize <= LITERALS_PER_BLOCK as usize);
30 let ptr = self.0.as_mut_ptr().add(self.1);
31 literals.read_short_raw::<CopyTypeIndex>(ptr, Fse::MAX_LITERAL_LEN as usize);
32 self.1 += Fse::MAX_LITERAL_LEN as usize;
33 }
34
35 #[inline(always)]
36 pub unsafe fn push_unchecked<I>(&mut self, literals: &mut I, n_literals: u32)
37 where
38 I: ShortBuffer,
39 {
40 debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
41 debug_assert!(self.1 + n_literals as usize <= LITERALS_PER_BLOCK as usize);
42 debug_assert!(n_literals <= I::SHORT_LIMIT);
43 let ptr = self.0.as_mut_ptr().add(self.1);
44 literals.read_short_raw::<CopyTypeIndex>(ptr, n_literals as usize);
45 self.1 += n_literals as usize;
46 }
47
48 #[allow(clippy::identity_op)]
49 pub fn load<T>(&mut self, src: T, decoder: &Decoder, param: &LiteralParam) -> crate::Result<()>
50 where
51 T: BitSrc,
52 {
53 let mut reader = BitReader::new(src, param.bits() as usize)?;
54 let state = param.state();
55 let mut state = (
56 decoder::U::new(state[0] as usize),
57 decoder::U::new(state[1] as usize),
58 decoder::U::new(state[2] as usize),
59 decoder::U::new(state[3] as usize),
60 );
61 let ptr = self.0.as_mut_ptr().cast::<u8>();
62 let n_literals = param.num() as usize;
63 debug_assert!(n_literals <= LITERALS_PER_BLOCK as usize);
64 let mut i = 0;
65 while i != n_literals {
66 unsafe { *ptr.add(i + 0) = decoder.u(&mut reader, &mut state.0) };
70 unsafe { *ptr.add(i + 1) = decoder.u(&mut reader, &mut state.1) };
71 #[cfg(target_pointer_width = "32")]
72 reader.flush();
73 unsafe { *ptr.add(i + 2) = decoder.u(&mut reader, &mut state.2) };
74 unsafe { *ptr.add(i + 3) = decoder.u(&mut reader, &mut state.3) };
75 reader.flush();
76 i += 4;
77 }
78 reader.finalize()?;
79 if state
80 != (
81 decoder::U::default(),
82 decoder::U::default(),
83 decoder::U::default(),
84 decoder::U::default(),
85 )
86 {
87 return Err(FseErrorKind::BadLmdPayload.into());
88 }
89 self.1 = n_literals;
90 Ok(())
91 }
92
93 pub fn store<T>(&self, dst: &mut T, encoder: &Encoder) -> io::Result<LiteralParam>
94 where
95 T: BitDst,
96 {
97 debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
98 let mark = dst.pos();
99 let n_literals = (self.1 + 3) / 4 * 4;
100 let n_bytes = (n_literals * MAX_U_BITS as usize + 7) / 8;
101 let mut writer = BitWriter::new(dst, n_bytes)?;
102 let mut state = (
103 encoder::U::default(),
104 encoder::U::default(),
105 encoder::U::default(),
106 encoder::U::default(),
107 );
108 let ptr = self.0.as_ptr();
109 let mut i = n_literals;
110 while i != 0 {
111 unsafe { encoder.u(&mut writer, &mut state.3, *ptr.add(i - 1)) };
115 unsafe { encoder.u(&mut writer, &mut state.2, *ptr.add(i - 2)) };
116 #[cfg(target_pointer_width = "32")]
117 writer.flush();
118 unsafe { encoder.u(&mut writer, &mut state.1, *ptr.add(i - 3)) };
119 unsafe { encoder.u(&mut writer, &mut state.0, *ptr.add(i - 4)) };
120 writer.flush();
121 i -= 4;
122 }
123 let state = [
124 u32::from(state.0) as u16,
125 u32::from(state.1) as u16,
126 u32::from(state.2) as u16,
127 u32::from(state.3) as u16,
128 ];
129 let bits = writer.finalize()? as u32;
130 let n_payload_bytes = (dst.pos() - mark) as u32;
131 let n_literals = (self.1 as u32 + 3) / 4 * 4;
132 Ok(LiteralParam::new(n_literals, n_payload_bytes, bits, state).expect("internal error"))
133 }
134
135 #[inline(always)]
136 pub fn pad(&mut self) {
137 debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
138 self.pad_u(unsafe { *self.0.get_unchecked(0) });
139 }
140
141 #[inline(always)]
142 pub fn pad_u(&mut self, u: u8) {
143 debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
144 unsafe { self.0.get_unchecked_mut(self.1..).get_unchecked_mut(..4) }.fill(u);
145 }
146
147 #[inline(always)]
148 pub fn len(&self) -> usize {
149 debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
150 self.1
151 }
152
153 #[inline(always)]
154 pub fn reset(&mut self) {
155 debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
156 self.1 = 0;
157 }
158
159 #[inline(always)]
160 pub fn as_ptr(&self) -> *const u8 {
161 self.0.as_ptr()
162 }
163}
164
165impl AsRef<[u8]> for Literals {
166 #[inline(always)]
167 fn as_ref(&self) -> &[u8] {
168 debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
169 unsafe { self.0.get_unchecked(..self.1) }
170 }
171}
172
173impl Default for Literals {
174 fn default() -> Self {
175 Self(vec![0u8; BUF_LEN].into_boxed_slice(), 0)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use crate::fse::Weights;
182
183 use test_kit::{Rng, Seq};
184
185 use super::*;
186
187 #[derive(Default)]
189 struct Buddy {
190 weights: Weights,
191 encoder: Encoder,
192 decoder: Decoder,
193 src: Literals,
194 dst: Literals,
195 param: LiteralParam,
196 enc: Vec<u8>,
197 n_literals: usize,
198 }
199
200 impl Buddy {
201 #[allow(dead_code)]
202 pub fn push(&mut self, mut literals: &[u8]) {
203 self.src.reset();
204 self.n_literals = literals.len();
205 assert!(self.n_literals <= LITERALS_PER_BLOCK as usize);
206 unsafe { self.src.push_unchecked(&mut literals, self.n_literals as u32) }
207 assert_eq!(literals.len(), 0);
208 }
209
210 fn encode(&mut self) -> io::Result<()> {
211 let u = self.weights.load(&[], self.src.as_ref());
212 self.src.pad_u(u);
213 self.encoder.init(&self.weights);
214 self.enc.clear();
215 self.enc.resize(8, 0);
216 self.param = self.src.store(&mut self.enc, &self.encoder)?;
217 assert_eq!(self.enc.len(), 8 + self.param.n_payload_bytes() as usize);
218 Ok(())
219 }
220
221 fn decode(&mut self) -> io::Result<()> {
222 self.decoder.init(&self.weights);
223 self.dst.load(self.enc.as_slice(), &self.decoder, &self.param)?;
224 Ok(())
225 }
226
227 fn check(&self) -> bool {
228 assert!(self.n_literals <= self.src.len());
229 assert!(self.n_literals <= self.dst.len());
230 self.src.as_ref()[..self.n_literals] == self.dst.as_ref()[..self.n_literals]
231 }
232
233 fn check_encode_decode(&mut self, literals: &[u8]) -> io::Result<bool> {
234 self.push(literals);
235 self.encode()?;
236 self.decode()?;
237 Ok(self.check())
238 }
239 }
240
241 #[test]
242 fn empty() -> io::Result<()> {
243 let mut buddy = Buddy::default();
244 assert!(buddy.check_encode_decode(&[])?);
245 Ok(())
246 }
247
248 #[test]
249 #[ignore = "expensive"]
250 fn incremental() -> io::Result<()> {
251 let bytes = Seq::default().take(LITERALS_PER_BLOCK as usize + 1).collect::<Vec<_>>();
252 let mut buddy = Buddy::default();
253 for literal_len in 1..bytes.len() {
254 assert!(buddy.check_encode_decode(&bytes[..literal_len])?);
255 }
256 Ok(())
257 }
258
259 #[test]
261 #[ignore = "expensive"]
262 fn rng_1() -> io::Result<()> {
263 let mut bytes = vec![0; LITERALS_PER_BLOCK as usize];
264 let mut buddy = Buddy::default();
265 for literal_len in 0..bytes.len() {
266 bytes.clear();
267 Seq::new(Rng::new(literal_len as u32)).take(literal_len).for_each(|u| bytes.push(u));
268 assert!(buddy.check_encode_decode(&bytes[..literal_len])?);
269 }
270 Ok(())
271 }
272
273 #[test]
275 #[ignore = "expensive"]
276 fn rng_2() -> io::Result<()> {
277 let mut bytes = vec![0; 0x1000];
278 let mut buddy = Buddy::default();
279 for entropy in 0..0xFF {
280 let mask = entropy * 0x0101_0101;
281 for literal_len in 0..bytes.len() {
282 bytes.clear();
283 Seq::masked(Rng::new(literal_len as u32), mask)
284 .take(literal_len)
285 .for_each(|u| bytes.push(u));
286 assert!(buddy.check_encode_decode(&bytes[..literal_len])?);
287 }
288 }
289 Ok(())
290 }
291
292 #[test]
296 #[ignore = "expensive"]
297 fn mutate_1() -> io::Result<()> {
298 let mut buddy = Buddy::default();
299 let mut bytes = Vec::default();
300 for seed in 0..0x0100 {
301 bytes.clear();
302 Seq::new(Rng::new(seed)).take(0x1000).for_each(|u| bytes.push(u));
303 assert!(buddy.check_encode_decode(&bytes)?);
304 for index in 0..buddy.enc.len() {
305 for n_bit in 0..8 {
306 let bit = 1 << n_bit;
307 buddy.enc[index] ^= bit;
308 let _ = buddy.decode();
309 buddy.enc[index] ^= bit;
310 }
311 }
312 assert!(buddy.check_encode_decode(&bytes)?);
313 }
314 Ok(())
315 }
316
317 #[test]
321 #[ignore = "expensive"]
322 fn mutate_2() -> io::Result<()> {
323 let mut buddy = Buddy::default();
324 let mut bytes = Vec::default();
325 for seed in 0..0x0100 {
326 bytes.clear();
327 Seq::new(Rng::new(seed)).take(0x0100).for_each(|u| bytes.push(u));
328 assert!(buddy.check_encode_decode(&bytes)?);
329 for index in 0..buddy.enc.len() {
330 for byte in 0..=0xFF {
331 buddy.enc[index] ^= byte;
332 let _ = buddy.decode();
333 buddy.enc[index] ^= byte;
334 }
335 }
336 assert!(buddy.check_encode_decode(&bytes)?);
337 }
338 Ok(())
339 }
340}