1use crate::common::{Checksum, ErrorInternal, Result};
2use crate::compression::{Compression, Context};
3use crate::lz4_block_header::{CompressionLevel, CompressionMethod, Lz4BlockHeader};
4
5use std::cmp::min;
6use std::io::Write;
7use std::result::Result as StdResult;
8
9pub type Lz4BlockOutput<'a, R> = Lz4BlockOutputBase<'a, R, Context>;
27
28impl<'a, W: Write> Lz4BlockOutput<'a, W> {
29 #[inline]
33 pub fn new(w: &'a mut W) -> Self {
34 Self::with_context(w, Context::default(), Self::default_block_size()).unwrap()
35 }
36}
37
38#[derive(Debug)]
42pub struct Lz4BlockOutputBase<'a, W: Write + Sized, C: Compression> {
43 writer: &'a mut W,
44 compression: C,
45 compression_level: CompressionLevel,
46 write_ptr: usize,
47 decompressed_buf: Vec<u8>,
48 compressed_buf: Vec<u8>,
49 checksum: Checksum,
50}
51
52impl<'a, W: Write, C: Compression> Lz4BlockOutputBase<'a, W, C> {
53 #[inline]
55 pub fn default_block_size() -> usize {
56 1 << 16
57 }
58
59 #[inline]
63 pub fn with_context(w: &'a mut W, c: C, block_size: usize) -> std::io::Result<Self> {
64 Self::with_checksum(w, c, block_size, Lz4BlockHeader::default_checksum)
65 }
66
67 pub fn with_checksum(
76 w: &'a mut W,
77 c: C,
78 block_size: usize,
79 checksum: fn(&[u8]) -> u32,
80 ) -> std::io::Result<Self> {
81 let compression_level = CompressionLevel::from_block_size(block_size)?;
82 let compressed_buf_len = c
83 .get_maximum_compressed_buffer_len(compression_level.get_max_decompressed_buffer_len());
84 Ok(Self {
85 writer: w,
86 compression: c,
87 compression_level,
88 write_ptr: 0,
89 compressed_buf: vec![0u8; compressed_buf_len],
90 decompressed_buf: vec![0u8; block_size],
91 checksum: Checksum::new(checksum),
92 })
93 }
94
95 fn copy_to_buf(&mut self, buf: &[u8]) -> StdResult<usize, ErrorInternal> {
96 let buf_into = &mut self.decompressed_buf[self.write_ptr..];
97 if buf.len() > buf_into.len() {
98 return ErrorInternal::new_error(
99 "Attempt to write a bigger buffer than the available one",
100 );
101 }
102
103 buf_into[..buf.len()].copy_from_slice(buf);
104 self.write_ptr += buf.len();
105
106 Ok(buf.len())
107 }
108
109 fn remaining_buf_len(&self) -> StdResult<usize, ErrorInternal> {
110 if self.write_ptr <= self.decompressed_buf.len() {
111 Ok(self.decompressed_buf.len() - self.write_ptr)
112 } else {
113 ErrorInternal::new_error("Could not determine the buffer size")
114 }
115 }
116
117 fn write(&mut self, buf: &[u8]) -> Result<usize> {
118 if self.write_ptr == self.decompressed_buf.len() {
119 self.flush()?;
120 }
121 let size_to_copy = min(buf.len(), self.remaining_buf_len()?);
122 Ok(self.copy_to_buf(&buf[..size_to_copy])?)
123 }
124
125 fn flush(&mut self) -> Result<()> {
126 if self.write_ptr > 0 {
127 let decompressed_buf = &self.decompressed_buf[..self.write_ptr];
128 let compressed_buf = match self
129 .compression
130 .compress(decompressed_buf, self.compressed_buf.as_mut())
131 {
132 Ok(s) => &self.compressed_buf[..s],
133 Err(err) => return Err(err.into()),
134 };
135 let (compression_method, buf_to_write) =
136 if compressed_buf.len() < decompressed_buf.len() {
137 (CompressionMethod::Lz4, compressed_buf)
138 } else {
139 (CompressionMethod::Raw, decompressed_buf)
140 };
141 Lz4BlockHeader {
142 compression_method,
143 compression_level: self.compression_level,
144 compressed_len: buf_to_write.len() as u32,
145 decompressed_len: decompressed_buf.len() as u32,
146 checksum: self.checksum.run(decompressed_buf),
147 }
148 .write(&mut self.writer)?;
149 self.writer.write_all(buf_to_write)?;
150 }
151 self.write_ptr = 0;
152 self.writer.flush()?;
153 Ok(())
154 }
155}
156
157impl<'a, W: Write, C: Compression> Write for Lz4BlockOutputBase<'a, W, C> {
158 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
159 Ok(Self::write(self, buf)?)
160 }
161
162 fn flush(&mut self) -> std::io::Result<()> {
163 Ok(Self::flush(self)?)
164 }
165}
166
167impl<'a, W: Write, C: Compression> Drop for Lz4BlockOutputBase<'a, W, C> {
168 fn drop(&mut self) {
169 let _ = self.flush();
170 }
171}
172
173#[cfg(test)]
174mod test_lz4_block_output {
175 use super::{CompressionLevel, Context, Lz4BlockOutput};
176 use crate::lz4_block_header::data::VALID_DATA;
177
178 use std::io::Write;
179
180 #[test]
181 fn valid_default_block_size() {
182 let default_block_size = Lz4BlockOutput::<Vec<u8>>::default_block_size();
183 assert_eq!(
184 CompressionLevel::from_block_size(default_block_size).is_ok(),
185 true
186 );
187 }
188
189 #[test]
190 fn write_empty() {
191 let mut out = Vec::<u8>::new();
192 Lz4BlockOutput::with_context(&mut out, Context::default(), 128).unwrap();
193 assert_eq!(out, []);
194 }
195
196 #[test]
197 fn write_basic() {
198 let mut out = Vec::<u8>::new();
199 Lz4BlockOutput::with_context(&mut out, Context::default(), 128)
200 .unwrap()
201 .write_all("...".as_bytes())
202 .unwrap();
203 assert_eq!(out, VALID_DATA);
204 }
205
206 #[test]
207 fn write_several_small_blocks() {
208 let mut out = Vec::<u8>::new();
209 let buf = ['.' as u8; 1024];
210 let loops = 1024;
211 {
212 let mut writer =
213 Lz4BlockOutput::with_context(&mut out, Context::default(), buf.len() * loops)
214 .unwrap();
215 for _ in 0..loops {
216 writer.write_all(&buf).unwrap();
217 }
218 }
219 let needle = &VALID_DATA[..8];
220 assert_eq!(
222 out.windows(needle.len())
223 .filter(|window| *window == needle)
224 .count(),
225 1
226 );
227 }
228
229 #[test]
230 fn write_several_big_blocks() {
231 let mut out = Vec::<u8>::new();
232 let buf = ['.' as u8; 128];
233 let loops = 1234;
234 {
235 let mut writer =
236 Lz4BlockOutput::with_context(&mut out, Context::default(), buf.len()).unwrap();
237 for _ in 0..loops {
238 writer.write_all(&buf).unwrap();
239 }
240 }
241 let needle = &VALID_DATA[..8];
242 assert_eq!(
244 out.windows(needle.len())
245 .filter(|window| *window == needle)
246 .count(),
247 loops
248 );
249 }
250
251 #[test]
252 fn flush_basic() {
253 let mut out = Vec::<u8>::new();
254 {
255 let mut writer =
256 Lz4BlockOutput::with_context(&mut out, Context::default(), 128).unwrap();
257 writer.write_all("...".as_bytes()).unwrap();
258 writer.flush().unwrap();
259 writer.write_all("...".as_bytes()).unwrap();
260 }
261 let mut expected = VALID_DATA.to_vec();
262 expected.extend_from_slice(&VALID_DATA[..]);
263 assert_eq!(out, expected);
264 }
265}