initramfs_builder/initramfs/
compress.rs1use anyhow::{Context, Result};
2use flate2::write::GzEncoder;
3use flate2::Compression as GzCompression;
4use std::fs::File;
5use std::io::{BufWriter, Write};
6use std::path::Path;
7use tracing::info;
8
9#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
10pub enum Compression {
11 #[default]
12 Gzip,
13 Zstd,
14 None,
15}
16
17impl std::str::FromStr for Compression {
18 type Err = String;
19
20 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
21 match s.to_lowercase().as_str() {
22 "gzip" | "gz" => Ok(Compression::Gzip),
23 "zstd" | "zst" => Ok(Compression::Zstd),
24 "none" | "raw" => Ok(Compression::None),
25 _ => Err(format!("Unknown compression: {}", s)),
26 }
27 }
28}
29
30impl std::fmt::Display for Compression {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Compression::Gzip => write!(f, "gzip"),
34 Compression::Zstd => write!(f, "zstd"),
35 Compression::None => write!(f, "none"),
36 }
37 }
38}
39
40pub fn compress_archive(data: &[u8], output_path: &Path, compression: Compression) -> Result<u64> {
42 info!(
43 "Compressing {} bytes with {} to {:?}",
44 data.len(),
45 compression,
46 output_path
47 );
48
49 let file = File::create(output_path)
50 .with_context(|| format!("Failed to create output file: {:?}", output_path))?;
51 let mut writer = BufWriter::new(file);
52
53 match compression {
54 Compression::Gzip => {
55 let mut encoder = GzEncoder::new(&mut writer, GzCompression::default());
56 encoder.write_all(data)?;
57 encoder.finish()?;
58 }
59 Compression::Zstd => {
60 let mut encoder = zstd::stream::Encoder::new(&mut writer, 3)?;
61 encoder.write_all(data)?;
62 encoder.finish()?;
63 }
64 Compression::None => {
65 writer.write_all(data)?;
66 }
67 }
68
69 writer.flush()?;
70
71 let output_size = std::fs::metadata(output_path)?.len();
72 info!(
73 "Compressed {} bytes -> {} bytes ({:.1}% ratio)",
74 data.len(),
75 output_size,
76 (output_size as f64 / data.len() as f64) * 100.0
77 );
78
79 Ok(output_size)
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use std::fs;
86 use std::io::Read;
87 use tempfile::TempDir;
88
89 #[test]
90 fn test_compression_from_str() {
91 assert_eq!("gzip".parse::<Compression>().unwrap(), Compression::Gzip);
92 assert_eq!("gz".parse::<Compression>().unwrap(), Compression::Gzip);
93 assert_eq!("zstd".parse::<Compression>().unwrap(), Compression::Zstd);
94 assert_eq!("zst".parse::<Compression>().unwrap(), Compression::Zstd);
95 assert_eq!("none".parse::<Compression>().unwrap(), Compression::None);
96 assert_eq!("raw".parse::<Compression>().unwrap(), Compression::None);
97 assert!("invalid".parse::<Compression>().is_err());
98 }
99
100 #[test]
101 fn test_compression_display() {
102 assert_eq!(format!("{}", Compression::Gzip), "gzip");
103 assert_eq!(format!("{}", Compression::Zstd), "zstd");
104 assert_eq!(format!("{}", Compression::None), "none");
105 }
106
107 #[test]
108 fn test_compression_default() {
109 assert_eq!(Compression::default(), Compression::Gzip);
110 }
111
112 #[test]
113 fn test_gzip_compression() {
114 let temp_dir = TempDir::new().unwrap();
115 let output_path = temp_dir.path().join("test.gz");
116 let data: Vec<u8> = b"hello world ".repeat(100).to_vec();
118
119 let size = compress_archive(&data, &output_path, Compression::Gzip).unwrap();
120
121 assert!(output_path.exists());
122 assert!(size > 0);
123
124 let file = File::open(&output_path).unwrap();
126 let mut decoder = flate2::read::GzDecoder::new(file);
127 let mut decompressed = Vec::new();
128 decoder.read_to_end(&mut decompressed).unwrap();
129 assert_eq!(decompressed, data);
130 }
131
132 #[test]
133 fn test_zstd_compression() {
134 let temp_dir = TempDir::new().unwrap();
135 let output_path = temp_dir.path().join("test.zst");
136 let data = b"hello world hello world hello world";
137
138 let size = compress_archive(data, &output_path, Compression::Zstd).unwrap();
139
140 assert!(output_path.exists());
141 assert!(
142 size < data.len() as u64,
143 "Compressed size should be smaller"
144 );
145
146 let compressed = fs::read(&output_path).unwrap();
148 let decompressed = zstd::decode_all(&compressed[..]).unwrap();
149 assert_eq!(decompressed, data);
150 }
151
152 #[test]
153 fn test_no_compression() {
154 let temp_dir = TempDir::new().unwrap();
155 let output_path = temp_dir.path().join("test.cpio");
156 let data = b"hello world";
157
158 let size = compress_archive(data, &output_path, Compression::None).unwrap();
159
160 assert_eq!(size, data.len() as u64);
161 assert_eq!(fs::read(&output_path).unwrap(), data);
162 }
163}