prs_rs/impls/comp/
comp_dict.rs

1use crate::prelude::{Allocator, Box, Global, Layout};
2use core::ptr::{write, NonNull};
3use core::slice;
4use core::{mem::size_of, ptr::read_unaligned};
5
6pub(crate) type MaxOffset = u32;
7type FreqCountType = u32;
8const MAX_U16: usize = 65536;
9const ALLOC_ALIGNMENT: usize = 64; // x86 cache line
10
11// Round up to next multiple of ALLOC_ALIGNMENT
12const DICTIONARY_PADDING: usize =
13    (ALLOC_ALIGNMENT - (size_of::<[CompDictEntry; MAX_U16]>() % ALLOC_ALIGNMENT)) % ALLOC_ALIGNMENT;
14
15const ENTRY_SECTION_LEN: usize = size_of::<[CompDictEntry; MAX_U16]>();
16
17/// Dictionary for PRS compression.
18///
19/// This dictionary stores the locations of every single possible place that a specified 2-byte sequence
20/// can be found, with the 2 byte combination being the dictionary 'key'. The values (locations) are
21/// stored inside a shared buffer, where [`CompDictEntry`] dictates the file offsets of the locations
22/// which start with this 2 byte combination. The items are stored in ascending order.
23///
24/// When the compressor is looking for longest match at given address, it will read the 2 bytes at the
25/// address and use that as key [`CompDict::get_item`]. Then the offsets inside the returned entry
26/// will be used to greatly speed up search.
27pub struct CompDict<L: Allocator + Copy = Global, S: Allocator + Copy = Global> {
28    /// Our memory allocation is here.
29    /// Layout:
30    /// - [CompDictEntry; MAX_U16] (dict), constant size
31    /// - [MaxOffset; data_len_num_bytes] (offsets), variable size. This buffer stores offsets of all items of 2 byte combinations.
32    buf: NonNull<u8>,
33    alloc_length: usize, // length of data that 'dict' and 'offsets' were made with
34    long_lived_allocator: L,
35    short_lived_allocator: S,
36}
37
38impl<L: Allocator + Copy, S: Allocator + Copy> Drop for CompDict<L, S> {
39    fn drop(&mut self) {
40        unsafe {
41            // dealloc buffer and box
42            let layout = Layout::from_size_align_unchecked(self.alloc_length, ALLOC_ALIGNMENT);
43            self.long_lived_allocator.deallocate(self.buf, layout);
44        }
45    }
46}
47
48/// An entry in [Compression Dictionary][`CompDict`].
49///
50/// This has pointer to current 'last min offset' [`CompDictEntry::last_read_item`] in [`CompDict::buf`]
51/// allocation, and pointer to last offset for the current 2 byte key.
52///
53/// Last min offset [`CompDictEntry::last_read_item`] is advanced as items are sequentially read,
54/// i.e. when [`CompDict::get_item`] is called. This offset corresponds to the first item which had
55/// offset greater than `min_offset` parameter of last [`CompDict::get_item`] call.
56///
57/// When compressing, this means we can find next matching offset in LZ77 search window
58/// in (effectively) O(1) time.
59#[derive(Clone)]
60#[allow(rustdoc::private_intra_doc_links)]
61pub struct CompDictEntry {
62    /// Address of the last minimum offset from previous call to [`CompDict::get_item`].
63    last_read_item: *mut MaxOffset,
64    /// Address of the last maximum offset from previous call to [`CompDict::get_item`].
65    last_read_item_max: *mut MaxOffset,
66    /// Item after last item within the offsets section of the [`CompDict::buf`].
67    last_item: *mut MaxOffset,
68}
69
70impl<L: Allocator + Copy, S: Allocator + Copy> CompDict<L, S> {
71    /// Create a new [`CompDict`] without initializing it.
72    ///
73    /// # Parameters
74    ///
75    /// - `data_len`: The length of the data that will be used to initialize the dictionary.
76    /// - `long_lived_allocator`: The allocator to use for long-lived memory allocation.
77    /// - `short_lived_allocator`: The allocator to use for short-lived memory allocation.
78    #[inline(always)]
79    pub fn new_in(data_len: usize, long_lived_allocator: L, short_lived_allocator: S) -> Self {
80        unsafe {
81            // constant
82            let offset_section_len = size_of::<MaxOffset>() * data_len;
83            let alloc_size = ENTRY_SECTION_LEN + DICTIONARY_PADDING + offset_section_len;
84
85            let layout = Layout::from_size_align_unchecked(alloc_size, ALLOC_ALIGNMENT);
86            let buf = long_lived_allocator.allocate(layout).unwrap();
87
88            CompDict {
89                buf: NonNull::new_unchecked(buf.as_ptr() as *mut u8),
90                alloc_length: alloc_size,
91                long_lived_allocator,
92                short_lived_allocator,
93            }
94        }
95    }
96
97    /// Initialize the [`CompDict`] with the given data and offset.
98    ///
99    /// # Parameters
100    ///
101    /// - `data`: The data to create the dictionary from.
102    /// - `offset`: The offset to add to the offsets in the dictionary.
103    ///
104    /// # Safety
105    ///
106    /// This function is unsafe as it operates on raw pointers and assumes that
107    /// the `CompDict` has been properly allocated with enough space for `data`.
108    #[inline(always)]
109    pub unsafe fn init(&mut self, data: &[u8], offset: usize) {
110        let dict_entry_ptr = self.buf.as_ptr() as *mut CompDictEntry;
111        let max_ofs_ptr =
112            self.buf
113                .as_ptr()
114                .add(ENTRY_SECTION_LEN + DICTIONARY_PADDING) as *mut MaxOffset;
115
116        // We will use this later to populate the dictionary.
117        // The `dict_insert_entry_ptrs` is a buffer which stores the pointer to the current location
118        // where we need to insert the offset for a given 2 byte sequence (hence length MAX_U16).
119        let alloc = self
120            .short_lived_allocator
121            .allocate(Layout::new::<[*mut MaxOffset; MAX_U16]>())
122            .unwrap()
123            .as_ptr() as *mut [*mut MaxOffset; MAX_U16];
124
125        let mut dict_insert_entry_ptrs =
126            Box::<[*mut MaxOffset; MAX_U16], S>::from_raw_in(alloc, self.short_lived_allocator);
127
128        // dict_insert_entry_ptrs is now a Box, so it will be deallocated when it goes out of scope.
129
130        // Initialize all CompDictEntries
131        let freq_table = self.create_frequency_table(data);
132        let mut cur_ofs_addr = max_ofs_ptr;
133        let mut cur_dict_entry = dict_entry_ptr;
134        let mut cur_freq_tbl_entry = freq_table.as_ptr();
135        let mut cur_ofs_insert_ptr = dict_insert_entry_ptrs.as_mut_ptr();
136        let max_dict_entry = cur_dict_entry.add(MAX_U16);
137
138        // This loop initializes each CompDictEntry (ies) based on the frequency table.
139        // It sets up the pointers for where the offsets for each 2-byte sequence will be stored.
140        // This also populates `dict_insert_entry_ptrs` (via `cur_ofs_insert_ptr`) setting each
141        // entry to the value of `cur_ofs_addr` (the current offset address).
142        while cur_dict_entry < max_dict_entry {
143            let num_items = *cur_freq_tbl_entry;
144            *cur_ofs_insert_ptr = cur_ofs_addr;
145
146            write(
147                cur_dict_entry,
148                CompDictEntry {
149                    last_read_item: cur_ofs_addr,
150                    last_read_item_max: cur_ofs_addr,
151                    last_item: cur_ofs_addr.add(num_items as usize),
152                },
153            );
154
155            cur_ofs_addr = cur_ofs_addr.add(num_items as usize);
156            cur_freq_tbl_entry = cur_freq_tbl_entry.add(1);
157            cur_dict_entry = cur_dict_entry.add(1);
158            cur_ofs_insert_ptr = cur_ofs_insert_ptr.add(1);
159        }
160
161        // The rest of the function is dedicated to actually populating the dictionary with offsets.
162        // Here we do the following:
163        // - Read Each 2 Byte Sequence
164        // - Use 2 Byte Sequence as Key
165        // - Gets insert location via `dict_insert_entry_ptrs` (**insert_entry_ptr)
166        // - Advance insert location for given key (*insert_entry_ptr)
167
168        // Iterate over the data, and add each 2-byte sequence to the dictionary.
169        #[cfg(not(target_pointer_width = "64"))]
170        {
171            let data_ptr_start = data.as_ptr();
172            let mut data_ptr = data.as_ptr();
173            let data_ptr_max = data.as_ptr().add(data.len().saturating_sub(1));
174            debug_assert!(data.len() as MaxOffset <= MaxOffset::MAX);
175
176            while data_ptr < data_ptr_max {
177                let key = read_unaligned(data_ptr as *const u16);
178                let insert_entry_ptr = dict_insert_entry_ptrs.as_mut_ptr().add(key as usize);
179
180                // Insert the offset into the dictionary
181                **insert_entry_ptr = (data_ptr.sub(data_ptr_start as usize) as MaxOffset)
182                    .wrapping_add(offset as MaxOffset);
183
184                *insert_entry_ptr = (*insert_entry_ptr).add(1); // advance to next entry
185
186                data_ptr = data_ptr.add(1);
187            }
188        }
189
190        #[cfg(target_pointer_width = "64")]
191        {
192            let mut data_ofs = 0;
193            let data_len = data.len();
194
195            while data_ofs < data_len.saturating_sub(16) {
196                // Doing a lot of the `data.as_ptr().add()` is ugly, but it makes LLVM do a better job.
197                let chunk = read_unaligned(data.as_ptr().add(data_ofs) as *const u64);
198
199                // Process every 16-bit sequence starting at each byte within the 64-bit chunk
200                for shift in 0..7 {
201                    // Successfully unrolled by LLVM
202                    let key = ((chunk >> (shift * 8)) & 0xFFFF) as u16;
203                    let insert_entry_ptr = dict_insert_entry_ptrs.as_mut_ptr().add(key as usize);
204
205                    **insert_entry_ptr = ((data.as_ptr().add(data_ofs + shift) as usize
206                        - data.as_ptr() as usize)
207                        as MaxOffset)
208                        .wrapping_add(offset as MaxOffset);
209
210                    *insert_entry_ptr = (*insert_entry_ptr).add(1);
211                }
212
213                // Handle the 16-bit number that spans the boundary between this chunk and the next
214                // Note: LLVM puts next_chunk in register and reuses it for next loop iteration (under x64), nothing special to do here.
215                let next_chunk = read_unaligned(data.as_ptr().add(data_ofs + 8) as *const u64);
216                let next_chunk_byte = (next_chunk & 0xFF) << 8;
217                let key = ((chunk >> 56) | next_chunk_byte) as u16;
218                let insert_entry_ptr = dict_insert_entry_ptrs.as_mut_ptr().add(key as usize);
219
220                **insert_entry_ptr = ((data.as_ptr().add(data_ofs + 7) as usize
221                    - data.as_ptr() as usize) as MaxOffset)
222                    .wrapping_add(offset as MaxOffset);
223                *insert_entry_ptr = (*insert_entry_ptr).add(1);
224
225                data_ofs += 8;
226            }
227
228            // Process any remaining bytes in the data.
229            while data_ofs < data_len.saturating_sub(1) {
230                let key = read_unaligned(data.as_ptr().add(data_ofs) as *const u16);
231                let insert_entry_ptr = dict_insert_entry_ptrs.as_mut_ptr().add(key as usize);
232
233                **insert_entry_ptr = ((data.as_ptr().add(data_ofs) as usize
234                    - data.as_ptr() as usize) as MaxOffset)
235                    .wrapping_add(offset as MaxOffset);
236                *insert_entry_ptr = (*insert_entry_ptr).add(1);
237                data_ofs += 1;
238            }
239        }
240    }
241
242    /// Creates a frequency table for the given data.
243    ///
244    /// # Parameters
245    /// - `data`: The data to create the frequency table from.
246    pub(crate) unsafe fn create_frequency_table(&self, data: &[u8]) -> Box<[FreqCountType], S> {
247        // This actually has no overhead.
248
249        let result =
250            Box::<[FreqCountType], S>::new_zeroed_slice_in(MAX_U16, self.short_lived_allocator);
251        let mut result = result.assume_init();
252
253        #[cfg(not(target_pointer_width = "64"))]
254        {
255            // Iterate over the data, and add each 2-byte sequence to the dictionary.
256            let data_ptr = data.as_ptr();
257            let data_ofs_max = data.len().saturating_sub(1);
258            let mut data_ofs = 0;
259            while data_ofs < data_ofs_max {
260                // LLVM successfully unrolls this
261                let index = read_unaligned(data_ptr.add(data_ofs) as *const u16);
262                result[index as usize] += 1;
263                data_ofs += 1;
264            }
265
266            result
267        }
268
269        #[cfg(target_pointer_width = "64")]
270        {
271            let data_len = data.len();
272            let mut data_ofs = 0;
273
274            while data_ofs < data_len.saturating_sub(16) {
275                let chunk = read_unaligned(data.as_ptr().add(data_ofs) as *const u64);
276
277                // Process every 16-bit sequence starting at each byte within the 64-bit chunk
278                for shift in 0..7 {
279                    let index = ((chunk >> (shift * 8)) & 0xFFFF) as u16;
280                    result[index as usize] += 1;
281                }
282
283                // Handle the 16-bit number that spans the boundary between this chunk and the next
284                // Note: LLVM puts next_chunk in register and reuses it for next loop iteration (under x64), nothing special to do here.
285                let next_chunk = read_unaligned(data.as_ptr().add(data_ofs + 8) as *const u64);
286                let key = ((chunk >> 56) | ((next_chunk & 0xFF) << 8)) as u16;
287                result[key as usize] += 1;
288
289                data_ofs += 8;
290            }
291
292            // Process any remaining bytes in the data.
293            while data_ofs < data_len.saturating_sub(1) {
294                let index = read_unaligned(data.as_ptr().add(data_ofs) as *const u16);
295                result[index as usize] += 1;
296                data_ofs += 1;
297            }
298
299            result
300        }
301    }
302
303    /// Returns a slice of offsets for the given key which are greater than or equal to `min_ofs`
304    /// and less than or equal to `max_ofs`.
305    ///
306    /// # Parameters
307    ///
308    /// - `key`: The key to search for.
309    /// - `min_ofs`: The minimum offset returned in the slice.
310    /// - `max_ofs`: The maximum offset returned in the slice.
311    ///
312    /// # Safety
313    ///
314    /// This function is unsafe as it operates on raw pointers.
315    #[inline(always)]
316    pub unsafe fn get_item(&mut self, key: u16, min_ofs: usize, max_ofs: usize) -> &[MaxOffset] {
317        // Ensure that the key is within the bounds of the dictionary.
318        debug_assert!(key as usize <= MAX_U16, "Key is out of range!");
319
320        let entry = &mut self.get_dict_mut()[key as usize];
321        let mut cur_last_read_item = entry.last_read_item;
322
323        // Advance the 'last_read_item' pointer to the first offset greater than or equal to min_ofs
324        while cur_last_read_item < entry.last_item && *cur_last_read_item < min_ofs as MaxOffset {
325            cur_last_read_item = cur_last_read_item.add(1);
326        }
327        entry.last_read_item = cur_last_read_item;
328
329        // Find the end of the range - the first offset greater than max_ofs
330        // TODO: Try last read max item.
331        let mut end = entry.last_read_item_max;
332        while end < entry.last_item && *end <= max_ofs as MaxOffset {
333            end = end.add(1);
334        }
335        entry.last_read_item_max = end;
336
337        // Create a slice from the updated range
338        slice::from_raw_parts(
339            cur_last_read_item,
340            end.offset_from(cur_last_read_item) as usize,
341        )
342    }
343
344    /// Retrieves the dictionary entries section of this [`CompDict`].
345    pub fn get_dict_mut(&mut self) -> &mut [CompDictEntry; MAX_U16] {
346        unsafe {
347            let first_item = self.buf.as_ptr() as *mut CompDictEntry;
348            &mut *(first_item as *mut [CompDictEntry; MAX_U16])
349        }
350    }
351}
352
353impl CompDict {
354    /// Create a new [`CompDict`] without initializing it.
355    ///
356    /// # Parameters
357    ///
358    /// - `data_len`: The length of the data that will be used to initialize the dictionary.
359    pub fn new(data_len: usize) -> Self {
360        Self::new_in(data_len, Global, Global)
361    }
362}
363
364impl CompDictEntry {
365    /// Returns a slice of offsets between `last_read_item` and `last_item`.
366    ///
367    /// # Safety
368    ///
369    /// This function is unsafe as it operates on raw pointers.
370    /// The caller must ensure that `last_read_item` and `last_item` are valid.
371    #[cfg(test)]
372    pub unsafe fn get_items(&mut self) -> &[MaxOffset] {
373        // Calculate the length of the slice by finding the distance between the pointers.
374        let length = self.last_item.offset_from(self.last_read_item) as usize;
375
376        // Create and return a slice from the raw pointers.
377        slice::from_raw_parts(self.last_read_item, length)
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn can_create_dict() {
387        unsafe {
388            let data = &[0x41, 0x42, 0x43, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41];
389            let mut comp_dict = CompDict::new(data.len());
390            comp_dict.init(data, 0);
391
392            // Assert that the items were correctly inserted.
393            assert_eq!(
394                comp_dict.get_dict_mut()[0x4241_u16.to_le() as usize].get_items(),
395                &[0]
396            );
397            assert_eq!(
398                comp_dict.get_dict_mut()[0x4342_u16.to_le() as usize].get_items(),
399                &[1]
400            );
401
402            // Ensure we can get a slice.
403            let result = comp_dict.get_item(0x4141, 3, 4);
404            assert_eq!(&[3, 4], result);
405
406            // Access the next in sequence, and ensure it was correctly advanced.
407            let result = comp_dict.get_item(0x4141, 4, 5);
408            assert_eq!(&[4, 5], result);
409            assert_eq!(*comp_dict.get_dict_mut()[0x4141].last_read_item, 4);
410
411            // Access beyond end of sequence
412            let result = comp_dict.get_item(0x4141, 5, 99);
413            assert_eq!(&[5, 6, 7], result);
414        }
415    }
416
417    #[test]
418    fn can_create_dict_with_offset() {
419        unsafe {
420            let data = &[0x41, 0x42, 0x43, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41];
421            let offset = 1000;
422            let mut comp_dict = CompDict::new(data.len());
423            comp_dict.init(data, offset);
424
425            // Assert that the items were correctly inserted with the offset.
426            assert_eq!(
427                comp_dict.get_dict_mut()[0x4241_u16.to_le() as usize].get_items(),
428                &[1000]
429            );
430            assert_eq!(
431                comp_dict.get_dict_mut()[0x4342_u16.to_le() as usize].get_items(),
432                &[1001]
433            );
434
435            // Ensure we can get a slice with offsets.
436            let result = comp_dict.get_item(0x4141, 1003, 1004);
437            assert_eq!(&[1003, 1004], result);
438
439            // Access the next in sequence, and ensure it was correctly advanced.
440            let result = comp_dict.get_item(0x4141, 1004, 1005);
441            assert_eq!(&[1004, 1005], result);
442            assert_eq!(*comp_dict.get_dict_mut()[0x4141].last_read_item, 1004);
443
444            // Access beyond end of sequence
445            let result = comp_dict.get_item(0x4141, 1005, 1099);
446            assert_eq!(&[1005, 1006, 1007], result);
447        }
448    }
449}