Skip to main content

ralph_workflow/checkpoint/io/
compression.rs

1const MAX_DECOMPRESSED_SNAPSHOT_BYTES: usize = 1024 * 1024;
2
3pub fn compress(data: &[u8]) -> Result<String, std::io::Error> {
4    use base64::{engine::general_purpose::STANDARD, Engine};
5    use flate2::write::GzEncoder;
6    use flate2::Compression;
7    use std::io::Write;
8
9    let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
10    encoder.write_all(data)?;
11    let compressed = encoder.finish()?;
12
13    Ok(STANDARD.encode(&compressed))
14}
15
16fn base64_decode(encoded: &str) -> Result<Vec<u8>, std::io::Error> {
17    use base64::{engine::general_purpose::STANDARD, Engine};
18    STANDARD.decode(encoded).map_err(|e| {
19        std::io::Error::new(
20            std::io::ErrorKind::InvalidData,
21            format!("Base64 decode error: {e}"),
22        )
23    })
24}
25
26fn check_size_limit(current_len: usize, n: usize) -> Result<(), std::io::Error> {
27    if current_len.saturating_add(n) > MAX_DECOMPRESSED_SNAPSHOT_BYTES {
28        return Err(std::io::Error::new(
29            std::io::ErrorKind::InvalidData,
30            format!(
31                "Decompressed payload exceeds max size ({MAX_DECOMPRESSED_SNAPSHOT_BYTES} bytes)"
32            ),
33        ));
34    }
35    Ok(())
36}
37
38fn read_chunk<R: std::io::Read>(
39    reader: &mut R,
40    buf: &mut [u8],
41    decompressed: &mut Vec<u8>,
42) -> Result<bool, std::io::Error> {
43    let n = reader.read(buf)?;
44    if n == 0 {
45        return Ok(false);
46    }
47    check_size_limit(decompressed.len(), n)?;
48    decompressed.extend_from_slice(&buf[..n]);
49    Ok(true)
50}
51
52fn gz_decompress(compressed: &[u8]) -> Result<Vec<u8>, std::io::Error> {
53    use flate2::read::GzDecoder;
54
55    let mut decoder = GzDecoder::new(compressed);
56    let mut decompressed = Vec::new();
57    let mut buf = [0u8; 8 * 1024];
58
59    while read_chunk(&mut decoder, &mut buf, &mut decompressed)? {}
60
61    Ok(decompressed)
62}
63
64fn bytes_to_utf8(decompressed: Vec<u8>) -> Result<String, std::io::Error> {
65    String::from_utf8(decompressed).map_err(|e| {
66        std::io::Error::new(
67            std::io::ErrorKind::InvalidData,
68            format!("UTF-8 decode error: {e}"),
69        )
70    })
71}
72
73pub fn decompress(encoded: &str) -> Result<String, std::io::Error> {
74    let compressed = base64_decode(encoded)?;
75    let decompressed = gz_decompress(&compressed)?;
76    bytes_to_utf8(decompressed)
77}