1use std::io::{self, Read};
39use std::path::Path;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
45pub enum CompressionFormat {
46 #[default]
48 None,
49
50 #[cfg(feature = "compression")]
52 Gzip,
53
54 #[cfg(feature = "compression-zstd")]
56 Zstd,
57
58 #[cfg(feature = "compression-lz4")]
60 Lz4,
61}
62
63impl CompressionFormat {
64 pub fn from_path<P: AsRef<Path>>(path: P) -> Self {
66 match path.as_ref().extension().and_then(|s| s.to_str()) {
67 #[cfg(feature = "compression")]
68 Some("gz" | "gzip") => CompressionFormat::Gzip,
69
70 #[cfg(feature = "compression-zstd")]
71 Some("zst" | "zstd") => CompressionFormat::Zstd,
72
73 #[cfg(feature = "compression-lz4")]
74 Some("lz4") => CompressionFormat::Lz4,
75
76 _ => CompressionFormat::None,
77 }
78 }
79
80 #[must_use]
82 pub fn from_magic_bytes(bytes: &[u8]) -> Self {
83 if bytes.len() < 2 {
84 return CompressionFormat::None;
85 }
86
87 #[cfg(feature = "compression")]
88 if bytes[0] == 0x1f && bytes[1] == 0x8b {
89 return CompressionFormat::Gzip;
90 }
91
92 if bytes.len() >= 4 {
93 #[cfg(feature = "compression-zstd")]
94 if bytes[0] == 0x28 && bytes[1] == 0xb5 && bytes[2] == 0x2f && bytes[3] == 0xfd {
95 return CompressionFormat::Zstd;
96 }
97
98 #[cfg(feature = "compression-lz4")]
99 if bytes[0] == 0x04 && bytes[1] == 0x22 && bytes[2] == 0x4d && bytes[3] == 0x18 {
100 return CompressionFormat::Lz4;
101 }
102 }
103
104 CompressionFormat::None
105 }
106
107 #[must_use]
109 pub fn is_compressed(&self) -> bool {
110 !matches!(self, CompressionFormat::None)
111 }
112
113 #[must_use]
115 pub fn extension(&self) -> Option<&'static str> {
116 match self {
117 CompressionFormat::None => None,
118 #[cfg(feature = "compression")]
119 CompressionFormat::Gzip => Some("gz"),
120 #[cfg(feature = "compression-zstd")]
121 CompressionFormat::Zstd => Some("zst"),
122 #[cfg(feature = "compression-lz4")]
123 CompressionFormat::Lz4 => Some("lz4"),
124 }
125 }
126}
127
128pub struct CompressionReader<R: Read> {
132 inner: Box<dyn Read>,
133 format: CompressionFormat,
134 _phantom: std::marker::PhantomData<R>,
136}
137
138impl<R: Read + 'static> CompressionReader<R> {
139 pub fn new(mut reader: R) -> io::Result<Self> {
143 let mut magic = [0u8; 4];
145 let bytes_read = Self::read_partial(&mut reader, &mut magic)?;
146
147 let format = CompressionFormat::from_magic_bytes(&magic[..bytes_read]);
149
150 Self::create_decoder(reader, format, Some(magic))
152 }
153
154 pub fn with_format(reader: R, format: CompressionFormat) -> io::Result<Self> {
156 Self::create_decoder(reader, format, None)
157 }
158
159 #[must_use]
161 pub fn format(&self) -> CompressionFormat {
162 self.format
163 }
164
165 fn read_partial(reader: &mut R, buf: &mut [u8]) -> io::Result<usize> {
167 let mut total = 0;
168 while total < buf.len() {
169 match reader.read(&mut buf[total..]) {
170 Ok(0) => break,
171 Ok(n) => total += n,
172 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue,
173 Err(e) => return Err(e),
174 }
175 }
176 Ok(total)
177 }
178
179 fn create_decoder(
181 reader: R,
182 format: CompressionFormat,
183 magic_prefix: Option<[u8; 4]>,
184 ) -> io::Result<Self> {
185 let inner: Box<dyn Read> = match (format, magic_prefix) {
186 (CompressionFormat::None, Some(magic)) => {
188 let chained = std::io::Cursor::new(magic).chain(reader);
189 Box::new(chained)
190 }
191 (CompressionFormat::None, None) => Box::new(reader),
192
193 #[cfg(feature = "compression")]
195 (CompressionFormat::Gzip, Some(magic)) => {
196 let chained = std::io::Cursor::new(magic).chain(reader);
197 Box::new(flate2::read::GzDecoder::new(chained))
198 }
199 #[cfg(feature = "compression")]
200 (CompressionFormat::Gzip, None) => Box::new(flate2::read::GzDecoder::new(reader)),
201
202 #[cfg(feature = "compression-zstd")]
204 (CompressionFormat::Zstd, Some(magic)) => {
205 let chained = std::io::Cursor::new(magic).chain(reader);
206 let decoder = zstd::Decoder::new(chained)
207 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
208 Box::new(decoder)
209 }
210 #[cfg(feature = "compression-zstd")]
211 (CompressionFormat::Zstd, None) => {
212 let decoder = zstd::Decoder::new(reader)
213 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
214 Box::new(decoder)
215 }
216
217 #[cfg(feature = "compression-lz4")]
219 (CompressionFormat::Lz4, Some(magic)) => {
220 let chained = std::io::Cursor::new(magic).chain(reader);
221 Box::new(lz4_flex::frame::FrameDecoder::new(chained))
222 }
223 #[cfg(feature = "compression-lz4")]
224 (CompressionFormat::Lz4, None) => Box::new(lz4_flex::frame::FrameDecoder::new(reader)),
225 };
226
227 Ok(Self {
228 inner,
229 format,
230 _phantom: std::marker::PhantomData,
231 })
232 }
233}
234
235impl<R: Read> Read for CompressionReader<R> {
236 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
237 self.inner.read(buf)
238 }
239}
240
241#[cfg(feature = "compression")]
243pub struct CompressionWriter<W: std::io::Write + 'static> {
244 inner: CompressionWriterInner<W>,
245 format: CompressionFormat,
246}
247
248#[cfg(feature = "compression")]
249enum CompressionWriterInner<W: std::io::Write> {
250 Plain(W),
251 Gzip(Box<flate2::write::GzEncoder<W>>),
253 #[cfg(feature = "compression-zstd")]
254 Zstd(Box<zstd::Encoder<'static, W>>),
255 #[cfg(feature = "compression-lz4")]
256 Lz4(Box<lz4_flex::frame::FrameEncoder<W>>),
257}
258
259#[cfg(feature = "compression")]
260impl<W: std::io::Write + 'static> CompressionWriter<W> {
261 pub fn new(writer: W, format: CompressionFormat) -> io::Result<Self> {
263 Self::with_level(writer, format, None)
264 }
265
266 pub fn with_level(
268 writer: W,
269 format: CompressionFormat,
270 level: Option<u32>,
271 ) -> io::Result<Self> {
272 let inner = match format {
273 CompressionFormat::None => CompressionWriterInner::Plain(writer),
274
275 CompressionFormat::Gzip => {
276 let level = flate2::Compression::new(level.unwrap_or(6));
277 CompressionWriterInner::Gzip(Box::new(flate2::write::GzEncoder::new(writer, level)))
278 }
279
280 #[cfg(feature = "compression-zstd")]
281 CompressionFormat::Zstd => {
282 let level = level.unwrap_or(3) as i32;
283 CompressionWriterInner::Zstd(Box::new(
284 zstd::Encoder::new(writer, level)
285 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?,
286 ))
287 }
288
289 #[cfg(feature = "compression-lz4")]
290 CompressionFormat::Lz4 => {
291 CompressionWriterInner::Lz4(Box::new(lz4_flex::frame::FrameEncoder::new(writer)))
292 }
293 };
294
295 Ok(Self { inner, format })
296 }
297
298 pub fn format(&self) -> CompressionFormat {
300 self.format
301 }
302
303 pub fn finish(self) -> io::Result<W> {
305 match self.inner {
306 CompressionWriterInner::Plain(w) => Ok(w),
307 CompressionWriterInner::Gzip(w) => w.finish(),
308
309 #[cfg(feature = "compression-zstd")]
310 CompressionWriterInner::Zstd(w) => w
311 .finish()
312 .map_err(|e| io::Error::new(io::ErrorKind::Other, e)),
313
314 #[cfg(feature = "compression-lz4")]
315 CompressionWriterInner::Lz4(w) => w
316 .finish()
317 .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string())),
318 }
319 }
320}
321
322#[cfg(feature = "compression")]
323impl<W: std::io::Write + 'static> std::io::Write for CompressionWriter<W> {
324 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
325 match &mut self.inner {
326 CompressionWriterInner::Plain(w) => w.write(buf),
327 CompressionWriterInner::Gzip(w) => w.write(buf),
328
329 #[cfg(feature = "compression-zstd")]
330 CompressionWriterInner::Zstd(w) => w.write(buf),
331
332 #[cfg(feature = "compression-lz4")]
333 CompressionWriterInner::Lz4(w) => w.write(buf),
334 }
335 }
336
337 fn flush(&mut self) -> io::Result<()> {
338 match &mut self.inner {
339 CompressionWriterInner::Plain(w) => w.flush(),
340 CompressionWriterInner::Gzip(w) => w.flush(),
341
342 #[cfg(feature = "compression-zstd")]
343 CompressionWriterInner::Zstd(w) => w.flush(),
344
345 #[cfg(feature = "compression-lz4")]
346 CompressionWriterInner::Lz4(w) => w.flush(),
347 }
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_format_from_path_uncompressed() {
357 assert_eq!(
358 CompressionFormat::from_path("data.hedl"),
359 CompressionFormat::None
360 );
361 assert_eq!(
362 CompressionFormat::from_path("data.txt"),
363 CompressionFormat::None
364 );
365 }
366
367 #[cfg(feature = "compression")]
368 #[test]
369 fn test_format_from_path_gzip() {
370 assert_eq!(
371 CompressionFormat::from_path("data.hedl.gz"),
372 CompressionFormat::Gzip
373 );
374 }
375
376 #[cfg(feature = "compression-zstd")]
377 #[test]
378 fn test_format_from_path_zstd() {
379 assert_eq!(
380 CompressionFormat::from_path("data.zst"),
381 CompressionFormat::Zstd
382 );
383 }
384
385 #[cfg(feature = "compression")]
386 #[test]
387 fn test_format_from_magic_gzip() {
388 assert_eq!(
389 CompressionFormat::from_magic_bytes(&[0x1f, 0x8b, 0x08, 0x00]),
390 CompressionFormat::Gzip
391 );
392 }
393
394 #[test]
395 fn test_compression_reader_uncompressed() {
396 let data = b"Hello, World!";
397 let reader = CompressionReader::new(std::io::Cursor::new(data.to_vec())).unwrap();
398 assert_eq!(reader.format(), CompressionFormat::None);
399
400 let mut output = String::new();
401 std::io::BufReader::new(reader)
402 .read_to_string(&mut output)
403 .unwrap();
404 assert!(output.starts_with("Hell"));
406 }
407
408 #[cfg(feature = "compression")]
409 #[test]
410 fn test_compression_reader_gzip_roundtrip() {
411 use std::io::Write;
412
413 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::fast());
415 encoder.write_all(b"Hello, HEDL!").unwrap();
416 let compressed = encoder.finish().unwrap();
417
418 let reader = CompressionReader::new(std::io::Cursor::new(compressed)).unwrap();
420 assert_eq!(reader.format(), CompressionFormat::Gzip);
421
422 let mut output = String::new();
423 std::io::BufReader::new(reader)
424 .read_to_string(&mut output)
425 .unwrap();
426 assert_eq!(output, "Hello, HEDL!");
427 }
428
429 #[cfg(feature = "compression")]
430 #[test]
431 fn test_compression_writer_gzip_roundtrip() {
432 use std::io::Write;
433
434 let mut writer = CompressionWriter::new(Vec::new(), CompressionFormat::Gzip).unwrap();
436 write!(writer, "Hello, HEDL!").unwrap();
437 let compressed = writer.finish().unwrap();
438
439 let mut decoder = flate2::read::GzDecoder::new(std::io::Cursor::new(compressed));
441 let mut output = String::new();
442 decoder.read_to_string(&mut output).unwrap();
443 assert_eq!(output, "Hello, HEDL!");
444 }
445}