include_flate_compress/
lib.rs

1// include-flate
2// Copyright (C) SOFe, Kento Oki
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#[cfg(not(any(feature = "zstd", feature = "deflate")))]
17compile_error!("You must enable either the `deflate` or `zstd` feature.");
18
19use std::{
20    fmt,
21    io::{self, BufRead, BufReader, Read, Seek, Write},
22};
23
24#[cfg(feature = "deflate")]
25use libflate::deflate::Decoder as DeflateDecoder;
26#[cfg(feature = "deflate")]
27use libflate::deflate::Encoder as DeflateEncoder;
28#[cfg(feature = "zstd")]
29use zstd::Decoder as ZstdDecoder;
30#[cfg(feature = "zstd")]
31use zstd::Encoder as ZstdEncoder;
32
33#[derive(Debug)]
34pub enum FlateCompressionError {
35    #[cfg(feature = "deflate")]
36    DeflateError(io::Error),
37    #[cfg(feature = "zstd")]
38    ZstdError(io::Error),
39    IoError(io::Error),
40}
41
42impl From<io::Error> for FlateCompressionError {
43    fn from(err: io::Error) -> Self {
44        FlateCompressionError::IoError(err)
45    }
46}
47
48impl fmt::Display for FlateCompressionError {
49    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
50        match self {
51            #[cfg(feature = "deflate")]
52            FlateCompressionError::DeflateError(err) => write!(f, "Deflate error: {}", err),
53            #[cfg(feature = "zstd")]
54            FlateCompressionError::ZstdError(err) => write!(f, "Zstd error: {}", err),
55            FlateCompressionError::IoError(err) => write!(f, "I/O error: {}", err),
56        }
57    }
58}
59
60#[derive(Debug, Copy, Clone)]
61pub enum CompressionMethod {
62    #[cfg(feature = "deflate")]
63    Deflate,
64    #[cfg(feature = "zstd")]
65    Zstd,
66}
67
68impl CompressionMethod {
69    pub fn encoder<'a, W: BufRead + Write + Seek + 'a>(
70        &'a self,
71        write: W,
72    ) -> Result<FlateEncoder<W>, FlateCompressionError> {
73        FlateEncoder::new(*self, write)
74    }
75
76    pub fn decoder<'a, R: ReadSeek + 'a>(
77        &'a self,
78        read: R,
79    ) -> Result<FlateDecoder<'a>, FlateCompressionError> {
80        FlateDecoder::new(*self, Box::new(read))
81    }
82}
83
84#[cfg(any(feature = "deflate", feature = "zstd"))]
85impl Default for CompressionMethod {
86    fn default() -> Self {
87        #[cfg(feature = "deflate")]
88        {
89            Self::Deflate
90        }
91        #[cfg(all(not(feature = "deflate"), feature = "zstd"))]
92        {
93            Self::Zstd
94        }
95    }
96}
97
98pub enum FlateEncoder<W: Write> {
99    #[cfg(feature = "deflate")]
100    Deflate(DeflateEncoder<W>),
101    #[cfg(feature = "zstd")]
102    Zstd(ZstdEncoder<'static, W>),
103}
104
105impl<'a, W: BufRead + Write + Seek + 'a> FlateEncoder<W> {
106    pub fn new(
107        method: CompressionMethod,
108        write: W,
109    ) -> Result<FlateEncoder<W>, FlateCompressionError> {
110        match method {
111            #[cfg(feature = "deflate")]
112            CompressionMethod::Deflate => Ok(FlateEncoder::Deflate(DeflateEncoder::new(write))),
113            #[cfg(feature = "zstd")]
114            CompressionMethod::Zstd => ZstdEncoder::new(write, 0)
115                .map(FlateEncoder::Zstd)
116                .map_err(FlateCompressionError::ZstdError),
117        }
118    }
119}
120
121impl<'a, W: Write + 'a> Write for FlateEncoder<W> {
122    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
123        match self {
124            #[cfg(feature = "deflate")]
125            FlateEncoder::Deflate(encoder) => encoder.write(buf),
126            #[cfg(feature = "zstd")]
127            FlateEncoder::Zstd(encoder) => encoder.write(buf),
128        }
129    }
130
131    fn flush(&mut self) -> io::Result<()> {
132        match self {
133            #[cfg(feature = "deflate")]
134            FlateEncoder::Deflate(encoder) => encoder.flush(),
135            #[cfg(feature = "zstd")]
136            FlateEncoder::Zstd(encoder) => encoder.flush(),
137        }
138    }
139}
140
141impl<'a, W: Write + 'a> FlateEncoder<W> {
142    fn finish_encode(self) -> Result<W, FlateCompressionError> {
143        match self {
144            #[cfg(feature = "deflate")]
145            FlateEncoder::Deflate(encoder) => encoder
146                .finish()
147                .into_result()
148                .map_err(FlateCompressionError::DeflateError),
149            #[cfg(feature = "zstd")]
150            FlateEncoder::Zstd(encoder) => {
151                encoder.finish().map_err(FlateCompressionError::ZstdError)
152            }
153        }
154    }
155}
156
157pub trait ReadSeek: BufRead + Seek {}
158
159impl<T: BufRead + Seek> ReadSeek for T {}
160
161pub enum FlateDecoder<'a> {
162    #[cfg(feature = "deflate")]
163    Deflate(DeflateDecoder<Box<dyn BufRead + 'a>>),
164    #[cfg(feature = "zstd")]
165    Zstd(ZstdDecoder<'a, BufReader<Box<dyn BufRead + 'a>>>),
166}
167
168impl<'a> FlateDecoder<'a> {
169    pub fn new(
170        method: CompressionMethod,
171        read: Box<dyn BufRead + 'a>,
172    ) -> Result<FlateDecoder<'a>, FlateCompressionError> {
173        match method {
174            #[cfg(feature = "deflate")]
175            CompressionMethod::Deflate => Ok(FlateDecoder::Deflate(DeflateDecoder::new(read))),
176            #[cfg(feature = "zstd")]
177            CompressionMethod::Zstd => {
178                let decoder = ZstdDecoder::new(read)?;
179                Ok(FlateDecoder::Zstd(decoder))
180            }
181        }
182    }
183}
184
185impl<'a> Read for FlateDecoder<'a> {
186    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
187        match self {
188            #[cfg(feature = "deflate")]
189            FlateDecoder::Deflate(decoder) => decoder.read(buf),
190            #[cfg(feature = "zstd")]
191            FlateDecoder::Zstd(decoder) => decoder.read(buf),
192        }
193    }
194}
195
196pub fn apply_compression<R: Sized, W: Sized + BufRead + Seek>(
197    reader: &mut R,
198    writer: &mut W,
199    method: CompressionMethod,
200) -> Result<(), FlateCompressionError>
201where
202    R: Read,
203    W: Write,
204{
205    let mut encoder = method.encoder(writer)?;
206    io::copy(reader, &mut encoder)?;
207    encoder.finish_encode().map(|_| ())
208}
209
210pub fn apply_decompression<R: Sized + BufRead + Seek, W: Sized>(
211    reader: &mut R,
212    writer: &mut W,
213    method: CompressionMethod,
214) -> Result<(), FlateCompressionError>
215where
216    R: Read,
217    W: Write,
218{
219    let mut decoder = method.decoder(reader)?;
220    io::copy(&mut decoder, writer)?;
221    Ok(())
222}