lzfse_rust/fse/
literals.rs

1use crate::bits::{BitDst, BitReader, BitSrc, BitWriter};
2use crate::kit::{CopyTypeIndex, WIDE};
3use crate::lmd::LMax;
4use crate::types::ShortBuffer;
5
6use super::block::LiteralParam;
7use super::constants::*;
8use super::decoder::{self, Decoder};
9use super::encoder::{self, Encoder};
10use super::error_kind::FseErrorKind;
11use super::Fse;
12
13use std::io;
14use std::usize;
15
16const BUF_LEN: usize = LITERALS_PER_BLOCK as usize + MAX_L_VALUE as usize + WIDE;
17
18#[repr(C)]
19pub struct Literals(Box<[u8]>, pub usize);
20
21impl Literals {
22    #[inline(always)]
23    pub unsafe fn push_unchecked_max<I>(&mut self, literals: &mut I)
24    where
25        I: ShortBuffer,
26    {
27        assert!(Fse::MAX_LITERAL_LEN as u32 <= I::SHORT_LIMIT);
28        debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
29        debug_assert!(self.1 + Fse::MAX_LITERAL_LEN as usize <= LITERALS_PER_BLOCK as usize);
30        let ptr = self.0.as_mut_ptr().add(self.1);
31        literals.read_short_raw::<CopyTypeIndex>(ptr, Fse::MAX_LITERAL_LEN as usize);
32        self.1 += Fse::MAX_LITERAL_LEN as usize;
33    }
34
35    #[inline(always)]
36    pub unsafe fn push_unchecked<I>(&mut self, literals: &mut I, n_literals: u32)
37    where
38        I: ShortBuffer,
39    {
40        debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
41        debug_assert!(self.1 + n_literals as usize <= LITERALS_PER_BLOCK as usize);
42        debug_assert!(n_literals <= I::SHORT_LIMIT);
43        let ptr = self.0.as_mut_ptr().add(self.1);
44        literals.read_short_raw::<CopyTypeIndex>(ptr, n_literals as usize);
45        self.1 += n_literals as usize;
46    }
47
48    #[allow(clippy::identity_op)]
49    pub fn load<T>(&mut self, src: T, decoder: &Decoder, param: &LiteralParam) -> crate::Result<()>
50    where
51        T: BitSrc,
52    {
53        let mut reader = BitReader::new(src, param.bits() as usize)?;
54        let state = param.state();
55        let mut state = (
56            decoder::U::new(state[0] as usize),
57            decoder::U::new(state[1] as usize),
58            decoder::U::new(state[2] as usize),
59            decoder::U::new(state[3] as usize),
60        );
61        let ptr = self.0.as_mut_ptr().cast::<u8>();
62        let n_literals = param.num() as usize;
63        debug_assert!(n_literals <= LITERALS_PER_BLOCK as usize);
64        let mut i = 0;
65        while i != n_literals {
66            // `flush` constraints:
67            // 32 bit systems: maximum of x2 10 bit pushes.
68            // 64 bit systems: maximum of x5 10 bit pushes (although we only push 4 for simplicity).
69            unsafe { *ptr.add(i + 0) = decoder.u(&mut reader, &mut state.0) };
70            unsafe { *ptr.add(i + 1) = decoder.u(&mut reader, &mut state.1) };
71            #[cfg(target_pointer_width = "32")]
72            reader.flush();
73            unsafe { *ptr.add(i + 2) = decoder.u(&mut reader, &mut state.2) };
74            unsafe { *ptr.add(i + 3) = decoder.u(&mut reader, &mut state.3) };
75            reader.flush();
76            i += 4;
77        }
78        reader.finalize()?;
79        if state
80            != (
81                decoder::U::default(),
82                decoder::U::default(),
83                decoder::U::default(),
84                decoder::U::default(),
85            )
86        {
87            return Err(FseErrorKind::BadLmdPayload.into());
88        }
89        self.1 = n_literals;
90        Ok(())
91    }
92
93    pub fn store<T>(&self, dst: &mut T, encoder: &Encoder) -> io::Result<LiteralParam>
94    where
95        T: BitDst,
96    {
97        debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
98        let mark = dst.pos();
99        let n_literals = (self.1 + 3) / 4 * 4;
100        let n_bytes = (n_literals * MAX_U_BITS as usize + 7) / 8;
101        let mut writer = BitWriter::new(dst, n_bytes)?;
102        let mut state = (
103            encoder::U::default(),
104            encoder::U::default(),
105            encoder::U::default(),
106            encoder::U::default(),
107        );
108        let ptr = self.0.as_ptr();
109        let mut i = n_literals;
110        while i != 0 {
111            // `flush` constraints:
112            // 32 bit systems: maximum of x2 10 bit pushes.
113            // 64 bit systems: maximum of x5 10 bit pushes (although we only push 4 for simplicity).
114            unsafe { encoder.u(&mut writer, &mut state.3, *ptr.add(i - 1)) };
115            unsafe { encoder.u(&mut writer, &mut state.2, *ptr.add(i - 2)) };
116            #[cfg(target_pointer_width = "32")]
117            writer.flush();
118            unsafe { encoder.u(&mut writer, &mut state.1, *ptr.add(i - 3)) };
119            unsafe { encoder.u(&mut writer, &mut state.0, *ptr.add(i - 4)) };
120            writer.flush();
121            i -= 4;
122        }
123        let state = [
124            u32::from(state.0) as u16,
125            u32::from(state.1) as u16,
126            u32::from(state.2) as u16,
127            u32::from(state.3) as u16,
128        ];
129        let bits = writer.finalize()? as u32;
130        let n_payload_bytes = (dst.pos() - mark) as u32;
131        let n_literals = (self.1 as u32 + 3) / 4 * 4;
132        Ok(LiteralParam::new(n_literals, n_payload_bytes, bits, state).expect("internal error"))
133    }
134
135    #[inline(always)]
136    pub fn pad(&mut self) {
137        debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
138        self.pad_u(unsafe { *self.0.get_unchecked(0) });
139    }
140
141    #[inline(always)]
142    pub fn pad_u(&mut self, u: u8) {
143        debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
144        unsafe { self.0.get_unchecked_mut(self.1..).get_unchecked_mut(..4) }.fill(u);
145    }
146
147    #[inline(always)]
148    pub fn len(&self) -> usize {
149        debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
150        self.1
151    }
152
153    #[inline(always)]
154    pub fn reset(&mut self) {
155        debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
156        self.1 = 0;
157    }
158
159    #[inline(always)]
160    pub fn as_ptr(&self) -> *const u8 {
161        self.0.as_ptr()
162    }
163}
164
165impl AsRef<[u8]> for Literals {
166    #[inline(always)]
167    fn as_ref(&self) -> &[u8] {
168        debug_assert!(self.1 <= LITERALS_PER_BLOCK as usize);
169        unsafe { self.0.get_unchecked(..self.1) }
170    }
171}
172
173impl Default for Literals {
174    fn default() -> Self {
175        Self(vec![0u8; BUF_LEN].into_boxed_slice(), 0)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use crate::fse::Weights;
182
183    use test_kit::{Rng, Seq};
184
185    use super::*;
186
187    /// Test buddy.
188    #[derive(Default)]
189    struct Buddy {
190        weights: Weights,
191        encoder: Encoder,
192        decoder: Decoder,
193        src: Literals,
194        dst: Literals,
195        param: LiteralParam,
196        enc: Vec<u8>,
197        n_literals: usize,
198    }
199
200    impl Buddy {
201        #[allow(dead_code)]
202        pub fn push(&mut self, mut literals: &[u8]) {
203            self.src.reset();
204            self.n_literals = literals.len();
205            assert!(self.n_literals <= LITERALS_PER_BLOCK as usize);
206            unsafe { self.src.push_unchecked(&mut literals, self.n_literals as u32) }
207            assert_eq!(literals.len(), 0);
208        }
209
210        fn encode(&mut self) -> io::Result<()> {
211            let u = self.weights.load(&[], self.src.as_ref());
212            self.src.pad_u(u);
213            self.encoder.init(&self.weights);
214            self.enc.clear();
215            self.enc.resize(8, 0);
216            self.param = self.src.store(&mut self.enc, &self.encoder)?;
217            assert_eq!(self.enc.len(), 8 + self.param.n_payload_bytes() as usize);
218            Ok(())
219        }
220
221        fn decode(&mut self) -> io::Result<()> {
222            self.decoder.init(&self.weights);
223            self.dst.load(self.enc.as_slice(), &self.decoder, &self.param)?;
224            Ok(())
225        }
226
227        fn check(&self) -> bool {
228            assert!(self.n_literals <= self.src.len());
229            assert!(self.n_literals <= self.dst.len());
230            self.src.as_ref()[..self.n_literals] == self.dst.as_ref()[..self.n_literals]
231        }
232
233        fn check_encode_decode(&mut self, literals: &[u8]) -> io::Result<bool> {
234            self.push(literals);
235            self.encode()?;
236            self.decode()?;
237            Ok(self.check())
238        }
239    }
240
241    #[test]
242    fn empty() -> io::Result<()> {
243        let mut buddy = Buddy::default();
244        assert!(buddy.check_encode_decode(&[])?);
245        Ok(())
246    }
247
248    #[test]
249    #[ignore = "expensive"]
250    fn incremental() -> io::Result<()> {
251        let bytes = Seq::default().take(LITERALS_PER_BLOCK as usize + 1).collect::<Vec<_>>();
252        let mut buddy = Buddy::default();
253        for literal_len in 1..bytes.len() {
254            assert!(buddy.check_encode_decode(&bytes[..literal_len])?);
255        }
256        Ok(())
257    }
258
259    // Random literals.
260    #[test]
261    #[ignore = "expensive"]
262    fn rng_1() -> io::Result<()> {
263        let mut bytes = vec![0; LITERALS_PER_BLOCK as usize];
264        let mut buddy = Buddy::default();
265        for literal_len in 0..bytes.len() {
266            bytes.clear();
267            Seq::new(Rng::new(literal_len as u32)).take(literal_len).for_each(|u| bytes.push(u));
268            assert!(buddy.check_encode_decode(&bytes[..literal_len])?);
269        }
270        Ok(())
271    }
272
273    // Random literals, incremental entropy.
274    #[test]
275    #[ignore = "expensive"]
276    fn rng_2() -> io::Result<()> {
277        let mut bytes = vec![0; 0x1000];
278        let mut buddy = Buddy::default();
279        for entropy in 0..0xFF {
280            let mask = entropy * 0x0101_0101;
281            for literal_len in 0..bytes.len() {
282                bytes.clear();
283                Seq::masked(Rng::new(literal_len as u32), mask)
284                    .take(literal_len)
285                    .for_each(|u| bytes.push(u));
286                assert!(buddy.check_encode_decode(&bytes[..literal_len])?);
287            }
288        }
289        Ok(())
290    }
291
292    // Bitwise mutation. We are looking to break the decoder. In all cases the
293    // decoder should reject invalid data via `Err(error)` and exit gracefully. It should not hang/
294    // segfault/ panic/ trip debug assertions or break in any other fashion.
295    #[test]
296    #[ignore = "expensive"]
297    fn mutate_1() -> io::Result<()> {
298        let mut buddy = Buddy::default();
299        let mut bytes = Vec::default();
300        for seed in 0..0x0100 {
301            bytes.clear();
302            Seq::new(Rng::new(seed)).take(0x1000).for_each(|u| bytes.push(u));
303            assert!(buddy.check_encode_decode(&bytes)?);
304            for index in 0..buddy.enc.len() {
305                for n_bit in 0..8 {
306                    let bit = 1 << n_bit;
307                    buddy.enc[index] ^= bit;
308                    let _ = buddy.decode();
309                    buddy.enc[index] ^= bit;
310                }
311            }
312            assert!(buddy.check_encode_decode(&bytes)?);
313        }
314        Ok(())
315    }
316
317    // Byte mutation. We are looking to break the decoder. In all cases the
318    // decoder should reject invalid data via `Err(error)` and exit gracefully. It should not hang/
319    // segfault/ panic/ trip debug assertions or break in any other fashion.
320    #[test]
321    #[ignore = "expensive"]
322    fn mutate_2() -> io::Result<()> {
323        let mut buddy = Buddy::default();
324        let mut bytes = Vec::default();
325        for seed in 0..0x0100 {
326            bytes.clear();
327            Seq::new(Rng::new(seed)).take(0x0100).for_each(|u| bytes.push(u));
328            assert!(buddy.check_encode_decode(&bytes)?);
329            for index in 0..buddy.enc.len() {
330                for byte in 0..=0xFF {
331                    buddy.enc[index] ^= byte;
332                    let _ = buddy.decode();
333                    buddy.enc[index] ^= byte;
334                }
335            }
336            assert!(buddy.check_encode_decode(&bytes)?);
337        }
338        Ok(())
339    }
340}