Skip to main content

npm_utils/
extract.rs

1//! Archive extraction with path-traversal protection.
2//!
3//! Both [`tar_gz`] and [`zip`] iterate an archive in memory and write selected
4//! entries beneath `dest`. `strip_prefix` (e.g. `Some("package/")` for npm
5//! tarballs) is removed from each entry path before [`Select`] is applied.
6
7use flate2::read::GzDecoder;
8use std::fs::{create_dir_all, File};
9use std::io::{Cursor, Read, Write};
10use std::path::{Component, Path, PathBuf};
11use tar::Archive;
12
13/// Which archive entries to extract, and where each lands (relative to `dest`).
14pub enum Select<'a> {
15    /// Every file, keeping its (prefix-stripped) path. Directory entries create
16    /// directories; symlinks are skipped.
17    All,
18    /// Only entries whose (prefix-stripped) path equals a listed source; written
19    /// to the paired destination.
20    Files(&'a [(&'a str, &'a str)]),
21    /// Each entry's (prefix-stripped) path is handed to the closure, which
22    /// returns the destination path or `None` to skip the entry.
23    Matching(&'a dyn Fn(&str) -> Option<String>),
24}
25
26impl Select<'_> {
27    /// Resolve an entry's (prefix-stripped) archive path to a destination
28    /// relative path, or `None` to skip it.
29    fn dest_for(&self, rel: &str) -> Option<String> {
30        match self {
31            Select::All => Some(rel.to_string()),
32            Select::Files(files) => files
33                .iter()
34                .find(|(src, _)| *src == rel)
35                .map(|(_, dst)| dst.to_string()),
36            Select::Matching(f) => f(rel),
37        }
38    }
39}
40
41/// Extract a gzipped tarball into `dest`. Returns the number of files written.
42pub fn tar_gz(
43    bytes: &[u8],
44    dest: &Path,
45    strip_prefix: Option<&str>,
46    select: Select<'_>,
47) -> Result<usize, Box<dyn std::error::Error>> {
48    let mut archive = Archive::new(GzDecoder::new(Cursor::new(bytes)));
49    let mut count = 0;
50    for entry in archive.entries()? {
51        let mut entry = entry?;
52        let entry_type = entry.header().entry_type();
53        if entry_type.is_symlink() {
54            continue;
55        }
56        let path = entry.path()?;
57        let path_str = path.to_string_lossy().into_owned();
58        let rel = strip(&path_str, strip_prefix);
59        if rel.is_empty() {
60            continue;
61        }
62        let is_dir = entry_type.is_dir();
63        if is_dir {
64            if matches!(select, Select::All) {
65                create_dir_all(safe_join(dest, rel)?)?;
66            }
67            continue;
68        }
69        let Some(dest_rel) = select.dest_for(rel) else {
70            continue;
71        };
72        let out = safe_join(dest, &dest_rel)?;
73        if let Some(parent) = out.parent() {
74            create_dir_all(parent)?;
75        }
76        let mut content = Vec::new();
77        entry.read_to_end(&mut content)?;
78        File::create(&out)?.write_all(&content)?;
79        count += 1;
80    }
81    Ok(count)
82}
83
84/// Extract a zip archive into `dest`. Returns the number of files written.
85pub fn zip(
86    bytes: &[u8],
87    dest: &Path,
88    strip_prefix: Option<&str>,
89    select: Select<'_>,
90) -> Result<usize, Box<dyn std::error::Error>> {
91    let mut archive = zip::ZipArchive::new(Cursor::new(bytes))?;
92    let mut count = 0;
93    for i in 0..archive.len() {
94        let mut file = archive.by_index(i)?;
95        if file.is_dir() || file.is_symlink() {
96            continue;
97        }
98        let name = match file.enclosed_name() {
99            Some(n) => n.to_string_lossy().into_owned(),
100            None => return Err("unsafe zip entry name (escapes destination)".into()),
101        };
102        let rel = strip(&name, strip_prefix);
103        if rel.is_empty() {
104            continue;
105        }
106        let Some(dest_rel) = select.dest_for(rel) else {
107            continue;
108        };
109        let out = safe_join(dest, &dest_rel)?;
110        if let Some(parent) = out.parent() {
111            create_dir_all(parent)?;
112        }
113        let mut content = Vec::new();
114        file.read_to_end(&mut content)?;
115        File::create(&out)?.write_all(&content)?;
116        count += 1;
117    }
118    Ok(count)
119}
120
121fn strip<'a>(path: &'a str, prefix: Option<&str>) -> &'a str {
122    match prefix {
123        Some(p) => path.strip_prefix(p).unwrap_or(path),
124        None => path,
125    }
126}
127
128/// Join `relative` onto `base`, returning an error for an empty path or anything
129/// that would escape `base` (`..`, absolute, or a drive/root prefix). Extraction
130/// aborts on such an entry rather than silently skipping it, so a malicious or
131/// malformed archive fails loudly instead of being partially written.
132fn safe_join(base: &Path, relative: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
133    let reject = || -> Box<dyn std::error::Error> {
134        format!("unsafe archive entry path: {relative:?}").into()
135    };
136    if relative.is_empty() || relative.contains("..") {
137        return Err(reject());
138    }
139    let path = Path::new(relative);
140    for component in path.components() {
141        match component {
142            Component::ParentDir | Component::Prefix(_) | Component::RootDir => {
143                return Err(reject())
144            }
145            _ => {}
146        }
147    }
148    Ok(base.join(relative))
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use flate2::write::GzEncoder;
155    use flate2::Compression;
156    use std::io::Cursor as IoCursor;
157    use tempfile::tempdir;
158
159    /// Build an in-memory `.tar.gz` from `(path, contents)` pairs.
160    fn make_tar_gz(entries: &[(&str, &[u8])]) -> Vec<u8> {
161        let mut builder = tar::Builder::new(GzEncoder::new(Vec::new(), Compression::fast()));
162        for (path, contents) in entries {
163            let mut header = tar::Header::new_gnu();
164            header.set_size(contents.len() as u64);
165            header.set_mode(0o644);
166            header.set_entry_type(tar::EntryType::Regular);
167            builder
168                .append_data(&mut header, *path, IoCursor::new(*contents))
169                .unwrap();
170        }
171        builder.finish().unwrap();
172        builder.into_inner().unwrap().finish().unwrap()
173    }
174
175    #[test]
176    fn tar_gz_all_strips_prefix() {
177        let tgz = make_tar_gz(&[("package/index.js", b"a"), ("package/sub/util.js", b"b")]);
178        let tmp = tempdir().unwrap();
179        let n = tar_gz(&tgz, tmp.path(), Some("package/"), Select::All).unwrap();
180        assert_eq!(n, 2);
181        assert!(tmp.path().join("index.js").exists());
182        assert!(tmp.path().join("sub/util.js").exists());
183    }
184
185    #[test]
186    fn tar_gz_files_picks_named_entries() {
187        let tgz = make_tar_gz(&[
188            ("package/dist/sprite.svg", b"<svg/>"),
189            ("package/readme.md", b"x"),
190        ]);
191        let tmp = tempdir().unwrap();
192        let n = tar_gz(
193            &tgz,
194            tmp.path(),
195            Some("package/"),
196            Select::Files(&[("dist/sprite.svg", "icons/sprite.svg")]),
197        )
198        .unwrap();
199        assert_eq!(n, 1);
200        assert!(tmp.path().join("icons/sprite.svg").exists());
201        assert!(!tmp.path().join("readme.md").exists());
202    }
203
204    #[test]
205    fn tar_gz_matching_predicate_and_prefix() {
206        let tgz = make_tar_gz(&[
207            ("package/a.js", b"x"),
208            ("package/b.css", b"y"),
209            ("package/c.mjs", b"z"),
210        ]);
211        let tmp = tempdir().unwrap();
212        let keep_js = |rel: &str| -> Option<String> {
213            (rel.ends_with(".js") || rel.ends_with(".mjs")).then(|| format!("lit/{rel}"))
214        };
215        let n = tar_gz(
216            &tgz,
217            tmp.path(),
218            Some("package/"),
219            Select::Matching(&keep_js),
220        )
221        .unwrap();
222        assert_eq!(n, 2);
223        assert!(tmp.path().join("lit/a.js").exists());
224        assert!(tmp.path().join("lit/c.mjs").exists());
225        assert!(!tmp.path().join("lit/b.css").exists());
226    }
227
228    #[test]
229    fn safe_join_rejects_escapes() {
230        let base = Path::new("/tmp/base");
231        assert!(safe_join(base, "../escape").is_err());
232        assert!(safe_join(base, "/abs").is_err());
233        assert!(safe_join(base, "a/../b").is_err());
234        assert!(safe_join(base, "").is_err());
235        assert!(safe_join(base, "a/b.js").is_ok());
236    }
237
238    #[test]
239    fn tar_gz_errors_when_selection_escapes_dest() {
240        // Benign archive, but the selection maps an entry to a path that escapes
241        // `dest` — extraction must abort, not silently skip.
242        let tgz = make_tar_gz(&[("package/x.js", b"x")]);
243        let tmp = tempdir().unwrap();
244        let escape = |_rel: &str| -> Option<String> { Some("../escape.js".to_string()) };
245        let result = tar_gz(
246            &tgz,
247            tmp.path(),
248            Some("package/"),
249            Select::Matching(&escape),
250        );
251        assert!(result.is_err(), "extraction must error when a dest escapes");
252    }
253}