Skip to main content

rar_stream/decompress/
lzss.rs

1//! LZSS sliding window decoder.
2//!
3//! Implements the dictionary-based decompression used in RAR.
4
5use super::{DecompressError, Result};
6
7/// Window size for RAR29 (2MB).
8pub const WINDOW_SIZE_29: usize = 0x200000;
9
10/// Window size for RAR50 (up to 4GB, but we use 64MB max for memory).
11pub const WINDOW_SIZE_50: usize = 0x4000000;
12
13/// LZSS sliding window decoder.
14pub struct LzssDecoder {
15    /// Sliding window buffer
16    window: Vec<u8>,
17    /// Window size mask for wrap-around
18    mask: usize,
19    /// Current write position in window
20    pos: usize,
21    /// Total bytes written to window
22    total_written: u64,
23    /// How much has been flushed to output
24    flushed_pos: u64,
25    /// Output buffer for final result
26    output: Vec<u8>,
27}
28
29impl LzssDecoder {
30    /// Create a new LZSS decoder with the specified window size.
31    pub fn new(window_size: usize) -> Self {
32        debug_assert!(window_size.is_power_of_two());
33        Self {
34            window: vec![0; window_size],
35            mask: window_size - 1,
36            pos: 0,
37            total_written: 0,
38            flushed_pos: 0,
39            output: Vec::new(),
40        }
41    }
42
43    /// Create decoder for RAR 2.9 format.
44    pub fn rar29() -> Self {
45        Self::new(WINDOW_SIZE_29)
46    }
47
48    /// Create decoder for RAR 5.0 format.
49    pub fn rar50() -> Self {
50        Self::new(WINDOW_SIZE_50)
51    }
52
53    /// Reset the decoder for reuse, avoiding reallocation.
54    /// Note: Window contents are NOT cleared - we only read after writing.
55    #[inline]
56    pub fn reset(&mut self) {
57        self.pos = 0;
58        self.total_written = 0;
59        self.output.clear();
60        // No need to clear window - we validate reads against total_written
61    }
62
63    /// Enable output accumulation for extracting files larger than window.
64    #[inline]
65    pub fn enable_output(&mut self, capacity: usize) {
66        self.output = Vec::with_capacity(capacity);
67    }
68
69    /// Write a literal byte to the window.
70    #[inline(always)]
71    pub fn write_literal(&mut self, byte: u8) {
72        // SAFETY: pos is always < window.len() due to mask
73        unsafe {
74            *self.window.get_unchecked_mut(self.pos) = byte;
75        }
76        self.pos = (self.pos + 1) & self.mask;
77        self.total_written += 1;
78    }
79
80    /// Flush data from window to output, up to the given absolute position.
81    /// This is called after filters have been applied.
82    pub fn flush_to_output(&mut self, up_to: u64) {
83        let current_output_len = self.output.len() as u64;
84        if up_to <= current_output_len {
85            return; // Already flushed
86        }
87
88        let flush_start = current_output_len as usize;
89        let flush_end = up_to as usize;
90        let flush_len = flush_end - flush_start;
91        let window_start = flush_start & self.mask;
92
93        // Reserve space upfront
94        self.output.reserve(flush_len);
95
96        // Check if we can do a contiguous copy (no wrap)
97        if window_start + flush_len <= self.window.len() {
98            // Fast path: contiguous copy
99            self.output
100                .extend_from_slice(&self.window[window_start..window_start + flush_len]);
101        } else {
102            // Slow path: wrapping copy in two parts
103            let first_part = self.window.len() - window_start;
104            self.output.extend_from_slice(&self.window[window_start..]);
105            self.output
106                .extend_from_slice(&self.window[..flush_len - first_part]);
107        }
108
109        self.flushed_pos = up_to;
110    }
111
112    /// Get mutable access to the window for filter execution.
113    pub fn window_mut(&mut self) -> &mut [u8] {
114        &mut self.window
115    }
116
117    /// Get the window mask (for filter positioning).
118    pub fn window_mask(&self) -> u32 {
119        self.mask as u32
120    }
121
122    /// Get how much has been flushed to output.
123    pub fn flushed_pos(&self) -> u64 {
124        self.flushed_pos
125    }
126
127    /// Write filtered data directly to output, bypassing the window.
128    /// This is used for VM filter output which should NOT modify the window.
129    pub fn write_filtered_to_output(&mut self, data: &[u8], position: u64) {
130        // Ensure we're at the right position - if not, we might have missed a flush
131        let current_len = self.output.len() as u64;
132        if current_len < position {
133            // Flush unfiltered data from window up to this position
134            let window_start = current_len as usize;
135            let flush_len = (position - current_len) as usize;
136            self.output.reserve(flush_len);
137            let ws = window_start & self.mask;
138            if ws + flush_len <= self.window.len() {
139                self.output
140                    .extend_from_slice(&self.window[ws..ws + flush_len]);
141            } else {
142                let first = self.window.len() - ws;
143                self.output.extend_from_slice(&self.window[ws..]);
144                self.output
145                    .extend_from_slice(&self.window[..flush_len - first]);
146            }
147        }
148        self.output.extend_from_slice(data);
149        self.flushed_pos = position + data.len() as u64;
150    }
151
152    /// Get read-only access to the window for filter execution.
153    pub fn window(&self) -> &[u8] {
154        &self.window
155    }
156
157    /// Copy bytes from a previous position in the window.
158    /// Optimized for both overlapping and non-overlapping copies.
159    #[inline(always)]
160    pub fn copy_match(&mut self, distance: u32, length: u32) -> Result<()> {
161        // Validate distance against bytes actually written, not window size
162        if distance == 0 || distance as u64 > self.total_written {
163            return self.copy_match_error(distance);
164        }
165
166        let len = length as usize;
167        let dist = distance as usize;
168
169        // Fast path: copy doesn't wrap around window boundary and doesn't overlap
170        if dist >= len && self.pos + len <= self.window.len() && self.pos >= dist {
171            // Non-overlapping, non-wrapping: use copy_within for speed
172            let src_start = self.pos - dist;
173            self.window
174                .copy_within(src_start..src_start + len, self.pos);
175            self.pos += len;
176            self.total_written += length as u64;
177            return Ok(());
178        }
179
180        // Medium path: overlapping but no wrapping - use copy_within in chunks
181        // Only worthwhile if we can copy at least 8 bytes at a time
182        if self.pos + len <= self.window.len() && self.pos >= dist && dist >= 8 {
183            let src_start = self.pos - dist;
184            let mut copied = 0;
185            while copied < len {
186                let chunk = (len - copied).min(dist);
187                self.window
188                    .copy_within(src_start..src_start + chunk, self.pos + copied);
189                copied += chunk;
190            }
191            self.pos += len;
192            self.total_written += length as u64;
193            return Ok(());
194        }
195
196        // Slow path: handle wrapping or very short distance copies
197        // Chunk into non-wrapping segments where possible
198        let src_pos = (self.pos.wrapping_sub(dist)) & self.mask;
199        let window_len = self.window.len();
200
201        let mut remaining = len;
202        let mut si = src_pos;
203        let mut di = self.pos;
204
205        while remaining > 0 {
206            // How many bytes until src or dest wraps?
207            let src_avail = window_len - si;
208            let dst_avail = window_len - di;
209            let chunk = remaining.min(src_avail).min(dst_avail);
210
211            // For overlapping copies with very short distance, fall back to byte-by-byte
212            if si < di && di < si + chunk || di < si && si < di + chunk {
213                // Overlapping within this chunk — byte-by-byte
214                let window_ptr = self.window.as_mut_ptr();
215                for _ in 0..chunk {
216                    // SAFETY: si and di are always < window.len() due to mask
217                    unsafe {
218                        *window_ptr.add(di) = *window_ptr.add(si);
219                    }
220                    si = (si + 1) & self.mask;
221                    di = (di + 1) & self.mask;
222                }
223            } else {
224                self.window.copy_within(si..si + chunk, di);
225                si = (si + chunk) & self.mask;
226                di = (di + chunk) & self.mask;
227            }
228            remaining -= chunk;
229        }
230        self.pos = (self.pos + len) & self.mask;
231
232        self.total_written += length as u64;
233        Ok(())
234    }
235
236    /// Cold path for error handling - keeps hot path small
237    #[cold]
238    #[inline(never)]
239    fn copy_match_error(&self, distance: u32) -> Result<()> {
240        Err(DecompressError::InvalidBackReference {
241            offset: distance,
242            position: self.pos as u32,
243        })
244    }
245
246    /// Get the current window position.
247    pub fn position(&self) -> usize {
248        self.pos
249    }
250
251    /// Get total bytes written.
252    pub fn total_written(&self) -> u64 {
253        self.total_written
254    }
255
256    /// Get a byte at the specified offset from current position (going back).
257    /// Call this after decompression to get the output.
258    pub fn get_output(&self, start: u64, len: usize) -> Vec<u8> {
259        // If we have accumulated output, use it
260        if !self.output.is_empty() {
261            let start = start as usize;
262            let end = (start + len).min(self.output.len());
263            return self.output[start..end].to_vec();
264        }
265
266        let window_len = self.window.len();
267
268        // Calculate start position in window
269        let start_pos = if self.total_written <= window_len as u64 {
270            start as usize
271        } else {
272            // Window has wrapped
273            let offset = (self.total_written - start) as usize;
274            if offset > window_len {
275                return Vec::new(); // Data no longer in window
276            }
277            (self.pos.wrapping_sub(offset)) & self.mask
278        };
279
280        self.copy_from_window(start_pos, len)
281    }
282
283    /// Take ownership of the accumulated output buffer.
284    /// More efficient than get_output() when you need all output.
285    pub fn take_output(&mut self) -> Vec<u8> {
286        std::mem::take(&mut self.output)
287    }
288
289    /// Get read access to the accumulated output buffer.
290    pub fn output(&self) -> &[u8] {
291        &self.output
292    }
293
294    /// Get mutable access to the output buffer for filter execution.
295    pub fn output_mut(&mut self) -> &mut [u8] {
296        &mut self.output
297    }
298
299    /// Get the most recent `len` bytes from the window.
300    pub fn get_recent(&self, len: usize) -> Vec<u8> {
301        let actual_len = len.min(self.total_written as usize);
302        let start = (self.pos.wrapping_sub(actual_len)) & self.mask;
303        self.copy_from_window(start, actual_len)
304    }
305
306    /// Bulk copy from window handling wrap-around with extend_from_slice.
307    fn copy_from_window(&self, start: usize, len: usize) -> Vec<u8> {
308        let mut out = Vec::with_capacity(len);
309        let ws = start & self.mask;
310        if ws + len <= self.window.len() {
311            out.extend_from_slice(&self.window[ws..ws + len]);
312        } else {
313            let first = self.window.len() - ws;
314            out.extend_from_slice(&self.window[ws..]);
315            out.extend_from_slice(&self.window[..len - first]);
316        }
317        out
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_literal_output() {
327        let mut decoder = LzssDecoder::new(256);
328
329        decoder.write_literal(b'H');
330        decoder.write_literal(b'e');
331        decoder.write_literal(b'l');
332        decoder.write_literal(b'l');
333        decoder.write_literal(b'o');
334
335        assert_eq!(decoder.total_written(), 5);
336        assert_eq!(decoder.get_recent(5), b"Hello");
337    }
338
339    #[test]
340    fn test_copy_match() {
341        let mut decoder = LzssDecoder::new(256);
342
343        // Write "abc"
344        decoder.write_literal(b'a');
345        decoder.write_literal(b'b');
346        decoder.write_literal(b'c');
347
348        // Copy from distance 3, length 6 -> "abcabc"
349        decoder.copy_match(3, 6).unwrap();
350
351        assert_eq!(decoder.total_written(), 9);
352        assert_eq!(decoder.get_recent(9), b"abcabcabc");
353    }
354
355    #[test]
356    fn test_overlapping_copy() {
357        let mut decoder = LzssDecoder::new(256);
358
359        // Write "a"
360        decoder.write_literal(b'a');
361
362        // Copy from distance 1, length 5 -> "aaaaa"
363        decoder.copy_match(1, 5).unwrap();
364
365        assert_eq!(decoder.get_recent(6), b"aaaaaa");
366    }
367
368    #[test]
369    fn test_invalid_distance() {
370        let mut decoder = LzssDecoder::new(256);
371        decoder.write_literal(b'a');
372
373        // Distance 0 is invalid
374        assert!(decoder.copy_match(0, 1).is_err());
375    }
376}