lz4_java_wrc/
lz4_block_output.rs

1use crate::common::{Checksum, ErrorInternal, Result};
2use crate::compression::{Compression, Context};
3use crate::lz4_block_header::{CompressionLevel, CompressionMethod, Lz4BlockHeader};
4
5use std::cmp::min;
6use std::io::Write;
7use std::result::Result as StdResult;
8
9/// Wrapper around a [`Write`] object to compress data.
10///
11/// The data written to [`Lz4BlockOutput`] is be compressed and then written to the wrapped [`Write`].
12///
13/// # Example
14///
15/// ```rust
16/// use lz4_java_wrc::Lz4BlockOutput;
17/// use std::io::Write;
18///
19/// fn main() -> std::io::Result<()> {
20///     let mut output = Vec::new(); // Vec<u8> implements the Write trait
21///     Lz4BlockOutput::new(&mut output).write_all("...".as_bytes())?;
22///     println!("{:?}", output);
23///     Ok(())
24/// }
25/// ```
26pub type Lz4BlockOutput<'a, R> = Lz4BlockOutputBase<'a, R, Context>;
27
28impl<'a, W: Write> Lz4BlockOutput<'a, W> {
29    /// Create a new [`Lz4BlockOutput`] with the default parameters.
30    ///
31    /// See [`Self::with_context()`]
32    #[inline]
33    pub fn new(w: &'a mut W) -> Self {
34        Self::with_context(w, Context::default(), Self::default_block_size()).unwrap()
35    }
36}
37
38/// Wrapper around a [`Write`] object to compress data.
39///
40/// Use this struct only if you want to provide your own Compression implementation. Otherwise use the alias [`Lz4BlockOutput`].
41#[derive(Debug)]
42pub struct Lz4BlockOutputBase<'a, W: Write + Sized, C: Compression> {
43    writer: &'a mut W,
44    compression: C,
45    compression_level: CompressionLevel,
46    write_ptr: usize,
47    decompressed_buf: Vec<u8>,
48    compressed_buf: Vec<u8>,
49    checksum: Checksum,
50}
51
52impl<'a, W: Write, C: Compression> Lz4BlockOutputBase<'a, W, C> {
53    /// Get the default block size: 65536B.
54    #[inline]
55    pub fn default_block_size() -> usize {
56        1 << 16
57    }
58
59    /// Create a new [`Lz4BlockOutputBase`] with the default checksum implementation which is compatible with the Java's default implementation, including the missing 4 bits bug.
60    ///
61    /// See [`Self::with_checksum()`]
62    #[inline]
63    pub fn with_context(w: &'a mut W, c: C, block_size: usize) -> std::io::Result<Self> {
64        Self::with_checksum(w, c, block_size, Lz4BlockHeader::default_checksum)
65    }
66
67    /// Create a new [`Lz4BlockOutputBase`].
68    ///
69    /// The `block_size` must be between `64` and `33554432` bytes.
70    /// The checksum must return a [`u32`].
71    ///
72    /// # Errors
73    ///
74    /// It will return an error if the `block_size` is out of range
75    pub fn with_checksum(
76        w: &'a mut W,
77        c: C,
78        block_size: usize,
79        checksum: fn(&[u8]) -> u32,
80    ) -> std::io::Result<Self> {
81        let compression_level = CompressionLevel::from_block_size(block_size)?;
82        let compressed_buf_len = c
83            .get_maximum_compressed_buffer_len(compression_level.get_max_decompressed_buffer_len());
84        Ok(Self {
85            writer: w,
86            compression: c,
87            compression_level,
88            write_ptr: 0,
89            compressed_buf: vec![0u8; compressed_buf_len],
90            decompressed_buf: vec![0u8; block_size],
91            checksum: Checksum::new(checksum),
92        })
93    }
94
95    fn copy_to_buf(&mut self, buf: &[u8]) -> StdResult<usize, ErrorInternal> {
96        let buf_into = &mut self.decompressed_buf[self.write_ptr..];
97        if buf.len() > buf_into.len() {
98            return ErrorInternal::new_error(
99                "Attempt to write a bigger buffer than the available one",
100            );
101        }
102
103        buf_into[..buf.len()].copy_from_slice(buf);
104        self.write_ptr += buf.len();
105
106        Ok(buf.len())
107    }
108
109    fn remaining_buf_len(&self) -> StdResult<usize, ErrorInternal> {
110        if self.write_ptr <= self.decompressed_buf.len() {
111            Ok(self.decompressed_buf.len() - self.write_ptr)
112        } else {
113            ErrorInternal::new_error("Could not determine the buffer size")
114        }
115    }
116
117    fn write(&mut self, buf: &[u8]) -> Result<usize> {
118        if self.write_ptr == self.decompressed_buf.len() {
119            self.flush()?;
120        }
121        let size_to_copy = min(buf.len(), self.remaining_buf_len()?);
122        Ok(self.copy_to_buf(&buf[..size_to_copy])?)
123    }
124
125    fn flush(&mut self) -> Result<()> {
126        if self.write_ptr > 0 {
127            let decompressed_buf = &self.decompressed_buf[..self.write_ptr];
128            let compressed_buf = match self
129                .compression
130                .compress(decompressed_buf, self.compressed_buf.as_mut())
131            {
132                Ok(s) => &self.compressed_buf[..s],
133                Err(err) => return Err(err.into()),
134            };
135            let (compression_method, buf_to_write) =
136                if compressed_buf.len() < decompressed_buf.len() {
137                    (CompressionMethod::Lz4, compressed_buf)
138                } else {
139                    (CompressionMethod::Raw, decompressed_buf)
140                };
141            Lz4BlockHeader {
142                compression_method,
143                compression_level: self.compression_level,
144                compressed_len: buf_to_write.len() as u32,
145                decompressed_len: decompressed_buf.len() as u32,
146                checksum: self.checksum.run(decompressed_buf),
147            }
148            .write(&mut self.writer)?;
149            self.writer.write_all(buf_to_write)?;
150        }
151        self.write_ptr = 0;
152        self.writer.flush()?;
153        Ok(())
154    }
155}
156
157impl<'a, W: Write, C: Compression> Write for Lz4BlockOutputBase<'a, W, C> {
158    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
159        Ok(Self::write(self, buf)?)
160    }
161
162    fn flush(&mut self) -> std::io::Result<()> {
163        Ok(Self::flush(self)?)
164    }
165}
166
167impl<'a, W: Write, C: Compression> Drop for Lz4BlockOutputBase<'a, W, C> {
168    fn drop(&mut self) {
169        let _ = self.flush();
170    }
171}
172
173#[cfg(test)]
174mod test_lz4_block_output {
175    use super::{CompressionLevel, Context, Lz4BlockOutput};
176    use crate::lz4_block_header::data::VALID_DATA;
177
178    use std::io::Write;
179
180    #[test]
181    fn valid_default_block_size() {
182        let default_block_size = Lz4BlockOutput::<Vec<u8>>::default_block_size();
183        assert_eq!(
184            CompressionLevel::from_block_size(default_block_size).is_ok(),
185            true
186        );
187    }
188
189    #[test]
190    fn write_empty() {
191        let mut out = Vec::<u8>::new();
192        Lz4BlockOutput::with_context(&mut out, Context::default(), 128).unwrap();
193        assert_eq!(out, []);
194    }
195
196    #[test]
197    fn write_basic() {
198        let mut out = Vec::<u8>::new();
199        Lz4BlockOutput::with_context(&mut out, Context::default(), 128)
200            .unwrap()
201            .write_all("...".as_bytes())
202            .unwrap();
203        assert_eq!(out, VALID_DATA);
204    }
205
206    #[test]
207    fn write_several_small_blocks() {
208        let mut out = Vec::<u8>::new();
209        let buf = ['.' as u8; 1024];
210        let loops = 1024;
211        {
212            let mut writer =
213                Lz4BlockOutput::with_context(&mut out, Context::default(), buf.len() * loops)
214                    .unwrap();
215            for _ in 0..loops {
216                writer.write_all(&buf).unwrap();
217            }
218        }
219        let needle = &VALID_DATA[..8];
220        // count number of blocks
221        assert_eq!(
222            out.windows(needle.len())
223                .filter(|window| *window == needle)
224                .count(),
225            1
226        );
227    }
228
229    #[test]
230    fn write_several_big_blocks() {
231        let mut out = Vec::<u8>::new();
232        let buf = ['.' as u8; 128];
233        let loops = 1234;
234        {
235            let mut writer =
236                Lz4BlockOutput::with_context(&mut out, Context::default(), buf.len()).unwrap();
237            for _ in 0..loops {
238                writer.write_all(&buf).unwrap();
239            }
240        }
241        let needle = &VALID_DATA[..8];
242        // count number of blocks
243        assert_eq!(
244            out.windows(needle.len())
245                .filter(|window| *window == needle)
246                .count(),
247            loops
248        );
249    }
250
251    #[test]
252    fn flush_basic() {
253        let mut out = Vec::<u8>::new();
254        {
255            let mut writer =
256                Lz4BlockOutput::with_context(&mut out, Context::default(), 128).unwrap();
257            writer.write_all("...".as_bytes()).unwrap();
258            writer.flush().unwrap();
259            writer.write_all("...".as_bytes()).unwrap();
260        }
261        let mut expected = VALID_DATA.to_vec();
262        expected.extend_from_slice(&VALID_DATA[..]);
263        assert_eq!(out, expected);
264    }
265}