1use anyhow::{anyhow, Context, Result};
16use cpio::{write_cpio, NewcBuilder, NewcReader};
17use lazy_static::lazy_static;
18use std::collections::BTreeMap;
19use std::io::{BufRead, Cursor, Read};
20use xz2::stream::{Check, Stream};
21use xz2::write::XzEncoder;
22
23use crate::io::*;
24
25lazy_static! {
26 static ref ALL_GLOB: GlobMatcher = GlobMatcher::new(&["*"]).unwrap();
27}
28
29#[derive(Default, Debug)]
30pub struct Initrd {
31 members: BTreeMap<String, Vec<u8>>,
32}
33
34impl Initrd {
35 pub fn to_bytes(&self) -> Result<Vec<u8>> {
37 let mut members = Vec::new();
38 let mut cwd: Vec<&str> = Vec::new();
46 for (path, contents) in &self.members {
47 let mut parent: Vec<&str> = path.split('/').collect();
49 parent.pop();
50 cwd = cwd
51 .iter()
52 .zip(&parent)
53 .take_while(|(a, b)| a == b)
54 .map(|(a, _)| *a)
55 .collect();
56 for component in parent.iter().skip(cwd.len()) {
59 cwd.push(component);
60 members.push((
62 NewcBuilder::new(&cwd.join("/")).mode(0o40_755),
63 Cursor::new(&[][..]),
64 ));
65 }
66 members.push((
68 NewcBuilder::new(path).mode(0o100_600),
69 Cursor::new(contents),
70 ));
71 }
72 let mut encoder = XzEncoder::new_stream(
74 Vec::new(),
75 Stream::new_easy_encoder(9, Check::Crc32).context("creating XZ encoder")?,
76 );
77 write_cpio(members.drain(..), &mut encoder).context("writing CPIO archive")?;
78 encoder.finish().context("closing XZ compressor")
79 }
80
81 pub fn from_reader<R: Read>(source: R) -> Result<Self> {
83 Self::from_reader_filtered(source, &ALL_GLOB)
84 }
85
86 pub fn from_reader_filtered<R: Read>(source: R, filter: &GlobMatcher) -> Result<Self> {
89 let mut source = PeekReader::with_capacity(BUFFER_SIZE, source);
90 let mut result = Self::default();
91 while !source
93 .fill_buf()
94 .context("checking for data in initrd")?
95 .is_empty()
96 {
97 let mut decompressor = DecompressReader::for_concatenated(source)?;
99 loop {
100 let mut reader = NewcReader::new(decompressor).context("reading CPIO entry")?;
101 let entry = reader.entry();
102 if entry.is_trailer() {
103 decompressor = reader.finish().context("finishing reading CPIO trailer")?;
104 break;
105 }
106 let name = entry.name().to_string();
107 if entry.mode() & 0o170_000 == 0o100_000 && filter.matches(&name) {
108 let mut buf = Vec::with_capacity(entry.file_size() as usize);
110 reader
111 .read_to_end(&mut buf)
112 .context("reading CPIO entry contents")?;
113 result.members.insert(name, buf);
114 }
115 decompressor = reader.finish().context("finishing reading CPIO entry")?;
116 }
117
118 if decompressor.compressed() {
120 let mut trailing = Vec::new();
121 decompressor
122 .read_to_end(&mut trailing)
123 .context("finishing reading compressed archive")?;
124 if trailing.iter().any(|v| *v != 0) {
126 bail!("found trailing garbage inside compressed archive");
127 }
128 }
129 source = decompressor.into_inner();
130
131 loop {
133 let buf = source
134 .fill_buf()
135 .context("checking for padding in initrd")?;
136 if buf.is_empty() {
137 break;
139 }
140 match buf.iter().position(|v| *v != 0) {
141 Some(pos) => {
142 source.consume(pos);
143 break;
144 }
145 None => {
146 let len = buf.len();
147 source.consume(len);
148 }
149 }
150 }
151 }
152 Ok(result)
153 }
154
155 pub fn get(&self, path: &str) -> Option<&[u8]> {
156 self.members.get(path).map(|v| v.as_slice())
157 }
158
159 pub fn find(&self, filter: &GlobMatcher) -> BTreeMap<&str, &[u8]> {
160 self.members
161 .iter()
162 .filter(|(p, _)| filter.matches(p))
163 .map(|(p, c)| (p.as_str(), c.as_slice()))
164 .collect()
165 }
166
167 pub fn add(&mut self, path: &str, contents: Vec<u8>) {
168 self.members.insert(path.into(), contents);
169 }
170
171 pub fn remove(&mut self, path: &str) {
172 self.members.remove(path);
173 }
174
175 pub fn is_empty(&self) -> bool {
176 self.members.is_empty()
177 }
178}
179
180pub struct GlobMatcher {
181 patterns: Vec<glob::Pattern>,
182}
183
184impl GlobMatcher {
185 pub fn new(globs: &[&str]) -> Result<Self> {
186 Ok(Self {
187 patterns: globs
188 .iter()
189 .map(|p| glob::Pattern::new(p).map_err(|e| anyhow!(e)))
190 .collect::<Result<_>>()?,
191 })
192 }
193
194 fn matches(&self, path: &str) -> bool {
195 self.patterns.iter().any(|p| p.matches(path))
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use maplit::btreemap;
203 use xz2::read::XzDecoder;
204
205 #[test]
206 fn roundtrip() {
207 let input = r#"{}"#;
208 let mut initrd = Initrd::default();
209 initrd.add("z", input.as_bytes().into());
210 assert_eq!(
211 input.as_bytes(),
212 Initrd::from_reader(&*initrd.to_bytes().unwrap())
213 .unwrap()
214 .get("z")
215 .unwrap()
216 );
217 }
218
219 #[test]
220 fn compression() {
221 let mut archive: Vec<u8> = Vec::new();
222 XzDecoder::new(&include_bytes!("../../fixtures/initrd/compressed.img.xz")[..])
223 .read_to_end(&mut archive)
224 .unwrap();
225 let initrd = Initrd::from_reader(&*archive).unwrap();
226 assert_eq!(
227 initrd.members,
228 btreemap! {
229 "uncompressed-1/hello".into() => b"HELLO\n".to_vec(),
230 "uncompressed-1/world".into() => b"WORLD\n".to_vec(),
231 "uncompressed-2/hello".into() => b"HELLO\n".to_vec(),
232 "uncompressed-2/world".into() => b"WORLD\n".to_vec(),
233 "gzip/hello".into() => b"HELLO\n".to_vec(),
234 "gzip/world".into() => b"WORLD\n".to_vec(),
235 "xz/hello".into() => b"HELLO\n".to_vec(),
236 "xz/world".into() => b"WORLD\n".to_vec(),
237 "zstd/hello".into() => b"HELLO\n".to_vec(),
238 "zstd/world".into() => b"WORLD\n".to_vec(),
239 }
240 );
241 }
242
243 #[test]
246 fn redundancy() {
247 let mut archive: Vec<u8> = Vec::new();
248 XzDecoder::new(&include_bytes!("../../fixtures/initrd/redundant.img.xz")[..])
249 .read_to_end(&mut archive)
250 .unwrap();
251 assert_eq!(
252 Initrd::from_reader(&*archive)
253 .unwrap()
254 .get("data/file")
255 .unwrap(),
256 b"third\n"
257 );
258 }
259
260 #[test]
261 fn matching() {
262 let mut archive: Vec<u8> = Vec::new();
263 XzDecoder::new(&include_bytes!("../../fixtures/initrd/compressed.img.xz")[..])
264 .read_to_end(&mut archive)
265 .unwrap();
266
267 let matcher = |glob| GlobMatcher::new(&[glob]).unwrap();
268
269 let initrd = Initrd::from_reader(&*archive).unwrap();
271 assert_eq!(initrd.find(&matcher("gzip/hello")).len(), 1);
272 assert_eq!(initrd.find(&matcher("gzip/*")).len(), 2);
273 assert_eq!(initrd.find(&matcher("*/hello")).len(), 5);
274 assert_eq!(initrd.find(&matcher("*")).len(), 10);
275 assert_eq!(initrd.find(&matcher("z")).len(), 0);
276
277 let initrd = Initrd::from_reader_filtered(&*archive, &matcher("z")).unwrap();
279 assert_eq!(initrd.find(&matcher("*")).len(), 0);
280 let initrd = Initrd::from_reader_filtered(&*archive, &matcher("gzip/*")).unwrap();
281 assert_eq!(initrd.find(&matcher("*")).len(), 2);
282 let initrd = Initrd::from_reader_filtered(&*archive, &matcher("uncompressed-*")).unwrap();
283 assert_eq!(initrd.find(&matcher("*")).len(), 4);
284 }
285
286 #[test]
287 fn directories() {
288 let mut initrd = Initrd::default();
289 initrd.add("c/f", vec![]);
290 initrd.add("c/a/b/d/f", vec![]);
291 initrd.add("d", vec![]);
292 initrd.add("c/g", vec![]);
293 initrd.add("c/a/c/f", vec![]);
294 initrd.add("a/b/c/d/e/f", vec![]);
295
296 let mut cpio = Vec::new();
297 XzDecoder::new(&*initrd.to_bytes().unwrap())
298 .read_to_end(&mut cpio)
299 .unwrap();
300
301 let mut source = &*cpio;
302 let mut paths = Vec::new();
303 loop {
304 let reader = NewcReader::new(source).unwrap();
305 let entry = reader.entry();
306 if entry.is_trailer() {
307 break;
308 }
309 let is_dir = entry.mode() & 0o170_000 == 0o40_000;
310 paths.push((is_dir, entry.name().to_string()));
311 source = reader.finish().unwrap();
312 }
313
314 assert_eq!(
315 paths
316 .iter()
317 .map(|(d, p)| (*d, p.as_str()))
318 .collect::<Vec<(bool, &str)>>(),
319 vec![
320 (true, "a"),
321 (true, "a/b"),
322 (true, "a/b/c"),
323 (true, "a/b/c/d"),
324 (true, "a/b/c/d/e"),
325 (false, "a/b/c/d/e/f"),
326 (true, "c"),
327 (true, "c/a"),
328 (true, "c/a/b"),
329 (true, "c/a/b/d"),
330 (false, "c/a/b/d/f"),
331 (true, "c/a/c"),
332 (false, "c/a/c/f"),
333 (false, "c/f"),
334 (false, "c/g"),
335 (false, "d"),
336 ]
337 );
338 }
339
340 #[test]
344 fn padded_filenames() {
345 let mut archive: Vec<u8> = Vec::new();
346 XzDecoder::new(&include_bytes!("../../fixtures/initrd/padded-names.img.xz")[..])
347 .read_to_end(&mut archive)
348 .unwrap();
349 assert_eq!(
350 Initrd::from_reader(&*archive).unwrap().members,
351 btreemap! {
352 "dir/hello".into() => std::iter::repeat_n(b'z', 5000).collect(),
353 "dir/world".into() => std::iter::repeat_n(b'q', 4500).collect(),
354 }
355 );
356 }
357}