Skip to main content

rar_stream/decompress/
lzss.rs

1//! LZSS sliding window decoder.
2//!
3//! Implements the dictionary-based decompression used in RAR.
4
5use super::{DecompressError, Result};
6
7/// Window size for RAR29 (2MB).
8pub const WINDOW_SIZE_29: usize = 0x200000;
9
10/// Window size for RAR50 (up to 4GB, but we use 64MB max for memory).
11pub const WINDOW_SIZE_50: usize = 0x4000000;
12
13/// LZSS sliding window decoder.
14pub struct LzssDecoder {
15    /// Sliding window buffer
16    window: Vec<u8>,
17    /// Window size mask for wrap-around
18    mask: usize,
19    /// Current write position in window
20    pos: usize,
21    /// Total bytes written to window
22    total_written: u64,
23    /// How much has been flushed to output
24    flushed_pos: u64,
25    /// Output buffer for final result
26    output: Vec<u8>,
27}
28
29impl LzssDecoder {
30    /// Create a new LZSS decoder with the specified window size.
31    pub fn new(window_size: usize) -> Self {
32        debug_assert!(window_size.is_power_of_two());
33        Self {
34            window: vec![0; window_size],
35            mask: window_size - 1,
36            pos: 0,
37            total_written: 0,
38            flushed_pos: 0,
39            output: Vec::new(),
40        }
41    }
42
43    /// Create decoder for RAR 2.9 format.
44    pub fn rar29() -> Self {
45        Self::new(WINDOW_SIZE_29)
46    }
47
48    /// Create decoder for RAR 5.0 format.
49    pub fn rar50() -> Self {
50        Self::new(WINDOW_SIZE_50)
51    }
52
53    /// Reset the decoder for reuse, avoiding reallocation.
54    /// Note: Window contents are NOT cleared - we only read after writing.
55    #[inline]
56    pub fn reset(&mut self) {
57        self.pos = 0;
58        self.total_written = 0;
59        self.output.clear();
60        // No need to clear window - we validate reads against total_written
61    }
62
63    /// Enable output accumulation for extracting files larger than window.
64    pub fn enable_output(&mut self, capacity: usize) {
65        self.output = Vec::with_capacity(capacity);
66    }
67
68    /// Write a literal byte to the window.
69    #[inline(always)]
70    pub fn write_literal(&mut self, byte: u8) {
71        // SAFETY: pos is always < window.len() due to mask
72        unsafe {
73            *self.window.get_unchecked_mut(self.pos) = byte;
74        }
75        self.pos = (self.pos + 1) & self.mask;
76        self.total_written += 1;
77    }
78
79    /// Flush data from window to output, up to the given absolute position.
80    /// This is called after filters have been applied.
81    pub fn flush_to_output(&mut self, up_to: u64) {
82        let current_output_len = self.output.len() as u64;
83        if up_to <= current_output_len {
84            return; // Already flushed
85        }
86
87        let flush_start = current_output_len as usize;
88        let flush_end = up_to as usize;
89        let flush_len = flush_end - flush_start;
90        let window_start = flush_start & self.mask;
91
92        // Reserve space upfront
93        self.output.reserve(flush_len);
94
95        // Check if we can do a contiguous copy (no wrap)
96        if window_start + flush_len <= self.window.len() {
97            // Fast path: contiguous copy
98            self.output
99                .extend_from_slice(&self.window[window_start..window_start + flush_len]);
100        } else {
101            // Slow path: wrapping copy in two parts
102            let first_part = self.window.len() - window_start;
103            self.output.extend_from_slice(&self.window[window_start..]);
104            self.output
105                .extend_from_slice(&self.window[..flush_len - first_part]);
106        }
107
108        self.flushed_pos = up_to;
109    }
110
111    /// Get mutable access to the window for filter execution.
112    pub fn window_mut(&mut self) -> &mut [u8] {
113        &mut self.window
114    }
115
116    /// Get the window mask (for filter positioning).
117    pub fn window_mask(&self) -> u32 {
118        self.mask as u32
119    }
120
121    /// Get how much has been flushed to output.
122    pub fn flushed_pos(&self) -> u64 {
123        self.flushed_pos
124    }
125
126    /// Write filtered data directly to output, bypassing the window.
127    /// This is used for VM filter output which should NOT modify the window.
128    pub fn write_filtered_to_output(&mut self, data: &[u8], position: u64) {
129        // Ensure we're at the right position - if not, we might have missed a flush
130        let current_len = self.output.len() as u64;
131        if current_len < position {
132            // Need to flush unfiltered data from window up to this position first
133            // This can happen if there's data between the last flush and the filter start
134            let window_start = current_len as usize;
135            let flush_len = (position - current_len) as usize;
136            self.output.reserve(flush_len);
137            for i in 0..flush_len {
138                let window_idx = (window_start + i) & self.mask;
139                self.output.push(self.window[window_idx]);
140            }
141        }
142        self.output.extend_from_slice(data);
143        self.flushed_pos = position + data.len() as u64;
144    }
145
146    /// Get read-only access to the window for filter execution.
147    pub fn window(&self) -> &[u8] {
148        &self.window
149    }
150
151    /// Copy bytes from a previous position in the window.
152    /// Optimized for both overlapping and non-overlapping copies.
153    #[inline(always)]
154    pub fn copy_match(&mut self, distance: u32, length: u32) -> Result<()> {
155        // Validate distance against bytes actually written, not window size
156        if distance == 0 || distance as u64 > self.total_written {
157            return self.copy_match_error(distance);
158        }
159
160        let len = length as usize;
161        let dist = distance as usize;
162
163        // Fast path: copy doesn't wrap around window boundary and doesn't overlap
164        if dist >= len && self.pos + len <= self.window.len() && self.pos >= dist {
165            // Non-overlapping, non-wrapping: use copy_within for speed
166            let src_start = self.pos - dist;
167            self.window
168                .copy_within(src_start..src_start + len, self.pos);
169            self.pos += len;
170            self.total_written += length as u64;
171            return Ok(());
172        }
173
174        // Medium path: overlapping but no wrapping - use copy_within in chunks
175        // Only worthwhile if we can copy at least 8 bytes at a time
176        if self.pos + len <= self.window.len() && self.pos >= dist && dist >= 8 {
177            let src_start = self.pos - dist;
178            let mut copied = 0;
179            while copied < len {
180                let chunk = (len - copied).min(dist);
181                self.window
182                    .copy_within(src_start..src_start + chunk, self.pos + copied);
183                copied += chunk;
184            }
185            self.pos += len;
186            self.total_written += length as u64;
187            return Ok(());
188        }
189
190        // Slow path: handle wrapping or very short distance copies byte-by-byte
191        // Use unchecked access since we've already validated distance
192        let src_pos = (self.pos.wrapping_sub(dist)) & self.mask;
193        let window_ptr = self.window.as_mut_ptr();
194
195        for i in 0..len {
196            let src_idx = (src_pos + i) & self.mask;
197            let dest_idx = (self.pos + i) & self.mask;
198            // SAFETY: src_idx and dest_idx are always < window.len() due to mask
199            unsafe {
200                let byte = *window_ptr.add(src_idx);
201                *window_ptr.add(dest_idx) = byte;
202            }
203        }
204        self.pos = (self.pos + len) & self.mask;
205
206        self.total_written += length as u64;
207        Ok(())
208    }
209
210    /// Cold path for error handling - keeps hot path small
211    #[cold]
212    #[inline(never)]
213    fn copy_match_error(&self, distance: u32) -> Result<()> {
214        Err(DecompressError::InvalidBackReference {
215            offset: distance,
216            position: self.pos as u32,
217        })
218    }
219
220    /// Get the current window position.
221    pub fn position(&self) -> usize {
222        self.pos
223    }
224
225    /// Get total bytes written.
226    pub fn total_written(&self) -> u64 {
227        self.total_written
228    }
229
230    /// Get a byte at the specified offset from current position (going back).
231    /// Call this after decompression to get the output.
232    pub fn get_output(&self, start: u64, len: usize) -> Vec<u8> {
233        // If we have accumulated output, use it
234        if !self.output.is_empty() {
235            let start = start as usize;
236            let end = (start + len).min(self.output.len());
237            return self.output[start..end].to_vec();
238        }
239
240        let mut output = Vec::with_capacity(len);
241        let window_len = self.window.len();
242
243        // Calculate start position in window
244        let start_pos = if self.total_written <= window_len as u64 {
245            start as usize
246        } else {
247            // Window has wrapped
248            let _written_in_window = self.total_written as usize % window_len;
249            let offset = (self.total_written - start) as usize;
250            if offset > window_len {
251                return output; // Data no longer in window
252            }
253            (self.pos.wrapping_sub(offset)) & self.mask
254        };
255
256        for i in 0..len {
257            let idx = (start_pos + i) & self.mask;
258            output.push(self.window[idx]);
259        }
260
261        output
262    }
263
264    /// Take ownership of the accumulated output buffer.
265    /// More efficient than get_output() when you need all output.
266    pub fn take_output(&mut self) -> Vec<u8> {
267        std::mem::take(&mut self.output)
268    }
269
270    /// Get read access to the accumulated output buffer.
271    pub fn output(&self) -> &[u8] {
272        &self.output
273    }
274
275    /// Get mutable access to the output buffer for filter execution.
276    pub fn output_mut(&mut self) -> &mut [u8] {
277        &mut self.output
278    }
279
280    /// Get the most recent `len` bytes from the window.
281    pub fn get_recent(&self, len: usize) -> Vec<u8> {
282        let actual_len = len.min(self.total_written as usize);
283        let mut output = Vec::with_capacity(actual_len);
284
285        let start = (self.pos.wrapping_sub(actual_len)) & self.mask;
286        for i in 0..actual_len {
287            let idx = (start + i) & self.mask;
288            output.push(self.window[idx]);
289        }
290
291        output
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_literal_output() {
301        let mut decoder = LzssDecoder::new(256);
302
303        decoder.write_literal(b'H');
304        decoder.write_literal(b'e');
305        decoder.write_literal(b'l');
306        decoder.write_literal(b'l');
307        decoder.write_literal(b'o');
308
309        assert_eq!(decoder.total_written(), 5);
310        assert_eq!(decoder.get_recent(5), b"Hello");
311    }
312
313    #[test]
314    fn test_copy_match() {
315        let mut decoder = LzssDecoder::new(256);
316
317        // Write "abc"
318        decoder.write_literal(b'a');
319        decoder.write_literal(b'b');
320        decoder.write_literal(b'c');
321
322        // Copy from distance 3, length 6 -> "abcabc"
323        decoder.copy_match(3, 6).unwrap();
324
325        assert_eq!(decoder.total_written(), 9);
326        assert_eq!(decoder.get_recent(9), b"abcabcabc");
327    }
328
329    #[test]
330    fn test_overlapping_copy() {
331        let mut decoder = LzssDecoder::new(256);
332
333        // Write "a"
334        decoder.write_literal(b'a');
335
336        // Copy from distance 1, length 5 -> "aaaaa"
337        decoder.copy_match(1, 5).unwrap();
338
339        assert_eq!(decoder.get_recent(6), b"aaaaaa");
340    }
341
342    #[test]
343    fn test_invalid_distance() {
344        let mut decoder = LzssDecoder::new(256);
345        decoder.write_literal(b'a');
346
347        // Distance 0 is invalid
348        assert!(decoder.copy_match(0, 1).is_err());
349    }
350}