Skip to main content

datasynth_output/
compressed.rs

1//! Compressed output writers using zstd for CSV/JSON files.
2//!
3//! Provides transparent compression wrappers that can wrap any `Write` sink.
4//! Uses zstd multithreaded encoding for parallel compression on multi-core systems.
5
6use std::fs::File;
7use std::io::{self, BufWriter, Write};
8use std::path::{Path, PathBuf};
9
10/// Compression configuration for output files.
11#[derive(Debug, Clone)]
12pub struct CompressionConfig {
13    /// Zstd compression level (1-22, default 3).
14    pub level: i32,
15    /// Number of worker threads for parallel compression (0 = auto-detect).
16    pub threads: u32,
17}
18
19impl Default for CompressionConfig {
20    fn default() -> Self {
21        Self {
22            level: 3,
23            threads: 0,
24        }
25    }
26}
27
28impl CompressionConfig {
29    /// Create a config with the given compression level.
30    pub fn with_level(mut self, level: i32) -> Self {
31        self.level = level.clamp(1, 22);
32        self
33    }
34
35    /// Create a config with the given number of threads.
36    pub fn with_threads(mut self, threads: u32) -> Self {
37        self.threads = threads;
38        self
39    }
40}
41
42/// A writer that transparently compresses output using zstd.
43///
44/// Wraps a `BufWriter<File>` with zstd compression. The compressed data
45/// is written to a file with a `.zst` extension appended to the original path.
46pub struct CompressedWriter<'a> {
47    encoder: zstd::Encoder<'a, BufWriter<File>>,
48    bytes_written: u64,
49}
50
51impl<'a> CompressedWriter<'a> {
52    /// Create a new compressed writer for the given path.
53    pub fn new(path: &Path, config: &CompressionConfig) -> io::Result<Self> {
54        let file = File::create(path)?;
55        let buf_writer = BufWriter::with_capacity(256 * 1024, file);
56        let mut encoder = zstd::Encoder::new(buf_writer, config.level)?;
57
58        // Enable multithreaded compression if requested
59        if config.threads > 0 {
60            encoder
61                .set_parameter(zstd::zstd_safe::CParameter::NbWorkers(config.threads))
62                .map_err(|_| io::Error::other("Failed to set zstd worker threads"))?;
63        }
64
65        Ok(Self {
66            encoder,
67            bytes_written: 0,
68        })
69    }
70
71    /// Get total uncompressed bytes written.
72    pub fn bytes_written(&self) -> u64 {
73        self.bytes_written
74    }
75
76    /// Finish compression and flush all remaining data.
77    pub fn finish(self) -> io::Result<()> {
78        self.encoder.finish()?;
79        Ok(())
80    }
81}
82
83impl Write for CompressedWriter<'_> {
84    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
85        let n = self.encoder.write(buf)?;
86        self.bytes_written += n as u64;
87        Ok(n)
88    }
89
90    fn flush(&mut self) -> io::Result<()> {
91        self.encoder.flush()
92    }
93}
94
95/// Determine the compressed output path (adds .zst extension).
96pub fn compressed_path(path: &Path) -> PathBuf {
97    let mut p = path.as_os_str().to_owned();
98    p.push(".zst");
99    PathBuf::from(p)
100}
101
102#[cfg(test)]
103#[allow(clippy::unwrap_used)]
104mod tests {
105    use super::*;
106    use std::io::Read;
107    use tempfile::tempdir;
108
109    #[test]
110    fn test_compressed_writer_roundtrip() {
111        let dir = tempdir().unwrap();
112        let path = dir.path().join("test.csv.zst");
113
114        let config = CompressionConfig::default();
115        let mut writer = CompressedWriter::new(&path, &config).unwrap();
116
117        let data = "id,name,value\n1,hello,42.5\n2,world,99.9\n";
118        writer.write_all(data.as_bytes()).unwrap();
119        writer.finish().unwrap();
120
121        // Decompress and verify
122        let compressed = std::fs::read(&path).unwrap();
123        let mut decoder = zstd::Decoder::new(&compressed[..]).unwrap();
124        let mut decompressed = String::new();
125        decoder.read_to_string(&mut decompressed).unwrap();
126
127        assert_eq!(decompressed, data);
128    }
129
130    #[test]
131    fn test_compressed_writer_large_data() {
132        let dir = tempdir().unwrap();
133        let path = dir.path().join("large.csv.zst");
134
135        let config = CompressionConfig::default().with_level(3);
136        let mut writer = CompressedWriter::new(&path, &config).unwrap();
137
138        // Write 10K rows
139        writer.write_all(b"id,name,value\n").unwrap();
140        for i in 0..10_000u32 {
141            let row = format!("{},item_{},{}.{:02}\n", i, i, i * 100, i % 100);
142            writer.write_all(row.as_bytes()).unwrap();
143        }
144        let bytes_written = writer.bytes_written();
145        writer.finish().unwrap();
146
147        // Verify compressed file is smaller
148        let file_size = std::fs::metadata(&path).unwrap().len();
149        assert!(
150            file_size < bytes_written,
151            "Compressed size {} should be less than uncompressed {}",
152            file_size,
153            bytes_written
154        );
155
156        // Verify decompression roundtrip
157        let compressed = std::fs::read(&path).unwrap();
158        let mut decoder = zstd::Decoder::new(&compressed[..]).unwrap();
159        let mut decompressed = String::new();
160        decoder.read_to_string(&mut decompressed).unwrap();
161        assert!(decompressed.starts_with("id,name,value\n"));
162        let line_count = decompressed.lines().count();
163        assert_eq!(line_count, 10_001); // header + 10K rows
164    }
165
166    #[test]
167    fn test_compressed_path() {
168        let path = Path::new("/tmp/output/data.csv");
169        let cp = compressed_path(path);
170        assert_eq!(cp, PathBuf::from("/tmp/output/data.csv.zst"));
171    }
172
173    #[test]
174    fn test_compression_config() {
175        let config = CompressionConfig::default().with_level(6).with_threads(4);
176        assert_eq!(config.level, 6);
177        assert_eq!(config.threads, 4);
178    }
179
180    #[test]
181    fn test_compression_level_clamp() {
182        let config = CompressionConfig::default().with_level(50);
183        assert_eq!(config.level, 22);
184
185        let config = CompressionConfig::default().with_level(-5);
186        assert_eq!(config.level, 1);
187    }
188}