Skip to main content

omne_cli/
tarball.rs

1//! Safe tarball extraction with defense-in-depth against path traversal.
2//!
3//! `extract_safe` takes any `impl Read` source and a target directory,
4//! iterates entries manually, and rejects any entry that would escape the
5//! target directory, is not a regular file or directory, or would be
6//! written through a pre-planted symlink.
7//!
8//! This module has **no** dependency on HTTP, ureq, or `GithubClient`.
9//! The split seam from the network layer is at the `impl Read` boundary.
10
11#![allow(dead_code)]
12
13use std::io::Read;
14use std::path::{Component, Path, PathBuf};
15
16use path_clean::PathClean;
17use thiserror::Error;
18
19/// Errors returned by tarball extraction operations.
20#[derive(Debug, Error)]
21pub enum Error {
22    /// A tarball entry attempted to escape the target directory via `..`,
23    /// absolute paths, or Windows prefix components.
24    #[error("tarball path traversal attempt: {entry_path}")]
25    Escape { entry_path: PathBuf },
26
27    /// A tarball entry is not a regular file or directory (e.g. symlink,
28    /// hardlink, character device, FIFO).
29    #[error("unsupported tarball entry type '{kind}': {entry_path}")]
30    UnsupportedEntry {
31        entry_path: PathBuf,
32        kind: &'static str,
33    },
34
35    /// A pre-existing symlink was found in the target tree before extraction.
36    /// Extraction is refused to prevent symlink-following attacks.
37    #[error("pre-existing symlink in target tree: {path}")]
38    PrePlantedSymlink { path: PathBuf },
39
40    /// Underlying I/O error during extraction.
41    #[error(transparent)]
42    Io(#[from] std::io::Error),
43}
44
45/// Extract a `.tar.gz` archive from `source` into `target`, rejecting
46/// any entry that would escape the target or is not a regular file/directory.
47///
48/// # Pre-extraction checks
49///
50/// Before iterating entries, this function:
51/// 1. Canonicalizes `target` (caller must create it first).
52/// 2. Walks every ancestor of `target` up to the filesystem root and
53///    verifies none is a symlink.
54/// 3. Lists every direct child of `target` and verifies none is a symlink.
55///
56/// # Per-entry checks
57///
58/// For each entry:
59/// - Only `Regular` and `Directory` types are accepted; all others
60///   (symlinks, hardlinks, devices, etc.) are rejected.
61/// - Path components are scanned: `..`, `RootDir`, and `Prefix` are rejected.
62/// - Absolute paths are rejected.
63/// - A lexical prefix check confirms the resolved path stays within `target`
64///   (case-insensitive on Windows).
65pub fn extract_safe(source: impl Read, target: &Path) -> Result<(), Error> {
66    // ── Pre-extraction symlink precondition ──────────────────────────
67    let canon_target = target.canonicalize()?;
68
69    // Check ancestors of target for symlinks.
70    check_ancestors_not_symlinks(&canon_target)?;
71
72    // Check direct children of target for symlinks.
73    check_children_not_symlinks(&canon_target)?;
74
75    // ── Extraction loop ──────────────────────────────────────────────
76    let gz = flate2::read::GzDecoder::new(source);
77    let mut archive = tar::Archive::new(gz);
78
79    for entry_result in archive.entries()? {
80        let mut entry = entry_result?;
81        let header = entry.header();
82
83        // Filter entry type: accept only Regular and Directory.
84        let entry_type = header.entry_type();
85        if !entry_type.is_file() && !entry_type.is_dir() {
86            let kind = classify_entry_type(entry_type);
87            let entry_path = entry.path()?.into_owned();
88            return Err(Error::UnsupportedEntry { entry_path, kind });
89        }
90
91        // Read and validate the path.
92        let raw_path = entry.path()?.into_owned();
93
94        // Reject absolute paths.
95        if raw_path.is_absolute() {
96            return Err(Error::Escape {
97                entry_path: raw_path,
98            });
99        }
100
101        // Reject dangerous components.
102        for component in raw_path.components() {
103            match component {
104                Component::ParentDir | Component::RootDir | Component::Prefix(_) => {
105                    return Err(Error::Escape {
106                        entry_path: raw_path,
107                    });
108                }
109                _ => {}
110            }
111        }
112
113        // Lexical prefix check: resolved path must stay within target.
114        let resolved = canon_target.join(&raw_path).clean();
115        if !starts_with_normalized(&resolved, &canon_target) {
116            return Err(Error::Escape {
117                entry_path: raw_path,
118            });
119        }
120
121        // Extract the entry. tar's own guards are defense-in-depth.
122        entry.unpack_in(&canon_target)?;
123    }
124
125    Ok(())
126}
127
128/// Check that no ancestor of `path` is a symlink.
129fn check_ancestors_not_symlinks(path: &Path) -> Result<(), Error> {
130    let mut current = path.to_path_buf();
131    while let Some(parent) = current.parent() {
132        if parent.as_os_str().is_empty() {
133            break;
134        }
135        if parent.symlink_metadata()?.file_type().is_symlink() {
136            return Err(Error::PrePlantedSymlink {
137                path: parent.to_path_buf(),
138            });
139        }
140        current = parent.to_path_buf();
141    }
142    Ok(())
143}
144
145/// Check that no direct child of `path` is a symlink.
146fn check_children_not_symlinks(path: &Path) -> Result<(), Error> {
147    if !path.is_dir() {
148        return Ok(());
149    }
150    for entry in std::fs::read_dir(path)? {
151        let entry = entry?;
152        if entry.metadata()?.file_type().is_symlink() {
153            return Err(Error::PrePlantedSymlink { path: entry.path() });
154        }
155    }
156    Ok(())
157}
158
159/// Case-aware `starts_with` check. On Windows, normalizes to lowercase
160/// for comparison. On Unix, uses byte comparison.
161fn starts_with_normalized(path: &Path, prefix: &Path) -> bool {
162    #[cfg(windows)]
163    {
164        let path_lower = path.to_string_lossy().to_lowercase();
165        let prefix_lower = prefix.to_string_lossy().to_lowercase();
166        path_lower.starts_with(&prefix_lower)
167    }
168    #[cfg(not(windows))]
169    {
170        path.starts_with(prefix)
171    }
172}
173
174/// Map tar entry type to a human-readable label.
175fn classify_entry_type(t: tar::EntryType) -> &'static str {
176    match t {
177        tar::EntryType::Symlink => "symlink",
178        tar::EntryType::Link => "hardlink",
179        tar::EntryType::Char => "char device",
180        tar::EntryType::Block => "block device",
181        tar::EntryType::Fifo => "fifo",
182        tar::EntryType::GNUSparse => "sparse",
183        _ => "unknown",
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use std::io::Cursor;
191    use tempfile::TempDir;
192
193    /// Build a .tar.gz in memory with the given entries.
194    /// Each entry is (path, content_bytes, is_dir).
195    fn build_tar_gz(entries: &[(&str, &[u8], bool)]) -> Vec<u8> {
196        let mut builder = tar::Builder::new(Vec::new());
197        for &(path, content, is_dir) in entries {
198            if is_dir {
199                let mut header = tar::Header::new_gnu();
200                header.set_entry_type(tar::EntryType::Directory);
201                header.set_size(0);
202                header.set_mode(0o755);
203                header.set_cksum();
204                builder
205                    .append_data(&mut header, path, &[] as &[u8])
206                    .unwrap();
207            } else {
208                let mut header = tar::Header::new_gnu();
209                header.set_entry_type(tar::EntryType::Regular);
210                header.set_size(content.len() as u64);
211                header.set_mode(0o644);
212                header.set_cksum();
213                builder.append_data(&mut header, path, content).unwrap();
214            }
215        }
216        let tar_bytes = builder.into_inner().unwrap();
217
218        // Compress with gzip.
219        use flate2::write::GzEncoder;
220        use flate2::Compression;
221        use std::io::Write;
222        let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
223        encoder.write_all(&tar_bytes).unwrap();
224        encoder.finish().unwrap()
225    }
226
227    /// Build a .tar.gz with a symlink entry.
228    fn build_tar_gz_with_symlink(link_name: &str, target: &str) -> Vec<u8> {
229        let mut builder = tar::Builder::new(Vec::new());
230        let mut header = tar::Header::new_gnu();
231        header.set_entry_type(tar::EntryType::Symlink);
232        header.set_size(0);
233        header.set_mode(0o777);
234        header.set_cksum();
235        builder.append_link(&mut header, link_name, target).unwrap();
236        let tar_bytes = builder.into_inner().unwrap();
237
238        use flate2::write::GzEncoder;
239        use flate2::Compression;
240        use std::io::Write;
241        let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
242        encoder.write_all(&tar_bytes).unwrap();
243        encoder.finish().unwrap()
244    }
245
246    /// Build a .tar.gz with a hardlink entry.
247    fn build_tar_gz_with_hardlink(link_name: &str, target: &str) -> Vec<u8> {
248        let mut builder = tar::Builder::new(Vec::new());
249        let mut header = tar::Header::new_gnu();
250        header.set_entry_type(tar::EntryType::Link);
251        header.set_size(0);
252        header.set_mode(0o644);
253        header.set_cksum();
254        builder.append_link(&mut header, link_name, target).unwrap();
255        let tar_bytes = builder.into_inner().unwrap();
256
257        use flate2::write::GzEncoder;
258        use flate2::Compression;
259        use std::io::Write;
260        let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
261        encoder.write_all(&tar_bytes).unwrap();
262        encoder.finish().unwrap()
263    }
264
265    /// Build a .tar.gz with a single file at a forged path.
266    ///
267    /// The tar crate's builder validates paths and rejects `..` and absolute
268    /// paths. To test our extraction guards, we need to forge a tarball with
269    /// a malicious path at the raw byte level. This builds a valid tar entry
270    /// with a safe path, then patches the path bytes in the raw tar data
271    /// before gzip-compressing.
272    fn build_tar_gz_with_forged_path(malicious_path: &str) -> Vec<u8> {
273        // Build a tar with a placeholder path.
274        let placeholder = "PLACEHOLDER_PATH_FOR_FORGING";
275        assert!(
276            malicious_path.len() <= placeholder.len(),
277            "malicious path too long for placeholder"
278        );
279
280        let content = b"evil";
281        let mut builder = tar::Builder::new(Vec::new());
282        let mut header = tar::Header::new_gnu();
283        header.set_entry_type(tar::EntryType::Regular);
284        header.set_size(content.len() as u64);
285        header.set_mode(0o644);
286        header.set_cksum();
287        builder
288            .append_data(&mut header, placeholder, &content[..])
289            .unwrap();
290        let mut tar_bytes = builder.into_inner().unwrap();
291
292        // Find and replace the placeholder path in the raw tar bytes.
293        // The path lives in the first 100 bytes of the 512-byte header.
294        if let Some(pos) = tar_bytes
295            .windows(placeholder.len())
296            .position(|w| w == placeholder.as_bytes())
297        {
298            // Zero out the path field first.
299            for b in &mut tar_bytes[pos..pos + placeholder.len()] {
300                *b = 0;
301            }
302            // Write the malicious path.
303            tar_bytes[pos..pos + malicious_path.len()].copy_from_slice(malicious_path.as_bytes());
304
305            // Recompute the header checksum (bytes 148..156).
306            // The checksum is the sum of all header bytes treating the
307            // checksum field itself as 8 spaces (0x20).
308            let header_start = pos - (pos % 512); // Align to 512-byte block.
309            let mut sum: u64 = 0;
310            for (i, &b) in tar_bytes[header_start..header_start + 512]
311                .iter()
312                .enumerate()
313            {
314                if (148..156).contains(&i) {
315                    sum += 0x20u64; // Checksum field treated as spaces.
316                } else {
317                    sum += b as u64;
318                }
319            }
320            let cksum = format!("{sum:06o}\0 ");
321            tar_bytes[header_start + 148..header_start + 156].copy_from_slice(cksum.as_bytes());
322        }
323
324        // Compress with gzip.
325        use flate2::write::GzEncoder;
326        use flate2::Compression;
327        use std::io::Write;
328        let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
329        encoder.write_all(&tar_bytes).unwrap();
330        encoder.finish().unwrap()
331    }
332
333    // ── Happy path ───────────────────────────────────────────────────
334
335    #[test]
336    fn extracts_regular_files_to_target() {
337        let tmp = TempDir::new().unwrap();
338        let target = tmp.path().join("dest");
339        std::fs::create_dir_all(&target).unwrap();
340
341        let gz = build_tar_gz(&[
342            ("core/", b"", true),
343            ("core/manifest.json", b"{\"version\":\"0.1\"}", false),
344        ]);
345
346        extract_safe(Cursor::new(gz), &target).unwrap();
347
348        let manifest = target.join("core/manifest.json");
349        assert!(manifest.exists(), "manifest.json should exist");
350        assert_eq!(
351            std::fs::read_to_string(manifest).unwrap(),
352            "{\"version\":\"0.1\"}"
353        );
354    }
355
356    #[test]
357    fn extracts_nested_directories() {
358        let tmp = TempDir::new().unwrap();
359        let target = tmp.path().join("dest");
360        std::fs::create_dir_all(&target).unwrap();
361
362        let gz = build_tar_gz(&[
363            ("core/", b"", true),
364            ("core/skills/", b"", true),
365            ("core/skills/query-installation.md", b"# Query", false),
366        ]);
367
368        extract_safe(Cursor::new(gz), &target).unwrap();
369
370        let file = target.join("core/skills/query-installation.md");
371        assert!(file.exists());
372        assert_eq!(std::fs::read_to_string(file).unwrap(), "# Query");
373    }
374
375    #[test]
376    fn empty_tarball_succeeds() {
377        let tmp = TempDir::new().unwrap();
378        let target = tmp.path().join("dest");
379        std::fs::create_dir_all(&target).unwrap();
380
381        let gz = build_tar_gz(&[]);
382        extract_safe(Cursor::new(gz), &target).unwrap();
383        // No panic, no error — target remains empty.
384    }
385
386    // ── Path traversal rejection ─────────────────────────────────────
387
388    #[test]
389    fn rejects_dotdot_path_traversal() {
390        let tmp = TempDir::new().unwrap();
391        let target = tmp.path().join("dest");
392        std::fs::create_dir_all(&target).unwrap();
393
394        let gz = build_tar_gz_with_forged_path("../../etc/passwd");
395        let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
396        assert!(
397            matches!(err, Error::Escape { .. }),
398            "Expected Escape, got: {err:?}"
399        );
400    }
401
402    #[test]
403    fn rejects_absolute_path() {
404        let tmp = TempDir::new().unwrap();
405        let target = tmp.path().join("dest");
406        std::fs::create_dir_all(&target).unwrap();
407
408        let gz = build_tar_gz_with_forged_path("/tmp/evil");
409        let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
410        assert!(
411            matches!(err, Error::Escape { .. }),
412            "Expected Escape, got: {err:?}"
413        );
414    }
415
416    // ── Entry type rejection ─────────────────────────────────────────
417
418    #[test]
419    fn rejects_symlink_entry() {
420        let tmp = TempDir::new().unwrap();
421        let target = tmp.path().join("dest");
422        std::fs::create_dir_all(&target).unwrap();
423
424        let gz = build_tar_gz_with_symlink("evil-link", "/etc/passwd");
425        let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
426        match err {
427            Error::UnsupportedEntry { kind, .. } => assert_eq!(kind, "symlink"),
428            other => panic!("Expected UnsupportedEntry(symlink), got: {other:?}"),
429        }
430    }
431
432    #[test]
433    fn rejects_hardlink_entry() {
434        let tmp = TempDir::new().unwrap();
435        let target = tmp.path().join("dest");
436        std::fs::create_dir_all(&target).unwrap();
437
438        let gz = build_tar_gz_with_hardlink("evil-link", "core/manifest.json");
439        let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
440        match err {
441            Error::UnsupportedEntry { kind, .. } => assert_eq!(kind, "hardlink"),
442            other => panic!("Expected UnsupportedEntry(hardlink), got: {other:?}"),
443        }
444    }
445
446    // ── Pre-planted symlink precondition ─────────────────────────────
447
448    // Note: creating directory symlinks on Windows requires elevated privileges.
449    // These tests use file symlinks which work without elevation on Windows 10+
450    // with Developer Mode enabled, but are gated by a runtime capability check.
451
452    #[test]
453    fn rejects_pre_planted_symlink_child() {
454        let tmp = TempDir::new().unwrap();
455        let target = tmp.path().join("dest");
456        std::fs::create_dir_all(&target).unwrap();
457
458        let decoy = tmp.path().join("decoy");
459        std::fs::create_dir_all(&decoy).unwrap();
460
461        // Plant a symlink as a direct child of target.
462        let symlink_path = target.join("evil");
463        #[cfg(unix)]
464        std::os::unix::fs::symlink(&decoy, &symlink_path).unwrap();
465        #[cfg(windows)]
466        {
467            if std::os::windows::fs::symlink_dir(&decoy, &symlink_path).is_err() {
468                // Symlink creation requires Developer Mode on Windows.
469                // Skip this test if we can't create symlinks.
470                eprintln!("Skipping: symlink creation requires Developer Mode");
471                return;
472            }
473        }
474
475        let gz = build_tar_gz(&[("core/manifest.json", b"data", false)]);
476        let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
477        assert!(
478            matches!(err, Error::PrePlantedSymlink { .. }),
479            "Expected PrePlantedSymlink, got: {err:?}"
480        );
481    }
482
483    // ── Error Display tests ──────────────────────────────────────────
484
485    #[test]
486    fn escape_error_names_path() {
487        let err = Error::Escape {
488            entry_path: PathBuf::from("../../etc/passwd"),
489        };
490        let display = format!("{err}");
491        assert!(display.contains("../../etc/passwd"));
492    }
493
494    #[test]
495    fn unsupported_entry_names_kind_and_path() {
496        let err = Error::UnsupportedEntry {
497            entry_path: PathBuf::from("evil-link"),
498            kind: "symlink",
499        };
500        let display = format!("{err}");
501        assert!(display.contains("symlink"));
502        assert!(display.contains("evil-link"));
503    }
504
505    #[test]
506    fn pre_planted_symlink_error_names_path() {
507        let err = Error::PrePlantedSymlink {
508            path: PathBuf::from("/some/path"),
509        };
510        let display = format!("{err}");
511        assert!(display.contains("/some/path"));
512    }
513}