1use crate::FaucetError;
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use std::pin::Pin;
16use tokio::io::{AsyncBufRead, AsyncWrite};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
20#[serde(rename_all = "lowercase")]
21pub enum CompressionConfig {
22 None,
23 Gzip,
24 Zstd,
25 #[default]
27 Auto,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum Compression {
33 None,
34 Gzip,
35 Zstd,
36}
37
38impl CompressionConfig {
39 pub fn resolve(self, path: &str) -> Compression {
41 match self {
42 Self::None => Compression::None,
43 Self::Gzip => Compression::Gzip,
44 Self::Zstd => Compression::Zstd,
45 Self::Auto => detect_from_path(path),
46 }
47 }
48}
49
50pub fn detect_from_path(path: &str) -> Compression {
53 let lower = path.to_ascii_lowercase();
54 if lower.ends_with(".gz") {
55 Compression::Gzip
56 } else if lower.ends_with(".zst") {
57 Compression::Zstd
58 } else {
59 Compression::None
60 }
61}
62
63pub fn wrap_async_reader<'a, R>(
65 r: R,
66 c: Compression,
67) -> Pin<Box<dyn AsyncBufRead + Send + Unpin + 'a>>
68where
69 R: AsyncBufRead + Send + Unpin + 'a,
70{
71 match c {
72 Compression::None => Box::pin(r),
73 Compression::Gzip => {
74 let mut dec = async_compression::tokio::bufread::GzipDecoder::new(r);
75 dec.multiple_members(true);
76 Box::pin(tokio::io::BufReader::new(dec))
77 }
78 Compression::Zstd => {
81 let dec = async_compression::tokio::bufread::ZstdDecoder::new(r);
82 Box::pin(tokio::io::BufReader::new(dec))
83 }
84 }
85}
86
87pub fn wrap_async_writer<'a, W>(
91 w: W,
92 c: Compression,
93) -> Pin<Box<dyn AsyncWrite + Send + Unpin + 'a>>
94where
95 W: AsyncWrite + Send + Unpin + 'a,
96{
97 match c {
98 Compression::None => Box::pin(w),
99 Compression::Gzip => Box::pin(async_compression::tokio::write::GzipEncoder::new(w)),
100 Compression::Zstd => Box::pin(async_compression::tokio::write::ZstdEncoder::new(w)),
101 }
102}
103
104pub fn wrap_sync_reader<'a, R>(r: R, c: Compression) -> Box<dyn std::io::Read + Send + 'a>
106where
107 R: std::io::Read + Send + 'a,
108{
109 match c {
110 Compression::None => Box::new(r),
111 Compression::Gzip => Box::new(flate2::read::MultiGzDecoder::new(r)),
112 Compression::Zstd => Box::new(
113 zstd::stream::read::Decoder::new(r)
114 .expect("zstd decoder construction is infallible for any Read"),
115 ),
116 }
117}
118
119pub fn wrap_sync_writer<'a, W>(w: W, c: Compression) -> Box<dyn std::io::Write + Send + 'a>
134where
135 W: std::io::Write + Send + 'a,
136{
137 match c {
138 Compression::None => Box::new(w),
139 Compression::Gzip => Box::new(flate2::write::GzEncoder::new(
140 w,
141 flate2::Compression::default(),
142 )),
143 Compression::Zstd => Box::new(
144 zstd::stream::write::Encoder::new(w, 0)
145 .expect("zstd encoder construction is infallible")
146 .auto_finish(),
147 ),
148 }
149}
150
151pub enum SyncCompressWriter<W: std::io::Write> {
161 Plain(W),
162 Gzip(flate2::write::GzEncoder<W>),
163 Zstd(zstd::stream::write::Encoder<'static, W>),
164}
165
166impl<W: std::io::Write> SyncCompressWriter<W> {
167 pub fn finish(self) -> std::io::Result<W> {
170 match self {
171 SyncCompressWriter::Plain(w) => Ok(w),
172 SyncCompressWriter::Gzip(e) => e.finish(),
173 SyncCompressWriter::Zstd(e) => e.finish(),
174 }
175 }
176}
177
178impl<W: std::io::Write> std::io::Write for SyncCompressWriter<W> {
179 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
180 match self {
181 SyncCompressWriter::Plain(w) => w.write(buf),
182 SyncCompressWriter::Gzip(e) => e.write(buf),
183 SyncCompressWriter::Zstd(e) => e.write(buf),
184 }
185 }
186
187 fn flush(&mut self) -> std::io::Result<()> {
188 match self {
189 SyncCompressWriter::Plain(w) => w.flush(),
190 SyncCompressWriter::Gzip(e) => e.flush(),
191 SyncCompressWriter::Zstd(e) => e.flush(),
192 }
193 }
194}
195
196pub fn sync_compress_writer<W: std::io::Write>(w: W, c: Compression) -> SyncCompressWriter<W> {
200 match c {
201 Compression::None => SyncCompressWriter::Plain(w),
202 Compression::Gzip => SyncCompressWriter::Gzip(flate2::write::GzEncoder::new(
203 w,
204 flate2::Compression::default(),
205 )),
206 Compression::Zstd => SyncCompressWriter::Zstd(
207 zstd::stream::write::Encoder::new(w, 0)
208 .expect("zstd encoder construction is infallible"),
209 ),
210 }
211}
212
213pub fn compress_buf(data: &[u8], c: Compression) -> Result<Vec<u8>, FaucetError> {
216 use std::io::Write;
217 match c {
218 Compression::None => Ok(data.to_vec()),
219 Compression::Gzip => {
220 let mut enc = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
221 enc.write_all(data)
222 .map_err(|e| FaucetError::Sink(format!("gzip compress failed: {e}")))?;
223 enc.finish()
224 .map_err(|e| FaucetError::Sink(format!("gzip finalise failed: {e}")))
225 }
226 Compression::Zstd => zstd::stream::encode_all(data, 0)
227 .map_err(|e| FaucetError::Sink(format!("zstd compress failed: {e}"))),
228 }
229}
230
231pub fn warn_mismatch(path: &str, declared: Compression) {
235 use std::collections::HashSet;
236 use std::sync::{Mutex, OnceLock};
237 const MAX_SEEN: usize = 4096;
242 static SEEN: OnceLock<Mutex<HashSet<(String, Compression)>>> = OnceLock::new();
243 let detected = detect_from_path(path);
244 if detected == declared {
245 return;
246 }
247 let key = (path.to_string(), declared);
248 let should_warn = {
249 let mut seen = SEEN
250 .get_or_init(|| Mutex::new(HashSet::new()))
251 .lock()
252 .expect("compression mismatch log mutex poisoned");
253 if seen.len() >= MAX_SEEN {
254 true
255 } else {
256 seen.insert(key)
257 }
258 };
259 if should_warn {
260 tracing::warn!(
261 path = %path,
262 declared = ?declared,
263 detected = ?detected,
264 "compression codec mismatch — explicit config wins, filename extension ignored",
265 );
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
273
274 #[test]
275 fn detect_extensions() {
276 assert_eq!(detect_from_path("foo.jsonl"), Compression::None);
277 assert_eq!(detect_from_path("foo.json.gz"), Compression::Gzip);
278 assert_eq!(detect_from_path("foo.csv.zst"), Compression::Zstd);
279 assert_eq!(detect_from_path("FOO.GZ"), Compression::Gzip);
280 assert_eq!(detect_from_path("a.gz.zst"), Compression::Zstd);
281 assert_eq!(detect_from_path(""), Compression::None);
282 }
283
284 #[test]
285 fn resolve_auto_uses_path() {
286 assert_eq!(CompressionConfig::Auto.resolve("foo.gz"), Compression::Gzip);
287 assert_eq!(
288 CompressionConfig::Auto.resolve("foo.zst"),
289 Compression::Zstd
290 );
291 assert_eq!(CompressionConfig::Auto.resolve("foo"), Compression::None);
292 }
293
294 #[test]
295 fn resolve_explicit_ignores_path() {
296 assert_eq!(
297 CompressionConfig::Gzip.resolve("foo.txt"),
298 Compression::Gzip
299 );
300 assert_eq!(CompressionConfig::None.resolve("foo.gz"), Compression::None);
301 }
302
303 #[test]
304 fn config_default_is_auto() {
305 assert_eq!(CompressionConfig::default(), CompressionConfig::Auto);
306 }
307
308 #[test]
309 fn config_serde_lowercase() {
310 for (variant, expected) in [
312 (CompressionConfig::None, "\"none\""),
313 (CompressionConfig::Gzip, "\"gzip\""),
314 (CompressionConfig::Zstd, "\"zstd\""),
315 (CompressionConfig::Auto, "\"auto\""),
316 ] {
317 let serialised = serde_json::to_string(&variant).unwrap();
318 assert_eq!(serialised, expected);
319 let deserialised: CompressionConfig = serde_json::from_str(expected).unwrap();
320 assert_eq!(deserialised, variant);
321 }
322 }
323
324 #[tokio::test]
325 async fn async_roundtrip_gzip() {
326 let original = b"hello, compressed world!\n".repeat(100);
327 let mut buf = Vec::new();
328 {
329 let mut w = wrap_async_writer(&mut buf, Compression::Gzip);
330 w.write_all(&original).await.unwrap();
331 w.shutdown().await.unwrap();
332 }
333 let mut decompressed = Vec::new();
334 let mut r = wrap_async_reader(BufReader::new(&buf[..]), Compression::Gzip);
335 r.read_to_end(&mut decompressed).await.unwrap();
336 assert_eq!(decompressed, original);
337 }
338
339 #[tokio::test]
340 async fn async_roundtrip_zstd() {
341 let original = b"zstd payload\n".repeat(50);
342 let mut buf = Vec::new();
343 {
344 let mut w = wrap_async_writer(&mut buf, Compression::Zstd);
345 w.write_all(&original).await.unwrap();
346 w.shutdown().await.unwrap();
347 }
348 let mut decompressed = Vec::new();
349 let mut r = wrap_async_reader(BufReader::new(&buf[..]), Compression::Zstd);
350 r.read_to_end(&mut decompressed).await.unwrap();
351 assert_eq!(decompressed, original);
352 }
353
354 #[tokio::test]
355 async fn async_none_passthrough() {
356 let original = b"plain text";
357 let mut buf = Vec::new();
358 {
359 let mut w = wrap_async_writer(&mut buf, Compression::None);
360 w.write_all(original).await.unwrap();
361 w.shutdown().await.unwrap();
362 }
363 assert_eq!(&buf[..], original);
364 }
365
366 #[test]
367 fn sync_roundtrip_gzip() {
368 use std::io::{Read, Write};
369 let original = b"sync gzip data".repeat(20);
370 let mut buf = Vec::new();
371 {
372 let mut w = wrap_sync_writer(&mut buf, Compression::Gzip);
373 w.write_all(&original).unwrap();
374 w.flush().unwrap();
379 }
380 let mut r = wrap_sync_reader(&buf[..], Compression::Gzip);
381 let mut decompressed = Vec::new();
382 r.read_to_end(&mut decompressed).unwrap();
383 assert_eq!(decompressed, original);
384 }
385
386 #[test]
387 fn sync_roundtrip_zstd() {
388 use std::io::{Read, Write};
389 let original = b"sync zstd data".repeat(20);
390 let mut buf = Vec::new();
391 {
392 let mut w = wrap_sync_writer(&mut buf, Compression::Zstd);
393 w.write_all(&original).unwrap();
394 w.flush().unwrap();
395 }
396 let mut r = wrap_sync_reader(&buf[..], Compression::Zstd);
397 let mut decompressed = Vec::new();
398 r.read_to_end(&mut decompressed).unwrap();
399 assert_eq!(decompressed, original);
400 }
401
402 #[test]
403 fn compress_buf_roundtrip_gzip() {
404 use std::io::Read;
405 let original = b"buffer compression".repeat(10);
406 let compressed = compress_buf(&original, Compression::Gzip).unwrap();
407 assert_ne!(compressed, original);
408 let mut r = wrap_sync_reader(&compressed[..], Compression::Gzip);
409 let mut decompressed = Vec::new();
410 r.read_to_end(&mut decompressed).unwrap();
411 assert_eq!(decompressed, original);
412 }
413
414 #[test]
415 fn compress_buf_roundtrip_zstd() {
416 use std::io::Read;
417 let original = b"buffer zstd".repeat(10);
418 let compressed = compress_buf(&original, Compression::Zstd).unwrap();
419 assert_ne!(compressed, original);
420 let mut r = wrap_sync_reader(&compressed[..], Compression::Zstd);
421 let mut decompressed = Vec::new();
422 r.read_to_end(&mut decompressed).unwrap();
423 assert_eq!(decompressed, original);
424 }
425
426 #[test]
427 fn compress_buf_none_is_clone() {
428 let original = b"unchanged";
429 let out = compress_buf(original, Compression::None).unwrap();
430 assert_eq!(out, original);
431 }
432
433 #[tokio::test]
434 async fn empty_compressed_stream_yields_zero_bytes() {
435 let mut buf = Vec::new();
437 {
438 let mut w = wrap_async_writer(&mut buf, Compression::Gzip);
439 w.shutdown().await.unwrap();
440 }
441 let mut decompressed = Vec::new();
443 let mut r = wrap_async_reader(BufReader::new(&buf[..]), Compression::Gzip);
444 r.read_to_end(&mut decompressed).await.unwrap();
445 assert!(decompressed.is_empty());
446 }
447
448 #[tokio::test]
449 async fn truncated_gzip_stream_errors() {
450 let original = b"this will be truncated mid-stream".repeat(50);
451 let mut buf = Vec::new();
452 {
453 let mut w = wrap_async_writer(&mut buf, Compression::Gzip);
454 w.write_all(&original).await.unwrap();
455 w.shutdown().await.unwrap();
456 }
457 buf.truncate(buf.len() / 2);
459 let mut decompressed = Vec::new();
460 let mut r = wrap_async_reader(BufReader::new(&buf[..]), Compression::Gzip);
461 let err = r.read_to_end(&mut decompressed).await.unwrap_err();
462 assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
463 }
464
465 #[test]
466 fn warn_mismatch_dedups_per_path_and_codec() {
467 let unique_path = format!("warn_mismatch_dedup_fixture_{}.txt", line!());
473 warn_mismatch(&unique_path, Compression::Gzip);
475 warn_mismatch(&unique_path, Compression::Gzip);
477 warn_mismatch(&unique_path, Compression::Zstd);
479 warn_mismatch("file.gz", Compression::Gzip);
481 }
486}