Skip to main content

include_flate_compress/
lib.rs

1// include-flate
2// Copyright (C) SOFe, Kento Oki
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#[cfg(not(any(feature = "zstd", feature = "deflate")))]
17compile_error!("You must enable either the `deflate` or `zstd` feature.");
18
19use std::{
20    fmt,
21    io::{self, Read, Write},
22};
23
24#[cfg(feature = "deflate")]
25use libflate::deflate::Decoder as DeflateDecoder;
26#[cfg(feature = "deflate")]
27use libflate::deflate::Encoder as DeflateEncoder;
28#[cfg(feature = "zstd")]
29use zstd::Decoder as ZstdDecoder;
30#[cfg(feature = "zstd")]
31use zstd::Encoder as ZstdEncoder;
32
33#[derive(Debug)]
34pub enum CompressionError {
35    #[cfg(feature = "deflate")]
36    DeflateError(io::Error),
37    #[cfg(feature = "zstd")]
38    ZstdError(io::Error),
39    IoError(io::Error),
40}
41
42impl From<io::Error> for CompressionError {
43    fn from(err: io::Error) -> Self {
44        CompressionError::IoError(err)
45    }
46}
47
48impl fmt::Display for CompressionError {
49    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
50        match self {
51            #[cfg(feature = "deflate")]
52            CompressionError::DeflateError(err) => write!(f, "Deflate error: {}", err),
53            #[cfg(feature = "zstd")]
54            CompressionError::ZstdError(err) => write!(f, "Zstd error: {}", err),
55            CompressionError::IoError(err) => write!(f, "I/O error: {}", err),
56        }
57    }
58}
59
60#[derive(Debug, Copy, Clone, PartialEq)]
61pub enum CompressionMethod {
62    #[cfg(feature = "deflate")]
63    Deflate,
64    #[cfg(feature = "zstd")]
65    Zstd,
66}
67
68impl CompressionMethod {
69    pub fn encoder<W: Write>(
70        &self,
71        write: W,
72    ) -> Result<FlateEncoder<W>, CompressionError> {
73        FlateEncoder::new(*self, write)
74    }
75
76    pub fn decoder<R: Read>(
77        &self,
78        read: R,
79    ) -> Result<FlateDecoder<R>, CompressionError> {
80        FlateDecoder::new(*self, read)
81    }
82}
83
84#[expect(clippy::derivable_impls, reason = "cfg_attr on defaults could be confusing")]
85#[cfg(any(feature = "deflate", feature = "zstd"))]
86impl Default for CompressionMethod {
87    fn default() -> Self {
88        #[cfg(feature = "deflate")]
89        {
90            Self::Deflate
91        }
92        #[cfg(all(not(feature = "deflate"), feature = "zstd"))]
93        {
94            Self::Zstd
95        }
96    }
97}
98
99impl fmt::Display for CompressionMethod {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        f.write_str(match self {
102            #[cfg(feature = "deflate")]
103            Self::Deflate => "deflate",
104            #[cfg(feature = "zstd")]
105            Self::Zstd => "zstd",
106        })
107    }
108}
109
110pub enum FlateEncoder<W: Write> {
111    #[cfg(feature = "deflate")]
112    Deflate(DeflateEncoder<W>),
113    #[cfg(feature = "zstd")]
114    Zstd(ZstdEncoder<'static, W>),
115}
116
117impl<W: Write> FlateEncoder<W> {
118    pub fn new(
119        method: CompressionMethod,
120        write: W,
121    ) -> Result<FlateEncoder<W>, CompressionError> {
122        match method {
123            #[cfg(feature = "deflate")]
124            CompressionMethod::Deflate => Ok(FlateEncoder::Deflate(DeflateEncoder::new(write))),
125            #[cfg(feature = "zstd")]
126            CompressionMethod::Zstd => ZstdEncoder::new(write, 0)
127                .map(FlateEncoder::Zstd)
128                .map_err(CompressionError::ZstdError),
129        }
130    }
131}
132
133impl<W: Write> Write for FlateEncoder<W> {
134    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
135        match self {
136            #[cfg(feature = "deflate")]
137            FlateEncoder::Deflate(encoder) => encoder.write(buf),
138            #[cfg(feature = "zstd")]
139            FlateEncoder::Zstd(encoder) => encoder.write(buf),
140        }
141    }
142
143    fn flush(&mut self) -> io::Result<()> {
144        match self {
145            #[cfg(feature = "deflate")]
146            FlateEncoder::Deflate(encoder) => encoder.flush(),
147            #[cfg(feature = "zstd")]
148            FlateEncoder::Zstd(encoder) => encoder.flush(),
149        }
150    }
151}
152
153impl<W: Write> FlateEncoder<W> {
154    fn finish_encode(self) -> Result<W, CompressionError> {
155        match self {
156            #[cfg(feature = "deflate")]
157            FlateEncoder::Deflate(encoder) => encoder
158                .finish()
159                .into_result()
160                .map_err(CompressionError::DeflateError),
161            #[cfg(feature = "zstd")]
162            FlateEncoder::Zstd(encoder) => {
163                encoder.finish().map_err(CompressionError::ZstdError)
164            }
165        }
166    }
167}
168
169pub enum FlateDecoder<R> {
170    #[cfg(feature = "deflate")]
171    Deflate(DeflateDecoder<R>),
172    #[cfg(feature = "zstd")]
173    Zstd(ZstdDecoder<'static, std::io::BufReader<R>>),
174}
175
176impl<R: Read> FlateDecoder<R> {
177    pub fn new(
178        method: CompressionMethod,
179        read: R,
180    ) -> Result<FlateDecoder<R>, CompressionError> {
181        match method {
182            #[cfg(feature = "deflate")]
183            CompressionMethod::Deflate => Ok(FlateDecoder::Deflate(DeflateDecoder::new(read))),
184            #[cfg(feature = "zstd")]
185            CompressionMethod::Zstd => {
186                let decoder = ZstdDecoder::new(read)?;
187                Ok(FlateDecoder::Zstd(decoder))
188            }
189        }
190    }
191}
192
193impl<R: Read> Read for FlateDecoder<R> {
194    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
195        match self {
196            #[cfg(feature = "deflate")]
197            FlateDecoder::Deflate(decoder) => decoder.read(buf),
198            #[cfg(feature = "zstd")]
199            FlateDecoder::Zstd(decoder) => decoder.read(buf),
200        }
201    }
202}
203
204pub fn apply_compression<R, W>(
205    reader: &mut R,
206    writer: &mut W,
207    method: CompressionMethod,
208) -> Result<(), CompressionError>
209where
210    R: Read,
211    W: Write,
212{
213    let mut encoder = method.encoder(writer)?;
214    io::copy(reader, &mut encoder)?;
215    encoder.finish_encode().map(|_| ())
216}
217
218pub fn apply_decompression(
219    reader: impl Read,
220    mut writer: impl Write,
221    method: CompressionMethod,
222) -> Result<(), CompressionError>
223{
224    let mut decoder = method.decoder(reader)?;
225    io::copy(&mut decoder, &mut writer)?;
226    Ok(())
227}