Skip to main content

initramfs_builder/initramfs/
compress.rs

1use anyhow::{Context, Result};
2use flate2::write::GzEncoder;
3use flate2::Compression as GzCompression;
4use std::fs::File;
5use std::io::{BufWriter, Write};
6use std::path::Path;
7use tracing::info;
8
9#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
10pub enum Compression {
11    #[default]
12    Gzip,
13    Zstd,
14    None,
15}
16
17impl std::str::FromStr for Compression {
18    type Err = String;
19
20    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
21        match s.to_lowercase().as_str() {
22            "gzip" | "gz" => Ok(Compression::Gzip),
23            "zstd" | "zst" => Ok(Compression::Zstd),
24            "none" | "raw" => Ok(Compression::None),
25            _ => Err(format!("Unknown compression: {}", s)),
26        }
27    }
28}
29
30impl std::fmt::Display for Compression {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Compression::Gzip => write!(f, "gzip"),
34            Compression::Zstd => write!(f, "zstd"),
35            Compression::None => write!(f, "none"),
36        }
37    }
38}
39
40/// Compress data and write to output path
41pub fn compress_archive(data: &[u8], output_path: &Path, compression: Compression) -> Result<u64> {
42    info!(
43        "Compressing {} bytes with {} to {:?}",
44        data.len(),
45        compression,
46        output_path
47    );
48
49    let file = File::create(output_path)
50        .with_context(|| format!("Failed to create output file: {:?}", output_path))?;
51    let mut writer = BufWriter::new(file);
52
53    match compression {
54        Compression::Gzip => {
55            let mut encoder = GzEncoder::new(&mut writer, GzCompression::default());
56            encoder.write_all(data)?;
57            encoder.finish()?;
58        }
59        Compression::Zstd => {
60            let mut encoder = zstd::stream::Encoder::new(&mut writer, 3)?;
61            encoder.write_all(data)?;
62            encoder.finish()?;
63        }
64        Compression::None => {
65            writer.write_all(data)?;
66        }
67    }
68
69    writer.flush()?;
70
71    let output_size = std::fs::metadata(output_path)?.len();
72    info!(
73        "Compressed {} bytes -> {} bytes ({:.1}% ratio)",
74        data.len(),
75        output_size,
76        (output_size as f64 / data.len() as f64) * 100.0
77    );
78
79    Ok(output_size)
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use std::fs;
86    use std::io::Read;
87    use tempfile::TempDir;
88
89    #[test]
90    fn test_compression_from_str() {
91        assert_eq!("gzip".parse::<Compression>().unwrap(), Compression::Gzip);
92        assert_eq!("gz".parse::<Compression>().unwrap(), Compression::Gzip);
93        assert_eq!("zstd".parse::<Compression>().unwrap(), Compression::Zstd);
94        assert_eq!("zst".parse::<Compression>().unwrap(), Compression::Zstd);
95        assert_eq!("none".parse::<Compression>().unwrap(), Compression::None);
96        assert_eq!("raw".parse::<Compression>().unwrap(), Compression::None);
97        assert!("invalid".parse::<Compression>().is_err());
98    }
99
100    #[test]
101    fn test_compression_display() {
102        assert_eq!(format!("{}", Compression::Gzip), "gzip");
103        assert_eq!(format!("{}", Compression::Zstd), "zstd");
104        assert_eq!(format!("{}", Compression::None), "none");
105    }
106
107    #[test]
108    fn test_compression_default() {
109        assert_eq!(Compression::default(), Compression::Gzip);
110    }
111
112    #[test]
113    fn test_gzip_compression() {
114        let temp_dir = TempDir::new().unwrap();
115        let output_path = temp_dir.path().join("test.gz");
116        // Use repetitive data that compresses well
117        let data: Vec<u8> = b"hello world ".repeat(100).to_vec();
118
119        let size = compress_archive(&data, &output_path, Compression::Gzip).unwrap();
120
121        assert!(output_path.exists());
122        assert!(size > 0);
123
124        // Verify it's valid gzip and decompresses correctly
125        let file = File::open(&output_path).unwrap();
126        let mut decoder = flate2::read::GzDecoder::new(file);
127        let mut decompressed = Vec::new();
128        decoder.read_to_end(&mut decompressed).unwrap();
129        assert_eq!(decompressed, data);
130    }
131
132    #[test]
133    fn test_zstd_compression() {
134        let temp_dir = TempDir::new().unwrap();
135        let output_path = temp_dir.path().join("test.zst");
136        let data = b"hello world hello world hello world";
137
138        let size = compress_archive(data, &output_path, Compression::Zstd).unwrap();
139
140        assert!(output_path.exists());
141        assert!(
142            size < data.len() as u64,
143            "Compressed size should be smaller"
144        );
145
146        // Verify it's valid zstd
147        let compressed = fs::read(&output_path).unwrap();
148        let decompressed = zstd::decode_all(&compressed[..]).unwrap();
149        assert_eq!(decompressed, data);
150    }
151
152    #[test]
153    fn test_no_compression() {
154        let temp_dir = TempDir::new().unwrap();
155        let output_path = temp_dir.path().join("test.cpio");
156        let data = b"hello world";
157
158        let size = compress_archive(data, &output_path, Compression::None).unwrap();
159
160        assert_eq!(size, data.len() as u64);
161        assert_eq!(fs::read(&output_path).unwrap(), data);
162    }
163}