use anyhow::{anyhow, Result};
use flate2::read::MultiGzDecoder;
use std::fs::File;
use std::io::{BufRead, BufReader, Chain, Cursor, Read};
use std::path::Path;
#[derive(Debug)]
pub enum DecompressionReader {
Gzip(BufReader<MultiGzDecoder<Chain<Cursor<Vec<u8>>, File>>>),
Plain(BufReader<Chain<Cursor<Vec<u8>>, File>>),
}
unsafe impl Send for DecompressionReader {}
impl BufRead for DecompressionReader {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
match self {
DecompressionReader::Gzip(reader) => reader.fill_buf(),
DecompressionReader::Plain(reader) => reader.fill_buf(),
}
}
fn consume(&mut self, amt: usize) {
match self {
DecompressionReader::Gzip(reader) => reader.consume(amt),
DecompressionReader::Plain(reader) => reader.consume(amt),
}
}
}
impl Read for DecompressionReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
DecompressionReader::Gzip(reader) => reader.read(buf),
DecompressionReader::Plain(reader) => reader.read(buf),
}
}
}
fn maybe_gzip_file(mut file: File) -> std::io::Result<DecompressionReader> {
let mut head = [0u8; 3];
let n = file.read(&mut head)?;
let prefix = Cursor::new(head[..n].to_vec());
let chained = prefix.chain(file);
let is_gzip = n >= 3 && head[0] == 0x1F && head[1] == 0x8B && head[2] == 0x08;
if is_gzip {
let decoder = MultiGzDecoder::new(chained);
Ok(DecompressionReader::Gzip(BufReader::new(decoder)))
} else {
Ok(DecompressionReader::Plain(BufReader::new(chained)))
}
}
pub fn maybe_gzip<R: Read + Send + 'static>(
mut reader: R,
) -> std::io::Result<Box<dyn Read + Send>> {
let mut head = [0u8; 3];
let n = reader.read(&mut head)?;
let prefix = Cursor::new(head[..n].to_vec());
let chained: Chain<Cursor<Vec<u8>>, R> = prefix.chain(reader);
let is_gzip = n >= 3 && head[0] == 0x1F && head[1] == 0x8B && head[2] == 0x08;
if is_gzip {
Ok(Box::new(MultiGzDecoder::new(chained)))
} else {
Ok(Box::new(chained))
}
}
impl DecompressionReader {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let path_ref = path.as_ref();
let file = File::open(path_ref)?;
if let Some(extension) = path_ref.extension().and_then(|ext| ext.to_str()) {
if extension.to_lowercase() == "zip" {
return Err(anyhow!("ZIP file decompression is not supported. Only gzip files are supported for streaming decompression. Extract the ZIP file first: unzip {}", path_ref.display()));
}
}
maybe_gzip_file(file).map_err(|e| anyhow!("Failed to detect compression format: {}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Write};
use tempfile::NamedTempFile;
#[test]
fn test_plain_file_passthrough() -> Result<()> {
let mut temp_file = NamedTempFile::new()?;
writeln!(temp_file, "test line 1")?;
writeln!(temp_file, "test line 2")?;
temp_file.flush()?;
let mut reader = DecompressionReader::new(temp_file.path())?;
let mut content = String::new();
reader.read_to_string(&mut content)?;
assert!(content.contains("test line 1"));
assert!(content.contains("test line 2"));
Ok(())
}
#[test]
fn test_zip_file_rejection() {
let temp_file = NamedTempFile::new().unwrap();
let temp_path = temp_file.path();
let zip_path = temp_path.with_extension("zip");
std::fs::write(&zip_path, b"fake zip content").unwrap();
let result = DecompressionReader::new(&zip_path);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("ZIP file decompression is not supported"));
assert!(error_msg.contains("Only gzip files are supported"));
let _ = std::fs::remove_file(&zip_path);
}
#[test]
fn test_magic_bytes_detection() -> Result<()> {
let mut temp_file = NamedTempFile::new()?;
writeln!(temp_file, "plain text file")?;
temp_file.flush()?;
let mut reader = DecompressionReader::new(temp_file.path())?;
let mut content = String::new();
reader.read_to_string(&mut content)?;
assert!(content.contains("plain text file"));
let mut gzip_temp = NamedTempFile::new()?;
gzip_temp.write_all(&[0x1F, 0x8B, 0x08])?;
gzip_temp.write_all(b"fake gzip data")?;
gzip_temp.flush()?;
let result = DecompressionReader::new(gzip_temp.path());
match result {
Ok(_reader) => {
}
Err(_e) => {
}
}
Ok(())
}
}