1use std::io::{self, Read};
39use std::path::Path;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
45pub enum CompressionFormat {
46 #[default]
48 None,
49
50 #[cfg(feature = "compression")]
52 Gzip,
53
54 #[cfg(feature = "compression-zstd")]
56 Zstd,
57
58 #[cfg(feature = "compression-lz4")]
60 Lz4,
61}
62
63impl CompressionFormat {
64 pub fn from_path<P: AsRef<Path>>(path: P) -> Self {
66 match path.as_ref().extension().and_then(|s| s.to_str()) {
67 #[cfg(feature = "compression")]
68 Some("gz" | "gzip") => CompressionFormat::Gzip,
69
70 #[cfg(feature = "compression-zstd")]
71 Some("zst" | "zstd") => CompressionFormat::Zstd,
72
73 #[cfg(feature = "compression-lz4")]
74 Some("lz4") => CompressionFormat::Lz4,
75
76 _ => CompressionFormat::None,
77 }
78 }
79
80 #[must_use]
82 pub fn from_magic_bytes(bytes: &[u8]) -> Self {
83 if bytes.len() < 2 {
84 return CompressionFormat::None;
85 }
86
87 #[cfg(feature = "compression")]
88 if bytes[0] == 0x1f && bytes[1] == 0x8b {
89 return CompressionFormat::Gzip;
90 }
91
92 if bytes.len() >= 4 {
93 #[cfg(feature = "compression-zstd")]
94 if bytes[0] == 0x28 && bytes[1] == 0xb5 && bytes[2] == 0x2f && bytes[3] == 0xfd {
95 return CompressionFormat::Zstd;
96 }
97
98 #[cfg(feature = "compression-lz4")]
99 if bytes[0] == 0x04 && bytes[1] == 0x22 && bytes[2] == 0x4d && bytes[3] == 0x18 {
100 return CompressionFormat::Lz4;
101 }
102 }
103
104 CompressionFormat::None
105 }
106
107 #[must_use]
109 pub fn is_compressed(&self) -> bool {
110 !matches!(self, CompressionFormat::None)
111 }
112
113 #[must_use]
115 pub fn extension(&self) -> Option<&'static str> {
116 match self {
117 CompressionFormat::None => None,
118 #[cfg(feature = "compression")]
119 CompressionFormat::Gzip => Some("gz"),
120 #[cfg(feature = "compression-zstd")]
121 CompressionFormat::Zstd => Some("zst"),
122 #[cfg(feature = "compression-lz4")]
123 CompressionFormat::Lz4 => Some("lz4"),
124 }
125 }
126}
127
128pub struct CompressionReader<R: Read> {
132 inner: Box<dyn Read>,
133 format: CompressionFormat,
134 _phantom: std::marker::PhantomData<R>,
136}
137
138impl<R: Read + 'static> CompressionReader<R> {
139 pub fn new(mut reader: R) -> io::Result<Self> {
143 let mut magic = [0u8; 4];
145 let bytes_read = Self::read_partial(&mut reader, &mut magic)?;
146
147 let format = CompressionFormat::from_magic_bytes(&magic[..bytes_read]);
149
150 Self::create_decoder(reader, format, Some(magic))
152 }
153
154 pub fn with_format(reader: R, format: CompressionFormat) -> io::Result<Self> {
156 Self::create_decoder(reader, format, None)
157 }
158
159 #[must_use]
161 pub fn format(&self) -> CompressionFormat {
162 self.format
163 }
164
165 fn read_partial(reader: &mut R, buf: &mut [u8]) -> io::Result<usize> {
167 let mut total = 0;
168 while total < buf.len() {
169 match reader.read(&mut buf[total..]) {
170 Ok(0) => break,
171 Ok(n) => total += n,
172 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue,
173 Err(e) => return Err(e),
174 }
175 }
176 Ok(total)
177 }
178
179 fn create_decoder(
181 reader: R,
182 format: CompressionFormat,
183 magic_prefix: Option<[u8; 4]>,
184 ) -> io::Result<Self> {
185 let inner: Box<dyn Read> = match (format, magic_prefix) {
186 (CompressionFormat::None, Some(magic)) => {
188 let chained = std::io::Cursor::new(magic).chain(reader);
189 Box::new(chained)
190 }
191 (CompressionFormat::None, None) => Box::new(reader),
192
193 #[cfg(feature = "compression")]
195 (CompressionFormat::Gzip, Some(magic)) => {
196 let chained = std::io::Cursor::new(magic).chain(reader);
197 Box::new(flate2::read::GzDecoder::new(chained))
198 }
199 #[cfg(feature = "compression")]
200 (CompressionFormat::Gzip, None) => Box::new(flate2::read::GzDecoder::new(reader)),
201
202 #[cfg(feature = "compression-zstd")]
204 (CompressionFormat::Zstd, Some(magic)) => {
205 let chained = std::io::Cursor::new(magic).chain(reader);
206 let decoder = zstd::Decoder::new(chained)
207 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
208 Box::new(decoder)
209 }
210 #[cfg(feature = "compression-zstd")]
211 (CompressionFormat::Zstd, None) => {
212 let decoder = zstd::Decoder::new(reader)
213 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
214 Box::new(decoder)
215 }
216
217 #[cfg(feature = "compression-lz4")]
219 (CompressionFormat::Lz4, Some(magic)) => {
220 let chained = std::io::Cursor::new(magic).chain(reader);
221 Box::new(lz4_flex::frame::FrameDecoder::new(chained))
222 }
223 #[cfg(feature = "compression-lz4")]
224 (CompressionFormat::Lz4, None) => Box::new(lz4_flex::frame::FrameDecoder::new(reader)),
225 };
226
227 Ok(Self {
228 inner,
229 format,
230 _phantom: std::marker::PhantomData,
231 })
232 }
233}
234
235impl<R: Read> Read for CompressionReader<R> {
236 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
237 self.inner.read(buf)
238 }
239}
240
241#[cfg(feature = "compression")]
243pub struct CompressionWriter<W: std::io::Write + 'static> {
244 inner: CompressionWriterInner<W>,
245 format: CompressionFormat,
246}
247
248#[cfg(feature = "compression")]
249enum CompressionWriterInner<W: std::io::Write> {
250 Plain(W),
251 Gzip(Box<flate2::write::GzEncoder<W>>),
253 #[cfg(feature = "compression-zstd")]
254 Zstd(Box<zstd::Encoder<'static, W>>),
255 #[cfg(feature = "compression-lz4")]
256 Lz4(Box<lz4_flex::frame::FrameEncoder<W>>),
257 #[allow(dead_code)]
259 Finished,
260}
261
262#[cfg(feature = "compression")]
263impl<W: std::io::Write + 'static> CompressionWriter<W> {
264 pub fn new(writer: W, format: CompressionFormat) -> io::Result<Self> {
266 Self::with_level(writer, format, None)
267 }
268
269 pub fn with_level(
271 writer: W,
272 format: CompressionFormat,
273 level: Option<u32>,
274 ) -> io::Result<Self> {
275 let inner = match format {
276 CompressionFormat::None => CompressionWriterInner::Plain(writer),
277
278 CompressionFormat::Gzip => {
279 let level = flate2::Compression::new(level.unwrap_or(6));
280 CompressionWriterInner::Gzip(Box::new(flate2::write::GzEncoder::new(writer, level)))
281 }
282
283 #[cfg(feature = "compression-zstd")]
284 CompressionFormat::Zstd => {
285 let level = level.unwrap_or(3) as i32;
286 CompressionWriterInner::Zstd(Box::new(
287 zstd::Encoder::new(writer, level)
288 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?,
289 ))
290 }
291
292 #[cfg(feature = "compression-lz4")]
293 CompressionFormat::Lz4 => {
294 CompressionWriterInner::Lz4(Box::new(lz4_flex::frame::FrameEncoder::new(writer)))
295 }
296 };
297
298 Ok(Self { inner, format })
299 }
300
301 pub fn format(&self) -> CompressionFormat {
303 self.format
304 }
305
306 pub fn finish(self) -> io::Result<W> {
308 match self.inner {
309 CompressionWriterInner::Plain(w) => Ok(w),
310 CompressionWriterInner::Gzip(w) => w.finish(),
311
312 #[cfg(feature = "compression-zstd")]
313 CompressionWriterInner::Zstd(w) => w
314 .finish()
315 .map_err(|e| io::Error::new(io::ErrorKind::Other, e)),
316
317 #[cfg(feature = "compression-lz4")]
318 CompressionWriterInner::Lz4(w) => w
319 .finish()
320 .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string())),
321
322 CompressionWriterInner::Finished => Err(io::Error::new(
323 io::ErrorKind::Other,
324 "Writer already finished",
325 )),
326 }
327 }
328}
329
330#[cfg(feature = "compression")]
331impl<W: std::io::Write + 'static> std::io::Write for CompressionWriter<W> {
332 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
333 match &mut self.inner {
334 CompressionWriterInner::Plain(w) => w.write(buf),
335 CompressionWriterInner::Gzip(w) => w.write(buf),
336
337 #[cfg(feature = "compression-zstd")]
338 CompressionWriterInner::Zstd(w) => w.write(buf),
339
340 #[cfg(feature = "compression-lz4")]
341 CompressionWriterInner::Lz4(w) => w.write(buf),
342
343 CompressionWriterInner::Finished => Err(io::Error::new(
344 io::ErrorKind::Other,
345 "Writer already finished",
346 )),
347 }
348 }
349
350 fn flush(&mut self) -> io::Result<()> {
351 match &mut self.inner {
352 CompressionWriterInner::Plain(w) => w.flush(),
353 CompressionWriterInner::Gzip(w) => w.flush(),
354
355 #[cfg(feature = "compression-zstd")]
356 CompressionWriterInner::Zstd(w) => w.flush(),
357
358 #[cfg(feature = "compression-lz4")]
359 CompressionWriterInner::Lz4(w) => w.flush(),
360
361 CompressionWriterInner::Finished => Ok(()),
362 }
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_format_from_path_uncompressed() {
372 assert_eq!(
373 CompressionFormat::from_path("data.hedl"),
374 CompressionFormat::None
375 );
376 assert_eq!(
377 CompressionFormat::from_path("data.txt"),
378 CompressionFormat::None
379 );
380 }
381
382 #[cfg(feature = "compression")]
383 #[test]
384 fn test_format_from_path_gzip() {
385 assert_eq!(
386 CompressionFormat::from_path("data.hedl.gz"),
387 CompressionFormat::Gzip
388 );
389 }
390
391 #[cfg(feature = "compression-zstd")]
392 #[test]
393 fn test_format_from_path_zstd() {
394 assert_eq!(
395 CompressionFormat::from_path("data.zst"),
396 CompressionFormat::Zstd
397 );
398 }
399
400 #[cfg(feature = "compression")]
401 #[test]
402 fn test_format_from_magic_gzip() {
403 assert_eq!(
404 CompressionFormat::from_magic_bytes(&[0x1f, 0x8b, 0x08, 0x00]),
405 CompressionFormat::Gzip
406 );
407 }
408
409 #[test]
410 fn test_compression_reader_uncompressed() {
411 let data = b"Hello, World!";
412 let reader = CompressionReader::new(std::io::Cursor::new(data.to_vec())).unwrap();
413 assert_eq!(reader.format(), CompressionFormat::None);
414
415 let mut output = String::new();
416 std::io::BufReader::new(reader)
417 .read_to_string(&mut output)
418 .unwrap();
419 assert!(output.starts_with("Hell"));
421 }
422
423 #[cfg(feature = "compression")]
424 #[test]
425 fn test_compression_reader_gzip_roundtrip() {
426 use std::io::Write;
427
428 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::fast());
430 encoder.write_all(b"Hello, HEDL!").unwrap();
431 let compressed = encoder.finish().unwrap();
432
433 let reader = CompressionReader::new(std::io::Cursor::new(compressed)).unwrap();
435 assert_eq!(reader.format(), CompressionFormat::Gzip);
436
437 let mut output = String::new();
438 std::io::BufReader::new(reader)
439 .read_to_string(&mut output)
440 .unwrap();
441 assert_eq!(output, "Hello, HEDL!");
442 }
443
444 #[cfg(feature = "compression")]
445 #[test]
446 fn test_compression_writer_gzip_roundtrip() {
447 use std::io::Write;
448
449 let mut writer = CompressionWriter::new(Vec::new(), CompressionFormat::Gzip).unwrap();
451 write!(writer, "Hello, HEDL!").unwrap();
452 let compressed = writer.finish().unwrap();
453
454 let mut decoder = flate2::read::GzDecoder::new(std::io::Cursor::new(compressed));
456 let mut output = String::new();
457 decoder.read_to_string(&mut output).unwrap();
458 assert_eq!(output, "Hello, HEDL!");
459 }
460}