lz_str/
compress.rs

1use crate::constants::BASE64_KEY;
2use crate::constants::CLOSE_CODE;
3use crate::constants::START_CODE_BITS;
4use crate::constants::U16_CODE;
5use crate::constants::U8_CODE;
6use crate::constants::URI_KEY;
7use crate::IntoWideIter;
8use std::collections::hash_map::Entry as HashMapEntry;
9use std::convert::TryInto;
10
11#[cfg(not(feature = "rustc-hash"))]
12type HashMap<K, V> = std::collections::HashMap<K, V>;
13
14#[cfg(not(feature = "rustc-hash"))]
15type HashSet<T> = std::collections::HashSet<T>;
16
17#[cfg(feature = "rustc-hash")]
18type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;
19
20#[cfg(feature = "rustc-hash")]
21type HashSet<T> = rustc_hash::FxHashSet<T>;
22
23/// The number of "base codes",
24/// the default codes of all streams.
25///
26/// These are U8_CODE, U16_CODE, and CLOSE_CODE.
27const NUM_BASE_CODES: usize = 3;
28
29#[derive(Debug)]
30pub(crate) struct CompressContext<'a, F> {
31    dictionary: HashMap<&'a [u16], u32>,
32    dictionary_to_create: HashSet<u16>,
33
34    /// The current word, w,
35    /// in terms of indexes into the input.
36    w_start_idx: usize,
37    w_end_idx: usize,
38
39    // The counter for increasing the current number of bits in a code.
40    // The max size of this is 1 << max(num_bits) == 1 + u32::MAX, so we use u64.
41    enlarge_in: u64,
42
43    /// The input buffer.
44    input: &'a [u16],
45
46    /// The output buffer.
47    output: Vec<u16>,
48
49    /// The bit buffer.
50    bit_buffer: u16,
51
52    /// The current number of bits in a code.
53    ///
54    /// This is a u8,
55    /// because we currently assume the max code size is 32 bits.
56    /// 32 < u8::MAX
57    num_bits: u8,
58
59    /// The current bit position.
60    bit_position: u8,
61
62    /// The maximum # of bits per char.
63    ///
64    /// This value may not exceed 16,
65    /// as the reference implementation will also not handle values over 16.
66    bits_per_char: u8,
67
68    /// A transformation function to map a u16 to another u16,
69    /// before appending it to the output buffer.
70    to_char: F,
71}
72
73impl<'a, F> CompressContext<'a, F>
74where
75    F: Fn(u16) -> u16,
76{
77    /// Make a new [`CompressContext`].
78    ///
79    /// # Panics
80    /// Panics if `bits_per_char` exceeds the number of bits in a u16.
81    #[inline]
82    pub fn new(input: &'a [u16], bits_per_char: u8, to_char: F) -> Self {
83        assert!(usize::from(bits_per_char) <= std::mem::size_of::<u16>() * 8);
84
85        CompressContext {
86            dictionary: HashMap::default(),
87            dictionary_to_create: HashSet::default(),
88
89            w_start_idx: 0,
90            w_end_idx: 0,
91
92            enlarge_in: 2,
93
94            input,
95            output: Vec::with_capacity(input.len() >> 1), // Lowball, assume we can get a 50% reduction in size.
96
97            bit_buffer: 0,
98
99            num_bits: START_CODE_BITS,
100
101            bit_position: 0,
102            bits_per_char,
103            to_char,
104        }
105    }
106
107    #[inline]
108    pub fn produce_w(&mut self) {
109        let w = &self.input[self.w_start_idx..self.w_end_idx];
110
111        match w
112            .first()
113            .map(|first_w_char| self.dictionary_to_create.take(first_w_char))
114        {
115            Some(Some(first_w_char)) => {
116                if first_w_char < 256 {
117                    self.write_bits(self.num_bits, U8_CODE.into());
118                    self.write_bits(8, first_w_char.into());
119                } else {
120                    self.write_bits(self.num_bits, U16_CODE.into());
121                    self.write_bits(16, first_w_char.into());
122                }
123                self.decrement_enlarge_in();
124            }
125            None | Some(None) => {
126                self.write_bits(self.num_bits, *self.dictionary.get(w).unwrap());
127            }
128        }
129        self.decrement_enlarge_in();
130    }
131
132    /// Append the bit to the bit buffer.
133    #[inline]
134    pub fn write_bit(&mut self, bit: bool) {
135        self.bit_buffer = (self.bit_buffer << 1) | u16::from(bit);
136        self.bit_position += 1;
137
138        if self.bit_position == self.bits_per_char {
139            self.bit_position = 0;
140            let output_char = (self.to_char)(self.bit_buffer);
141            self.bit_buffer = 0;
142
143            self.output.push(output_char);
144        }
145    }
146
147    #[inline]
148    pub fn write_bits(&mut self, n: u8, mut value: u32) {
149        for _ in 0..n {
150            self.write_bit(value & 1 == 1);
151            value >>= 1;
152        }
153    }
154
155    #[inline]
156    pub fn decrement_enlarge_in(&mut self) {
157        self.enlarge_in -= 1;
158        if self.enlarge_in == 0 {
159            self.enlarge_in = 1 << self.num_bits;
160            self.num_bits += 1;
161        }
162    }
163
164    /// Compress a `u16`. This represents a wide char.
165    #[inline]
166    pub fn write_u16(&mut self, i: usize) {
167        let c = &self.input[i];
168
169        let dictionary_len = self.dictionary.len();
170        if let HashMapEntry::Vacant(entry) = self.dictionary.entry(std::slice::from_ref(c)) {
171            entry.insert((dictionary_len + NUM_BASE_CODES).try_into().unwrap());
172            self.dictionary_to_create.insert(*c);
173        }
174
175        // wc = w + c.
176        let wc = &self.input[self.w_start_idx..self.w_end_idx + 1];
177
178        let dictionary_len = self.dictionary.len();
179        match self.dictionary.entry(wc) {
180            HashMapEntry::Occupied(_entry) => {
181                // w = wc.
182                self.w_end_idx += 1;
183            }
184            HashMapEntry::Vacant(entry) => {
185                // Add wc to the dictionary.
186                entry.insert((dictionary_len + NUM_BASE_CODES).try_into().unwrap());
187
188                // Originally, this was before adding wc to the dict.
189                // However, we only use the dict for a lookup that will crash if it fails in produce_w.
190                // Therefore, moving it here should be fine.
191                self.produce_w();
192
193                // w = c.
194                self.w_start_idx = i;
195                self.w_end_idx = i + 1;
196            }
197        }
198    }
199
200    /// Finish the stream and get the final result.
201    #[inline]
202    pub fn finish(mut self) -> Vec<u16> {
203        let w = &self.input[self.w_start_idx..self.w_end_idx];
204
205        // Output the code for w.
206        if !w.is_empty() {
207            self.produce_w();
208        }
209
210        // Mark the end of the stream
211        self.write_bits(self.num_bits, CLOSE_CODE.into());
212
213        let str_len = self.output.len();
214        // Flush the last char
215        while self.output.len() == str_len {
216            self.write_bit(false);
217        }
218
219        self.output
220    }
221
222    /// Perform the compression and return the result.
223    pub fn compress(mut self) -> Vec<u16> {
224        for i in 0..self.input.len() {
225            self.write_u16(i);
226        }
227        self.finish()
228    }
229}
230
231/// Compress a string into a [`Vec<u16>`].
232///
233/// The resulting [`Vec`] may contain invalid UTF16.
234#[inline]
235pub fn compress(data: impl IntoWideIter) -> Vec<u16> {
236    let data: Vec<u16> = data.into_wide_iter().collect();
237    compress_internal(&data, 16, std::convert::identity)
238}
239
240/// Compress a string as a valid [`String`].
241///
242/// This function converts the result back into a Rust [`String`] since it is guaranteed to be valid UTF16.
243#[inline]
244pub fn compress_to_utf16(data: impl IntoWideIter) -> String {
245    let data: Vec<u16> = data.into_wide_iter().collect();
246    let compressed = compress_internal(&data, 15, |n| n + 32);
247    let mut compressed =
248        String::from_utf16(&compressed).expect("`compress_to_utf16 output was not valid unicode`");
249    compressed.push(' ');
250
251    compressed
252}
253
254/// Compress a string into a [`String`], which can be safely used in a uri.
255///
256/// This function converts the result back into a Rust [`String`] since it is guaranteed to be valid unicode.
257#[inline]
258pub fn compress_to_encoded_uri_component(data: impl IntoWideIter) -> String {
259    let data: Vec<u16> = data.into_wide_iter().collect();
260    let compressed = compress_internal(&data, 6, |n| u16::from(URI_KEY[usize::from(n)]));
261
262    String::from_utf16(&compressed)
263        .expect("`compress_to_encoded_uri_component` output was not valid unicode`")
264}
265
266/// Compress a string into a [`String`], which is valid base64.
267///
268/// This function converts the result back into a Rust [`String`] since it is guaranteed to be valid unicode.
269pub fn compress_to_base64(data: impl IntoWideIter) -> String {
270    let data: Vec<u16> = data.into_wide_iter().collect();
271    let mut compressed = compress_internal(&data, 6, |n| u16::from(BASE64_KEY[usize::from(n)]));
272
273    let mod_4 = compressed.len() % 4;
274
275    if mod_4 != 0 {
276        for _ in mod_4..(4 + 1) {
277            compressed.push(u16::from(b'='));
278        }
279    }
280
281    String::from_utf16(&compressed).expect("`compress_to_base64` output was not valid unicode`")
282}
283
284/// Compress a string into a [`Vec<u8>`].
285pub fn compress_to_uint8_array(data: impl IntoWideIter) -> Vec<u8> {
286    compress(data)
287        .into_iter()
288        .flat_map(|value| value.to_be_bytes())
289        .collect()
290}
291
292/// The internal function for compressing data.
293///
294/// All other compression functions are built on top of this.
295/// It generally should not be used directly.
296#[inline]
297pub fn compress_internal<F>(data: &[u16], bits_per_char: u8, to_char: F) -> Vec<u16>
298where
299    F: Fn(u16) -> u16,
300{
301    let ctx = CompressContext::new(data, bits_per_char, to_char);
302    ctx.compress()
303}