Skip to main content

hedl_stream/
compression.rs

1// Dweve HEDL - Hierarchical Entity Data Language
2//
3// Copyright (c) 2025 Dweve IP B.V. and individual contributors.
4//
5// SPDX-License-Identifier: Apache-2.0
6//
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License in the LICENSE file at the
10// root of this repository or at: http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Transparent compression support for streaming HEDL parsing.
19//!
20//! This module provides automatic compression format detection and transparent
21//! decompression for HEDL files. Supported formats:
22//!
23//! - **GZIP** (`.gz`, `.gzip`) - Wide compatibility, HTTP standard
24//! - **ZSTD** (`.zst`, `.zstd`) - Best compression ratio/speed balance (optional)
25//! - **LZ4** (`.lz4`) - Fastest decompression (optional)
26//!
27//! # Examples
28//!
29//! ```rust,no_run
30//! use hedl_stream::compression::{CompressionFormat, CompressionReader};
31//! use std::fs::File;
32//!
33//! // Auto-detect from file extension
34//! let format = CompressionFormat::from_path("data.hedl.gz");
35//! assert!(matches!(format, CompressionFormat::Gzip));
36//! ```
37
38use std::io::{self, Read};
39use std::path::Path;
40
41/// Compression format for HEDL files.
42///
43/// Detected automatically from file extension or magic bytes.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
45pub enum CompressionFormat {
46    /// No compression (plain HEDL text).
47    #[default]
48    None,
49
50    /// GZIP compression (RFC 1952).
51    #[cfg(feature = "compression")]
52    Gzip,
53
54    /// Zstandard compression (RFC 8878).
55    #[cfg(feature = "compression-zstd")]
56    Zstd,
57
58    /// LZ4 frame compression.
59    #[cfg(feature = "compression-lz4")]
60    Lz4,
61}
62
63impl CompressionFormat {
64    /// Detect compression format from file path extension.
65    pub fn from_path<P: AsRef<Path>>(path: P) -> Self {
66        match path.as_ref().extension().and_then(|s| s.to_str()) {
67            #[cfg(feature = "compression")]
68            Some("gz" | "gzip") => CompressionFormat::Gzip,
69
70            #[cfg(feature = "compression-zstd")]
71            Some("zst" | "zstd") => CompressionFormat::Zstd,
72
73            #[cfg(feature = "compression-lz4")]
74            Some("lz4") => CompressionFormat::Lz4,
75
76            _ => CompressionFormat::None,
77        }
78    }
79
80    /// Detect compression format from magic bytes.
81    #[must_use]
82    pub fn from_magic_bytes(bytes: &[u8]) -> Self {
83        if bytes.len() < 2 {
84            return CompressionFormat::None;
85        }
86
87        #[cfg(feature = "compression")]
88        if bytes[0] == 0x1f && bytes[1] == 0x8b {
89            return CompressionFormat::Gzip;
90        }
91
92        if bytes.len() >= 4 {
93            #[cfg(feature = "compression-zstd")]
94            if bytes[0] == 0x28 && bytes[1] == 0xb5 && bytes[2] == 0x2f && bytes[3] == 0xfd {
95                return CompressionFormat::Zstd;
96            }
97
98            #[cfg(feature = "compression-lz4")]
99            if bytes[0] == 0x04 && bytes[1] == 0x22 && bytes[2] == 0x4d && bytes[3] == 0x18 {
100                return CompressionFormat::Lz4;
101            }
102        }
103
104        CompressionFormat::None
105    }
106
107    /// Returns whether compression is enabled for this format.
108    #[must_use]
109    pub fn is_compressed(&self) -> bool {
110        !matches!(self, CompressionFormat::None)
111    }
112
113    /// Returns the file extension typically used for this format.
114    #[must_use]
115    pub fn extension(&self) -> Option<&'static str> {
116        match self {
117            CompressionFormat::None => None,
118            #[cfg(feature = "compression")]
119            CompressionFormat::Gzip => Some("gz"),
120            #[cfg(feature = "compression-zstd")]
121            CompressionFormat::Zstd => Some("zst"),
122            #[cfg(feature = "compression-lz4")]
123            CompressionFormat::Lz4 => Some("lz4"),
124        }
125    }
126}
127
128/// A reader that transparently decompresses data based on format.
129///
130/// Uses boxed trait objects for simplicity and type erasure.
131pub struct CompressionReader<R: Read> {
132    inner: Box<dyn Read>,
133    format: CompressionFormat,
134    // Keep the phantom to maintain the type parameter in the signature
135    _phantom: std::marker::PhantomData<R>,
136}
137
138impl<R: Read + 'static> CompressionReader<R> {
139    /// Create a compression reader with automatic format detection.
140    ///
141    /// Reads the first 4 bytes to detect the compression format.
142    pub fn new(mut reader: R) -> io::Result<Self> {
143        // Read magic bytes for format detection
144        let mut magic = [0u8; 4];
145        let bytes_read = Self::read_partial(&mut reader, &mut magic)?;
146
147        // Detect format from magic bytes
148        let format = CompressionFormat::from_magic_bytes(&magic[..bytes_read]);
149
150        // Create the appropriate decoder
151        Self::create_decoder(reader, format, Some(magic))
152    }
153
154    /// Create a compression reader with explicit format specification.
155    pub fn with_format(reader: R, format: CompressionFormat) -> io::Result<Self> {
156        Self::create_decoder(reader, format, None)
157    }
158
159    /// Get the detected or specified compression format.
160    #[must_use]
161    pub fn format(&self) -> CompressionFormat {
162        self.format
163    }
164
165    /// Read up to `buf.len()` bytes, returning actual bytes read.
166    fn read_partial(reader: &mut R, buf: &mut [u8]) -> io::Result<usize> {
167        let mut total = 0;
168        while total < buf.len() {
169            match reader.read(&mut buf[total..]) {
170                Ok(0) => break,
171                Ok(n) => total += n,
172                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue,
173                Err(e) => return Err(e),
174            }
175        }
176        Ok(total)
177    }
178
179    /// Create the appropriate decoder based on format.
180    fn create_decoder(
181        reader: R,
182        format: CompressionFormat,
183        magic_prefix: Option<[u8; 4]>,
184    ) -> io::Result<Self> {
185        let inner: Box<dyn Read> = match (format, magic_prefix) {
186            // Uncompressed - chain magic bytes back if we read them
187            (CompressionFormat::None, Some(magic)) => {
188                let chained = std::io::Cursor::new(magic).chain(reader);
189                Box::new(chained)
190            }
191            (CompressionFormat::None, None) => Box::new(reader),
192
193            // GZIP
194            #[cfg(feature = "compression")]
195            (CompressionFormat::Gzip, Some(magic)) => {
196                let chained = std::io::Cursor::new(magic).chain(reader);
197                Box::new(flate2::read::GzDecoder::new(chained))
198            }
199            #[cfg(feature = "compression")]
200            (CompressionFormat::Gzip, None) => Box::new(flate2::read::GzDecoder::new(reader)),
201
202            // ZSTD
203            #[cfg(feature = "compression-zstd")]
204            (CompressionFormat::Zstd, Some(magic)) => {
205                let chained = std::io::Cursor::new(magic).chain(reader);
206                let decoder = zstd::Decoder::new(chained)
207                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
208                Box::new(decoder)
209            }
210            #[cfg(feature = "compression-zstd")]
211            (CompressionFormat::Zstd, None) => {
212                let decoder = zstd::Decoder::new(reader)
213                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
214                Box::new(decoder)
215            }
216
217            // LZ4
218            #[cfg(feature = "compression-lz4")]
219            (CompressionFormat::Lz4, Some(magic)) => {
220                let chained = std::io::Cursor::new(magic).chain(reader);
221                Box::new(lz4_flex::frame::FrameDecoder::new(chained))
222            }
223            #[cfg(feature = "compression-lz4")]
224            (CompressionFormat::Lz4, None) => Box::new(lz4_flex::frame::FrameDecoder::new(reader)),
225        };
226
227        Ok(Self {
228            inner,
229            format,
230            _phantom: std::marker::PhantomData,
231        })
232    }
233}
234
235impl<R: Read> Read for CompressionReader<R> {
236    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
237        self.inner.read(buf)
238    }
239}
240
241/// A writer that compresses data as it is written.
242#[cfg(feature = "compression")]
243pub struct CompressionWriter<W: std::io::Write + 'static> {
244    inner: CompressionWriterInner<W>,
245    format: CompressionFormat,
246}
247
248#[cfg(feature = "compression")]
249enum CompressionWriterInner<W: std::io::Write> {
250    Plain(W),
251    // Box the large encoder types to reduce enum variant size
252    Gzip(Box<flate2::write::GzEncoder<W>>),
253    #[cfg(feature = "compression-zstd")]
254    Zstd(Box<zstd::Encoder<'static, W>>),
255    #[cfg(feature = "compression-lz4")]
256    Lz4(Box<lz4_flex::frame::FrameEncoder<W>>),
257}
258
259#[cfg(feature = "compression")]
260impl<W: std::io::Write + 'static> CompressionWriter<W> {
261    /// Create a compression writer with the specified format.
262    pub fn new(writer: W, format: CompressionFormat) -> io::Result<Self> {
263        Self::with_level(writer, format, None)
264    }
265
266    /// Create a compression writer with a specific compression level.
267    pub fn with_level(
268        writer: W,
269        format: CompressionFormat,
270        level: Option<u32>,
271    ) -> io::Result<Self> {
272        let inner = match format {
273            CompressionFormat::None => CompressionWriterInner::Plain(writer),
274
275            CompressionFormat::Gzip => {
276                let level = flate2::Compression::new(level.unwrap_or(6));
277                CompressionWriterInner::Gzip(Box::new(flate2::write::GzEncoder::new(writer, level)))
278            }
279
280            #[cfg(feature = "compression-zstd")]
281            CompressionFormat::Zstd => {
282                let level = level.unwrap_or(3) as i32;
283                CompressionWriterInner::Zstd(Box::new(
284                    zstd::Encoder::new(writer, level)
285                        .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?,
286                ))
287            }
288
289            #[cfg(feature = "compression-lz4")]
290            CompressionFormat::Lz4 => {
291                CompressionWriterInner::Lz4(Box::new(lz4_flex::frame::FrameEncoder::new(writer)))
292            }
293        };
294
295        Ok(Self { inner, format })
296    }
297
298    /// Get the compression format being used.
299    pub fn format(&self) -> CompressionFormat {
300        self.format
301    }
302
303    /// Finish compression and return the underlying writer.
304    pub fn finish(self) -> io::Result<W> {
305        match self.inner {
306            CompressionWriterInner::Plain(w) => Ok(w),
307            CompressionWriterInner::Gzip(w) => w.finish(),
308
309            #[cfg(feature = "compression-zstd")]
310            CompressionWriterInner::Zstd(w) => w
311                .finish()
312                .map_err(|e| io::Error::new(io::ErrorKind::Other, e)),
313
314            #[cfg(feature = "compression-lz4")]
315            CompressionWriterInner::Lz4(w) => w
316                .finish()
317                .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string())),
318        }
319    }
320}
321
322#[cfg(feature = "compression")]
323impl<W: std::io::Write + 'static> std::io::Write for CompressionWriter<W> {
324    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
325        match &mut self.inner {
326            CompressionWriterInner::Plain(w) => w.write(buf),
327            CompressionWriterInner::Gzip(w) => w.write(buf),
328
329            #[cfg(feature = "compression-zstd")]
330            CompressionWriterInner::Zstd(w) => w.write(buf),
331
332            #[cfg(feature = "compression-lz4")]
333            CompressionWriterInner::Lz4(w) => w.write(buf),
334        }
335    }
336
337    fn flush(&mut self) -> io::Result<()> {
338        match &mut self.inner {
339            CompressionWriterInner::Plain(w) => w.flush(),
340            CompressionWriterInner::Gzip(w) => w.flush(),
341
342            #[cfg(feature = "compression-zstd")]
343            CompressionWriterInner::Zstd(w) => w.flush(),
344
345            #[cfg(feature = "compression-lz4")]
346            CompressionWriterInner::Lz4(w) => w.flush(),
347        }
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_format_from_path_uncompressed() {
357        assert_eq!(
358            CompressionFormat::from_path("data.hedl"),
359            CompressionFormat::None
360        );
361        assert_eq!(
362            CompressionFormat::from_path("data.txt"),
363            CompressionFormat::None
364        );
365    }
366
367    #[cfg(feature = "compression")]
368    #[test]
369    fn test_format_from_path_gzip() {
370        assert_eq!(
371            CompressionFormat::from_path("data.hedl.gz"),
372            CompressionFormat::Gzip
373        );
374    }
375
376    #[cfg(feature = "compression-zstd")]
377    #[test]
378    fn test_format_from_path_zstd() {
379        assert_eq!(
380            CompressionFormat::from_path("data.zst"),
381            CompressionFormat::Zstd
382        );
383    }
384
385    #[cfg(feature = "compression")]
386    #[test]
387    fn test_format_from_magic_gzip() {
388        assert_eq!(
389            CompressionFormat::from_magic_bytes(&[0x1f, 0x8b, 0x08, 0x00]),
390            CompressionFormat::Gzip
391        );
392    }
393
394    #[test]
395    fn test_compression_reader_uncompressed() {
396        let data = b"Hello, World!";
397        let reader = CompressionReader::new(std::io::Cursor::new(data.to_vec())).unwrap();
398        assert_eq!(reader.format(), CompressionFormat::None);
399
400        let mut output = String::new();
401        std::io::BufReader::new(reader)
402            .read_to_string(&mut output)
403            .unwrap();
404        // Magic bytes are chained back for uncompressed
405        assert!(output.starts_with("Hell"));
406    }
407
408    #[cfg(feature = "compression")]
409    #[test]
410    fn test_compression_reader_gzip_roundtrip() {
411        use std::io::Write;
412
413        // Create compressed data
414        let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::fast());
415        encoder.write_all(b"Hello, HEDL!").unwrap();
416        let compressed = encoder.finish().unwrap();
417
418        // Read it back
419        let reader = CompressionReader::new(std::io::Cursor::new(compressed)).unwrap();
420        assert_eq!(reader.format(), CompressionFormat::Gzip);
421
422        let mut output = String::new();
423        std::io::BufReader::new(reader)
424            .read_to_string(&mut output)
425            .unwrap();
426        assert_eq!(output, "Hello, HEDL!");
427    }
428
429    #[cfg(feature = "compression")]
430    #[test]
431    fn test_compression_writer_gzip_roundtrip() {
432        use std::io::Write;
433
434        // Write compressed data
435        let mut writer = CompressionWriter::new(Vec::new(), CompressionFormat::Gzip).unwrap();
436        write!(writer, "Hello, HEDL!").unwrap();
437        let compressed = writer.finish().unwrap();
438
439        // Read it back with flate2 directly
440        let mut decoder = flate2::read::GzDecoder::new(std::io::Cursor::new(compressed));
441        let mut output = String::new();
442        decoder.read_to_string(&mut output).unwrap();
443        assert_eq!(output, "Hello, HEDL!");
444    }
445}