use async_trait::async_trait;
use super::{Builtin, Context, resolve_path};
use crate::error::Result;
use crate::interpreter::ExecResult;
pub struct Patch;
struct PatchOptions {
strip: usize,
dry_run: bool,
reverse: bool,
target_file: Option<String>,
}
fn parse_patch_args(args: &[String]) -> PatchOptions {
let mut opts = PatchOptions {
strip: 0,
dry_run: false,
reverse: false,
target_file: None,
};
let mut p = super::arg_parser::ArgParser::new(args);
while !p.is_done() {
if let Some(val) = p.flag_value_opt("-p") {
opts.strip = val.parse().unwrap_or(0);
} else if p.flag("--dry-run") {
opts.dry_run = true;
} else if p.flag_any(&["-R", "--reverse"]) {
opts.reverse = true;
} else if let Some(arg) = p.positional() {
opts.target_file = Some(arg.to_string());
} else {
p.advance();
}
}
opts
}
#[derive(Debug)]
struct Hunk {
old_start: usize,
#[allow(dead_code)]
old_count: usize,
new_start: usize,
#[allow(dead_code)]
new_count: usize,
lines: Vec<HunkLine>,
}
#[derive(Debug, Clone)]
enum HunkLine {
Context(String),
Add(String),
Remove(String),
}
#[derive(Debug)]
struct FileDiff {
old_path: String,
new_path: String,
hunks: Vec<Hunk>,
}
fn strip_path(path: &str, strip: usize) -> String {
if strip == 0 {
return path.to_string();
}
let parts: Vec<&str> = path.split('/').collect();
if strip >= parts.len() {
parts.last().unwrap_or(&"").to_string()
} else {
parts[strip..].join("/")
}
}
fn validate_path(path: &str) -> std::result::Result<(), String> {
for component in path.split('/') {
if component == ".." {
return Err(format!(
"patch: rejecting path '{}': contains '..' traversal",
path
));
}
}
Ok(())
}
fn parse_unified_diff(input: &str) -> Vec<FileDiff> {
let mut diffs = Vec::new();
let lines: Vec<&str> = input.lines().collect();
let mut i = 0;
while i < lines.len() {
if lines[i].starts_with("--- ") && i + 1 < lines.len() && lines[i + 1].starts_with("+++ ") {
let old_path = lines[i]
.strip_prefix("--- ")
.unwrap_or("")
.split('\t')
.next()
.unwrap_or("")
.to_string();
let new_path = lines[i + 1]
.strip_prefix("+++ ")
.unwrap_or("")
.split('\t')
.next()
.unwrap_or("")
.to_string();
i += 2;
let mut hunks = Vec::new();
while i < lines.len() && lines[i].starts_with("@@ ") {
if let Some(hunk) = parse_hunk_header(lines[i]) {
let mut hunk = hunk;
i += 1;
while i < lines.len() {
let line = lines[i];
if line.starts_with("@@ ") || line.starts_with("--- ") {
break;
}
if let Some(rest) = line.strip_prefix('+') {
hunk.lines.push(HunkLine::Add(rest.to_string()));
} else if let Some(rest) = line.strip_prefix('-') {
hunk.lines.push(HunkLine::Remove(rest.to_string()));
} else if let Some(rest) = line.strip_prefix(' ') {
hunk.lines.push(HunkLine::Context(rest.to_string()));
} else if line == "\\ No newline at end of file" {
} else {
hunk.lines.push(HunkLine::Context(line.to_string()));
}
i += 1;
}
hunks.push(hunk);
} else {
i += 1;
}
}
diffs.push(FileDiff {
old_path,
new_path,
hunks,
});
} else {
i += 1;
}
}
diffs
}
fn parse_hunk_header(line: &str) -> Option<Hunk> {
let line = line.strip_prefix("@@ ")?;
let line = line.split(" @@").next()?;
let parts: Vec<&str> = line.split(' ').collect();
if parts.len() < 2 {
return None;
}
let old_part = parts[0].strip_prefix('-')?;
let new_part = parts[1].strip_prefix('+')?;
let (old_start, old_count) = parse_range(old_part);
let (new_start, new_count) = parse_range(new_part);
Some(Hunk {
old_start,
old_count,
new_start,
new_count,
lines: Vec::new(),
})
}
fn parse_range(s: &str) -> (usize, usize) {
if let Some((start, count)) = s.split_once(',') {
(start.parse().unwrap_or(1), count.parse().unwrap_or(1))
} else {
(s.parse().unwrap_or(1), 1)
}
}
fn apply_hunks(
content: &str,
hunks: &[Hunk],
reverse: bool,
) -> std::result::Result<String, String> {
let mut lines: Vec<String> = content.lines().map(|l| l.to_string()).collect();
let had_trailing_newline = content.ends_with('\n') || content.is_empty();
for hunk in hunks.iter().rev() {
let start = if reverse {
hunk.new_start
} else {
hunk.old_start
};
let start_idx = if start > 0 { start - 1 } else { 0 };
let mut old_lines = Vec::new();
let mut new_lines = Vec::new();
for hl in &hunk.lines {
match hl {
HunkLine::Context(l) => {
old_lines.push(l.clone());
new_lines.push(l.clone());
}
HunkLine::Add(l) => {
if reverse {
old_lines.push(l.clone());
} else {
new_lines.push(l.clone());
}
}
HunkLine::Remove(l) => {
if reverse {
new_lines.push(l.clone());
} else {
old_lines.push(l.clone());
}
}
}
}
let end_idx = start_idx + old_lines.len();
if end_idx > lines.len() {
return Err(format!(
"hunk at line {} does not match (file too short)",
start
));
}
for (j, expected) in old_lines.iter().enumerate() {
let actual_idx = start_idx + j;
if actual_idx < lines.len() && lines[actual_idx] != *expected {
return Err(format!(
"hunk at line {} does not match: expected '{}', got '{}'",
start, expected, lines[actual_idx]
));
}
}
lines.splice(start_idx..end_idx, new_lines);
}
let mut result = lines.join("\n");
if had_trailing_newline && !result.is_empty() {
result.push('\n');
}
Ok(result)
}
#[async_trait]
impl Builtin for Patch {
async fn execute(&self, ctx: Context<'_>) -> Result<ExecResult> {
if let Some(r) = super::check_help_version(
ctx.args,
"Usage: patch [OPTION]... [FILE]\nApply a unified diff patch to FILE(s).\n\n -pNUM\tstrip NUM leading path components\n --dry-run\tprint results without modifying files\n -R, --reverse\treverse the patch\n --help\tdisplay this help and exit\n --version\toutput version information and exit\n",
Some("patch (bashkit) 0.1"),
) {
return Ok(r);
}
let opts = parse_patch_args(ctx.args);
let input = match ctx.stdin {
Some(s) if !s.is_empty() => s.to_string(),
_ => {
return Ok(ExecResult::err(
"patch: no input (expected unified diff on stdin)\n".to_string(),
1,
));
}
};
let file_diffs = parse_unified_diff(&input);
if file_diffs.is_empty() {
return Ok(ExecResult::err(
"patch: no valid diff found in input\n".to_string(),
1,
));
}
let mut output = String::new();
let mut had_error = false;
for diff in &file_diffs {
let target = if let Some(ref t) = opts.target_file {
t.clone()
} else {
let raw_path = if opts.reverse {
&diff.new_path
} else {
if diff.new_path == "/dev/null" {
&diff.old_path
} else {
&diff.new_path
}
};
let stripped = strip_path(raw_path, opts.strip);
if let Err(e) = validate_path(&stripped) {
output.push_str(&format!("{}\n", e));
had_error = true;
continue;
}
stripped
};
let path = resolve_path(ctx.cwd, &target);
let content = if diff.old_path == "/dev/null" && !opts.reverse {
String::new()
} else {
match ctx.fs.read_file(&path).await {
Ok(bytes) => String::from_utf8_lossy(&bytes).to_string(),
Err(_) => {
String::new()
}
}
};
match apply_hunks(&content, &diff.hunks, opts.reverse) {
Ok(patched) => {
if opts.dry_run {
output.push_str(&format!("checking file {}\n", target));
} else {
if diff.new_path == "/dev/null" && !opts.reverse {
output.push_str(&format!("patching file {} (removed)\n", target));
ctx.fs.remove(&path, false).await?;
} else {
ctx.fs.write_file(&path, patched.as_bytes()).await?;
output.push_str(&format!("patching file {}\n", target));
}
}
}
Err(e) => {
output.push_str(&format!("patch: {}: {}\n", target, e));
had_error = true;
}
}
}
if had_error {
Ok(ExecResult::err(output, 1))
} else {
Ok(ExecResult::ok(output))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fs::{FileSystem, InMemoryFs};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
async fn run_patch(
args: &[&str],
stdin: &str,
files: &[(&str, &[u8])],
) -> (ExecResult, Arc<InMemoryFs>) {
let fs = Arc::new(InMemoryFs::new());
for (path, content) in files {
let fs_trait = fs.clone() as Arc<dyn FileSystem>;
fs_trait.write_file(Path::new(path), content).await.unwrap();
}
let args: Vec<String> = args.iter().map(|s| s.to_string()).collect();
let env = HashMap::new();
let mut variables = HashMap::new();
let mut cwd = PathBuf::from("/");
let fs_dyn = fs.clone() as Arc<dyn FileSystem>;
let ctx = Context {
args: &args,
env: &env,
variables: &mut variables,
cwd: &mut cwd,
fs: fs_dyn,
stdin: Some(stdin),
#[cfg(feature = "http_client")]
http_client: None,
#[cfg(feature = "git")]
git_client: None,
#[cfg(feature = "ssh")]
ssh_client: None,
shell: None,
};
let result = Patch.execute(ctx).await.unwrap();
(result, fs)
}
#[tokio::test]
async fn test_patch_simple_change() {
let diff = "\
--- a/test.txt
+++ b/test.txt
@@ -1,3 +1,3 @@
line1
-line2
+modified
line3
";
let (result, fs) =
run_patch(&["-p1"], diff, &[("/test.txt", b"line1\nline2\nline3\n")]).await;
assert_eq!(result.exit_code, 0);
let fs_trait = fs as Arc<dyn FileSystem>;
let content = fs_trait.read_file(Path::new("/test.txt")).await.unwrap();
let text = String::from_utf8_lossy(&content);
assert!(text.contains("modified"));
assert!(!text.contains("line2"));
}
#[tokio::test]
async fn test_patch_add_lines() {
let diff = "\
--- a/test.txt
+++ b/test.txt
@@ -1,2 +1,4 @@
line1
+added1
+added2
line2
";
let (result, fs) = run_patch(&["-p1"], diff, &[("/test.txt", b"line1\nline2\n")]).await;
assert_eq!(result.exit_code, 0);
let fs_trait = fs as Arc<dyn FileSystem>;
let content = fs_trait.read_file(Path::new("/test.txt")).await.unwrap();
let text = String::from_utf8_lossy(&content);
assert!(text.contains("added1"));
assert!(text.contains("added2"));
}
#[tokio::test]
async fn test_patch_remove_lines() {
let diff = "\
--- a/test.txt
+++ b/test.txt
@@ -1,3 +1,1 @@
line1
-line2
-line3
";
let (result, fs) =
run_patch(&["-p1"], diff, &[("/test.txt", b"line1\nline2\nline3\n")]).await;
assert_eq!(result.exit_code, 0);
let fs_trait = fs as Arc<dyn FileSystem>;
let content = fs_trait.read_file(Path::new("/test.txt")).await.unwrap();
let text = String::from_utf8_lossy(&content);
assert_eq!(text.trim(), "line1");
}
#[tokio::test]
async fn test_patch_dry_run() {
let diff = "\
--- a/test.txt
+++ b/test.txt
@@ -1,2 +1,2 @@
line1
-line2
+changed
";
let (result, fs) = run_patch(
&["--dry-run", "-p1"],
diff,
&[("/test.txt", b"line1\nline2\n")],
)
.await;
assert_eq!(result.exit_code, 0);
assert!(result.stdout.contains("checking file"));
let fs_trait = fs as Arc<dyn FileSystem>;
let content = fs_trait.read_file(Path::new("/test.txt")).await.unwrap();
let text = String::from_utf8_lossy(&content);
assert!(text.contains("line2"));
}
#[tokio::test]
async fn test_patch_reverse() {
let diff = "\
--- a/test.txt
+++ b/test.txt
@@ -1,2 +1,2 @@
line1
-original
+changed
";
let (result, fs) =
run_patch(&["-R", "-p1"], diff, &[("/test.txt", b"line1\nchanged\n")]).await;
assert_eq!(result.exit_code, 0);
let fs_trait = fs as Arc<dyn FileSystem>;
let content = fs_trait.read_file(Path::new("/test.txt")).await.unwrap();
let text = String::from_utf8_lossy(&content);
assert!(text.contains("original"));
}
#[tokio::test]
async fn test_patch_strip_path() {
assert_eq!(strip_path("a/b/c.txt", 0), "a/b/c.txt");
assert_eq!(strip_path("a/b/c.txt", 1), "b/c.txt");
assert_eq!(strip_path("a/b/c.txt", 2), "c.txt");
assert_eq!(strip_path("a/b/c.txt", 5), "c.txt");
}
#[tokio::test]
async fn test_patch_no_input() {
let (result, _fs) = run_patch(&[], "", &[]).await;
assert_eq!(result.exit_code, 1);
assert!(result.stderr.contains("no input"));
}
#[tokio::test]
async fn test_patch_invalid_diff() {
let (result, _fs) = run_patch(&[], "this is not a diff\n", &[]).await;
assert_eq!(result.exit_code, 1);
assert!(result.stderr.contains("no valid diff"));
}
#[tokio::test]
async fn test_patch_hunk_mismatch() {
let diff = "\
--- a/test.txt
+++ b/test.txt
@@ -1,2 +1,2 @@
line1
-wrong_content
+changed
";
let (result, _fs) =
run_patch(&["-p1"], diff, &[("/test.txt", b"line1\nactual_content\n")]).await;
assert_eq!(result.exit_code, 1);
assert!(result.stderr.contains("does not match"));
}
#[tokio::test]
async fn test_patch_new_file() {
let diff = "\
--- /dev/null
+++ b/newfile.txt
@@ -0,0 +1,2 @@
+hello
+world
";
let (result, fs) = run_patch(&["-p1"], diff, &[]).await;
assert_eq!(result.exit_code, 0);
let fs_trait = fs as Arc<dyn FileSystem>;
let content = fs_trait.read_file(Path::new("/newfile.txt")).await.unwrap();
let text = String::from_utf8_lossy(&content);
assert!(text.contains("hello"));
assert!(text.contains("world"));
}
#[tokio::test]
async fn test_patch_target_file_override() {
let diff = "\
--- a/original.txt
+++ b/original.txt
@@ -1,2 +1,2 @@
line1
-old
+new
";
let (result, fs) = run_patch(
&["-p1", "target.txt"],
diff,
&[("/target.txt", b"line1\nold\n")],
)
.await;
assert_eq!(result.exit_code, 0);
let fs_trait = fs as Arc<dyn FileSystem>;
let content = fs_trait.read_file(Path::new("/target.txt")).await.unwrap();
let text = String::from_utf8_lossy(&content);
assert!(text.contains("new"));
}
#[tokio::test]
async fn test_parse_hunk_header() {
let hunk = parse_hunk_header("@@ -1,3 +1,4 @@").unwrap();
assert_eq!(hunk.old_start, 1);
assert_eq!(hunk.old_count, 3);
assert_eq!(hunk.new_start, 1);
assert_eq!(hunk.new_count, 4);
}
#[tokio::test]
async fn test_parse_hunk_header_single_line() {
let hunk = parse_hunk_header("@@ -5 +5,2 @@").unwrap();
assert_eq!(hunk.old_start, 5);
assert_eq!(hunk.old_count, 1);
assert_eq!(hunk.new_start, 5);
assert_eq!(hunk.new_count, 2);
}
#[tokio::test]
async fn test_patch_rejects_path_traversal() {
let diff = "\
--- a/../../../etc/passwd
+++ b/../../../etc/passwd
@@ -1,1 +1,1 @@
-root:x:0:0:root:/root:/bin/bash
+pwned
";
let (result, _fs) = run_patch(&["-p1"], diff, &[]).await;
assert_eq!(result.exit_code, 1);
assert!(
result.stderr.contains(".."),
"error should mention path traversal: {}",
result.stderr
);
}
#[tokio::test]
async fn test_patch_rejects_embedded_dotdot() {
let diff = "\
--- a/foo/../../secret.txt
+++ b/foo/../../secret.txt
@@ -1,1 +1,1 @@
-old
+new
";
let (result, _fs) = run_patch(&["-p1"], diff, &[("/secret.txt", b"old\n")]).await;
assert_eq!(result.exit_code, 1);
assert!(result.stderr.contains(".."));
}
#[tokio::test]
async fn test_patch_allows_clean_path_after_strip() {
let diff = "\
--- a/main.rs
+++ b/main.rs
@@ -1,1 +1,1 @@
-old
+new
";
let (result, _fs) = run_patch(&["-p1"], diff, &[("/main.rs", b"old\n")]).await;
assert_eq!(result.exit_code, 0);
}
#[tokio::test]
async fn test_strip_path_preserves_dotdot() {
assert_eq!(
strip_path("a/../../../etc/passwd", 1),
"../../../etc/passwd"
);
}
#[tokio::test]
async fn test_validate_path_rejects_dotdot() {
assert!(validate_path("../../../etc/passwd").is_err());
assert!(validate_path("foo/../../bar").is_err());
assert!(validate_path("..").is_err());
}
#[tokio::test]
async fn test_validate_path_allows_clean() {
assert!(validate_path("foo/bar/baz.txt").is_ok());
assert!(validate_path("file.txt").is_ok());
assert!(validate_path("a/b/c").is_ok());
}
#[tokio::test]
async fn test_patch_delete_file_removes_from_vfs() {
let diff = "--- a/to_delete.txt\n\
+++ /dev/null\n\
@@ -1,2 +0,0 @@\n\
-hello\n\
-world\n";
let (result, fs) =
run_patch(&["-p1"], diff, &[("/to_delete.txt", b"hello\nworld\n")]).await;
assert_eq!(result.exit_code, 0, "stderr: {}", result.stderr);
let fs_dyn = fs as Arc<dyn FileSystem>;
assert!(
!fs_dyn.exists(Path::new("/to_delete.txt")).await.unwrap(),
"deleted file should not exist in VFS"
);
}
}