datafusion_datasource/
file_compression_type.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! File Compression type abstraction
19
20use std::str::FromStr;
21
22use datafusion_common::error::{DataFusionError, Result};
23
24use datafusion_common::GetExt;
25use datafusion_common::parsers::CompressionTypeVariant::{self, *};
26
27#[cfg(feature = "compression")]
28use async_compression::tokio::bufread::{
29    BzDecoder as AsyncBzDecoder, BzEncoder as AsyncBzEncoder,
30    GzipDecoder as AsyncGzDecoder, GzipEncoder as AsyncGzEncoder,
31    XzDecoder as AsyncXzDecoder, XzEncoder as AsyncXzEncoder,
32    ZstdDecoder as AsyncZstdDecoer, ZstdEncoder as AsyncZstdEncoder,
33};
34
35#[cfg(feature = "compression")]
36use async_compression::tokio::write::{BzEncoder, GzipEncoder, XzEncoder, ZstdEncoder};
37use bytes::Bytes;
38#[cfg(feature = "compression")]
39use bzip2::read::MultiBzDecoder;
40#[cfg(feature = "compression")]
41use flate2::read::MultiGzDecoder;
42use futures::StreamExt;
43#[cfg(feature = "compression")]
44use futures::TryStreamExt;
45use futures::stream::BoxStream;
46#[cfg(feature = "compression")]
47use liblzma::read::XzDecoder;
48use object_store::buffered::BufWriter;
49use tokio::io::AsyncWrite;
50#[cfg(feature = "compression")]
51use tokio_util::io::{ReaderStream, StreamReader};
52#[cfg(feature = "compression")]
53use zstd::Decoder as ZstdDecoder;
54
55/// Readable file compression type
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub struct FileCompressionType {
58    variant: CompressionTypeVariant,
59}
60
61impl GetExt for FileCompressionType {
62    fn get_ext(&self) -> String {
63        match self.variant {
64            GZIP => ".gz".to_owned(),
65            BZIP2 => ".bz2".to_owned(),
66            XZ => ".xz".to_owned(),
67            ZSTD => ".zst".to_owned(),
68            UNCOMPRESSED => "".to_owned(),
69        }
70    }
71}
72
73impl From<CompressionTypeVariant> for FileCompressionType {
74    fn from(t: CompressionTypeVariant) -> Self {
75        Self { variant: t }
76    }
77}
78
79impl From<FileCompressionType> for CompressionTypeVariant {
80    fn from(t: FileCompressionType) -> Self {
81        t.variant
82    }
83}
84
85impl FromStr for FileCompressionType {
86    type Err = DataFusionError;
87
88    fn from_str(s: &str) -> Result<Self> {
89        let variant = CompressionTypeVariant::from_str(s).map_err(|_| {
90            DataFusionError::NotImplemented(format!("Unknown FileCompressionType: {s}"))
91        })?;
92        Ok(Self { variant })
93    }
94}
95
96/// `FileCompressionType` implementation
97impl FileCompressionType {
98    /// Gzip-ed file
99    pub const GZIP: Self = Self { variant: GZIP };
100
101    /// Bzip2-ed file
102    pub const BZIP2: Self = Self { variant: BZIP2 };
103
104    /// Xz-ed file (liblzma)
105    pub const XZ: Self = Self { variant: XZ };
106
107    /// Zstd-ed file
108    pub const ZSTD: Self = Self { variant: ZSTD };
109
110    /// Uncompressed file
111    pub const UNCOMPRESSED: Self = Self {
112        variant: UNCOMPRESSED,
113    };
114
115    /// Read only access to self.variant
116    pub fn get_variant(&self) -> &CompressionTypeVariant {
117        &self.variant
118    }
119
120    /// The file is compressed or not
121    pub const fn is_compressed(&self) -> bool {
122        self.variant.is_compressed()
123    }
124
125    /// Given a `Stream`, create a `Stream` which data are compressed with `FileCompressionType`.
126    pub fn convert_to_compress_stream<'a>(
127        &self,
128        s: BoxStream<'a, Result<Bytes>>,
129    ) -> Result<BoxStream<'a, Result<Bytes>>> {
130        Ok(match self.variant {
131            #[cfg(feature = "compression")]
132            GZIP => ReaderStream::new(AsyncGzEncoder::new(StreamReader::new(s)))
133                .map_err(DataFusionError::from)
134                .boxed(),
135            #[cfg(feature = "compression")]
136            BZIP2 => ReaderStream::new(AsyncBzEncoder::new(StreamReader::new(s)))
137                .map_err(DataFusionError::from)
138                .boxed(),
139            #[cfg(feature = "compression")]
140            XZ => ReaderStream::new(AsyncXzEncoder::new(StreamReader::new(s)))
141                .map_err(DataFusionError::from)
142                .boxed(),
143            #[cfg(feature = "compression")]
144            ZSTD => ReaderStream::new(AsyncZstdEncoder::new(StreamReader::new(s)))
145                .map_err(DataFusionError::from)
146                .boxed(),
147            #[cfg(not(feature = "compression"))]
148            GZIP | BZIP2 | XZ | ZSTD => {
149                return Err(DataFusionError::NotImplemented(
150                    "Compression feature is not enabled".to_owned(),
151                ));
152            }
153            UNCOMPRESSED => s.boxed(),
154        })
155    }
156
157    /// Wrap the given `BufWriter` so that it performs compressed writes
158    /// according to this `FileCompressionType` using the default compression level.
159    pub fn convert_async_writer(
160        &self,
161        w: BufWriter,
162    ) -> Result<Box<dyn AsyncWrite + Send + Unpin>> {
163        self.convert_async_writer_with_level(w, None)
164    }
165
166    /// Wrap the given `BufWriter` so that it performs compressed writes
167    /// according to this `FileCompressionType`.
168    ///
169    /// If `compression_level` is `Some`, the encoder will use the specified
170    /// compression level. If `None`, the default level for each algorithm is used.
171    pub fn convert_async_writer_with_level(
172        &self,
173        w: BufWriter,
174        compression_level: Option<u32>,
175    ) -> Result<Box<dyn AsyncWrite + Send + Unpin>> {
176        #[cfg(feature = "compression")]
177        use async_compression::Level;
178
179        Ok(match self.variant {
180            #[cfg(feature = "compression")]
181            GZIP => match compression_level {
182                Some(level) => {
183                    Box::new(GzipEncoder::with_quality(w, Level::Precise(level as i32)))
184                }
185                None => Box::new(GzipEncoder::new(w)),
186            },
187            #[cfg(feature = "compression")]
188            BZIP2 => match compression_level {
189                Some(level) => {
190                    Box::new(BzEncoder::with_quality(w, Level::Precise(level as i32)))
191                }
192                None => Box::new(BzEncoder::new(w)),
193            },
194            #[cfg(feature = "compression")]
195            XZ => match compression_level {
196                Some(level) => {
197                    Box::new(XzEncoder::with_quality(w, Level::Precise(level as i32)))
198                }
199                None => Box::new(XzEncoder::new(w)),
200            },
201            #[cfg(feature = "compression")]
202            ZSTD => match compression_level {
203                Some(level) => {
204                    Box::new(ZstdEncoder::with_quality(w, Level::Precise(level as i32)))
205                }
206                None => Box::new(ZstdEncoder::new(w)),
207            },
208            #[cfg(not(feature = "compression"))]
209            GZIP | BZIP2 | XZ | ZSTD => {
210                // compression_level is not used when compression feature is disabled
211                let _ = compression_level;
212                return Err(DataFusionError::NotImplemented(
213                    "Compression feature is not enabled".to_owned(),
214                ));
215            }
216            UNCOMPRESSED => Box::new(w),
217        })
218    }
219
220    /// Given a `Stream`, create a `Stream` which data are decompressed with `FileCompressionType`.
221    pub fn convert_stream<'a>(
222        &self,
223        s: BoxStream<'a, Result<Bytes>>,
224    ) -> Result<BoxStream<'a, Result<Bytes>>> {
225        Ok(match self.variant {
226            #[cfg(feature = "compression")]
227            GZIP => {
228                let mut decoder = AsyncGzDecoder::new(StreamReader::new(s));
229                decoder.multiple_members(true);
230
231                ReaderStream::new(decoder)
232                    .map_err(DataFusionError::from)
233                    .boxed()
234            }
235            #[cfg(feature = "compression")]
236            BZIP2 => ReaderStream::new(AsyncBzDecoder::new(StreamReader::new(s)))
237                .map_err(DataFusionError::from)
238                .boxed(),
239            #[cfg(feature = "compression")]
240            XZ => ReaderStream::new(AsyncXzDecoder::new(StreamReader::new(s)))
241                .map_err(DataFusionError::from)
242                .boxed(),
243            #[cfg(feature = "compression")]
244            ZSTD => ReaderStream::new(AsyncZstdDecoer::new(StreamReader::new(s)))
245                .map_err(DataFusionError::from)
246                .boxed(),
247            #[cfg(not(feature = "compression"))]
248            GZIP | BZIP2 | XZ | ZSTD => {
249                return Err(DataFusionError::NotImplemented(
250                    "Compression feature is not enabled".to_owned(),
251                ));
252            }
253            UNCOMPRESSED => s.boxed(),
254        })
255    }
256
257    /// Given a `Read`, create a `Read` which data are decompressed with `FileCompressionType`.
258    pub fn convert_read<T: std::io::Read + Send + 'static>(
259        &self,
260        r: T,
261    ) -> Result<Box<dyn std::io::Read + Send>> {
262        Ok(match self.variant {
263            #[cfg(feature = "compression")]
264            GZIP => Box::new(MultiGzDecoder::new(r)),
265            #[cfg(feature = "compression")]
266            BZIP2 => Box::new(MultiBzDecoder::new(r)),
267            #[cfg(feature = "compression")]
268            XZ => Box::new(XzDecoder::new_multi_decoder(r)),
269            #[cfg(feature = "compression")]
270            ZSTD => match ZstdDecoder::new(r) {
271                Ok(decoder) => Box::new(decoder),
272                Err(e) => return Err(DataFusionError::External(Box::new(e))),
273            },
274            #[cfg(not(feature = "compression"))]
275            GZIP | BZIP2 | XZ | ZSTD => {
276                return Err(DataFusionError::NotImplemented(
277                    "Compression feature is not enabled".to_owned(),
278                ));
279            }
280            UNCOMPRESSED => Box::new(r),
281        })
282    }
283}
284
285/// Trait for extending the functionality of the `FileType` enum.
286pub trait FileTypeExt {
287    /// Given a `FileCompressionType`, return the `FileType`'s extension with compression suffix
288    fn get_ext_with_compression(&self, c: FileCompressionType) -> Result<String>;
289}
290
291#[cfg(test)]
292mod tests {
293    use std::str::FromStr;
294
295    use super::FileCompressionType;
296    use datafusion_common::error::DataFusionError;
297
298    use bytes::Bytes;
299    use futures::StreamExt;
300
301    #[test]
302    fn from_str() {
303        for (ext, compression_type) in [
304            ("gz", FileCompressionType::GZIP),
305            ("GZ", FileCompressionType::GZIP),
306            ("gzip", FileCompressionType::GZIP),
307            ("GZIP", FileCompressionType::GZIP),
308            ("xz", FileCompressionType::XZ),
309            ("XZ", FileCompressionType::XZ),
310            ("bz2", FileCompressionType::BZIP2),
311            ("BZ2", FileCompressionType::BZIP2),
312            ("bzip2", FileCompressionType::BZIP2),
313            ("BZIP2", FileCompressionType::BZIP2),
314            ("zst", FileCompressionType::ZSTD),
315            ("ZST", FileCompressionType::ZSTD),
316            ("zstd", FileCompressionType::ZSTD),
317            ("ZSTD", FileCompressionType::ZSTD),
318            ("", FileCompressionType::UNCOMPRESSED),
319        ] {
320            assert_eq!(
321                FileCompressionType::from_str(ext).unwrap(),
322                compression_type
323            );
324        }
325
326        assert!(matches!(
327            FileCompressionType::from_str("Unknown"),
328            Err(DataFusionError::NotImplemented(_))
329        ));
330    }
331
332    #[tokio::test]
333    async fn test_bgzip_stream_decoding() -> Result<(), DataFusionError> {
334        // As described in https://samtools.github.io/hts-specs/SAMv1.pdf ("The BGZF compression format")
335
336        // Ignore rust formatting so the byte array is easier to read
337        #[rustfmt::skip]
338        let data = [
339            // Block header
340            0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
341            0x02, 0x00,
342            // Block 0, literal: 42
343            0x1e, 0x00, 0x33, 0x31, 0xe2, 0x02, 0x00, 0x31, 0x29, 0x86, 0xd1, 0x03, 0x00, 0x00, 0x00,
344            // Block header
345            0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
346            0x02, 0x00,
347            // Block 1, literal: 42
348            0x1e, 0x00, 0x33, 0x31, 0xe2, 0x02, 0x00, 0x31, 0x29, 0x86, 0xd1, 0x03, 0x00, 0x00, 0x00,
349            // EOF
350            0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
351            0x02, 0x00, 0x1b, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
352        ];
353
354        // Create a byte stream
355        let stream = futures::stream::iter(vec![Ok::<Bytes, DataFusionError>(
356            Bytes::from(data.to_vec()),
357        )]);
358        let converted_stream =
359            FileCompressionType::GZIP.convert_stream(stream.boxed())?;
360
361        let vec = converted_stream
362            .map(|r| r.unwrap())
363            .collect::<Vec<Bytes>>()
364            .await;
365
366        let string_value = String::from_utf8_lossy(&vec[0]);
367
368        assert_eq!(string_value, "42\n42\n");
369
370        Ok(())
371    }
372}