use crate::FaucetError;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use tokio::io::{AsyncBufRead, AsyncWrite};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "lowercase")]
pub enum CompressionConfig {
None,
Gzip,
Zstd,
#[default]
Auto,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Compression {
None,
Gzip,
Zstd,
}
impl CompressionConfig {
pub fn resolve(self, path: &str) -> Compression {
match self {
Self::None => Compression::None,
Self::Gzip => Compression::Gzip,
Self::Zstd => Compression::Zstd,
Self::Auto => detect_from_path(path),
}
}
}
pub fn detect_from_path(path: &str) -> Compression {
let lower = path.to_ascii_lowercase();
if lower.ends_with(".gz") {
Compression::Gzip
} else if lower.ends_with(".zst") {
Compression::Zstd
} else {
Compression::None
}
}
pub fn wrap_async_reader<'a, R>(
r: R,
c: Compression,
) -> Pin<Box<dyn AsyncBufRead + Send + Unpin + 'a>>
where
R: AsyncBufRead + Send + Unpin + 'a,
{
match c {
Compression::None => Box::pin(r),
Compression::Gzip => {
let mut dec = async_compression::tokio::bufread::GzipDecoder::new(r);
dec.multiple_members(true);
Box::pin(tokio::io::BufReader::new(dec))
}
Compression::Zstd => {
let dec = async_compression::tokio::bufread::ZstdDecoder::new(r);
Box::pin(tokio::io::BufReader::new(dec))
}
}
}
pub fn wrap_async_writer<'a, W>(
w: W,
c: Compression,
) -> Pin<Box<dyn AsyncWrite + Send + Unpin + 'a>>
where
W: AsyncWrite + Send + Unpin + 'a,
{
match c {
Compression::None => Box::pin(w),
Compression::Gzip => Box::pin(async_compression::tokio::write::GzipEncoder::new(w)),
Compression::Zstd => Box::pin(async_compression::tokio::write::ZstdEncoder::new(w)),
}
}
pub fn wrap_sync_reader<'a, R>(r: R, c: Compression) -> Box<dyn std::io::Read + Send + 'a>
where
R: std::io::Read + Send + 'a,
{
match c {
Compression::None => Box::new(r),
Compression::Gzip => Box::new(flate2::read::MultiGzDecoder::new(r)),
Compression::Zstd => Box::new(
zstd::stream::read::Decoder::new(r)
.expect("zstd decoder construction is infallible for any Read"),
),
}
}
pub fn wrap_sync_writer<'a, W>(w: W, c: Compression) -> Box<dyn std::io::Write + Send + 'a>
where
W: std::io::Write + Send + 'a,
{
match c {
Compression::None => Box::new(w),
Compression::Gzip => Box::new(flate2::write::GzEncoder::new(
w,
flate2::Compression::default(),
)),
Compression::Zstd => Box::new(
zstd::stream::write::Encoder::new(w, 0)
.expect("zstd encoder construction is infallible")
.auto_finish(),
),
}
}
pub enum SyncCompressWriter<W: std::io::Write> {
Plain(W),
Gzip(flate2::write::GzEncoder<W>),
Zstd(zstd::stream::write::Encoder<'static, W>),
}
impl<W: std::io::Write> SyncCompressWriter<W> {
pub fn finish(self) -> std::io::Result<W> {
match self {
SyncCompressWriter::Plain(w) => Ok(w),
SyncCompressWriter::Gzip(e) => e.finish(),
SyncCompressWriter::Zstd(e) => e.finish(),
}
}
}
impl<W: std::io::Write> std::io::Write for SyncCompressWriter<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
SyncCompressWriter::Plain(w) => w.write(buf),
SyncCompressWriter::Gzip(e) => e.write(buf),
SyncCompressWriter::Zstd(e) => e.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
SyncCompressWriter::Plain(w) => w.flush(),
SyncCompressWriter::Gzip(e) => e.flush(),
SyncCompressWriter::Zstd(e) => e.flush(),
}
}
}
pub fn sync_compress_writer<W: std::io::Write>(w: W, c: Compression) -> SyncCompressWriter<W> {
match c {
Compression::None => SyncCompressWriter::Plain(w),
Compression::Gzip => SyncCompressWriter::Gzip(flate2::write::GzEncoder::new(
w,
flate2::Compression::default(),
)),
Compression::Zstd => SyncCompressWriter::Zstd(
zstd::stream::write::Encoder::new(w, 0)
.expect("zstd encoder construction is infallible"),
),
}
}
pub fn compress_buf(data: &[u8], c: Compression) -> Result<Vec<u8>, FaucetError> {
use std::io::Write;
match c {
Compression::None => Ok(data.to_vec()),
Compression::Gzip => {
let mut enc = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
enc.write_all(data)
.map_err(|e| FaucetError::Sink(format!("gzip compress failed: {e}")))?;
enc.finish()
.map_err(|e| FaucetError::Sink(format!("gzip finalise failed: {e}")))
}
Compression::Zstd => zstd::stream::encode_all(data, 0)
.map_err(|e| FaucetError::Sink(format!("zstd compress failed: {e}"))),
}
}
pub fn warn_mismatch(path: &str, declared: Compression) {
use std::collections::HashSet;
use std::sync::{Mutex, OnceLock};
const MAX_SEEN: usize = 4096;
static SEEN: OnceLock<Mutex<HashSet<(String, Compression)>>> = OnceLock::new();
let detected = detect_from_path(path);
if detected == declared {
return;
}
let key = (path.to_string(), declared);
let should_warn = {
let mut seen = SEEN
.get_or_init(|| Mutex::new(HashSet::new()))
.lock()
.expect("compression mismatch log mutex poisoned");
if seen.len() >= MAX_SEEN {
true
} else {
seen.insert(key)
}
};
if should_warn {
tracing::warn!(
path = %path,
declared = ?declared,
detected = ?detected,
"compression codec mismatch — explicit config wins, filename extension ignored",
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
#[test]
fn detect_extensions() {
assert_eq!(detect_from_path("foo.jsonl"), Compression::None);
assert_eq!(detect_from_path("foo.json.gz"), Compression::Gzip);
assert_eq!(detect_from_path("foo.csv.zst"), Compression::Zstd);
assert_eq!(detect_from_path("FOO.GZ"), Compression::Gzip);
assert_eq!(detect_from_path("a.gz.zst"), Compression::Zstd);
assert_eq!(detect_from_path(""), Compression::None);
}
#[test]
fn resolve_auto_uses_path() {
assert_eq!(CompressionConfig::Auto.resolve("foo.gz"), Compression::Gzip);
assert_eq!(
CompressionConfig::Auto.resolve("foo.zst"),
Compression::Zstd
);
assert_eq!(CompressionConfig::Auto.resolve("foo"), Compression::None);
}
#[test]
fn resolve_explicit_ignores_path() {
assert_eq!(
CompressionConfig::Gzip.resolve("foo.txt"),
Compression::Gzip
);
assert_eq!(CompressionConfig::None.resolve("foo.gz"), Compression::None);
}
#[test]
fn config_default_is_auto() {
assert_eq!(CompressionConfig::default(), CompressionConfig::Auto);
}
#[test]
fn config_serde_lowercase() {
for (variant, expected) in [
(CompressionConfig::None, "\"none\""),
(CompressionConfig::Gzip, "\"gzip\""),
(CompressionConfig::Zstd, "\"zstd\""),
(CompressionConfig::Auto, "\"auto\""),
] {
let serialised = serde_json::to_string(&variant).unwrap();
assert_eq!(serialised, expected);
let deserialised: CompressionConfig = serde_json::from_str(expected).unwrap();
assert_eq!(deserialised, variant);
}
}
#[tokio::test]
async fn async_roundtrip_gzip() {
let original = b"hello, compressed world!\n".repeat(100);
let mut buf = Vec::new();
{
let mut w = wrap_async_writer(&mut buf, Compression::Gzip);
w.write_all(&original).await.unwrap();
w.shutdown().await.unwrap();
}
let mut decompressed = Vec::new();
let mut r = wrap_async_reader(BufReader::new(&buf[..]), Compression::Gzip);
r.read_to_end(&mut decompressed).await.unwrap();
assert_eq!(decompressed, original);
}
#[tokio::test]
async fn async_roundtrip_zstd() {
let original = b"zstd payload\n".repeat(50);
let mut buf = Vec::new();
{
let mut w = wrap_async_writer(&mut buf, Compression::Zstd);
w.write_all(&original).await.unwrap();
w.shutdown().await.unwrap();
}
let mut decompressed = Vec::new();
let mut r = wrap_async_reader(BufReader::new(&buf[..]), Compression::Zstd);
r.read_to_end(&mut decompressed).await.unwrap();
assert_eq!(decompressed, original);
}
#[tokio::test]
async fn async_none_passthrough() {
let original = b"plain text";
let mut buf = Vec::new();
{
let mut w = wrap_async_writer(&mut buf, Compression::None);
w.write_all(original).await.unwrap();
w.shutdown().await.unwrap();
}
assert_eq!(&buf[..], original);
}
#[test]
fn sync_roundtrip_gzip() {
use std::io::{Read, Write};
let original = b"sync gzip data".repeat(20);
let mut buf = Vec::new();
{
let mut w = wrap_sync_writer(&mut buf, Compression::Gzip);
w.write_all(&original).unwrap();
w.flush().unwrap();
}
let mut r = wrap_sync_reader(&buf[..], Compression::Gzip);
let mut decompressed = Vec::new();
r.read_to_end(&mut decompressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn sync_roundtrip_zstd() {
use std::io::{Read, Write};
let original = b"sync zstd data".repeat(20);
let mut buf = Vec::new();
{
let mut w = wrap_sync_writer(&mut buf, Compression::Zstd);
w.write_all(&original).unwrap();
w.flush().unwrap();
}
let mut r = wrap_sync_reader(&buf[..], Compression::Zstd);
let mut decompressed = Vec::new();
r.read_to_end(&mut decompressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn compress_buf_roundtrip_gzip() {
use std::io::Read;
let original = b"buffer compression".repeat(10);
let compressed = compress_buf(&original, Compression::Gzip).unwrap();
assert_ne!(compressed, original);
let mut r = wrap_sync_reader(&compressed[..], Compression::Gzip);
let mut decompressed = Vec::new();
r.read_to_end(&mut decompressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn compress_buf_roundtrip_zstd() {
use std::io::Read;
let original = b"buffer zstd".repeat(10);
let compressed = compress_buf(&original, Compression::Zstd).unwrap();
assert_ne!(compressed, original);
let mut r = wrap_sync_reader(&compressed[..], Compression::Zstd);
let mut decompressed = Vec::new();
r.read_to_end(&mut decompressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn compress_buf_none_is_clone() {
let original = b"unchanged";
let out = compress_buf(original, Compression::None).unwrap();
assert_eq!(out, original);
}
#[tokio::test]
async fn empty_compressed_stream_yields_zero_bytes() {
let mut buf = Vec::new();
{
let mut w = wrap_async_writer(&mut buf, Compression::Gzip);
w.shutdown().await.unwrap();
}
let mut decompressed = Vec::new();
let mut r = wrap_async_reader(BufReader::new(&buf[..]), Compression::Gzip);
r.read_to_end(&mut decompressed).await.unwrap();
assert!(decompressed.is_empty());
}
#[tokio::test]
async fn truncated_gzip_stream_errors() {
let original = b"this will be truncated mid-stream".repeat(50);
let mut buf = Vec::new();
{
let mut w = wrap_async_writer(&mut buf, Compression::Gzip);
w.write_all(&original).await.unwrap();
w.shutdown().await.unwrap();
}
buf.truncate(buf.len() / 2);
let mut decompressed = Vec::new();
let mut r = wrap_async_reader(BufReader::new(&buf[..]), Compression::Gzip);
let err = r.read_to_end(&mut decompressed).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[test]
fn warn_mismatch_dedups_per_path_and_codec() {
let unique_path = format!("warn_mismatch_dedup_fixture_{}.txt", line!());
warn_mismatch(&unique_path, Compression::Gzip);
warn_mismatch(&unique_path, Compression::Gzip);
warn_mismatch(&unique_path, Compression::Zstd);
warn_mismatch("file.gz", Compression::Gzip);
}
}