ages_prs/
compress.rs

1//! Compression routine for PRS
2
3use crate::Variant;
4
5use std::fmt;
6use std::error;
7use std::io::{self, Write};
8
9use libflate_lz77::{
10    Code,
11    DefaultLz77Encoder,
12    DefaultLz77EncoderBuilder,
13    Lz77Encode,
14    Sink,
15    MAX_LENGTH,
16};
17
18/// An IO sink for compressing and encoding a stream to PRS.
19pub struct PrsEncoder<W: Write, V: Variant> {
20    sink: Option<PrsSink<V>>,
21    inner: Option<W>,
22    encoder: DefaultLz77Encoder,
23    _pd: std::marker::PhantomData<V>,
24}
25
26/// Error returned when `PrsEncoder::into_inner` fails.
27#[derive(Debug)]
28pub struct IntoInnerError<W>(W, io::Error);
29
30impl<W: Write, V: Variant> PrsEncoder<W, V> {
31    /// Wraps a Write sink, initializing the encoder state
32    pub fn new(inner: W) -> PrsEncoder<W, V> {
33        let encoder = DefaultLz77EncoderBuilder::new()
34            .window_size(8191)
35            .max_length(std::cmp::min(MAX_LENGTH, V::MAX_COPY_LENGTH))
36            .build();
37        
38        PrsEncoder {
39            sink: Some(PrsSink::new(32)),
40            inner: Some(inner),
41            encoder,
42            _pd: std::marker::PhantomData,
43        }
44    }
45
46    /// Finish encoding the PRS stream, returning the inner Write.
47    ///
48    /// Errors will leave the PRS stream in an incomplete state; the E type is
49    /// only present to capture the inner Write for inspection. There is no way
50    /// to recover the broken PRS stream if this operation fails.
51    pub fn into_inner(mut self) -> Result<W, IntoInnerError<W>> {
52        match self.flush_buf() {
53            Err(e) => Err(IntoInnerError(self.inner.take().unwrap(), e)),
54            Ok(()) => {
55                let mut sink = self.sink.take().unwrap();
56                let mut inner = self.inner.take().unwrap();
57                self.encoder.flush(&mut sink);
58                let buf = sink.finish();
59
60                match inner.write_all(&buf[..]) {
61                    Err(e) => Err(IntoInnerError(inner, e)),
62                    Ok(_) => Ok(inner),
63                }
64            },
65        }
66    }
67
68    /// Attempt to flush the intermediary buffer to the sink
69    fn flush_buf(&mut self) -> io::Result<()> {
70        let mut sink = self.sink.as_mut().unwrap();
71        let inner = self.inner.as_mut().unwrap();
72
73        // everything before the current cmd index is safe to write
74        let high_water = sink.cmd_index;
75        if high_water == 0 {
76            // don't flush; we don't have a saturated command byte yet
77            return Ok(());
78        }
79
80        let mut written = 0;
81        let len = high_water;
82        let mut ret: io::Result<()> = Ok(());
83
84        while written < len {
85            // only write up to len bytes this flush
86            let r = inner.write(&sink.out[written..len]);
87
88            match r {
89                Ok(0) => {
90                    ret = Err(io::Error::new(
91                        io::ErrorKind::WriteZero,
92                        "failed to write the buffered data"
93                    ));
94                    break;
95                },
96                Ok(n) => written += n,
97                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {},
98                Err(e) => {
99                    ret = Err(e);
100                    break;
101                }
102            }
103        }
104        if written > 0 {
105            sink.out.drain(..written);
106            sink.cmd_index -= written;
107        }
108        ret
109    }
110}
111
112impl<W: Write, V: Variant> Write for PrsEncoder<W, V> {
113    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
114        // unlike BufWriter we can't flush when buffer capacity is hit
115        {
116            self.encoder.encode(buf, self.sink.as_mut().unwrap());
117        }
118        // we'll try to flush as much as possible since buffer perf is not
119        // the goal here; PrsEncoder<BufWriter<_>, _> is fine for that
120        self.flush_buf()?;
121        Ok(buf.len())
122    }
123
124    fn flush(&mut self) -> io::Result<()> {
125        self.flush_buf().and_then(|()| self.inner.as_mut().unwrap().flush())
126    }
127}
128
129impl<W: Write, V: Variant> Drop for PrsEncoder<W, V> {
130    fn drop(&mut self) {
131        if self.inner.is_some() && self.sink.is_some() {
132            let _r = self.flush_buf();
133            let mut sink = self.sink.take().unwrap();
134            let mut inner = self.inner.take().unwrap();
135            self.encoder.flush(&mut sink);
136            let buf = sink.finish();
137
138            // we'll try to finish the stream but it is impossible to report
139            // errors from a Drop
140            let _r = inner.write_all(&buf[..]);
141        }
142    }
143}
144
145impl<W: Write, V: Variant> fmt::Debug for PrsEncoder<W, V>
146where
147    W: fmt::Debug,
148{
149    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
150        fmt.debug_struct("PrsEncoder")
151            .field("writer", &self.inner.as_ref().unwrap())
152            .field("buffer", &self.sink.as_ref().unwrap().out)
153            .finish()
154    }
155}
156
157impl<W> fmt::Display for IntoInnerError<W> {
158    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> fmt::Result {
159        write!(fmt, "Failed to complete PRS stream: {}", self.1)
160    }
161}
162
163impl<W: Send + fmt::Debug> error::Error for IntoInnerError<W> {
164    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
165        Some(&self.1)
166    }
167}
168
169impl<W> IntoInnerError<W> {
170    /// Reference the IO error that failed the operation.
171    pub fn error(&self) -> &io::Error {
172        &self.1
173    }
174
175    /// Retrieve the inner type.
176    pub fn into_inner(self) -> W {
177        self.0
178    }
179}
180
181// ---- LZ77 Sink implementation ----
182
183struct PrsSink<V: Variant> {
184    /// index into `out` which is the current cmd stream head
185    cmd_index: usize,
186    /// how many cmd bits can we still write
187    cmd_bits_rem: u8,
188    /// the output buffer
189    out: Vec<u8>,
190
191    _pd: std::marker::PhantomData<V>,
192}
193
194impl<V: Variant> PrsSink<V> {
195    fn new(capacity: usize) -> PrsSink<V> {
196        PrsSink {
197            cmd_index: 0,
198            cmd_bits_rem: 0,
199            out: Vec::with_capacity(capacity),
200            _pd: std::marker::PhantomData,
201        }
202    }
203
204    fn write_bit(&mut self, bit: bool) {
205        if self.cmd_bits_rem == 0 {
206            self.cmd_index = self.out.len();
207            self.cmd_bits_rem = 8;
208            self.out.push(0);
209        }
210
211        if bit {
212            self.out[self.cmd_index] |= 1 << (8 - self.cmd_bits_rem);
213        }
214
215        self.cmd_bits_rem -= 1;
216    }
217
218    fn finish(mut self) -> Vec<u8> {
219        self.write_bit(false);
220        self.write_bit(true); // long ptr
221        self.out.push(0); // zero offset = EOF
222        self.out.push(0);
223
224        self.out
225    }
226}
227
228impl<V: Variant> Sink for PrsSink<V> {
229    fn consume(&mut self, code: Code) {
230        match code {
231            Code::Literal(b) => {
232                self.write_bit(true);
233                self.out.push(b);
234            },
235            Code::Pointer { length, backward_distance } => {
236                // preconditions
237                if length < 2 {
238                    panic!("copy length too small (< 2)");
239                }
240                if length > V::MAX_COPY_LENGTH {
241                    panic!("copy length too large");
242                }
243                if backward_distance >= 8192 {
244                    panic!("copy distance too far (>8191)");
245                }
246
247                if backward_distance >= 256 || length > 5 {
248                    // long ptr
249                    self.write_bit(false);
250                    self.write_bit(true);
251
252                    let mut offset = backward_distance as i32;
253                    
254                    offset = -offset;
255                    offset <<= 3;
256                    if (length - 2) < 8 {
257                        offset |= (length - 2) as i32;
258                    }
259
260                    self.out.extend_from_slice(&(offset as u16).to_le_bytes());
261                    
262                    if (length - 2) >= 8 {
263                        let size = (
264                            length - (V::MIN_LONG_COPY_LENGTH as u16)
265                        ) as u8;
266                        self.out.push(size);
267                    }
268                } else {
269                    // short ptr
270                    self.write_bit(false);
271                    self.write_bit(false);
272
273                    let offset = backward_distance as i32;
274                    let size = (length - 2) as i32;
275                    
276                    self.write_bit(size & 0b10 > 0);
277                    self.write_bit(size & 0b01 > 0);
278                    self.out.push((-offset & 0xFF) as u8);
279                }
280            },
281        }
282    }
283}