1use flate2::read::GzDecoder;
22use std::fs::{create_dir_all, File};
23use std::io::{Cursor, Read, Write};
24use std::path::Path;
25use tar::Archive;
26
27use crate::path_safety::{contained_target, safe_join};
28
29pub enum Select<'a> {
31 All,
34 Files(&'a [(&'a str, &'a str)]),
37 Matching(&'a dyn Fn(&str) -> Option<String>),
40}
41
42impl Select<'_> {
43 fn dest_for(&self, rel: &str) -> Option<String> {
46 match self {
47 Select::All => Some(rel.to_string()),
48 Select::Files(files) => files
49 .iter()
50 .find(|(src, _)| *src == rel)
51 .map(|(_, dst)| dst.to_string()),
52 Select::Matching(f) => f(rel),
53 }
54 }
55}
56
57pub fn tar_gz(
59 bytes: &[u8],
60 dest: &Path,
61 strip_prefix: Option<&str>,
62 select: Select<'_>,
63) -> Result<usize, Box<dyn std::error::Error>> {
64 let mut archive = Archive::new(GzDecoder::new(Cursor::new(bytes)));
65 let mut count = 0;
66 let mut total: u64 = 0;
67 let mut entries: u64 = 0;
68 create_dir_all(dest)?;
70 let root = dest.canonicalize()?;
71 for entry in archive.entries()? {
72 let mut entry = entry?;
73 entries += 1;
74 if entries > MAX_ENTRIES {
75 return Err(too_many_entries());
76 }
77 let entry_type = entry.header().entry_type();
78 let is_dir = entry_type.is_dir();
79 if !is_dir && !entry_type.is_file() {
83 continue;
84 }
85 let path = entry.path()?;
86 let path_str = path.to_string_lossy().into_owned();
87 let rel = strip(&path_str, strip_prefix);
88 if is_root_entry(rel) {
91 continue;
92 }
93 if is_dir {
94 if matches!(select, Select::All) {
95 create_dir_all(safe_join(dest, rel)?)?;
96 }
97 continue;
98 }
99 let Some(dest_rel) = select.dest_for(rel) else {
100 continue;
101 };
102 let out = safe_join(dest, &dest_rel)?;
103 let target = contained_target(&root, &out)?;
104 let mut file = File::create(&target)?;
105 total += copy_capped(&mut entry, &mut file, MAX_TOTAL_BYTES.saturating_sub(total))?;
106 count += 1;
107 }
108 Ok(count)
109}
110
111pub fn zip(
113 bytes: &[u8],
114 dest: &Path,
115 strip_prefix: Option<&str>,
116 select: Select<'_>,
117) -> Result<usize, Box<dyn std::error::Error>> {
118 let mut archive = zip::ZipArchive::new(Cursor::new(bytes))?;
119 if archive.len() as u64 > MAX_ENTRIES {
120 return Err(too_many_entries());
121 }
122 let mut count = 0;
123 let mut total: u64 = 0;
124 create_dir_all(dest)?;
126 let root = dest.canonicalize()?;
127 for i in 0..archive.len() {
128 let mut file = archive.by_index(i)?;
129 if file.is_dir() || file.is_symlink() {
130 continue;
131 }
132 let name = match file.enclosed_name() {
133 Some(n) => n.to_string_lossy().into_owned(),
134 None => return Err("unsafe zip entry name (escapes destination)".into()),
135 };
136 let rel = strip(&name, strip_prefix);
137 if is_root_entry(rel) {
139 continue;
140 }
141 let Some(dest_rel) = select.dest_for(rel) else {
142 continue;
143 };
144 let out = safe_join(dest, &dest_rel)?;
145 let target = contained_target(&root, &out)?;
146 let mut writer = File::create(&target)?;
147 total += copy_capped(
148 &mut file,
149 &mut writer,
150 MAX_TOTAL_BYTES.saturating_sub(total),
151 )?;
152 count += 1;
153 }
154 Ok(count)
155}
156
157fn strip<'a>(path: &'a str, prefix: Option<&str>) -> &'a str {
158 match prefix {
159 Some(p) => path.strip_prefix(p).unwrap_or(path),
160 None => path,
161 }
162}
163
164fn is_root_entry(rel: &str) -> bool {
168 rel.is_empty() || rel == "."
169}
170
171const MAX_TOTAL_BYTES: u64 = 4 * 1024 * 1024 * 1024; const MAX_ENTRIES: u64 = 200_000;
181
182fn too_many_entries() -> Box<dyn std::error::Error> {
183 format!("archive has more than {MAX_ENTRIES} entries (possible archive bomb)").into()
184}
185
186fn copy_capped<R: Read, W: Write>(
191 reader: &mut R,
192 writer: &mut W,
193 budget: u64,
194) -> Result<u64, Box<dyn std::error::Error>> {
195 let written = std::io::copy(&mut reader.take(budget.saturating_add(1)), writer)?;
198 if written > budget {
199 return Err(
200 "archive exceeds the extraction size limit (possible decompression bomb)".into(),
201 );
202 }
203 Ok(written)
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use flate2::write::GzEncoder;
210 use flate2::Compression;
211 use std::io::Cursor as IoCursor;
212 use tempfile::tempdir;
213
214 fn make_tar_gz(entries: &[(&str, &[u8])]) -> Vec<u8> {
216 let mut builder = tar::Builder::new(GzEncoder::new(Vec::new(), Compression::fast()));
217 for (path, contents) in entries {
218 let mut header = tar::Header::new_gnu();
219 header.set_size(contents.len() as u64);
220 header.set_mode(0o644);
221 header.set_entry_type(tar::EntryType::Regular);
222 builder
223 .append_data(&mut header, *path, IoCursor::new(*contents))
224 .unwrap();
225 }
226 builder.finish().unwrap();
227 builder.into_inner().unwrap().finish().unwrap()
228 }
229
230 #[test]
231 fn tar_gz_all_strips_prefix() {
232 let tgz = make_tar_gz(&[("package/index.js", b"a"), ("package/sub/util.js", b"b")]);
233 let tmp = tempdir().unwrap();
234 let n = tar_gz(&tgz, tmp.path(), Some("package/"), Select::All).unwrap();
235 assert_eq!(n, 2);
236 assert!(tmp.path().join("index.js").exists());
237 assert!(tmp.path().join("sub/util.js").exists());
238 }
239
240 #[test]
241 fn tar_gz_files_picks_named_entries() {
242 let tgz = make_tar_gz(&[
243 ("package/dist/sprite.svg", b"<svg/>"),
244 ("package/readme.md", b"x"),
245 ]);
246 let tmp = tempdir().unwrap();
247 let n = tar_gz(
248 &tgz,
249 tmp.path(),
250 Some("package/"),
251 Select::Files(&[("dist/sprite.svg", "icons/sprite.svg")]),
252 )
253 .unwrap();
254 assert_eq!(n, 1);
255 assert!(tmp.path().join("icons/sprite.svg").exists());
256 assert!(!tmp.path().join("readme.md").exists());
257 }
258
259 #[test]
260 fn tar_gz_matching_predicate_and_prefix() {
261 let tgz = make_tar_gz(&[
262 ("package/a.js", b"x"),
263 ("package/b.css", b"y"),
264 ("package/c.mjs", b"z"),
265 ]);
266 let tmp = tempdir().unwrap();
267 let keep_js = |rel: &str| -> Option<String> {
268 (rel.ends_with(".js") || rel.ends_with(".mjs")).then(|| format!("lit/{rel}"))
269 };
270 let n = tar_gz(
271 &tgz,
272 tmp.path(),
273 Some("package/"),
274 Select::Matching(&keep_js),
275 )
276 .unwrap();
277 assert_eq!(n, 2);
278 assert!(tmp.path().join("lit/a.js").exists());
279 assert!(tmp.path().join("lit/c.mjs").exists());
280 assert!(!tmp.path().join("lit/b.css").exists());
281 }
282
283 #[test]
284 fn tar_gz_errors_when_selection_escapes_dest() {
285 let tgz = make_tar_gz(&[("package/x.js", b"x")]);
288 let tmp = tempdir().unwrap();
289 let escape = |_rel: &str| -> Option<String> { Some("../escape.js".to_string()) };
290 let result = tar_gz(
291 &tgz,
292 tmp.path(),
293 Some("package/"),
294 Select::Matching(&escape),
295 );
296 assert!(result.is_err(), "extraction must error when a dest escapes");
297 }
298
299 #[test]
300 #[cfg(unix)]
301 fn rejects_writing_through_a_preexisting_symlink() {
302 use std::os::unix::fs::symlink;
303 let tmp = tempdir().unwrap();
307 let dest = tmp.path().join("dest");
308 let outside = tmp.path().join("outside");
309 std::fs::create_dir_all(&dest).unwrap();
310 std::fs::create_dir_all(&outside).unwrap();
311 symlink(&outside, dest.join("evil")).unwrap();
312
313 let tgz = make_tar_gz(&[("package/evil/pwned", b"owned")]);
314 let result = tar_gz(&tgz, &dest, Some("package/"), Select::All);
315
316 assert!(
317 result.is_err(),
318 "must refuse to write through an escaping symlink"
319 );
320 assert!(
321 !outside.join("pwned").exists(),
322 "nothing may be written outside the extract dir"
323 );
324 }
325
326 #[test]
327 fn odd_but_legal_entry_names_stay_contained() {
328 let tmp = tempdir().unwrap();
332 let dest = tmp.path().join("dest");
333 let tgz = make_tar_gz(&[
334 (".../flag.txt", b"a"),
335 ("~/flag.txt", b"b"),
336 ("file:///tmp/flag.txt", b"c"),
337 ]);
338 let n = tar_gz(&tgz, &dest, None, Select::All).unwrap();
339 assert_eq!(n, 3);
340 assert!(dest.join("...").join("flag.txt").is_file());
341 assert!(dest.join("~").join("flag.txt").is_file());
342 assert!(dest.join("file:").join("tmp").join("flag.txt").is_file());
344 assert!(!tmp.path().join("flag.txt").exists());
346 }
347
348 fn tar_with_links() -> Vec<u8> {
350 let mut b = tar::Builder::new(GzEncoder::new(Vec::new(), Compression::fast()));
351 let mut reg = tar::Header::new_gnu();
352 reg.set_size(4);
353 reg.set_mode(0o644);
354 reg.set_entry_type(tar::EntryType::Regular);
355 b.append_data(&mut reg, "real.txt", IoCursor::new(&b"data"[..]))
356 .unwrap();
357
358 let mut sym = tar::Header::new_gnu();
359 sym.set_size(0);
360 sym.set_mode(0o777);
361 sym.set_entry_type(tar::EntryType::Symlink);
362 b.append_link(&mut sym, "evil-symlink", "real.txt").unwrap();
363
364 let mut hard = tar::Header::new_gnu();
365 hard.set_size(0);
366 hard.set_mode(0o644);
367 hard.set_entry_type(tar::EntryType::Link);
368 b.append_link(&mut hard, "evil-hardlink", "real.txt")
369 .unwrap();
370
371 b.finish().unwrap();
372 b.into_inner().unwrap().finish().unwrap()
373 }
374
375 #[test]
376 fn skips_symlink_and_hardlink_entries() {
377 let tmp = tempdir().unwrap();
380 let dest = tmp.path().join("dest");
381 let n = tar_gz(&tar_with_links(), &dest, None, Select::All).unwrap();
382 assert_eq!(n, 1, "only the regular file is written");
383 assert!(dest.join("real.txt").is_file());
384 assert!(!dest.join("evil-symlink").exists());
385 assert!(!dest.join("evil-hardlink").exists());
386 }
387
388 #[test]
389 fn copy_capped_streams_within_budget_and_rejects_a_bomb() {
390 let src = vec![7u8; 1000];
391 let mut ok = Vec::new();
393 assert_eq!(
394 copy_capped(&mut src.as_slice(), &mut ok, 2000).unwrap(),
395 1000
396 );
397 assert_eq!(ok, src);
398 let mut overflow = Vec::new();
400 assert!(copy_capped(&mut src.as_slice(), &mut overflow, 100).is_err());
401 }
402
403 #[test]
404 fn is_root_entry_flags_dot_and_empty() {
405 assert!(is_root_entry("."));
408 assert!(is_root_entry(""));
409 assert!(!is_root_entry("index.js"));
410 assert!(!is_root_entry("./index.js"));
411 assert!(!is_root_entry("..."));
412 }
413
414 #[test]
415 fn refuses_to_write_at_the_destination_root() {
416 let tmp = tempdir().unwrap();
420 let dest = tmp.path().join("dest");
421 let tgz = make_tar_gz(&[("package/x.js", b"x")]);
422 let onto_root = |_rel: &str| -> Option<String> { Some(".".to_string()) };
423 let result = tar_gz(&tgz, &dest, Some("package/"), Select::Matching(&onto_root));
424 assert!(result.is_err(), "writing onto the root must be refused");
425 assert!(
426 dest.is_dir(),
427 "the destination root remains a real directory"
428 );
429 }
430}