Skip to main content

omne_cli/
tarball.rs

1//! Safe tarball extraction with defense-in-depth against path traversal.
2//!
3//! `extract_safe` takes any `impl Read` source and a target directory,
4//! iterates entries manually, and rejects any entry that would escape the
5//! target directory, is not a regular file or directory, or would be
6//! written through a pre-planted symlink.
7//!
8//! This module has **no** dependency on HTTP, ureq, or `GithubClient`.
9//! The split seam from the network layer is at the `impl Read` boundary.
10
11#![allow(dead_code)]
12
13use std::io::Read;
14use std::path::{Component, Path, PathBuf};
15
16use path_clean::PathClean;
17use thiserror::Error;
18
19/// Errors returned by tarball extraction operations.
20#[derive(Debug, Error)]
21pub enum Error {
22    /// A tarball entry attempted to escape the target directory via `..`,
23    /// absolute paths, or Windows prefix components.
24    #[error("tarball path traversal attempt: {entry_path}")]
25    Escape { entry_path: PathBuf },
26
27    /// A tarball entry is not a regular file or directory (e.g. symlink,
28    /// hardlink, character device, FIFO).
29    #[error("unsupported tarball entry type '{kind}': {entry_path}")]
30    UnsupportedEntry {
31        entry_path: PathBuf,
32        kind: &'static str,
33    },
34
35    /// A pre-existing symlink was found in the target tree before extraction.
36    /// Extraction is refused to prevent symlink-following attacks.
37    #[error("pre-existing symlink in target tree: {path}")]
38    PrePlantedSymlink { path: PathBuf },
39
40    /// Underlying I/O error during extraction.
41    #[error(transparent)]
42    Io(#[from] std::io::Error),
43}
44
45/// Extract a `.tar.gz` archive from `source` into `target`, rejecting
46/// any entry that would escape the target or is not a regular file/directory.
47///
48/// # Pre-extraction checks
49///
50/// Before iterating entries, this function:
51/// 1. Canonicalizes `target` (caller must create it first).
52/// 2. Walks every ancestor of `target` up to the filesystem root and
53///    verifies none is a symlink.
54/// 3. Lists every direct child of `target` and verifies none is a symlink.
55///
56/// # Per-entry checks
57///
58/// For each entry:
59/// - Only `Regular` and `Directory` types are accepted; all others
60///   (symlinks, hardlinks, devices, etc.) are rejected.
61/// - Path components are scanned: `..`, `RootDir`, and `Prefix` are rejected.
62/// - Absolute paths are rejected.
63/// - A lexical prefix check confirms the resolved path stays within `target`
64///   (case-insensitive on Windows).
65pub fn extract_safe(source: impl Read, target: &Path) -> Result<(), Error> {
66    // ── Pre-extraction symlink precondition ──────────────────────────
67    let canon_target = target.canonicalize()?;
68
69    // Check ancestors of target for symlinks.
70    check_ancestors_not_symlinks(&canon_target)?;
71
72    // Check direct children of target for symlinks.
73    check_children_not_symlinks(&canon_target)?;
74
75    // ── Extraction loop ──────────────────────────────────────────────
76    let gz = flate2::read::GzDecoder::new(source);
77    let mut archive = tar::Archive::new(gz);
78
79    for entry_result in archive.entries()? {
80        let mut entry = entry_result?;
81        let header = entry.header();
82
83        // Filter entry type: accept only Regular and Directory.
84        let entry_type = header.entry_type();
85        if !entry_type.is_file() && !entry_type.is_dir() {
86            let kind = classify_entry_type(entry_type);
87            let entry_path = entry.path()?.into_owned();
88            return Err(Error::UnsupportedEntry { entry_path, kind });
89        }
90
91        // Read and validate the path.
92        let raw_path = entry.path()?.into_owned();
93
94        // Reject absolute paths.
95        if raw_path.is_absolute() {
96            return Err(Error::Escape {
97                entry_path: raw_path,
98            });
99        }
100
101        // Reject dangerous components.
102        for component in raw_path.components() {
103            match component {
104                Component::ParentDir | Component::RootDir | Component::Prefix(_) => {
105                    return Err(Error::Escape {
106                        entry_path: raw_path,
107                    });
108                }
109                _ => {}
110            }
111        }
112
113        // Lexical prefix check: resolved path must stay within target.
114        let resolved = canon_target.join(&raw_path).clean();
115        if !starts_with_normalized(&resolved, &canon_target) {
116            return Err(Error::Escape {
117                entry_path: raw_path,
118            });
119        }
120
121        // Extract the entry. tar's own guards are defense-in-depth.
122        entry.unpack_in(&canon_target)?;
123    }
124
125    Ok(())
126}
127
128/// Check that no ancestor of `path` is a symlink.
129fn check_ancestors_not_symlinks(path: &Path) -> Result<(), Error> {
130    let mut current = path.to_path_buf();
131    while let Some(parent) = current.parent() {
132        if parent.as_os_str().is_empty() {
133            break;
134        }
135        if parent.symlink_metadata()?.file_type().is_symlink() {
136            return Err(Error::PrePlantedSymlink {
137                path: parent.to_path_buf(),
138            });
139        }
140        current = parent.to_path_buf();
141    }
142    Ok(())
143}
144
145/// Check that no direct child of `path` is a symlink.
146fn check_children_not_symlinks(path: &Path) -> Result<(), Error> {
147    if !path.is_dir() {
148        return Ok(());
149    }
150    for entry in std::fs::read_dir(path)? {
151        let entry = entry?;
152        if entry.metadata()?.file_type().is_symlink() {
153            return Err(Error::PrePlantedSymlink { path: entry.path() });
154        }
155    }
156    Ok(())
157}
158
159/// Case-aware `starts_with` check. On Windows, normalizes to lowercase
160/// for comparison. On Unix, uses byte comparison.
161fn starts_with_normalized(path: &Path, prefix: &Path) -> bool {
162    #[cfg(windows)]
163    {
164        let path_lower = path.to_string_lossy().to_lowercase();
165        let prefix_lower = prefix.to_string_lossy().to_lowercase();
166        path_lower.starts_with(&prefix_lower)
167    }
168    #[cfg(not(windows))]
169    {
170        path.starts_with(prefix)
171    }
172}
173
174/// Map tar entry type to a human-readable label.
175fn classify_entry_type(t: tar::EntryType) -> &'static str {
176    match t {
177        tar::EntryType::Symlink => "symlink",
178        tar::EntryType::Link => "hardlink",
179        tar::EntryType::Char => "char device",
180        tar::EntryType::Block => "block device",
181        tar::EntryType::Fifo => "fifo",
182        tar::EntryType::GNUSparse => "sparse",
183        _ => "unknown",
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use std::io::Cursor;
191    use tempfile::TempDir;
192
193    /// Build a .tar.gz in memory with the given entries.
194    /// Each entry is (path, content_bytes, is_dir).
195    fn build_tar_gz(entries: &[(&str, &[u8], bool)]) -> Vec<u8> {
196        let mut builder = tar::Builder::new(Vec::new());
197        for &(path, content, is_dir) in entries {
198            if is_dir {
199                let mut header = tar::Header::new_gnu();
200                header.set_entry_type(tar::EntryType::Directory);
201                header.set_size(0);
202                header.set_mode(0o755);
203                header.set_cksum();
204                builder
205                    .append_data(&mut header, path, &[] as &[u8])
206                    .unwrap();
207            } else {
208                let mut header = tar::Header::new_gnu();
209                header.set_entry_type(tar::EntryType::Regular);
210                header.set_size(content.len() as u64);
211                header.set_mode(0o644);
212                header.set_cksum();
213                builder.append_data(&mut header, path, content).unwrap();
214            }
215        }
216        let tar_bytes = builder.into_inner().unwrap();
217
218        // Compress with gzip.
219        use flate2::write::GzEncoder;
220        use flate2::Compression;
221        use std::io::Write;
222        let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
223        encoder.write_all(&tar_bytes).unwrap();
224        encoder.finish().unwrap()
225    }
226
227    /// Build a .tar.gz with a symlink entry.
228    fn build_tar_gz_with_symlink(link_name: &str, target: &str) -> Vec<u8> {
229        let mut builder = tar::Builder::new(Vec::new());
230        let mut header = tar::Header::new_gnu();
231        header.set_entry_type(tar::EntryType::Symlink);
232        header.set_size(0);
233        header.set_mode(0o777);
234        header.set_cksum();
235        builder.append_link(&mut header, link_name, target).unwrap();
236        let tar_bytes = builder.into_inner().unwrap();
237
238        use flate2::write::GzEncoder;
239        use flate2::Compression;
240        use std::io::Write;
241        let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
242        encoder.write_all(&tar_bytes).unwrap();
243        encoder.finish().unwrap()
244    }
245
246    /// Build a .tar.gz with a hardlink entry.
247    fn build_tar_gz_with_hardlink(link_name: &str, target: &str) -> Vec<u8> {
248        let mut builder = tar::Builder::new(Vec::new());
249        let mut header = tar::Header::new_gnu();
250        header.set_entry_type(tar::EntryType::Link);
251        header.set_size(0);
252        header.set_mode(0o644);
253        header.set_cksum();
254        builder.append_link(&mut header, link_name, target).unwrap();
255        let tar_bytes = builder.into_inner().unwrap();
256
257        use flate2::write::GzEncoder;
258        use flate2::Compression;
259        use std::io::Write;
260        let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
261        encoder.write_all(&tar_bytes).unwrap();
262        encoder.finish().unwrap()
263    }
264
265    /// Build a .tar.gz with a single file at a forged path.
266    ///
267    /// The tar crate's builder validates paths and rejects `..` and absolute
268    /// paths. To test our extraction guards, we need to forge a tarball with
269    /// a malicious path at the raw byte level. This builds a valid tar entry
270    /// with a safe path, then patches the path bytes in the raw tar data
271    /// before gzip-compressing.
272    fn build_tar_gz_with_forged_path(malicious_path: &str) -> Vec<u8> {
273        // Build a tar with a placeholder path.
274        let placeholder = "PLACEHOLDER_PATH_FOR_FORGING";
275        assert!(
276            malicious_path.len() <= placeholder.len(),
277            "malicious path too long for placeholder"
278        );
279
280        let content = b"evil";
281        let mut builder = tar::Builder::new(Vec::new());
282        let mut header = tar::Header::new_gnu();
283        header.set_entry_type(tar::EntryType::Regular);
284        header.set_size(content.len() as u64);
285        header.set_mode(0o644);
286        header.set_cksum();
287        builder
288            .append_data(&mut header, placeholder, &content[..])
289            .unwrap();
290        let mut tar_bytes = builder.into_inner().unwrap();
291
292        // Find and replace the placeholder path in the raw tar bytes.
293        // The path lives in the first 100 bytes of the 512-byte header.
294        if let Some(pos) = tar_bytes
295            .windows(placeholder.len())
296            .position(|w| w == placeholder.as_bytes())
297        {
298            // Zero out the path field first.
299            for b in &mut tar_bytes[pos..pos + placeholder.len()] {
300                *b = 0;
301            }
302            // Write the malicious path.
303            tar_bytes[pos..pos + malicious_path.len()].copy_from_slice(malicious_path.as_bytes());
304
305            // Recompute the header checksum (bytes 148..156).
306            // The checksum is the sum of all header bytes treating the
307            // checksum field itself as 8 spaces (0x20).
308            let header_start = pos - (pos % 512); // Align to 512-byte block.
309            let mut sum: u64 = 0;
310            for (i, &b) in tar_bytes[header_start..header_start + 512]
311                .iter()
312                .enumerate()
313            {
314                if (148..156).contains(&i) {
315                    sum += 0x20u64; // Checksum field treated as spaces.
316                } else {
317                    sum += b as u64;
318                }
319            }
320            let cksum = format!("{sum:06o}\0 ");
321            tar_bytes[header_start + 148..header_start + 156].copy_from_slice(cksum.as_bytes());
322        }
323
324        // Compress with gzip.
325        use flate2::write::GzEncoder;
326        use flate2::Compression;
327        use std::io::Write;
328        let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
329        encoder.write_all(&tar_bytes).unwrap();
330        encoder.finish().unwrap()
331    }
332
333    // ── Happy path ───────────────────────────────────────────────────
334
335    #[test]
336    fn extracts_regular_files_to_target() {
337        let tmp = TempDir::new().unwrap();
338        let target = tmp.path().join("dest");
339        std::fs::create_dir_all(&target).unwrap();
340
341        let gz = build_tar_gz(&[
342            ("core/", b"", true),
343            ("core/manifest.json", b"{\"version\":\"0.1\"}", false),
344        ]);
345
346        extract_safe(Cursor::new(gz), &target).unwrap();
347
348        let manifest = target.join("core/manifest.json");
349        assert!(manifest.exists(), "manifest.json should exist");
350        assert_eq!(
351            std::fs::read_to_string(manifest).unwrap(),
352            "{\"version\":\"0.1\"}"
353        );
354    }
355
356    #[test]
357    fn extracts_nested_directories() {
358        let tmp = TempDir::new().unwrap();
359        let target = tmp.path().join("dest");
360        std::fs::create_dir_all(&target).unwrap();
361
362        let gz = build_tar_gz(&[
363            ("core/", b"", true),
364            ("core/skills/", b"", true),
365            ("core/skills/query-installation/", b"", true),
366            ("core/skills/query-installation/SKILL.md", b"# Query", false),
367        ]);
368
369        extract_safe(Cursor::new(gz), &target).unwrap();
370
371        let file = target.join("core/skills/query-installation/SKILL.md");
372        assert!(file.exists());
373        assert_eq!(std::fs::read_to_string(file).unwrap(), "# Query");
374    }
375
376    #[test]
377    fn empty_tarball_succeeds() {
378        let tmp = TempDir::new().unwrap();
379        let target = tmp.path().join("dest");
380        std::fs::create_dir_all(&target).unwrap();
381
382        let gz = build_tar_gz(&[]);
383        extract_safe(Cursor::new(gz), &target).unwrap();
384        // No panic, no error — target remains empty.
385    }
386
387    // ── Path traversal rejection ─────────────────────────────────────
388
389    #[test]
390    fn rejects_dotdot_path_traversal() {
391        let tmp = TempDir::new().unwrap();
392        let target = tmp.path().join("dest");
393        std::fs::create_dir_all(&target).unwrap();
394
395        let gz = build_tar_gz_with_forged_path("../../etc/passwd");
396        let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
397        assert!(
398            matches!(err, Error::Escape { .. }),
399            "Expected Escape, got: {err:?}"
400        );
401    }
402
403    #[test]
404    fn rejects_absolute_path() {
405        let tmp = TempDir::new().unwrap();
406        let target = tmp.path().join("dest");
407        std::fs::create_dir_all(&target).unwrap();
408
409        let gz = build_tar_gz_with_forged_path("/tmp/evil");
410        let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
411        assert!(
412            matches!(err, Error::Escape { .. }),
413            "Expected Escape, got: {err:?}"
414        );
415    }
416
417    // ── Entry type rejection ─────────────────────────────────────────
418
419    #[test]
420    fn rejects_symlink_entry() {
421        let tmp = TempDir::new().unwrap();
422        let target = tmp.path().join("dest");
423        std::fs::create_dir_all(&target).unwrap();
424
425        let gz = build_tar_gz_with_symlink("evil-link", "/etc/passwd");
426        let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
427        match err {
428            Error::UnsupportedEntry { kind, .. } => assert_eq!(kind, "symlink"),
429            other => panic!("Expected UnsupportedEntry(symlink), got: {other:?}"),
430        }
431    }
432
433    #[test]
434    fn rejects_hardlink_entry() {
435        let tmp = TempDir::new().unwrap();
436        let target = tmp.path().join("dest");
437        std::fs::create_dir_all(&target).unwrap();
438
439        let gz = build_tar_gz_with_hardlink("evil-link", "core/manifest.json");
440        let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
441        match err {
442            Error::UnsupportedEntry { kind, .. } => assert_eq!(kind, "hardlink"),
443            other => panic!("Expected UnsupportedEntry(hardlink), got: {other:?}"),
444        }
445    }
446
447    // ── Pre-planted symlink precondition ─────────────────────────────
448
449    // Note: creating directory symlinks on Windows requires elevated privileges.
450    // These tests use file symlinks which work without elevation on Windows 10+
451    // with Developer Mode enabled, but are gated by a runtime capability check.
452
453    #[test]
454    fn rejects_pre_planted_symlink_child() {
455        let tmp = TempDir::new().unwrap();
456        let target = tmp.path().join("dest");
457        std::fs::create_dir_all(&target).unwrap();
458
459        let decoy = tmp.path().join("decoy");
460        std::fs::create_dir_all(&decoy).unwrap();
461
462        // Plant a symlink as a direct child of target.
463        let symlink_path = target.join("evil");
464        #[cfg(unix)]
465        std::os::unix::fs::symlink(&decoy, &symlink_path).unwrap();
466        #[cfg(windows)]
467        {
468            if std::os::windows::fs::symlink_dir(&decoy, &symlink_path).is_err() {
469                // Symlink creation requires Developer Mode on Windows.
470                // Skip this test if we can't create symlinks.
471                eprintln!("Skipping: symlink creation requires Developer Mode");
472                return;
473            }
474        }
475
476        let gz = build_tar_gz(&[("core/manifest.json", b"data", false)]);
477        let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
478        assert!(
479            matches!(err, Error::PrePlantedSymlink { .. }),
480            "Expected PrePlantedSymlink, got: {err:?}"
481        );
482    }
483
484    // ── Error Display tests ──────────────────────────────────────────
485
486    #[test]
487    fn escape_error_names_path() {
488        let err = Error::Escape {
489            entry_path: PathBuf::from("../../etc/passwd"),
490        };
491        let display = format!("{err}");
492        assert!(display.contains("../../etc/passwd"));
493    }
494
495    #[test]
496    fn unsupported_entry_names_kind_and_path() {
497        let err = Error::UnsupportedEntry {
498            entry_path: PathBuf::from("evil-link"),
499            kind: "symlink",
500        };
501        let display = format!("{err}");
502        assert!(display.contains("symlink"));
503        assert!(display.contains("evil-link"));
504    }
505
506    #[test]
507    fn pre_planted_symlink_error_names_path() {
508        let err = Error::PrePlantedSymlink {
509            path: PathBuf::from("/some/path"),
510        };
511        let display = format!("{err}");
512        assert!(display.contains("/some/path"));
513    }
514}