Skip to main content

include_flate_compress/
lib.rs

1// include-flate
2// Copyright (C) SOFe, Kento Oki
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#[cfg(not(any(feature = "zstd", feature = "deflate")))]
17compile_error!("You must enable either the `deflate` or `zstd` feature.");
18
19use std::{
20    fmt,
21    io::{self, Read, Write},
22};
23
24#[cfg(feature = "deflate")]
25use libflate::deflate::Decoder as DeflateDecoder;
26#[cfg(feature = "deflate")]
27use libflate::deflate::Encoder as DeflateEncoder;
28#[cfg(feature = "zstd")]
29use zstd::Decoder as ZstdDecoder;
30#[cfg(feature = "zstd")]
31use zstd::Encoder as ZstdEncoder;
32
33#[derive(Debug)]
34pub enum CompressionError {
35    #[cfg(feature = "deflate")]
36    DeflateError(io::Error),
37    #[cfg(feature = "zstd")]
38    ZstdError(io::Error),
39    IoError(io::Error),
40}
41
42impl From<io::Error> for CompressionError {
43    fn from(err: io::Error) -> Self {
44        CompressionError::IoError(err)
45    }
46}
47
48impl fmt::Display for CompressionError {
49    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
50        match self {
51            #[cfg(feature = "deflate")]
52            CompressionError::DeflateError(err) => write!(f, "Deflate error: {}", err),
53            #[cfg(feature = "zstd")]
54            CompressionError::ZstdError(err) => write!(f, "Zstd error: {}", err),
55            CompressionError::IoError(err) => write!(f, "I/O error: {}", err),
56        }
57    }
58}
59
60#[derive(Debug, Copy, Clone, PartialEq)]
61pub enum CompressionMethod {
62    #[cfg(feature = "deflate")]
63    Deflate,
64    #[cfg(feature = "zstd")]
65    Zstd,
66}
67
68impl CompressionMethod {
69    pub fn encoder<W: Write>(&self, write: W) -> Result<FlateEncoder<W>, CompressionError> {
70        FlateEncoder::new(*self, write)
71    }
72
73    pub fn decoder<R: Read>(&self, read: R) -> Result<FlateDecoder<R>, CompressionError> {
74        FlateDecoder::new(*self, read)
75    }
76}
77
78#[expect(
79    clippy::derivable_impls,
80    reason = "cfg_attr on defaults could be confusing"
81)]
82#[cfg(any(feature = "deflate", feature = "zstd"))]
83impl Default for CompressionMethod {
84    fn default() -> Self {
85        #[cfg(feature = "deflate")]
86        {
87            Self::Deflate
88        }
89        #[cfg(all(not(feature = "deflate"), feature = "zstd"))]
90        {
91            Self::Zstd
92        }
93    }
94}
95
96impl fmt::Display for CompressionMethod {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.write_str(match self {
99            #[cfg(feature = "deflate")]
100            Self::Deflate => "deflate",
101            #[cfg(feature = "zstd")]
102            Self::Zstd => "zstd",
103        })
104    }
105}
106
107pub enum FlateEncoder<W: Write> {
108    #[cfg(feature = "deflate")]
109    Deflate(DeflateEncoder<W>),
110    #[cfg(feature = "zstd")]
111    Zstd(ZstdEncoder<'static, W>),
112}
113
114impl<W: Write> FlateEncoder<W> {
115    pub fn new(method: CompressionMethod, write: W) -> Result<FlateEncoder<W>, CompressionError> {
116        match method {
117            #[cfg(feature = "deflate")]
118            CompressionMethod::Deflate => Ok(FlateEncoder::Deflate(DeflateEncoder::new(write))),
119            #[cfg(feature = "zstd")]
120            CompressionMethod::Zstd => ZstdEncoder::new(write, 0)
121                .map(FlateEncoder::Zstd)
122                .map_err(CompressionError::ZstdError),
123        }
124    }
125}
126
127impl<W: Write> Write for FlateEncoder<W> {
128    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
129        match self {
130            #[cfg(feature = "deflate")]
131            FlateEncoder::Deflate(encoder) => encoder.write(buf),
132            #[cfg(feature = "zstd")]
133            FlateEncoder::Zstd(encoder) => encoder.write(buf),
134        }
135    }
136
137    fn flush(&mut self) -> io::Result<()> {
138        match self {
139            #[cfg(feature = "deflate")]
140            FlateEncoder::Deflate(encoder) => encoder.flush(),
141            #[cfg(feature = "zstd")]
142            FlateEncoder::Zstd(encoder) => encoder.flush(),
143        }
144    }
145}
146
147impl<W: Write> FlateEncoder<W> {
148    fn finish_encode(self) -> Result<W, CompressionError> {
149        match self {
150            #[cfg(feature = "deflate")]
151            FlateEncoder::Deflate(encoder) => encoder
152                .finish()
153                .into_result()
154                .map_err(CompressionError::DeflateError),
155            #[cfg(feature = "zstd")]
156            FlateEncoder::Zstd(encoder) => encoder.finish().map_err(CompressionError::ZstdError),
157        }
158    }
159}
160
161pub enum FlateDecoder<R> {
162    #[cfg(feature = "deflate")]
163    Deflate(DeflateDecoder<R>),
164    #[cfg(feature = "zstd")]
165    Zstd(ZstdDecoder<'static, std::io::BufReader<R>>),
166}
167
168impl<R: Read> FlateDecoder<R> {
169    pub fn new(method: CompressionMethod, read: R) -> Result<FlateDecoder<R>, CompressionError> {
170        match method {
171            #[cfg(feature = "deflate")]
172            CompressionMethod::Deflate => Ok(FlateDecoder::Deflate(DeflateDecoder::new(read))),
173            #[cfg(feature = "zstd")]
174            CompressionMethod::Zstd => {
175                let decoder = ZstdDecoder::new(read)?;
176                Ok(FlateDecoder::Zstd(decoder))
177            }
178        }
179    }
180}
181
182impl<R: Read> Read for FlateDecoder<R> {
183    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
184        match self {
185            #[cfg(feature = "deflate")]
186            FlateDecoder::Deflate(decoder) => decoder.read(buf),
187            #[cfg(feature = "zstd")]
188            FlateDecoder::Zstd(decoder) => decoder.read(buf),
189        }
190    }
191}
192
193pub fn apply_compression<R, W>(
194    reader: &mut R,
195    writer: &mut W,
196    method: CompressionMethod,
197) -> Result<(), CompressionError>
198where
199    R: Read,
200    W: Write,
201{
202    let mut encoder = method.encoder(writer)?;
203    io::copy(reader, &mut encoder)?;
204    encoder.finish_encode().map(|_| ())
205}
206
207pub fn apply_decompression(
208    reader: impl Read,
209    mut writer: impl Write,
210    method: CompressionMethod,
211) -> Result<(), CompressionError> {
212    let mut decoder = method.decoder(reader)?;
213    io::copy(&mut decoder, &mut writer)?;
214    Ok(())
215}